File size: 5,530 Bytes
efb198f
447b422
 
 
 
b8c0a2d
 
447b422
efb198f
282b03d
efb198f
447b422
07ea3d5
 
efb198f
 
6df8ecd
efb198f
dd77d0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efb198f
 
b8c0a2d
efb198f
 
282b03d
447b422
b8c0a2d
 
dd77d0b
b8c0a2d
447b422
dd77d0b
1d177ae
447b422
 
dd77d0b
 
 
b8c0a2d
282b03d
b8c0a2d
447b422
 
 
b8c0a2d
447b422
efb198f
 
b8c0a2d
447b422
efb198f
 
447b422
 
 
efb198f
 
6df8ecd
efb198f
 
 
6df8ecd
 
 
447b422
07ea3d5
447b422
07ea3d5
282b03d
dd77d0b
07ea3d5
447b422
07ea3d5
dd77d0b
a40fe7a
 
b302d2e
 
b4d0c67
 
dd77d0b
 
 
 
 
 
 
 
 
8183d04
dd77d0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07ea3d5
 
447b422
07ea3d5
6df8ecd
447b422
 
6df8ecd
 
 
efb198f
6df8ecd
 
 
efb198f
 
b8c0a2d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from fastapi import FastAPI, Request
from transformers import (
    MarianMTModel,
    MarianTokenizer,
    MBartForConditionalGeneration,
    MBart50TokenizerFast,
    AutoModelForSeq2SeqLM
)
import torch
from tokenization_small100 import SMALL100Tokenizer 

# import your chunking helpers
from chunking import get_max_word_length, chunk_text

app = FastAPI()

# Map target languages to Hugging Face model IDs
MODEL_MAP = {
    "bg": "Helsinki-NLP/opus-mt-tc-big-en-bg",
    "cs": "Helsinki-NLP/opus-mt-en-cs",
    "da": "Helsinki-NLP/opus-mt-en-da",
    "de": "Helsinki-NLP/opus-mt-en-de",
    "el": "Helsinki-NLP/opus-mt-tc-big-en-el",
    "es": "Helsinki-NLP/opus-mt-tc-big-en-es",
    "et": "Helsinki-NLP/opus-mt-tc-big-en-et",
    "fi": "Helsinki-NLP/opus-mt-tc-big-en-fi",
    "fr": "Helsinki-NLP/opus-mt-en-fr",
    "hr": "facebook/mbart-large-50-many-to-many-mmt",
    "hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
    "is": "mkorada/opus-mt-en-is-finetuned-v4",
    "it": "Helsinki-NLP/opus-mt-tc-big-en-it",
    "lb": "alirezamsh/small100",
    "lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
    "lv": "facebook/mbart-large-50-many-to-many-mmt",
    "cnr": "Helsinki-NLP/opus-mt-tc-base-en-sh",
    "mk": "Helsinki-NLP/opus-mt-en-mk",
    "nl": "facebook/mbart-large-50-many-to-many-mmt",
    "no": "Confused404/eng-gmq-finetuned_v2-no",
    "pl": "Helsinki-NLP/opus-mt-en-sla",
    "pt": "facebook/mbart-large-50-many-to-many-mmt",
    "ro": "facebook/mbart-large-50-many-to-many-mmt",
    "sk": "Helsinki-NLP/opus-mt-en-sk",
    "sl": "alirezamsh/small100",
    "sq": "alirezamsh/small100",
    "sv": "Helsinki-NLP/opus-mt-en-sv",
    "tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
}

# Cache loaded models/tokenizers
MODEL_CACHE = {}

def load_model(model_id: str, target_lang: str):
    """
    Load & cache:
      - facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
      - alirezamsh/small100 via SMALL100Tokenizer & AutoModelForSeq2SeqLM
      - all others via MarianTokenizer & MarianMTModel
    """
    # Always reload small100 so we can pass a new tgt_lang
    if model_id not in MODEL_CACHE or model_id == "alirezamsh/small100":
        if model_id.startswith("facebook/mbart"):
            tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
            # ── MBART: always translate FROM English
            tokenizer.src_lang = "en_XX"
            model = MBartForConditionalGeneration.from_pretrained(model_id)
        elif model_id == "alirezamsh/small100":
            tokenizer = SMALL100Tokenizer.from_pretrained(model_id, tgt_lang=target_lang)
            model     = AutoModelForSeq2SeqLM.from_pretrained(model_id)
        else:
            tokenizer = MarianTokenizer.from_pretrained(model_id)
            model     = MarianMTModel.from_pretrained(model_id)

        model.to("cpu")
        MODEL_CACHE[model_id] = (tokenizer, model)

    return MODEL_CACHE[model_id]

@app.post("/translate")
async def translate(request: Request):
    payload     = await request.json()
    text        = payload.get("text")
    target_lang = payload.get("target_lang")

    if not text or not target_lang:
        return {"error": "Missing 'text' or 'target_lang'"}

    model_id = MODEL_MAP.get(target_lang)
    if not model_id:
        return {"error": f"No model found for target language '{target_lang}'"}

    try:
        # chunk to safe length
        safe_limit = get_max_word_length([target_lang])
        chunks     = chunk_text(text, safe_limit)

        tokenizer, model = load_model(model_id, target_lang)
        is_mbart = model_id.startswith("facebook/mbart")
        full_translation = []

        for chunk in chunks:
            # special-prefix hacks for nor/cnr
            if model_id == "Confused404/eng-gmq-finetuned_v2-no":
                chunk = f">>nob<< {chunk}"
            if model_id == "Helsinki-NLP/opus-mt-tc-base-en-sh":
                chunk = f">>cnr<< {chunk}"
            if model_id == "Helsinki-NLP/opus-mt-en-sla":
                chunk = f">>pol<< {chunk}"

            # tokenize
            inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}

            # generate
            if is_mbart:
                # build e.g. "de_DE", "es_XX", etc.
                lang_code = f"{target_lang}_{target_lang.upper()}"
                if target_lang == "nl" or target_lang == "pt": lang_code = f"{target_lang}_XX"
                bos_id    = tokenizer.lang_code_to_id[lang_code]
                outputs = model.generate(
                    **inputs,
                    forced_bos_token_id=bos_id,
                    num_beams=5,
                    length_penalty=1.2,
                    early_stopping=True
                )
            else:
                outputs = model.generate(
                    **inputs,
                    num_beams=5,
                    length_penalty=1.2,
                    early_stopping=True
                )

            full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))

        return {"translation": " ".join(full_translation)}

    except Exception as e:
        return {"error": f"Translation failed: {e}"}

@app.get("/languages")
def list_languages():
    return {"supported_languages": list(MODEL_MAP.keys())}

@app.get("/health")
def health():
    return {"status": "ok"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)