|
import csv |
|
import datetime |
|
import os |
|
import re |
|
import subprocess |
|
import time |
|
import uuid |
|
from io import BytesIO, StringIO |
|
import gradio as gr |
|
import spaces |
|
import torch |
|
import torchaudio |
|
from huggingface_hub import HfApi, hf_hub_download, snapshot_download |
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
from TTS.tts.models.xtts import Xtts |
|
from vinorm import TTSnorm |
|
from content_generation import create_content |
|
from PIL import Image |
|
from pathlib import Path |
|
import requests |
|
import json |
|
import hashlib |
|
|
|
|
|
os.system("python -m unidic download") |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
api = HfApi(token=HF_TOKEN) |
|
|
|
|
|
print("Downloading if not downloaded viXTTS") |
|
checkpoint_dir = "model/" |
|
repo_id = "capleaf/viXTTS" |
|
use_deepspeed = False |
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] |
|
files_in_dir = os.listdir(checkpoint_dir) |
|
if not all(file in files_in_dir for file in required_files): |
|
snapshot_download( |
|
repo_id=repo_id, |
|
repo_type="model", |
|
local_dir=checkpoint_dir, |
|
) |
|
hf_hub_download( |
|
repo_id="coqui/XTTS-v2", |
|
filename="speakers_xtts.pth", |
|
local_dir=checkpoint_dir, |
|
) |
|
|
|
xtts_config = os.path.join(checkpoint_dir, "config.json") |
|
config = XttsConfig() |
|
config.load_json(xtts_config) |
|
MODEL = Xtts.init_from_config(config) |
|
MODEL.load_checkpoint( |
|
config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed |
|
) |
|
if torch.cuda.is_available(): |
|
MODEL.cuda() |
|
|
|
supported_languages = config.languages |
|
if "vi" not in supported_languages: |
|
supported_languages.append("vi") |
|
|
|
|
|
def normalize_vietnamese_text(text): |
|
text = ( |
|
TTSnorm(text, unknown=False, lower=False, rule=True) |
|
.replace("..", ".") |
|
.replace("!.", "!") |
|
.replace("?.", "?") |
|
.replace(" .", ".") |
|
.replace(" ,", ",") |
|
.replace('"', "") |
|
.replace("'", "") |
|
.replace("AI", "Ây Ai") |
|
.replace("A.I", "Ây Ai") |
|
.replace("%", "phần trăm") |
|
) |
|
return text |
|
|
|
|
|
def calculate_keep_len(text, lang): |
|
"""Simple hack for short sentences""" |
|
if lang in ["ja", "zh-cn"]: |
|
return -1 |
|
word_count = len(text.split()) |
|
num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",") |
|
if word_count < 5: |
|
return 15000 * word_count + 2000 * num_punct |
|
elif word_count < 10: |
|
return 13000 * word_count + 2000 * num_punct |
|
return -1 |
|
|
|
|
|
def generate_image_description(prompt): |
|
return f"A visual representation of: {prompt}" |
|
|
|
|
|
def txt2img(prompt, width, height): |
|
model_id = "770694094415489962" |
|
vae_id = "sdxl-vae-fp16-fix.safetensors" |
|
lora_items = [ |
|
{"loraModel": "766419665653268679", "weight": 0.7}, |
|
{"loraModel": "777630084346589138", "weight": 0.7}, |
|
{"loraModel": "776587863287492519", "weight": 0.7} |
|
] |
|
txt2img_data = { |
|
"request_id": hashlib.md5(str(int(time.time())).encode()).hexdigest(), |
|
"stages": [ |
|
{ |
|
"type": "INPUT_INITIALIZE", |
|
"inputInitialize": { |
|
"seed": -1, |
|
"count": 1 |
|
} |
|
}, |
|
{ |
|
"type": "DIFFUSION", |
|
"diffusion": { |
|
"width": width, |
|
"height": height, |
|
"prompts": [ |
|
{ |
|
"text": prompt |
|
} |
|
], |
|
"negativePrompts": [ |
|
{ |
|
"text": "nsfw" |
|
} |
|
], |
|
"sdModel": model_id, |
|
"sdVae": vae_id, |
|
"sampler": "Euler a", |
|
"steps": 20, |
|
"cfgScale": 3, |
|
"clipSkip": 1, |
|
"etaNoiseSeedDelta": 31337, |
|
"lora": { |
|
"items": lora_items |
|
} |
|
} |
|
} |
|
] |
|
} |
|
body = json.dumps(txt2img_data) |
|
headers = { |
|
'Content-Type': 'application/json', |
|
'Accept': 'application/json', |
|
'Authorization': f'Bearer {os.getenv("api_key_token")}' |
|
} |
|
response = requests.post(f"https://ap-east-1.tensorart.cloud/v1/jobs", json=txt2img_data, headers=headers) |
|
if response.status_code != 200: |
|
return f"Error: {response.status_code} - {response.text}" |
|
response_data = response.json() |
|
job_id = response_data['job']['id'] |
|
print(f"Job created. ID: {job_id}") |
|
start_time = time.time() |
|
timeout = 300 |
|
while True: |
|
time.sleep(10) |
|
elapsed_time = time.time() - start_time |
|
if elapsed_time > timeout: |
|
return f"Error: Job timed out after {timeout} seconds." |
|
response = requests.get(f"https://ap-east-1.tensorart.cloud/v1/jobs/{job_id}", headers=headers) |
|
if response.status_code != 200: |
|
return f"Error: {response.status_code} - {response.text}" |
|
get_job_response_data = response.json() |
|
job_status = get_job_response_data['job']['status'] |
|
print(f"Job status: {job_status}") |
|
if job_status == 'SUCCESS': |
|
if 'successInfo' in get_job_response_data['job']: |
|
image_url = get_job_response_data['job']['successInfo']['images'][0]['url'] |
|
print(f"Job succeeded. Image URL: {image_url}") |
|
response_image = requests.get(image_url) |
|
img = Image.open(BytesIO(response_image.content)) |
|
return img |
|
else: |
|
return "Error: Output is missing in the job response." |
|
elif job_status == 'FAILED': |
|
return "Error: Job failed. Please try again with different settings." |
|
|
|
|
|
def create_video(image_path, audio_path, output_path): |
|
command = [ |
|
"ffmpeg", |
|
"-i", image_path, |
|
"-i", audio_path, |
|
"-filter_complex", |
|
"[1:a]aformat=channel_layouts=mono,showwaves=s=1200x400:mode=p2p:colors=blue@0.8[w];[0:v][w]overlay=(W-w)/2:(H-h)/2", |
|
"-c:v", "libx264", |
|
"-b:v", "2000k", |
|
"-c:a", "aac", |
|
"-b:a", "192k", |
|
"-y", output_path |
|
] |
|
subprocess.run(command, check=True) |
|
|
|
|
|
def generate_video(prompt, language, audio_file_pth, normalize_text, use_llm, content_type): |
|
|
|
if not os.path.exists("output.wav"): |
|
audio_file, metrics_text = predict(prompt, language, audio_file_pth, normalize_text, use_llm, content_type) |
|
if not audio_file: |
|
return None, metrics_text |
|
else: |
|
audio_file = "output.wav" |
|
|
|
|
|
image_description = generate_image_description(prompt) |
|
|
|
|
|
try: |
|
image = txt2img(image_description, width=800, height=600) |
|
if isinstance(image, str): |
|
return None, image |
|
|
|
|
|
image_path = os.path.join(SAVE_DIR, "generated_image.png") |
|
image.save(image_path) |
|
except Exception as e: |
|
return None, f"Error generating image: {str(e)}" |
|
|
|
|
|
video_output_path = os.path.join(SAVE_DIR, "output_video.mp4") |
|
try: |
|
create_video(image_path, audio_file, video_output_path) |
|
except Exception as e: |
|
return None, f"Error creating video: {str(e)}" |
|
|
|
return video_output_path, "Video created successfully!" |
|
|
|
|
|
SAVE_DIR = "generated_images" |
|
Path(SAVE_DIR).mkdir(exist_ok=True) |
|
|
|
|
|
@spaces.GPU |
|
def predict( |
|
prompt, |
|
language, |
|
audio_file_pth, |
|
normalize_text=True, |
|
use_llm=False, |
|
content_type="Theo yêu cầu", |
|
): |
|
if use_llm: |
|
print("I: Generating text with LLM...") |
|
generated_text = create_content(prompt, content_type, language) |
|
print(f"Generated text: {generated_text}") |
|
prompt = generated_text |
|
if language not in supported_languages: |
|
metrics_text = gr.Warning( |
|
f"Language you put {language} in is not in our Supported Languages, please choose from dropdown" |
|
) |
|
return (None, metrics_text) |
|
speaker_wav = audio_file_pth |
|
if len(prompt) < 2: |
|
metrics_text = gr.Warning("Please give a longer prompt text") |
|
return (None, metrics_text) |
|
try: |
|
metrics_text = "" |
|
t_latent = time.time() |
|
try: |
|
( |
|
gpt_cond_latent, |
|
speaker_embedding, |
|
) = MODEL.get_conditioning_latents( |
|
audio_path=speaker_wav, |
|
gpt_cond_len=30, |
|
gpt_cond_chunk_len=4, |
|
max_ref_length=60, |
|
) |
|
except Exception as e: |
|
print("Speaker encoding error", str(e)) |
|
metrics_text = gr.Warning( |
|
"It appears something wrong with reference, did you unmute your microphone?" |
|
) |
|
return (None, metrics_text) |
|
prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt) |
|
if normalize_text and language == "vi": |
|
prompt = normalize_vietnamese_text(prompt) |
|
print("I: Generating new audio...") |
|
t0 = time.time() |
|
out = MODEL.inference( |
|
prompt, |
|
language, |
|
gpt_cond_latent, |
|
speaker_embedding, |
|
repetition_penalty=5.0, |
|
temperature=0.75, |
|
enable_text_splitting=True, |
|
) |
|
inference_time = time.time() - t0 |
|
print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds") |
|
metrics_text += ( |
|
f"Time to generate audio: {round(inference_time*1000)} milliseconds\n" |
|
) |
|
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000 |
|
print(f"Real-time factor (RTF): {real_time_factor}") |
|
metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n" |
|
keep_len = calculate_keep_len(prompt, language) |
|
out["wav"] = out["wav"][:keep_len] |
|
torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000) |
|
except RuntimeError as e: |
|
if "device-side assert" in str(e): |
|
print( |
|
f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}", |
|
flush=True, |
|
) |
|
gr.Warning("Unhandled Exception encounter, please retry in a minute") |
|
print("Cuda device-assert Runtime encountered need restart") |
|
error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S") |
|
error_data = [ |
|
error_time, |
|
prompt, |
|
language, |
|
audio_file_pth, |
|
] |
|
error_data = [str(e) if type(e) != str else e for e in error_data] |
|
print(error_data) |
|
print(speaker_wav) |
|
write_io = StringIO() |
|
csv.writer(write_io).writerows([error_data]) |
|
csv_upload = write_io.getvalue().encode() |
|
filename = error_time + "_" + str(uuid.uuid4()) + ".csv" |
|
print("Writing error csv") |
|
error_api = HfApi() |
|
error_api.upload_file( |
|
path_or_fileobj=csv_upload, |
|
path_in_repo=filename, |
|
repo_id="coqui/xtts-flagged-dataset", |
|
repo_type="dataset", |
|
) |
|
speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav" |
|
error_api = HfApi() |
|
error_api.upload_file( |
|
path_or_fileobj=speaker_wav, |
|
path_in_repo=speaker_filename, |
|
repo_id="coqui/xtts-flagged-dataset", |
|
repo_type="dataset", |
|
) |
|
space = api.get_space_runtime(repo_id=repo_id) |
|
if space.stage != "BUILDING": |
|
api.restart_space(repo_id=repo_id) |
|
else: |
|
print("TRIED TO RESTART but space is building") |
|
else: |
|
if "Failed to decode" in str(e): |
|
print("Speaker encoding error", str(e)) |
|
metrics_text = gr.Warning( |
|
"It appears something wrong with reference, did you unmute your microphone?" |
|
) |
|
else: |
|
print("RuntimeError: non device-side assert error:", str(e)) |
|
metrics_text = gr.Warning( |
|
"Something unexpected happened please retry again." |
|
) |
|
return (None, metrics_text) |
|
return ("output.wav", metrics_text) |
|
|
|
|
|
with gr.Blocks(analytics_enabled=False) as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown( |
|
""" |
|
# tts@TDNM ✨ https:www.tdn-m.com |
|
""" |
|
) |
|
with gr.Column(): |
|
pass |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text_gr = gr.Textbox( |
|
label="Bạn cần nội dung gì?", |
|
info="Tôi có thể viết và thu âm luôn cho bạn", |
|
value="Lời tự sự của AI, 150 từ", |
|
) |
|
language_gr = gr.Dropdown( |
|
label="Language (Ngôn ngữ)", |
|
choices=[ |
|
"vi", "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "ko", "hu", "hi", |
|
], |
|
max_choices=1, |
|
value="vi", |
|
) |
|
normalize_text = gr.Checkbox( |
|
label="Chuẩn hóa văn bản tiếng Việt", |
|
info="Normalize Vietnamese text", |
|
value=True, |
|
) |
|
use_llm_checkbox = gr.Checkbox( |
|
label="Sử dụng LLM để tạo nội dung", |
|
info="Use LLM to generate content", |
|
value=True, |
|
) |
|
content_type_dropdown = gr.Dropdown( |
|
label="Loại nội dung", |
|
choices=["triết lý sống", "Theo yêu cầu"], |
|
value="Theo yêu cầu", |
|
) |
|
ref_gr = gr.Audio( |
|
label="Reference Audio (Giọng mẫu)", |
|
type="filepath", |
|
value="nam-tai-llieu.wav", |
|
) |
|
tts_button = gr.Button( |
|
"Đọc 🗣️🔥", |
|
elem_id="send-btn", |
|
visible=True, |
|
variant="primary", |
|
) |
|
video_button = gr.Button("Tạo Video 🎥", visible=True) |
|
|
|
with gr.Column(): |
|
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True) |
|
out_text_gr = gr.Text(label="Metrics") |
|
video_output = gr.Video(label="Generated Video", visible=True) |
|
video_status = gr.Text(label="Video Status") |
|
|
|
tts_button.click( |
|
predict, |
|
[ |
|
input_text_gr, |
|
language_gr, |
|
ref_gr, |
|
normalize_text, |
|
use_llm_checkbox, |
|
content_type_dropdown, |
|
], |
|
outputs=[audio_gr, out_text_gr], |
|
api_name="predict", |
|
) |
|
|
|
video_button.click( |
|
generate_video, |
|
inputs=[ |
|
input_text_gr, |
|
language_gr, |
|
ref_gr, |
|
normalize_text, |
|
use_llm_checkbox, |
|
content_type_dropdown, |
|
], |
|
outputs=[video_output, video_status], |
|
) |
|
|
|
demo.queue() |
|
demo.launch(debug=True, show_api=True, share=True) |