ArissBandoss commited on
Commit
54108c5
·
verified ·
1 Parent(s): 7b6884e

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +56 -19
goai_helpers/goai_traduction.py CHANGED
@@ -1,18 +1,59 @@
1
  import torch
2
  import spaces
3
- from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
4
- from peft import PeftModel, PeftConfig
5
- import os
6
- import unicodedata
7
  from huggingface_hub import login
 
8
 
9
  max_length = 512
10
  auth_token = os.getenv('HF_SPACE_TOKEN')
11
  login(token=auth_token)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  @spaces.GPU
15
- def goai_traduction(text, src_lang, tgt_lang):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  if src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
@@ -27,25 +68,21 @@ def goai_traduction(text, src_lang, tgt_lang):
27
  tokenizer.src_lang = src_lang
28
 
29
  # Tokenisation
30
- inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device)
31
- input_length = inputs["input_ids"].shape[1]
32
-
33
 
34
  # ID du token de langue cible
35
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
36
 
37
- # ID du token EOS
38
- eos_token_id = tokenizer.eos_token_id
39
-
40
- # Bloquer complètement le token EOS jusqu'à un certain point
41
  outputs = model.generate(
42
- **inputs,
43
- forced_bos_token_id=tgt_lang_id,
44
- max_new_tokens=1024,
45
- early_stopping=False,
46
- num_beams=5,
47
- no_repeat_ngram_size=0,
48
- length_penalty=1.0
 
49
  )
50
 
51
  # Décodage
 
1
  import torch
2
  import spaces
3
+ import re
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
 
5
  from huggingface_hub import login
6
+ import os
7
 
8
  max_length = 512
9
  auth_token = os.getenv('HF_SPACE_TOKEN')
10
  login(token=auth_token)
11
 
12
+ def split_text_intelligently(text, max_chunk_length=100):
13
+ """
14
+ Divise le texte en chunks en respectant les phrases complètes.
15
+ """
16
+ # Séparation basée sur les phrases (utilise les points, points d'interrogation, etc.)
17
+ sentences = re.split(r'([.!?])', text)
18
+ chunks = []
19
+ current_chunk = ""
20
+
21
+ for i in range(0, len(sentences), 2):
22
+ # Reconstruire la phrase avec sa ponctuation
23
+ if i + 1 < len(sentences):
24
+ sentence = sentences[i] + sentences[i+1]
25
+ else:
26
+ sentence = sentences[i]
27
+
28
+ # Si l'ajout de cette phrase dépasse la longueur maximale, on crée un nouveau chunk
29
+ if len(current_chunk) + len(sentence) > max_chunk_length and current_chunk:
30
+ chunks.append(current_chunk.strip())
31
+ current_chunk = sentence
32
+ else:
33
+ current_chunk += sentence
34
+
35
+ # Ajouter le dernier chunk s'il reste du texte
36
+ if current_chunk:
37
+ chunks.append(current_chunk.strip())
38
+
39
+ return chunks
40
 
41
  @spaces.GPU
42
+ def goai_traduction(text, src_lang, tgt_lang, max_chunk_length=100):
43
+ # Si le texte est trop long, le diviser en chunks
44
+ if len(text) > max_chunk_length:
45
+ chunks = split_text_intelligently(text, max_chunk_length)
46
+ translations = []
47
+
48
+ for chunk in chunks:
49
+ translated_chunk = translate_chunk(chunk, src_lang, tgt_lang)
50
+ translations.append(translated_chunk)
51
+
52
+ return " ".join(translations)
53
+ else:
54
+ return translate_chunk(text, src_lang, tgt_lang)
55
+
56
+ def translate_chunk(text, src_lang, tgt_lang):
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
 
59
  if src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
 
68
  tokenizer.src_lang = src_lang
69
 
70
  # Tokenisation
71
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
 
 
72
 
73
  # ID du token de langue cible
74
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
75
 
76
+ # Paramètres de génération optimisés pour éviter les répétitions
 
 
 
77
  outputs = model.generate(
78
+ **inputs,
79
+ forced_bos_token_id=tgt_lang_id,
80
+ max_new_tokens=512,
81
+ num_beams=5,
82
+ no_repeat_ngram_size=4,
83
+ repetition_penalty=2.0,
84
+ length_penalty=1.0,
85
+ early_stopping=True
86
  )
87
 
88
  # Décodage