StableSR / app.py
Iceclear's picture
Update app.py
5c37840 verified
import sys
sys.path.append('StableSR')
import os
import cv2
import torch
import torch.nn.functional as F
import gradio as gr
import torchvision
from torchvision.transforms.functional import normalize
from ldm.util import instantiate_from_config
from torch import autocast
import PIL
import numpy as np
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from omegaconf import OmegaConf
from PIL import Image
import copy
from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization
from scripts.util_image import ImageSpliterTh
from basicsr.utils.download_util import load_file_from_url
from einops import rearrange, repeat
from itertools import islice
# Download weights
pretrain_model_url = {
'stablesr_512': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_000117.ckpt',
'stablesr_768': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_768v_000139.ckpt',
'CFW': 'https://huggingface.co/Iceclear/StableSR/resolve/main/vqgan_cfw_00011.ckpt',
}
for k, url in pretrain_model_url.items():
filename = url.split("/")[-1]
if not os.path.exists(f'./{filename}'):
load_file_from_url(url=url, model_dir='./', progress=True, file_name=None)
# Download sample images
image_urls = [
('01.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/Lincoln.png'),
('02.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/oldphoto6.png'),
('03.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/comic2.png'),
('04.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/OST_120.png'),
('05.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet65/comic3.png'),
]
for fname, url in image_urls:
torch.hub.download_url_to_file(url, fname)
def load_img(path):
image = Image.open(path).convert("RGB")
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h))
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim"):])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")] #[250,]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
# Load VQGAN model
device = torch.device("cuda")
vqgan_config = OmegaConf.load("StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
vq_model = instantiate_from_config(vqgan_config.model)
vq_sd = torch.load('./vqgan_cfw_00011.ckpt', map_location='cpu')['state_dict']
vq_model.load_state_dict(vq_sd, strict=False)
vq_model.cuda().eval()
os.makedirs('output', exist_ok=True)
def inference(image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type):
"""Run a single prediction on the model"""
precision_scope = autocast
vq_model.decoder.fusion_w = dec_w
seed_everything(seed)
if model_type == '512':
config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_512.yaml")
model = load_model_from_config(config, "./stablesr_000117.ckpt")
min_size = 512
else:
config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_768v.yaml")
model = load_model_from_config(config, "./stablesr_768v_000139.ckpt")
min_size = 768
model = model.to(device)
model.configs = config
model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3)
model.num_timesteps = 1000
sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod)
sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod)
use_timesteps = set(space_timesteps(1000, [ddpm_steps]))
last_alpha_cumprod = 1.0
new_betas = []
timestep_map = []
for i, alpha_cumprod in enumerate(model.alphas_cumprod):
if i in use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
timestep_map.append(i)
new_betas = [beta.data.cpu().numpy() for beta in new_betas]
model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas))
model.num_timesteps = 1000
model.ori_timesteps = list(use_timesteps)
model.ori_timesteps.sort()
model = model.to(device)
try: # global try
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
init_image = load_img(image)
init_image = F.interpolate(
init_image,
size=(int(init_image.size(-2)*upscale),
int(init_image.size(-1)*upscale)),
mode='bicubic',
)
if init_image.size(-1) < min_size or init_image.size(-2) < min_size:
ori_size = init_image.size()
rescale = min_size * 1.0 / min(init_image.size(-2), init_image.size(-1))
new_h = max(int(ori_size[-2]*rescale), min_size)
new_w = max(int(ori_size[-1]*rescale), min_size)
init_template = F.interpolate(
init_image,
size=(new_h, new_w),
mode='bicubic',
)
else:
init_template = init_image
rescale = 1
init_template = init_template.clamp(-1, 1)
assert init_template.size(-1) >= min_size
assert init_template.size(-2) >= min_size
init_template = init_template.type(torch.float16).to(device)
if init_template.size(-1) <= 1280 or init_template.size(-2) <= 1280:
init_latent_generator, enc_fea_lq = vq_model.encode(init_template)
init_latent = model.get_first_stage_encoding(init_latent_generator)
text_init = ['']*init_template.size(0)
semantic_c = model.cond_stage_model(text_init)
noise = torch.randn_like(init_latent)
t = repeat(torch.tensor([999]), '1 -> b', b=init_image.size(0))
t = t.to(device).long()
x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
if init_template.size(-1)<= min_size and init_template.size(-2) <= min_size:
samples, _ = model.sample(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True)
else:
samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=init_template.size(0))
x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
if colorfix_type == 'adain':
x_samples = adaptive_instance_normalization(x_samples, init_template)
elif colorfix_type == 'wavelet':
x_samples = wavelet_reconstruction(x_samples, init_template)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
else:
im_spliter = ImageSpliterTh(init_template, 1280, 1000, sf=1)
for im_lq_pch, index_infos in im_spliter:
init_latent = model.get_first_stage_encoding(model.encode_first_stage(im_lq_pch)) # move to latent space
text_init = ['']*init_latent.size(0)
semantic_c = model.cond_stage_model(text_init)
noise = torch.randn_like(init_latent)
# If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
t = repeat(torch.tensor([999]), '1 -> b', b=init_template.size(0))
t = t.to(device).long()
x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
# x_T = noise
samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=im_lq_pch.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=im_lq_pch.size(0))
_, enc_fea_lq = vq_model.encode(im_lq_pch)
x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
if colorfix_type == 'adain':
x_samples = adaptive_instance_normalization(x_samples, im_lq_pch)
elif colorfix_type == 'wavelet':
x_samples = wavelet_reconstruction(x_samples, im_lq_pch)
im_spliter.update(x_samples, index_infos)
x_samples = im_spliter.gather()
x_samples = torch.clamp((x_samples+1.0)/2.0, min=0.0, max=1.0)
if rescale > 1:
x_samples = F.interpolate(
x_samples,
size=(int(init_image.size(-2)),
int(init_image.size(-1))),
mode='bicubic',
)
x_samples = x_samples.clamp(0, 1)
x_sample = 255. * rearrange(x_samples[0].cpu().numpy(), 'c h w -> h w c')
restored_img = x_sample.astype(np.uint8)
Image.fromarray(x_sample.astype(np.uint8)).save(f'output/out.png')
return restored_img, f'output/out.png'
except Exception as error:
print('Global exception', error)
return None, None
# Gradio UI
with gr.Blocks(title="Exploiting Diffusion Prior for Real-World Image Super-Resolution") as demo:
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; height: 40px;">
<img src="https://user-images.githubusercontent.com/22350795/236680126-0b1cdd62-d6fc-4620-b998-75ed6c31bf6f.png"
alt="StableSR logo" style='height:40px'>
</div>
<div style='text-align: center;'>
<h2>Exploiting Diffusion Prior for Real-World Image Super-Resolution</h2>
<p><strong>Official Gradio demo</strong> for <a href='https://github.com/IceClear/StableSR' target='_blank'>StableSR</a>.<br>
🔥 StableSR is a general image super-resolution algorithm for real-world and AIGC images.</p>
</div>
"""
)
gr.HTML(
"""
<div style="margin-top:1em">
<p>If StableSR is helpful, please help to ⭐ the <a href='https://github.com/IceClear/StableSR' target='_blank'>Github Repo</a>. Thanks!</p>
<a href='https://github.com/IceClear/StableSR' target='_blank'>
<img src='https://img.shields.io/github/stars/IceClear/StableSR?style=social'>
</a>
<hr>
<h4>Citation</h4>
<pre style="white-space: pre-wrap; background: #a7a7a7; padding: 1em; border-radius: 5px;">
@article{wang2024exploiting,
author = {Wang, Jianyi and Yue, Zongsheng and Zhou, Shangchen and Chan, Kelvin C.K. and Loy, Chen Change},
title = {Exploiting Diffusion Prior for Real-World Image Super-Resolution},
journal = {International Journal of Computer Vision},
year = {2024}
}
</pre>
<h4>License</h4>
<p>This project is licensed under <a rel="license" href="https://github.com/IceClear/StableSR/blob/main/LICENSE.txt">S-Lab License 1.0</a>. Redistribution and use for non-commercial purposes should follow this license.</p>
<h4>Contact</h4>
<p>If you have any questions, please feel free to reach out at <b>iceclearwjy@gmail.com</b>.</p>
<div style="margin-top:1em">
🤗 Find Me:<br>
<a href="https://twitter.com/Iceclearwjy">
<img src="https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social" alt="Twitter Follow">
</a>
<a href="https://github.com/IceClear">
<img src="https://img.shields.io/github/followers/IceClear?style=social" alt="Github Follow">
</a>
</div>
<div style="text-align: center; margin-top:1em">
<img src='https://visitor-badge.laobi.icu/badge?page_id=IceClear/StableSR' alt='visitors'>
</div>
</div>
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(type="filepath", label="Input")
upscale = gr.Number(value=1, label="Rescaling_Factor")
dec_w = gr.Slider(0, 1, value=0.5, step=0.01, label='CFW_Fidelity')
seed = gr.Number(value=42, label="Seeds")
model_type = gr.Dropdown(choices=["512", "768v"], value="512", label="Model")
ddpm_steps = gr.Slider(10, 1000, value=200, step=1, label='DDPM Steps')
colorfix_type = gr.Dropdown(choices=["none", "adain", "wavelet"], value="adain", label="Color Correction")
run_btn = gr.Button("Run Inference")
with gr.Column():
output_image = gr.Image(type="numpy", label="Output")
output_file = gr.File(label="Download the output")
run_btn.click(
fn=inference,
inputs=[image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type],
outputs=[output_image, output_file]
)
gr.Examples(
examples=[
['01.png', 4, 0.5, 42, "512", 200, "adain"],
['02.png', 4, 0.5, 42, "512", 200, "adain"],
['03.png', 4, 0.5, 42, "512", 200, "adain"],
['04.png', 4, 0.5, 42, "512", 200, "adain"],
['05.png', 4, 0.5, 42, "512", 200, "adain"]
],
fn=inference,
inputs=[image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type],
outputs=[output_image, output_file],
cache_examples=True
)
demo.queue()
demo.launch()