tsi-org commited on
Commit
0ebd095
·
verified ·
1 Parent(s): 8f69df8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -29
app.py CHANGED
@@ -653,31 +653,52 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
653
  label="Generation Status"
654
  )
655
 
656
- # Add download button and output
657
- download_btn = gr.Button("💾 Download Video", variant="secondary")
658
- download_output = gr.File(label="Download")
659
-
660
  # Define a wrapper function to ensure proper handling of outputs
661
  def safe_frame_generator(p, s, f):
662
  # Clear frames from previous generation
663
- APP_STATE.update({"current_frames": []})
664
 
665
  # Reset the final video display
666
  yield None, None, gr.update(visible=False)
667
 
 
 
 
668
  # Call the generator function and yield frames, status, and video updates
669
  try:
670
- for frame, status_html, *video_update in video_generation_handler_streaming(p, s, f):
671
- # Ensure we always have three outputs
672
- if not video_update:
673
- video_update = [gr.update(visible=False)] # Default - hide video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
 
675
- yield frame, status_html, video_update[0]
676
  except Exception as e:
677
  import traceback
678
- print(f"Error in generator: {e}")
679
  traceback.print_exc()
680
- error_html = f"<div style='color: red; padding: 10px; border: 1px solid #ffcccc; border-radius: 5px;'>Error: {e}</div>"
681
  yield None, error_html, gr.update(visible=False)
682
 
683
  # Connect the generator to the streaming video
@@ -687,23 +708,7 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
687
  outputs=[streaming_video, status_display, final_video]
688
  )
689
 
690
- # Function to handle download button click
691
- def download_video(fps):
692
- if not APP_STATE.get("current_frames"):
693
- return None
694
- video_path = save_frames_as_video(APP_STATE["current_frames"], fps)
695
- return video_path
696
-
697
- # Connect download button
698
- download_btn.click(
699
- fn=download_video,
700
- inputs=[fps],
701
- outputs=[download_output],
702
- show_progress=True,
703
- api_name="download_video" # Make it accessible via API
704
- )
705
-
706
- # Make the FPS slider visible for download quality control
707
  fps.visible = True
708
 
709
  enhance_button.click(
 
653
  label="Generation Status"
654
  )
655
 
 
 
 
 
656
  # Define a wrapper function to ensure proper handling of outputs
657
  def safe_frame_generator(p, s, f):
658
  # Clear frames from previous generation
659
+ APP_STATE["current_frames"] = []
660
 
661
  # Reset the final video display
662
  yield None, None, gr.update(visible=False)
663
 
664
+ # Collect all frames from this generation
665
+ collected_frames = []
666
+
667
  # Call the generator function and yield frames, status, and video updates
668
  try:
669
+ # Set save_frames=True to explicitly ensure frames are collected
670
+ for frame, status_html in video_generation_handler_streaming(p, s, f, save_frames=True):
671
+ # Track frames for this specific session
672
+ if frame is not None and isinstance(frame, np.ndarray):
673
+ collected_frames.append(frame.copy())
674
+
675
+ # Show status update during generation
676
+ yield frame, status_html, gr.update(visible=False)
677
+
678
+ # Check if this is the final frame
679
+ if "Complete" in str(status_html) or "100%" in str(status_html):
680
+ # Create the final video
681
+ if collected_frames:
682
+ print(f"Creating final video from {len(collected_frames)} frames at {f} FPS")
683
+ temp_file = save_frames_as_video(collected_frames, f)
684
+ if temp_file:
685
+ # Save these frames as the current set
686
+ APP_STATE["current_frames"] = collected_frames
687
+ yield frame, status_html, gr.update(visible=True, value=temp_file)
688
+ else:
689
+ yield frame, status_html, gr.update(visible=False)
690
+
691
+ # Ensure final frame is properly handled
692
+ if collected_frames and "Complete" not in str(status_html) and "100%" not in str(status_html):
693
+ print(f"Generation complete, creating final video from {len(collected_frames)} frames")
694
+ temp_file = save_frames_as_video(collected_frames, f)
695
+ if temp_file:
696
+ yield frame, status_html, gr.update(visible=True, value=temp_file)
697
 
 
698
  except Exception as e:
699
  import traceback
 
700
  traceback.print_exc()
701
+ error_html = f"<div style='color: red; padding: 10px; border: 1px solid #ffcccc; border-radius: 5px;'>Error: {str(e)}</div>"
702
  yield None, error_html, gr.update(visible=False)
703
 
704
  # Connect the generator to the streaming video
 
708
  outputs=[streaming_video, status_display, final_video]
709
  )
710
 
711
+ # Make the FPS slider visible for video quality control
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  fps.visible = True
713
 
714
  enhance_button.click(