import gradio as gr
import spaces
import torch
import faiss
import numpy as np

from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    pipeline,
    BitsAndBytesConfig,
)

from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel

from sentence_transformers import SentenceTransformer

# Global variables for pipelines and settings.
TEXT_PIPELINE = None
COMPARISON_PIPELINE = None
NUM_EXAMPLES = 50  

@spaces.GPU(duration=300)
def finetune_small_subset():
    """
    Fine-tunes the custom R1 model on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset.
    Steps:
      1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code).
      2) Applies 4-bit quantization and prepares for QLoRA training.
      3) Fine-tunes on the dataset (mapping "problem" to prompt and "solution" to target).
      4) Saves the LoRA adapter to "finetuned_myr1".
      5) Reloads the adapter for inference.
    """
    # Specify the configuration ("v0" or "v1") explicitly.
    ds = load_dataset("ServiceNow-AI/R1-Distill-SFT", "v0", split="train")
    ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )

    # Load the custom model configuration from the repository.
    base_config = AutoConfig.from_pretrained(
        "wuhp/myr1",
        subfolder="myr1",
        trust_remote_code=True,
    )
    # (Optionally apply local overrides here if needed.)

    tokenizer = AutoTokenizer.from_pretrained(
        "wuhp/myr1",
        subfolder="myr1",
        trust_remote_code=True
    )

    base_model = AutoModelForCausalLM.from_pretrained(
        "wuhp/myr1",
        subfolder="myr1",
        config=base_config,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )

    base_model = prepare_model_for_kbit_training(base_model)

    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        target_modules=["q_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
    )
    lora_model = get_peft_model(base_model, lora_config)

    def tokenize_fn(ex):
        text = (
            f"Problem: {ex['problem']}\n\n"
            f"Solution: {ex['solution']}"
        )
        return tokenizer(text, truncation=True, max_length=512)

    ds = ds.map(tokenize_fn, batched=False, remove_columns=ds.column_names)
    ds.set_format("torch")

    collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir="finetuned_myr1",
        num_train_epochs=1,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=2,
        logging_steps=5,
        save_steps=999999,
        save_total_limit=1,
        fp16=False,
    )

    trainer = Trainer(
        model=lora_model,
        args=training_args,
        train_dataset=ds,
        data_collator=collator,
    )
    trainer.train()

    trainer.model.save_pretrained("finetuned_myr1")
    tokenizer.save_pretrained("finetuned_myr1")

    base_model_2 = AutoModelForCausalLM.from_pretrained(
        "wuhp/myr1",
        subfolder="myr1",
        config=base_config,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    base_model_2 = prepare_model_for_kbit_training(base_model_2)

    lora_model_2 = PeftModel.from_pretrained(
        base_model_2,
        "finetuned_myr1",
    )

    global TEXT_PIPELINE
    TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer)

    return "Finetuning complete. Model loaded for inference."

def ensure_pipeline():
    """
    Loads the base model (without LoRA) if no fine-tuned model is available.
    """
    global TEXT_PIPELINE
    if TEXT_PIPELINE is None:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
        base_config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
        base_model = AutoModelForCausalLM.from_pretrained(
            "wuhp/myr1",
            subfolder="myr1",
            config=base_config,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True
        )
        TEXT_PIPELINE = pipeline("text-generation", model=base_model, tokenizer=tokenizer)
    return TEXT_PIPELINE

def ensure_comparison_pipeline():
    """
    Loads the official R1 model pipeline if not already loaded.
    """
    global COMPARISON_PIPELINE
    if COMPARISON_PIPELINE is None:
        config = AutoConfig.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
        tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
        model = AutoModelForCausalLM.from_pretrained(
            "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
            config=config,
            device_map="auto"
        )
        COMPARISON_PIPELINE = pipeline("text-generation", model=model, tokenizer=tokenizer)
    return COMPARISON_PIPELINE

@spaces.GPU(duration=120)
def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
    """
    Direct generation without retrieval using the custom R1 model.
    """
    pipe = ensure_pipeline()
    out = pipe(
        prompt,
        temperature=float(temperature),
        top_p=float(top_p),
        min_new_tokens=int(min_new_tokens),
        max_new_tokens=int(max_new_tokens),
        do_sample=True
    )
    return out[0]["generated_text"]

@spaces.GPU(duration=120)
def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
    """
    Compare outputs between your custom R1 model and the official R1 model.
    """
    local_pipe = ensure_pipeline()
    comp_pipe = ensure_comparison_pipeline()

    local_out = local_pipe(
        prompt,
        temperature=float(temperature),
        top_p=float(top_p),
        min_new_tokens=int(min_new_tokens),
        max_new_tokens=int(max_new_tokens),
        do_sample=True
    )
    comp_out = comp_pipe(
        prompt,
        temperature=float(temperature),
        top_p=float(top_p),
        min_new_tokens=int(min_new_tokens),
        max_new_tokens=int(max_new_tokens),
        do_sample=True
    )
    return local_out[0]["generated_text"], comp_out[0]["generated_text"]

class ConversationRetriever:
    """
    A FAISS-based retriever using SentenceTransformer for embedding.
    """
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", embed_dim=384):
        self.embed_model = SentenceTransformer(model_name)
        self.embed_dim = embed_dim
        self.index = faiss.IndexFlatL2(embed_dim)
        self.texts = []
        self.vectors = []
        self.ids = []
        self.id_counter = 0

    def add_text(self, text):
        if not text.strip():
            return
        emb = self.embed_model.encode([text], convert_to_numpy=True)
        vec = emb[0].astype(np.float32)
        self.index.add(vec.reshape(1, -1))
        self.texts.append(text)
        self.vectors.append(vec)
        self.ids.append(self.id_counter)
        self.id_counter += 1

    def search(self, query, top_k=3):
        q_emb = self.embed_model.encode([query], convert_to_numpy=True).astype(np.float32)
        q_vec = q_emb[0].reshape(1, -1)
        distances, indices = self.index.search(q_vec, top_k)
        results = []
        for dist, idx in zip(distances[0], indices[0]):
            if idx < len(self.texts):
                results.append((self.texts[idx], dist))
        return results

retriever = ConversationRetriever()

def build_rag_prompt(user_query, retrieved_chunks):
    """
    Builds a prompt for retrieval-augmented generation.
    """
    context_str = ""
    for i, (chunk, dist) in enumerate(retrieved_chunks):
        context_str += f"Chunk #{i+1} (similarity ~ {dist:.2f}):\n{chunk}\n\n"
    prompt = (
        f"User's Query:\n{user_query}\n\n"
        f"Relevant Context:\n{context_str}"
        "Assistant:"
    )
    return prompt

@spaces.GPU(duration=120)
def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_tokens):
    """
    Chat with retrieval augmentation.
    """
    pipe = ensure_pipeline()
    retriever.add_text(f"User: {user_input}")
    top_k = 3
    results = retriever.search(user_input, top_k=top_k)
    prompt = build_rag_prompt(user_input, results)
    output = pipe(
        prompt,
        temperature=float(temperature),
        top_p=float(top_p),
        min_new_tokens=int(min_new_tokens),
        max_new_tokens=int(max_new_tokens),
        do_sample=True
    )[0]["generated_text"]

    if output.startswith(prompt):
        assistant_reply = output[len(prompt):].strip()
    else:
        assistant_reply = output.strip()

    retriever.add_text(f"Assistant: {assistant_reply}")
    history.append([user_input, assistant_reply])
    return history, history

# Build the Gradio interface.
with gr.Blocks() as demo:
    gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")

    finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on ServiceNow-AI/R1-Distill-SFT subset (up to 5 min)")
    status_box = gr.Textbox(label="Finetune Status")
    finetune_btn.click(fn=finetune_small_subset, outputs=status_box)

    gr.Markdown("## Direct Generation (No Retrieval) using Custom R1")
    prompt_in = gr.Textbox(lines=3, label="Prompt")
    temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
    top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
    min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
    max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
    output_box = gr.Textbox(label="Custom R1 Output", lines=8)
    gen_btn = gr.Button("Generate with Custom R1")
    gen_btn.click(
        fn=predict,
        inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
        outputs=output_box
    )

    gr.Markdown("## Compare Custom R1 vs Official R1")
    compare_btn = gr.Button("Compare")
    out_custom = gr.Textbox(label="Custom R1 Output", lines=6)
    out_official = gr.Textbox(label="Official R1 Output", lines=6)
    compare_btn.click(
        fn=compare_models,
        inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
        outputs=[out_custom, out_official]
    )

    gr.Markdown("## Chat with Retrieval-Augmented Memory")
    with gr.Row():
        with gr.Column():
            chatbot = gr.Chatbot(label="RAG Chat")
            chat_state = gr.State([])
            user_input = gr.Textbox(
                show_label=False,
                placeholder="Ask a question...",
                lines=2
            )
            send_btn = gr.Button("Send")
    user_input.submit(
        fn=chat_rag,
        inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
        outputs=[chat_state, chatbot]
    )
    send_btn.click(
        fn=chat_rag,
        inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
        outputs=[chat_state, chatbot]
    )

demo.launch()