fffiloni commited on
Commit
7ba1d45
·
verified ·
1 Parent(s): 3f071e2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+ # Download All Required Models using `snapshot_download`
4
+
5
+ # Download Wan2.1-I2V-14B-480P model
6
+ wan_model_path = snapshot_download(
7
+ repo_id="Wan-AI/Wan2.1-I2V-14B-480P",
8
+ local_dir="./weights/Wan2.1-I2V-14B-480P",
9
+ #local_dir_use_symlinks=False
10
+ )
11
+
12
+ # Download Chinese wav2vec2 model
13
+ wav2vec_path = snapshot_download(
14
+ repo_id="TencentGameMate/chinese-wav2vec2-base",
15
+ local_dir="./weights/chinese-wav2vec2-base",
16
+ #local_dir_use_symlinks=False
17
+ )
18
+
19
+ # Download MeiGen MultiTalk weights
20
+ multitalk_path = snapshot_download(
21
+ repo_id="MeiGen-AI/MeiGen-MultiTalk",
22
+ local_dir="./weights/MeiGen-MultiTalk",
23
+ #local_dir_use_symlinks=False
24
+ )
25
+
26
+
27
+ import os
28
+ import shutil
29
+
30
+ # Define paths
31
+ base_model_dir = "./weights/Wan2.1-I2V-14B-480P"
32
+ multitalk_dir = "./weights/MeiGen-MultiTalk"
33
+
34
+ # File to rename
35
+ original_index = os.path.join(base_model_dir, "diffusion_pytorch_model.safetensors.index.json")
36
+ backup_index = os.path.join(base_model_dir, "diffusion_pytorch_model.safetensors.index.json_old")
37
+
38
+ # Rename the original index file
39
+ if os.path.exists(original_index):
40
+ os.rename(original_index, backup_index)
41
+ print("Renamed original index file to .json_old")
42
+
43
+ # Copy updated index file from MultiTalk
44
+ shutil.copy2(
45
+ os.path.join(multitalk_dir, "diffusion_pytorch_model.safetensors.index.json"),
46
+ base_model_dir
47
+ )
48
+
49
+ # Copy MultiTalk model weights
50
+ shutil.copy2(
51
+ os.path.join(multitalk_dir, "multitalk.safetensors"),
52
+ base_model_dir
53
+ )
54
+
55
+ print("Copied MultiTalk files into base model directory.")
56
+
57
+
58
+ import torch
59
+
60
+ # Check if CUDA-compatible GPU is available
61
+ if torch.cuda.is_available():
62
+ # Get current GPU name
63
+ gpu_name = torch.cuda.get_device_name(torch.cuda.current_device())
64
+ print(f"Current GPU: {gpu_name}")
65
+
66
+ # Enforce GPU requirement
67
+ if "A100" not in gpu_name and "L4" not in gpu_name:
68
+ raise RuntimeError(f"This notebook requires an A100 or L4 GPU. Found: {gpu_name}")
69
+ elif "L4" in gpu_name:
70
+ print("Warning: L4 is supported, but A100 is recommended for faster inference.")
71
+ else:
72
+ raise RuntimeError("No CUDA-compatible GPU found. An A100 or L4 GPU is required.")
73
+
74
+
75
+ GPU_TO_VRAM_PARAMS = {
76
+ "NVIDIA A100": 11000000000,
77
+ "NVIDIA A100-SXM4-40GB": 11000000000,
78
+ "NVIDIA A100-SXM4-80GB": 22000000000,
79
+ "NVIDIA L4": 5000000000
80
+ }
81
+ USED_VRAM_PARAMS = GPU_TO_VRAM_PARAMS[gpu_name]
82
+ print("Using", USED_VRAM_PARAMS, "for num_persistent_param_in_dit")
83
+
84
+ import subprocess
85
+
86
+ import json
87
+ import tempfile
88
+ #import os
89
+
90
+ def create_temp_input_json(prompt: str, cond_image_path: str, cond_audio_path: str) -> str:
91
+ """
92
+ Create a temporary JSON file with the user-provided prompt, image, and audio paths.
93
+ Returns the path to the temporary JSON file.
94
+ """
95
+ # Structure based on your original JSON format
96
+ data = {
97
+ "prompt": prompt,
98
+ "cond_image": cond_image_path,
99
+ "cond_audio": {
100
+ "person1": cond_audio_path
101
+ }
102
+ }
103
+
104
+ # Create a temp file
105
+ temp_json = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w', encoding='utf-8')
106
+ json.dump(data, temp_json, indent=4)
107
+ temp_json_path = temp_json.name
108
+ temp_json.close()
109
+
110
+ print(f"Temporary input JSON saved to: {temp_json_path}")
111
+ return temp_json_path
112
+
113
+
114
+ def infer(prompt, cond_image_path, cond_audio_path):
115
+
116
+ # Example usage (from user input)
117
+ prompt = "A woman sings passionately in a dimly lit studio."
118
+ cond_image_path = "examples/single/single1.png" # Assume uploaded via Gradio
119
+ cond_audio_path = "examples/single/1.wav" # Assume uploaded via Gradio
120
+
121
+ input_json_path = create_temp_input_json(prompt, cond_image_path, cond_audio_path)
122
+
123
+ cmd = [
124
+ "python3", "generate_multitalk.py",
125
+ "--ckpt_dir", "weights/Wan2.1-I2V-14B-480P",
126
+ "--wav2vec_dir", "weights/chinese-wav2vec2-base",
127
+ "--input_json", "./examples/single_example_1.json",
128
+ "--sample_steps", "20",
129
+ "--num_persistent_param_in_dit", USED_VRAM_PARAMS,
130
+ "--mode", "streaming",
131
+ "--use_teacache",
132
+ "--save_file", "multi_long_mediumvram_exp"
133
+ ]
134
+
135
+ subprocess.run(cmd, check=True)
136
+
137
+ return "multi_long_mediumvra_exp.mp4"
138
+
139
+ import gradio as gr
140
+
141
+ with gr.Blocks(title="MultiTalk Inference") as demo:
142
+ gr.Markdown("## 🎤 MultiTalk Inference Demo")
143
+
144
+ with gr.Row():
145
+ with gr.Column():
146
+ prompt_input = gr.Textbox(
147
+ label="Text Prompt",
148
+ placeholder="Describe the scene...",
149
+ lines=4
150
+ )
151
+
152
+ image_input = gr.Image(
153
+ type="filepath",
154
+ label="Conditioning Image"
155
+ )
156
+
157
+ audio_input = gr.Audio(
158
+ type="filepath",
159
+ label="Conditioning Audio (.wav)"
160
+ )
161
+
162
+ submit_btn = gr.Button("Generate")
163
+
164
+ with gr.Column():
165
+ output_video = gr.Video(label="Generated Video")
166
+
167
+ submit_btn.click(
168
+ fn=infer,
169
+ inputs=[prompt_input, image_input, audio_input],
170
+ outputs=output_video
171
+ )
172
+
173
+ demo.launch()