svjack commited on
Commit
b742f12
·
verified ·
1 Parent(s): d203919

Update app_df.py

Browse files
Files changed (1) hide show
  1. app_df.py +224 -0
app_df.py CHANGED
@@ -4,6 +4,230 @@ pip uninstall -y torch torchvision xformers && pip install torch==2.5.0 torchvis
4
  pip install flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
5
  '''
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import os
8
  import gc
9
  import time
 
4
  pip install flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
5
  '''
6
 
7
+ '''
8
+ import os
9
+ import gc
10
+ import time
11
+ import random
12
+ import torch
13
+ import imageio
14
+ from diffusers.utils import load_image
15
+ from skyreels_v2_infer import DiffusionForcingPipeline
16
+ from skyreels_v2_infer.modules import download_model
17
+ from skyreels_v2_infer.pipelines import PromptEnhancer, resizecrop
18
+
19
+ # ---------------------
20
+ # 全局初始化部分(只执行一次)
21
+ # ---------------------
22
+
23
+ is_shared_ui = True
24
+ model_id = download_model("Skywork/SkyReels-V2-DF-1.3B-540P") if is_shared_ui else None
25
+
26
+ # 预设分辨率参数
27
+ RESOLUTION_CONFIG = {
28
+ "540P": (544, 960),
29
+ "720P": (720, 1280)
30
+ }
31
+
32
+ # 负向提示词(固定)
33
+ negative_prompt = (
34
+ "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, "
35
+ "overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, "
36
+ "poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, "
37
+ "three legs, many people in the background, walking backwards"
38
+ )
39
+
40
+ # 初始化 pipeline(只初始化一次)
41
+ pipe = DiffusionForcingPipeline(
42
+ model_id,
43
+ dit_path=model_id,
44
+ device=torch.device("cuda"),
45
+ weight_dtype=torch.bfloat16,
46
+ use_usp=False,
47
+ offload=True,
48
+ )
49
+ # ---------------------
50
+ # 函数定义部分
51
+ # ---------------------
52
+
53
+
54
+ def generate_diffusion_forced_video(
55
+ prompt,
56
+ image=None,
57
+ target_length="10",
58
+ model_id="Skywork/SkyReels-V2-DF-1.3B-540P",
59
+ resolution="540P",
60
+ num_frames=257,
61
+ ar_step=0,
62
+ causal_attention=False,
63
+ causal_block_size=1,
64
+ base_num_frames=97,
65
+ overlap_history=17,
66
+ addnoise_condition=20,
67
+ guidance_scale=6.0,
68
+ shift=8.0,
69
+ inference_steps=30,
70
+ use_usp=False,
71
+ offload=True,
72
+ fps=24,
73
+ seed=None,
74
+ prompt_enhancer=False,
75
+ teacache=True,
76
+ teacache_thresh=0.2,
77
+ use_ret_steps=True,
78
+ ):
79
+ """
80
+ 使用已初始化的 pipeline 进行视频生成,仅需传入动态参数
81
+ """
82
+ # 获取分辨率
83
+ if resolution not in RESOLUTION_CONFIG:
84
+ raise ValueError(f"Invalid resolution: {resolution}")
85
+ height, width = RESOLUTION_CONFIG[resolution]
86
+
87
+ # 设置种子
88
+ if seed is None:
89
+ random.seed(time.time())
90
+ seed = int(random.randrange(4294967294))
91
+
92
+ # 检查长视频参数
93
+ if num_frames > base_num_frames and overlap_history is None:
94
+ raise ValueError("Specify `overlap_history` for long video generation. Try 17 or 37.")
95
+ if addnoise_condition > 60:
96
+ print("Warning: Large `addnoise_condition` may reduce consistency. Recommended: 20.")
97
+
98
+ # 图像处理
99
+ pil_image = None
100
+ if image is not None:
101
+ pil_image = load_image(image).convert("RGB")
102
+ image_width, image_height = pil_image.size
103
+ if image_height > image_width:
104
+ height, width = width, height
105
+ pil_image = resizecrop(pil_image, height, width)
106
+
107
+ # 提示词增强
108
+ prompt_input = prompt
109
+ if prompt_enhancer and pil_image is None:
110
+ enhancer = PromptEnhancer()
111
+ prompt_input = enhancer(prompt_input)
112
+ del enhancer
113
+ gc.collect()
114
+ torch.cuda.empty_cache()
115
+
116
+ # TeaCache 初始化(如启用)
117
+ if teacache:
118
+ if ar_step > 0:
119
+ num_steps = (
120
+ inference_steps + (((base_num_frames - 1) // 4 + 1) // causal_block_size - 1) * ar_step
121
+ )
122
+ else:
123
+ num_steps = inference_steps
124
+ pipe.transformer.initialize_teacache(
125
+ enable_teacache=True,
126
+ num_steps=num_steps,
127
+ teacache_thresh=teacache_thresh,
128
+ use_ret_steps=use_ret_steps,
129
+ ckpt_dir=model_id,
130
+ )
131
+
132
+ # 是否开启因果注意力
133
+ if causal_attention:
134
+ pipe.transformer.set_ar_attention(causal_block_size)
135
+
136
+ # 生成视频
137
+ with torch.amp.autocast("cuda", dtype=pipe.transformer.dtype), torch.no_grad():
138
+ video_frames = pipe(
139
+ prompt=prompt_input,
140
+ negative_prompt=negative_prompt,
141
+ image=pil_image,
142
+ height=height,
143
+ width=width,
144
+ num_frames=num_frames,
145
+ num_inference_steps=inference_steps,
146
+ shift=shift,
147
+ guidance_scale=guidance_scale,
148
+ generator=torch.Generator(device="cuda").manual_seed(seed),
149
+ overlap_history=overlap_history,
150
+ addnoise_condition=addnoise_condition,
151
+ base_num_frames=base_num_frames,
152
+ ar_step=ar_step,
153
+ causal_block_size=causal_block_size,
154
+ fps=fps,
155
+ )[0]
156
+
157
+ # 保存视频
158
+ os.makedirs("gradio_df_videos", exist_ok=True)
159
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
160
+ output_path = f"gradio_df_videos/{prompt[:50].replace('/', '')}_{seed}_{timestamp}.mp4"
161
+ imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
162
+ return output_path
163
+
164
+ import os
165
+ from datasets import load_dataset
166
+ from PIL import Image
167
+ from diffusers.utils import load_image
168
+
169
+ # 加载数据集
170
+ dataset = load_dataset("svjack/Mavuika_PosterCraft_Product_Posters_WAV")["train"]
171
+
172
+ # 初始化输出目录
173
+ output_dir = "Mavuika_generated_videos"
174
+ os.makedirs(output_dir, exist_ok=True)
175
+
176
+ # 循环遍历数据集
177
+ for idx, item in enumerate(dataset):
178
+ try:
179
+ # 获取图像和提示词
180
+ pil_image = item["postercraft_image"]
181
+ prompt = item["final_prompt"]
182
+
183
+ # 保存原始图片为临时文件供 generate_diffusion_forced_video 使用
184
+ temp_input_path = f"temp_input_{idx:04d}.png"
185
+ pil_image.resize((544, 960)).save(temp_input_path)
186
+
187
+ # 调用视频生成函数
188
+ video_path = generate_diffusion_forced_video(
189
+ prompt=prompt,
190
+ image=temp_input_path,
191
+ target_length="4", # 可选参数,实际使用 height/width 控制长度
192
+ model_id="Skywork/SkyReels-V2-DF-1.3B-540P",
193
+ resolution="540P",
194
+ num_frames=97,
195
+ ar_step=0,
196
+ causal_attention=False,
197
+ causal_block_size=1,
198
+ base_num_frames=97,
199
+ overlap_history=3,
200
+ addnoise_condition=0,
201
+ guidance_scale=6,
202
+ shift=8,
203
+ inference_steps=30,
204
+ use_usp=False,
205
+ offload=True,
206
+ fps=24,
207
+ seed=None,
208
+ prompt_enhancer=False,
209
+ teacache=True,
210
+ teacache_thresh=0.2,
211
+ use_ret_steps=True,
212
+ )
213
+
214
+ # 构建输出路径
215
+ output_video_path = os.path.join(output_dir, f"{idx:04d}.mp4")
216
+ output_txt_path = os.path.join(output_dir, f"{idx:04d}.txt")
217
+
218
+ # 移动视频文件到输出目录
219
+ os.rename(video_path, output_video_path)
220
+
221
+ # 保存 prompt 到 .txt 文件
222
+ with open(output_txt_path, 'w', encoding='utf-8') as f:
223
+ f.write(prompt)
224
+
225
+ print(f"✅ 已生成并保存:{output_video_path}")
226
+
227
+ except Exception as e:
228
+ print(f"❌ 处理第 {idx} 张图片时出错: {e}")
229
+ '''
230
+
231
  import os
232
  import gc
233
  import time