Spaces:
TDN-M
/
Running on Zero

TDN-M commited on
Commit
11019ca
·
verified ·
1 Parent(s): c9207a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -40
app.py CHANGED
@@ -13,12 +13,9 @@ from huggingface_hub import HfApi, hf_hub_download, snapshot_download
13
  from TTS.tts.configs.xtts_config import XttsConfig
14
  from TTS.tts.models.xtts import Xtts
15
  from vinorm import TTSnorm
16
- from langchain.prompts import PromptTemplate
17
- from langchain.chains import LLMChain
18
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
19
- from langchain_community.llms import HuggingFacePipeline
20
 
21
- # Download for mecab
22
  os.system("python -m unidic download")
23
  HF_TOKEN = os.environ.get("HF_TOKEN")
24
  api = HfApi(token=HF_TOKEN)
@@ -55,30 +52,6 @@ supported_languages = config.languages
55
  if not "vi" in supported_languages:
56
  supported_languages.append("vi")
57
 
58
- # Load LangChain components with the new model
59
- model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl")
60
- tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
61
- pipe = pipeline(
62
- 'text2text-generation',
63
- model=model,
64
- tokenizer=tokenizer,
65
- max_length=1024 # Update max_length
66
- )
67
- local_llm = HuggingFacePipeline(pipeline=pipe)
68
-
69
- # Define the caption_chain function
70
- def caption_chain(llm):
71
- sum_template = """What is the most significant action, place, or thing? Say it in at most 5 words:
72
-
73
- {sentence}
74
- """
75
- sum_prompt = PromptTemplate(template=sum_template, input_variables=["sentence"])
76
- sum_llm_chain = LLMChain(prompt=sum_prompt, llm=llm)
77
- return sum_llm_chain
78
-
79
- # Initialize the caption_chain and tag_chain
80
- llm_chain = caption_chain(llm=local_llm)
81
-
82
  def normalize_vietnamese_text(text):
83
  text = (
84
  TTSnorm(text, unknown=False, lower=False, rule=True)
@@ -113,7 +86,16 @@ def predict(
113
  language,
114
  audio_file_pth,
115
  normalize_text=True,
 
 
116
  ):
 
 
 
 
 
 
 
117
  if language not in supported_languages:
118
  metrics_text = gr.Warning(
119
  f"Language you put {language} in is not in our Supported Languages, please choose from dropdown"
@@ -148,7 +130,6 @@ def predict(
148
  prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
149
  if normalize_text and language == "vi":
150
  prompt = normalize_vietnamese_text(prompt)
151
-
152
  print("I: Generating new audio...")
153
  t0 = time.time()
154
  out = MODEL.inference(
@@ -175,13 +156,13 @@ def predict(
175
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
176
  except RuntimeError as e:
177
  if "device-side assert" in str(e):
178
- # Cannot do anything on CUDA device side error, need to restart
179
  print(
180
  f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
181
  flush=True,
182
  )
183
  gr.Warning("Unhandled Exception encounter, please retry in a minute")
184
- print("CUDA device-assert Runtime encountered need restart")
185
  error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
186
  error_data = [
187
  error_time,
@@ -196,7 +177,7 @@ def predict(
196
  csv.writer(write_io).writerows([error_data])
197
  csv_upload = write_io.getvalue().encode()
198
  filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
199
- print("Writing error CSV")
200
  error_api = HfApi()
201
  error_api.upload_file(
202
  path_or_fileobj=csv_upload,
@@ -204,7 +185,7 @@ def predict(
204
  repo_id="coqui/xtts-flagged-dataset",
205
  repo_type="dataset",
206
  )
207
- # Speaker WAV
208
  print("Writing error reference audio")
209
  speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
210
  error_api = HfApi()
@@ -234,24 +215,25 @@ def predict(
234
  return (None, metrics_text)
235
  return ("output.wav", metrics_text)
236
 
 
237
  with gr.Blocks(analytics_enabled=False) as demo:
238
  with gr.Row():
239
  with gr.Column():
240
  gr.Markdown(
241
  """
242
- # tts@TDNM ✨ https://www.tdn-m.com
243
  """
244
  )
245
  with gr.Column():
246
- # Placeholder to align the image
247
  pass
248
 
249
  with gr.Row():
250
  with gr.Column():
251
  input_text_gr = gr.Textbox(
252
- label="Text Prompt (Văn bản cần đọc)",
253
- info="Mỗi câu nên từ 10 từ trở lên.",
254
- value="Xin chào, tôi một hình chuyển đổi văn bản thành giọng nói tiếng Việt.",
255
  )
256
  language_gr = gr.Dropdown(
257
  label="Language (Ngôn ngữ)",
@@ -283,10 +265,20 @@ with gr.Blocks(analytics_enabled=False) as demo:
283
  info="Normalize Vietnamese text",
284
  value=True,
285
  )
 
 
 
 
 
 
 
 
 
 
286
  ref_gr = gr.Audio(
287
  label="Reference Audio (Giọng mẫu)",
288
  type="filepath",
289
- value="nam-tai-lieu.wav",
290
  )
291
  tts_button = gr.Button(
292
  "Đọc 🗣️🔥",
@@ -306,6 +298,8 @@ with gr.Blocks(analytics_enabled=False) as demo:
306
  language_gr,
307
  ref_gr,
308
  normalize_text,
 
 
309
  ],
310
  outputs=[audio_gr, out_text_gr],
311
  api_name="predict",
 
13
  from TTS.tts.configs.xtts_config import XttsConfig
14
  from TTS.tts.models.xtts import Xtts
15
  from vinorm import TTSnorm
16
+ from content_generation import create_content # Nhập hàm create_content từ file content_generation.py
 
 
 
17
 
18
+ # download for mecab
19
  os.system("python -m unidic download")
20
  HF_TOKEN = os.environ.get("HF_TOKEN")
21
  api = HfApi(token=HF_TOKEN)
 
52
  if not "vi" in supported_languages:
53
  supported_languages.append("vi")
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def normalize_vietnamese_text(text):
56
  text = (
57
  TTSnorm(text, unknown=False, lower=False, rule=True)
 
86
  language,
87
  audio_file_pth,
88
  normalize_text=True,
89
+ use_llm=False, # Thêm tùy chọn sử dụng LLM
90
+ content_type="Theo yêu cầu", # Loại nội dung (ví dụ: "triết lý sống" hoặc "Theo yêu cầu")
91
  ):
92
+ if use_llm:
93
+ # Nếu sử dụng LLM, tạo nội dung văn bản từ đầu vào
94
+ print("I: Generating text with LLM...")
95
+ generated_text = create_content(prompt, content_type, language)
96
+ print(f"Generated text: {generated_text}")
97
+ prompt = generated_text # Gán văn bản được tạo bởi LLM vào biến prompt
98
+
99
  if language not in supported_languages:
100
  metrics_text = gr.Warning(
101
  f"Language you put {language} in is not in our Supported Languages, please choose from dropdown"
 
130
  prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
131
  if normalize_text and language == "vi":
132
  prompt = normalize_vietnamese_text(prompt)
 
133
  print("I: Generating new audio...")
134
  t0 = time.time()
135
  out = MODEL.inference(
 
156
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
157
  except RuntimeError as e:
158
  if "device-side assert" in str(e):
159
+ # cannot do anything on cuda device side error, need to restart
160
  print(
161
  f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
162
  flush=True,
163
  )
164
  gr.Warning("Unhandled Exception encounter, please retry in a minute")
165
+ print("Cuda device-assert Runtime encountered need restart")
166
  error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
167
  error_data = [
168
  error_time,
 
177
  csv.writer(write_io).writerows([error_data])
178
  csv_upload = write_io.getvalue().encode()
179
  filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
180
+ print("Writing error csv")
181
  error_api = HfApi()
182
  error_api.upload_file(
183
  path_or_fileobj=csv_upload,
 
185
  repo_id="coqui/xtts-flagged-dataset",
186
  repo_type="dataset",
187
  )
188
+ # speaker_wav
189
  print("Writing error reference audio")
190
  speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
191
  error_api = HfApi()
 
215
  return (None, metrics_text)
216
  return ("output.wav", metrics_text)
217
 
218
+ # Cập nhật giao diện Gradio
219
  with gr.Blocks(analytics_enabled=False) as demo:
220
  with gr.Row():
221
  with gr.Column():
222
  gr.Markdown(
223
  """
224
+ # tts@TDNM ✨ https:www.tdn-m.com
225
  """
226
  )
227
  with gr.Column():
228
+ # placeholder to align the image
229
  pass
230
 
231
  with gr.Row():
232
  with gr.Column():
233
  input_text_gr = gr.Textbox(
234
+ label="Bạn cần nội dung gì?",
235
+ info="Tôi thể viết thu âm luôn cho bạn",
236
+ value="Lời tự sự của AI, 150 từ",
237
  )
238
  language_gr = gr.Dropdown(
239
  label="Language (Ngôn ngữ)",
 
265
  info="Normalize Vietnamese text",
266
  value=True,
267
  )
268
+ use_llm_checkbox = gr.Checkbox(
269
+ label="Sử dụng LLM để tạo nội dung",
270
+ info="Use LLM to generate content",
271
+ value=True,
272
+ )
273
+ content_type_dropdown = gr.Dropdown(
274
+ label="Loại nội dung",
275
+ choices=["triết lý sống", "Theo yêu cầu"],
276
+ value="Theo yêu cầu",
277
+ )
278
  ref_gr = gr.Audio(
279
  label="Reference Audio (Giọng mẫu)",
280
  type="filepath",
281
+ value="nam-tai-llieu.wav",
282
  )
283
  tts_button = gr.Button(
284
  "Đọc 🗣️🔥",
 
298
  language_gr,
299
  ref_gr,
300
  normalize_text,
301
+ use_llm_checkbox, # Thêm checkbox để bật/tắt LLM
302
+ content_type_dropdown, # Thêm dropdown để chọn loại nội dung
303
  ],
304
  outputs=[audio_gr, out_text_gr],
305
  api_name="predict",