File size: 3,091 Bytes
78d1101
 
54108c5
 
78d1101
54108c5
1c7cbff
78d1101
 
 
 
6b82660
54108c5
585f2eb
54108c5
585f2eb
242f488
54108c5
 
 
 
 
 
 
 
 
 
585f2eb
 
54108c5
 
 
 
 
 
 
 
 
 
78d1101
 
242f488
585f2eb
 
 
54108c5
 
 
 
 
 
 
 
 
 
 
78d1101
261a5aa
4a24c4f
8b20d96
78d1101
9216991
 
7f3d8a9
dfb286c
 
fb3e214
7f3d8a9
261a5aa
7f3d8a9
 
54108c5
7f3d8a9
 
261a5aa
4a24c4f
585f2eb
c220549
54108c5
 
 
 
c220549
4a24c4f
 
7cc5f39
4a24c4f
261a5aa
20a6b7b
 
585f2eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import spaces
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from huggingface_hub import login
import os

max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)

def split_text_intelligently(text, max_chunk_length=80):
    """
    Divise le texte en chunks en respectant les phrases complètes.
    """
    # Séparation basée sur les phrases (utilise les points, points d'interrogation, etc.)
    sentences = re.split(r'([.!?:])', text)
    chunks = []
    current_chunk = ""
    
    for i in range(0, len(sentences), 2):
        # Reconstruire la phrase avec sa ponctuation
        if i + 1 < len(sentences):
            sentence = sentences[i] + sentences[i+1]
        else:
            sentence = sentences[i]
        
        # Si l'ajout de cette phrase dépasse la longueur maximale, on crée un nouveau chunk
        if len(current_chunk) + len(sentence) > max_chunk_length and current_chunk:
            chunks.append(current_chunk.strip())
            current_chunk = sentence
        else:
            current_chunk += sentence
    
    # Ajouter le dernier chunk s'il reste du texte
    if current_chunk:
        chunks.append(current_chunk.strip())
    
    return chunks

@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang, max_chunk_length=80):
    # Si le texte est trop long, le diviser en chunks
    if len(text) > max_chunk_length:
        chunks = split_text_intelligently(text, max_chunk_length)
        translations = []
        
        for chunk in chunks:
            translated_chunk = translate_chunk(chunk, src_lang, tgt_lang)
            translations.append(translated_chunk)
        
        return " ".join(translations)
    else:
        return translate_chunk(text, src_lang, tgt_lang)

def translate_chunk(text, src_lang, tgt_lang):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
        model_id = "ArissBandoss/mos2fr-3B"
    else:
        #model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
        model_id = "ArissBandoss/fr2mos-1B"
    
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
    
    # Configuration du tokenizer
    tokenizer.src_lang = src_lang
    
    # Tokenisation
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
    
    # ID du token de langue cible
    tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
    
    # Paramètres de génération optimisés pour éviter les répétitions
    outputs = model.generate(
        **inputs,
        forced_bos_token_id=tgt_lang_id,
        max_new_tokens=512,
        early_stopping=True
    )
    
    # Décodage
    translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    
    return translation

def real_time_traduction(input_text, src_lang, tgt_lang):
    return goai_traduction(input_text, src_lang, tgt_lang)