File size: 3,967 Bytes
efb198f
447b422
 
 
 
 
 
efb198f
 
447b422
07ea3d5
 
efb198f
 
6df8ecd
efb198f
6df8ecd
 
 
 
 
be484c1
6df8ecd
 
efb198f
6df8ecd
 
be484c1
 
7bc869d
6df8ecd
 
be484c1
6df8ecd
 
 
be484c1
6df8ecd
 
 
 
 
 
be484c1
efb198f
 
6df8ecd
efb198f
 
6df8ecd
447b422
 
 
 
 
 
efb198f
447b422
 
 
 
 
 
 
efb198f
 
 
447b422
6df8ecd
efb198f
 
447b422
 
 
efb198f
 
6df8ecd
efb198f
 
 
6df8ecd
 
 
447b422
07ea3d5
447b422
07ea3d5
6df8ecd
07ea3d5
447b422
07ea3d5
447b422
 
07ea3d5
 
 
447b422
07ea3d5
6df8ecd
447b422
 
6df8ecd
 
 
 
 
efb198f
6df8ecd
 
 
 
efb198f
6df8ecd
efb198f
 
6df8ecd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from fastapi import FastAPI, Request
from transformers import (
    MarianMTModel,
    MarianTokenizer,
    MBartForConditionalGeneration,
    MBart50TokenizerFast
)
import torch

# import your chunking helpers
from chunking import get_max_word_length, chunk_text

app = FastAPI()

# Map target languages to Hugging Face model IDs
MODEL_MAP = {
    "bg": "Helsinki-NLP/opus-mt-tc-big-en-bg",
    "cs": "Helsinki-NLP/opus-mt-en-cs",
    "da": "Helsinki-NLP/opus-mt-en-da",
    "de": "Helsinki-NLP/opus-mt-en-de",
    "el": "Helsinki-NLP/opus-mt-tc-big-en-el",
    "es": "Helsinki-NLP/opus-mt-tc-big-en-es",
    "et": "Helsinki-NLP/opus-mt-tc-big-en-et",
    "fi": "Helsinki-NLP/opus-mt-tc-big-en-fi",
    "fr": "Helsinki-NLP/opus-mt-en-fr",
    "hr": "facebook/mbart-large-50-many-to-many-mmt",
    "hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
    "is": "Helsinki-NLP/opus-mt-tc-big-en-gmq",
    "it": "Helsinki-NLP/opus-mt-tc-big-en-it",
    "lb": "alirezamsh/small100",
    "lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
    "lv": "facebook/mbart-large-50-many-to-many-mmt",
    "mk": "Helsinki-NLP/opus-mt-en-mk",
    "nb": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    "nl": "facebook/mbart-large-50-many-to-many-mmt",
    "no": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    "pl": "Helsinki-NLP/opus-mt-en-sla",
    "pt": "facebook/mbart-large-50-many-to-many-mmt",
    "ro": "facebook/mbart-large-50-many-to-many-mmt",
    "sk": "Helsinki-NLP/opus-mt-en-sk",
    "sl": "alirezamsh/small100",
    "sq": "alirezamsh/small100",
    "sv": "Helsinki-NLP/opus-mt-en-sv",
    "tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
}


MODEL_CACHE = {}

# βœ… Load Hugging Face model (Helsinki or Small100)
def load_model(model_id: str):
    """
    Load & cache either:
      - MBart50 (facebook/mbart-*) 
      - MarianMT otherwise
    """
    if model_id not in MODEL_CACHE:
        if model_id.startswith("facebook/mbart"):
            tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
            model     = MBartForConditionalGeneration.from_pretrained(model_id)
        else:
            tokenizer = MarianTokenizer.from_pretrained(model_id)
            model     = MarianMTModel.from_pretrained(model_id)
        model.to("cpu")
        MODEL_CACHE[model_id] = (tokenizer, model)
    return MODEL_CACHE[model_id]


# βœ… POST /translate
@app.post("/translate")
async def translate(request: Request):
    payload     = await request.json()
    text        = payload.get("text")
    target_lang = payload.get("target_lang")

    if not text or not target_lang:
        return {"error": "Missing 'text' or 'target_lang'"}

    model_id = MODEL_MAP.get(target_lang)
    if not model_id:
        return {"error": f"No model found for target language '{target_lang}'"}

    try:
        # chunk to safe length
        safe_limit = get_max_word_length([target_lang])
        chunks     = chunk_text(text, safe_limit)

        tokenizer, model = load_model(model_id)
        full_translation = []

        for chunk in chunks:
            inputs  = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
            inputs  = {k: v.to(model.device) for k, v in inputs.items()}
            outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
            full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))

        return {"translation": " ".join(full_translation)}

    except Exception as e:
        return {"error": f"Translation failed: {e}"}


# βœ… GET /languages
@app.get("/languages")
def list_languages():
    return {"supported_languages": list(MODEL_MAP.keys())}

# βœ… GET /health
@app.get("/health")
def health():
    return {"status": "ok"}

# βœ… Uvicorn startup (required by Hugging Face)
import uvicorn
if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860)