import gradio as gr
from datasets import load_dataset, Dataset
from difflib import ndiff
import pandas as pd
from gradio_huggingfacehub_search import HuggingfaceHubSearch

from semhash import SemHash
from semhash.datamodels import DeduplicationResult

from model2vec import StaticModel

# Default parameters
default_dataset_name = "SetFit/amazon_massive_scenario_en-US"
default_dataset1_split = "train"
default_dataset2_split = "test"
default_text_column = "text"
default_threshold = 0.9

# Load the model to use
model = StaticModel.from_pretrained("minishlab/potion-base-8M")


def display_word_differences(x: str, y: str) -> str:
    """
    Display the word-level differences between two texts, formatted to avoid
    misinterpretation of Markdown syntax.
    """
    diff = ndiff(x.split(), y.split())
    formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
    return f"```\n{formatted_diff}\n```"


def load_dataset_texts(
    dataset_name: str, dataset_split: str, text_column: str
) -> tuple[list[str], Dataset]:
    """Load texts from a specified dataset split."""
    ds = load_dataset(dataset_name, split=dataset_split)
    return [example[text_column] for example in ds], ds


def deduplicate_single_dataset(
    texts: list[str], threshold: float
) -> DeduplicationResult:
    """
    Deduplicate within a single dataset using SemHash, treating each text
    as a raw string record.
    """
    # Build a SemHash index from the raw texts
    semhash = SemHash.from_records(records=texts, model=model)
    # Deduplicate the entire dataset
    return semhash.self_deduplicate(threshold=threshold)


def deduplicate_two_datasets(
    texts1: list[str], texts2: list[str], threshold: float
) -> DeduplicationResult:
    """Deduplicate dataset2 against dataset1, both as raw strings, using SemHash."""
    # Build SemHash index on dataset1
    semhash = SemHash.from_records(records=texts1, model=model)
    # Deduplicate texts2 against dataset1
    return semhash.deduplicate(records=texts2, threshold=threshold)


def create_deduplicated_dataset(
    original_dataset: Dataset, deduplicated_texts: list[str], text_column: str
) -> Dataset:
    """Create a new dataset with only the deduplicated texts."""
    # Create a mapping from text to original row
    text_to_row = {row[text_column]: row for row in original_dataset}

    # Build new dataset with deduplicated texts
    deduplicated_rows = []
    for text in deduplicated_texts:
        if text in text_to_row:
            deduplicated_rows.append(text_to_row[text])

    return Dataset.from_list(deduplicated_rows)


def perform_deduplication(
    deduplication_type: str,
    dataset1_name: str,
    dataset1_split: str,
    dataset1_text_column: str,
    dataset2_name: str = "",
    dataset2_split: str = "",
    dataset2_text_column: str = "",
    threshold: float = default_threshold,
    progress: gr.Progress = gr.Progress(track_tqdm=True),
):
    """
    Perform deduplication on one or two datasets using SemHash. This function
    streams status updates to Gradio for user feedback.
    """
    try:
        threshold = float(threshold)

        # Load Dataset 1
        texts1, dataset1 = load_dataset_texts(
            dataset1_name, dataset1_split, dataset1_text_column
        )

        if deduplication_type == "Single dataset":
            # Single-dataset deduplication
            result = deduplicate_single_dataset(texts1, threshold=threshold)

            # Sort all duplicates by score (ascending for least similar)
            for duprec in result.duplicates:
                duprec.duplicates.sort(key=lambda x: x[1])

            # Create deduplicated dataset
            deduplicated_dataset = create_deduplicated_dataset(
                dataset1, result.deduplicated, dataset1_text_column
            )

            # Summarize results
            num_duplicates = len(result.duplicates)
            deduplicated_count = len(result.deduplicated)
            total_docs = len(texts1)

            # Create examples table
            examples_table = None
            if num_duplicates > 0:
                # Only show duplicates that actually have near-duplicate records
                duplicates_with_data = [
                    duprec for duprec in result.duplicates if duprec.duplicates
                ]

                # sort duplicates by score (ascending for least similar)
                for duprec in result.duplicates:
                    duprec.duplicates.sort(key=lambda x: x[1])

                if duplicates_with_data:
                    # Create table data for the 5 least similar examples
                    table_data = []
                    for duprec in duplicates_with_data[:5]:
                        dup_text = duprec.record
                        orig_text, score = duprec.duplicates[0]
                        table_data.append(
                            [
                                orig_text[:200] + "..."
                                if len(orig_text) > 200
                                else orig_text,
                                dup_text[:200] + "..."
                                if len(dup_text) > 200
                                else dup_text,
                                f"{score:.4f}",
                            ]
                        )

                    examples_table = pd.DataFrame(
                        table_data,
                        columns=["Original Text", "Duplicate Text", "Similarity Score"],
                    )

            # Show success info with stats
            gr.Info(
                f"Deduplication completed! Found {num_duplicates} duplicates. "
                f"Dataset reduced from {total_docs} to {deduplicated_count} unique documents."
            )

            # Return table with visibility update
            if examples_table is not None and not examples_table.empty:
                return deduplicated_dataset, gr.update(
                    visible=True, value=examples_table
                )
            else:
                return deduplicated_dataset, gr.update(visible=False)

        else:
            # Cross-dataset deduplication
            texts2, dataset2 = load_dataset_texts(
                dataset2_name, dataset2_split, dataset2_text_column
            )

            result = deduplicate_two_datasets(texts1, texts2, threshold=threshold)

            # Sort duplicates by score (ascending for least similar)
            for duprec in result.duplicates:
                duprec.duplicates.sort(key=lambda x: x[1])

            # Create deduplicated dataset from dataset2
            deduplicated_dataset = create_deduplicated_dataset(
                dataset2, result.deduplicated, dataset2_text_column
            )

            num_duplicates = len(result.duplicates)
            total_docs2 = len(texts2)
            deduplicated_count = len(result.deduplicated)

            # Create examples table
            examples_table = None
            if num_duplicates > 0:
                # Again, only show duplicates that have records
                duplicates_with_data = [
                    duprec for duprec in result.duplicates if duprec.duplicates
                ]
                if duplicates_with_data:
                    # Create table data for the 5 least similar examples
                    table_data = []
                    for duprec in duplicates_with_data[:5]:
                        dup_text = duprec.record
                        orig_text, score = duprec.duplicates[0]
                        table_data.append(
                            [
                                orig_text[:200] + "..."
                                if len(orig_text) > 200
                                else orig_text,
                                dup_text[:200] + "..."
                                if len(dup_text) > 200
                                else dup_text,
                                f"{score:.4f}",
                            ]
                        )

                    examples_table = pd.DataFrame(
                        table_data,
                        columns=[
                            "Original Text (Dataset 1)",
                            "Duplicate Text (Dataset 2)",
                            "Similarity Score",
                        ],
                    )

            # Show success info with stats
            gr.Info(
                f"Deduplication completed! Found {num_duplicates} duplicates in Dataset 2. "
                f"Dataset reduced from {total_docs2} to {deduplicated_count} unique documents."
            )

            # Return table with visibility update
            if examples_table is not None and not examples_table.empty:
                return deduplicated_dataset, gr.update(
                    visible=True, value=examples_table
                )
            else:
                return deduplicated_dataset, gr.update(visible=False)

    except Exception as e:
        gr.Error(f"An error occurred during deduplication: {str(e)}")
        return None, gr.update(visible=False)


def push_to_hub(
    deduplicated_dataset: Dataset,
    output_dataset_name: str,
    oauth_profile: gr.OAuthProfile | None,
    oauth_token: gr.OAuthToken | None,
    progress: gr.Progress = gr.Progress(),
) -> str:
    """Push the deduplicated dataset to Hugging Face Hub."""
    if oauth_token is None:
        raise gr.Error("Please log in with Hugging Face to push datasets to the Hub.")

    if not output_dataset_name.strip():
        raise gr.Error("Please provide a dataset name.")

    if deduplicated_dataset is None:
        raise gr.Error(
            "No deduplicated dataset available. Please run deduplication first."
        )

    try:
        progress(0.1, desc="Preparing dataset...")

        # Determine the full dataset name (username/dataset_name)
        username = oauth_profile.username if oauth_profile else None
        if "/" not in output_dataset_name and username:
            full_dataset_name = f"{username}/{output_dataset_name}"
        else:
            full_dataset_name = output_dataset_name

        progress(0.3, desc="Pushing to Hub...")

        # Push to hub using the OAuth token
        deduplicated_dataset.push_to_hub(
            full_dataset_name, token=oauth_token.token, private=False
        )

        progress(1.0, desc="Complete!")

        gr.Info(
            f"Successfully pushed deduplicated dataset with {len(deduplicated_dataset)} rows to the Hub!"
        )

        return (
            f"✅ **Dataset published:** [{full_dataset_name}]"
            f"(https://huggingface.co/datasets/{full_dataset_name})"
        )

    except Exception as e:
        raise gr.Error(f"Failed to push dataset to Hub: {str(e)}")


def get_user_info(oauth_profile: gr.OAuthProfile | None) -> str:
    """Display user login status."""
    if oauth_profile is None:
        return "Not logged in. Please log in to push datasets to the Hub."
    return f"Logged in as: **{oauth_profile.username}**"


def update_push_button_state(oauth_profile: gr.OAuthProfile | None):
    """Update the push button state based on login status."""
    is_logged_in = oauth_profile is not None
    return gr.update(interactive=is_logged_in)


# --- Gradio App ---
with gr.Blocks(
    theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }"
) as demo:
    gr.Markdown("# SemDedup-My-Dataset: Semantic Text Deduplication Using SemHash")
    gr.Markdown("""
    This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder.
    It can be used to identify duplicate texts within a **single dataset** or across **two datasets**.
    You can adjust the similarity threshold to control the strictness of the deduplication.

    """)

    deduplication_type = gr.Radio(
        choices=["Cross-dataset", "Single dataset"],
        label="Deduplication Type",
        value="Cross-dataset",  # default
    )

    with gr.Row():
        dataset1_name = HuggingfaceHubSearch(
            label="Dataset 1 Name",
            placeholder="Search for datasets on HuggingFace Hub",
            search_type="dataset",
            value=default_dataset_name,
        )
        dataset1_split = gr.Textbox(
            value=default_dataset1_split, label="Dataset 1 Split"
        )
        dataset1_text_column = gr.Textbox(
            value=default_text_column, label="Text Column Name"
        )

    dataset2_inputs = gr.Column(visible=True)
    with dataset2_inputs:
        with gr.Row():
            dataset2_name = HuggingfaceHubSearch(
                label="Dataset 2 Name",
                placeholder="Search for datasets on HuggingFace Hub",
                search_type="dataset",
                value=default_dataset_name,
            )
            dataset2_split = gr.Textbox(
                value=default_dataset2_split, label="Dataset 2 Split"
            )
            dataset2_text_column = gr.Textbox(
                value=default_text_column, label="Text Column Name"
            )

    threshold = gr.Slider(
        0.0, 1.0, value=default_threshold, label="Similarity Threshold"
    )

    with gr.Row():
        compute_button = gr.Button("Deduplicate", variant="primary")

    status_output = gr.Markdown(elem_id="status_output")

    # Examples table
    examples_table = gr.Dataframe(
        headers=["Original Text", "Duplicate Text", "Similarity Score"],
        datatype=["str", "str", "str"],
    )

    # Hidden state to store the deduplicated dataset
    deduplicated_dataset_state = gr.State()

    # Output dataset configuration
    gr.Markdown("## Push Deduplicated Dataset to Hub")
    with gr.Row():
        with gr.Column():
            output_dataset_name = gr.Textbox(
                label="Output Dataset Name",
                placeholder="my-deduplicated-dataset",
                info="Will be saved as username/dataset-name",
            )
        with gr.Column():
            push_button = gr.Button(
                "Push to Hub", variant="secondary", interactive=False
            )
            login_button = gr.LoginButton()

    # Login section - moved below push to hub
    with gr.Row():
        user_info = gr.Markdown()
        push_output = gr.Markdown()

    # HACK: for some reason gradio wants this.
    login_button.activate()
    
    def update_visibility(choice: str):
        return gr.update(visible=(choice == "Cross-dataset"))

    deduplication_type.change(
        update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
    )

    # Update user info and button state when page loads or login status changes
    demo.load(get_user_info, inputs=None, outputs=user_info)
    demo.load(update_push_button_state, inputs=None, outputs=push_button)
    login_button.click(get_user_info, inputs=None, outputs=user_info)
    login_button.click(update_push_button_state, inputs=None, outputs=push_button)

    compute_button.click(
        fn=perform_deduplication,
        inputs=[
            deduplication_type,
            dataset1_name,
            dataset1_split,
            dataset1_text_column,
            dataset2_name,
            dataset2_split,
            dataset2_text_column,
            threshold,
        ],
        outputs=[deduplicated_dataset_state, examples_table],
    )

    push_button.click(
        fn=push_to_hub,
        inputs=[
            deduplicated_dataset_state,
            output_dataset_name,
        ],
        outputs=push_output,
    )

demo.launch(ssr_mode=False)