Spaces:
Running
Running
File size: 3,967 Bytes
efb198f 447b422 efb198f 447b422 07ea3d5 efb198f 6df8ecd efb198f 6df8ecd be484c1 6df8ecd efb198f 6df8ecd be484c1 7bc869d 6df8ecd be484c1 6df8ecd be484c1 6df8ecd be484c1 efb198f 6df8ecd efb198f 6df8ecd 447b422 efb198f 447b422 efb198f 447b422 6df8ecd efb198f 447b422 efb198f 6df8ecd efb198f 6df8ecd 447b422 07ea3d5 447b422 07ea3d5 6df8ecd 07ea3d5 447b422 07ea3d5 447b422 07ea3d5 447b422 07ea3d5 6df8ecd 447b422 6df8ecd efb198f 6df8ecd efb198f 6df8ecd efb198f 6df8ecd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from fastapi import FastAPI, Request
from transformers import (
MarianMTModel,
MarianTokenizer,
MBartForConditionalGeneration,
MBart50TokenizerFast
)
import torch
# import your chunking helpers
from chunking import get_max_word_length, chunk_text
app = FastAPI()
# Map target languages to Hugging Face model IDs
MODEL_MAP = {
"bg": "Helsinki-NLP/opus-mt-tc-big-en-bg",
"cs": "Helsinki-NLP/opus-mt-en-cs",
"da": "Helsinki-NLP/opus-mt-en-da",
"de": "Helsinki-NLP/opus-mt-en-de",
"el": "Helsinki-NLP/opus-mt-tc-big-en-el",
"es": "Helsinki-NLP/opus-mt-tc-big-en-es",
"et": "Helsinki-NLP/opus-mt-tc-big-en-et",
"fi": "Helsinki-NLP/opus-mt-tc-big-en-fi",
"fr": "Helsinki-NLP/opus-mt-en-fr",
"hr": "facebook/mbart-large-50-many-to-many-mmt",
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
"is": "Helsinki-NLP/opus-mt-tc-big-en-gmq",
"it": "Helsinki-NLP/opus-mt-tc-big-en-it",
"lb": "alirezamsh/small100",
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
"lv": "facebook/mbart-large-50-many-to-many-mmt",
"mk": "Helsinki-NLP/opus-mt-en-mk",
"nb": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
"nl": "facebook/mbart-large-50-many-to-many-mmt",
"no": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
"pl": "Helsinki-NLP/opus-mt-en-sla",
"pt": "facebook/mbart-large-50-many-to-many-mmt",
"ro": "facebook/mbart-large-50-many-to-many-mmt",
"sk": "Helsinki-NLP/opus-mt-en-sk",
"sl": "alirezamsh/small100",
"sq": "alirezamsh/small100",
"sv": "Helsinki-NLP/opus-mt-en-sv",
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
}
MODEL_CACHE = {}
# β
Load Hugging Face model (Helsinki or Small100)
def load_model(model_id: str):
"""
Load & cache either:
- MBart50 (facebook/mbart-*)
- MarianMT otherwise
"""
if model_id not in MODEL_CACHE:
if model_id.startswith("facebook/mbart"):
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
model = MBartForConditionalGeneration.from_pretrained(model_id)
else:
tokenizer = MarianTokenizer.from_pretrained(model_id)
model = MarianMTModel.from_pretrained(model_id)
model.to("cpu")
MODEL_CACHE[model_id] = (tokenizer, model)
return MODEL_CACHE[model_id]
# β
POST /translate
@app.post("/translate")
async def translate(request: Request):
payload = await request.json()
text = payload.get("text")
target_lang = payload.get("target_lang")
if not text or not target_lang:
return {"error": "Missing 'text' or 'target_lang'"}
model_id = MODEL_MAP.get(target_lang)
if not model_id:
return {"error": f"No model found for target language '{target_lang}'"}
try:
# chunk to safe length
safe_limit = get_max_word_length([target_lang])
chunks = chunk_text(text, safe_limit)
tokenizer, model = load_model(model_id)
full_translation = []
for chunk in chunks:
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return {"translation": " ".join(full_translation)}
except Exception as e:
return {"error": f"Translation failed: {e}"}
# β
GET /languages
@app.get("/languages")
def list_languages():
return {"supported_languages": list(MODEL_MAP.keys())}
# β
GET /health
@app.get("/health")
def health():
return {"status": "ok"}
# β
Uvicorn startup (required by Hugging Face)
import uvicorn
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860) |