#! /usr/bin/env python3
"""
This script is a simple text generator using the SmollmV2 model.
It uses Gradio to create a web interface for generating text.
"""
# Third-Party Imports
import torch
import torch.nn.functional as F
import gradio as gr
from transformers import GPT2Tokenizer
import spaces
import os
from pathlib import Path
import warnings

# Local imports
from smollmv2 import SmollmV2
from config import SmollmConfig, DataConfig
from smollv2_lightning import LitSmollmv2

# Configure PyTorch to handle the device properties issue
torch._dynamo.config.suppress_errors = True
warnings.filterwarnings('ignore', category=UserWarning)

def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
    """
    Combine split model parts into a single checkpoint file
    """
    # Create checkpoints directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # Check if combined model already exists
    if os.path.exists(output_file):
        print(f"Model already combined at: {output_file}")
        return output_file
    
    # Ensure the model parts exist
    if not os.path.exists(model_dir):
        raise FileNotFoundError(f"Model directory {model_dir} not found")
    
    # Combine the parts
    parts = sorted(Path(model_dir).glob("last.ckpt.part_*"))
    if not parts:
        raise FileNotFoundError("No model parts found")
    
    print("Combining model parts...")
    with open(output_file, 'wb') as outfile:
        for part in parts:
            print(f"Processing part: {part}")
            with open(part, 'rb') as infile:
                outfile.write(infile.read())
    
    print(f"Model combined successfully: {output_file}")
    return output_file

def load_model():
    """
    Load the SmollmV2 model and tokenizer.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Load model directly from checkpoint
    checkpoint_path = "last.ckpt"
    
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(
            f"Model checkpoint {checkpoint_path} not found. "
            "Please ensure the model checkpoint file 'last.ckpt' is present in the root directory."
        )
    
    try:
        # Load the model from checkpoint using Lightning module
        model = LitSmollmv2.load_from_checkpoint(
            checkpoint_path,
            model_config=SmollmConfig,
            strict=False
        )
        
        model.to(device)
        model.eval()
        
        # Initialize tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
        tokenizer.pad_token = tokenizer.eos_token
        
        return model, tokenizer, device
    
    except Exception as e:
        raise RuntimeError(f"Error loading model: {str(e)}")

# Load the model globally
model, tokenizer, device = load_model()

@spaces.GPU(enable_queue=True)
def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
    """
    Generate text using the SmollmV2 model.
    :param prompt: The initial text prompt to start the generation from.
    :param num_tokens: The number of tokens to generate.
    :param temperature: The temperature parameter for controlling randomness.
    :param top_p: The top-p parameter for nucleus sampling
    :return: The generated text.
    """
    try:
        # Ensure num_tokens doesn't exceed model's block size
        num_tokens = min(num_tokens, SmollmConfig.block_size)
        
        # Tokenize input prompt
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
        
        # Generate tokens one at a time
        with torch.inference_mode():  # Use inference_mode instead of no_grad
            for _ in range(num_tokens):
                # Get the model's predictions
                with torch.autocast(device_type=device, dtype=torch.float16):  # Changed to float16
                    outputs = model(input_ids)
                    logits = outputs[0] if isinstance(outputs, tuple) else outputs
                
                # Get the next token probabilities
                logits = logits[:, -1, :] / temperature
                probs = F.softmax(logits, dim=-1)
                
                # Apply top-p sampling
                if top_p > 0:
                    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
                    sorted_indices_to_keep = cumsum_probs <= top_p
                    sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
                    sorted_indices_to_keep[..., 0] = 1
                    indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
                    probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
                    probs = probs / probs.sum(dim=-1, keepdim=True)
                
                # Sample next token
                next_token = torch.multinomial(probs, num_samples=1)
                
                # Append to input_ids
                input_ids = torch.cat([input_ids, next_token], dim=-1)
                
                # Stop if we generate an EOS token
                if next_token.item() == tokenizer.eos_token_id:
                    break
        
        # Decode and return the generated text
        generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return generated_text
    
    except Exception as e:
        return f"Error during text generation: {str(e)}"

# Create the Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Enter your prompt", value="Once upon a time"),
        gr.Slider(minimum=1, maximum=SmollmConfig.block_size//2, value=100, step=1, label="Number of tokens to generate"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="SmoLLMv2 Text Generator",
    description="Generate text using the SmoLLMv2-135M model",
    allow_flagging="never",
    cache_examples=True
)

if __name__ == "__main__":
    demo.launch()