Spaces:
Sleeping
Sleeping
import sys | |
sys.path.append('StableSR') | |
import os | |
import cv2 | |
import torch | |
import torch.nn.functional as F | |
import gradio as gr | |
import torchvision | |
from torchvision.transforms.functional import normalize | |
from ldm.util import instantiate_from_config | |
from torch import autocast | |
import PIL | |
import numpy as np | |
from pytorch_lightning import seed_everything | |
from contextlib import nullcontext | |
from omegaconf import OmegaConf | |
from PIL import Image | |
import copy | |
from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization | |
from scripts.util_image import ImageSpliterTh | |
from basicsr.utils.download_util import load_file_from_url | |
from einops import rearrange, repeat | |
from itertools import islice | |
# Download weights | |
pretrain_model_url = { | |
'stablesr_512': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_000117.ckpt', | |
'stablesr_768': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_768v_000139.ckpt', | |
'CFW': 'https://huggingface.co/Iceclear/StableSR/resolve/main/vqgan_cfw_00011.ckpt', | |
} | |
for k, url in pretrain_model_url.items(): | |
filename = url.split("/")[-1] | |
if not os.path.exists(f'./{filename}'): | |
load_file_from_url(url=url, model_dir='./', progress=True, file_name=None) | |
# Download sample images | |
image_urls = [ | |
('01.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/Lincoln.png'), | |
('02.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/oldphoto6.png'), | |
('03.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/comic2.png'), | |
('04.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/OST_120.png'), | |
('05.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet65/comic3.png'), | |
] | |
for fname, url in image_urls: | |
torch.hub.download_url_to_file(url, fname) | |
def load_img(path): | |
image = Image.open(path).convert("RGB") | |
w, h = image.size | |
w, h = map(lambda x: x - x % 32, (w, h)) | |
image = image.resize((w, h), resample=PIL.Image.LANCZOS) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image[None].transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image) | |
return 2.*image - 1. | |
def space_timesteps(num_timesteps, section_counts): | |
""" | |
Create a list of timesteps to use from an original diffusion process, | |
given the number of timesteps we want to take from equally-sized portions | |
of the original process. | |
For example, if there's 300 timesteps and the section counts are [10,15,20] | |
then the first 100 timesteps are strided to be 10 timesteps, the second 100 | |
are strided to be 15 timesteps, and the final 100 are strided to be 20. | |
If the stride is a string starting with "ddim", then the fixed striding | |
from the DDIM paper is used, and only one section is allowed. | |
:param num_timesteps: the number of diffusion steps in the original | |
process to divide up. | |
:param section_counts: either a list of numbers, or a string containing | |
comma-separated numbers, indicating the step count | |
per section. As a special case, use "ddimN" where N | |
is a number of steps to use the striding from the | |
DDIM paper. | |
:return: a set of diffusion steps from the original process to use. | |
""" | |
if isinstance(section_counts, str): | |
if section_counts.startswith("ddim"): | |
desired_count = int(section_counts[len("ddim"):]) | |
for i in range(1, num_timesteps): | |
if len(range(0, num_timesteps, i)) == desired_count: | |
return set(range(0, num_timesteps, i)) | |
raise ValueError( | |
f"cannot create exactly {num_timesteps} steps with an integer stride" | |
) | |
section_counts = [int(x) for x in section_counts.split(",")] #[250,] | |
size_per = num_timesteps // len(section_counts) | |
extra = num_timesteps % len(section_counts) | |
start_idx = 0 | |
all_steps = [] | |
for i, section_count in enumerate(section_counts): | |
size = size_per + (1 if i < extra else 0) | |
if size < section_count: | |
raise ValueError( | |
f"cannot divide section of {size} steps into {section_count}" | |
) | |
if section_count <= 1: | |
frac_stride = 1 | |
else: | |
frac_stride = (size - 1) / (section_count - 1) | |
cur_idx = 0.0 | |
taken_steps = [] | |
for _ in range(section_count): | |
taken_steps.append(start_idx + round(cur_idx)) | |
cur_idx += frac_stride | |
all_steps += taken_steps | |
start_idx += size | |
return set(all_steps) | |
def chunk(it, size): | |
it = iter(it) | |
return iter(lambda: tuple(islice(it, size)), ()) | |
def load_model_from_config(config, ckpt, verbose=False): | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
if "global_step" in pl_sd: | |
print(f"Global Step: {pl_sd['global_step']}") | |
sd = pl_sd["state_dict"] | |
model = instantiate_from_config(config.model) | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0 and verbose: | |
print("missing keys:") | |
print(m) | |
if len(u) > 0 and verbose: | |
print("unexpected keys:") | |
print(u) | |
model.cuda() | |
model.eval() | |
return model | |
# Load VQGAN model | |
device = torch.device("cuda") | |
vqgan_config = OmegaConf.load("StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml") | |
vq_model = instantiate_from_config(vqgan_config.model) | |
vq_sd = torch.load('./vqgan_cfw_00011.ckpt', map_location='cpu')['state_dict'] | |
vq_model.load_state_dict(vq_sd, strict=False) | |
vq_model.cuda().eval() | |
os.makedirs('output', exist_ok=True) | |
def inference(image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type): | |
"""Run a single prediction on the model""" | |
precision_scope = autocast | |
vq_model.decoder.fusion_w = dec_w | |
seed_everything(seed) | |
if model_type == '512': | |
config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_512.yaml") | |
model = load_model_from_config(config, "./stablesr_000117.ckpt") | |
min_size = 512 | |
else: | |
config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_768v.yaml") | |
model = load_model_from_config(config, "./stablesr_768v_000139.ckpt") | |
min_size = 768 | |
model = model.to(device) | |
model.configs = config | |
model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, | |
linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3) | |
model.num_timesteps = 1000 | |
sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod) | |
sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod) | |
use_timesteps = set(space_timesteps(1000, [ddpm_steps])) | |
last_alpha_cumprod = 1.0 | |
new_betas = [] | |
timestep_map = [] | |
for i, alpha_cumprod in enumerate(model.alphas_cumprod): | |
if i in use_timesteps: | |
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) | |
last_alpha_cumprod = alpha_cumprod | |
timestep_map.append(i) | |
new_betas = [beta.data.cpu().numpy() for beta in new_betas] | |
model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas)) | |
model.num_timesteps = 1000 | |
model.ori_timesteps = list(use_timesteps) | |
model.ori_timesteps.sort() | |
model = model.to(device) | |
try: # global try | |
with torch.no_grad(): | |
with precision_scope("cuda"): | |
with model.ema_scope(): | |
init_image = load_img(image) | |
init_image = F.interpolate( | |
init_image, | |
size=(int(init_image.size(-2)*upscale), | |
int(init_image.size(-1)*upscale)), | |
mode='bicubic', | |
) | |
if init_image.size(-1) < min_size or init_image.size(-2) < min_size: | |
ori_size = init_image.size() | |
rescale = min_size * 1.0 / min(init_image.size(-2), init_image.size(-1)) | |
new_h = max(int(ori_size[-2]*rescale), min_size) | |
new_w = max(int(ori_size[-1]*rescale), min_size) | |
init_template = F.interpolate( | |
init_image, | |
size=(new_h, new_w), | |
mode='bicubic', | |
) | |
else: | |
init_template = init_image | |
rescale = 1 | |
init_template = init_template.clamp(-1, 1) | |
assert init_template.size(-1) >= min_size | |
assert init_template.size(-2) >= min_size | |
init_template = init_template.type(torch.float16).to(device) | |
if init_template.size(-1) <= 1280 or init_template.size(-2) <= 1280: | |
init_latent_generator, enc_fea_lq = vq_model.encode(init_template) | |
init_latent = model.get_first_stage_encoding(init_latent_generator) | |
text_init = ['']*init_template.size(0) | |
semantic_c = model.cond_stage_model(text_init) | |
noise = torch.randn_like(init_latent) | |
t = repeat(torch.tensor([999]), '1 -> b', b=init_image.size(0)) | |
t = t.to(device).long() | |
x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise) | |
if init_template.size(-1)<= min_size and init_template.size(-2) <= min_size: | |
samples, _ = model.sample(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True) | |
else: | |
samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=init_template.size(0)) | |
x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq) | |
if colorfix_type == 'adain': | |
x_samples = adaptive_instance_normalization(x_samples, init_template) | |
elif colorfix_type == 'wavelet': | |
x_samples = wavelet_reconstruction(x_samples, init_template) | |
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) | |
else: | |
im_spliter = ImageSpliterTh(init_template, 1280, 1000, sf=1) | |
for im_lq_pch, index_infos in im_spliter: | |
init_latent = model.get_first_stage_encoding(model.encode_first_stage(im_lq_pch)) # move to latent space | |
text_init = ['']*init_latent.size(0) | |
semantic_c = model.cond_stage_model(text_init) | |
noise = torch.randn_like(init_latent) | |
# If you would like to start from the intermediate steps, you can add noise to LR to the specific steps. | |
t = repeat(torch.tensor([999]), '1 -> b', b=init_template.size(0)) | |
t = t.to(device).long() | |
x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise) | |
# x_T = noise | |
samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=im_lq_pch.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=im_lq_pch.size(0)) | |
_, enc_fea_lq = vq_model.encode(im_lq_pch) | |
x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq) | |
if colorfix_type == 'adain': | |
x_samples = adaptive_instance_normalization(x_samples, im_lq_pch) | |
elif colorfix_type == 'wavelet': | |
x_samples = wavelet_reconstruction(x_samples, im_lq_pch) | |
im_spliter.update(x_samples, index_infos) | |
x_samples = im_spliter.gather() | |
x_samples = torch.clamp((x_samples+1.0)/2.0, min=0.0, max=1.0) | |
if rescale > 1: | |
x_samples = F.interpolate( | |
x_samples, | |
size=(int(init_image.size(-2)), | |
int(init_image.size(-1))), | |
mode='bicubic', | |
) | |
x_samples = x_samples.clamp(0, 1) | |
x_sample = 255. * rearrange(x_samples[0].cpu().numpy(), 'c h w -> h w c') | |
restored_img = x_sample.astype(np.uint8) | |
Image.fromarray(x_sample.astype(np.uint8)).save(f'output/out.png') | |
return restored_img, f'output/out.png' | |
except Exception as error: | |
print('Global exception', error) | |
return None, None | |
# Gradio UI | |
with gr.Blocks(title="Exploiting Diffusion Prior for Real-World Image Super-Resolution") as demo: | |
gr.HTML( | |
""" | |
<div style="display: flex; justify-content: center; align-items: center; height: 40px;"> | |
<img src="https://user-images.githubusercontent.com/22350795/236680126-0b1cdd62-d6fc-4620-b998-75ed6c31bf6f.png" | |
alt="StableSR logo" style='height:40px'> | |
</div> | |
<div style='text-align: center;'> | |
<h2>Exploiting Diffusion Prior for Real-World Image Super-Resolution</h2> | |
<p><strong>Official Gradio demo</strong> for <a href='https://github.com/IceClear/StableSR' target='_blank'>StableSR</a>.<br> | |
🔥 StableSR is a general image super-resolution algorithm for real-world and AIGC images.</p> | |
</div> | |
""" | |
) | |
gr.HTML( | |
""" | |
<div style="margin-top:1em"> | |
<p>If StableSR is helpful, please help to ⭐ the <a href='https://github.com/IceClear/StableSR' target='_blank'>Github Repo</a>. Thanks!</p> | |
<a href='https://github.com/IceClear/StableSR' target='_blank'> | |
<img src='https://img.shields.io/github/stars/IceClear/StableSR?style=social'> | |
</a> | |
<hr> | |
<h4>Citation</h4> | |
<pre style="white-space: pre-wrap; background: #a7a7a7; padding: 1em; border-radius: 5px;"> | |
@article{wang2024exploiting, | |
author = {Wang, Jianyi and Yue, Zongsheng and Zhou, Shangchen and Chan, Kelvin C.K. and Loy, Chen Change}, | |
title = {Exploiting Diffusion Prior for Real-World Image Super-Resolution}, | |
journal = {International Journal of Computer Vision}, | |
year = {2024} | |
} | |
</pre> | |
<h4>License</h4> | |
<p>This project is licensed under <a rel="license" href="https://github.com/IceClear/StableSR/blob/main/LICENSE.txt">S-Lab License 1.0</a>. Redistribution and use for non-commercial purposes should follow this license.</p> | |
<h4>Contact</h4> | |
<p>If you have any questions, please feel free to reach out at <b>iceclearwjy@gmail.com</b>.</p> | |
<div style="margin-top:1em"> | |
🤗 Find Me:<br> | |
<a href="https://twitter.com/Iceclearwjy"> | |
<img src="https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social" alt="Twitter Follow"> | |
</a> | |
<a href="https://github.com/IceClear"> | |
<img src="https://img.shields.io/github/followers/IceClear?style=social" alt="Github Follow"> | |
</a> | |
</div> | |
<div style="text-align: center; margin-top:1em"> | |
<img src='https://visitor-badge.laobi.icu/badge?page_id=IceClear/StableSR' alt='visitors'> | |
</div> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(type="filepath", label="Input") | |
upscale = gr.Number(value=1, label="Rescaling_Factor") | |
dec_w = gr.Slider(0, 1, value=0.5, step=0.01, label='CFW_Fidelity') | |
seed = gr.Number(value=42, label="Seeds") | |
model_type = gr.Dropdown(choices=["512", "768v"], value="512", label="Model") | |
ddpm_steps = gr.Slider(10, 1000, value=200, step=1, label='DDPM Steps') | |
colorfix_type = gr.Dropdown(choices=["none", "adain", "wavelet"], value="adain", label="Color Correction") | |
run_btn = gr.Button("Run Inference") | |
with gr.Column(): | |
output_image = gr.Image(type="numpy", label="Output") | |
output_file = gr.File(label="Download the output") | |
run_btn.click( | |
fn=inference, | |
inputs=[image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type], | |
outputs=[output_image, output_file] | |
) | |
gr.Examples( | |
examples=[ | |
['01.png', 4, 0.5, 42, "512", 200, "adain"], | |
['02.png', 4, 0.5, 42, "512", 200, "adain"], | |
['03.png', 4, 0.5, 42, "512", 200, "adain"], | |
['04.png', 4, 0.5, 42, "512", 200, "adain"], | |
['05.png', 4, 0.5, 42, "512", 200, "adain"] | |
], | |
fn=inference, | |
inputs=[image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type], | |
outputs=[output_image, output_file], | |
cache_examples=True | |
) | |
demo.queue() | |
demo.launch() |