import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

# --- Konfiguration ---
MODEL_ID = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
HF_TOKEN = os.getenv("HF_TOKEN") # Optional: Für private Modelle oder Zugriffsbeschränkungen

# --- Lade Modell und Tokenizer (explizit auf CPU) ---
print(f"Lade Tokenizer: {MODEL_ID}")
# Stelle sicher, dass trust_remote_code=True gesetzt ist, da Qwen3 dies oft benötigt
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)

if tokenizer.pad_token is None:
    print("pad_token nicht gesetzt, verwende eos_token als pad_token.")
    tokenizer.pad_token = tokenizer.eos_token

print(f"Lade Modell: {MODEL_ID} auf CPU. Dies kann einige Zeit dauern...")
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
        trust_remote_code=True,
        token=HF_TOKEN
    )
except Exception as e:
    print(f"Fehler beim Laden mit bfloat16 ({e}), versuche float32...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float32,
        device_map="cpu",
        trust_remote_code=True,
        token=HF_TOKEN
    )

model.eval()
print("Modell und Tokenizer erfolgreich geladen.")

# --- Vorhersagefunktion für das ChatInterface ---
def predict(message, history):
    messages_for_template = []
    for user_msg, ai_msg in history: # history ist jetzt eine Liste von Listen/Tupeln
        messages_for_template.append({"role": "user", "content": user_msg})
        messages_for_template.append({"role": "assistant", "content": ai_msg})
    messages_for_template.append({"role": "user", "content": message})

    try:
        prompt = tokenizer.apply_chat_template(
            messages_for_template,
            tokenize=False,
            add_generation_prompt=True
        )
    except Exception as e:
        print(f"Fehler beim Anwenden des Chat-Templates: {e}")
        prompt_parts = []
        for turn in messages_for_template:
            prompt_parts.append(f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>")
        prompt = "\n".join(prompt_parts) + "\n<|im_start|>assistant\n"

    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to("cpu")

    generation_kwargs = {
        "max_new_tokens": 512,
        "temperature": 0.7,
        "top_p": 0.9,
        "top_k": 50,
        "do_sample": True,
        "pad_token_id": tokenizer.eos_token_id,
    }

    print("Generiere Antwort...")
    with torch.no_grad():
        outputs = model.generate(**inputs, **generation_kwargs)

    response_ids = outputs[0][inputs.input_ids.shape[-1]:]
    response = tokenizer.decode(response_ids, skip_special_tokens=True)
    print(f"Antwort: {response}")
    return response

# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek Qwen3 8B (CPU)") as demo:
    gr.Markdown(
        """
        # DeepSeek Qwen3 8B Chat (CPU)
        Dies ist eine Demo des `deepseek-ai/DeepSeek-R1-0528-Qwen3-8B` Modells, das auf einer CPU läuft.
        **Achtung:** Antworten können aufgrund der CPU-Inferenz **sehr langsam** sein (mehrere Minuten pro Antwort sind möglich).
        Bitte habe Geduld.
        """
    )
    chatbot_interface = gr.ChatInterface(
        fn=predict,
        chatbot=gr.Chatbot(
            height=600,
            label="Chat",
            show_label=False,
            # bubble_full_width=False, # Entfernt, da veraltet
            # type="messages" # Wichtig, um die Warnung zu beheben, aber history-Format in predict() muss passen
                            # Da predict bereits die history als [[user, ai], [user, ai]] erwartet (Standard für ChatInterface),
                            # lassen wir type hier weg, damit es mit dem Format von predict harmoniert.
                            # Wenn predict `history` als [{"role": "user", ...}, {"role": "assistant", ...}] erwarten würde,
                            # dann wäre `type="messages"` hier richtig.
                            # Da die Warnung sich auf die Standardeinstellung bezieht, die bald "messages" sein wird,
                            # und unsere predict-Funktion bereits das "tuples"-Format verarbeitet, ist das OK für jetzt.
                            # Man könnte predict anpassen, um das "messages" Format direkt zu verarbeiten, wenn man type="messages" setzt.
        ),
        textbox=gr.Textbox(
            placeholder="Stelle mir eine Frage...",
            container=False,
            scale=7
        ),
        examples=[
            ["Hallo, wer bist du?"],
            ["Was ist die Hauptstadt von Frankreich?"],
            ["Schreibe ein kurzes Gedicht über KI."]
        ],
        # Entferne die nicht unterstützten Button-Argumente:
        # retry_btn="Wiederholen",
        # undo_btn="Letzte entfernen",
        # clear_btn="Chat löschen",
    )
    gr.Markdown("Modell von [deepseek-ai](https://huggingface.co/deepseek-ai) auf Hugging Face.")

if __name__ == "__main__":
    demo.launch()