import evaluate
import numpy as np
import streamlit as st
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)

import wandb
from guardrails_genie.utils import StreamlitProgressbarCallback


def train_binary_classifier(
    project_name: str,
    entity_name: str,
    run_name: str,
    dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
    model_name: str = "distilbert/distilbert-base-uncased",
    prompt_column_name: str = "prompt",
    id2label: dict[int, str] = {0: "SAFE", 1: "INJECTION"},
    label2id: dict[str, int] = {"SAFE": 0, "INJECTION": 1},
    learning_rate: float = 1e-5,
    batch_size: int = 16,
    num_epochs: int = 2,
    weight_decay: float = 0.01,
    save_steps: int = 1000,
    streamlit_mode: bool = False,
):
    """
    Trains a binary classifier using a specified dataset and model architecture.

    This function sets up and trains a binary sequence classification model using
    the Hugging Face Transformers library. It integrates with Weights & Biases for
    experiment tracking and optionally displays a progress bar in a Streamlit app.

    Args:
        project_name (str): The name of the Weights & Biases project.
        entity_name (str): The Weights & Biases entity (user or team).
        run_name (str): The name of the Weights & Biases run.
        dataset_repo (str, optional): The Hugging Face dataset repository to load.
        model_name (str, optional): The pre-trained model to use.
        prompt_column_name (str, optional): The column name in the dataset containing
            the text prompts.
        id2label (dict[int, str], optional): Mapping from label IDs to label names.
        label2id (dict[str, int], optional): Mapping from label names to label IDs.
        learning_rate (float, optional): The learning rate for training.
        batch_size (int, optional): The batch size for training and evaluation.
        num_epochs (int, optional): The number of training epochs.
        weight_decay (float, optional): The weight decay for the optimizer.
        save_steps (int, optional): The number of steps between model checkpoints.
        streamlit_mode (bool, optional): If True, integrates with Streamlit to display
            a progress bar.

    Returns:
        dict: The output of the training process, including metrics and model state.

    Raises:
        Exception: If an error occurs during training, the exception is raised after
            ensuring Weights & Biases run is finished.
    """
    wandb.init(
        project=project_name,
        entity=entity_name,
        name=run_name,
        job_type="train-binary-classifier",
    )
    if streamlit_mode:
        st.markdown(
            f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
        )
    dataset = load_dataset(dataset_repo)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    tokenized_datasets = dataset.map(
        lambda examples: tokenizer(examples[prompt_column_name], truncation=True),
        batched=True,
    )
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    accuracy = evaluate.load("accuracy")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return accuracy.compute(predictions=predictions, references=labels)

    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,
        id2label=id2label,
        label2id=label2id,
    )

    trainer = Trainer(
        model=model,
        args=TrainingArguments(
            output_dir="binary-classifier",
            learning_rate=learning_rate,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=num_epochs,
            weight_decay=weight_decay,
            eval_strategy="epoch",
            save_strategy="steps",
            save_steps=save_steps,
            load_best_model_at_end=True,
            push_to_hub=False,
            report_to="wandb",
            logging_strategy="steps",
            logging_steps=1,
        ),
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        processing_class=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
    )
    try:
        training_output = trainer.train()
    except Exception as e:
        wandb.finish()
        raise e
    wandb.finish()
    return training_output