mayacou commited on
Commit
dd77d0b
·
verified ·
1 Parent(s): fb386ea

call mbart correctly - from en to language code

Browse files
Files changed (1) hide show
  1. app.py +60 -38
app.py CHANGED
@@ -4,7 +4,6 @@ from transformers import (
4
  MarianTokenizer,
5
  MBartForConditionalGeneration,
6
  MBart50TokenizerFast,
7
- AutoTokenizer,
8
  AutoModelForSeq2SeqLM
9
  )
10
  import torch
@@ -17,35 +16,34 @@ app = FastAPI()
17
 
18
  # Map target languages to Hugging Face model IDs
19
  MODEL_MAP = {
20
- "bg": "Helsinki-NLP/opus-mt-tc-big-en-bg", # bulgarian
21
- "cs": "Helsinki-NLP/opus-mt-en-cs", # czech
22
- "da": "Helsinki-NLP/opus-mt-en-da", # danish
23
- "de": "Helsinki-NLP/opus-mt-en-de", # german
24
- "el": "Helsinki-NLP/opus-mt-tc-big-en-el", # greek
25
- "es": "Helsinki-NLP/opus-mt-tc-big-en-es", # spanish
26
- "et": "Helsinki-NLP/opus-mt-tc-big-en-et", # estonian
27
- "fi": "Helsinki-NLP/opus-mt-tc-big-en-fi", # finnish
28
- "fr": "Helsinki-NLP/opus-mt-en-fr", # french
29
- "hr": "facebook/mbart-large-50-many-to-many-mmt", # croatian
30
- "hu": "Helsinki-NLP/opus-mt-tc-big-en-hu", # hungarian
31
- "is": "mkorada/opus-mt-en-is-finetuned-v4", # icelandic # Manas's fine-tuned model
32
- "it": "Helsinki-NLP/opus-mt-tc-big-en-it", # italian
33
- "lb": "alirezamsh/small100", # luxembourgish # small100
34
- "lt": "Helsinki-NLP/opus-mt-tc-big-en-lt", # lithuanian
35
- "lv": "facebook/mbart-large-50-many-to-many-mmt", # latvian
36
- "cnr": "Helsinki-NLP/opus-mt-tc-base-en-sh", # montegrin
37
- "mk": "Helsinki-NLP/opus-mt-en-mk", # macedonian
38
- # "nb": "facebook/mbart-large-50-many-to-many-mmt", # norwegian
39
- "nl": "facebook/mbart-large-50-many-to-many-mmt", # dutch
40
- "no": "Confused404/eng-gmq-finetuned_v2-no", # norwegian # Alex's fine-tuned model
41
- "pl": "Helsinki-NLP/opus-mt-en-sla", # polish
42
- "pt": "facebook/mbart-large-50-many-to-many-mmt", # portuguese
43
- "ro": "facebook/mbart-large-50-many-to-many-mmt", # romanian
44
- "sk": "Helsinki-NLP/opus-mt-en-sk", # slovak
45
- "sl": "alirezamsh/small100", # slovene
46
- "sq": "alirezamsh/small100", # albanian
47
- "sv": "Helsinki-NLP/opus-mt-en-sv", # swedish
48
- "tr": "Helsinki-NLP/opus-mt-tc-big-en-tr" # turkish
49
  }
50
 
51
  # Cache loaded models/tokenizers
@@ -55,13 +53,16 @@ def load_model(model_id: str, target_lang: str):
55
  """
56
  Load & cache:
57
  - facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
58
- - alirezamsh/small100 via AutoTokenizer & AutoModelForSeq2SeqLM
59
  - all others via MarianTokenizer & MarianMTModel
60
  """
 
61
  if model_id not in MODEL_CACHE or model_id == "alirezamsh/small100":
62
  if model_id.startswith("facebook/mbart"):
63
  tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
64
- model = MBartForConditionalGeneration.from_pretrained(model_id)
 
 
65
  elif model_id == "alirezamsh/small100":
66
  tokenizer = SMALL100Tokenizer.from_pretrained(model_id, tgt_lang=target_lang)
67
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
@@ -93,18 +94,40 @@ async def translate(request: Request):
93
  chunks = chunk_text(text, safe_limit)
94
 
95
  tokenizer, model = load_model(model_id, target_lang)
 
96
  full_translation = []
97
 
98
  for chunk in chunks:
99
-
100
  if model_id == "Confused404/eng-gmq-finetuned_v2-no":
101
  chunk = f">>nob<< {chunk}"
102
  if model_id == "Helsinki-NLP/opus-mt-tc-base-en-sh":
103
  chunk = f">>cnr<< {chunk}"
104
-
105
- inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
106
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
107
- outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
109
 
110
  return {"translation": " ".join(full_translation)}
@@ -120,7 +143,6 @@ def list_languages():
120
  def health():
121
  return {"status": "ok"}
122
 
123
- # Uvicorn startup for local testing
124
  if __name__ == "__main__":
125
  import uvicorn
126
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
4
  MarianTokenizer,
5
  MBartForConditionalGeneration,
6
  MBart50TokenizerFast,
 
7
  AutoModelForSeq2SeqLM
8
  )
9
  import torch
 
16
 
17
  # Map target languages to Hugging Face model IDs
18
  MODEL_MAP = {
19
+ "bg": "Helsinki-NLP/opus-mt-tc-big-en-bg",
20
+ "cs": "Helsinki-NLP/opus-mt-en-cs",
21
+ "da": "Helsinki-NLP/opus-mt-en-da",
22
+ "de": "Helsinki-NLP/opus-mt-en-de",
23
+ "el": "Helsinki-NLP/opus-mt-tc-big-en-el",
24
+ "es": "Helsinki-NLP/opus-mt-tc-big-en-es",
25
+ "et": "Helsinki-NLP/opus-mt-tc-big-en-et",
26
+ "fi": "Helsinki-NLP/opus-mt-tc-big-en-fi",
27
+ "fr": "Helsinki-NLP/opus-mt-en-fr",
28
+ "hr": "facebook/mbart-large-50-many-to-many-mmt",
29
+ "hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
30
+ "is": "mkorada/opus-mt-en-is-finetuned-v4",
31
+ "it": "Helsinki-NLP/opus-mt-tc-big-en-it",
32
+ "lb": "alirezamsh/small100",
33
+ "lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
34
+ "lv": "facebook/mbart-large-50-many-to-many-mmt",
35
+ "cnr": "Helsinki-NLP/opus-mt-tc-base-en-sh",
36
+ "mk": "Helsinki-NLP/opus-mt-en-mk",
37
+ "nl": "facebook/mbart-large-50-many-to-many-mmt",
38
+ "no": "Confused404/eng-gmq-finetuned_v2-no",
39
+ "pl": "Helsinki-NLP/opus-mt-en-sla",
40
+ "pt": "facebook/mbart-large-50-many-to-many-mmt",
41
+ "ro": "facebook/mbart-large-50-many-to-many-mmt",
42
+ "sk": "Helsinki-NLP/opus-mt-en-sk",
43
+ "sl": "alirezamsh/small100",
44
+ "sq": "alirezamsh/small100",
45
+ "sv": "Helsinki-NLP/opus-mt-en-sv",
46
+ "tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
 
47
  }
48
 
49
  # Cache loaded models/tokenizers
 
53
  """
54
  Load & cache:
55
  - facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
56
+ - alirezamsh/small100 via SMALL100Tokenizer & AutoModelForSeq2SeqLM
57
  - all others via MarianTokenizer & MarianMTModel
58
  """
59
+ # Always reload small100 so we can pass a new tgt_lang
60
  if model_id not in MODEL_CACHE or model_id == "alirezamsh/small100":
61
  if model_id.startswith("facebook/mbart"):
62
  tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
63
+ # ── MBART: always translate FROM English
64
+ tokenizer.src_lang = "en_XX"
65
+ model = MBartForConditionalGeneration.from_pretrained(model_id)
66
  elif model_id == "alirezamsh/small100":
67
  tokenizer = SMALL100Tokenizer.from_pretrained(model_id, tgt_lang=target_lang)
68
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
 
94
  chunks = chunk_text(text, safe_limit)
95
 
96
  tokenizer, model = load_model(model_id, target_lang)
97
+ is_mbart = model_id.startswith("facebook/mbart")
98
  full_translation = []
99
 
100
  for chunk in chunks:
101
+ # special-prefix hacks for nor/cnr
102
  if model_id == "Confused404/eng-gmq-finetuned_v2-no":
103
  chunk = f">>nob<< {chunk}"
104
  if model_id == "Helsinki-NLP/opus-mt-tc-base-en-sh":
105
  chunk = f">>cnr<< {chunk}"
106
+
107
+ # tokenize
108
+ inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
109
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
110
+
111
+ # generate
112
+ if is_mbart:
113
+ # build e.g. "de_DE", "es_XX", etc.
114
+ lang_code = f"{target_lang}_{target_lang.upper()}"
115
+ bos_id = tokenizer.lang_code_to_id[lang_code]
116
+ outputs = model.generate(
117
+ **inputs,
118
+ forced_bos_token_id=bos_id,
119
+ num_beams=5,
120
+ length_penalty=1.2,
121
+ early_stopping=True
122
+ )
123
+ else:
124
+ outputs = model.generate(
125
+ **inputs,
126
+ num_beams=5,
127
+ length_penalty=1.2,
128
+ early_stopping=True
129
+ )
130
+
131
  full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
132
 
133
  return {"translation": " ".join(full_translation)}
 
143
  def health():
144
  return {"status": "ok"}
145
 
 
146
  if __name__ == "__main__":
147
  import uvicorn
148
  uvicorn.run("app:app", host="0.0.0.0", port=7860)