File size: 4,206 Bytes
e8d1ade
 
fc0d433
e8d1ade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc5008f
e8d1ade
 
bc5008f
882b133
 
 
 
bc5008f
 
 
882b133
bc5008f
882b133
fc0d433
e8d1ade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc0d433
882b133
9bee04d
 
fc0d433
 
882b133
fc0d433
 
 
 
bc5008f
fc0d433
 
 
9bee04d
fc0d433
 
e8d1ade
 
 
fc0d433
 
e8d1ade
fc0d433
e8d1ade
 
fc0d433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8d1ade
 
 
 
 
bc5008f
e8d1ade
 
 
 
fc0d433
e8d1ade
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from transformers import pipeline
import tempfile, os, zipfile, traceback

translator_cache = {}

MODEL_MAP = {
    ("en", "zh"): "Helsinki-NLP/opus-mt-en-zh",
    ("zh", "en"): "Helsinki-NLP/opus-mt-zh-en",
    ("en", "ja"): "Helsinki-NLP/opus-mt-en-ja",
    ("ja", "en"): "Helsinki-NLP/opus-mt-ja-en",
}

def get_translator(src_lang, tgt_lang):
    key = (src_lang, tgt_lang)
    if key not in translator_cache:
        if key in MODEL_MAP:
            translator_cache[key] = pipeline("translation", model=MODEL_MAP[key])
        else:
            raise ValueError(f"No model for {src_lang} to {tgt_lang}")
    return translator_cache[key]

def safe_translate(text, src, tgt):
    try:
        if (src, tgt) in MODEL_MAP:
            translator = get_translator(src, tgt)
            return translator(text, max_length=512)[0]["translation_text"]
        elif (src, tgt) == ("ja", "zh") or (src, tgt) == ("zh", "ja"):
            mid = safe_translate(text, src, "en")
            return safe_translate(mid, "en", tgt)
        else:
            return f"[Unsupported: {src}->{tgt}]"
    except Exception as e:
        return f"[Translation error: {str(e)}]"

def parse_srt(srt_text):
    blocks = srt_text.strip().split("\n\n")
    subtitles = []
    for block in blocks:
        lines = block.splitlines()
        if len(lines) >= 3:
            idx = lines[0]
            timestamp = lines[1]
            text = " ".join(lines[2:])
            subtitles.append((idx, timestamp, text))
    return subtitles

def reassemble_srt(subtitles):
    return "\n\n".join(f"{idx}\n{ts}\n{txt}" for idx, ts, txt in subtitles)

def process_file(file_obj, src_lang, tgt_lang, output_dir, error_log):
    try:
        with open(file_obj.name, "r", encoding="utf-8", errors="ignore") as f:
            raw_text = f.read()
        subtitles = parse_srt(raw_text)
        translated_subs = []

        for idx, ts, txt in subtitles:
            translated = safe_translate(txt, src_lang, tgt_lang)
            bilingual = f"{txt}\n{translated}"
            translated_subs.append((idx, ts, bilingual))

        output_path = os.path.join(output_dir, os.path.basename(file_obj.name))
        with open(output_path, "w", encoding="utf-8") as f:
            f.write(reassemble_srt(translated_subs))

    except Exception as e:
        error_log.append(f"File {file_obj.name} failed: {str(e)}\n{traceback.format_exc()}")

def batch_translate(files, src_lang, tgt_lang):
    tmp_dir = tempfile.mkdtemp()
    error_log = []

    for file_obj in files:
        process_file(file_obj, src_lang, tgt_lang, tmp_dir, error_log)

    zip_path = os.path.join(tmp_dir, "translated_srt.zip")
    try:
        with zipfile.ZipFile(zip_path, 'w') as zipf:
            for name in os.listdir(tmp_dir):
                path = os.path.join(tmp_dir, name)
                if os.path.isfile(path) and name.endswith(".srt"):
                    zipf.write(path, arcname=name)
            if error_log:
                log_path = os.path.join(tmp_dir, "log.txt")
                with open(log_path, "w") as logf:
                    logf.write("\n".join(error_log))
                zipf.write(log_path, arcname="log.txt")
        return zip_path
    except Exception as e:
        fail_zip = os.path.join(tmp_dir, "fail.zip")
        with zipfile.ZipFile(fail_zip, 'w') as zipf:
            with open(os.path.join(tmp_dir, "log.txt"), "w") as logf:
                logf.write(f"ZIP error: {str(e)}\n\n{traceback.format_exc()}")
            zipf.write(os.path.join(tmp_dir, "log.txt"), arcname="log.txt")
        return fail_zip

gr.Interface(
    fn=batch_translate,
    inputs=[
        gr.File(file_types=[".srt"], label="Upload SRT files", file_count="multiple"),
        gr.Dropdown(["en", "zh", "ja"], label="Source Language", value="ja"),
        gr.Dropdown(["en", "zh", "ja"], label="Target Language", value="zh"),
    ],
    outputs=gr.File(label="Download Translated ZIP"),
    title="Batch SRT Translator (EN-ZH-JA)",
    description="Upload .srt subtitle files and translate between English, Chinese, and Japanese. Dual-language output with original + translation. ZIP output. Errors will be logged.",
).launch()