KevinNg99 commited on
Commit
55b359a
·
1 Parent(s): 1a63574

fix resolution

Browse files
Files changed (1) hide show
  1. app.py +45 -9
app.py CHANGED
@@ -200,6 +200,15 @@ class HunyuanImageApp:
200
 
201
  self.pipeline = pipeline
202
  self.current_use_distilled = None
 
 
 
 
 
 
 
 
 
203
 
204
 
205
  def print_peak_memory(self):
@@ -207,6 +216,16 @@ class HunyuanImageApp:
207
  stats = torch.cuda.memory_stats()
208
  peak_bytes_requirement = stats["allocated_bytes.all.peak"]
209
  print(f"Before refiner Peak memory requirement: {peak_bytes_requirement / 1024 ** 3:.2f} GB")
 
 
 
 
 
 
 
 
 
 
210
 
211
  @space_context(duration=300)
212
  def generate_image(self,
@@ -423,15 +442,25 @@ def create_interface(auto_load: bool = True, use_distilled: bool = False, device
423
  value=""
424
  )
425
 
426
- with gr.Row():
427
- width = gr.Slider(
428
- minimum=512, maximum=2048, step=64, value=2048,
429
- label="Width", info="Image width in pixels"
430
- )
431
- height = gr.Slider(
432
- minimum=512, maximum=2048, step=64, value=2048,
433
- label="Height", info="Image height in pixels"
434
- )
 
 
 
 
 
 
 
 
 
 
435
 
436
  with gr.Row():
437
  num_inference_steps = gr.Slider(
@@ -579,6 +608,13 @@ def create_interface(auto_load: bool = True, use_distilled: bool = False, device
579
  # )
580
 
581
  # Event handlers
 
 
 
 
 
 
 
582
  generate_btn.click(
583
  fn=app.generate_image,
584
  inputs=[
 
200
 
201
  self.pipeline = pipeline
202
  self.current_use_distilled = None
203
+
204
+ # Define aspect ratio mappings
205
+ self.aspect_ratio_mappings = {
206
+ "16:9": (2560, 1536),
207
+ "4:3": (2304, 1792),
208
+ "1:1": (2048, 2048),
209
+ "3:4": (1792, 2304),
210
+ "9:16": (1536, 2560)
211
+ }
212
 
213
 
214
  def print_peak_memory(self):
 
216
  stats = torch.cuda.memory_stats()
217
  peak_bytes_requirement = stats["allocated_bytes.all.peak"]
218
  print(f"Before refiner Peak memory requirement: {peak_bytes_requirement / 1024 ** 3:.2f} GB")
219
+
220
+ def update_resolution(self, aspect_ratio_choice: str) -> Tuple[int, int]:
221
+ """Update width and height based on selected aspect ratio."""
222
+ # Extract the aspect ratio key from the choice (e.g., "16:9" from "16:9 (2560×1536)")
223
+ aspect_key = aspect_ratio_choice.split(" (")[0]
224
+ if aspect_key in self.aspect_ratio_mappings:
225
+ return self.aspect_ratio_mappings[aspect_key]
226
+ else:
227
+ # Default to 1:1 if not found
228
+ return self.aspect_ratio_mappings["1:1"]
229
 
230
  @space_context(duration=300)
231
  def generate_image(self,
 
442
  value=""
443
  )
444
 
445
+ # Predefined aspect ratios
446
+ aspect_ratios = [
447
+ ("16:9 (2560×1536)", "16:9"),
448
+ ("4:3 (2304×1792)", "4:3"),
449
+ ("1:1 (2048×2048)", "1:1"),
450
+ ("3:4 (1792×2304)", "3:4"),
451
+ ("9:16 (1536×2560)", "9:16")
452
+ ]
453
+
454
+ aspect_ratio = gr.Radio(
455
+ choices=aspect_ratios,
456
+ value="1:1",
457
+ label="Aspect Ratio",
458
+ info="Select the aspect ratio for image generation"
459
+ )
460
+
461
+ # Hidden width and height inputs that get updated based on aspect ratio
462
+ width = gr.Number(value=2048, visible=False)
463
+ height = gr.Number(value=2048, visible=False)
464
 
465
  with gr.Row():
466
  num_inference_steps = gr.Slider(
 
608
  # )
609
 
610
  # Event handlers
611
+ # Update width and height when aspect ratio changes
612
+ aspect_ratio.change(
613
+ fn=app.update_resolution,
614
+ inputs=[aspect_ratio],
615
+ outputs=[width, height]
616
+ )
617
+
618
  generate_btn.click(
619
  fn=app.generate_image,
620
  inputs=[