|
import os |
|
import time |
|
import torch |
|
import torchaudio |
|
import spaces |
|
import tempfile |
|
from tqdm import tqdm |
|
from typing import Optional, Tuple |
|
from huggingface_hub import hf_hub_download, hf_hub_url, login |
|
|
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
from TTS.tts.models.xtts import Xtts |
|
|
|
from goai_helpers.utils import download_file, diviser_phrases_moore |
|
from goai_helpers.goai_traduction import goai_traduction |
|
|
|
|
|
auth_token = os.getenv('HF_SPACE_TOKEN') |
|
login(token=auth_token) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class MooreTTS: |
|
""" |
|
Classe Mooré Text-to-Speech (TTS) qui initialise et utilise un modèle TTS. |
|
Attributs : |
|
language_code (str) : code ISO de la langue pour le mooré. |
|
checkpoint_repo_or_dir (str) : URL ou chemin local vers le répertoire du point de contrôle du modèle. |
|
local_dir (str) : Le répertoire pour stocker les points de contrôle téléchargés. |
|
paths (dict) : Un dictionnaire des chemins vers les composants du modèle. |
|
config (XttsConfig) : Objet de configuration pour le modèle TTS. |
|
model (Xtts) : L'instance du modèle TTS. |
|
""" |
|
|
|
def __init__(self, checkpoint_repo_or_dir: str, local_dir: Optional[str] = None): |
|
""" |
|
Initialise l'instance MooreTTS. |
|
Args : |
|
checkpoint_repo_or_dir : Une chaîne représentant soit un dépôt Hugging Face, |
|
soit un répertoire local où le point de contrôle du modèle TTS est situé. |
|
local_dir : Une chaîne optionnelle représentant un chemin de répertoire local où les points de contrôle du modèle |
|
seront téléchargés. Si non spécifié, un répertoire local par défaut est utilisé |
|
basé sur `checkpoint_repo_or_dir`. |
|
Le processus d'initialisation implique la configuration de répertoires locaux pour les composants du modèle, |
|
l'assurance que le point de contrôle du modèle est disponible, et le chargement de la configuration et du tokenizer du modèle. |
|
""" |
|
|
|
|
|
self.language_code = 'mos' |
|
|
|
|
|
self.checkpoint_repo_or_dir = checkpoint_repo_or_dir |
|
|
|
|
|
self.local_dir = local_dir if local_dir else self.default_local_dir(checkpoint_repo_or_dir) |
|
|
|
|
|
self.paths = self.init_paths(self.local_dir) |
|
|
|
|
|
self.ensure_checkpoint_is_downloaded() |
|
|
|
|
|
self.config = XttsConfig() |
|
self.config.load_json(self.paths['config.json']) |
|
|
|
|
|
self.model = Xtts.init_from_config(self.config) |
|
|
|
|
|
|
|
self.model.load_checkpoint( |
|
self.config, |
|
checkpoint_path=self.local_dir+ "/model_compressed.pth" , |
|
vocab_path=self.paths['vocab.json'], |
|
use_deepspeed=False |
|
) |
|
|
|
if torch.cuda.is_available(): |
|
self.model.cuda() |
|
|
|
print("Model loaded successfully!") |
|
|
|
|
|
def ensure_checkpoint_is_downloaded(self): |
|
""" |
|
S'assure que le point de contrôle du modèle est téléchargé et disponible localement. |
|
""" |
|
if os.path.exists(self.checkpoint_repo_or_dir): |
|
return |
|
|
|
os.makedirs(self.local_dir, exist_ok=True) |
|
print("Téléchargement du point de contrôle depuis le hub...") |
|
|
|
for filename, filepath in self.paths.items(): |
|
if os.path.exists(filepath): |
|
print(f"Fichier {filepath} déjà existant. Passé...") |
|
continue |
|
|
|
file_url = hf_hub_url(repo_id=self.checkpoint_repo_or_dir, filename=filename) |
|
print(f"Téléchargement de {filename} depuis {file_url}") |
|
try: |
|
download_file(file_url, filepath) |
|
except Exception as e: |
|
print(f"Téléchargement de {filename} échoué: {e}") |
|
|
|
|
|
print("Point de contrôle téléchargé avec succès !") |
|
|
|
|
|
def default_local_dir(self, checkpoint_repo_or_dir: str) -> str: |
|
""" |
|
Génère un chemin de répertoire local par défaut pour stocker le point de contrôle du modèle. |
|
Args : |
|
checkpoint_repo_or_dir : Le dépôt ou chemin de répertoire original du point de contrôle. |
|
Returns : |
|
Le chemin de répertoire local par défaut. |
|
""" |
|
if os.path.exists(checkpoint_repo_or_dir): |
|
return checkpoint_repo_or_dir |
|
|
|
model_path = f"models--{checkpoint_repo_or_dir.replace('/', '--')}" |
|
local_dir = os.path.join(os.path.expanduser('~'), "mooreTTS", model_path) |
|
return local_dir.lower() |
|
|
|
@staticmethod |
|
def init_paths(local_dir: str) -> dict: |
|
""" |
|
Initialise les chemins vers les divers composants du modèle basés sur le répertoire local. |
|
Args : |
|
local_dir : Le répertoire local où les composants du modèle sont stockés. |
|
Returns : |
|
Un dictionnaire avec des clés comme noms des composants et des valeurs comme chemins des fichiers. |
|
""" |
|
components = ['model_compressed.pth', 'config.json', 'vocab.json', 'dvae.pth', 'mel_stats.pth'] |
|
return {name: os.path.join(local_dir, name) for name in components} |
|
|
|
def text_to_speech( |
|
self, |
|
tts_text: str, |
|
speaker_reference_wav_path: Optional[str] = None, |
|
temperature: Optional[float] = 0.1 |
|
) -> Tuple[int, torch.Tensor]: |
|
""" |
|
Convertit un texte en audio de synthèse vocale. |
|
Args : |
|
text : Le texte d'entrée à convertir en audio. |
|
speaker_reference_wav_path : Un chemin vers un fichier WAV de référence pour l'orateur. |
|
temperature : Le paramètre de température pour l'échantillonnage. |
|
enable_text_splitting : Indicateur pour activer ou désactiver la découpe du texte. |
|
Returns : |
|
Un tuple contenant le taux d'échantillonnage et le tenseur audio généré. |
|
""" |
|
if speaker_reference_wav_path is None: |
|
speaker_reference_wav_path = "./audios/ref1_male_17.wav" |
|
print("Utilisation du fichier de référence par défaut ./audios/ref1_male_17.wav") |
|
|
|
print("Calcul des latents de conditionnement de l'orateur...") |
|
gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents( |
|
audio_path=[speaker_reference_wav_path], |
|
gpt_cond_len=self.model.config.gpt_cond_len, |
|
max_ref_length=self.model.config.max_ref_len, |
|
sound_norm_refs=self.model.config.sound_norm_refs, |
|
) |
|
|
|
tts_texts = diviser_phrases_moore(tts_text) |
|
|
|
print("Début de l'inférence...") |
|
start_time = time.time() |
|
|
|
wav_chunks = [] |
|
for text in tqdm(tts_texts): |
|
wav_chunk = self.model.inference( |
|
text=text, |
|
language=self.language_code, |
|
gpt_cond_latent=gpt_cond_latent, |
|
speaker_embedding=speaker_embedding, |
|
temperature=0.1, |
|
length_penalty=1.0, |
|
repetition_penalty=10.0, |
|
top_k=10, |
|
top_p=0.3, |
|
) |
|
wav_chunks.append(torch.tensor(wav_chunk["wav"])) |
|
|
|
end_time = time.time() |
|
|
|
audio = torch.cat(wav_chunks, dim=0).unsqueeze(0).cpu() |
|
sampling_rate = torch.tensor(self.config.model_args.output_sample_rate).cpu().item() |
|
|
|
print(f"Voix générée en {end_time - start_time:.2f} secondes.") |
|
|
|
return sampling_rate, audio |
|
|
|
|
|
|
|
@spaces.GPU |
|
def text_to_speech(tts, text, reference_speaker: str, reference_audio: Optional[Tuple] = None): |
|
if reference_audio is not None: |
|
ref_sr, ref_audio = reference_audio |
|
ref_audio = torch.from_numpy(ref_audio) |
|
|
|
|
|
if ref_audio.ndim == 1: |
|
ref_audio = ref_audio.unsqueeze(0) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp: |
|
torchaudio.save(tmp.name, ref_audio, ref_sr) |
|
tmp_path = tmp.name |
|
|
|
|
|
sr, audio = tts.text_to_speech(text, speaker_reference_wav_path=tmp_path) |
|
|
|
|
|
os.unlink(tmp_path) |
|
else: |
|
|
|
sr, audio = tts.text_to_speech(text, speaker_reference_wav_path=reference_speaker) |
|
|
|
audio = audio.mean(dim=0) |
|
return audio, sr |
|
|
|
|
|
|
|
@spaces.GPU |
|
def goai_tts2( |
|
text, |
|
reference_speaker, |
|
reference_audio=None, |
|
solver="Midpoint", |
|
nfe=128, |
|
prior_temp=0.01, |
|
denoise_before_enhancement=False |
|
): |
|
|
|
tts_model = "ArissBandoss/coqui-tts-moore-V1" |
|
tts = MooreTTS(tts_model) |
|
|
|
reference_speaker = os.path.join("./exples_voix", reference_speaker) |
|
|
|
|
|
|
|
if reference_audio is not None: |
|
audio_array, sampling_rate = text_to_speech(tts, text, reference_speaker, reference_audio) |
|
else: |
|
audio_array, sampling_rate = text_to_speech(tts, text, reference_speaker=reference_speaker) |
|
|
|
yield text, (sampling_rate, audio_array.numpy()), None, None |
|
|
|
|
|
|
|
yield (sampling_rate, audio_array.numpy()) |
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def goai_ttt_tts( |
|
text, |
|
reference_speaker, |
|
reference_audio=None, |
|
solver="Midpoint", |
|
nfe=128, |
|
prior_temp=0.01, |
|
denoise_before_enhancement=False |
|
): |
|
|
|
|
|
mos_text = goai_traduction( |
|
text, |
|
src_lang="fra_Latn", |
|
tgt_lang="mos_Latn" |
|
) |
|
yield mos_text, None, None, None |
|
|
|
|
|
reference_speaker = os.path.join("./exples_voix", reference_speaker) |
|
tts_model = "ArissBandoss/coqui-tts-moore-V1" |
|
tts = MooreTTS(tts_model) |
|
|
|
|
|
if reference_audio is not None: |
|
audio_array, sampling_rate = text_to_speech(tts, mos_text, reference_speaker, reference_audio) |
|
else: |
|
audio_array, sampling_rate = text_to_speech(tts, mos_text, reference_speaker=reference_speaker) |
|
|
|
yield mos_text, (sampling_rate, audio_array.numpy()), None, None |
|
|
|
|
|
|
|
yield mos_text, (sampling_rate, audio_array.numpy()) |