Lorenzob commited on
Commit
e22f66a
Β·
verified Β·
1 Parent(s): 7854a2b

Fix model recognition error and improve error handling

Browse files
Files changed (1) hide show
  1. app.py +86 -18
app.py CHANGED
@@ -2,7 +2,8 @@
2
  import gradio as gr
3
  import torch
4
  import os
5
- from transformers import AutoProcessor, SpeechT5ForTextToSpeech, set_seed
 
6
  import numpy as np
7
  from scipy import signal
8
  import warnings
@@ -12,38 +13,98 @@ warnings.filterwarnings("ignore")
12
  set_seed(42)
13
 
14
  # Definizioni di variabili globali
15
- MODEL_REPO = "Lorenzob/aurora-1.6b-complete" # Repository aggiornata con il modello completo
 
16
  SAMPLE_RATE = 24000 # Frequenza di campionamento per il modello TTS
17
 
18
  # Cache per il modello e il processor (per evitare di ricaricarli ad ogni richiesta)
19
  model = None
20
  processor = None
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def load_model_and_processor():
23
- """Carica il modello e il processor solo se non sono giΓ  stati caricati"""
24
  global model, processor
25
 
26
  if model is None or processor is None:
27
  try:
28
- print("πŸ“‚ Caricamento del modello Aurora-1.6b-complete...")
 
 
29
  processor = AutoProcessor.from_pretrained(MODEL_REPO)
 
 
30
  model = SpeechT5ForTextToSpeech.from_pretrained(
31
  MODEL_REPO,
32
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
33
  device_map="auto"
34
  )
35
- print("βœ… Modello caricato con successo!")
 
36
  except Exception as e:
37
- print(f"❌ Errore nel caricamento del modello: {e}")
38
- # Fallback al modello originale di Dia se il caricamento fallisce
39
- print("⚠️ Tentativo di fallback al modello Dia-1.6B...")
40
- processor = AutoProcessor.from_pretrained("nari-labs/Dia-1.6B")
41
- model = SpeechT5ForTextToSpeech.from_pretrained(
42
- "nari-labs/Dia-1.6B",
43
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
44
- device_map="auto"
45
- )
46
- print("βœ… Modello di fallback caricato con successo!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  return model, processor
49
 
@@ -63,10 +124,13 @@ def text_to_speech(text, language="it", speaker_id=0, speed=1.0, show_log=False)
63
  # Prepara gli input per il modello
64
  inputs = processor(
65
  text=text,
66
- language=language,
67
  return_tensors="pt"
68
  )
69
 
 
 
 
 
70
  # Sposta gli input sul dispositivo di calcolo
71
  for k, v in inputs.items():
72
  if hasattr(v, "to"):
@@ -82,7 +146,11 @@ def text_to_speech(text, language="it", speaker_id=0, speed=1.0, show_log=False)
82
 
83
  # Genera il speech
84
  with torch.no_grad():
85
- speech = model.generate(**inputs, **gen_params)
 
 
 
 
86
 
87
  # Converti il tensore in un array numpy
88
  speech_array = speech.cpu().numpy().squeeze()
@@ -116,7 +184,7 @@ with gr.Blocks(title="Aurora-1.6b TTS Demo", theme=gr.themes.Soft()) as demo:
116
  gr.Markdown("""
117
  # πŸŽ™οΈ Aurora-1.6b Text-to-Speech Demo
118
 
119
- Questa demo utilizza il modello Aurora-1.6b-complete per la sintesi vocale (TTS), un modello fine-tuned basato su Dia-1.6B con pesi completi.
120
 
121
  Il modello supporta italiano, inglese, spagnolo, francese e tedesco, ma Γ¨ stato ottimizzato per l'italiano.
122
  """)
 
2
  import gradio as gr
3
  import torch
4
  import os
5
+ import json
6
+ from transformers import AutoProcessor, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, SpeechT5ForTextToSpeech, set_seed
7
  import numpy as np
8
  from scipy import signal
9
  import warnings
 
13
  set_seed(42)
14
 
15
  # Definizioni di variabili globali
16
+ MODEL_REPO = "Lorenzob/aurora-1.6b-complete" # Repository principale
17
+ FALLBACK_REPO = "nari-labs/Dia-1.6B" # Repository di fallback
18
  SAMPLE_RATE = 24000 # Frequenza di campionamento per il modello TTS
19
 
20
  # Cache per il modello e il processor (per evitare di ricaricarli ad ogni richiesta)
21
  model = None
22
  processor = None
23
 
24
+ def fix_model_config(model_path):
25
+ """Aggiunge il model_type alla configurazione se necessario"""
26
+ try:
27
+ config_path = os.path.join(model_path, "config.json")
28
+ if os.path.exists(config_path):
29
+ with open(config_path, "r") as f:
30
+ config = json.load(f)
31
+
32
+ # Aggiungi model_type se mancante
33
+ if "model_type" not in config:
34
+ config["model_type"] = "speecht5"
35
+ print(f"Aggiunto model_type 'speecht5' alla configurazione")
36
+
37
+ # Aggiungi architectures se mancante
38
+ if "architectures" not in config:
39
+ config["architectures"] = ["SpeechT5ForTextToSpeech"]
40
+ print(f"Aggiunto architectures al config")
41
+
42
+ # Salva la configurazione aggiornata
43
+ with open(config_path, "w") as f:
44
+ json.dump(config, f, indent=2)
45
+
46
+ print(f"Configurazione aggiornata e salvata in {config_path}")
47
+ return True
48
+ else:
49
+ print(f"File di configurazione non trovato in {model_path}")
50
+ return False
51
+ except Exception as e:
52
+ print(f"Errore nella modifica del config: {e}")
53
+ return False
54
+
55
  def load_model_and_processor():
56
+ """Carica il modello e il processor con gestione degli errori avanzata"""
57
  global model, processor
58
 
59
  if model is None or processor is None:
60
  try:
61
+ print(f"πŸ“‚ Tentativo di caricamento del modello da {MODEL_REPO}...")
62
+
63
+ # Prova a caricare il processor
64
  processor = AutoProcessor.from_pretrained(MODEL_REPO)
65
+
66
+ # Carica il modello specificando esplicitamente la classe
67
  model = SpeechT5ForTextToSpeech.from_pretrained(
68
  MODEL_REPO,
69
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
70
  device_map="auto"
71
  )
72
+ print("βœ… Modello principale caricato con successo!")
73
+
74
  except Exception as e:
75
+ print(f"❌ Errore nel caricamento del modello principale: {e}")
76
+
77
+ # Prova con il modello di fallback
78
+ try:
79
+ print(f"⚠️ Tentativo di caricamento del modello di fallback da {FALLBACK_REPO}...")
80
+
81
+ # Prova a caricare il processor di fallback
82
+ processor = AutoProcessor.from_pretrained(FALLBACK_REPO)
83
+
84
+ # Carica il modello di fallback specificando esplicitamente la classe
85
+ model = SpeechT5ForTextToSpeech.from_pretrained(
86
+ FALLBACK_REPO,
87
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
88
+ device_map="auto"
89
+ )
90
+ print("βœ… Modello di fallback caricato con successo!")
91
+
92
+ except Exception as e2:
93
+ print(f"❌ Errore anche nel caricamento del modello di fallback: {e2}")
94
+
95
+ # Se entrambi i tentativi falliscono, prova con un modello TTS generico ben supportato
96
+ try:
97
+ print("πŸ”„ Tentativo con un modello TTS generico (microsoft/speecht5_tts)...")
98
+ processor = AutoProcessor.from_pretrained("microsoft/speecht5_tts")
99
+ model = SpeechT5ForTextToSpeech.from_pretrained(
100
+ "microsoft/speecht5_tts",
101
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
102
+ device_map="auto"
103
+ )
104
+ print("βœ… Modello generico caricato con successo!")
105
+ except Exception as e3:
106
+ print(f"❌ Tutti i tentativi di caricamento sono falliti: {e3}")
107
+ raise RuntimeError("Impossibile caricare alcun modello TTS")
108
 
109
  return model, processor
110
 
 
124
  # Prepara gli input per il modello
125
  inputs = processor(
126
  text=text,
 
127
  return_tensors="pt"
128
  )
129
 
130
+ # Aggiungi il parametro di lingua se supportato dal processor
131
+ if "language" in processor.model_input_names:
132
+ inputs["language"] = language
133
+
134
  # Sposta gli input sul dispositivo di calcolo
135
  for k, v in inputs.items():
136
  if hasattr(v, "to"):
 
146
 
147
  # Genera il speech
148
  with torch.no_grad():
149
+ # Passa speaker_embeddings se disponibili/necessari
150
+ if hasattr(model, "generate_speech"):
151
+ speech = model.generate_speech(**inputs, **gen_params)
152
+ else:
153
+ speech = model.generate(**inputs, **gen_params)
154
 
155
  # Converti il tensore in un array numpy
156
  speech_array = speech.cpu().numpy().squeeze()
 
184
  gr.Markdown("""
185
  # πŸŽ™οΈ Aurora-1.6b Text-to-Speech Demo
186
 
187
+ Questa demo utilizza il modello Aurora-1.6b-complete per la sintesi vocale (TTS), un modello fine-tuned basato su Dia-1.6B.
188
 
189
  Il modello supporta italiano, inglese, spagnolo, francese e tedesco, ma Γ¨ stato ottimizzato per l'italiano.
190
  """)