import gradio as gr
import torch
import numpy as np
import librosa
import soundfile as sf
import tempfile
import os

from transformers import pipeline, VitsModel, AutoTokenizer
from datasets import load_dataset

# For Coqui TTS (XTTS-v2) used for Chinese and Japanese
try:
    from TTS.api import TTS as CoquiTTS
except ImportError:
    raise ImportError("Please install Coqui TTS via pip install TTS.")

# ------------------------------------------------------
# 1. ASR Pipeline (English) using Wav2Vec2
# ------------------------------------------------------
asr = pipeline(
    "automatic-speech-recognition",
    model="facebook/wav2vec2-base-960h"
)

# ------------------------------------------------------
# 2. Translation Models (9 languages)
# ------------------------------------------------------
translation_models = {
    "French": "Helsinki-NLP/opus-mt-en-fr",
    "Spanish": "Helsinki-NLP/opus-mt-en-es",
    "Vietnamese": "Helsinki-NLP/opus-mt-en-vi",
    "Indonesian": "Helsinki-NLP/opus-mt-en-id",
    "Turkish": "Helsinki-NLP/opus-mt-en-trk",
    "Portuguese": "Helsinki-NLP/opus-mt-tc-big-en-pt",
    "Korean": "Helsinki-NLP/opus-mt-tc-big-en-ko",
    "Chinese": "Helsinki-NLP/opus-mt-en-zh",
    "Japanese": "Helsinki-NLP/opus-mt-en-jap"
}

translation_tasks = {
    "French": "translation_en_to_fr",
    "Spanish": "translation_en_to_es",
    "Vietnamese": "translation_en_to_vi",
    "Indonesian": "translation_en_to_id",
    "Turkish": "translation_en_to_tr",
    "Portuguese": "translation_en_to_pt",
    "Korean": "translation_en_to-ko",
    "Chinese": "translation_en_to_zh",
    "Japanese": "translation_en_to_ja"
}

# ------------------------------------------------------
# 3. TTS Configuration
#    - MMS TTS (VITS) for: French, Spanish, Vietnamese, Indonesian, Turkish, Portuguese, Korean
#    - Coqui XTTS-v2 for: Chinese and Japanese
# ------------------------------------------------------
tts_config = {
    "French": {"model_id": "facebook/mms-tts-fra", "architecture": "vits", "type": "mms"},
    "Spanish": {"model_id": "facebook/mms-tts-spa", "architecture": "vits", "type": "mms"},
    "Vietnamese": {"model_id": "facebook/mms-tts-vie", "architecture": "vits", "type": "mms"},
    "Indonesian": {"model_id": "facebook/mms-tts-ind", "architecture": "vits", "type": "mms"},
    "Turkish": {"model_id": "facebook/mms-tts-tur", "architecture": "vits", "type": "mms"},
    "Portuguese": {"model_id": "facebook/mms-tts-por", "architecture": "vits", "type": "mms"},
    "Korean": {"model_id": "facebook/mms-tts-kor", "architecture": "vits", "type": "mms"},
    "Chinese": {"type": "coqui"},
    "Japanese": {"type": "coqui"}
}

# For Coqui, map languages to expected language codes.
coqui_lang_map = {
    "Chinese": "zh",
    "Japanese": "ja"
}

# ------------------------------------------------------
# 4. Global Caches for Translators and TTS Models
# ------------------------------------------------------
translator_cache = {}
mms_tts_cache = {}   
coqui_tts_cache = None  

# ------------------------------------------------------
# 5. Translator Helper
# ------------------------------------------------------
def get_translator(lang):
    if lang in translator_cache:
        return translator_cache[lang]
    model_name = translation_models[lang]
    task_name = translation_tasks[lang]
    translator = pipeline(task_name, model=model_name)
    translator_cache[lang] = translator
    return translator

# ------------------------------------------------------
# 6. MMS TTS (VITS) Helper for languages using MMS TTS
# ------------------------------------------------------
def load_mms_tts(lang):
    if lang in mms_tts_cache:
        return mms_tts_cache[lang]
    config = tts_config[lang]
    try:
        model = VitsModel.from_pretrained(config["model_id"])
        tokenizer = AutoTokenizer.from_pretrained(config["model_id"])
        mms_tts_cache[lang] = (model, tokenizer)
    except Exception as e:
        raise RuntimeError(f"Failed to load MMS TTS model for {lang} ({config['model_id']}): {e}")
    return mms_tts_cache[lang]

def run_mms_tts(text, lang):
    model, tokenizer = load_mms_tts(lang)
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        output = model(**inputs)
    if not hasattr(output, "waveform"):
        raise RuntimeError(f"MMS TTS model output for {lang} does not contain 'waveform'.")
    waveform = output.waveform.squeeze().cpu().numpy()
    sample_rate = 16000
    return sample_rate, waveform

# ------------------------------------------------------
# 7. Coqui TTS Helper for Chinese and Japanese
# ------------------------------------------------------
def load_coqui_tts():
    global coqui_tts_cache
    if coqui_tts_cache is not None:
        return coqui_tts_cache
    try:
        coqui_tts_cache = CoquiTTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False)
    except Exception as e:
        raise RuntimeError(f"Failed to load Coqui XTTS-v2 TTS: {e}")
    return coqui_tts_cache

def run_coqui_tts(text, lang):
    coqui_tts = load_coqui_tts()
    lang_code = coqui_lang_map[lang]
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        tmp_name = tmp.name
    try:
        coqui_tts.tts_to_file(
            text=text,
            file_path=tmp_name,
            language=lang_code
        )
        data, sr = sf.read(tmp_name)
    finally:
        if os.path.exists(tmp_name):
            os.remove(tmp_name)
    return sr, data

# ------------------------------------------------------
# 8. Main Prediction Function
# ------------------------------------------------------
def predict(audio, text, target_language):
    """
    1. Obtain English text (via ASR if audio provided, else text).
    2. Translate English text to target_language.
    3. Generate TTS audio using either MMS TTS (VITS) or Coqui XTTS-v2.
    """
    # Step 1: Get English text.
    if text.strip():
        english_text = text.strip()
    elif audio is not None:
        sample_rate, audio_data = audio
        if audio_data.dtype not in [np.float32, np.float64]:
            audio_data = audio_data.astype(np.float32)
        if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
            audio_data = np.mean(audio_data, axis=1)
        if sample_rate != 16000:
            audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
        asr_input = {"array": audio_data, "sampling_rate": 16000}
        asr_result = asr(asr_input)
        english_text = asr_result["text"].lower()
    else:
        return "No input provided.", "", None

    # Step 2: Translate.
    translator = get_translator(target_language)
    try:
        translation_result = translator(english_text)
        translated_text = translation_result[0]["translation_text"]
    except Exception as e:
        return english_text, f"Translation error: {e}", None

    # Step 3: TTS.
    try:
        tts_type = tts_config[target_language]["type"]
        if tts_type == "mms":
            sr, waveform = run_mms_tts(translated_text, target_language)
        elif tts_type == "coqui":
            sr, waveform = run_coqui_tts(translated_text, target_language)
        else:
            raise RuntimeError("Unknown TTS type for target language.")
    except Exception as e:
        return english_text, translated_text, f"TTS error: {e}"

    return english_text, translated_text, (sr, waveform)

# ------------------------------------------------------
# 9. Gradio Interface
# ------------------------------------------------------
language_choices = [
    "French", "Spanish", "Vietnamese", "Indonesian", "Turkish", "Portuguese", "Korean", "Chinese", "Japanese"
]

iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
        gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
        gr.Dropdown(choices=language_choices, value="French", label="Target Language")
    ],
    outputs=[
        gr.Textbox(label="English Transcription"),
        gr.Textbox(label="Translation (Target Language)"),
        gr.Audio(label="Synthesized Speech")
    ],
    title="Multimodal Language Learning Aid",
    description=(
        "This app performs the following tasks:\n"
        "1. Transcribes English speech using Wav2Vec2 (accepts text input as well).\n"
        "2. Translates the English text to the target language using Helsinki-NLP models.\n"
        "3. Provides speech:\n"
        "   - For French, Spanish, Vietnamese, Indonesian, Turkish, Portuguese, and Korean: uses Facebook MMS TTS (VITS-based).\n"
        "   - For Chinese and Japanese: uses myshell-ai MeloTTS models (work-in-progress).\n"
        "\nSelect your target language from the dropdown."
    ),
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)