mayacou's picture
add fixes for mbart
447b422 verified
from fastapi import FastAPI, Request
from transformers import (
MarianMTModel,
MarianTokenizer,
MBartForConditionalGeneration,
MBart50TokenizerFast
)
import torch
# import your chunking helpers
from chunking import get_max_word_length, chunk_text
app = FastAPI()
# Map target languages to Hugging Face model IDs
MODEL_MAP = {
"bg": "Helsinki-NLP/opus-mt-tc-big-en-bg",
"cs": "Helsinki-NLP/opus-mt-en-cs",
"da": "Helsinki-NLP/opus-mt-en-da",
"de": "Helsinki-NLP/opus-mt-en-de",
"el": "Helsinki-NLP/opus-mt-tc-big-en-el",
"es": "Helsinki-NLP/opus-mt-tc-big-en-es",
"et": "Helsinki-NLP/opus-mt-tc-big-en-et",
"fi": "Helsinki-NLP/opus-mt-tc-big-en-fi",
"fr": "Helsinki-NLP/opus-mt-en-fr",
"hr": "facebook/mbart-large-50-many-to-many-mmt",
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
"is": "Helsinki-NLP/opus-mt-tc-big-en-gmq",
"it": "Helsinki-NLP/opus-mt-tc-big-en-it",
"lb": "alirezamsh/small100",
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
"lv": "facebook/mbart-large-50-many-to-many-mmt",
"mk": "Helsinki-NLP/opus-mt-en-mk",
"nb": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
"nl": "facebook/mbart-large-50-many-to-many-mmt",
"no": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
"pl": "Helsinki-NLP/opus-mt-en-sla",
"pt": "facebook/mbart-large-50-many-to-many-mmt",
"ro": "facebook/mbart-large-50-many-to-many-mmt",
"sk": "Helsinki-NLP/opus-mt-en-sk",
"sl": "alirezamsh/small100",
"sq": "alirezamsh/small100",
"sv": "Helsinki-NLP/opus-mt-en-sv",
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
}
MODEL_CACHE = {}
# βœ… Load Hugging Face model (Helsinki or Small100)
def load_model(model_id: str):
"""
Load & cache either:
- MBart50 (facebook/mbart-*)
- MarianMT otherwise
"""
if model_id not in MODEL_CACHE:
if model_id.startswith("facebook/mbart"):
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
model = MBartForConditionalGeneration.from_pretrained(model_id)
else:
tokenizer = MarianTokenizer.from_pretrained(model_id)
model = MarianMTModel.from_pretrained(model_id)
model.to("cpu")
MODEL_CACHE[model_id] = (tokenizer, model)
return MODEL_CACHE[model_id]
# βœ… POST /translate
@app.post("/translate")
async def translate(request: Request):
payload = await request.json()
text = payload.get("text")
target_lang = payload.get("target_lang")
if not text or not target_lang:
return {"error": "Missing 'text' or 'target_lang'"}
model_id = MODEL_MAP.get(target_lang)
if not model_id:
return {"error": f"No model found for target language '{target_lang}'"}
try:
# chunk to safe length
safe_limit = get_max_word_length([target_lang])
chunks = chunk_text(text, safe_limit)
tokenizer, model = load_model(model_id)
full_translation = []
for chunk in chunks:
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return {"translation": " ".join(full_translation)}
except Exception as e:
return {"error": f"Translation failed: {e}"}
# βœ… GET /languages
@app.get("/languages")
def list_languages():
return {"supported_languages": list(MODEL_MAP.keys())}
# βœ… GET /health
@app.get("/health")
def health():
return {"status": "ok"}
# βœ… Uvicorn startup (required by Hugging Face)
import uvicorn
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)