Spaces:
Paused
Paused
gpu
Browse files- batch_generator.py +5 -6
batch_generator.py
CHANGED
|
@@ -23,7 +23,7 @@ import numpy as np
|
|
| 23 |
import gc
|
| 24 |
|
| 25 |
# Global device configuration
|
| 26 |
-
|
| 27 |
dtype = torch.float16
|
| 28 |
|
| 29 |
# Initialize global generator
|
|
@@ -33,7 +33,7 @@ generator = torch.Generator()
|
|
| 33 |
print("Loading VAE...")
|
| 34 |
vae = AutoencoderKL.from_pretrained(
|
| 35 |
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype, use_safetensors=True
|
| 36 |
-
).to(
|
| 37 |
|
| 38 |
print("Loading base pipeline...")
|
| 39 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
@@ -42,7 +42,7 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
| 42 |
torch_dtype=dtype,
|
| 43 |
variant="fp16",
|
| 44 |
use_safetensors=True,
|
| 45 |
-
)
|
| 46 |
|
| 47 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 48 |
|
|
@@ -58,7 +58,7 @@ openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
|
| 58 |
print("Loading T2I adapter...")
|
| 59 |
adapter = T2IAdapter.from_pretrained(
|
| 60 |
"TencentARC/t2i-adapter-openpose-sdxl-1.0", torch_dtype=dtype
|
| 61 |
-
)
|
| 62 |
|
| 63 |
print("Loading adapter pipeline...")
|
| 64 |
posepipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
|
@@ -68,8 +68,7 @@ posepipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
|
| 68 |
torch_dtype=dtype,
|
| 69 |
variant="fp16",
|
| 70 |
use_safetensors=True,
|
| 71 |
-
|
| 72 |
-
)
|
| 73 |
|
| 74 |
|
| 75 |
posepipe.scheduler = UniPCMultistepScheduler.from_config(posepipe.scheduler.config)
|
|
|
|
| 23 |
import gc
|
| 24 |
|
| 25 |
# Global device configuration
|
| 26 |
+
|
| 27 |
dtype = torch.float16
|
| 28 |
|
| 29 |
# Initialize global generator
|
|
|
|
| 33 |
print("Loading VAE...")
|
| 34 |
vae = AutoencoderKL.from_pretrained(
|
| 35 |
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype, use_safetensors=True
|
| 36 |
+
).to("cuda")
|
| 37 |
|
| 38 |
print("Loading base pipeline...")
|
| 39 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
|
| 42 |
torch_dtype=dtype,
|
| 43 |
variant="fp16",
|
| 44 |
use_safetensors=True,
|
| 45 |
+
).to("cuda")
|
| 46 |
|
| 47 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 48 |
|
|
|
|
| 58 |
print("Loading T2I adapter...")
|
| 59 |
adapter = T2IAdapter.from_pretrained(
|
| 60 |
"TencentARC/t2i-adapter-openpose-sdxl-1.0", torch_dtype=dtype
|
| 61 |
+
).to("cuda")
|
| 62 |
|
| 63 |
print("Loading adapter pipeline...")
|
| 64 |
posepipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
|
|
|
| 68 |
torch_dtype=dtype,
|
| 69 |
variant="fp16",
|
| 70 |
use_safetensors=True,
|
| 71 |
+
).to("cuda")
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
posepipe.scheduler = UniPCMultistepScheduler.from_config(posepipe.scheduler.config)
|