import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
import os
import importlib
import os
from huggingface_hub import hf_hub_download

import spaces
from dotenv import load_dotenv
from infer import (
    load_trained_model,
    find_answer_start,
    get_noising_schedule,
    noisify_answer,
    filter_logits,
    confidence_guided_noising,
    noisify_answer_without_remasking
)
from models import CustomTransformerModel
from model_config import CustomTransformerConfig

# Load .env only when running locally
if os.getenv("HF_TOKEN") is None:
    load_dotenv()

hf_token = os.getenv("HF_TOKEN")

if hf_token is None:
    raise ValueError("HF_TOKEN is not set")

rng = np.random.default_rng()

@spaces.GPU
def generate_diffusion_text(input_ids, top_p, top_k, eos_bias=0.0):
    with torch.no_grad():
        input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
        with torch.cuda.amp.autocast(dtype=torch.float16):
            logits = model(input_ids=input_tensor)["logits"]
        
        # Apply eos_bias
        if eos_bias != 0.0:
            logits[0, :, eos_token_id] += eos_bias

        logits = filter_logits(logits, top_k=top_p, top_p=top_k) 
        logits = logits.clamp(min=-1e8, max=1e4)
        probs = torch.nn.functional.softmax(logits, dim=-1)[0]
        probs = torch.clamp(probs, min=1e-8, max=1.0)
        assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
        assert (probs >= 0).all(), "Negative probs!"
        sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()

        # Extract confidence of selected tokens
        conf = probs[range(len(sampled)), sampled].cpu().numpy()
    return sampled, conf 

def format_chat_prompt(question):
    return (
        "<|begin_of_text|>\n"
        "<|start_header_id|>system<|end_header_id|>\n"
        "You are a helpful assistant.\n"
        "<|start_header_id|>user<|end_header_id|>\n"
        f"{question}\n"
        "<|start_header_id|>assistant<|end_header_id|>\n"
    )

def render_html(label, text):
    return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>"

def highlight_tokens(token_ids, answer_start, changed_indices, color):
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    highlighted = []
    for j, tok in enumerate(tokens):
        if tokenizer.convert_tokens_to_ids(tok) == eos_token_id:
            continue
        tok_str = tokenizer.convert_tokens_to_string([tok])
        if (answer_start + j) in changed_indices:
            highlighted.append(f'<span style="color:{color}">{tok_str}</span>')
        else:
            highlighted.append(tok_str)
    return "".join(highlighted)

def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
                   noise_start, use_confidence_noising, 
                   use_permanent_unmasking, noise_clipping, top_p, 
                   top_k, added_tokens):

    eos_bias = -eos_bias
    if question.strip() == "":
        question = "What do you know about the city of Amsterdam?"

    prompt = format_chat_prompt(question)
    input_ids = tokenizer.encode(prompt, add_special_tokens=False)
    answer_start = find_answer_start(input_ids, assistant_marker_ids)
    if answer_start is None:
        yield render_html("Error", "Could not find Assistant marker in input.")
        return

    input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256]
    ori_input_tokens = input_ids

    # Initial noising
    current_tokens, just_noised_indices = noisify_answer(
        input_ids, answer_start, tokenizer, threshold=1.0, noise_start=1.0
    )
    yield render_html("Iteration 0 (initial noise)",
                      highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
    time.sleep(pause_length)

    last_tokens = []
    prev_decoded = []

    unmasked_mask = [False] * len(current_tokens)
    current_tokens = current_tokens[:answer_start]
    for i in range(max_it):
        current_tokens = current_tokens + [mask_token_id] * added_tokens
        current_tokens = current_tokens[:256]  # Ensure we don't exceed the max length
        generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k, eos_bias = eos_bias)
        current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]

        # GREEN highlighting: compare to previous tokens
        new_decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
        diff_indices = {
            answer_start + j for j, tok in enumerate(new_decoded)
            if j >= len(prev_decoded) or tok != prev_decoded[j]
        }
        prev_decoded = new_decoded

        yield render_html(f"Iteration {i+1}/{max_it} (after generation)",
                          highlight_tokens(current_tokens[answer_start:], answer_start,  diff_indices, color="green"))
        time.sleep(pause_length)

        # Early stopping
        last_tokens.append(current_tokens)
        if len(last_tokens) > 3:
            last_tokens.pop(0)
        if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
            yield render_html("Stopped early", f"After {i+1} iterations.")
            break

        # NOISING
        if i < max_it-1:
            threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
            if use_confidence_noising:
                noised_answer, just_noised_indices = confidence_guided_noising(
                    current_tokens, answer_start, tokenizer, confidences, noise_clipping,
                    threshold=threshold, noise_start=noise_start
                )
            elif use_permanent_unmasking:
                noised_answer, just_noised_indices = noisify_answer_without_remasking(
                    current_tokens, answer_start, tokenizer, threshold=threshold,
                    noise_start=noise_start, unmasked_mask=unmasked_mask
                )
            else:
                noised_answer, just_noised_indices = noisify_answer(
                    current_tokens, answer_start, tokenizer,
                    threshold=threshold, noise_start=noise_start
                )
            
            for idx in range(answer_start, len(current_tokens)):
                if noised_answer[idx] != mask_token_id:
                    unmasked_mask[idx] = True
    
    
    
            yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
                              highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
    
            current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]

    # Final output
    answer_ids = current_tokens[answer_start:]
    try:
        final_ids = answer_ids[:answer_ids.index(eos_token_id)]
    except ValueError:
        final_ids = answer_ids

    final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
    yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output) # type: ignore


def is_running_on_spaces():
    return os.getenv("SPACE_ID") is not None

print("Loading model...")

if is_running_on_spaces():
    # Load from Hugging Face Hub
    ckpt_path = hf_hub_download(
        repo_id="ruurd/tini_model",
        filename="diffusion-model-8B.pth",
        token=os.getenv("HF_TOKEN")
    )
else:
    # Load from local path
    ckpt_path = "diffusion-model-3B.pth"  # change to your actual local path

model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
print("✅ Model loaded.")

vocab_size = len(tokenizer)
eos_token_id = tokenizer.eos_token_id
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False)

demo = gr.Interface(
    fn=diffusion_chat,
    inputs=[
        gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
        gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"),
        gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
        gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label="Generation length: ↑ = more output tokens by decreasing eos token probability"),
        gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
        gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
        gr.Checkbox(value=False, label="Use confidence-guided noising"),
        gr.Checkbox(value=False, label="Use permanent unmasking"),
        gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
        gr.Slider(1, 1000, value = 3, step = 1, label = "Top-p: ↑ = more random answers"),
        gr.Slider(0.0, 1.0, value = 1.0, step = 0.01, label = "Top-k: ↑ = more random answers"),
        gr.Slider(1, 256, value=256, step=1, label="Semi-autoregressive generation: number of added tokens per iteration"),
    ],
    outputs=[gr.HTML(label="Diffusion Output")],
    title="Diffusion Language Model Chat",
    theme="default",
    description="This interface runs a diffusion-based language model to generate answers progressively."
)

demo.launch(share=True, allowed_paths=["."], ssr_mode=False)