Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -57,14 +57,14 @@ def generate_image(prompt, base="Realistic", motion="", step=1, progress=gr.Prog
|
|
57 |
logger.info(f"Generating video with prompt: {prompt}, base: {base}, steps: {step}")
|
58 |
|
59 |
try:
|
60 |
-
# Tải
|
61 |
if step_loaded != step:
|
62 |
repo = "ByteDance/AnimateDiff-Lightning"
|
63 |
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
|
64 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
|
65 |
step_loaded = step
|
66 |
|
67 |
-
# Tải mô hình cơ sở
|
68 |
if base_loaded != base:
|
69 |
pipe.unet.load_state_dict(
|
70 |
torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
|
@@ -72,66 +72,59 @@ def generate_image(prompt, base="Realistic", motion="", step=1, progress=gr.Prog
|
|
72 |
)
|
73 |
base_loaded = base
|
74 |
|
75 |
-
# Tải motion LoRA
|
76 |
if motion_loaded != motion:
|
77 |
try:
|
78 |
pipe.unload_lora_weights()
|
79 |
-
if motion
|
80 |
pipe.load_lora_weights(motion, adapter_name="motion")
|
81 |
pipe.set_adapters(["motion"], [0.7])
|
82 |
motion_loaded = motion
|
83 |
except Exception as e:
|
84 |
-
logger.warning(f"Failed to load LoRA
|
85 |
motion_loaded = ""
|
86 |
|
87 |
progress((0, step))
|
88 |
def progress_callback(i, t, z):
|
89 |
progress((i + 1, step))
|
90 |
|
91 |
-
#
|
92 |
with torch.no_grad():
|
93 |
output = pipe(
|
94 |
prompt=prompt,
|
95 |
guidance_scale=1.2,
|
96 |
num_inference_steps=step,
|
97 |
-
num_frames=32,
|
98 |
callback=progress_callback,
|
99 |
callback_steps=1,
|
100 |
width=256,
|
101 |
height=256
|
102 |
)
|
103 |
|
104 |
-
# Chuẩn
|
105 |
-
frames = output.frames[0]
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
current_frames = len(frames)
|
112 |
-
if current_frames < target_frames:
|
113 |
-
repeat_factor = (target_frames + current_frames - 1) // current_frames
|
114 |
-
frames = np.repeat(frames, repeat_factor, axis=0)[:target_frames]
|
115 |
-
elif current_frames > target_frames:
|
116 |
frames = frames[:target_frames]
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
# Xuất video với imageio
|
121 |
name = str(uuid.uuid4()).replace("-", "")
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
125 |
|
126 |
-
#
|
127 |
-
|
128 |
-
raise FileNotFoundError(f"Video file {path} was not created.")
|
129 |
-
|
130 |
-
logger.info(f"Video generated successfully: {path}")
|
131 |
-
return path
|
132 |
|
133 |
except Exception as e:
|
134 |
-
logger.error(f"Error
|
135 |
raise
|
136 |
|
137 |
# Giao diện Gradio
|
|
|
57 |
logger.info(f"Generating video with prompt: {prompt}, base: {base}, steps: {step}")
|
58 |
|
59 |
try:
|
60 |
+
# Tải AnimateDiff Lightning checkpoint
|
61 |
if step_loaded != step:
|
62 |
repo = "ByteDance/AnimateDiff-Lightning"
|
63 |
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
|
64 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
|
65 |
step_loaded = step
|
66 |
|
67 |
+
# Tải mô hình cơ sở nếu thay đổi
|
68 |
if base_loaded != base:
|
69 |
pipe.unet.load_state_dict(
|
70 |
torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
|
|
|
72 |
)
|
73 |
base_loaded = base
|
74 |
|
75 |
+
# Tải motion LoRA nếu có
|
76 |
if motion_loaded != motion:
|
77 |
try:
|
78 |
pipe.unload_lora_weights()
|
79 |
+
if motion:
|
80 |
pipe.load_lora_weights(motion, adapter_name="motion")
|
81 |
pipe.set_adapters(["motion"], [0.7])
|
82 |
motion_loaded = motion
|
83 |
except Exception as e:
|
84 |
+
logger.warning(f"Failed to load motion LoRA: {e}")
|
85 |
motion_loaded = ""
|
86 |
|
87 |
progress((0, step))
|
88 |
def progress_callback(i, t, z):
|
89 |
progress((i + 1, step))
|
90 |
|
91 |
+
# Suy luận
|
92 |
with torch.no_grad():
|
93 |
output = pipe(
|
94 |
prompt=prompt,
|
95 |
guidance_scale=1.2,
|
96 |
num_inference_steps=step,
|
97 |
+
num_frames=32,
|
98 |
callback=progress_callback,
|
99 |
callback_steps=1,
|
100 |
width=256,
|
101 |
height=256
|
102 |
)
|
103 |
|
104 |
+
# Chuẩn hóa khung hình cho 8 giây
|
105 |
+
frames = output.frames[0]
|
106 |
+
fps = 24
|
107 |
+
target_frames = fps * 8
|
108 |
+
if len(frames) < target_frames:
|
109 |
+
frames = np.tile(frames, (target_frames // len(frames) + 1, 1, 1, 1))[:target_frames]
|
110 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
111 |
frames = frames[:target_frames]
|
112 |
|
113 |
+
# Tạo video
|
|
|
|
|
114 |
name = str(uuid.uuid4()).replace("-", "")
|
115 |
+
video_path = os.path.join(output_dir, f"{name}.mp4")
|
116 |
+
export_to_video(frames, video_path, fps=fps)
|
117 |
+
|
118 |
+
if not os.path.exists(video_path):
|
119 |
+
raise FileNotFoundError("❌ Video was not created")
|
120 |
+
|
121 |
+
logger.info(f"✅ Video ready at {video_path}")
|
122 |
|
123 |
+
# ✅ Trả về đối tượng `gr.File` để có link tải được từ client/frontend
|
124 |
+
return gr.File(video_path, file_name=f"{name}.mp4", show_download_button=True)
|
|
|
|
|
|
|
|
|
125 |
|
126 |
except Exception as e:
|
127 |
+
logger.error(f"❌ Error in generation: {e}")
|
128 |
raise
|
129 |
|
130 |
# Giao diện Gradio
|