Update device handling in TwoStagePipeline to default to 'cuda' and adjust model loading to use 'cpu' for compatibility. Ensure consistent device usage across stages and improve code clarity.
7391a72
import torch | |
from libs.base_utils import do_resize_content | |
from imagedream.ldm.util import ( | |
instantiate_from_config, | |
get_obj_from_str, | |
) | |
from omegaconf import OmegaConf | |
from PIL import Image | |
import numpy as np | |
class TwoStagePipeline(object): | |
def __init__( | |
self, | |
stage1_model_config, | |
stage2_model_config, | |
stage1_sampler_config, | |
stage2_sampler_config, | |
device="cuda", | |
dtype=torch.float16, | |
resize_rate=1, | |
) -> None: | |
""" | |
only for two stage generate process. | |
- the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config | |
- the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config | |
""" | |
self.resize_rate = resize_rate | |
self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model) | |
self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False) | |
self.stage1_model = self.stage1_model.to(device).to(dtype) | |
self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model) | |
sd = torch.load(stage2_model_config.resume, map_location="cpu") | |
self.stage2_model.load_state_dict(sd, strict=False) | |
self.stage2_model = self.stage2_model.to(device).to(dtype) | |
self.stage1_model.device = device | |
self.stage2_model.device = device | |
self.device = device | |
self.dtype = dtype | |
self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)( | |
self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params | |
) | |
self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)( | |
self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params | |
) | |
def stage1_sample( | |
self, | |
pixel_img, | |
prompt="3D assets", | |
neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.", | |
step=50, | |
scale=5, | |
ddim_eta=0.0, | |
): | |
if type(pixel_img) == str: | |
pixel_img = Image.open(pixel_img) | |
if isinstance(pixel_img, Image.Image): | |
if pixel_img.mode == "RGBA": | |
background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) | |
pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") | |
else: | |
pixel_img = pixel_img.convert("RGB") | |
else: | |
raise | |
uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device) | |
stage1_images = self.stage1_sampler.i2i( | |
self.stage1_sampler.model, | |
self.stage1_sampler.size, | |
prompt, | |
uc=uc, | |
sampler=self.stage1_sampler.sampler, | |
ip=pixel_img, | |
step=step, | |
scale=scale, | |
batch_size=self.stage1_sampler.batch_size, | |
ddim_eta=ddim_eta, | |
dtype=self.stage1_sampler.dtype, | |
device=self.stage1_sampler.device, | |
camera=self.stage1_sampler.camera, | |
num_frames=self.stage1_sampler.num_frames, | |
pixel_control=(self.stage1_sampler.mode == "pixel"), | |
transform=self.stage1_sampler.image_transform, | |
offset_noise=self.stage1_sampler.offset_noise, | |
) | |
stage1_images = [Image.fromarray(img) for img in stage1_images] | |
stage1_images.pop(self.stage1_sampler.ref_position) | |
return stage1_images | |
def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50): | |
if type(pixel_img) == str: | |
pixel_img = Image.open(pixel_img) | |
if isinstance(pixel_img, Image.Image): | |
if pixel_img.mode == "RGBA": | |
background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) | |
pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") | |
else: | |
pixel_img = pixel_img.convert("RGB") | |
else: | |
raise | |
stage2_images = self.stage2_sampler.i2iStage2( | |
self.stage2_sampler.model, | |
self.stage2_sampler.size, | |
"3D assets", | |
self.stage2_sampler.uc, | |
self.stage2_sampler.sampler, | |
pixel_images=stage1_images, | |
ip=pixel_img, | |
step=step, | |
scale=scale, | |
batch_size=self.stage2_sampler.batch_size, | |
ddim_eta=0.0, | |
dtype=self.stage2_sampler.dtype, | |
device=self.stage2_sampler.device, | |
camera=self.stage2_sampler.camera, | |
num_frames=self.stage2_sampler.num_frames, | |
pixel_control=(self.stage2_sampler.mode == "pixel"), | |
transform=self.stage2_sampler.image_transform, | |
offset_noise=self.stage2_sampler.offset_noise, | |
) | |
stage2_images = [Image.fromarray(img) for img in stage2_images] | |
return stage2_images | |
def set_seed(self, seed): | |
self.stage1_sampler.seed = seed | |
self.stage2_sampler.seed = seed | |
def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50): | |
pixel_img = do_resize_content(pixel_img, self.resize_rate) | |
stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step) | |
stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step) | |
return { | |
"ref_img": pixel_img, | |
"stage1_images": stage1_images, | |
"stage2_images": stage2_images, | |
} | |
if __name__ == "__main__": | |
stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config | |
stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config | |
stage2_sampler_config = stage2_config.sampler | |
stage1_sampler_config = stage1_config.sampler | |
stage1_model_config = stage1_config.models | |
stage2_model_config = stage2_config.models | |
pipeline = TwoStagePipeline( | |
stage1_model_config, | |
stage2_model_config, | |
stage1_sampler_config, | |
stage2_sampler_config, | |
) | |
img = Image.open("assets/astronaut.png") | |
rt_dict = pipeline(img) | |
stage1_images = rt_dict["stage1_images"] | |
stage2_images = rt_dict["stage2_images"] | |
np_imgs = np.concatenate(stage1_images, 1) | |
np_xyzs = np.concatenate(stage2_images, 1) | |
Image.fromarray(np_imgs).save("pixel_images.png") | |
Image.fromarray(np_xyzs).save("xyz_images.png") |