Spaces:
TDN-M
/
Running on Zero

TDN-M commited on
Commit
c9207a5
·
verified ·
1 Parent(s): c60ab48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -128
app.py CHANGED
@@ -13,15 +13,12 @@ from huggingface_hub import HfApi, hf_hub_download, snapshot_download
13
  from TTS.tts.configs.xtts_config import XttsConfig
14
  from TTS.tts.models.xtts import Xtts
15
  from vinorm import TTSnorm
16
- from langchain_community.llms import HuggingFacePipeline
 
17
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
18
- from gradio_client import Client
19
- import cv2
20
- from moviepy.editor import AudioFileClip, ImageSequenceClip
21
- import gc
22
- from content_generation import create_content # Nhập hàm create_content từ file content_generation.py
23
 
24
- # download for mecab
25
  os.system("python -m unidic download")
26
  HF_TOKEN = os.environ.get("HF_TOKEN")
27
  api = HfApi(token=HF_TOKEN)
@@ -58,19 +55,29 @@ supported_languages = config.languages
58
  if not "vi" in supported_languages:
59
  supported_languages.append("vi")
60
 
61
- # Load LangChain components với hình mới
62
  model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl")
63
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
64
  pipe = pipeline(
65
  'text2text-generation',
66
  model=model,
67
  tokenizer=tokenizer,
68
- max_length=1024 # Cập nhật max_length
69
  )
70
  local_llm = HuggingFacePipeline(pipeline=pipe)
71
- llm_chain = caption_chain.chain(llm=local_llm)
72
- sum_llm_chain = tag_chain.chain(llm=local_llm)
73
- pexels_api_key = os.getenv('pexels_api_key')
 
 
 
 
 
 
 
 
 
 
74
 
75
  def normalize_vietnamese_text(text):
76
  text = (
@@ -100,81 +107,23 @@ def calculate_keep_len(text, lang):
100
  return 13000 * word_count + 2000 * num_punct
101
  return -1
102
 
103
- def create_video_from_audio_and_images(audio_path, images, output_path):
104
- audio_clip = AudioFileClip(audio_path)
105
- duration = audio_clip.duration
106
-
107
- # Calculate frame rate based on number of images and audio duration
108
- frame_rate = len(images) / duration
109
-
110
- # Create video clip from images
111
- video_clip = ImageSequenceClip(images, fps=frame_rate)
112
-
113
- # Set audio for video clip
114
- final_clip = video_clip.set_audio(audio_clip)
115
-
116
- # Write result to file
117
- final_clip.write_videofile(output_path, codec='libx264', audio_codec='aac')
118
- audio_clip.close()
119
- video_clip.close()
120
- final_clip.close()
121
-
122
- def truncate_prompt(prompt, tokenizer, max_length=512):
123
- """Truncate prompt to fit within the maximum token length."""
124
- tokens = tokenizer.tokenize(prompt)
125
- if len(tokens) > max_length:
126
- tokens = tokens[:max_length]
127
- prompt = tokenizer.convert_tokens_to_string(tokens)
128
- return prompt
129
-
130
- def generate_images_from_sentences(sentences):
131
- try:
132
- client = Client("ByteDance/Hyper-FLUX-8Steps-LoRA")
133
- for i, sentence in enumerate(sentences):
134
- print(f"Generating image for sentence {i + 1}: {sentence}")
135
- result = client.predict(
136
- height=1024,
137
- width=1024,
138
- steps=8,
139
- scales=3.5,
140
- prompt=sentence,
141
- seed=3413,
142
- api_name="/process_image"
143
- )
144
- image_path = os.path.join(folder_path, f"image_{i + 1}.png")
145
- result.save(image_path)
146
- print(f"Saved image at {image_path}")
147
- except Exception as e:
148
- print("Error! Failed generating images")
149
- print(e)
150
- return []
151
-
152
  @spaces.GPU
153
  def predict(
154
  prompt,
155
  language,
156
  audio_file_pth,
157
  normalize_text=True,
158
- use_llm=False, # Thêm tùy chọn sử dụng LLM
159
- content_type="Theo yêu cầu", # Loại nội dung (ví dụ: "triết lý sống" hoặc "Theo yêu cầu")
160
  ):
161
- if use_llm:
162
- # Nếu sử dụng LLM, tạo nội dung văn bản từ đầu vào
163
- print("I: Generating text with LLM...")
164
- generated_text = create_content(prompt, content_type, language)
165
- print(f"Generated text: {generated_text}")
166
- prompt = generated_text # Gán văn bản được tạo bởi LLM vào biến prompt
167
-
168
  if language not in supported_languages:
169
  metrics_text = gr.Warning(
170
  f"Language you put {language} in is not in our Supported Languages, please choose from dropdown"
171
  )
172
- return (None, None, metrics_text)
173
 
174
  speaker_wav = audio_file_pth
175
  if len(prompt) < 2:
176
  metrics_text = gr.Warning("Please give a longer prompt text")
177
- return (None, None, metrics_text)
178
 
179
  try:
180
  metrics_text = ""
@@ -194,15 +143,12 @@ def predict(
194
  metrics_text = gr.Warning(
195
  "It appears something wrong with reference, did you unmute your microphone?"
196
  )
197
- return (None, None, metrics_text)
198
 
199
  prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
200
  if normalize_text and language == "vi":
201
  prompt = normalize_vietnamese_text(prompt)
202
-
203
- # Truncate prompt to fit within the maximum token length
204
- prompt = truncate_prompt(prompt, tokenizer, max_length=512)
205
-
206
  print("I: Generating new audio...")
207
  t0 = time.time()
208
  out = MODEL.inference(
@@ -227,38 +173,15 @@ def predict(
227
  keep_len = calculate_keep_len(prompt, language)
228
  out["wav"] = out["wav"][:keep_len]
229
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
230
-
231
- # Tạo video từ file audio và các cảnh
232
- print("I: Generating images from sentences...")
233
- # Sử dụng UUID để tạo tên thư mục ngắn gọn
234
- folder_name = f"video_{uuid.uuid4().hex}"
235
- os.makedirs(folder_name, exist_ok=True)
236
- folder_path = os.path.join(folder_name, "images")
237
- os.makedirs(folder_path, exist_ok=True)
238
-
239
- # Tách các câu từ văn bản
240
- sentences = [x.strip() for x in re.split(r'[.!?]', prompt) if len(x.strip()) > 6]
241
-
242
- # Tạo ảnh minh họa cho từng câu
243
- images = generate_images_from_sentences(sentences)
244
-
245
- # Tạo video từ file audio và các ảnh
246
- video_path = os.path.join(folder_name, "Final_Ad_Video.mp4")
247
- create_video_from_audio_and_images("output.wav", images, video_path)
248
-
249
- print(f"I: Video generated at {video_path}")
250
- metrics_text += f"Video generated at {video_path}\n"
251
-
252
- return ("output.wav", video_path, metrics_text)
253
  except RuntimeError as e:
254
  if "device-side assert" in str(e):
255
- # cannot do anything on cuda device side error, need to restart
256
  print(
257
  f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
258
  flush=True,
259
  )
260
  gr.Warning("Unhandled Exception encounter, please retry in a minute")
261
- print("Cuda device-assert Runtime encountered need restart")
262
  error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
263
  error_data = [
264
  error_time,
@@ -273,7 +196,7 @@ def predict(
273
  csv.writer(write_io).writerows([error_data])
274
  csv_upload = write_io.getvalue().encode()
275
  filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
276
- print("Writing error csv")
277
  error_api = HfApi()
278
  error_api.upload_file(
279
  path_or_fileobj=csv_upload,
@@ -281,7 +204,7 @@ def predict(
281
  repo_id="coqui/xtts-flagged-dataset",
282
  repo_type="dataset",
283
  )
284
- # speaker_wav
285
  print("Writing error reference audio")
286
  speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
287
  error_api = HfApi()
@@ -308,16 +231,9 @@ def predict(
308
  metrics_text = gr.Warning(
309
  "Something unexpected happened please retry again."
310
  )
311
- return (None, None, metrics_text)
312
- except Exception as e:
313
- print("Unexpected error:", str(e))
314
- metrics_text = gr.Warning(
315
- "An unexpected error occurred. Please try again later."
316
- )
317
- return (None, None, metrics_text)
318
- return ("output.wav", None, metrics_text)
319
 
320
- # Cập nhật giao diện Gradio
321
  with gr.Blocks(analytics_enabled=False) as demo:
322
  with gr.Row():
323
  with gr.Column():
@@ -327,7 +243,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
327
  """
328
  )
329
  with gr.Column():
330
- # placeholder to align the image
331
  pass
332
 
333
  with gr.Row():
@@ -367,16 +283,6 @@ with gr.Blocks(analytics_enabled=False) as demo:
367
  info="Normalize Vietnamese text",
368
  value=True,
369
  )
370
- use_llm_checkbox = gr.Checkbox(
371
- label="Sử dụng LLM để tạo nội dung",
372
- info="Use LLM to generate content",
373
- value=False,
374
- )
375
- content_type_dropdown = gr.Dropdown(
376
- label="Loại nội dung",
377
- choices=["triết lý sống", "Theo y��u cầu"],
378
- value="Theo yêu cầu",
379
- )
380
  ref_gr = gr.Audio(
381
  label="Reference Audio (Giọng mẫu)",
382
  type="filepath",
@@ -391,7 +297,6 @@ with gr.Blocks(analytics_enabled=False) as demo:
391
 
392
  with gr.Column():
393
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
394
- video_gr = gr.Video(label="Generated Video")
395
  out_text_gr = gr.Text(label="Metrics")
396
 
397
  tts_button.click(
@@ -401,10 +306,8 @@ with gr.Blocks(analytics_enabled=False) as demo:
401
  language_gr,
402
  ref_gr,
403
  normalize_text,
404
- use_llm_checkbox, # Thêm checkbox để bật/tắt LLM
405
- content_type_dropdown, # Thêm dropdown để chọn loại nội dung
406
  ],
407
- outputs=[audio_gr, video_gr, out_text_gr],
408
  api_name="predict",
409
  )
410
 
 
13
  from TTS.tts.configs.xtts_config import XttsConfig
14
  from TTS.tts.models.xtts import Xtts
15
  from vinorm import TTSnorm
16
+ from langchain.prompts import PromptTemplate
17
+ from langchain.chains import LLMChain
18
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
19
+ from langchain_community.llms import HuggingFacePipeline
 
 
 
 
20
 
21
+ # Download for mecab
22
  os.system("python -m unidic download")
23
  HF_TOKEN = os.environ.get("HF_TOKEN")
24
  api = HfApi(token=HF_TOKEN)
 
55
  if not "vi" in supported_languages:
56
  supported_languages.append("vi")
57
 
58
+ # Load LangChain components with the new model
59
  model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl")
60
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
61
  pipe = pipeline(
62
  'text2text-generation',
63
  model=model,
64
  tokenizer=tokenizer,
65
+ max_length=1024 # Update max_length
66
  )
67
  local_llm = HuggingFacePipeline(pipeline=pipe)
68
+
69
+ # Define the caption_chain function
70
+ def caption_chain(llm):
71
+ sum_template = """What is the most significant action, place, or thing? Say it in at most 5 words:
72
+
73
+ {sentence}
74
+ """
75
+ sum_prompt = PromptTemplate(template=sum_template, input_variables=["sentence"])
76
+ sum_llm_chain = LLMChain(prompt=sum_prompt, llm=llm)
77
+ return sum_llm_chain
78
+
79
+ # Initialize the caption_chain and tag_chain
80
+ llm_chain = caption_chain(llm=local_llm)
81
 
82
  def normalize_vietnamese_text(text):
83
  text = (
 
107
  return 13000 * word_count + 2000 * num_punct
108
  return -1
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  @spaces.GPU
111
  def predict(
112
  prompt,
113
  language,
114
  audio_file_pth,
115
  normalize_text=True,
 
 
116
  ):
 
 
 
 
 
 
 
117
  if language not in supported_languages:
118
  metrics_text = gr.Warning(
119
  f"Language you put {language} in is not in our Supported Languages, please choose from dropdown"
120
  )
121
+ return (None, metrics_text)
122
 
123
  speaker_wav = audio_file_pth
124
  if len(prompt) < 2:
125
  metrics_text = gr.Warning("Please give a longer prompt text")
126
+ return (None, metrics_text)
127
 
128
  try:
129
  metrics_text = ""
 
143
  metrics_text = gr.Warning(
144
  "It appears something wrong with reference, did you unmute your microphone?"
145
  )
146
+ return (None, metrics_text)
147
 
148
  prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
149
  if normalize_text and language == "vi":
150
  prompt = normalize_vietnamese_text(prompt)
151
+
 
 
 
152
  print("I: Generating new audio...")
153
  t0 = time.time()
154
  out = MODEL.inference(
 
173
  keep_len = calculate_keep_len(prompt, language)
174
  out["wav"] = out["wav"][:keep_len]
175
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  except RuntimeError as e:
177
  if "device-side assert" in str(e):
178
+ # Cannot do anything on CUDA device side error, need to restart
179
  print(
180
  f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
181
  flush=True,
182
  )
183
  gr.Warning("Unhandled Exception encounter, please retry in a minute")
184
+ print("CUDA device-assert Runtime encountered need restart")
185
  error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
186
  error_data = [
187
  error_time,
 
196
  csv.writer(write_io).writerows([error_data])
197
  csv_upload = write_io.getvalue().encode()
198
  filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
199
+ print("Writing error CSV")
200
  error_api = HfApi()
201
  error_api.upload_file(
202
  path_or_fileobj=csv_upload,
 
204
  repo_id="coqui/xtts-flagged-dataset",
205
  repo_type="dataset",
206
  )
207
+ # Speaker WAV
208
  print("Writing error reference audio")
209
  speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
210
  error_api = HfApi()
 
231
  metrics_text = gr.Warning(
232
  "Something unexpected happened please retry again."
233
  )
234
+ return (None, metrics_text)
235
+ return ("output.wav", metrics_text)
 
 
 
 
 
 
236
 
 
237
  with gr.Blocks(analytics_enabled=False) as demo:
238
  with gr.Row():
239
  with gr.Column():
 
243
  """
244
  )
245
  with gr.Column():
246
+ # Placeholder to align the image
247
  pass
248
 
249
  with gr.Row():
 
283
  info="Normalize Vietnamese text",
284
  value=True,
285
  )
 
 
 
 
 
 
 
 
 
 
286
  ref_gr = gr.Audio(
287
  label="Reference Audio (Giọng mẫu)",
288
  type="filepath",
 
297
 
298
  with gr.Column():
299
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
 
300
  out_text_gr = gr.Text(label="Metrics")
301
 
302
  tts_button.click(
 
306
  language_gr,
307
  ref_gr,
308
  normalize_text,
 
 
309
  ],
310
+ outputs=[audio_gr, out_text_gr],
311
  api_name="predict",
312
  )
313