import os
import json
import gradio as gr
from gliner import GLiNER

# Load the common examples from the JSON file
with open("examples.json", "r", encoding="utf-8") as f:
    common_examples = json.load(f)

# Utility function to merge adjacent entities (used in NuNER Zero)
def merge_entities(entities):
    if not entities:
        return []
    merged = []
    current = entities[0]
    for next_entity in entities[1:]:
        # Merge if same label and adjacent
        if next_entity['entity'] == current['entity'] and (next_entity['start'] == current['end'] + 1 or next_entity['start'] == current['end']):
            current['word'] += ' ' + next_entity['word']
            current['end'] = next_entity['end']
        else:
            merged.append(current)
            current = next_entity
    merged.append(current)
    return merged

# Load the three models
model_nuner = GLiNER.from_pretrained("numind/NuZero_token")
model_pii   = GLiNER.from_pretrained("urchade/gliner_multi_pii-v1")
model_med   = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")

# Define NER functions for each model
def ner_nuner(text, labels, threshold, nested_ner):
    label_list = [lbl.strip() for lbl in labels.split(",")]
    pred_entities = model_nuner.predict_entities(text, label_list, flat_ner=not nested_ner, threshold=threshold)
    entities = [
        {"entity": entity["label"], "word": entity["text"], "start": entity["start"], "end": entity["end"], "score": 0}
        for entity in pred_entities
    ]
    merged_entities = merge_entities(entities)
    return {"text": text, "entities": merged_entities}

def ner_pii(text, labels, threshold, nested_ner):
    label_list = [lbl.strip() for lbl in labels.split(",")]
    pred_entities = model_pii.predict_entities(text, label_list, flat_ner=not nested_ner, threshold=threshold)
    entities = [
        {"entity": entity["label"], "word": entity["text"], "start": entity["start"], "end": entity["end"], "score": 0}
        for entity in pred_entities
    ]
    return {"text": text, "entities": entities}

def ner_med(text, labels, threshold, nested_ner):
    label_list = [lbl.strip() for lbl in labels.split(",")]
    pred_entities = model_med.predict_entities(text, label_list, flat_ner=not nested_ner, threshold=threshold)
    entities = [
        {"entity": entity["label"], "word": entity["text"], "start": entity["start"], "end": entity["end"], "score": 0}
        for entity in pred_entities
    ]
    return {"text": text, "entities": entities}

# Use the first example from the common examples for default values
default_text, default_labels, default_threshold, default_nested = common_examples[0]

# Build the combined Gradio app with three tabs
with gr.Blocks(title="GLiNER NER Testbed") as demo:
    gr.Markdown("# GLiNER NER Testbed")
    with gr.Accordion("This interface allows you to compare different zero-shot Named Entity Recognition models...", open=True):
        gr.Markdown(
            """        
            ## Models Available:
            - **GLiNER Medium v2.1**: The original GLiNER medium model
            - **GLiNER Multi PII**: Fine-tuned for detecting personally identifiable information across multiple languages
            - **NuNER Zero**: A specialized token-based NER model
            
            ## Features:
            - Select different models
            - Select examples based on different use cases
            - Toggle nested entity recognition
            - Entity merging is currently enabled for NuNER Zero only
            
            ## About GLiNER:

            **GLiNER** is a state-of-the-art Named Entity Recognition (NER) system that leverages a BERT-like bidirectional transformer encoder to identify a wide range of entity types in text. Unlike conventional NER models that are restricted to fixed entity categories, GLiNER supports flexible, zero-shot extraction, making it ideal for diverse real-world applications. It also provides a resource-efficient alternative to large language models (LLMs) for scenarios where cost and speed are critical. Distributed under the Apache 2.0 license, GLiNER is commercially friendly and readily deployable.

            **Useful Links**

            - **Model:** [gliner_medium-v2.1](https://huggingface.co/urchade/gliner_medium-v2.1)
            - **All GLiNER Models:** [Hugging Face GLiNER Models](https://huggingface.co/models?library=gliner)
            - **Research Paper:** [arXiv:2311.08526](https://arxiv.org/abs/2311.08526)
            - **Repository:** [GitHub - GLiNER](https://github.com/urchade/GLiNER)
            """
        )
    
    with gr.Tabs():
        # Tab for GLiNER-medium
        with gr.Tab("GLiNER-medium"):
            gr.Markdown("## GLiNER-medium-v2.1")
            with gr.Accordion("How to run this model locally", open=False):
                gr.Markdown(
                    """
                    **Installation:**
                    ```
                    !pip install gliner
                    ```
                    **Usage:**
                    Load the model with `GLiNER.from_pretrained("urchade/gliner_medium-v2.1")`
                    and call `predict_entities` to perform zero-shot NER.
                    """
                )
                gr.Code(
                    '''from gliner import GLiNER
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")''',
                    language="python",
                )
            input_text_med = gr.Textbox(value=default_text, label="Text input", placeholder="Enter your text here")
            with gr.Row():
                labels_med = gr.Textbox(value=default_labels, label="Labels", placeholder="Enter labels (comma separated)", scale=2)
                threshold_med = gr.Slider(0, 1, value=default_threshold, step=0.01, label="Threshold", info="Lower threshold to increase predictions", scale=1)
                nested_ner_med = gr.Checkbox(value=default_nested, label="Nested NER", info="Allow for nested NER?", scale=0)
            output_med = gr.HighlightedText(label="Predicted Entities")
            submit_btn_med = gr.Button("Submit")
            gr.Examples(
                common_examples,
                fn=ner_med,
                inputs=[input_text_med, labels_med, threshold_med, nested_ner_med],
                outputs=output_med,
                cache_examples=False,
            )
            input_text_med.submit(ner_med, inputs=[input_text_med, labels_med, threshold_med, nested_ner_med], outputs=output_med)
            labels_med.submit(ner_med, inputs=[input_text_med, labels_med, threshold_med, nested_ner_med], outputs=output_med)
            threshold_med.release(ner_med, inputs=[input_text_med, labels_med, threshold_med, nested_ner_med], outputs=output_med)
            submit_btn_med.click(ner_med, inputs=[input_text_med, labels_med, threshold_med, nested_ner_med], outputs=output_med)
            nested_ner_med.change(ner_med, inputs=[input_text_med, labels_med, threshold_med, nested_ner_med], outputs=output_med)

        # Tab for GLiNER-PII
        with gr.Tab("GLiNER-PII"):
            gr.Markdown("## GLiNER-PII")
            with gr.Accordion("How to run this model locally", open=False):
                gr.Markdown(
                    """
                    **Installation:**
                    ```
                    !pip install gliner
                    ```
                    **Usage:**
                    Load the model with `GLiNER.from_pretrained("urchade/gliner_multi_pii-v1")`
                    and call `predict_entities` to extract PII.
                    """
                )
                gr.Code(
                    '''from gliner import GLiNER
model = GLiNER.from_pretrained("urchade/gliner_multi_pii-v1")''',
                    language="python",
                )
            input_text_pii = gr.Textbox(value=default_text, label="Text input", placeholder="Enter your text here")
            with gr.Row():
                labels_pii = gr.Textbox(value=default_labels, label="Labels", placeholder="Enter labels (comma separated)", scale=2)
                threshold_pii = gr.Slider(0, 1, value=default_threshold, step=0.01, label="Threshold", info="Lower threshold to increase predictions", scale=1)
                nested_ner_pii = gr.Checkbox(value=default_nested, label="Nested NER", info="Allow for nested NER?", scale=0)
            output_pii = gr.HighlightedText(label="Predicted Entities")
            submit_btn_pii = gr.Button("Submit")
            gr.Examples(
                common_examples,
                fn=ner_pii,
                inputs=[input_text_pii, labels_pii, threshold_pii, nested_ner_pii],
                outputs=output_pii,
                cache_examples=False,
            )
            input_text_pii.submit(ner_pii, inputs=[input_text_pii, labels_pii, threshold_pii, nested_ner_pii], outputs=output_pii)
            labels_pii.submit(ner_pii, inputs=[input_text_pii, labels_pii, threshold_pii, nested_ner_pii], outputs=output_pii)
            threshold_pii.release(ner_pii, inputs=[input_text_pii, labels_pii, threshold_pii, nested_ner_pii], outputs=output_pii)
            submit_btn_pii.click(ner_pii, inputs=[input_text_pii, labels_pii, threshold_pii, nested_ner_pii], outputs=output_pii)
            nested_ner_pii.change(ner_pii, inputs=[input_text_pii, labels_pii, threshold_pii, nested_ner_pii], outputs=output_pii)
    
        # Tab for NuNER Zero
        with gr.Tab("NuNER Zero"):
            gr.Markdown("## NuNER Zero")
            with gr.Accordion("How to run this model locally", open=False):
                gr.Markdown(
                    """
                    **Installation:**
                    ```
                    !pip install gliner
                    ```
                    **Usage:**
                    Load the model with `GLiNER.from_pretrained("numind/NuZero_token")`
                    and call `predict_entities` to perform zero-shot NER.
                    """
                )
                gr.Code(
                    '''from gliner import GLiNER
model = GLiNER.from_pretrained("numind/NuZero_token")''',
                    language="python",
                )
            input_text_nuner = gr.Textbox(value=default_text, label="Text input", placeholder="Enter your text here")
            with gr.Row():
                labels_nuner = gr.Textbox(value=default_labels, label="Labels", placeholder="Enter labels (comma separated)", scale=2)
                threshold_nuner = gr.Slider(0, 1, value=default_threshold, step=0.01, label="Threshold", info="Lower threshold to increase predictions", scale=1)
                nested_ner_nuner = gr.Checkbox(value=default_nested, label="Nested NER", info="Allow for nested NER?", scale=0)
            output_nuner = gr.HighlightedText(label="Predicted Entities")
            submit_btn_nuner = gr.Button("Submit")
            gr.Examples(
                common_examples,
                fn=ner_nuner,
                inputs=[input_text_nuner, labels_nuner, threshold_nuner, nested_ner_nuner],
                outputs=output_nuner,
                cache_examples=False,
            )
            input_text_nuner.submit(ner_nuner, inputs=[input_text_nuner, labels_nuner, threshold_nuner, nested_ner_nuner], outputs=output_nuner)
            labels_nuner.submit(ner_nuner, inputs=[input_text_nuner, labels_nuner, threshold_nuner, nested_ner_nuner], outputs=output_nuner)
            threshold_nuner.release(ner_nuner, inputs=[input_text_nuner, labels_nuner, threshold_nuner, nested_ner_nuner], outputs=output_nuner)
            submit_btn_nuner.click(ner_nuner, inputs=[input_text_nuner, labels_nuner, threshold_nuner, nested_ner_nuner], outputs=output_nuner)
            nested_ner_nuner.change(ner_nuner, inputs=[input_text_nuner, labels_nuner, threshold_nuner, nested_ner_nuner], outputs=output_nuner)

    # Enable queuing and launch the app
    demo.queue()
    demo.launch(debug=True)