mung-bean commited on
Commit
218dfba
·
1 Parent(s): 53e70fe
Files changed (1) hide show
  1. batch_generator.py +5 -6
batch_generator.py CHANGED
@@ -23,7 +23,7 @@ import numpy as np
23
  import gc
24
 
25
  # Global device configuration
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  dtype = torch.float16
28
 
29
  # Initialize global generator
@@ -33,7 +33,7 @@ generator = torch.Generator()
33
  print("Loading VAE...")
34
  vae = AutoencoderKL.from_pretrained(
35
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype, use_safetensors=True
36
- ).to(device)
37
 
38
  print("Loading base pipeline...")
39
  pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -42,7 +42,7 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
42
  torch_dtype=dtype,
43
  variant="fp16",
44
  use_safetensors=True,
45
- )
46
 
47
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
48
 
@@ -58,7 +58,7 @@ openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
58
  print("Loading T2I adapter...")
59
  adapter = T2IAdapter.from_pretrained(
60
  "TencentARC/t2i-adapter-openpose-sdxl-1.0", torch_dtype=dtype
61
- )
62
 
63
  print("Loading adapter pipeline...")
64
  posepipe = StableDiffusionXLAdapterPipeline.from_pretrained(
@@ -68,8 +68,7 @@ posepipe = StableDiffusionXLAdapterPipeline.from_pretrained(
68
  torch_dtype=dtype,
69
  variant="fp16",
70
  use_safetensors=True,
71
- attn_implementation="xformers",
72
- )
73
 
74
 
75
  posepipe.scheduler = UniPCMultistepScheduler.from_config(posepipe.scheduler.config)
 
23
  import gc
24
 
25
  # Global device configuration
26
+
27
  dtype = torch.float16
28
 
29
  # Initialize global generator
 
33
  print("Loading VAE...")
34
  vae = AutoencoderKL.from_pretrained(
35
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype, use_safetensors=True
36
+ ).to("cuda")
37
 
38
  print("Loading base pipeline...")
39
  pipe = StableDiffusionXLPipeline.from_pretrained(
 
42
  torch_dtype=dtype,
43
  variant="fp16",
44
  use_safetensors=True,
45
+ ).to("cuda")
46
 
47
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
48
 
 
58
  print("Loading T2I adapter...")
59
  adapter = T2IAdapter.from_pretrained(
60
  "TencentARC/t2i-adapter-openpose-sdxl-1.0", torch_dtype=dtype
61
+ ).to("cuda")
62
 
63
  print("Loading adapter pipeline...")
64
  posepipe = StableDiffusionXLAdapterPipeline.from_pretrained(
 
68
  torch_dtype=dtype,
69
  variant="fp16",
70
  use_safetensors=True,
71
+ ).to("cuda")
 
72
 
73
 
74
  posepipe.scheduler = UniPCMultistepScheduler.from_config(posepipe.scheduler.config)