tsi-org commited on
Commit
9f3e36b
·
verified ·
1 Parent(s): 32aeeee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -66
app.py CHANGED
@@ -264,15 +264,15 @@ pipeline.to(dtype=torch.float16).to(gpu)
264
 
265
  @torch.no_grad()
266
  @spaces.GPU
267
- def video_generation_handler_streaming(prompt, seed=42, fps=15, save_frames=True):
268
  """
269
- Generator function that yields .ts video chunks using PyAV for streaming.
270
- Now optimized for block-based processing.
271
  """
272
  if seed == -1:
273
  seed = random.randint(0, 2**32 - 1)
274
 
275
- print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
276
 
277
  # Setup
278
  conditional_dict = text_encoder(text_prompts=[prompt])
@@ -381,55 +381,56 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, save_frames=True
381
  f"</div>"
382
  )
383
 
384
- # Yield None for video but update status (frame-by-frame tracking)
385
- yield None, frame_status_html
386
 
387
- # Encode entire block as one chunk immediately
388
  if all_frames_from_block:
389
- print(f"📹 Encoding block {idx} with {len(all_frames_from_block)} frames")
390
 
391
- try:
392
- chunk_uuid = str(uuid.uuid4())[:8]
393
- ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
394
- ts_path = os.path.join("gradio_tmp", ts_filename)
395
-
396
- frames_to_ts_file(all_frames_from_block, ts_path, fps)
397
-
398
- # Calculate final progress for this block
399
- total_progress = (idx + 1) / num_blocks * 100
400
-
401
- # Yield the actual video chunk
402
- yield ts_path, gr.update()
403
-
404
- except Exception as e:
405
- print(f"⚠️ Error encoding block {idx}: {e}")
406
- import traceback
407
- traceback.print_exc()
408
 
409
  current_start_frame += current_num_frames
410
 
411
- # Final completion status
412
- final_status_html = (
413
- f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
414
- f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
415
- f" <span style='font-size: 24px; margin-right: 12px;'>🎉</span>"
416
- f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Stream Complete!</h4>"
417
- f" </div>"
418
- f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>"
419
- f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>"
420
- f" 📊 Generated {total_frames_yielded} frames across {num_blocks} blocks"
421
- f" </p>"
422
- f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
423
- f" 🎬 Playback: {fps} FPS • 📁 Format: MPEG-TS/H.264"
424
- f" </p>"
425
- f" </div>"
426
- f"</div>"
427
- )
428
- yield None, final_status_html
429
- print(f" PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  # Function to save frames as downloadable video
432
- def save_frames_as_video(frames, fps=15):
433
  """
434
  Convert frames to a downloadable MP4 video file.
435
 
@@ -444,8 +445,8 @@ def save_frames_as_video(frames, fps=15):
444
  print("No frames available to save")
445
  return None
446
 
447
- # Create a temporary file with a unique name
448
- temp_file = os.path.join("gradio_tmp", f"download_{uuid.uuid4()}.mp4")
449
 
450
  # Use PyAV for better quality and reliability
451
  try:
@@ -549,16 +550,21 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
549
  )
550
 
551
  with gr.Column(scale=3):
552
- gr.Markdown("### 📺 Video Stream")
553
 
554
- streaming_video = gr.Video(
555
- label="Live Stream",
556
- streaming=True,
557
- loop=True,
558
  height=400,
559
- autoplay=True,
560
  show_label=False,
561
- format="mp4" # Use more stable mp4 format when possible
 
 
 
 
 
 
 
562
  )
563
 
564
  status_display = gr.HTML(
@@ -574,29 +580,34 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
574
  download_btn = gr.Button("💾 Download Video", variant="secondary")
575
  download_output = gr.File(label="Download")
576
 
577
- # Define a wrapper function to ensure we always return the expected outputs
578
- def safe_video_generator(p, s, f):
579
  # Clear frames from previous generation
580
  APP_STATE["current_frames"] = []
581
 
582
- # Call the generator function and ensure it returns both video and status
 
 
 
583
  try:
584
- for video_path, status_html in video_generation_handler_streaming(p, s, f):
585
- if video_path is not None:
586
- yield video_path, status_html
587
- else:
588
- yield None, status_html
 
589
  except Exception as e:
590
  import traceback
591
  print(f"Error in generator: {e}")
592
  traceback.print_exc()
593
- yield None, f"<div style='color: red; padding: 10px; border: 1px solid #ffcccc; border-radius: 5px;'>Error: {e}</div>"
 
594
 
595
- # Connect the generator to the streaming video
596
  start_btn.click(
597
- fn=safe_video_generator,
598
  inputs=[prompt, seed, fps],
599
- outputs=[streaming_video, status_display]
600
  )
601
 
602
  # Function to handle download button click
 
264
 
265
  @torch.no_grad()
266
  @spaces.GPU
267
+ def video_generation_handler_frame_by_frame(prompt, seed=42, fps=15, save_frames=True):
268
  """
269
+ Generator function that yields individual frames and status updates.
270
+ No streaming - just frame by frame display.
271
  """
272
  if seed == -1:
273
  seed = random.randint(0, 2**32 - 1)
274
 
275
+ print(f"🎬 Starting frame-by-frame generation: '{prompt}', seed: {seed}")
276
 
277
  # Setup
278
  conditional_dict = text_encoder(text_prompts=[prompt])
 
381
  f"</div>"
382
  )
383
 
384
+ # No streaming - show the current frame and update status
385
+ yield frame_np, frame_status_html
386
 
387
+ # Save frames for download without streaming
388
  if all_frames_from_block:
389
+ print(f"💹 Processed block {idx} with {len(all_frames_from_block)} frames")
390
 
391
+ # We already yielded each frame individually for display
392
+ # No need to encode video chunks for streaming anymore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
  current_start_frame += current_num_frames
395
 
396
+ # Generate final video preview if we have frames
397
+ if APP_STATE["current_frames"]:
398
+ # Create a temporary preview file
399
+ preview_file = os.path.join("gradio_tmp", f"preview_{uuid.uuid4()}.mp4")
400
+ try:
401
+ # Save a preview video file
402
+ save_frames_as_video(APP_STATE["current_frames"], fps, preview_file)
403
+
404
+ # Final completion status with success message
405
+ final_status_html = (
406
+ f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
407
+ f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
408
+ f" <span style='font-size: 24px; margin-right: 12px;'>🎉</span>"
409
+ f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Generation Complete!</h4>"
410
+ f" </div>"
411
+ f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>"
412
+ f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>"
413
+ f" 📈 Generated {total_frames_yielded} frames across {num_blocks} blocks"
414
+ f" </p>"
415
+ f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
416
+ f" 🎬 Preview available • Click Download to save as MP4"
417
+ f" </p>"
418
+ f" </div>"
419
+ f"</div>"
420
+ )
421
+
422
+ # Return the last frame and completion message along with final video
423
+ yield APP_STATE["current_frames"][-1], final_status_html, gr.update(value=preview_file, visible=True)
424
+ except Exception as e:
425
+ print(f"Error creating preview: {e}")
426
+ # Just return the last frame and completion message
427
+ final_status_html = f"<div style='color: green; padding: 10px;'>Generation complete! {total_frames_yielded} frames generated. Ready to download.</div>"
428
+ yield APP_STATE["current_frames"][-1], final_status_html, gr.update(visible=False)
429
+
430
+ print(f"✅ Generation complete! {total_frames_yielded} frames across {num_blocks} blocks")
431
 
432
  # Function to save frames as downloadable video
433
+ def save_frames_as_video(frames, fps=15, output_path=None):
434
  """
435
  Convert frames to a downloadable MP4 video file.
436
 
 
445
  print("No frames available to save")
446
  return None
447
 
448
+ # Create a temporary file with a unique name or use provided path
449
+ temp_file = output_path if output_path else os.path.join("gradio_tmp", f"download_{uuid.uuid4()}.mp4")
450
 
451
  # Use PyAV for better quality and reliability
452
  try:
 
550
  )
551
 
552
  with gr.Column(scale=3):
553
+ gr.Markdown("### 📺 Video Preview")
554
 
555
+ # Replace streaming video with image display
556
+ preview_image = gr.Image(
557
+ label="Current Frame",
 
558
  height=400,
 
559
  show_label=False,
560
+ )
561
+
562
+ # Add a non-streaming video component for final result preview
563
+ final_video = gr.Video(
564
+ label="Final Video Preview",
565
+ visible=False,
566
+ autoplay=True,
567
+ loop=True
568
  )
569
 
570
  status_display = gr.HTML(
 
580
  download_btn = gr.Button("💾 Download Video", variant="secondary")
581
  download_output = gr.File(label="Download")
582
 
583
+ # Define a wrapper function to ensure proper handling of outputs
584
+ def safe_frame_generator(p, s, f):
585
  # Clear frames from previous generation
586
  APP_STATE["current_frames"] = []
587
 
588
+ # Reset the final video display
589
+ yield None, None, gr.update(visible=False)
590
+
591
+ # Call the generator function and yield frames, status, and video updates
592
  try:
593
+ for frame, status_html, *video_update in video_generation_handler_frame_by_frame(p, s, f):
594
+ # Ensure we always have three outputs
595
+ if not video_update:
596
+ video_update = [gr.update(visible=False)] # Default - hide video
597
+
598
+ yield frame, status_html, video_update[0]
599
  except Exception as e:
600
  import traceback
601
  print(f"Error in generator: {e}")
602
  traceback.print_exc()
603
+ error_html = f"<div style='color: red; padding: 10px; border: 1px solid #ffcccc; border-radius: 5px;'>Error: {e}</div>"
604
+ yield None, error_html, gr.update(visible=False)
605
 
606
+ # Connect the generator to the image and status display
607
  start_btn.click(
608
+ fn=safe_frame_generator,
609
  inputs=[prompt, seed, fps],
610
+ outputs=[preview_image, status_display, final_video]
611
  )
612
 
613
  # Function to handle download button click