tsi-org commited on
Commit
54e6494
·
verified ·
1 Parent(s): a30355f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -88
app.py CHANGED
@@ -34,6 +34,7 @@ from tqdm import tqdm
34
  import imageio
35
  import av
36
  import uuid
 
37
 
38
  from pipeline import CausalInferencePipeline
39
  from demo_utils.constant import ZERO_VAE_CACHE
@@ -68,7 +69,7 @@ T2V_CINEMATIC_PROMPT = \
68
  '''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
69
  '''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
70
  '''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
71
- '''4. Prompts should match the user's intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
72
  '''5. Emphasize motion information and different camera movements present in the input description;\n''' \
73
  '''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
74
  '''7. The revised prompt should be around 80-100 words long.\n''' \
@@ -146,75 +147,58 @@ APP_STATE = {
146
  "fp8_applied": False,
147
  "current_use_taehv": False,
148
  "current_vae_decoder": None,
 
149
  }
150
 
151
- # Store frames for download
152
- DOWNLOAD_FRAMES = []
153
-
154
- def frames_to_mp4_chunk(frames, filepath, fps=15):
155
  """
156
- Convert frames to MP4 chunk using imageio (more compatible than .ts for Gradio streaming)
 
 
 
 
 
 
 
 
157
  """
158
  if not frames:
159
  return filepath
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  try:
162
- # Use imageio to create MP4 chunk
163
- with imageio.get_writer(filepath, fps=fps, codec='libx264', quality=6) as writer:
164
- for frame_np in frames:
165
- writer.append_data(frame_np)
166
-
167
- return filepath
168
-
169
- except Exception as e:
170
- print(f"❌ Error creating MP4 chunk: {e}")
171
- # Fallback to PyAV if imageio fails
172
- try:
173
- height, width = frames[0].shape[:2]
174
- container = av.open(filepath, mode='w', format='mp4')
175
-
176
- stream = container.add_stream('h264', rate=fps)
177
- stream.width = width
178
- stream.height = height
179
- stream.pix_fmt = 'yuv420p'
180
- stream.options = {
181
- 'preset': 'ultrafast',
182
- 'tune': 'zerolatency',
183
- 'crf': '28'
184
- }
185
-
186
- for frame_np in frames:
187
- frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
188
- frame = frame.reformat(format=stream.pix_fmt)
189
- for packet in stream.encode(frame):
190
- container.mux(packet)
191
-
192
- for packet in stream.encode():
193
  container.mux(packet)
194
-
195
- container.close()
196
- return filepath
197
 
198
- except Exception as e2:
199
- print(f"❌ Both imageio and PyAV failed: {e2}")
200
- return filepath
201
-
202
- def create_download_mp4():
203
- global DOWNLOAD_FRAMES
204
- if not DOWNLOAD_FRAMES:
205
- return None
206
- try:
207
- os.makedirs("downloads", exist_ok=True)
208
- timestamp = int(time.time())
209
- mp4_path = f"downloads/video_{timestamp}.mp4"
210
- with imageio.get_writer(mp4_path, fps=args.fps, codec='libx264', quality=8) as writer:
211
- for frame in DOWNLOAD_FRAMES:
212
- writer.append_data(frame)
213
- print(f"✅ Download MP4 created: {mp4_path}")
214
- return mp4_path
215
- except Exception as e:
216
- print(f"❌ Download error: {e}")
217
- return None
218
 
219
  def initialize_vae_decoder(use_taehv=False, use_trt=False):
220
  if use_trt:
@@ -275,17 +259,15 @@ pipeline.to(dtype=torch.float16).to(gpu)
275
 
276
  @torch.no_grad()
277
  @spaces.GPU
278
- def video_generation_handler_streaming(prompt, seed=42, fps=15):
279
  """
280
- Generator function that yields MP4 video chunks for streaming.
 
281
  """
282
- global DOWNLOAD_FRAMES
283
- DOWNLOAD_FRAMES = [] # Reset frames
284
-
285
  if seed == -1:
286
  seed = random.randint(0, 2**32 - 1)
287
 
288
- print(f"🎬 Starting MP4 streaming: '{prompt}', seed: {seed}")
289
 
290
  # Setup
291
  conditional_dict = text_encoder(text_prompts=[prompt])
@@ -372,13 +354,14 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
372
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
373
 
374
  all_frames_from_block.append(frame_np)
375
- DOWNLOAD_FRAMES.append(frame_np) # Store for download
376
  total_frames_yielded += 1
377
 
378
- # Yield status update for each frame
379
  blocks_completed = idx
380
  current_block_progress = (frame_idx + 1) / pixels.shape[1]
381
  total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
 
 
382
  total_progress = min(total_progress, 100.0)
383
 
384
  frame_status_html = (
@@ -393,21 +376,25 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
393
  f"</div>"
394
  )
395
 
 
396
  yield None, frame_status_html
397
 
398
- # Create MP4 chunk for this block
399
  if all_frames_from_block:
400
  print(f"📹 Encoding block {idx} with {len(all_frames_from_block)} frames")
401
 
402
  try:
403
  chunk_uuid = str(uuid.uuid4())[:8]
404
- mp4_filename = f"block_{idx:04d}_{chunk_uuid}.mp4"
405
- mp4_path = os.path.join("gradio_tmp", mp4_filename)
406
 
407
- frames_to_mp4_chunk(all_frames_from_block, mp4_path, fps)
408
 
409
- # Yield the MP4 chunk
410
- yield mp4_path, gr.update()
 
 
 
411
 
412
  except Exception as e:
413
  print(f"⚠️ Error encoding block {idx}: {e}")
@@ -428,13 +415,41 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
428
  f" 📊 Generated {total_frames_yielded} frames across {num_blocks} blocks"
429
  f" </p>"
430
  f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
431
- f" 🎬 Playback: {fps} FPS • 📁 Format: MP4/H.264 • 📥 Download ready!"
432
  f" </p>"
433
  f" </div>"
434
  f"</div>"
435
  )
436
  yield None, final_status_html
437
- print(f" MP4 streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
  # --- Gradio UI Layout ---
440
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
@@ -504,20 +519,31 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
504
  label="Generation Status"
505
  )
506
 
507
- # Download button
508
- download_btn = gr.DownloadButton(
509
- label="📥 Download MP4",
510
- value=create_download_mp4,
511
- variant="secondary"
512
- )
513
 
514
  # Connect the generator to the streaming video
515
  start_btn.click(
516
- fn=video_generation_handler_streaming,
517
  inputs=[prompt, seed, fps],
518
  outputs=[streaming_video, status_display]
519
  )
520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  enhance_button.click(
522
  fn=enhance_prompt,
523
  inputs=[prompt],
@@ -530,12 +556,10 @@ if __name__ == "__main__":
530
  import shutil
531
  shutil.rmtree("gradio_tmp")
532
  os.makedirs("gradio_tmp", exist_ok=True)
533
- os.makedirs("downloads", exist_ok=True)
534
 
535
  print("🚀 Starting Self-Forcing Streaming Demo")
536
  print(f"📁 Temporary files will be stored in: gradio_tmp/")
537
- print(f"📥 Download files will be stored in: downloads/")
538
- print(f"🎯 Chunk encoding: MP4/H.264 (more compatible)")
539
  print(f"⚡ GPU acceleration: {gpu}")
540
 
541
  demo.queue().launch(
@@ -546,8 +570,6 @@ if __name__ == "__main__":
546
  max_threads=40,
547
  mcp_server=True
548
  )
549
-
550
-
551
  # import subprocess
552
  # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
553
 
 
34
  import imageio
35
  import av
36
  import uuid
37
+ import tempfile
38
 
39
  from pipeline import CausalInferencePipeline
40
  from demo_utils.constant import ZERO_VAE_CACHE
 
69
  '''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
70
  '''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
71
  '''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
72
+ '''4. Prompts should match the users intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
73
  '''5. Emphasize motion information and different camera movements present in the input description;\n''' \
74
  '''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
75
  '''7. The revised prompt should be around 80-100 words long.\n''' \
 
147
  "fp8_applied": False,
148
  "current_use_taehv": False,
149
  "current_vae_decoder": None,
150
+ "current_frames": [],
151
  }
152
 
153
+ def frames_to_ts_file(frames, filepath, fps = 15):
 
 
 
154
  """
155
+ Convert frames directly to .ts file using PyAV.
156
+
157
+ Args:
158
+ frames: List of numpy arrays (HWC, RGB, uint8)
159
+ filepath: Output file path
160
+ fps: Frames per second
161
+
162
+ Returns:
163
+ The filepath of the created file
164
  """
165
  if not frames:
166
  return filepath
167
 
168
+ height, width = frames[0].shape[:2]
169
+
170
+ # Create container for MPEG-TS format
171
+ container = av.open(filepath, mode='w', format='mpegts')
172
+
173
+ # Add video stream with optimized settings for streaming
174
+ stream = container.add_stream('h264', rate=fps)
175
+ stream.width = width
176
+ stream.height = height
177
+ stream.pix_fmt = 'yuv420p'
178
+
179
+ # Optimize for low latency streaming
180
+ stream.options = {
181
+ 'preset': 'ultrafast',
182
+ 'tune': 'zerolatency',
183
+ 'crf': '23',
184
+ 'profile': 'baseline',
185
+ 'level': '3.0'
186
+ }
187
+
188
  try:
189
+ for frame_np in frames:
190
+ frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
191
+ frame = frame.reformat(format=stream.pix_fmt)
192
+ for packet in stream.encode(frame):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  container.mux(packet)
194
+
195
+ for packet in stream.encode():
196
+ container.mux(packet)
197
 
198
+ finally:
199
+ container.close()
200
+
201
+ return filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  def initialize_vae_decoder(use_taehv=False, use_trt=False):
204
  if use_trt:
 
259
 
260
  @torch.no_grad()
261
  @spaces.GPU
262
+ def video_generation_handler_streaming(prompt, seed=42, fps=15, save_frames=True):
263
  """
264
+ Generator function that yields .ts video chunks using PyAV for streaming.
265
+ Now optimized for block-based processing.
266
  """
 
 
 
267
  if seed == -1:
268
  seed = random.randint(0, 2**32 - 1)
269
 
270
+ print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
271
 
272
  # Setup
273
  conditional_dict = text_encoder(text_prompts=[prompt])
 
354
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
355
 
356
  all_frames_from_block.append(frame_np)
 
357
  total_frames_yielded += 1
358
 
359
+ # Yield status update for each frame (cute tracking!)
360
  blocks_completed = idx
361
  current_block_progress = (frame_idx + 1) / pixels.shape[1]
362
  total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
363
+
364
+ # Cap at 100% to avoid going over
365
  total_progress = min(total_progress, 100.0)
366
 
367
  frame_status_html = (
 
376
  f"</div>"
377
  )
378
 
379
+ # Yield None for video but update status (frame-by-frame tracking)
380
  yield None, frame_status_html
381
 
382
+ # Encode entire block as one chunk immediately
383
  if all_frames_from_block:
384
  print(f"📹 Encoding block {idx} with {len(all_frames_from_block)} frames")
385
 
386
  try:
387
  chunk_uuid = str(uuid.uuid4())[:8]
388
+ ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
389
+ ts_path = os.path.join("gradio_tmp", ts_filename)
390
 
391
+ frames_to_ts_file(all_frames_from_block, ts_path, fps)
392
 
393
+ # Calculate final progress for this block
394
+ total_progress = (idx + 1) / num_blocks * 100
395
+
396
+ # Yield the actual video chunk
397
+ yield ts_path, gr.update()
398
 
399
  except Exception as e:
400
  print(f"⚠️ Error encoding block {idx}: {e}")
 
415
  f" 📊 Generated {total_frames_yielded} frames across {num_blocks} blocks"
416
  f" </p>"
417
  f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
418
+ f" 🎬 Playback: {fps} FPS • 📁 Format: MPEG-TS/H.264"
419
  f" </p>"
420
  f" </div>"
421
  f"</div>"
422
  )
423
  yield None, final_status_html
424
+ print(f" PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
425
+
426
+ def save_frames_as_video(frames, fps=15):
427
+ """
428
+ Convert frames to a downloadable MP4 video file.
429
+
430
+ Args:
431
+ frames: List of numpy arrays (HWC, RGB, uint8)
432
+ fps: Frames per second
433
+
434
+ Returns:
435
+ Path to the saved video file
436
+ """
437
+ if not frames:
438
+ return None
439
+
440
+ # Create a temporary file with a unique name
441
+ temp_file = os.path.join("gradio_tmp", f"download_{uuid.uuid4()}.mp4")
442
+
443
+ # Use imageio to write the video file
444
+ try:
445
+ writer = imageio.get_writer(temp_file, fps=fps, codec='h264', quality=9)
446
+ for frame in frames:
447
+ writer.append_data(frame)
448
+ writer.close()
449
+ return temp_file
450
+ except Exception as e:
451
+ print(f"Error saving video: {e}")
452
+ return None
453
 
454
  # --- Gradio UI Layout ---
455
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
 
519
  label="Generation Status"
520
  )
521
 
522
+ download_btn = gr.Button("💾 Download Video", variant="secondary")
523
+ download_output = gr.File(label="Download")
 
 
 
 
524
 
525
  # Connect the generator to the streaming video
526
  start_btn.click(
527
+ fn=lambda p, s, f: (APP_STATE.update({"current_frames": []}) or video_generation_handler_streaming(p, s, f)),
528
  inputs=[prompt, seed, fps],
529
  outputs=[streaming_video, status_display]
530
  )
531
 
532
+ # Function to handle download button click
533
+ def download_video(fps):
534
+ if not APP_STATE.get("current_frames"):
535
+ return None
536
+ video_path = save_frames_as_video(APP_STATE["current_frames"], fps)
537
+ return video_path
538
+
539
+ # Connect download button
540
+ download_btn.click(
541
+ fn=download_video,
542
+ inputs=[fps],
543
+ outputs=[download_output],
544
+ show_progress=True
545
+ )
546
+
547
  enhance_button.click(
548
  fn=enhance_prompt,
549
  inputs=[prompt],
 
556
  import shutil
557
  shutil.rmtree("gradio_tmp")
558
  os.makedirs("gradio_tmp", exist_ok=True)
 
559
 
560
  print("🚀 Starting Self-Forcing Streaming Demo")
561
  print(f"📁 Temporary files will be stored in: gradio_tmp/")
562
+ print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
 
563
  print(f"⚡ GPU acceleration: {gpu}")
564
 
565
  demo.queue().launch(
 
570
  max_threads=40,
571
  mcp_server=True
572
  )
 
 
573
  # import subprocess
574
  # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
575