File size: 16,078 Bytes
e571ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e02395
e571ea9
1e02395
e571ea9
1e02395
 
 
e571ea9
1e02395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e571ea9
 
1e02395
 
 
 
 
 
 
 
e571ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e02395
e571ea9
 
1e02395
 
 
 
e571ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e02395
13a0f52
f0699dc
f0aa253
927da4e
 
5c37840
927da4e
f0699dc
 
 
 
 
f0aa253
 
 
f0699dc
f0aa253
f0699dc
 
 
 
 
 
 
5c37840
f0699dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0aa253
 
 
4755318
f0699dc
1e02395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e571ea9
85a3726
1e02395
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
import sys
sys.path.append('StableSR')
import os
import cv2
import torch
import torch.nn.functional as F
import gradio as gr
import torchvision
from torchvision.transforms.functional import normalize
from ldm.util import instantiate_from_config
from torch import autocast
import PIL
import numpy as np
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from omegaconf import OmegaConf
from PIL import Image
import copy
from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization
from scripts.util_image import ImageSpliterTh
from basicsr.utils.download_util import load_file_from_url
from einops import rearrange, repeat
from itertools import islice

# Download weights
pretrain_model_url = {
    'stablesr_512': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_000117.ckpt',
    'stablesr_768': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_768v_000139.ckpt',
    'CFW': 'https://huggingface.co/Iceclear/StableSR/resolve/main/vqgan_cfw_00011.ckpt',
}

for k, url in pretrain_model_url.items():
    filename = url.split("/")[-1]
    if not os.path.exists(f'./{filename}'):
        load_file_from_url(url=url, model_dir='./', progress=True, file_name=None)

# Download sample images
image_urls = [
    ('01.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/Lincoln.png'),
    ('02.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/oldphoto6.png'),
    ('03.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/comic2.png'),
    ('04.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/OST_120.png'),
    ('05.png', 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet65/comic3.png'),
]

for fname, url in image_urls:
    torch.hub.download_url_to_file(url, fname)

def load_img(path):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    w, h = map(lambda x: x - x % 32, (w, h))
    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

def space_timesteps(num_timesteps, section_counts):
	"""
	Create a list of timesteps to use from an original diffusion process,
	given the number of timesteps we want to take from equally-sized portions
	of the original process.
	For example, if there's 300 timesteps and the section counts are [10,15,20]
	then the first 100 timesteps are strided to be 10 timesteps, the second 100
	are strided to be 15 timesteps, and the final 100 are strided to be 20.
	If the stride is a string starting with "ddim", then the fixed striding
	from the DDIM paper is used, and only one section is allowed.
	:param num_timesteps: the number of diffusion steps in the original
							process to divide up.
	:param section_counts: either a list of numbers, or a string containing
							 comma-separated numbers, indicating the step count
							 per section. As a special case, use "ddimN" where N
							 is a number of steps to use the striding from the
							 DDIM paper.
	:return: a set of diffusion steps from the original process to use.
	"""
	if isinstance(section_counts, str):
		if section_counts.startswith("ddim"):
			desired_count = int(section_counts[len("ddim"):])
			for i in range(1, num_timesteps):
				if len(range(0, num_timesteps, i)) == desired_count:
					return set(range(0, num_timesteps, i))
			raise ValueError(
				f"cannot create exactly {num_timesteps} steps with an integer stride"
			)
		section_counts = [int(x) for x in section_counts.split(",")]   #[250,]
	size_per = num_timesteps // len(section_counts)
	extra = num_timesteps % len(section_counts)
	start_idx = 0
	all_steps = []
	for i, section_count in enumerate(section_counts):
		size = size_per + (1 if i < extra else 0)
		if size < section_count:
			raise ValueError(
				f"cannot divide section of {size} steps into {section_count}"
			)
		if section_count <= 1:
			frac_stride = 1
		else:
			frac_stride = (size - 1) / (section_count - 1)
		cur_idx = 0.0
		taken_steps = []
		for _ in range(section_count):
			taken_steps.append(start_idx + round(cur_idx))
			cur_idx += frac_stride
		all_steps += taken_steps
		start_idx += size
	return set(all_steps)

def chunk(it, size):
	it = iter(it)
	return iter(lambda: tuple(islice(it, size)), ())

def load_model_from_config(config, ckpt, verbose=False):
	print(f"Loading model from {ckpt}")
	pl_sd = torch.load(ckpt, map_location="cpu")
	if "global_step" in pl_sd:
		print(f"Global Step: {pl_sd['global_step']}")
	sd = pl_sd["state_dict"]
	model = instantiate_from_config(config.model)
	m, u = model.load_state_dict(sd, strict=False)
	if len(m) > 0 and verbose:
		print("missing keys:")
		print(m)
	if len(u) > 0 and verbose:
		print("unexpected keys:")
		print(u)

	model.cuda()
	model.eval()
	return model

# Load VQGAN model
device = torch.device("cuda")
vqgan_config = OmegaConf.load("StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
vq_model = instantiate_from_config(vqgan_config.model)
vq_sd = torch.load('./vqgan_cfw_00011.ckpt', map_location='cpu')['state_dict']
vq_model.load_state_dict(vq_sd, strict=False)
vq_model.cuda().eval()

os.makedirs('output', exist_ok=True)

def inference(image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type):
	"""Run a single prediction on the model"""
	precision_scope = autocast
	vq_model.decoder.fusion_w = dec_w
	seed_everything(seed)

	if model_type == '512':
		config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_512.yaml")
		model = load_model_from_config(config, "./stablesr_000117.ckpt")
		min_size = 512
	else:
		config = OmegaConf.load("StableSR/configs/stableSRNew/v2-finetune_text_T_768v.yaml")
		model = load_model_from_config(config, "./stablesr_768v_000139.ckpt")
		min_size = 768

	model = model.to(device)
	model.configs = config
	model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000,
							linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3)
	model.num_timesteps = 1000

	sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod)
	sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod)

	use_timesteps = set(space_timesteps(1000, [ddpm_steps]))
	last_alpha_cumprod = 1.0
	new_betas = []
	timestep_map = []
	for i, alpha_cumprod in enumerate(model.alphas_cumprod):
		if i in use_timesteps:
			new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
			last_alpha_cumprod = alpha_cumprod
			timestep_map.append(i)
	new_betas = [beta.data.cpu().numpy() for beta in new_betas]
	model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas))
	model.num_timesteps = 1000
	model.ori_timesteps = list(use_timesteps)
	model.ori_timesteps.sort()
	model = model.to(device)

	try: # global try
			with torch.no_grad():
				with precision_scope("cuda"):
					with model.ema_scope():
						init_image = load_img(image)
						init_image = F.interpolate(
									init_image,
									size=(int(init_image.size(-2)*upscale),
											int(init_image.size(-1)*upscale)),
									mode='bicubic',
									)

						if init_image.size(-1) < min_size or init_image.size(-2) < min_size:
							ori_size = init_image.size()
							rescale = min_size * 1.0 / min(init_image.size(-2), init_image.size(-1))
							new_h = max(int(ori_size[-2]*rescale), min_size)
							new_w = max(int(ori_size[-1]*rescale), min_size)
							init_template = F.interpolate(
										init_image,
										size=(new_h, new_w),
										mode='bicubic',
										)
						else:
							init_template = init_image
							rescale = 1
						init_template = init_template.clamp(-1, 1)
						assert init_template.size(-1) >= min_size
						assert init_template.size(-2) >= min_size

						init_template = init_template.type(torch.float16).to(device)

						if init_template.size(-1) <= 1280 or init_template.size(-2) <= 1280:
							init_latent_generator, enc_fea_lq = vq_model.encode(init_template)
							init_latent = model.get_first_stage_encoding(init_latent_generator)
							text_init = ['']*init_template.size(0)
							semantic_c = model.cond_stage_model(text_init)

							noise = torch.randn_like(init_latent)

							t = repeat(torch.tensor([999]), '1 -> b', b=init_image.size(0))
							t = t.to(device).long()
							x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)

							if init_template.size(-1)<= min_size and init_template.size(-2) <= min_size:
								samples, _ = model.sample(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True)
							else:
								samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=init_template.size(0))
							x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
							if colorfix_type == 'adain':
								x_samples = adaptive_instance_normalization(x_samples, init_template)
							elif colorfix_type == 'wavelet':
								x_samples = wavelet_reconstruction(x_samples, init_template)
							x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
						else:
							im_spliter = ImageSpliterTh(init_template, 1280, 1000, sf=1)
							for im_lq_pch, index_infos in im_spliter:
								init_latent = model.get_first_stage_encoding(model.encode_first_stage(im_lq_pch))  # move to latent space
								text_init = ['']*init_latent.size(0)
								semantic_c = model.cond_stage_model(text_init)
								noise = torch.randn_like(init_latent)
								# If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
								t = repeat(torch.tensor([999]), '1 -> b', b=init_template.size(0))
								t = t.to(device).long()
								x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
								# x_T = noise
								samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=im_lq_pch.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=im_lq_pch.size(0))
								_, enc_fea_lq = vq_model.encode(im_lq_pch)
								x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
								if colorfix_type == 'adain':
									x_samples = adaptive_instance_normalization(x_samples, im_lq_pch)
								elif colorfix_type == 'wavelet':
									x_samples = wavelet_reconstruction(x_samples, im_lq_pch)
								im_spliter.update(x_samples, index_infos)
							x_samples = im_spliter.gather()
							x_samples = torch.clamp((x_samples+1.0)/2.0, min=0.0, max=1.0)

			if rescale > 1:
				x_samples = F.interpolate(
							x_samples,
							size=(int(init_image.size(-2)),
									int(init_image.size(-1))),
							mode='bicubic',
							)
				x_samples = x_samples.clamp(0, 1)
			x_sample = 255. * rearrange(x_samples[0].cpu().numpy(), 'c h w -> h w c')
			restored_img = x_sample.astype(np.uint8)
			Image.fromarray(x_sample.astype(np.uint8)).save(f'output/out.png')

			return restored_img, f'output/out.png'
	except Exception as error:
		print('Global exception', error)
		return None, None

# Gradio UI
with gr.Blocks(title="Exploiting Diffusion Prior for Real-World Image Super-Resolution") as demo:
    gr.HTML(
        """
        <div style="display: flex; justify-content: center; align-items: center; height: 40px;">
          <img src="https://user-images.githubusercontent.com/22350795/236680126-0b1cdd62-d6fc-4620-b998-75ed6c31bf6f.png"
               alt="StableSR logo" style='height:40px'>
        </div>
        <div style='text-align: center;'>
            <h2>Exploiting Diffusion Prior for Real-World Image Super-Resolution</h2>
            <p><strong>Official Gradio demo</strong> for <a href='https://github.com/IceClear/StableSR' target='_blank'>StableSR</a>.<br>
            🔥 StableSR is a general image super-resolution algorithm for real-world and AIGC images.</p>
        </div>
        """
    )

    gr.HTML(
        """
        <div style="margin-top:1em">
        <p>If StableSR is helpful, please help to ⭐ the <a href='https://github.com/IceClear/StableSR' target='_blank'>Github Repo</a>. Thanks!</p>
        <a href='https://github.com/IceClear/StableSR' target='_blank'>
            <img src='https://img.shields.io/github/stars/IceClear/StableSR?style=social'>
        </a>
        <hr>
        <h4>Citation</h4>
        <pre style="white-space: pre-wrap; background: #a7a7a7; padding: 1em; border-radius: 5px;">
            @article{wang2024exploiting,
              author = {Wang, Jianyi and Yue, Zongsheng and Zhou, Shangchen and Chan, Kelvin C.K. and Loy, Chen Change},
              title = {Exploiting Diffusion Prior for Real-World Image Super-Resolution},
              journal = {International Journal of Computer Vision},
              year = {2024}
            }
        </pre>
        <h4>License</h4>
        <p>This project is licensed under <a rel="license" href="https://github.com/IceClear/StableSR/blob/main/LICENSE.txt">S-Lab License 1.0</a>. Redistribution and use for non-commercial purposes should follow this license.</p>

        <h4>Contact</h4>
        <p>If you have any questions, please feel free to reach out at <b>iceclearwjy@gmail.com</b>.</p>

        <div style="margin-top:1em">
            🤗 Find Me:<br>
            <a href="https://twitter.com/Iceclearwjy">
                <img src="https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social" alt="Twitter Follow">
            </a>
            <a href="https://github.com/IceClear">
                <img src="https://img.shields.io/github/followers/IceClear?style=social" alt="Github Follow">
            </a>
        </div>
        <div style="text-align: center; margin-top:1em">
            <img src='https://visitor-badge.laobi.icu/badge?page_id=IceClear/StableSR' alt='visitors'>
        </div>
        </div>
        """
    )


    with gr.Row():
        with gr.Column():
            image = gr.Image(type="filepath", label="Input")
            upscale = gr.Number(value=1, label="Rescaling_Factor")
            dec_w = gr.Slider(0, 1, value=0.5, step=0.01, label='CFW_Fidelity')
            seed = gr.Number(value=42, label="Seeds")
            model_type = gr.Dropdown(choices=["512", "768v"], value="512", label="Model")
            ddpm_steps = gr.Slider(10, 1000, value=200, step=1, label='DDPM Steps')
            colorfix_type = gr.Dropdown(choices=["none", "adain", "wavelet"], value="adain", label="Color Correction")
            run_btn = gr.Button("Run Inference")

        with gr.Column():
            output_image = gr.Image(type="numpy", label="Output")
            output_file = gr.File(label="Download the output")

    run_btn.click(
        fn=inference,
        inputs=[image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type],
        outputs=[output_image, output_file]
    )

    gr.Examples(
        examples=[
            ['01.png', 4, 0.5, 42, "512", 200, "adain"],
            ['02.png', 4, 0.5, 42, "512", 200, "adain"],
            ['03.png', 4, 0.5, 42, "512", 200, "adain"],
            ['04.png', 4, 0.5, 42, "512", 200, "adain"],
            ['05.png', 4, 0.5, 42, "512", 200, "adain"]
        ],
        fn=inference,
        inputs=[image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type],
        outputs=[output_image, output_file],
        cache_examples=True
    )

demo.queue()
demo.launch()