import gradio as gr
import time
from transformers import pipeline
from datasets import load_dataset

# Загружаем датасет
dataset = load_dataset("Romjiik/Russian_bank_reviews", split="train")

# Примеры для few-shot
few_shot_examples = []
for row in dataset.select(range(3)):
    review = row["review"]
    category = row["category"] if "category" in row else "(Категория)"
    ex = f"Клиент: {review}\nКлассификация: {category}"
    few_shot_examples.append(ex)

# Инструкции
cot_instruction = (
    "Ты — помощник банка. Клиент задал вопрос. Проанализируй обращение шаг за шагом, "
    "выдели ключевые признаки и выдай итоговую категорию обращения."
)

simple_instruction = (
    "Ты — помощник банка. Определи категорию обращения клиента. Ответ должен быть кратким, без лишнего текста."
)

# Используемые модели
models = {
    "ChatGPT-like (ruGPT3small)": pipeline("text-generation", model="ai-forever/rugpt3small_based_on_gpt2", tokenizer="ai-forever/rugpt3small_based_on_gpt2", device=-1),
    "GigaChat-like (ruDialoGPT-medium)": pipeline("text-generation", model="t-bank-ai/ruDialoGPT-medium", tokenizer="t-bank-ai/ruDialoGPT-medium", device=-1),
    "DeepSeek-like (RuBERT-tiny2)": pipeline("text-classification", model="cointegrated/rubert-tiny2", tokenizer="cointegrated/rubert-tiny2", device=-1)
}

# Формирование промптов

def build_cot_prompt(user_input):
    examples = "\n\n".join(few_shot_examples)
    return (
        f"{cot_instruction}\n\n{examples}\n\nКлиент: {user_input}\nРассуждение и классификация:"
    )

def build_simple_prompt(user_input):
    examples = "\n\n".join(few_shot_examples)
    return (
        f"{simple_instruction}\n\n{examples}\n\nКлиент: {user_input}\nКлассификация:"
    )

# Генерация ответов

def generate_dual_answers(user_input):
    results = {}
    prompt_cot = build_cot_prompt(user_input)
    prompt_simple = build_simple_prompt(user_input)

    for name, pipe in models.items():
        if name.startswith("DeepSeek"):
            # классификация
            start = time.time()
            output = pipe(user_input)[0]
            end = round(time.time() - start, 2)
            results[name] = {
                "cot_answer": output['label'],
                "cot_time": end,
                "simple_answer": output['label'],
                "simple_time": end
            }
        else:
            # генерация CoT
            start_cot = time.time()
            out_cot = pipe(prompt_cot, max_new_tokens=100, do_sample=True, top_p=0.9, temperature=0.7)[0]["generated_text"]
            end_cot = round(time.time() - start_cot, 2)
            answer_cot = out_cot.split("Классификация:")[-1].strip()

            # генерация Simple
            start_simple = time.time()
            out_simple = pipe(prompt_simple, max_new_tokens=60, do_sample=True, top_p=0.9, temperature=0.7)[0]["generated_text"]
            end_simple = round(time.time() - start_simple, 2)
            answer_simple = out_simple.split("Классификация:")[-1].strip()

            results[name] = {
                "cot_answer": answer_cot,
                "cot_time": end_cot,
                "simple_answer": answer_simple,
                "simple_time": end_simple
            }

    return (
        results["ChatGPT-like (ruGPT3small)"]["cot_answer"], f"{results['ChatGPT-like (ruGPT3small)']['cot_time']} сек",
        results["ChatGPT-like (ruGPT3small)"]["simple_answer"], f"{results['ChatGPT-like (ruGPT3small)']['simple_time']} сек",
        results["GigaChat-like (ruDialoGPT-medium)"]["cot_answer"], f"{results['GigaChat-like (ruDialoGPT-medium)']['cot_time']} сек",
        results["GigaChat-like (ruDialoGPT-medium)"]["simple_answer"], f"{results['GigaChat-like (ruDialoGPT-medium)']['simple_time']} сек",
        results["DeepSeek-like (RuBERT-tiny2)"]["cot_answer"], f"{results['DeepSeek-like (RuBERT-tiny2)']['cot_time']} сек",
        results["DeepSeek-like (RuBERT-tiny2)"]["simple_answer"], f"{results['DeepSeek-like (RuBERT-tiny2)']['simple_time']} сек"
    )

# Gradio интерфейс
with gr.Blocks() as demo:
    gr.Markdown("## 🛡️ Детектирование мошеннических обращений")

    inp = gr.Textbox(label="Вопрос клиента", placeholder="Например: Я не могу войти в приложение — пишет, что пароль неверный", lines=2)
    btn = gr.Button("Классифицировать")

    gr.Markdown("### ChatGPT-like (ruGPT3small)")
    cot1 = gr.Textbox(label="CoT ответ")
    cot1_time = gr.Textbox(label="Время CoT")
    simple1 = gr.Textbox(label="Zero-shot ответ")
    simple1_time = gr.Textbox(label="Время Zero-shot")

    gr.Markdown("### GigaChat-like (ruDialoGPT-medium)")
    cot2 = gr.Textbox(label="CoT ответ")
    cot2_time = gr.Textbox(label="Время CoT")
    simple2 = gr.Textbox(label="Zero-shot ответ")
    simple2_time = gr.Textbox(label="Время Zero-shot")

    gr.Markdown("### DeepSeek-like (RuBERT-tiny2)")
    cot3 = gr.Textbox(label="CoT ответ")
    cot3_time = gr.Textbox(label="Время CoT")
    simple3 = gr.Textbox(label="Zero-shot ответ")
    simple3_time = gr.Textbox(label="Время Zero-shot")

    btn.click(generate_dual_answers, inputs=[inp], outputs=[
        cot1, cot1_time, simple1, simple1_time,
        cot2, cot2_time, simple2, simple2_time,
        cot3, cot3_time, simple3, simple3_time
    ])

demo.launch()