Enhance Mesh class documentation by adding missing line breaks in docstrings for improved readability. Update device handling in FlexiCubes and FlexiCubesGeometry classes to default to 'cuda', ensuring consistent device usage across the application. Refactor ImageDreamDiffusion class to assert mode validity and streamline camera matrix pre-computation.
d493b2e
import numpy as np | |
import torch | |
from imagedream.camera_utils import get_camera_for_index | |
from imagedream.ldm.util import set_seed, add_random_background | |
from libs.base_utils import do_resize_content | |
from imagedream.ldm.models.diffusion.ddim import DDIMSampler | |
from torchvision import transforms as T | |
class ImageDreamDiffusion: | |
def __init__( | |
self, | |
model, | |
device, | |
dtype, | |
mode, | |
num_frames, | |
camera_views, | |
ref_position, | |
random_background=False, | |
offset_noise=False, | |
resize_rate=1, | |
image_size=256, | |
seed=1234, | |
) -> None: | |
assert mode in ["pixel", "local"] | |
size = image_size | |
self.seed = seed | |
batch_size = max(4, num_frames) | |
neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." | |
uc = model.get_learned_conditioning([neg_texts]).to(device) | |
sampler = DDIMSampler(model) | |
# pre-compute camera matrices | |
camera = [get_camera_for_index(i).squeeze() for i in camera_views] | |
camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero | |
camera = torch.stack(camera) | |
camera = camera.repeat(batch_size // num_frames, 1).to(device) | |
self.image_transform = T.Compose( | |
[ | |
T.Resize((size, size)), | |
T.ToTensor(), | |
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
self.dtype = dtype | |
self.ref_position = ref_position | |
self.mode = mode | |
self.random_background = random_background | |
self.resize_rate = resize_rate | |
self.num_frames = num_frames | |
self.size = size | |
self.device = device | |
self.batch_size = batch_size | |
self.model = model | |
self.sampler = sampler | |
self.uc = uc | |
self.camera = camera | |
self.offset_noise = offset_noise | |
def i2i( | |
model, | |
image_size, | |
prompt, | |
uc, | |
sampler, | |
ip=None, | |
step=20, | |
scale=5.0, | |
batch_size=8, | |
ddim_eta=0.0, | |
dtype=torch.float32, | |
device="cuda", | |
camera=None, | |
num_frames=4, | |
pixel_control=False, | |
transform=None, | |
offset_noise=False, | |
): | |
""" The function supports additional image prompt. | |
Args: | |
model (_type_): the image dream model | |
image_size (_type_): size of diffusion output (standard 256) | |
prompt (_type_): text prompt for the image (prompt in type str) | |
uc (_type_): unconditional vector (tensor in shape [1, 77, 1024]) | |
sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler | |
ip (Image, optional): the image prompt. Defaults to None. | |
step (int, optional): _description_. Defaults to 20. | |
scale (float, optional): _description_. Defaults to 7.5. | |
batch_size (int, optional): _description_. Defaults to 8. | |
ddim_eta (float, optional): _description_. Defaults to 0.0. | |
dtype (_type_, optional): _description_. Defaults to torch.float32. | |
device (str, optional): _description_. Defaults to "cuda". | |
camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 | |
num_frames (int, optional): _num of frames (views) to generate | |
pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode | |
transform: Compose( | |
Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn) | |
ToTensor() | |
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
) | |
""" | |
ip_raw = ip | |
if type(prompt) != list: | |
prompt = [prompt] | |
with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype): | |
c = model.get_learned_conditioning(prompt).to( | |
device | |
) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05 | |
c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size | |
uc_ = {"context": uc.repeat(batch_size, 1, 1)} | |
if camera is not None: | |
c_["camera"] = uc_["camera"] = ( | |
camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 | |
) | |
c_["num_frames"] = uc_["num_frames"] = num_frames | |
if ip is not None: | |
ip_embed = model.get_learned_image_conditioning(ip).to( | |
device | |
) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12 | |
ip_ = ip_embed.repeat(batch_size, 1, 1) | |
c_["ip"] = ip_ | |
uc_["ip"] = torch.zeros_like(ip_) | |
if pixel_control: | |
assert camera is not None | |
ip = transform(ip).to( | |
device | |
) # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00 | |
ip_img = model.get_first_stage_encoding( | |
model.encode_first_stage(ip[None, :, :, :]) | |
) # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55 | |
c_["ip_img"] = ip_img | |
uc_["ip_img"] = torch.zeros_like(ip_img) | |
shape = [4, image_size // 8, image_size // 8] # [4, 32, 32] | |
if offset_noise: | |
ref = transform(ip_raw).to(device) | |
ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :])) | |
ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True) | |
time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device) | |
x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps) | |
samples_ddim, _ = ( | |
sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43 | |
S=step, | |
conditioning=c_, | |
batch_size=batch_size, | |
shape=shape, | |
verbose=False, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=uc_, | |
eta=ddim_eta, | |
x_T=x_T if offset_noise else None, | |
) | |
) | |
x_sample = model.decode_first_stage(samples_ddim) | |
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) | |
x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() | |
return list(x_sample.astype(np.uint8)) | |
def diffuse(self, t, ip, n_test=2): | |
set_seed(self.seed) | |
ip = do_resize_content(ip, self.resize_rate) | |
if self.random_background: | |
ip = add_random_background(ip) | |
images = [] | |
for _ in range(n_test): | |
img = self.i2i( | |
self.model, | |
self.size, | |
t, | |
self.uc, | |
self.sampler, | |
ip=ip, | |
step=50, | |
scale=5, | |
batch_size=self.batch_size, | |
ddim_eta=0.0, | |
dtype=self.dtype, | |
device=self.device, | |
camera=self.camera, | |
num_frames=self.num_frames, | |
pixel_control=(self.mode == "pixel"), | |
transform=self.image_transform, | |
offset_noise=self.offset_noise, | |
) | |
img = np.concatenate(img, 1) | |
img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1) | |
images.append(img) | |
set_seed() # unset random and numpy seed | |
return images | |
class ImageDreamDiffusionStage2: | |
def __init__( | |
self, | |
model, | |
device, | |
dtype, | |
num_frames, | |
camera_views, | |
ref_position, | |
random_background=False, | |
offset_noise=False, | |
resize_rate=1, | |
mode="pixel", | |
image_size=256, | |
seed=1234, | |
) -> None: | |
assert mode in ["pixel", "local"] | |
size = image_size | |
self.seed = seed | |
batch_size = max(4, num_frames) | |
neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." | |
uc = model.get_learned_conditioning([neg_texts]).to(device) | |
sampler = DDIMSampler(model) | |
# pre-compute camera matrices | |
camera = [get_camera_for_index(i).squeeze() for i in camera_views] | |
if ref_position is not None: | |
camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero | |
camera = torch.stack(camera) | |
camera = camera.repeat(batch_size // num_frames, 1).to(device) | |
self.image_transform = T.Compose( | |
[ | |
T.Resize((size, size)), | |
T.ToTensor(), | |
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
self.dtype = dtype | |
self.mode = mode | |
self.ref_position = ref_position | |
self.random_background = random_background | |
self.resize_rate = resize_rate | |
self.num_frames = num_frames | |
self.size = size | |
self.device = device | |
self.batch_size = batch_size | |
self.model = model | |
self.sampler = sampler | |
self.uc = uc | |
self.camera = camera | |
self.offset_noise = offset_noise | |
def i2iStage2( | |
model, | |
image_size, | |
prompt, | |
uc, | |
sampler, | |
pixel_images, | |
ip=None, | |
step=20, | |
scale=5.0, | |
batch_size=8, | |
ddim_eta=0.0, | |
dtype=torch.float32, | |
device="cuda", | |
camera=None, | |
num_frames=4, | |
pixel_control=False, | |
transform=None, | |
offset_noise=False, | |
): | |
ip_raw = ip | |
if type(prompt) != list: | |
prompt = [prompt] | |
with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype): | |
c = model.get_learned_conditioning(prompt).to( | |
device | |
) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05 | |
c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size | |
uc_ = {"context": uc.repeat(batch_size, 1, 1)} | |
if camera is not None: | |
c_["camera"] = uc_["camera"] = ( | |
camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 | |
) | |
c_["num_frames"] = uc_["num_frames"] = num_frames | |
if ip is not None: | |
ip_embed = model.get_learned_image_conditioning(ip).to( | |
device | |
) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12 | |
ip_ = ip_embed.repeat(batch_size, 1, 1) | |
c_["ip"] = ip_ | |
uc_["ip"] = torch.zeros_like(ip_) | |
if pixel_control: | |
assert camera is not None | |
transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images]) | |
latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images)) | |
c_["pixel_images"] = latent_pixel_images | |
uc_["pixel_images"] = torch.zeros_like(latent_pixel_images) | |
shape = [4, image_size // 8, image_size // 8] # [4, 32, 32] | |
if offset_noise: | |
ref = transform(ip_raw).to(device) | |
ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :])) | |
ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True) | |
time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device) | |
x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps) | |
samples_ddim, _ = ( | |
sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43 | |
S=step, | |
conditioning=c_, | |
batch_size=batch_size, | |
shape=shape, | |
verbose=False, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=uc_, | |
eta=ddim_eta, | |
x_T=x_T if offset_noise else None, | |
) | |
) | |
x_sample = model.decode_first_stage(samples_ddim) | |
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) | |
x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() | |
return list(x_sample.astype(np.uint8)) | |
def diffuse(self, t, ip, pixel_images, n_test=2): | |
set_seed(self.seed) | |
ip = do_resize_content(ip, self.resize_rate) | |
pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images] | |
if self.random_background: | |
bg_color = np.random.rand() * 255 | |
ip = add_random_background(ip, bg_color) | |
pixel_images = [add_random_background(i, bg_color) for i in pixel_images] | |
images = [] | |
for _ in range(n_test): | |
img = self.i2iStage2( | |
self.model, | |
self.size, | |
t, | |
self.uc, | |
self.sampler, | |
pixel_images=pixel_images, | |
ip=ip, | |
step=50, | |
scale=5, | |
batch_size=self.batch_size, | |
ddim_eta=0.0, | |
dtype=self.dtype, | |
device=self.device, | |
camera=self.camera, | |
num_frames=self.num_frames, | |
pixel_control=(self.mode == "pixel"), | |
transform=self.image_transform, | |
offset_noise=self.offset_noise, | |
) | |
img = np.concatenate(img, 1) | |
img = np.concatenate( | |
(img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]), | |
axis=1, | |
) | |
images.append(img) | |
set_seed() # unset random and numpy seed | |
return images |