Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -34,6 +34,7 @@ from tqdm import tqdm
|
|
34 |
import imageio
|
35 |
import av
|
36 |
import uuid
|
|
|
37 |
|
38 |
from pipeline import CausalInferencePipeline
|
39 |
from demo_utils.constant import ZERO_VAE_CACHE
|
@@ -68,7 +69,7 @@ T2V_CINEMATIC_PROMPT = \
|
|
68 |
'''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
|
69 |
'''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
|
70 |
'''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
|
71 |
-
'''4. Prompts should match the user
|
72 |
'''5. Emphasize motion information and different camera movements present in the input description;\n''' \
|
73 |
'''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
|
74 |
'''7. The revised prompt should be around 80-100 words long.\n''' \
|
@@ -146,75 +147,58 @@ APP_STATE = {
|
|
146 |
"fp8_applied": False,
|
147 |
"current_use_taehv": False,
|
148 |
"current_vae_decoder": None,
|
|
|
149 |
}
|
150 |
|
151 |
-
|
152 |
-
DOWNLOAD_FRAMES = []
|
153 |
-
|
154 |
-
def frames_to_mp4_chunk(frames, filepath, fps=15):
|
155 |
"""
|
156 |
-
Convert frames to
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
"""
|
158 |
if not frames:
|
159 |
return filepath
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
try:
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
return filepath
|
168 |
-
|
169 |
-
except Exception as e:
|
170 |
-
print(f"❌ Error creating MP4 chunk: {e}")
|
171 |
-
# Fallback to PyAV if imageio fails
|
172 |
-
try:
|
173 |
-
height, width = frames[0].shape[:2]
|
174 |
-
container = av.open(filepath, mode='w', format='mp4')
|
175 |
-
|
176 |
-
stream = container.add_stream('h264', rate=fps)
|
177 |
-
stream.width = width
|
178 |
-
stream.height = height
|
179 |
-
stream.pix_fmt = 'yuv420p'
|
180 |
-
stream.options = {
|
181 |
-
'preset': 'ultrafast',
|
182 |
-
'tune': 'zerolatency',
|
183 |
-
'crf': '28'
|
184 |
-
}
|
185 |
-
|
186 |
-
for frame_np in frames:
|
187 |
-
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
188 |
-
frame = frame.reformat(format=stream.pix_fmt)
|
189 |
-
for packet in stream.encode(frame):
|
190 |
-
container.mux(packet)
|
191 |
-
|
192 |
-
for packet in stream.encode():
|
193 |
container.mux(packet)
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
def create_download_mp4():
|
203 |
-
global DOWNLOAD_FRAMES
|
204 |
-
if not DOWNLOAD_FRAMES:
|
205 |
-
return None
|
206 |
-
try:
|
207 |
-
os.makedirs("downloads", exist_ok=True)
|
208 |
-
timestamp = int(time.time())
|
209 |
-
mp4_path = f"downloads/video_{timestamp}.mp4"
|
210 |
-
with imageio.get_writer(mp4_path, fps=args.fps, codec='libx264', quality=8) as writer:
|
211 |
-
for frame in DOWNLOAD_FRAMES:
|
212 |
-
writer.append_data(frame)
|
213 |
-
print(f"✅ Download MP4 created: {mp4_path}")
|
214 |
-
return mp4_path
|
215 |
-
except Exception as e:
|
216 |
-
print(f"❌ Download error: {e}")
|
217 |
-
return None
|
218 |
|
219 |
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
220 |
if use_trt:
|
@@ -275,17 +259,15 @@ pipeline.to(dtype=torch.float16).to(gpu)
|
|
275 |
|
276 |
@torch.no_grad()
|
277 |
@spaces.GPU
|
278 |
-
def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
279 |
"""
|
280 |
-
Generator function that yields
|
|
|
281 |
"""
|
282 |
-
global DOWNLOAD_FRAMES
|
283 |
-
DOWNLOAD_FRAMES = [] # Reset frames
|
284 |
-
|
285 |
if seed == -1:
|
286 |
seed = random.randint(0, 2**32 - 1)
|
287 |
|
288 |
-
print(f"🎬 Starting
|
289 |
|
290 |
# Setup
|
291 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
@@ -372,13 +354,14 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
372 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
373 |
|
374 |
all_frames_from_block.append(frame_np)
|
375 |
-
DOWNLOAD_FRAMES.append(frame_np) # Store for download
|
376 |
total_frames_yielded += 1
|
377 |
|
378 |
-
# Yield status update for each frame
|
379 |
blocks_completed = idx
|
380 |
current_block_progress = (frame_idx + 1) / pixels.shape[1]
|
381 |
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
|
|
|
|
|
382 |
total_progress = min(total_progress, 100.0)
|
383 |
|
384 |
frame_status_html = (
|
@@ -393,21 +376,25 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
393 |
f"</div>"
|
394 |
)
|
395 |
|
|
|
396 |
yield None, frame_status_html
|
397 |
|
398 |
-
#
|
399 |
if all_frames_from_block:
|
400 |
print(f"📹 Encoding block {idx} with {len(all_frames_from_block)} frames")
|
401 |
|
402 |
try:
|
403 |
chunk_uuid = str(uuid.uuid4())[:8]
|
404 |
-
|
405 |
-
|
406 |
|
407 |
-
|
408 |
|
409 |
-
#
|
410 |
-
|
|
|
|
|
|
|
411 |
|
412 |
except Exception as e:
|
413 |
print(f"⚠️ Error encoding block {idx}: {e}")
|
@@ -428,13 +415,41 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
428 |
f" 📊 Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
429 |
f" </p>"
|
430 |
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
431 |
-
f" 🎬 Playback: {fps} FPS • 📁 Format:
|
432 |
f" </p>"
|
433 |
f" </div>"
|
434 |
f"</div>"
|
435 |
)
|
436 |
yield None, final_status_html
|
437 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
# --- Gradio UI Layout ---
|
440 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
@@ -504,20 +519,31 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
504 |
label="Generation Status"
|
505 |
)
|
506 |
|
507 |
-
|
508 |
-
|
509 |
-
label="📥 Download MP4",
|
510 |
-
value=create_download_mp4,
|
511 |
-
variant="secondary"
|
512 |
-
)
|
513 |
|
514 |
# Connect the generator to the streaming video
|
515 |
start_btn.click(
|
516 |
-
fn=video_generation_handler_streaming,
|
517 |
inputs=[prompt, seed, fps],
|
518 |
outputs=[streaming_video, status_display]
|
519 |
)
|
520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
enhance_button.click(
|
522 |
fn=enhance_prompt,
|
523 |
inputs=[prompt],
|
@@ -530,12 +556,10 @@ if __name__ == "__main__":
|
|
530 |
import shutil
|
531 |
shutil.rmtree("gradio_tmp")
|
532 |
os.makedirs("gradio_tmp", exist_ok=True)
|
533 |
-
os.makedirs("downloads", exist_ok=True)
|
534 |
|
535 |
print("🚀 Starting Self-Forcing Streaming Demo")
|
536 |
print(f"📁 Temporary files will be stored in: gradio_tmp/")
|
537 |
-
print(f"
|
538 |
-
print(f"🎯 Chunk encoding: MP4/H.264 (more compatible)")
|
539 |
print(f"⚡ GPU acceleration: {gpu}")
|
540 |
|
541 |
demo.queue().launch(
|
@@ -546,8 +570,6 @@ if __name__ == "__main__":
|
|
546 |
max_threads=40,
|
547 |
mcp_server=True
|
548 |
)
|
549 |
-
|
550 |
-
|
551 |
# import subprocess
|
552 |
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
553 |
|
|
|
34 |
import imageio
|
35 |
import av
|
36 |
import uuid
|
37 |
+
import tempfile
|
38 |
|
39 |
from pipeline import CausalInferencePipeline
|
40 |
from demo_utils.constant import ZERO_VAE_CACHE
|
|
|
69 |
'''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
|
70 |
'''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
|
71 |
'''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
|
72 |
+
'''4. Prompts should match the user’s intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
|
73 |
'''5. Emphasize motion information and different camera movements present in the input description;\n''' \
|
74 |
'''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
|
75 |
'''7. The revised prompt should be around 80-100 words long.\n''' \
|
|
|
147 |
"fp8_applied": False,
|
148 |
"current_use_taehv": False,
|
149 |
"current_vae_decoder": None,
|
150 |
+
"current_frames": [],
|
151 |
}
|
152 |
|
153 |
+
def frames_to_ts_file(frames, filepath, fps = 15):
|
|
|
|
|
|
|
154 |
"""
|
155 |
+
Convert frames directly to .ts file using PyAV.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
frames: List of numpy arrays (HWC, RGB, uint8)
|
159 |
+
filepath: Output file path
|
160 |
+
fps: Frames per second
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
The filepath of the created file
|
164 |
"""
|
165 |
if not frames:
|
166 |
return filepath
|
167 |
|
168 |
+
height, width = frames[0].shape[:2]
|
169 |
+
|
170 |
+
# Create container for MPEG-TS format
|
171 |
+
container = av.open(filepath, mode='w', format='mpegts')
|
172 |
+
|
173 |
+
# Add video stream with optimized settings for streaming
|
174 |
+
stream = container.add_stream('h264', rate=fps)
|
175 |
+
stream.width = width
|
176 |
+
stream.height = height
|
177 |
+
stream.pix_fmt = 'yuv420p'
|
178 |
+
|
179 |
+
# Optimize for low latency streaming
|
180 |
+
stream.options = {
|
181 |
+
'preset': 'ultrafast',
|
182 |
+
'tune': 'zerolatency',
|
183 |
+
'crf': '23',
|
184 |
+
'profile': 'baseline',
|
185 |
+
'level': '3.0'
|
186 |
+
}
|
187 |
+
|
188 |
try:
|
189 |
+
for frame_np in frames:
|
190 |
+
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
191 |
+
frame = frame.reformat(format=stream.pix_fmt)
|
192 |
+
for packet in stream.encode(frame):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
container.mux(packet)
|
194 |
+
|
195 |
+
for packet in stream.encode():
|
196 |
+
container.mux(packet)
|
197 |
|
198 |
+
finally:
|
199 |
+
container.close()
|
200 |
+
|
201 |
+
return filepath
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
204 |
if use_trt:
|
|
|
259 |
|
260 |
@torch.no_grad()
|
261 |
@spaces.GPU
|
262 |
+
def video_generation_handler_streaming(prompt, seed=42, fps=15, save_frames=True):
|
263 |
"""
|
264 |
+
Generator function that yields .ts video chunks using PyAV for streaming.
|
265 |
+
Now optimized for block-based processing.
|
266 |
"""
|
|
|
|
|
|
|
267 |
if seed == -1:
|
268 |
seed = random.randint(0, 2**32 - 1)
|
269 |
|
270 |
+
print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
|
271 |
|
272 |
# Setup
|
273 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
|
|
354 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
355 |
|
356 |
all_frames_from_block.append(frame_np)
|
|
|
357 |
total_frames_yielded += 1
|
358 |
|
359 |
+
# Yield status update for each frame (cute tracking!)
|
360 |
blocks_completed = idx
|
361 |
current_block_progress = (frame_idx + 1) / pixels.shape[1]
|
362 |
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
|
363 |
+
|
364 |
+
# Cap at 100% to avoid going over
|
365 |
total_progress = min(total_progress, 100.0)
|
366 |
|
367 |
frame_status_html = (
|
|
|
376 |
f"</div>"
|
377 |
)
|
378 |
|
379 |
+
# Yield None for video but update status (frame-by-frame tracking)
|
380 |
yield None, frame_status_html
|
381 |
|
382 |
+
# Encode entire block as one chunk immediately
|
383 |
if all_frames_from_block:
|
384 |
print(f"📹 Encoding block {idx} with {len(all_frames_from_block)} frames")
|
385 |
|
386 |
try:
|
387 |
chunk_uuid = str(uuid.uuid4())[:8]
|
388 |
+
ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
|
389 |
+
ts_path = os.path.join("gradio_tmp", ts_filename)
|
390 |
|
391 |
+
frames_to_ts_file(all_frames_from_block, ts_path, fps)
|
392 |
|
393 |
+
# Calculate final progress for this block
|
394 |
+
total_progress = (idx + 1) / num_blocks * 100
|
395 |
+
|
396 |
+
# Yield the actual video chunk
|
397 |
+
yield ts_path, gr.update()
|
398 |
|
399 |
except Exception as e:
|
400 |
print(f"⚠️ Error encoding block {idx}: {e}")
|
|
|
415 |
f" 📊 Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
416 |
f" </p>"
|
417 |
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
418 |
+
f" 🎬 Playback: {fps} FPS • 📁 Format: MPEG-TS/H.264"
|
419 |
f" </p>"
|
420 |
f" </div>"
|
421 |
f"</div>"
|
422 |
)
|
423 |
yield None, final_status_html
|
424 |
+
print(f" PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
425 |
+
|
426 |
+
def save_frames_as_video(frames, fps=15):
|
427 |
+
"""
|
428 |
+
Convert frames to a downloadable MP4 video file.
|
429 |
+
|
430 |
+
Args:
|
431 |
+
frames: List of numpy arrays (HWC, RGB, uint8)
|
432 |
+
fps: Frames per second
|
433 |
+
|
434 |
+
Returns:
|
435 |
+
Path to the saved video file
|
436 |
+
"""
|
437 |
+
if not frames:
|
438 |
+
return None
|
439 |
+
|
440 |
+
# Create a temporary file with a unique name
|
441 |
+
temp_file = os.path.join("gradio_tmp", f"download_{uuid.uuid4()}.mp4")
|
442 |
+
|
443 |
+
# Use imageio to write the video file
|
444 |
+
try:
|
445 |
+
writer = imageio.get_writer(temp_file, fps=fps, codec='h264', quality=9)
|
446 |
+
for frame in frames:
|
447 |
+
writer.append_data(frame)
|
448 |
+
writer.close()
|
449 |
+
return temp_file
|
450 |
+
except Exception as e:
|
451 |
+
print(f"Error saving video: {e}")
|
452 |
+
return None
|
453 |
|
454 |
# --- Gradio UI Layout ---
|
455 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
|
519 |
label="Generation Status"
|
520 |
)
|
521 |
|
522 |
+
download_btn = gr.Button("💾 Download Video", variant="secondary")
|
523 |
+
download_output = gr.File(label="Download")
|
|
|
|
|
|
|
|
|
524 |
|
525 |
# Connect the generator to the streaming video
|
526 |
start_btn.click(
|
527 |
+
fn=lambda p, s, f: (APP_STATE.update({"current_frames": []}) or video_generation_handler_streaming(p, s, f)),
|
528 |
inputs=[prompt, seed, fps],
|
529 |
outputs=[streaming_video, status_display]
|
530 |
)
|
531 |
|
532 |
+
# Function to handle download button click
|
533 |
+
def download_video(fps):
|
534 |
+
if not APP_STATE.get("current_frames"):
|
535 |
+
return None
|
536 |
+
video_path = save_frames_as_video(APP_STATE["current_frames"], fps)
|
537 |
+
return video_path
|
538 |
+
|
539 |
+
# Connect download button
|
540 |
+
download_btn.click(
|
541 |
+
fn=download_video,
|
542 |
+
inputs=[fps],
|
543 |
+
outputs=[download_output],
|
544 |
+
show_progress=True
|
545 |
+
)
|
546 |
+
|
547 |
enhance_button.click(
|
548 |
fn=enhance_prompt,
|
549 |
inputs=[prompt],
|
|
|
556 |
import shutil
|
557 |
shutil.rmtree("gradio_tmp")
|
558 |
os.makedirs("gradio_tmp", exist_ok=True)
|
|
|
559 |
|
560 |
print("🚀 Starting Self-Forcing Streaming Demo")
|
561 |
print(f"📁 Temporary files will be stored in: gradio_tmp/")
|
562 |
+
print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
|
|
|
563 |
print(f"⚡ GPU acceleration: {gpu}")
|
564 |
|
565 |
demo.queue().launch(
|
|
|
570 |
max_threads=40,
|
571 |
mcp_server=True
|
572 |
)
|
|
|
|
|
573 |
# import subprocess
|
574 |
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
575 |
|