File size: 5,530 Bytes
efb198f 447b422 b8c0a2d 447b422 efb198f 282b03d efb198f 447b422 07ea3d5 efb198f 6df8ecd efb198f dd77d0b efb198f b8c0a2d efb198f 282b03d 447b422 b8c0a2d dd77d0b b8c0a2d 447b422 dd77d0b 1d177ae 447b422 dd77d0b b8c0a2d 282b03d b8c0a2d 447b422 b8c0a2d 447b422 efb198f b8c0a2d 447b422 efb198f 447b422 efb198f 6df8ecd efb198f 6df8ecd 447b422 07ea3d5 447b422 07ea3d5 282b03d dd77d0b 07ea3d5 447b422 07ea3d5 dd77d0b a40fe7a b302d2e b4d0c67 dd77d0b 8183d04 dd77d0b 07ea3d5 447b422 07ea3d5 6df8ecd 447b422 6df8ecd efb198f 6df8ecd efb198f b8c0a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from fastapi import FastAPI, Request
from transformers import (
MarianMTModel,
MarianTokenizer,
MBartForConditionalGeneration,
MBart50TokenizerFast,
AutoModelForSeq2SeqLM
)
import torch
from tokenization_small100 import SMALL100Tokenizer
# import your chunking helpers
from chunking import get_max_word_length, chunk_text
app = FastAPI()
# Map target languages to Hugging Face model IDs
MODEL_MAP = {
"bg": "Helsinki-NLP/opus-mt-tc-big-en-bg",
"cs": "Helsinki-NLP/opus-mt-en-cs",
"da": "Helsinki-NLP/opus-mt-en-da",
"de": "Helsinki-NLP/opus-mt-en-de",
"el": "Helsinki-NLP/opus-mt-tc-big-en-el",
"es": "Helsinki-NLP/opus-mt-tc-big-en-es",
"et": "Helsinki-NLP/opus-mt-tc-big-en-et",
"fi": "Helsinki-NLP/opus-mt-tc-big-en-fi",
"fr": "Helsinki-NLP/opus-mt-en-fr",
"hr": "facebook/mbart-large-50-many-to-many-mmt",
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
"is": "mkorada/opus-mt-en-is-finetuned-v4",
"it": "Helsinki-NLP/opus-mt-tc-big-en-it",
"lb": "alirezamsh/small100",
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
"lv": "facebook/mbart-large-50-many-to-many-mmt",
"cnr": "Helsinki-NLP/opus-mt-tc-base-en-sh",
"mk": "Helsinki-NLP/opus-mt-en-mk",
"nl": "facebook/mbart-large-50-many-to-many-mmt",
"no": "Confused404/eng-gmq-finetuned_v2-no",
"pl": "Helsinki-NLP/opus-mt-en-sla",
"pt": "facebook/mbart-large-50-many-to-many-mmt",
"ro": "facebook/mbart-large-50-many-to-many-mmt",
"sk": "Helsinki-NLP/opus-mt-en-sk",
"sl": "alirezamsh/small100",
"sq": "alirezamsh/small100",
"sv": "Helsinki-NLP/opus-mt-en-sv",
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
}
# Cache loaded models/tokenizers
MODEL_CACHE = {}
def load_model(model_id: str, target_lang: str):
"""
Load & cache:
- facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
- alirezamsh/small100 via SMALL100Tokenizer & AutoModelForSeq2SeqLM
- all others via MarianTokenizer & MarianMTModel
"""
# Always reload small100 so we can pass a new tgt_lang
if model_id not in MODEL_CACHE or model_id == "alirezamsh/small100":
if model_id.startswith("facebook/mbart"):
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
# ── MBART: always translate FROM English
tokenizer.src_lang = "en_XX"
model = MBartForConditionalGeneration.from_pretrained(model_id)
elif model_id == "alirezamsh/small100":
tokenizer = SMALL100Tokenizer.from_pretrained(model_id, tgt_lang=target_lang)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
else:
tokenizer = MarianTokenizer.from_pretrained(model_id)
model = MarianMTModel.from_pretrained(model_id)
model.to("cpu")
MODEL_CACHE[model_id] = (tokenizer, model)
return MODEL_CACHE[model_id]
@app.post("/translate")
async def translate(request: Request):
payload = await request.json()
text = payload.get("text")
target_lang = payload.get("target_lang")
if not text or not target_lang:
return {"error": "Missing 'text' or 'target_lang'"}
model_id = MODEL_MAP.get(target_lang)
if not model_id:
return {"error": f"No model found for target language '{target_lang}'"}
try:
# chunk to safe length
safe_limit = get_max_word_length([target_lang])
chunks = chunk_text(text, safe_limit)
tokenizer, model = load_model(model_id, target_lang)
is_mbart = model_id.startswith("facebook/mbart")
full_translation = []
for chunk in chunks:
# special-prefix hacks for nor/cnr
if model_id == "Confused404/eng-gmq-finetuned_v2-no":
chunk = f">>nob<< {chunk}"
if model_id == "Helsinki-NLP/opus-mt-tc-base-en-sh":
chunk = f">>cnr<< {chunk}"
if model_id == "Helsinki-NLP/opus-mt-en-sla":
chunk = f">>pol<< {chunk}"
# tokenize
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# generate
if is_mbart:
# build e.g. "de_DE", "es_XX", etc.
lang_code = f"{target_lang}_{target_lang.upper()}"
if target_lang == "nl" or target_lang == "pt": lang_code = f"{target_lang}_XX"
bos_id = tokenizer.lang_code_to_id[lang_code]
outputs = model.generate(
**inputs,
forced_bos_token_id=bos_id,
num_beams=5,
length_penalty=1.2,
early_stopping=True
)
else:
outputs = model.generate(
**inputs,
num_beams=5,
length_penalty=1.2,
early_stopping=True
)
full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return {"translation": " ".join(full_translation)}
except Exception as e:
return {"error": f"Translation failed: {e}"}
@app.get("/languages")
def list_languages():
return {"supported_languages": list(MODEL_MAP.keys())}
@app.get("/health")
def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|