ArissBandoss's picture
Update goai_helpers/goai_traduction.py
9216991 verified
import torch
import spaces
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from huggingface_hub import login
import os
max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)
def split_text_intelligently(text, max_chunk_length=80):
"""
Divise le texte en chunks en respectant les phrases complètes.
"""
# Séparation basée sur les phrases (utilise les points, points d'interrogation, etc.)
sentences = re.split(r'([.!?:])', text)
chunks = []
current_chunk = ""
for i in range(0, len(sentences), 2):
# Reconstruire la phrase avec sa ponctuation
if i + 1 < len(sentences):
sentence = sentences[i] + sentences[i+1]
else:
sentence = sentences[i]
# Si l'ajout de cette phrase dépasse la longueur maximale, on crée un nouveau chunk
if len(current_chunk) + len(sentence) > max_chunk_length and current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
else:
current_chunk += sentence
# Ajouter le dernier chunk s'il reste du texte
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang, max_chunk_length=80):
# Si le texte est trop long, le diviser en chunks
if len(text) > max_chunk_length:
chunks = split_text_intelligently(text, max_chunk_length)
translations = []
for chunk in chunks:
translated_chunk = translate_chunk(chunk, src_lang, tgt_lang)
translations.append(translated_chunk)
return " ".join(translations)
else:
return translate_chunk(text, src_lang, tgt_lang)
def translate_chunk(text, src_lang, tgt_lang):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
model_id = "ArissBandoss/mos2fr-3B"
else:
#model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
model_id = "ArissBandoss/fr2mos-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
# Configuration du tokenizer
tokenizer.src_lang = src_lang
# Tokenisation
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
# ID du token de langue cible
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
# Paramètres de génération optimisés pour éviter les répétitions
outputs = model.generate(
**inputs,
forced_bos_token_id=tgt_lang_id,
max_new_tokens=512,
early_stopping=True
)
# Décodage
translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return translation
def real_time_traduction(input_text, src_lang, tgt_lang):
return goai_traduction(input_text, src_lang, tgt_lang)