innoai commited on
Commit
bce8064
·
verified ·
1 Parent(s): bd4727a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -1
app.py CHANGED
@@ -169,6 +169,27 @@ ASPECT_RATIOS = {
169
  }
170
  }
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def frames_to_ts_file(frames, filepath, fps = 15):
173
  """
174
  Convert frames directly to .ts file using PyAV.
@@ -360,7 +381,8 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, aspect_ratio="16
360
 
361
  vae_cache, latents_cache = None, None
362
  if not APP_STATE["current_use_taehv"] and not args.trt:
363
- vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
 
364
 
365
  num_blocks = 7
366
  current_start_frame = 0
 
169
  }
170
  }
171
 
172
+ def get_vae_cache_for_aspect_ratio(aspect_ratio, device, dtype):
173
+ """
174
+ Create VAE cache with appropriate dimensions for the given aspect ratio.
175
+ """
176
+ ar_config = ASPECT_RATIOS[aspect_ratio]
177
+ latent_h = ar_config["latent_h"]
178
+ latent_w = ar_config["latent_w"]
179
+
180
+ # Create new cache tensors with correct dimensions
181
+ # Based on ZERO_VAE_CACHE structure but adjusted for aspect ratio
182
+ cache = []
183
+
184
+ # The cache dimensions need to match the latent dimensions
185
+ # These are placeholder tensors that will be updated during generation
186
+ cache.append(torch.zeros(1, 512, latent_h // 8, latent_w // 8, device=device, dtype=dtype)) # 8x downsampled
187
+ cache.append(torch.zeros(1, 512, latent_h // 4, latent_w // 4, device=device, dtype=dtype)) # 4x downsampled
188
+ cache.append(torch.zeros(1, 256, latent_h // 2, latent_w // 2, device=device, dtype=dtype)) # 2x downsampled
189
+ cache.append(torch.zeros(1, 128, latent_h, latent_w, device=device, dtype=dtype)) # 1x (same as latent)
190
+
191
+ return cache
192
+
193
  def frames_to_ts_file(frames, filepath, fps = 15):
194
  """
195
  Convert frames directly to .ts file using PyAV.
 
381
 
382
  vae_cache, latents_cache = None, None
383
  if not APP_STATE["current_use_taehv"] and not args.trt:
384
+ # Create VAE cache with correct dimensions for the aspect ratio
385
+ vae_cache = get_vae_cache_for_aspect_ratio(aspect_ratio, gpu, torch.float16)
386
 
387
  num_blocks = 7
388
  current_start_frame = 0