ArissBandoss commited on
Commit
7f3d8a9
·
verified ·
1 Parent(s): dfb286c

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +18 -14
goai_helpers/goai_traduction.py CHANGED
@@ -18,32 +18,36 @@ def goai_traduction(text, src_lang, tgt_lang):
18
  if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
19
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
20
  elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
21
- model_id = "ArissBandoss/mos2fr-3B-1200"
22
  else:
23
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
24
-
25
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
26
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
27
-
 
28
  tokenizer.src_lang = src_lang
 
 
29
  inputs = tokenizer(text, return_tensors="pt").to(device)
30
-
31
- # Ajout du code de langue cible
32
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
33
-
34
- # Génération contrôlée
35
  outputs = model.generate(
36
  **inputs,
37
  forced_bos_token_id=tgt_lang_id,
38
- eos_token_id=tokenizer.eos_token_id, # S’assurer que le modèle peut s’arrêter
39
- max_length=512, # Teste avec 256 puis augmente progressivement
40
- do_sample=False,
41
- early_stopping=True
 
42
  )
43
-
 
44
  translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
45
- print("ici translation")
46
- print(translation)
47
  return translation
48
 
49
  def real_time_traduction(input_text, src_lang, tgt_lang):
 
18
  if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
19
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
20
  elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
21
+ model_id = "ArissBandoss/mos2fr-5B-800-fixed" # Modèle réparé
22
  else:
23
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
24
+
25
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
26
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
27
+
28
+ # Configuration du tokenizer
29
  tokenizer.src_lang = src_lang
30
+
31
+ # Tokenisation
32
  inputs = tokenizer(text, return_tensors="pt").to(device)
33
+
34
+ # ID du token de langue cible
35
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
36
+
37
+ # Génération avec les paramètres optimaux
38
  outputs = model.generate(
39
  **inputs,
40
  forced_bos_token_id=tgt_lang_id,
41
+ max_new_tokens=1024,
42
+ num_beams=5,
43
+ early_stopping=False,
44
+ no_repeat_ngram_size=0,
45
+ length_penalty=1.0
46
  )
47
+
48
+ # Décodage
49
  translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
50
+
 
51
  return translation
52
 
53
  def real_time_traduction(input_text, src_lang, tgt_lang):