import torch
import torch.nn as nn
import math

# RMSNorm is a normalization technique that normalizes the input by dividing by the square root of the variance plus a small number to prevent division by zero
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-5): # the number of features/dimensions/embeddings in the input, eps is a small number to prevent division by zero
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size)) # weight is a learnable parameter that scales the input
        self.eps = eps

    def forward(self, x):
        norm = x.pow(2).mean(-1, keepdim=True).sqrt() + self.eps # compute the norm of the input
        return x / norm * self.weight # normalize the input by dividing by the norm and scale it by the weight parameter


# RotaryEmbedding is a technique that rotates the input by a learnable angle
class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000, device=None): # dim is the number of features/dimensions/embeddings in the input, base is a base number for the frequency, device is the device to store the buffer
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)) # compute the inverse frequency
        self.register_buffer("inv_freq", inv_freq) # register the inverse frequency as a buffer

    def forward(self, x, seq_len):
        seq_len = seq_len.to(x.device) # convert seq_len to the device of the input 
        t = torch.arange(seq_len, device=x.device) # create a tensor of the sequence length
        freqs = torch.einsum("i,j->ij", t, self.inv_freq) # compute the frequency by taking the dot product of the sequence length and the inverse frequency
        emb = torch.cat((freqs, freqs), dim=-1) # concatenate the frequency with itself
        return emb

class LlamaMLP(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) # create the gate projection layer with the input dimension and the hidden dimension
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False) # create the up projection layer with the input dimension and the hidden dimension
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False) # create the down projection layer with the hidden dimension and the output dimension
        self.act_fn = nn.SiLU() # create the activation function

    def forward(self, x):
        gated = self.gate_proj(x) # apply the gate projection to the input
        hidden = self.up_proj(x) # apply the up projection to the input
        return self.down_proj(self.act_fn(gated * hidden)) # apply the activation function to the gated and hidden values and then apply the down projection
    
class LlamaAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.o_proj = nn.Linear(dim, dim, bias=False)

    def forward(self, x):
        batch_size, seq_len, dim = x.size() # [batch_size, seq_len, dim] -> [4, 128, 576]
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)


        # Split heads
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [batch_size, num_heads, seq_len, head_dim]
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention = torch.softmax(scores, dim=-1)
        context = torch.matmul(attention, v)

        # Combine heads
        context = context.transpose(1, 2).reshape(batch_size, seq_len, dim)
        return self.o_proj(context)

class LlamaDecoderLayer(nn.Module):
    def __init__(self, dim, hidden_dim, num_heads):
        super().__init__()
        self.self_attn = LlamaAttention(dim, num_heads)
        self.mlp = LlamaMLP(dim, hidden_dim)
        self.input_layernorm = LlamaRMSNorm(dim)
        self.post_attention_layernorm = LlamaRMSNorm(dim)

    def forward(self, x):
        residual = x
        x = self.input_layernorm(x)
        x = self.self_attn(x)
        x = x + residual

        residual = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = x + residual
        return x


class LlamaModel(nn.Module):
    def __init__(self, vocab_size, dim, num_layers, hidden_dim, num_heads):
        super().__init__()
        self.embed_tokens = nn.Embedding(vocab_size, dim)
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(dim, hidden_dim, num_heads) for _ in range(num_layers)
        ])
        self.norm = LlamaRMSNorm(dim)
        self.rotary_emb = LlamaRotaryEmbedding(dim)

    def forward(self, x):
        x = self.embed_tokens(x)
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

class LlamaForCausalLM(nn.Module):
    def __init__(self, vocab_size, dim, num_layers, hidden_dim, num_heads):
        super().__init__()
        self.model = LlamaModel(vocab_size, dim, num_layers, hidden_dim, num_heads)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)

    def forward(self, x):
        x = self.model(x)
        return self.lm_head(x)

def get_model(tokenizer):
    vocab_size = tokenizer.vocab_size  # Use actual tokenizer vocab size
    return LlamaForCausalLM(
        vocab_size=vocab_size,
        dim=576,
        num_layers=30,
        hidden_dim=1536,
        num_heads=8
    )

# model = get_model()
# print(model)

import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, get_scheduler
from torch.optim import AdamW
#import wandb
import os
#from model import get_model

#wandb.init(project="smollm-training", name="llama-smollm-corpus")

BATCH_SIZE = 8
SEQ_LEN = 256
LEARNING_RATE = 1e-4
EPOCHS = 5
WARMUP_STEPS = 1000
GRADIENT_CLIP_VAL = 1.0
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)


def generate_text(
    model, tokenizer, prompt, max_length=50, temperature=0.7, top_k=50, device=DEVICE
):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            next_token_logits = outputs[:, -1, :] / temperature

            # Apply top-k sampling
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
            probs = torch.softmax(top_k_logits, dim=-1)

            # Sample from the filtered distribution
            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = top_k_indices[0, next_token_idx[0]]

            if next_token.item() == tokenizer.eos_token_id:
                break

            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    model.train()
    return generated_text


def save_checkpoint(model, optimizer, scheduler, epoch, step, loss, path):
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
            "loss": loss,
            "step": step,
        },
        path,
    )


def load_checkpoint(path, model, optimizer, scheduler):
    if os.path.exists(path):
        # path = './checkpoints/checkpoint_step_5000.pt'
        # print(f"Loading checkpoint from {path}")
        checkpoint = torch.load(path, weights_only=True)

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        if scheduler and checkpoint["scheduler_state_dict"]:
            scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        return checkpoint["epoch"], checkpoint["step"]
    return 0, 0


def count_parameters(model):
    """Count the number of trainable parameters in the model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
if tokenizer.pad_token is None:
    if tokenizer.eos_token:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        tokenizer.resize_token_embeddings(len(tokenizer))

dataset = load_dataset(
    "HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True, split="train"
)


def tokenize_function(examples):
    return tokenizer(
        examples["text"], truncation=True, max_length=SEQ_LEN, padding="max_length"
    )


tokenized_dataset = dataset.map(tokenize_function, batched=True)


def collate_fn(batch):
    input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
    attention_mask = torch.tensor(
        [item["attention_mask"] for item in batch], dtype=torch.long
    )
    labels = input_ids.clone()
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


train_loader = DataLoader(
    tokenized_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn
)

# Initialize model, optimizer, and scheduler
model = get_model(tokenizer)
model.to(DEVICE)

# Print model parameters
# total_params = count_parameters(model)
# print(f"\nModel Statistics:")
# print(f"Total Parameters: {total_params:,}")
# print(f"Model Size: {total_params * 4 / (1024 * 1024):.2f} MB")  # Assuming float32 (4 bytes)
# print(f"Device: {DEVICE}")
# print(f"Batch Size: {BATCH_SIZE}")
# print(f"Sequence Length: {SEQ_LEN}")
# print(f"Learning Rate: {LEARNING_RATE}")
# print("-" * 50 + "\n")

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=LEARNING_RATE,
    total_steps=10000,
    pct_start=0.1,
    anneal_strategy="cos",
    cycle_momentum=False,
)

# Load checkpoint if exists
start_epoch, global_step = load_checkpoint(
    f"{CHECKPOINT_DIR}/latest_checkpoint.pt", model, optimizer, lr_scheduler
)

# Sample prompts for evaluation
sample_prompts = [
    "The future of artificial intelligence",
    "The most important thing in life",
    "The best way to learn programming",
]

model.train()
try:
    for epoch in range(start_epoch, EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        for step, batch in enumerate(train_loader, start=global_step):
            # Move batch to device
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            # Forward pass
            outputs = model(input_ids)
            logits = outputs.view(-1, tokenizer.vocab_size)

            # Calculate loss with label smoothing
            loss = torch.nn.functional.cross_entropy(
                logits, labels.view(-1), label_smoothing=0.1  # Add label smoothing
            )

            # Backward pass with gradient clipping
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VAL)
            optimizer.step()
            lr_scheduler.step()

            # Logging
            if step % 10 == 0:
                print(
                    f"Step {step}, Loss: {loss.item():.4f}, LR: {lr_scheduler.get_last_lr()[0]:.2e}"
                )
                #wandb.log(
                #    {
                #        "loss": loss.item(),
                #        "lr": lr_scheduler.get_last_lr()[0],
                #        "step": step,
                #        "epoch": epoch,
                #    }
                #)

            # Save checkpoint every 100 steps
            #if step % 100 == 0:
            #    save_checkpoint(
            #        model,
            #        optimizer,
            #        lr_scheduler,
            #        epoch,
            #        step,
            #        loss.item(),
            #        f"{CHECKPOINT_DIR}/latest_checkpoint.pt",
            #    )

                # Also save numbered checkpoint every 1000 steps
                if step % 1000 == 0:
                    save_checkpoint(
                        model,
                        optimizer,
                        lr_scheduler,
                        epoch,
                        step,
                        loss.item(),
                        f"{CHECKPOINT_DIR}/checkpoint_step.pt",
                    )

            # Generate sample text every 500 steps with different temperatures
            if step % 500 == 0:
                print("\n=== Generating Sample Texts ===")
                for temp in [0.7, 1.0]:  # Try different temperatures
                    for prompt in sample_prompts:
                        generated = generate_text(
                            model,
                            tokenizer,
                            prompt,
                            temperature=temp,
                            max_length=100,  # Increased max length
                        )
                        print(f"\nPrompt: {prompt}")
                        print(f"Temperature: {temp}")
                        print(f"Generated: {generated}")
                        #wandb.log(
                        #    {
                        #        f"generated_text_temp_{temp}_{prompt[:20]}": wandb.Html(
                        #            generated
                        #        )
                        #    }
                        #)
                print("\n=== End of Samples ===\n")
                model.train()

        # Save epoch checkpoint
        #save_checkpoint(
        #    model,
        #    optimizer,
        #    lr_scheduler,
        #    epoch,
        #    step,
        #    loss.item(),
        #    f"{CHECKPOINT_DIR}/checkpoint_epoch.pt",
        #)

except KeyboardInterrupt:
    print("\nTraining interrupted! Saving checkpoint...")
    save_checkpoint(
        model,
        optimizer,
        lr_scheduler,
        epoch,
        step,
        loss.item(),
        f"{CHECKPOINT_DIR}/interrupted_checkpoint.pt",
    )

print("Training complete!")
#wandb.finish()