Spaces:
TDN-M
/
Running on Zero

TDN-M commited on
Commit
b085276
verified
1 Parent(s): 9de5092

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -139
app.py CHANGED
@@ -14,14 +14,12 @@ from TTS.tts.configs.xtts_config import XttsConfig
14
  from TTS.tts.models.xtts import Xtts
15
  from vinorm import TTSnorm
16
 
17
- # download for mecab
18
- # os.system("python -m unidic download")
19
-
20
  HF_TOKEN = os.environ.get("HF_TOKEN")
21
  api = HfApi(token=HF_TOKEN)
22
 
23
- # This will trigger downloading model
24
- print("Downloading if not downloaded viXTTS")
25
  checkpoint_dir = "model/"
26
  repo_id = "capleaf/viXTTS"
27
  use_deepspeed = False
@@ -42,6 +40,7 @@ if not all(file in files_in_dir for file in required_files):
42
  local_dir=checkpoint_dir,
43
  )
44
 
 
45
  xtts_config = os.path.join(checkpoint_dir, "config.json")
46
  config = XttsConfig()
47
  config.load_json(xtts_config)
@@ -52,8 +51,9 @@ MODEL.load_checkpoint(
52
  if torch.cuda.is_available():
53
  MODEL.cuda()
54
 
 
55
  supported_languages = config.languages
56
- if not "vi" in supported_languages:
57
  supported_languages.append("vi")
58
 
59
 
@@ -74,7 +74,6 @@ def normalize_vietnamese_text(text):
74
 
75
 
76
  def calculate_keep_len(text, lang):
77
- """Simple hack for short sentences"""
78
  if lang in ["ja", "zh-cn"]:
79
  return -1
80
 
@@ -88,52 +87,39 @@ def calculate_keep_len(text, lang):
88
  return -1
89
 
90
 
91
- def predict(
92
- prompt,
93
- language,
94
- audio_file_pth,
95
- normalize_text=True,
96
- ):
97
  if language not in supported_languages:
98
  metrics_text = gr.Warning(
99
- f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
100
  )
101
-
102
- return (None, metrics_text)
103
-
104
- speaker_wav = audio_file_pth
105
 
106
  if len(prompt) < 2:
107
- metrics_text = gr.Warning("Please give a longer prompt text")
108
- return (None, metrics_text)
 
109
  try:
110
  metrics_text = ""
111
  t_latent = time.time()
112
 
113
  try:
114
- (
115
- gpt_cond_latent,
116
- speaker_embedding,
117
- ) = MODEL.get_conditioning_latents(
118
- audio_path=speaker_wav,
119
  gpt_cond_len=30,
120
  gpt_cond_chunk_len=4,
121
  max_ref_length=60,
122
  )
123
-
124
  except Exception as e:
125
- print("Speaker encoding error", str(e))
126
- metrics_text = gr.Warning(
127
- "It appears something wrong with reference, did you unmute your microphone?"
128
- )
129
- return (None, metrics_text)
130
 
131
- prompt = re.sub("([^\x00-\x7F]|\w)(\.|\銆倈\?)", r"\1 \2\2", prompt)
132
 
133
  if normalize_text and language == "vi":
134
  prompt = normalize_vietnamese_text(prompt)
135
 
136
- print("I: Generating new audio...")
137
  t0 = time.time()
138
  out = MODEL.inference(
139
  prompt,
@@ -145,100 +131,30 @@ def predict(
145
  enable_text_splitting=True,
146
  )
147
  inference_time = time.time() - t0
148
- print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
149
- metrics_text += (
150
- f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
151
- )
152
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
153
- print(f"Real-time factor (RTF): {real_time_factor}")
154
  metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
155
 
156
- # Temporary hack for short sentences
157
  keep_len = calculate_keep_len(prompt, language)
158
  out["wav"] = out["wav"][:keep_len]
159
 
160
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
161
 
162
  except RuntimeError as e:
163
- if "device-side assert" in str(e):
164
- # cannot do anything on cuda device side error, need tor estart
165
- print(
166
- f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
167
- flush=True,
168
- )
169
- gr.Warning("Unhandled Exception encounter, please retry in a minute")
170
- print("Cuda device-assert Runtime encountered need restart")
171
-
172
- error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
173
- error_data = [
174
- error_time,
175
- prompt,
176
- language,
177
- audio_file_pth,
178
- ]
179
- error_data = [str(e) if type(e) != str else e for e in error_data]
180
- print(error_data)
181
- print(speaker_wav)
182
- write_io = StringIO()
183
- csv.writer(write_io).writerows([error_data])
184
- csv_upload = write_io.getvalue().encode()
185
-
186
- filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
187
- print("Writing error csv")
188
- error_api = HfApi()
189
- error_api.upload_file(
190
- path_or_fileobj=csv_upload,
191
- path_in_repo=filename,
192
- repo_id="coqui/xtts-flagged-dataset",
193
- repo_type="dataset",
194
- )
195
-
196
- # speaker_wav
197
- print("Writing error reference audio")
198
- speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
199
- error_api = HfApi()
200
- error_api.upload_file(
201
- path_or_fileobj=speaker_wav,
202
- path_in_repo=speaker_filename,
203
- repo_id="coqui/xtts-flagged-dataset",
204
- repo_type="dataset",
205
- )
206
 
207
- # HF Space specific.. This error is unrecoverable need to restart space
208
- space = api.get_space_runtime(repo_id=repo_id)
209
- if space.stage != "BUILDING":
210
- api.restart_space(repo_id=repo_id)
211
- else:
212
- print("TRIED TO RESTART but space is building")
213
-
214
- else:
215
- if "Failed to decode" in str(e):
216
- print("Speaker encoding error", str(e))
217
- metrics_text = gr.Warning(
218
- metrics_text="It appears something wrong with reference, did you unmute your microphone?"
219
- )
220
- else:
221
- print("RuntimeError: non device-side assert error:", str(e))
222
- metrics_text = gr.Warning(
223
- "Something unexpected happened please retry again."
224
- )
225
- return (None, metrics_text)
226
- return ("output.wav", metrics_text)
227
 
228
 
229
  title = "viXTTS Demo"
230
 
231
-
232
  with gr.Blocks(analytics_enabled=False) as demo:
233
  with gr.Row():
234
  with gr.Column():
235
- gr.Markdown(
236
- """
237
- viXTTS Demo
238
- """
239
- )
240
  with gr.Column():
241
- # placeholder to align the image
242
  pass
243
 
244
  with gr.Row():
@@ -251,33 +167,13 @@ with gr.Blocks(analytics_enabled=False) as demo:
251
  language_gr = gr.Dropdown(
252
  label="Language",
253
  info="Select an output language for the synthesised speech",
254
- choices=[
255
- "vi",
256
- "en",
257
- "es",
258
- "fr",
259
- "de",
260
- "it",
261
- "pt",
262
- "pl",
263
- "tr",
264
- "ru",
265
- "nl",
266
- "cs",
267
- "ar",
268
- "zh-cn",
269
- "ja",
270
- "ko",
271
- "hu",
272
- "hi",
273
- ],
274
- max_choices=1,
275
  value="vi",
276
  )
277
  normalize_text = gr.Checkbox(
278
  label="Normalize Vietnamese Text",
279
  info="Normalize Vietnamese Text",
280
- default=True,
281
  )
282
  ref_gr = gr.Audio(
283
  label="Reference Audio",
@@ -289,19 +185,14 @@ with gr.Blocks(analytics_enabled=False) as demo:
289
 
290
  with gr.Column():
291
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
292
- out_text_gr = gr.Text(label="Metrics")
293
 
294
  tts_button.click(
295
  predict,
296
- [
297
- input_text_gr,
298
- language_gr,
299
- ref_gr,
300
- normalize_text,
301
- ],
302
  outputs=[audio_gr, out_text_gr],
303
  api_name="predict",
304
  )
305
 
306
  demo.queue()
307
- demo.launch(debug=True, show_api=True, share=True)
 
14
  from TTS.tts.models.xtts import Xtts
15
  from vinorm import TTSnorm
16
 
17
+ # Initialize Hugging Face API
 
 
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
19
  api = HfApi(token=HF_TOKEN)
20
 
21
+ # Download model files if not already downloaded
22
+ print("Downloading viXTTS model files if not already present...")
23
  checkpoint_dir = "model/"
24
  repo_id = "capleaf/viXTTS"
25
  use_deepspeed = False
 
40
  local_dir=checkpoint_dir,
41
  )
42
 
43
+ # Load model configuration and initialize model
44
  xtts_config = os.path.join(checkpoint_dir, "config.json")
45
  config = XttsConfig()
46
  config.load_json(xtts_config)
 
51
  if torch.cuda.is_available():
52
  MODEL.cuda()
53
 
54
+ # Supported languages
55
  supported_languages = config.languages
56
+ if "vi" not in supported_languages:
57
  supported_languages.append("vi")
58
 
59
 
 
74
 
75
 
76
  def calculate_keep_len(text, lang):
 
77
  if lang in ["ja", "zh-cn"]:
78
  return -1
79
 
 
87
  return -1
88
 
89
 
90
+ def predict(prompt, language, audio_file_pth, normalize_text=True):
 
 
 
 
 
91
  if language not in supported_languages:
92
  metrics_text = gr.Warning(
93
+ f"Language {language} is not supported. Please choose from the dropdown."
94
  )
95
+ return None, metrics_text
 
 
 
96
 
97
  if len(prompt) < 2:
98
+ metrics_text = gr.Warning("Please provide a longer prompt text.")
99
+ return None, metrics_text
100
+
101
  try:
102
  metrics_text = ""
103
  t_latent = time.time()
104
 
105
  try:
106
+ gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
107
+ audio_path=audio_file_pth,
 
 
 
108
  gpt_cond_len=30,
109
  gpt_cond_chunk_len=4,
110
  max_ref_length=60,
111
  )
 
112
  except Exception as e:
113
+ print("Speaker encoding error:", str(e))
114
+ metrics_text = gr.Warning("Error with reference audio.")
115
+ return None, metrics_text
 
 
116
 
117
+ prompt = re.sub("([^\x00-\x7F]|\w)(\.|\銆倈\?)", r"\1 \2", prompt)
118
 
119
  if normalize_text and language == "vi":
120
  prompt = normalize_vietnamese_text(prompt)
121
 
122
+ print("Generating new audio...")
123
  t0 = time.time()
124
  out = MODEL.inference(
125
  prompt,
 
131
  enable_text_splitting=True,
132
  )
133
  inference_time = time.time() - t0
134
+ metrics_text += f"Time to generate audio: {round(inference_time * 1000)} ms\n"
 
 
 
135
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
 
136
  metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
137
 
 
138
  keep_len = calculate_keep_len(prompt, language)
139
  out["wav"] = out["wav"][:keep_len]
140
 
141
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
142
 
143
  except RuntimeError as e:
144
+ print("RuntimeError:", str(e))
145
+ metrics_text = gr.Warning("An error occurred during processing.")
146
+ return None, metrics_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ return "output.wav", metrics_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  title = "viXTTS Demo"
152
 
 
153
  with gr.Blocks(analytics_enabled=False) as demo:
154
  with gr.Row():
155
  with gr.Column():
156
+ gr.Markdown("## viXTTS Demo")
 
 
 
 
157
  with gr.Column():
 
158
  pass
159
 
160
  with gr.Row():
 
167
  language_gr = gr.Dropdown(
168
  label="Language",
169
  info="Select an output language for the synthesised speech",
170
+ choices=supported_languages,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  value="vi",
172
  )
173
  normalize_text = gr.Checkbox(
174
  label="Normalize Vietnamese Text",
175
  info="Normalize Vietnamese Text",
176
+ value=True,
177
  )
178
  ref_gr = gr.Audio(
179
  label="Reference Audio",
 
185
 
186
  with gr.Column():
187
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
188
+ out_text_gr = gr.Textbox(label="Metrics")
189
 
190
  tts_button.click(
191
  predict,
192
+ [input_text_gr, language_gr, ref_gr, normalize_text],
 
 
 
 
 
193
  outputs=[audio_gr, out_text_gr],
194
  api_name="predict",
195
  )
196
 
197
  demo.queue()
198
+ demo.launch(debug=True, show_api=True)