hoangkha1810 commited on
Commit
15a723c
·
verified ·
1 Parent(s): 787c8db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -32
app.py CHANGED
@@ -57,14 +57,14 @@ def generate_image(prompt, base="Realistic", motion="", step=1, progress=gr.Prog
57
  logger.info(f"Generating video with prompt: {prompt}, base: {base}, steps: {step}")
58
 
59
  try:
60
- # Tải checkpoint AnimateDiff-Lightning
61
  if step_loaded != step:
62
  repo = "ByteDance/AnimateDiff-Lightning"
63
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
64
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
65
  step_loaded = step
66
 
67
- # Tải mô hình cơ sở
68
  if base_loaded != base:
69
  pipe.unet.load_state_dict(
70
  torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
@@ -72,66 +72,59 @@ def generate_image(prompt, base="Realistic", motion="", step=1, progress=gr.Prog
72
  )
73
  base_loaded = base
74
 
75
- # Tải motion LoRA (tùy chọn)
76
  if motion_loaded != motion:
77
  try:
78
  pipe.unload_lora_weights()
79
- if motion != "":
80
  pipe.load_lora_weights(motion, adapter_name="motion")
81
  pipe.set_adapters(["motion"], [0.7])
82
  motion_loaded = motion
83
  except Exception as e:
84
- logger.warning(f"Failed to load LoRA weights: {str(e)}. Proceeding without LoRA.")
85
  motion_loaded = ""
86
 
87
  progress((0, step))
88
  def progress_callback(i, t, z):
89
  progress((i + 1, step))
90
 
91
- # Tối ưu hóa suy luận
92
  with torch.no_grad():
93
  output = pipe(
94
  prompt=prompt,
95
  guidance_scale=1.2,
96
  num_inference_steps=step,
97
- num_frames=32, # Tạo 32 khung hình, sẽ lặp lại để đạt 192
98
  callback=progress_callback,
99
  callback_steps=1,
100
  width=256,
101
  height=256
102
  )
103
 
104
- # Chuẩn bị khung hình
105
- frames = output.frames[0] # frames là list các numpy arrays
106
- target_fps = 24
107
- target_duration = 8 # 8 giây
108
- target_frames = target_fps * target_duration # 192 khung hình
109
-
110
- # Lặp lại khung hình để đạt 192
111
- current_frames = len(frames)
112
- if current_frames < target_frames:
113
- repeat_factor = (target_frames + current_frames - 1) // current_frames
114
- frames = np.repeat(frames, repeat_factor, axis=0)[:target_frames]
115
- elif current_frames > target_frames:
116
  frames = frames[:target_frames]
117
 
118
- logger.info(f"Generated {len(frames)} frames for video.")
119
-
120
- # Xuất video với imageio
121
  name = str(uuid.uuid4()).replace("-", "")
122
- path = os.path.join(output_dir, f"{name}.mp4")
123
- logger.info(f"Saving video to {path}")
124
- export_to_video(frames, path, fps=target_fps)
 
 
 
 
125
 
126
- # Kiểm tra file video
127
- if not os.path.exists(path):
128
- raise FileNotFoundError(f"Video file {path} was not created.")
129
-
130
- logger.info(f"Video generated successfully: {path}")
131
- return path
132
 
133
  except Exception as e:
134
- logger.error(f"Error generating video: {str(e)}")
135
  raise
136
 
137
  # Giao diện Gradio
 
57
  logger.info(f"Generating video with prompt: {prompt}, base: {base}, steps: {step}")
58
 
59
  try:
60
+ # Tải AnimateDiff Lightning checkpoint
61
  if step_loaded != step:
62
  repo = "ByteDance/AnimateDiff-Lightning"
63
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
64
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
65
  step_loaded = step
66
 
67
+ # Tải mô hình cơ sở nếu thay đổi
68
  if base_loaded != base:
69
  pipe.unet.load_state_dict(
70
  torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
 
72
  )
73
  base_loaded = base
74
 
75
+ # Tải motion LoRA nếu
76
  if motion_loaded != motion:
77
  try:
78
  pipe.unload_lora_weights()
79
+ if motion:
80
  pipe.load_lora_weights(motion, adapter_name="motion")
81
  pipe.set_adapters(["motion"], [0.7])
82
  motion_loaded = motion
83
  except Exception as e:
84
+ logger.warning(f"Failed to load motion LoRA: {e}")
85
  motion_loaded = ""
86
 
87
  progress((0, step))
88
  def progress_callback(i, t, z):
89
  progress((i + 1, step))
90
 
91
+ # Suy luận
92
  with torch.no_grad():
93
  output = pipe(
94
  prompt=prompt,
95
  guidance_scale=1.2,
96
  num_inference_steps=step,
97
+ num_frames=32,
98
  callback=progress_callback,
99
  callback_steps=1,
100
  width=256,
101
  height=256
102
  )
103
 
104
+ # Chuẩn hóa khung hình cho 8 giây
105
+ frames = output.frames[0]
106
+ fps = 24
107
+ target_frames = fps * 8
108
+ if len(frames) < target_frames:
109
+ frames = np.tile(frames, (target_frames // len(frames) + 1, 1, 1, 1))[:target_frames]
110
+ else:
 
 
 
 
 
111
  frames = frames[:target_frames]
112
 
113
+ # Tạo video
 
 
114
  name = str(uuid.uuid4()).replace("-", "")
115
+ video_path = os.path.join(output_dir, f"{name}.mp4")
116
+ export_to_video(frames, video_path, fps=fps)
117
+
118
+ if not os.path.exists(video_path):
119
+ raise FileNotFoundError("❌ Video was not created")
120
+
121
+ logger.info(f"✅ Video ready at {video_path}")
122
 
123
+ # Trả về đối tượng `gr.File` để có link tải được từ client/frontend
124
+ return gr.File(video_path, file_name=f"{name}.mp4", show_download_button=True)
 
 
 
 
125
 
126
  except Exception as e:
127
+ logger.error(f"Error in generation: {e}")
128
  raise
129
 
130
  # Giao diện Gradio