import whisper
from transformers import MarianMTModel, MarianTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM
import os
import tempfile
import subprocess

# Load Whisper model
model = whisper.load_model("base")

def process_video(video_path, language, progress=None):
    output_video_path = os.path.join(tempfile.gettempdir(), "converted_video.mp4")
    srt_path = os.path.join(tempfile.gettempdir(), "subtitles.srt")

    try:
        # Convert video to MP4 using ffmpeg
        if progress:
            progress(0.2, desc="🔄 Converting video to MP4...")
        subprocess.run(
            ["ffmpeg", "-i", video_path, "-c:v", "libx264", "-preset", "fast", output_video_path],
            check=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )

        # Transcribe video
        if progress:
            progress(0.4, desc="📝 Transcribing audio...")
        result = model.transcribe(output_video_path, language="en")
        if progress:
            progress(0.6, desc="🌐 Translating subtitles...")

        # Translation logic
        segments = []
        if language == "English":
            segments = result["segments"]
        else:
            model_map = {
                "Hindi": "Helsinki-NLP/opus-mt-en-hi",
                "Spanish": "Helsinki-NLP/opus-mt-en-es",
                "French": "Helsinki-NLP/opus-mt-en-fr",
                "German": "Helsinki-NLP/opus-mt-en-de",
                "Telugu": "facebook/nllb-200-distilled-600M",
                "Portuguese": "Helsinki-NLP/opus-mt-en-pt",
                "Russian": "Helsinki-NLP/opus-mt-en-ru",
                "Chinese": "Helsinki-NLP/opus-mt-en-zh",
                "Arabic": "Helsinki-NLP/opus-mt-en-ar",
                "Japanese": "Helsinki-NLP/opus-mt-en-jap"
            }
            model_name = model_map.get(language)
            if not model_name:
                return None

            # Load translation model
            if language == "Telugu":
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                translation_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
                tgt_lang = "tel_Telu"
                for segment in result["segments"]:
                    inputs = tokenizer(segment["text"], return_tensors="pt", padding=True)
                    translated_tokens = translation_model.generate(
                        **inputs, forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang)
                    )
                    translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
                    segments.append({"text": translated_text, "start": segment["start"], "end": segment["end"]})
            else:
                tokenizer = MarianTokenizer.from_pretrained(model_name)
                translation_model = MarianMTModel.from_pretrained(model_name)
                for segment in result["segments"]:
                    inputs = tokenizer(segment["text"], return_tensors="pt", padding=True)
                    translated = translation_model.generate(**inputs)
                    translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
                    segments.append({"text": translated_text, "start": segment["start"], "end": segment["end"]})

        # Create SRT file
        if progress:
            progress(0.8, desc="📝 Generating SRT file...")
        with open(srt_path, "w", encoding="utf-8") as f:
            for i, segment in enumerate(segments, 1):
                start = f"{segment['start']:.3f}".replace(".", ",")
                end = f"{segment['end']:.3f}".replace(".", ",")
                text = segment["text"].strip()
                f.write(f"{i}\n00:00:{start} --> 00:00:{end}\n{text}\n\n")
        if progress:
            progress(1.0, desc="✅ Done!")
        return srt_path

    except subprocess.CalledProcessError as e:
        print(f"FFmpeg Error: {e.stderr.decode()}")
        return None
    except Exception as e:
        print(f"Unexpected Error: {str(e)}")
        return None
    finally:
        # Clean up temporary files
        if os.path.exists(output_video_path):
            os.remove(output_video_path)
        if os.path.exists(video_path):
            os.remove(video_path)