ArissBandoss commited on
Commit
fb3e214
·
verified ·
1 Parent(s): 8cf4d3b

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +41 -20
goai_helpers/goai_traduction.py CHANGED
@@ -12,7 +12,7 @@ login(token=auth_token)
12
 
13
 
14
  @spaces.GPU
15
- def goai_traduction(text, src_lang, tgt_lang):
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
@@ -22,38 +22,59 @@ def goai_traduction(text, src_lang, tgt_lang):
22
  else:
23
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
24
 
 
25
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
26
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
27
-
28
  print(f"Texte brut ({len(text)} caractères / {len(text.split())} mots):")
29
  print(text)
30
-
31
- print(tokenizer.model_max_length)
32
- print(model.model.encoder.embed_positions.weights.shape)
 
 
 
33
  # Configuration du tokenizer
34
  tokenizer.src_lang = src_lang
35
 
36
  # Tokenisation
37
  inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device)
38
- print("Nombre de tokens :", inputs["input_ids"].shape[1])
 
 
 
 
 
39
 
40
  # ID du token de langue cible
41
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
 
42
 
43
- # Génération avec les paramètres optimaux
44
- outputs = model.generate(
45
- **inputs,
46
- forced_bos_token_id=tgt_lang_id,
47
- max_new_tokens=1024,
48
- early_stopping=False,
49
- num_beams=5,
50
- no_repeat_ngram_size=0,
51
- length_penalty=1.0
52
- )
53
-
54
- # Décodage
55
- translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
56
-
 
 
 
 
 
 
 
 
 
 
 
57
  return translation
58
 
59
  def real_time_traduction(input_text, src_lang, tgt_lang):
 
12
 
13
 
14
  @spaces.GPU
15
+ def goai_traduction_debug(text, src_lang, tgt_lang):
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
 
22
  else:
23
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
24
 
25
+ print(f"Chargement du modèle: {model_id}")
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
27
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
28
+
29
  print(f"Texte brut ({len(text)} caractères / {len(text.split())} mots):")
30
  print(text)
31
+
32
+ print(f"Configuration du modèle:")
33
+ print(f"- tokenizer.model_max_length: {tokenizer.model_max_length}")
34
+ print(f"- Position embeddings shape: {model.model.encoder.embed_positions.weights.shape}")
35
+ print(f"- decoder.embed_positions shape: {model.model.decoder.embed_positions.weights.shape}")
36
+
37
  # Configuration du tokenizer
38
  tokenizer.src_lang = src_lang
39
 
40
  # Tokenisation
41
  inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device)
42
+ input_ids = inputs["input_ids"][0]
43
+
44
+ print("Tokens d'entrée:")
45
+ print(f"- Nombre de tokens: {input_ids.shape[0]}")
46
+ print(f"- Premiers tokens: {input_ids[:10].tolist()}")
47
+ print(f"- Derniers tokens: {input_ids[-10:].tolist()}")
48
 
49
  # ID du token de langue cible
50
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
51
+ print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}")
52
 
53
+ for length_penalty in [1.0, 1.5, 2.0]:
54
+ for num_beams in [5, 8]:
55
+ print(f"\nTest avec length_penalty={length_penalty}, num_beams={num_beams}")
56
+
57
+
58
+ outputs = model.generate(
59
+ **inputs,
60
+ forced_bos_token_id=tgt_lang_id,
61
+ max_new_tokens=2048,
62
+ early_stopping=False,
63
+ num_beams=num_beams,
64
+ no_repeat_ngram_size=0,
65
+ length_penalty=length_penalty
66
+ )
67
+
68
+
69
+ translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
70
+
71
+ print(f"Traduction ({len(translation)} caractères / {len(translation.split())} mots):")
72
+ print(translation)
73
+ output_ids = outputs[0]
74
+ print(f"- Nombre de tokens générés: {output_ids.shape[0]}")
75
+ print(f"- Premiers tokens générés: {output_ids[:10].tolist()}")
76
+ print(f"- Derniers tokens générés: {output_ids[-10:].tolist()}")
77
+
78
  return translation
79
 
80
  def real_time_traduction(input_text, src_lang, tgt_lang):