import gradio as gr
from huggingface_hub import HfFolder
from transformers import MarianMTModel, MarianTokenizer
from indic_transliteration import sanscript
from indic_transliteration.sanscript import transliterate
import torch  # Add this import at the top with other imports

# Global variables to store models and tokenizers
models = {}
tokenizers = {}
token = HfFolder.get_token()

# Model configurations
MODEL_CONFIGS = {
    "en-hi": {
        "model_path": "rooftopcoder/opus-mt-en-hi-samanantar-finetuned",
        "name": "English to Hindi"
    },
    "hi-en": {
        "model_path": "rooftopcoder/opus-mt-hi-en-samanantar-finetuned",
        "name": "Hindi to English"
    },
    "en-mr": {
        "model_path": "rooftopcoder/opus-mt-en-mr-samanantar-finetuned",
        "name": "English to Marathi"
    },
    "mr-en": {
        "model_path": "rooftopcoder/opus-mt-mr-en-samanantar-finetuned",
        "name": "Marathi to English"
    }
}

# Update language codes dictionary
language_codes = {
    "English": "en",
    "Hindi": "hi",
    "Marathi": "mr"
}

# Reverse dictionary for display purposes
language_names = {v: k for k, v in language_codes.items()}

def load_models():
    try:
        print("Loading models from local storage...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        for direction, config in MODEL_CONFIGS.items():
            print(f"Loading {config['name']} model...")
            tokenizers[direction] = MarianTokenizer.from_pretrained(config["model_path"], token=token)
            models[direction] = MarianMTModel.from_pretrained(config["model_path"], token=token).to(device)

        print("All models loaded successfully!")
        return True
    except Exception as e:
        print(f"Error loading models: {e}")
        return False

# Function to perform transliteration from English to Hindi
def transliterate_text(text, from_scheme=sanscript.ITRANS, to_scheme=sanscript.DEVANAGARI):
    """
    Transliterates text from one script to another
    Default is from ITRANS (Roman) to Devanagari (Hindi)
    """
    try:
        return transliterate(text, from_scheme, to_scheme)
    except Exception as e:
        print(f"Transliteration error: {e}")
        return text

# Function to perform translation with MarianMT
def translate(input_text, source_lang, target_lang):
    """
    Translates text using MarianMT models
    """
    direction = f"{source_lang}-{target_lang}"
    if direction not in models or direction not in tokenizers:
        return "Error: Unsupported language pair"

    if not input_text.strip():
        return "Error: Please enter some text to translate."

    try:
        device = next(models[direction].parameters()).device
        tokens = tokenizers[direction](input_text, return_tensors="pt", padding=True, truncation=True)
        tokens = {k: v.to(device) for k, v in tokens.items()}

        translated = models[direction].generate(**tokens)
        translated = translated.cpu()
        output = tokenizers[direction].batch_decode(translated, skip_special_tokens=True)
        return output[0]
    except Exception as e:
        print(f"Translation error: {e}")
        return f"Error during translation: {str(e)}"

# Helper function for handling the UI translation process
def perform_translation(input_text, source_lang, target_lang):
    """Wrapper function for the Gradio interface"""
    source_code = language_codes[source_lang]
    target_code = language_codes[target_lang]

    # Handle transliteration for Hindi and Marathi
    if source_code == "en" and target_code in ["hi", "mr"]:
        common_indic_words = {
            "hi": ["namaste", "dhanyavad", "kaise", "hai", "aap", "tum", "main"],
            "mr": ["namaskar", "dhanyawad", "kase", "ahe", "tumhi", "mi"]
        }

        words = input_text.lower().split()
        if any(word in common_indic_words.get(target_code, []) for word in words):
            transliterated = transliterate_text(input_text)
            if transliterated != input_text:
                translation = translate(input_text, source_code, target_code)
                return f"Transliterated: {transliterated}\n\nTranslated: {translation}"

    return translate(input_text, source_code, target_code)

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="Neural Machine Translation - Indian Languages") as demo:
        gr.Markdown("# Neural Machine Translation for Indian Languages")
        gr.Markdown("Translate between English, Hindi, and Marathi using MarianMT models")

        with gr.Row():
            with gr.Column():
                source_lang = gr.Dropdown(
                    choices=list(language_codes.keys()),
                    label="Source Language",
                    value="English"
                )
                input_text = gr.Textbox(
                    lines=5,
                    placeholder="Enter text to translate...",
                    label="Input Text"
                )

            with gr.Column():
                target_lang = gr.Dropdown(
                    choices=list(language_codes.keys()),
                    label="Target Language",
                    value="Hindi"
                )
                output_text = gr.Textbox(
                    lines=5,
                    label="Translated Text",
                    placeholder="Translation will appear here..."
                )

        translate_btn = gr.Button("Translate", variant="primary")
        transliterate_btn = gr.Button("Transliterate Only", variant="secondary")

        # Event handlers
        translate_btn.click(
            fn=perform_translation,
            inputs=[input_text, source_lang, target_lang],
            outputs=[output_text],
            api_name="translate"
        )

        # Direct transliteration handler (new)
        def direct_transliterate(text):
            if not text.strip():
                return "Please enter text to transliterate"
            return transliterate_text(text)

        transliterate_btn.click(
            fn=direct_transliterate,
            inputs=[input_text],
            outputs=[output_text],
            api_name="transliterate"
        )

        # Examples for all language pairs
        gr.Examples(
            examples=[
                ["Hello, how are you?", "English", "Hindi"],
                ["नमस्ते, आप कैसे हैं?", "Hindi", "English"],
                ["Hello, how are you?", "English", "Marathi"],
                ["नमस्कार, तुम्ही कसे आहात?", "Marathi", "English"],
            ],
            inputs=[input_text, source_lang, target_lang],
            fn=perform_translation,
            outputs=output_text,
            cache_examples=True
        )

        gr.Markdown("""
        ## Model Information

        This demo uses fine-tuned MarianMT models for translation between:
        - English ↔️ Hindi
        - English ↔️ Marathi

        ### Features:
        - Bidirectional translation support
        - Transliteration support for romanized Indic text
        - Optimized models for each language pair
        """)

    return demo

# Launch the interface
if __name__ == "__main__":
    # Load all models before launching the interface
    if load_models():
        demo = create_interface()
        demo.launch(share=False)
    else:
        print("Failed to load models. Please check the model paths and try again.")