Xingqian Xu commited on
Commit
2fbcf51
·
0 Parent(s):

New app first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. .gitignore +9 -0
  3. README.md +14 -0
  4. app.py +1083 -0
  5. assets/demo/mcg_example/e0i0.jpg +0 -0
  6. assets/demo/mcg_example/e0i1.jpg +0 -0
  7. assets/demo/mcg_example/e0i2.jpg +0 -0
  8. assets/demo/misc/mask_inst1.gif +3 -0
  9. assets/demo/misc/mask_inst2.gif +3 -0
  10. assets/demo/misc/mask_inst3.gif +3 -0
  11. assets/demo/misc/noimage.jpg +0 -0
  12. assets/demo/reg_example/benz.jpg +0 -0
  13. assets/demo/reg_example/boy_and_girl.jpg +0 -0
  14. assets/demo/reg_example/church.jpg +0 -0
  15. assets/demo/reg_example/firework.jpg +0 -0
  16. assets/demo/reg_example/ghibli.jpg +0 -0
  17. assets/demo/reg_example/horse.jpg +0 -0
  18. assets/demo/reg_example/house_by_lake.jpg +0 -0
  19. assets/demo/reg_example/matisse.jpg +0 -0
  20. assets/demo/reg_example/night_light.jpg +0 -0
  21. assets/demo/reg_example/noimage.jpg +0 -0
  22. assets/demo/reg_example/paris.jpg +0 -0
  23. assets/demo/reg_example/penguin.jpg +0 -0
  24. assets/demo/reg_example/san_diego.jpg +0 -0
  25. assets/demo/reg_example/scream.jpg +0 -0
  26. assets/demo/reg_example/space.jpg +0 -0
  27. assets/demo/reg_example/tiger.jpg +0 -0
  28. assets/demo/reg_example/train.jpg +0 -0
  29. assets/demo/reg_example/vermeer.jpg +0 -0
  30. assets/demo/tcg_example/e0i0.jpg +0 -0
  31. assets/demo/tcg_example/e0i1.jpg +0 -0
  32. assets/demo/tcg_example/e1i0.jpg +0 -0
  33. assets/demo/tcg_example/e1i1.jpg +0 -0
  34. assets/demo/tcg_example/e2i0.jpg +0 -0
  35. assets/figures/share_instruction.png +0 -0
  36. configs/model/autokl.yaml +23 -0
  37. configs/model/clip.yaml +13 -0
  38. configs/model/openai_unet.yaml +96 -0
  39. configs/model/optimus.yaml +102 -0
  40. configs/model/vd.yaml +29 -0
  41. cusomized_gradio_blocks.py +271 -0
  42. gradio_cached_examples/12/Image Result/23645e03-6435-4819-a746-2840be976419/captions.json +1 -0
  43. gradio_cached_examples/12/Image Result/23645e03-6435-4819-a746-2840be976419/tmp0m_lns_xtd2zm06b.png +0 -0
  44. gradio_cached_examples/12/Image Result/23645e03-6435-4819-a746-2840be976419/tmp9xugbhobbnp5ds0r.png +0 -0
  45. gradio_cached_examples/12/log.csv +2 -0
  46. lib/__init__.py +0 -0
  47. lib/cfg_helper.py +612 -0
  48. lib/cfg_holder.py +28 -0
  49. lib/log_service.py +166 -0
  50. lib/model_zoo/__init__.py +4 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pth filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .vscode/
3
+ src/
4
+ data/
5
+ data
6
+ log/
7
+ log
8
+ pretrained/
9
+ pretrained
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Versatile Diffusion
3
+ emoji:
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.8.5
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1083 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ # Copyright (C) 2023 Xingqian Xu - All Rights Reserved #
3
+ # #
4
+ # Please visit Versatile Diffusion's arXiv paper for more details, link at #
5
+ # arxiv.org/abs/2211.08332 #
6
+ # #
7
+ # Besides, this work is also inspired by many established techniques including:#
8
+ # Denoising Diffusion Probablistic Model; Denoising Diffusion Implicit Model; #
9
+ # Latent Diffusion Model; Stable Diffusion; Stable Diffusion - Img2Img; Stable #
10
+ # Diffusion - Variation; ImageMixer; DreamBooth; Stable Diffusion - Lora; More #
11
+ # Control for Free; Prompt-to-Prompt; #
12
+ # #
13
+ ################################################################################
14
+
15
+ import gradio as gr
16
+ import os
17
+ import PIL
18
+ from PIL import Image
19
+ from pathlib import Path
20
+ import numpy as np
21
+ import numpy.random as npr
22
+ from contextlib import nullcontext
23
+ import types
24
+
25
+ import torch
26
+ import torchvision.transforms as tvtrans
27
+ from lib.cfg_helper import model_cfg_bank
28
+ from lib.model_zoo import get_model
29
+ from cusomized_gradio_blocks import create_myexamples, customized_as_example, customized_postprocess
30
+
31
+ n_sample_image = 2
32
+ n_sample_text = 4
33
+ cache_examples = True
34
+
35
+ from lib.model_zoo.ddim import DDIMSampler
36
+
37
+ ##########
38
+ # helper #
39
+ ##########
40
+
41
+ def highlight_print(info):
42
+ print('')
43
+ print(''.join(['#']*(len(info)+4)))
44
+ print('# '+info+' #')
45
+ print(''.join(['#']*(len(info)+4)))
46
+ print('')
47
+
48
+ def decompose(x, q=20, niter=100):
49
+ x_mean = x.mean(-1, keepdim=True)
50
+ x_input = x - x_mean
51
+ u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
52
+ ss = torch.stack([torch.diag(si) for si in s])
53
+ x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
54
+ x_remain = x_input - x_lowrank
55
+ return u, s, v, x_mean, x_remain
56
+
57
+ class adjust_rank(object):
58
+ def __init__(self, max_drop_rank=[1, 5], q=20):
59
+ self.max_semantic_drop_rank = max_drop_rank[0]
60
+ self.max_style_drop_rank = max_drop_rank[1]
61
+ self.q = q
62
+
63
+ def t2y0_semf_wrapper(t0, y00, t1, y01):
64
+ return lambda t: (np.exp((t-0.5)*2)-t0)/(t1-t0)*(y01-y00)+y00
65
+ t0, y00 = np.exp((0 -0.5)*2), -self.max_semantic_drop_rank
66
+ t1, y01 = np.exp((0.5-0.5)*2), 1
67
+ self.t2y0_semf = t2y0_semf_wrapper(t0, y00, t1, y01)
68
+
69
+ def x2y_semf_wrapper(x0, x1, y1):
70
+ return lambda x, y0: (x-x0)/(x1-x0)*(y1-y0)+y0
71
+ x0 = 0
72
+ x1, y1 = self.max_semantic_drop_rank+1, 1
73
+ self.x2y_semf = x2y_semf_wrapper(x0, x1, y1)
74
+
75
+ def t2y0_styf_wrapper(t0, y00, t1, y01):
76
+ return lambda t: (np.exp((t-0.5)*2)-t0)/(t1-t0)*(y01-y00)+y00
77
+ t0, y00 = np.exp((1 -0.5)*2), -(q-self.max_style_drop_rank)
78
+ t1, y01 = np.exp((0.5-0.5)*2), 1
79
+ self.t2y0_styf = t2y0_styf_wrapper(t0, y00, t1, y01)
80
+
81
+ def x2y_styf_wrapper(x0, x1, y1):
82
+ return lambda x, y0: (x-x0)/(x1-x0)*(y1-y0)+y0
83
+ x0 = q-1
84
+ x1, y1 = self.max_style_drop_rank-1, 1
85
+ self.x2y_styf = x2y_styf_wrapper(x0, x1, y1)
86
+
87
+ def __call__(self, x, lvl):
88
+ if lvl == 0.5:
89
+ return x
90
+
91
+ if x.dtype == torch.float16:
92
+ fp16 = True
93
+ x = x.float()
94
+ else:
95
+ fp16 = False
96
+ std_save = x.std(axis=[-2, -1])
97
+
98
+ u, s, v, x_mean, x_remain = decompose(x, q=self.q)
99
+
100
+ if lvl < 0.5:
101
+ assert lvl>=0
102
+ for xi in range(0, self.max_semantic_drop_rank+1):
103
+ y0 = self.t2y0_semf(lvl)
104
+ yi = self.x2y_semf(xi, y0)
105
+ yi = 0 if yi<0 else yi
106
+ s[:, xi] *= yi
107
+
108
+ elif lvl > 0.5:
109
+ assert lvl <= 1
110
+ for xi in range(self.max_style_drop_rank, self.q):
111
+ y0 = self.t2y0_styf(lvl)
112
+ yi = self.x2y_styf(xi, y0)
113
+ yi = 0 if yi<0 else yi
114
+ s[:, xi] *= yi
115
+ x_remain = 0
116
+
117
+ ss = torch.stack([torch.diag(si) for si in s])
118
+ x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
119
+ x_new = x_lowrank + x_mean + x_remain
120
+
121
+ std_new = x_new.std(axis=[-2, -1])
122
+ x_new = x_new / std_new * std_save
123
+
124
+ if fp16:
125
+ x_new = x_new.half()
126
+
127
+ return x_new
128
+
129
+ def remove_duplicate_word(tx):
130
+ def combine_words(input, length):
131
+ combined_inputs = []
132
+ if len(splitted_input)>1:
133
+ for i in range(len(input)-1):
134
+ combined_inputs.append(input[i]+" "+last_word_of(splitted_input[i+1],length)) #add the last word of the right-neighbour (overlapping) sequence (before it has expanded), which is the next word in the original sentence
135
+ return combined_inputs, length+1
136
+
137
+ def remove_duplicates(input, length):
138
+ bool_broke=False #this means we didn't find any duplicates here
139
+ for i in range(len(input) - length):
140
+ if input[i]==input[i + length]: #found a duplicate piece of sentence!
141
+ for j in range(0, length): #remove the overlapping sequences in reverse order
142
+ del input[i + length - j]
143
+ bool_broke = True
144
+ break #break the for loop as the loop length does not matches the length of splitted_input anymore as we removed elements
145
+ if bool_broke:
146
+ return remove_duplicates(input, length) #if we found a duplicate, look for another duplicate of the same length
147
+ return input
148
+
149
+ def last_word_of(input, length):
150
+ splitted = input.split(" ")
151
+ if len(splitted)==0:
152
+ return input
153
+ else:
154
+ return splitted[length-1]
155
+
156
+ def split_and_puncsplit(text):
157
+ tx = text.split(" ")
158
+ txnew = []
159
+ for txi in tx:
160
+ txqueue=[]
161
+ while True:
162
+ if txi[0] in '([{':
163
+ txqueue.extend([txi[:1], '<puncnext>'])
164
+ txi = txi[1:]
165
+ if len(txi) == 0:
166
+ break
167
+ else:
168
+ break
169
+ txnew += txqueue
170
+ txstack=[]
171
+ if len(txi) == 0:
172
+ continue
173
+ while True:
174
+ if txi[-1] in '?!.,:;}])':
175
+ txstack = ['<puncnext>', txi[-1:]] + txstack
176
+ txi = txi[:-1]
177
+ if len(txi) == 0:
178
+ break
179
+ else:
180
+ break
181
+ if len(txi) != 0:
182
+ txnew += [txi]
183
+ txnew += txstack
184
+ return txnew
185
+
186
+ if tx == '':
187
+ return tx
188
+
189
+ splitted_input = split_and_puncsplit(tx)
190
+ word_length = 1
191
+ intermediate_output = False
192
+ while len(splitted_input)>1:
193
+ splitted_input = remove_duplicates(splitted_input, word_length)
194
+ if len(splitted_input)>1:
195
+ splitted_input, word_length = combine_words(splitted_input, word_length)
196
+ if intermediate_output:
197
+ print(splitted_input)
198
+ print(word_length)
199
+ output = splitted_input[0]
200
+ output = output.replace(' <puncnext> ', '')
201
+ return output
202
+
203
+ def get_instruction(mode):
204
+ t2i_instruction = ["Generate image from text prompt."]
205
+ i2i_instruction = ["Generate image conditioned on reference image.",]
206
+ i2t_instruction = ["Generate text from reference image. "]
207
+ t2t_instruction = ["Generate text from reference text prompt. "]
208
+ dcg_instruction = ["Generate image conditioned on both text and image."]
209
+ tcg_instruction = ["Generate image conditioned on text and up to two images."]
210
+ mcg_instruction = ["Generate image from multiple contexts."]
211
+
212
+ if mode == "Text-to-Image":
213
+ return '\n'.join(t2i_instruction)
214
+ elif mode == "Image-Variation":
215
+ return '\n'.join(i2i_instruction)
216
+ elif mode == "Image-to-Text":
217
+ return '\n'.join(i2t_instruction)
218
+ elif mode == "Text-Variation":
219
+ return '\n'.join(t2t_instruction)
220
+ elif mode == "Dual-Context":
221
+ return '\n'.join(dcg_instruction)
222
+ elif mode == "Triple-Context":
223
+ return '\n'.join(tcg_instruction)
224
+ elif mode == "Multi-Context":
225
+ return '\n'.join(mcg_instruction)
226
+ else:
227
+ assert False
228
+
229
+ ########
230
+ # main #
231
+ ########
232
+ class vd_dummy(object):
233
+ def __init__(self, *args, **kwarg):
234
+ self.which = 'Vdummy'
235
+ def inference_t2i(self, *args, **kwarg): pass
236
+ def inference_i2i(self, *args, **kwarg): pass
237
+ def inference_i2t(self, *args, **kwarg): pass
238
+ def inference_t2t(self, *args, **kwarg): pass
239
+ def inference_dcg(self, *args, **kwarg): pass
240
+ def inference_tcg(self, *args, **kwarg): pass
241
+ def inference_mcg(self, *args, **kwarg):
242
+ return None, None
243
+
244
+ class vd_inference(object):
245
+ def __init__(self, fp16=False, which='v2.0'):
246
+ highlight_print(which)
247
+ self.which = which
248
+
249
+ if self.which == 'v1.0':
250
+ cfgm = model_cfg_bank()('vd_four_flow_v1-0')
251
+ else:
252
+ assert False, 'Model type not supported'
253
+ net = get_model()(cfgm)
254
+
255
+ if self.which == 'v1.0':
256
+ sd = torch.load('pretrained/vd-four-flow-v1-0.pth', map_location='cpu')
257
+ net.load_state_dict(sd, strict=False)
258
+
259
+ if fp16:
260
+ highlight_print('Running in FP16')
261
+ if self.which == 'v1.0':
262
+ net.ctx['text'].fp16 = True
263
+ net.ctx['image'].fp16 = True
264
+ net = net.half()
265
+ self.dtype = torch.float16
266
+ else:
267
+ self.dtype = torch.float32
268
+
269
+ self.use_cuda = torch.cuda.is_available()
270
+ if self.use_cuda:
271
+ net.to('cuda')
272
+ self.net = net
273
+ self.sampler = DDIMSampler(net)
274
+
275
+ self.output_dim = [512, 512]
276
+ self.n_sample_image = n_sample_image
277
+ self.n_sample_text = n_sample_text
278
+ self.ddim_steps = 50
279
+ self.ddim_eta = 0.0
280
+ self.scale_textto = 7.5
281
+ self.image_latent_dim = 4
282
+ self.text_latent_dim = 768
283
+ self.text_temperature = 1
284
+
285
+ if which == 'v1.0':
286
+ self.adjust_rank_f = adjust_rank(max_drop_rank=[1, 5], q=20)
287
+ self.scale_imgto = 7.5
288
+ self.disentanglement_noglobal = True
289
+
290
+ def inference_t2i(self, text, seed):
291
+ n_samples = self.n_sample_image
292
+ scale = self.scale_textto
293
+ sampler = self.sampler
294
+ h, w = self.output_dim
295
+ u = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
296
+ c = self.net.ctx_encode([text], which='text').repeat(n_samples, 1, 1)
297
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
298
+ np.random.seed(seed)
299
+ torch.manual_seed(seed + 100)
300
+ x, _ = sampler.sample(
301
+ steps=self.ddim_steps,
302
+ x_info={'type':'image'},
303
+ c_info={'type':'text', 'conditioning':c, 'unconditional_conditioning':u,
304
+ 'unconditional_guidance_scale':scale},
305
+ shape=shape,
306
+ verbose=False,
307
+ eta=self.ddim_eta)
308
+ im = self.net.vae_decode(x, which='image')
309
+ im = [tvtrans.ToPILImage()(i) for i in im]
310
+ return im
311
+
312
+ def inference_i2i(self, im, fid_lvl, fcs_lvl, clr_adj, seed):
313
+ n_samples = self.n_sample_image
314
+ scale = self.scale_imgto
315
+ sampler = self.sampler
316
+ h, w = self.output_dim
317
+ device = self.net.device
318
+
319
+ BICUBIC = PIL.Image.Resampling.BICUBIC
320
+ im = im.resize([w, h], resample=BICUBIC)
321
+
322
+ if fid_lvl == 1:
323
+ return [im]*n_samples
324
+
325
+ cx = tvtrans.ToTensor()(im)[None].to(device).to(self.dtype)
326
+
327
+ c = self.net.ctx_encode(cx, which='image')
328
+ if self.disentanglement_noglobal:
329
+ c_glb = c[:, 0:1]
330
+ c_loc = c[:, 1: ]
331
+ c_loc = self.adjust_rank_f(c_loc, fcs_lvl)
332
+ c = torch.cat([c_glb, c_loc], dim=1).repeat(n_samples, 1, 1)
333
+ else:
334
+ c = self.adjust_rank_f(c, fcs_lvl).repeat(n_samples, 1, 1)
335
+ u = torch.zeros_like(c)
336
+
337
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
338
+ np.random.seed(seed)
339
+ torch.manual_seed(seed + 100)
340
+ if fid_lvl!=0:
341
+ x0 = self.net.vae_encode(cx, which='image').repeat(n_samples, 1, 1, 1)
342
+ step = int(self.ddim_steps * (1-fid_lvl))
343
+ x, _ = sampler.sample(
344
+ steps=self.ddim_steps,
345
+ x_info={'type':'image', 'x0':x0, 'x0_forward_timesteps':step},
346
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
347
+ 'unconditional_guidance_scale':scale},
348
+ shape=shape,
349
+ verbose=False,
350
+ eta=self.ddim_eta)
351
+ else:
352
+ x, _ = sampler.sample(
353
+ steps=self.ddim_steps,
354
+ x_info={'type':'image',},
355
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
356
+ 'unconditional_guidance_scale':scale},
357
+ shape=shape,
358
+ verbose=False,
359
+ eta=self.ddim_eta)
360
+
361
+ imout = self.net.vae_decode(x, which='image')
362
+
363
+ if clr_adj == 'Simple':
364
+ cx_mean = cx.view(3, -1).mean(-1)[:, None, None]
365
+ cx_std = cx.view(3, -1).std(-1)[:, None, None]
366
+ imout_mean = [imouti.view(3, -1).mean(-1)[:, None, None] for imouti in imout]
367
+ imout_std = [imouti.view(3, -1).std(-1)[:, None, None] for imouti in imout]
368
+ imout = [(ii-mi)/si*cx_std+cx_mean for ii, mi, si in zip(imout, imout_mean, imout_std)]
369
+ imout = [torch.clamp(ii, 0, 1) for ii in imout]
370
+
371
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
372
+ return imout
373
+
374
+ def inference_i2t(self, im, seed):
375
+ n_samples = self.n_sample_text
376
+ scale = self.scale_imgto
377
+ sampler = self.sampler
378
+ h, w = self.output_dim
379
+ device = self.net.device
380
+
381
+ BICUBIC = PIL.Image.Resampling.BICUBIC
382
+ im = im.resize([w, h], resample=BICUBIC)
383
+
384
+ cx = tvtrans.ToTensor()(im)[None].to(device)
385
+ c = self.net.ctx_encode(cx, which='image').repeat(n_samples, 1, 1)
386
+ u = self.net.ctx_encode(torch.zeros_like(cx), which='image').repeat(n_samples, 1, 1)
387
+
388
+ shape = [n_samples, self.text_latent_dim]
389
+ np.random.seed(seed)
390
+ torch.manual_seed(seed + 100)
391
+ x, _ = sampler.sample(
392
+ steps=self.ddim_steps,
393
+ x_info={'type':'text',},
394
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
395
+ 'unconditional_guidance_scale':scale},
396
+ shape=shape,
397
+ verbose=False,
398
+ eta=self.ddim_eta)
399
+ tx = self.net.vae_decode(x, which='text', temperature=self.text_temperature)
400
+ tx = [remove_duplicate_word(txi) for txi in tx]
401
+ tx_combined = '\n'.join(tx)
402
+ return tx_combined
403
+
404
+ def inference_t2t(self, text, seed):
405
+ n_samples = self.n_sample_text
406
+ scale = self.scale_textto
407
+ sampler = self.sampler
408
+ u = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
409
+ c = self.net.ctx_encode([text], which='text').repeat(n_samples, 1, 1)
410
+ shape = [n_samples, self.text_latent_dim]
411
+ np.random.seed(seed)
412
+ torch.manual_seed(seed + 100)
413
+ x, _ = sampler.sample(
414
+ steps=self.ddim_steps,
415
+ x_info={'type':'text',},
416
+ c_info={'type':'text', 'conditioning':c, 'unconditional_conditioning':u,
417
+ 'unconditional_guidance_scale':scale},
418
+ shape=shape,
419
+ verbose=False,
420
+ eta=self.ddim_eta)
421
+ tx = self.net.vae_decode(x, which='text', temperature=self.text_temperature)
422
+ tx = [remove_duplicate_word(txi) for txi in tx]
423
+ tx_combined = '\n'.join(tx)
424
+ return tx_combined
425
+
426
+ def inference_dcg(self, imctx, fcs_lvl, textctx, textstrength, seed):
427
+ n_samples = self.n_sample_image
428
+ sampler = self.sampler
429
+ h, w = self.output_dim
430
+ device = self.net.device
431
+
432
+ c_info_list = []
433
+
434
+ if (textctx is not None) and (textctx != "") and (textstrength != 0):
435
+ ut = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
436
+ ct = self.net.ctx_encode([textctx], which='text').repeat(n_samples, 1, 1)
437
+ scale = self.scale_imgto*(1-textstrength) + self.scale_textto*textstrength
438
+
439
+ c_info_list.append({
440
+ 'type':'text',
441
+ 'conditioning':ct,
442
+ 'unconditional_conditioning':ut,
443
+ 'unconditional_guidance_scale':scale,
444
+ 'ratio': textstrength, })
445
+ else:
446
+ scale = self.scale_imgto
447
+ textstrength = 0
448
+
449
+ BICUBIC = PIL.Image.Resampling.BICUBIC
450
+ cx = imctx.resize([w, h], resample=BICUBIC)
451
+ cx = tvtrans.ToTensor()(cx)[None].to(device).to(self.dtype)
452
+ ci = self.net.ctx_encode(cx, which='image')
453
+
454
+ if self.disentanglement_noglobal:
455
+ ci_glb = ci[:, 0:1]
456
+ ci_loc = ci[:, 1: ]
457
+ ci_loc = self.adjust_rank_f(ci_loc, fcs_lvl)
458
+ ci = torch.cat([ci_glb, ci_loc], dim=1).repeat(n_samples, 1, 1)
459
+ else:
460
+ ci = self.adjust_rank_f(ci, fcs_lvl).repeat(n_samples, 1, 1)
461
+
462
+ c_info_list.append({
463
+ 'type':'image',
464
+ 'conditioning':ci,
465
+ 'unconditional_conditioning':torch.zeros_like(ci),
466
+ 'unconditional_guidance_scale':scale,
467
+ 'ratio': (1-textstrength), })
468
+
469
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
470
+ np.random.seed(seed)
471
+ torch.manual_seed(seed + 100)
472
+ x, _ = sampler.sample_multicontext(
473
+ steps=self.ddim_steps,
474
+ x_info={'type':'image',},
475
+ c_info_list=c_info_list,
476
+ shape=shape,
477
+ verbose=False,
478
+ eta=self.ddim_eta)
479
+
480
+ imout = self.net.vae_decode(x, which='image')
481
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
482
+ return imout
483
+
484
+ def inference_tcg(self, *args):
485
+ args_imag = list(args[0:10]) + [None, None, None, None, None]*2
486
+ args_rest = args[10:]
487
+ imin, imout = self.inference_mcg(*args_imag, *args_rest)
488
+ return imin, imout
489
+
490
+ def inference_mcg(self, *args):
491
+ imctx = [args[0:5], args[5:10], args[10:15], args[15:20]]
492
+ textctx, textstrength, seed = args[20:]
493
+
494
+ n_samples = self.n_sample_image
495
+ sampler = self.sampler
496
+ h, w = self.output_dim
497
+ device = self.net.device
498
+
499
+ c_info_list = []
500
+
501
+ if (textctx is not None) and (textctx != "") and (textstrength != 0):
502
+ ut = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
503
+ ct = self.net.ctx_encode([textctx], which='text').repeat(n_samples, 1, 1)
504
+ scale = self.scale_imgto*(1-textstrength) + self.scale_textto*textstrength
505
+
506
+ c_info_list.append({
507
+ 'type':'text',
508
+ 'conditioning':ct,
509
+ 'unconditional_conditioning':ut,
510
+ 'unconditional_guidance_scale':scale,
511
+ 'ratio': textstrength, })
512
+ else:
513
+ scale = self.scale_imgto
514
+ textstrength = 0
515
+
516
+ input_save = []
517
+ imc = []
518
+ for im, imm, strength, fcs_lvl, use_mask in imctx:
519
+ if (im is None) and (imm is None):
520
+ continue
521
+ BILINEAR = PIL.Image.Resampling.BILINEAR
522
+ BICUBIC = PIL.Image.Resampling.BICUBIC
523
+ if use_mask:
524
+ cx = imm['image'].resize([w, h], resample=BICUBIC)
525
+ cx = tvtrans.ToTensor()(cx)[None].to(self.dtype).to(device)
526
+ m = imm['mask'].resize([w, h], resample=BILINEAR)
527
+ m = tvtrans.ToTensor()(m)[None, 0:1].to(self.dtype).to(device)
528
+ m = (1-m)
529
+ cx_show = cx*m
530
+ ci = self.net.ctx_encode(cx, which='image', masks=m)
531
+ else:
532
+ cx = im.resize([w, h], resample=BICUBIC)
533
+ cx = tvtrans.ToTensor()(cx)[None].to(self.dtype).to(device)
534
+ ci = self.net.ctx_encode(cx, which='image')
535
+ cx_show = cx
536
+
537
+ input_save.append(tvtrans.ToPILImage()(cx_show[0]))
538
+
539
+ if self.disentanglement_noglobal:
540
+ ci_glb = ci[:, 0:1]
541
+ ci_loc = ci[:, 1: ]
542
+ ci_loc = self.adjust_rank_f(ci_loc, fcs_lvl)
543
+ ci = torch.cat([ci_glb, ci_loc], dim=1).repeat(n_samples, 1, 1)
544
+ else:
545
+ ci = self.adjust_rank_f(ci, fcs_lvl).repeat(n_samples, 1, 1)
546
+ imc.append(ci * strength)
547
+
548
+ cis = torch.cat(imc, dim=1)
549
+ c_info_list.append({
550
+ 'type':'image',
551
+ 'conditioning':cis,
552
+ 'unconditional_conditioning':torch.zeros_like(cis),
553
+ 'unconditional_guidance_scale':scale,
554
+ 'ratio': (1-textstrength), })
555
+
556
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
557
+ np.random.seed(seed)
558
+ torch.manual_seed(seed + 100)
559
+ x, _ = sampler.sample_multicontext(
560
+ steps=self.ddim_steps,
561
+ x_info={'type':'image',},
562
+ c_info_list=c_info_list,
563
+ shape=shape,
564
+ verbose=False,
565
+ eta=self.ddim_eta)
566
+
567
+ imout = self.net.vae_decode(x, which='image')
568
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
569
+ return input_save, imout
570
+
571
+ # vd_inference = vd_dummy()
572
+ vd_inference = vd_inference(which='v1.0', fp16=True)
573
+
574
+ #################
575
+ # sub interface #
576
+ #################
577
+
578
+ def t2i_interface(with_example=False):
579
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Text-to-Image") + '</p>')
580
+ with gr.Row():
581
+ with gr.Column():
582
+ text = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input')
583
+ seed = gr.Number(20, label="Seed", precision=0)
584
+ button = gr.Button("Run")
585
+ with gr.Column():
586
+ img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
587
+
588
+ button.click(
589
+ vd_inference.inference_t2i,
590
+ inputs=[text, seed],
591
+ outputs=[img_output])
592
+
593
+ if with_example:
594
+ gr.Examples(
595
+ label='Examples',
596
+ examples=get_example('Text-to-Image'),
597
+ fn=vd_inference.inference_t2i,
598
+ inputs=[text, seed],
599
+ outputs=[img_output],
600
+ cache_examples=cache_examples),
601
+
602
+ def i2i_interface(with_example=False):
603
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Image-Variation") + '</p>')
604
+ with gr.Row():
605
+ with gr.Column():
606
+ img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
607
+ sim_flag = gr.Checkbox(label='Show Detail Controls')
608
+ with gr.Row():
609
+ fid_lvl = gr.Slider(label="Fidelity (Dislike -- Same)", minimum=0, maximum=1, value=0, step=0.02, visible=False)
610
+ fcs_lvl = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02, visible=False)
611
+ clr_adj = gr.Radio(label="Color Adjustment", choices=["None", "Simple"], value='Simple', visible=False)
612
+ explain = gr.HTML('<p id=myinst>&nbsp Fidelity: How likely the output image looks like the referece image (0-dislike (default), 1-same).</p>'+
613
+ '<p id=myinst>&nbsp Focus: What the output image should focused on (0-semantic, 0.5-balanced (default), 1-style).</p>',
614
+ visible=False)
615
+ seed = gr.Number(20, label="Seed", precision=0)
616
+ button = gr.Button("Run")
617
+ with gr.Column():
618
+ img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
619
+
620
+ sim_flag.change(
621
+ fn=lambda x: {
622
+ explain : gr.update(visible=x),
623
+ fid_lvl : gr.update(visible=x),
624
+ fcs_lvl : gr.update(visible=x),
625
+ clr_adj : gr.update(visible=x), },
626
+ inputs=sim_flag,
627
+ outputs=[explain, fid_lvl, fcs_lvl, clr_adj, seed],)
628
+
629
+ button.click(
630
+ vd_inference.inference_i2i,
631
+ inputs=[img_input, fid_lvl, fcs_lvl, clr_adj, seed],
632
+ outputs=[img_output])
633
+
634
+ if with_example:
635
+ gr.Examples(
636
+ label='Examples',
637
+ examples=get_example('Image-Variation'),
638
+ fn=vd_inference.inference_i2i,
639
+ inputs=[img_input, fid_lvl, fcs_lvl, clr_adj, seed],
640
+ outputs=[img_output],
641
+ cache_examples=cache_examples),
642
+
643
+ def i2t_interface(with_example=False):
644
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Image-to-Text") + '</p>')
645
+ with gr.Row():
646
+ with gr.Column():
647
+ img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
648
+ seed = gr.Number(20, label="Seed", precision=0)
649
+ button = gr.Button("Run")
650
+ with gr.Column():
651
+ txt_output = gr.Textbox(lines=4, label='Text Result')
652
+
653
+ button.click(
654
+ vd_inference.inference_i2t,
655
+ inputs=[img_input, seed],
656
+ outputs=[txt_output])
657
+
658
+ if with_example:
659
+ gr.Examples(
660
+ label='Examples',
661
+ examples=get_example('Image-to-Text'),
662
+ fn=vd_inference.inference_i2t,
663
+ inputs=[img_input, seed],
664
+ outputs=[txt_output],
665
+ cache_examples=cache_examples),
666
+
667
+ def t2t_interface(with_example=False):
668
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Text-Variation") + '</p>')
669
+ with gr.Row():
670
+ with gr.Column():
671
+ text = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input')
672
+ seed = gr.Number(20, label="Seed", precision=0)
673
+ button = gr.Button("Run")
674
+ with gr.Column():
675
+ txt_output = gr.Textbox(lines=4, label='Text Result')
676
+
677
+ button.click(
678
+ vd_inference.inference_t2t,
679
+ inputs=[text, seed],
680
+ outputs=[txt_output])
681
+
682
+ if with_example:
683
+ gr.Examples(
684
+ label='Examples',
685
+ examples=get_example('Text-Variation'),
686
+ fn=vd_inference.inference_t2t,
687
+ inputs=[text, seed],
688
+ outputs=[txt_output],
689
+ cache_examples=cache_examples, )
690
+
691
+ class image_mimage_swap(object):
692
+ def __init__(self, block0, block1):
693
+ self.block0 = block0
694
+ self.block1 = block1
695
+ self.which_update = 'both'
696
+
697
+ def __call__(self, x0, x1, flag):
698
+ if self.which_update == 'both':
699
+ return self.update_both(x0, x1, flag)
700
+ elif self.which_update == 'visible':
701
+ return self.update_visible(x0, x1, flag)
702
+ elif self.which_update == 'visible_oneoff':
703
+ return self.update_visible_oneoff(x0, x1, flag)
704
+ else:
705
+ assert False
706
+
707
+ def update_both(self, x0, x1, flag):
708
+ if flag:
709
+ ug0 = gr.update(visible=False)
710
+ if x0 is None:
711
+ ug1 = gr.update(value=None, visible=True)
712
+ else:
713
+ if (x1 is not None) and ('mask' in x1):
714
+ value1 = {'image':x0, 'mask':x1['mask']}
715
+ else:
716
+ value1 = {'image':x0, 'mask':None}
717
+ ug1 = gr.update(value=value1, visible=True)
718
+ else:
719
+ if (x1 is not None) and ('image' in x1):
720
+ value0 = x1['image']
721
+ else:
722
+ value0 = None
723
+ ug0 = gr.update(value=value0, visible=True)
724
+ ug1 = gr.update(visible=False)
725
+ return {
726
+ self.block0 : ug0,
727
+ self.block1 : ug1,}
728
+
729
+ def update_visible(self, x0, x1, flag):
730
+ return {
731
+ self.block0 : gr.update(visible=not flag),
732
+ self.block1 : gr.update(visible=flag), }
733
+
734
+ def update_visible_oneoff(self, x0, x1, flag):
735
+ self.which_update = 'both'
736
+ return {
737
+ self.block0 : gr.update(visible=not flag),
738
+ self.block1 : gr.update(visible=flag), }
739
+
740
+ class example_visible_only_hack(object):
741
+ def __init__(self, checkbox_list, functor_list):
742
+ self.checkbox_list = checkbox_list
743
+ self.functor_list = functor_list
744
+
745
+ def __call__(self, *args):
746
+ for bi, fi, vi in zip(self.checkbox_list, self.functor_list, args):
747
+ if bi.value != vi:
748
+ fi.which_update = 'visible_oneoff'
749
+
750
+ def dcg_interface(with_example=False):
751
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Dual-Context") + '</p>')
752
+ with gr.Row():
753
+ input_session = []
754
+ with gr.Column():
755
+ img = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
756
+ fcs = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
757
+ gr.HTML('<p id=myinst>&nbsp Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).</p>')
758
+
759
+ text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input')
760
+ tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02)
761
+
762
+ seed = gr.Number(20, label="Seed", precision=0)
763
+ button = gr.Button("Run")
764
+
765
+ with gr.Column():
766
+ output_gallary = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
767
+
768
+ input_list = []
769
+ for i in input_session:
770
+ input_list += i
771
+ button.click(
772
+ vd_inference.inference_dcg,
773
+ inputs=[img, fcs, text, tstrength, seed],
774
+ outputs=[output_gallary])
775
+
776
+ if with_example:
777
+ gr.Examples(
778
+ label='Examples',
779
+ examples=get_example('Dual-Context'),
780
+ fn=vd_inference.inference_dcg,
781
+ inputs=[img, fcs, text, tstrength, seed],
782
+ outputs=[output_gallary],
783
+ cache_examples=cache_examples)
784
+
785
+ def tcg_interface(with_example=False):
786
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Triple-Context") + '</p>')
787
+ with gr.Row():
788
+ input_session = []
789
+ with gr.Column(min_width=940):
790
+ with gr.Row():
791
+ with gr.Column():
792
+ img0 = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
793
+ img0.as_example = types.MethodType(customized_as_example, img0)
794
+ imgm0 = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False)
795
+ imgm0.postprocess = types.MethodType(customized_postprocess, imgm0)
796
+ imgm0.as_example = types.MethodType(customized_as_example, imgm0)
797
+ istrength0 = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02)
798
+ fcs0 = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
799
+ msk0 = gr.Checkbox(label='Use mask?')
800
+ swapf0 = image_mimage_swap(img0, imgm0)
801
+
802
+ msk0.change(
803
+ fn=swapf0,
804
+ inputs=[img0, imgm0, msk0],
805
+ outputs=[img0, imgm0],)
806
+ input_session.append([img0, imgm0, istrength0, fcs0, msk0])
807
+
808
+ with gr.Column():
809
+ img1 = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
810
+ img1.as_example = types.MethodType(customized_as_example, img1)
811
+ imgm1 = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False)
812
+ imgm1.postprocess = types.MethodType(customized_postprocess, imgm1)
813
+ imgm1.as_example = types.MethodType(customized_as_example, imgm1)
814
+ istrength1 = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02)
815
+ fcs1 = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
816
+ msk1 = gr.Checkbox(label='Use mask?')
817
+ swapf1 = image_mimage_swap(img1, imgm1)
818
+
819
+ msk1.change(
820
+ fn=swapf1,
821
+ inputs=[img1, imgm1, msk1],
822
+ outputs=[img1, imgm1],)
823
+ input_session.append([img1, imgm1, istrength1, fcs1, msk1])
824
+
825
+ gr.HTML('<p id=myinst>&nbsp Weight: The strength of the reference image. This weight is subject to <u>Text Domination</u>).</p>'+
826
+ '<p id=myinst>&nbsp Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).</p>'+
827
+ '<p id=myinst>&nbsp Mask: Remove regions on reference image so they will not influence the output.</p>',)
828
+
829
+ text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input')
830
+ tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02)
831
+
832
+ seed = gr.Number(20, label="Seed", precision=0)
833
+ button = gr.Button("Run")
834
+
835
+ with gr.Column(min_width=470):
836
+ input_gallary = gr.Gallery(label="Input Display", elem_id="customized_imbox").style(grid=2)
837
+ output_gallary = gr.Gallery(label="Image Result", elem_id="customized_imbox").style(grid=n_sample_image)
838
+
839
+ input_list = []
840
+ for i in input_session:
841
+ input_list += i
842
+ input_list += [text, tstrength, seed]
843
+ button.click(
844
+ vd_inference.inference_tcg,
845
+ inputs=input_list,
846
+ outputs=[input_gallary, output_gallary])
847
+
848
+ if with_example:
849
+ create_myexamples(
850
+ label='Examples',
851
+ examples=get_example('Triple-Context'),
852
+ fn=vd_inference.inference_tcg,
853
+ inputs=input_list,
854
+ outputs=[input_gallary, output_gallary, ],
855
+ cache_examples=cache_examples, )
856
+
857
+ gr.HTML('<br><p id=myinst>&nbsp How to add mask: Please see the following instructions.</p><br>'+
858
+ '<img src="file/assets/demo/misc/mask_inst1.gif" style="float:left;max-width:450px;">'+
859
+ '<img src="file/assets/demo/misc/mask_inst2.gif" style="float:left;max-width:450px;">'+
860
+ '<img src="file/assets/demo/misc/mask_inst3.gif" style="float:left;max-width:450px;">',)
861
+
862
+ def mcg_interface(with_example=False):
863
+ num_img_input = 4
864
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Multi-Context") + '</p>')
865
+ with gr.Row():
866
+ input_session = []
867
+ with gr.Column():
868
+ for idx in range(num_img_input):
869
+ with gr.Tab('Image{}'.format(idx+1)):
870
+ img = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
871
+ img.as_example = types.MethodType(customized_as_example, img)
872
+ imgm = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False)
873
+ imgm.postprocess = types.MethodType(customized_postprocess, imgm)
874
+ imgm.as_example = types.MethodType(customized_as_example, imgm)
875
+
876
+ with gr.Row():
877
+ istrength = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02)
878
+ fcs = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
879
+ msk = gr.Checkbox(label='Use mask?')
880
+ gr.HTML('<p id=myinst>&nbsp Weight: The strength of the reference image. This weight is subject to <u>Text Domination</u>).</p>'+
881
+ '<p id=myinst>&nbsp Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).</p>'+
882
+ '<p id=myinst>&nbsp Mask: Remove regions on reference image so they will not influence the output.</p>',)
883
+
884
+ msk.change(
885
+ fn=image_mimage_swap(img, imgm),
886
+ inputs=[img, imgm, msk],
887
+ outputs=[img, imgm],)
888
+ input_session.append([img, imgm, istrength, fcs, msk])
889
+
890
+ text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input')
891
+ tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02)
892
+
893
+ seed = gr.Number(20, label="Seed", precision=0)
894
+ button = gr.Button("Run")
895
+
896
+
897
+ with gr.Column():
898
+ input_gallary = gr.Gallery(label="Input Display", elem_id='customized_imbox').style(grid=4)
899
+ output_gallary = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
900
+
901
+ input_list = []
902
+ for i in input_session:
903
+ input_list += i
904
+ input_list += [text, tstrength, seed]
905
+ button.click(
906
+ vd_inference.inference_mcg,
907
+ inputs=input_list,
908
+ outputs=[input_gallary, output_gallary], )
909
+
910
+ if with_example:
911
+ create_myexamples(
912
+ label='Examples',
913
+ examples=get_example('Multi-Context'),
914
+ fn=vd_inference.inference_mcg,
915
+ inputs=input_list,
916
+ outputs=[input_gallary, output_gallary],
917
+ cache_examples=cache_examples, )
918
+
919
+ gr.HTML('<br><p id=myinst>&nbsp How to add mask: Please see the following instructions.</p><br>'+
920
+ '<img src="file/assets/demo/misc/mask_inst1.gif" style="float:left;max-width:450px;">'+
921
+ '<img src="file/assets/demo/misc/mask_inst2.gif" style="float:left;max-width:450px;">'+
922
+ '<img src="file/assets/demo/misc/mask_inst3.gif" style="float:left;max-width:450px;">',)
923
+
924
+ ###########
925
+ # Example #
926
+ ###########
927
+
928
+ def get_example(mode):
929
+ if mode == 'Text-to-Image':
930
+ case = [
931
+ ['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ', 23],
932
+ ['a beautiful landscape with mountains and rivers', 20],
933
+ ]
934
+ elif mode == "Image-Variation":
935
+ case = [
936
+ ['assets/demo/reg_example/ghibli.jpg', 0, 0.5, 'None', 20],
937
+ ['assets/demo/reg_example/ghibli.jpg', 0.5, 0.5, 'None', 20],
938
+ ['assets/demo/reg_example/matisse.jpg', 0, 0, 'None', 20],
939
+ ['assets/demo/reg_example/matisse.jpg', 0, 1, 'Simple', 20],
940
+ ['assets/demo/reg_example/vermeer.jpg', 0.2, 0.3, 'None', 30],
941
+ ]
942
+ elif mode == "Image-to-Text":
943
+ case = [
944
+ ['assets/demo/reg_example/house_by_lake.jpg', 20],
945
+ ]
946
+ elif mode == "Text-Variation":
947
+ case = [
948
+ ['heavy arms gundam penguin mech', 20],
949
+ ]
950
+ elif mode == "Dual-Context":
951
+ case = [
952
+ ['assets/demo/reg_example/benz.jpg', 0.5, 'cyberpunk 2077', 0.7, 22],
953
+ ['assets/demo/reg_example/ghibli.jpg', 1, 'Red maple on a hill in golden Autumn.', 0.5, 21],
954
+ ]
955
+ elif mode == "Triple-Context":
956
+ case = [
957
+ [
958
+ 'assets/demo/reg_example/night_light.jpg', None, 1 , 0.5, False,
959
+ 'assets/demo/reg_example/paris.jpg' , None, 0.94, 0.5, False,
960
+ "snow on the street", 0.4, 28],
961
+ [
962
+ 'assets/demo/tcg_example/e1i0.jpg', None, 1 , 0.5, False,
963
+ 'assets/demo/tcg_example/e1i1.jpg', None, 0.94, 0.5, False,
964
+ "a painting of an elegant woman in front of the moon", 0.2, 217],
965
+ [
966
+ 'assets/demo/tcg_example/e2i0.jpg', None, 1, 0.5, False,
967
+ 'assets/demo/reg_example/paris.jpg', None, 1, 0.5, False,
968
+ "", 0, 29],
969
+ [
970
+ 'assets/demo/tcg_example/e0i0.jpg', None, 1 , 0.5, False,
971
+ 'assets/demo/tcg_example/e0i1.jpg', None, 0.9, 0.5, False,
972
+ "rose blooms on the tree", 0.2, 20],
973
+ [
974
+ 'assets/demo/reg_example/ghibli.jpg', None, 1 , 1 , False,
975
+ 'assets/demo/reg_example/space.jpg' , None, 0.84, 0.5, False,
976
+ "", 0, 20],
977
+ [
978
+ 'assets/demo/reg_example/train.jpg' , None, 0.8, 0.5, False,
979
+ 'assets/demo/reg_example/matisse.jpg', None, 1 , 1 , False,
980
+ "", 0, 20],
981
+ ]
982
+ elif mode == "Multi-Context":
983
+ case = [
984
+ [
985
+ 'assets/demo/mcg_example/e0i0.jpg', None, 1, 0.5, False,
986
+ 'assets/demo/mcg_example/e0i1.jpg', None, 1, 0.5, False,
987
+ 'assets/demo/mcg_example/e0i2.jpg', None, 0.86, 0.5, False,
988
+ None, None, 1, 0.5, False,
989
+ "", 0, 20],
990
+ ]
991
+ else:
992
+ raise ValueError
993
+ return case
994
+
995
+ #############
996
+ # Interface #
997
+ #############
998
+
999
+ css = """
1000
+ #customized_imbox {
1001
+ min-height: 450px;
1002
+ }
1003
+ #customized_imbox>div[data-testid="image"] {
1004
+ min-height: 450px;
1005
+ }
1006
+ #customized_imbox>div[data-testid="image"]>div {
1007
+ min-height: 450px;
1008
+ }
1009
+ #customized_imbox>div[data-testid="image"]>iframe {
1010
+ min-height: 450px;
1011
+ }
1012
+ #customized_imbox>div.unpadded_box {
1013
+ min-height: 450px;
1014
+ }
1015
+ #myinst {
1016
+ font-size: 0.8rem;
1017
+ margin: 0rem;
1018
+ color: #6B7280;
1019
+ }
1020
+ """
1021
+
1022
+ if True:
1023
+ with gr.Blocks(css=css) as demo:
1024
+ gr.HTML(
1025
+ """
1026
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
1027
+ <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
1028
+ Versatile Diffusion{}
1029
+ </h1>
1030
+ <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
1031
+ We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
1032
+ VD can natively support image-to-text, image-variation, text-to-image, and text-variation,
1033
+ and can be further extended to other applications such as
1034
+ semantic-style disentanglement, image-text dual-guided generation, latent image-to-text-to-image editing, and more.
1035
+ Future versions will support more modalities such as speech, music, video and 3D.
1036
+ </h2>
1037
+ <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
1038
+ Xingqian Xu, Atlas Wang, Eric Zhang, Kai Wang,
1039
+ and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a>
1040
+ [<a href="https://arxiv.org/abs/2211.08332" style="color:blue;">arXiv</a>]
1041
+ [<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
1042
+ </h3>
1043
+ </div>
1044
+ """.format(' '+vd_inference.which))
1045
+ # .format('')) #
1046
+
1047
+ with gr.Tab('Text-to-Image'):
1048
+ t2i_interface(with_example=True)
1049
+ with gr.Tab('Image-Variation'):
1050
+ i2i_interface(with_example=True)
1051
+ with gr.Tab('Image-to-Text'):
1052
+ i2t_interface(with_example=True)
1053
+ with gr.Tab('Text-Variation'):
1054
+ t2t_interface(with_example=True)
1055
+ with gr.Tab('Dual-Context Image-Generation'):
1056
+ dcg_interface(with_example=True)
1057
+ with gr.Tab('Triple-Context Image-Blender'):
1058
+ tcg_interface(with_example=True)
1059
+ with gr.Tab('Multi-Context Image-Blender'):
1060
+ mcg_interface(with_example=True)
1061
+
1062
+ gr.HTML(
1063
+ """
1064
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
1065
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1066
+ <b>Caution</b>:
1067
+ We would like the raise the awareness of users of this demo of its potential issues and concerns.
1068
+ Like previous large foundation models, Versatile Diffusion could be problematic in some cases, partially due to the imperfect training data and pretrained network (VAEs / context encoders) with limited scope.
1069
+ In its future research phase, VD may do better on tasks such as text-to-image, image-to-text, etc., with the help of more powerful VAEs, more sophisticated network designs, and more cleaned data.
1070
+ So far, we keep all features available for research testing both to show the great potential of the VD framework and to collect important feedback to improve the model in the future.
1071
+ We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
1072
+ </h3>
1073
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1074
+ <b>Biases and content acknowledgement</b>:
1075
+ Beware that VD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
1076
+ VD was trained on the LAION-2B dataset, which scraped non-curated online images and text, and may contained unintended exceptions as we removed illegal content.
1077
+ VD in this demo is meant only for research purposes.
1078
+ </h3>
1079
+ </div>
1080
+ """)
1081
+
1082
+ demo.launch(share=True)
1083
+ # demo.launch(debug=True)
assets/demo/mcg_example/e0i0.jpg ADDED
assets/demo/mcg_example/e0i1.jpg ADDED
assets/demo/mcg_example/e0i2.jpg ADDED
assets/demo/misc/mask_inst1.gif ADDED

Git LFS Details

  • SHA256: 90732a23a9a275649068654ae0c29418ea28ffb45eef6605da6d42e77390e808
  • Pointer size: 132 Bytes
  • Size of remote file: 5.23 MB
assets/demo/misc/mask_inst2.gif ADDED

Git LFS Details

  • SHA256: 183544affa3f5c76cf347e25d991a87e0eeb426b042f70ea33ef9acc6217d53f
  • Pointer size: 132 Bytes
  • Size of remote file: 5.82 MB
assets/demo/misc/mask_inst3.gif ADDED

Git LFS Details

  • SHA256: 6136887307c45b86ce451eff1102c7e996d46a107795439b3f35e4391d348b30
  • Pointer size: 132 Bytes
  • Size of remote file: 5.53 MB
assets/demo/misc/noimage.jpg ADDED
assets/demo/reg_example/benz.jpg ADDED
assets/demo/reg_example/boy_and_girl.jpg ADDED
assets/demo/reg_example/church.jpg ADDED
assets/demo/reg_example/firework.jpg ADDED
assets/demo/reg_example/ghibli.jpg ADDED
assets/demo/reg_example/horse.jpg ADDED
assets/demo/reg_example/house_by_lake.jpg ADDED
assets/demo/reg_example/matisse.jpg ADDED
assets/demo/reg_example/night_light.jpg ADDED
assets/demo/reg_example/noimage.jpg ADDED
assets/demo/reg_example/paris.jpg ADDED
assets/demo/reg_example/penguin.jpg ADDED
assets/demo/reg_example/san_diego.jpg ADDED
assets/demo/reg_example/scream.jpg ADDED
assets/demo/reg_example/space.jpg ADDED
assets/demo/reg_example/tiger.jpg ADDED
assets/demo/reg_example/train.jpg ADDED
assets/demo/reg_example/vermeer.jpg ADDED
assets/demo/tcg_example/e0i0.jpg ADDED
assets/demo/tcg_example/e0i1.jpg ADDED
assets/demo/tcg_example/e1i0.jpg ADDED
assets/demo/tcg_example/e1i1.jpg ADDED
assets/demo/tcg_example/e2i0.jpg ADDED
assets/figures/share_instruction.png ADDED
configs/model/autokl.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ autokl:
2
+ symbol: autokl
3
+ find_unused_parameters: false
4
+
5
+ autokl_v1:
6
+ super_cfg: autokl
7
+ type: autoencoderkl
8
+ args:
9
+ embed_dim: 4
10
+ ddconfig:
11
+ double_z: true
12
+ z_channels: 4
13
+ resolution: 256
14
+ in_channels: 3
15
+ out_ch: 3
16
+ ch: 128
17
+ ch_mult: [1, 2, 4, 4]
18
+ num_res_blocks: 2
19
+ attn_resolutions: []
20
+ dropout: 0.0
21
+ lossconfig: null
22
+ # pth: pretrained/kl-f8.pth
23
+ hfm: ['shi-labs/versatile-diffusion-model', 'pretrained_pth/kl-f8.pth']
configs/model/clip.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ symbol: clip
3
+ args: {}
4
+
5
+ clip_text_context_encoder:
6
+ super_cfg: clip
7
+ type: clip_text_context_encoder
8
+ args: {}
9
+
10
+ clip_image_context_encoder:
11
+ super_cfg: clip
12
+ type: clip_image_context_encoder
13
+ args: {}
configs/model/openai_unet.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #########
2
+ # v1 2d #
3
+ #########
4
+
5
+ openai_unet_2d_v1:
6
+ type: openai_unet_2d_next
7
+ args:
8
+ in_channels: 4
9
+ out_channels: 4
10
+ model_channels: 320
11
+ attention_resolutions: [ 4, 2, 1 ]
12
+ num_res_blocks: [ 2, 2, 2, 2 ]
13
+ channel_mult: [ 1, 2, 4, 4 ]
14
+ num_heads: 8
15
+ context_dim: 768
16
+ use_checkpoint: True
17
+ parts: [global, data, context]
18
+
19
+ openai_unet_2d_v1_g:
20
+ super_cfg: openai_unet_2d_v1
21
+ args:
22
+ parts: [global]
23
+
24
+ openai_unet_2d_v1_d:
25
+ super_cfg: openai_unet_2d_v1
26
+ args:
27
+ parts: [data]
28
+
29
+ openai_unet_2d_v1_c:
30
+ super_cfg: openai_unet_2d_v1
31
+ args:
32
+ parts: [context]
33
+
34
+ openai_unet_2d_v1_gd:
35
+ super_cfg: openai_unet_2d_v1
36
+ args:
37
+ parts: [global, data]
38
+
39
+ openai_unet_2d_v1_gc:
40
+ super_cfg: openai_unet_2d_v1
41
+ args:
42
+ parts: [global, context]
43
+
44
+ openai_unet_2d_v1_dc:
45
+ super_cfg: openai_unet_2d_v1
46
+ args:
47
+ parts: [data, context]
48
+
49
+ #########
50
+ # v1 0d #
51
+ #########
52
+
53
+ openai_unet_0d_v1:
54
+ type: openai_unet_0d_next
55
+ args:
56
+ input_channels: 768
57
+ model_channels: 320
58
+ output_channels: 768
59
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
60
+ channel_mult: [ 1, 2, 4, 4 ]
61
+ second_dim: [ 4, 4, 4, 4 ]
62
+ with_attn: [true, true, true, false]
63
+ num_heads: 8
64
+ context_dim: 768
65
+ use_checkpoint: True
66
+ parts: [global, data, context]
67
+
68
+ openai_unet_0d_v1_g:
69
+ super_cfg: openai_unet_0d_v1
70
+ args:
71
+ parts: [global]
72
+
73
+ openai_unet_0d_v1_d:
74
+ super_cfg: openai_unet_0d_v1
75
+ args:
76
+ parts: [data]
77
+
78
+ openai_unet_0d_v1_c:
79
+ super_cfg: openai_unet_0d_v1
80
+ args:
81
+ parts: [context]
82
+
83
+ openai_unet_0d_v1_gd:
84
+ super_cfg: openai_unet_0d_v1
85
+ args:
86
+ parts: [global, data]
87
+
88
+ openai_unet_0d_v1_gc:
89
+ super_cfg: openai_unet_0d_v1
90
+ args:
91
+ parts: [global, context]
92
+
93
+ openai_unet_0d_v1_dc:
94
+ super_cfg: openai_unet_0d_v1
95
+ args:
96
+ parts: [data, context]
configs/model/optimus.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ optimus:
3
+ symbol: optimus
4
+ find_unused_parameters: false
5
+ args: {}
6
+
7
+ optimus_bert_encoder:
8
+ super_cfg: optimus
9
+ type: optimus_bert_connector
10
+ # pth: pretrained/optimus_bert_encoder.pth
11
+ args:
12
+ config:
13
+ architectures:
14
+ - BertForMaskedLM
15
+ attention_probs_dropout_prob: 0.1
16
+ finetuning_task: null
17
+ hidden_act: gelu
18
+ hidden_dropout_prob: 0.1
19
+ hidden_size: 768
20
+ initializer_range: 0.02
21
+ intermediate_size: 3072
22
+ layer_norm_eps: 1.e-12
23
+ max_position_embeddings: 512
24
+ num_attention_heads: 12
25
+ num_hidden_layers: 12
26
+ num_labels: 2
27
+ output_attentions: false
28
+ output_hidden_states: false
29
+ pruned_heads: {}
30
+ torchscript: false
31
+ type_vocab_size: 2
32
+ vocab_size: 28996
33
+ latent_size: 768
34
+
35
+ optimus_bert_tokenizer:
36
+ super_cfg: optimus
37
+ type: optimus_bert_tokenizer
38
+ args:
39
+ do_lower_case: false
40
+ max_len: 512
41
+ vocab_file: lib/model_zoo/optimus_models/vocab/bert-base-cased-vocab.txt
42
+
43
+ optimus_gpt2_decoder:
44
+ super_cfg: optimus
45
+ type: optimus_gpt2_connector
46
+ # pth: pretrained/optimus_gpt2_decoder.pth
47
+ args:
48
+ config:
49
+ architectures:
50
+ - GPT2LMHeadModel
51
+ attn_pdrop: 0.1
52
+ embd_pdrop: 0.1
53
+ finetuning_task: null
54
+ hidden_size: 768
55
+ initializer_range: 0.02
56
+ latent_size: 768
57
+ layer_norm_epsilon: 1.e-05
58
+ max_position_embeddings: 1024
59
+ n_ctx: 1024
60
+ n_embd: 768
61
+ n_head: 12
62
+ n_layer: 12
63
+ n_positions: 1024
64
+ num_attention_heads: 12
65
+ num_hidden_layers: 12
66
+ num_labels: 1
67
+ output_attentions: false
68
+ output_hidden_states: false
69
+ pretrained_config_archive_map:
70
+ gpt2 : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json
71
+ gpt2-medium : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json
72
+ gpt2-large : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json
73
+ pruned_heads: {}
74
+ resid_pdrop: 0.1
75
+ summary_activation: null
76
+ summary_first_dropout: 0.1
77
+ summary_proj_to_labels: true
78
+ summary_type: cls_index
79
+ summary_use_proj: true
80
+ torchscript: false
81
+ vocab_size: 50260
82
+
83
+ optimus_gpt2_tokenizer:
84
+ super_cfg: optimus
85
+ type: optimus_gpt2_tokenizer
86
+ args:
87
+ do_lower_case: false
88
+ max_len: 1024
89
+ vocab_file: lib/model_zoo/optimus_models/vocab/gpt2-vocab.json
90
+ merges_file: lib/model_zoo/optimus_models/vocab/gpt2-merges.txt
91
+
92
+ optimus_v1:
93
+ super_cfg: optimus
94
+ type: optimus_vae_next
95
+ pth: pretrained/optimus-vae.pth
96
+ args:
97
+ encoder: MODEL(optimus_bert_encoder)
98
+ decoder: MODEL(optimus_gpt2_decoder)
99
+ tokenizer_encoder: MODEL(optimus_bert_tokenizer)
100
+ tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
101
+ args:
102
+ latent_size: 768
configs/model/vd.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vd_base:
2
+ symbol: vd
3
+ find_unused_parameters: true
4
+ type: vd_v2_0
5
+ args:
6
+ beta_linear_start: 0.00085
7
+ beta_linear_end: 0.012
8
+ timesteps: 1000
9
+ use_ema: false
10
+
11
+ ###########
12
+ # vd v1.0 #
13
+ ###########
14
+
15
+ vd_four_flow_v1-0:
16
+ super_cfg: vd_base
17
+ args:
18
+ vae_cfg_list:
19
+ - [image, MODEL(autokl_v1)]
20
+ - [text, MODEL(optimus_v1)]
21
+ ctx_cfg_list:
22
+ - [image, MODEL(clip_image_context_encoder)]
23
+ - [text, MODEL(clip_text_context_encoder)]
24
+ diffuser_cfg_list:
25
+ - [image, MODEL(openai_unet_2d_v1)]
26
+ - [text, MODEL(openai_unet_0d_v1_dc)]
27
+ global_layer_ptr: image
28
+ latent_scale_factor:
29
+ image: 0.18215
cusomized_gradio_blocks.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import csv
5
+ import inspect
6
+ import os
7
+ import subprocess
8
+ import tempfile
9
+ import threading
10
+ import warnings
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple
13
+
14
+ import matplotlib
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import PIL
18
+ import PIL.Image
19
+
20
+ import gradio
21
+ from gradio import components, processing_utils, routes, utils
22
+ from gradio.context import Context
23
+ from gradio.documentation import document, set_documentation_group
24
+ from gradio.flagging import CSVLogger
25
+
26
+ if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
27
+ from gradio.components import IOComponent
28
+
29
+ CACHED_FOLDER = "gradio_cached_examples"
30
+ LOG_FILE = "log.csv"
31
+
32
+ def create_myexamples(
33
+ examples: List[Any] | List[List[Any]] | str,
34
+ inputs: IOComponent | List[IOComponent],
35
+ outputs: IOComponent | List[IOComponent] | None = None,
36
+ fn: Callable | None = None,
37
+ cache_examples: bool = False,
38
+ examples_per_page: int = 10,
39
+ _api_mode: bool = False,
40
+ label: str | None = None,
41
+ elem_id: str | None = None,
42
+ run_on_click: bool = False,
43
+ preprocess: bool = True,
44
+ postprocess: bool = True,
45
+ batch: bool = False,):
46
+ """Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
47
+ examples_obj = MyExamples(
48
+ examples=examples,
49
+ inputs=inputs,
50
+ outputs=outputs,
51
+ fn=fn,
52
+ cache_examples=cache_examples,
53
+ examples_per_page=examples_per_page,
54
+ _api_mode=_api_mode,
55
+ label=label,
56
+ elem_id=elem_id,
57
+ run_on_click=run_on_click,
58
+ preprocess=preprocess,
59
+ postprocess=postprocess,
60
+ batch=batch,
61
+ _initiated_directly=False,
62
+ )
63
+ utils.synchronize_async(examples_obj.create)
64
+ return examples_obj
65
+
66
+ class MyExamples(gradio.helpers.Examples):
67
+ def __init__(
68
+ self,
69
+ examples: List[Any] | List[List[Any]] | str,
70
+ inputs: IOComponent | List[IOComponent],
71
+ outputs: IOComponent | List[IOComponent] | None = None,
72
+ fn: Callable | None = None,
73
+ cache_examples: bool = False,
74
+ examples_per_page: int = 10,
75
+ _api_mode: bool = False,
76
+ label: str | None = "Examples",
77
+ elem_id: str | None = None,
78
+ run_on_click: bool = False,
79
+ preprocess: bool = True,
80
+ postprocess: bool = True,
81
+ batch: bool = False,
82
+ _initiated_directly: bool = True,):
83
+
84
+ if _initiated_directly:
85
+ warnings.warn(
86
+ "Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.",
87
+ )
88
+
89
+ if cache_examples and (fn is None or outputs is None):
90
+ raise ValueError("If caching examples, `fn` and `outputs` must be provided")
91
+
92
+ if not isinstance(inputs, list):
93
+ inputs = [inputs]
94
+ if outputs and not isinstance(outputs, list):
95
+ outputs = [outputs]
96
+
97
+ working_directory = Path().absolute()
98
+
99
+ if examples is None:
100
+ raise ValueError("The parameter `examples` cannot be None")
101
+ elif isinstance(examples, list) and (
102
+ len(examples) == 0 or isinstance(examples[0], list)
103
+ ):
104
+ pass
105
+ elif (
106
+ isinstance(examples, list) and len(inputs) == 1
107
+ ): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
108
+ examples = [[e] for e in examples]
109
+ elif isinstance(examples, str):
110
+ if not Path(examples).exists():
111
+ raise FileNotFoundError(
112
+ "Could not find examples directory: " + examples
113
+ )
114
+ working_directory = examples
115
+ if not (Path(examples) / LOG_FILE).exists():
116
+ if len(inputs) == 1:
117
+ examples = [[e] for e in os.listdir(examples)]
118
+ else:
119
+ raise FileNotFoundError(
120
+ "Could not find log file (required for multiple inputs): "
121
+ + LOG_FILE
122
+ )
123
+ else:
124
+ with open(Path(examples) / LOG_FILE) as logs:
125
+ examples = list(csv.reader(logs))
126
+ examples = [
127
+ examples[i][: len(inputs)] for i in range(1, len(examples))
128
+ ] # remove header and unnecessary columns
129
+
130
+ else:
131
+ raise ValueError(
132
+ "The parameter `examples` must either be a string directory or a list"
133
+ "(if there is only 1 input component) or (more generally), a nested "
134
+ "list, where each sublist represents a set of inputs."
135
+ )
136
+
137
+ input_has_examples = [False] * len(inputs)
138
+ for example in examples:
139
+ for idx, example_for_input in enumerate(example):
140
+ # if not (example_for_input is None):
141
+ if True:
142
+ try:
143
+ input_has_examples[idx] = True
144
+ except IndexError:
145
+ pass # If there are more example components than inputs, ignore. This can sometimes be intentional (e.g. loading from a log file where outputs and timestamps are also logged)
146
+
147
+ inputs_with_examples = [
148
+ inp for (inp, keep) in zip(inputs, input_has_examples) if keep
149
+ ]
150
+ non_none_examples = [
151
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
152
+ for example in examples
153
+ ]
154
+
155
+ self.examples = examples
156
+ self.non_none_examples = non_none_examples
157
+ self.inputs = inputs
158
+ self.inputs_with_examples = inputs_with_examples
159
+ self.outputs = outputs
160
+ self.fn = fn
161
+ self.cache_examples = cache_examples
162
+ self._api_mode = _api_mode
163
+ self.preprocess = preprocess
164
+ self.postprocess = postprocess
165
+ self.batch = batch
166
+
167
+ with utils.set_directory(working_directory):
168
+ self.processed_examples = [
169
+ [
170
+ component.postprocess(sample)
171
+ for component, sample in zip(inputs, example)
172
+ ]
173
+ for example in examples
174
+ ]
175
+ self.non_none_processed_examples = [
176
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
177
+ for example in self.processed_examples
178
+ ]
179
+ if cache_examples:
180
+ for example in self.examples:
181
+ if len([ex for ex in example if ex is not None]) != len(self.inputs):
182
+ warnings.warn(
183
+ "Examples are being cached but not all input components have "
184
+ "example values. This may result in an exception being thrown by "
185
+ "your function. If you do get an error while caching examples, make "
186
+ "sure all of your inputs have example values for all of your examples "
187
+ "or you provide default values for those particular parameters in your function."
188
+ )
189
+ break
190
+
191
+ with utils.set_directory(working_directory):
192
+ self.dataset = components.Dataset(
193
+ components=inputs_with_examples,
194
+ samples=non_none_examples,
195
+ type="index",
196
+ label=label,
197
+ samples_per_page=examples_per_page,
198
+ elem_id=elem_id,
199
+ )
200
+
201
+ self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
202
+ self.cached_file = Path(self.cached_folder) / "log.csv"
203
+ self.cache_examples = cache_examples
204
+ self.run_on_click = run_on_click
205
+
206
+ from gradio import utils, processing_utils
207
+ from PIL import Image as _Image
208
+ from pathlib import Path
209
+ import numpy as np
210
+
211
+ def customized_postprocess(self, y):
212
+ if y is None:
213
+ return None
214
+
215
+ if isinstance(y, dict):
216
+ if self.tool == "sketch" and self.source in ["upload", "webcam"]:
217
+ y, mask = y["image"], y["mask"]
218
+ if y is None:
219
+ return None
220
+ elif isinstance(y, np.ndarray):
221
+ im = processing_utils.encode_array_to_base64(y)
222
+ elif isinstance(y, _Image.Image):
223
+ im = processing_utils.encode_pil_to_base64(y)
224
+ elif isinstance(y, (str, Path)):
225
+ im = processing_utils.encode_url_or_file_to_base64(y)
226
+ else:
227
+ raise ValueError("Cannot process this value as an Image")
228
+ im = self._format_image(im)
229
+
230
+ if mask is None:
231
+ return im
232
+ elif isinstance(y, np.ndarray):
233
+ mask_im = processing_utils.encode_array_to_base64(mask)
234
+ elif isinstance(y, _Image.Image):
235
+ mask_im = processing_utils.encode_pil_to_base64(mask)
236
+ elif isinstance(y, (str, Path)):
237
+ mask_im = processing_utils.encode_url_or_file_to_base64(mask)
238
+ else:
239
+ raise ValueError("Cannot process this value as an Image")
240
+
241
+ return {"image": im, "mask" : mask_im,}
242
+
243
+ elif isinstance(y, np.ndarray):
244
+ return processing_utils.encode_array_to_base64(y)
245
+ elif isinstance(y, _Image.Image):
246
+ return processing_utils.encode_pil_to_base64(y)
247
+ elif isinstance(y, (str, Path)):
248
+ return processing_utils.encode_url_or_file_to_base64(y)
249
+ else:
250
+ raise ValueError("Cannot process this value as an Image")
251
+
252
+ # def customized_as_example(self, input_data=None):
253
+ # if input_data is None:
254
+ # return str('assets/demo/misc/noimage.jpg')
255
+ # elif isinstance(input_data, dict):
256
+ # im = np.array(PIL.Image.open(input_data["image"])).astype(float)
257
+ # mask = np.array(PIL.Image.open(input_data["mask"])).astype(float)/255
258
+ # imm = (im * (1-mask)).astype(np.uint8)
259
+ # import time
260
+ # ctime = int(time.time()*100)
261
+ # impath = 'assets/demo/temp/temp_{}.png'.format(ctime)
262
+ # PIL.Image.fromarray(imm).save(impath)
263
+ # return str(utils.abspath(impath))
264
+ # else:
265
+ # return str(utils.abspath(input_data))
266
+
267
+ def customized_as_example(self, input_data=None):
268
+ if input_data is None:
269
+ return str('assets/demo/misc/noimage.jpg')
270
+ else:
271
+ return str(utils.abspath(input_data))
gradio_cached_examples/12/Image Result/23645e03-6435-4819-a746-2840be976419/captions.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"/home/james/Project/vd-demo/gradio_cached_examples/12/Image Result/23645e03-6435-4819-a746-2840be976419/tmp9xugbhobbnp5ds0r.png": null, "/home/james/Project/vd-demo/gradio_cached_examples/12/Image Result/23645e03-6435-4819-a746-2840be976419/tmp0m_lns_xtd2zm06b.png": null}
gradio_cached_examples/12/Image Result/23645e03-6435-4819-a746-2840be976419/tmp0m_lns_xtd2zm06b.png ADDED
gradio_cached_examples/12/Image Result/23645e03-6435-4819-a746-2840be976419/tmp9xugbhobbnp5ds0r.png ADDED
gradio_cached_examples/12/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Image Result,flag,username,timestamp
2
+ /home/james/Project/vd-demo/gradio_cached_examples/12/Image Result/23645e03-6435-4819-a746-2840be976419,,,2023-02-07 08:44:12.243513
lib/__init__.py ADDED
File without changes
lib/cfg_helper.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import shutil
4
+ import copy
5
+ import time
6
+ import pprint
7
+ import numpy as np
8
+ import torch
9
+ import matplotlib
10
+ import argparse
11
+ import json
12
+ import yaml
13
+ from easydict import EasyDict as edict
14
+
15
+ from .model_zoo import get_model
16
+
17
+ ############
18
+ # cfg_bank #
19
+ ############
20
+
21
+ def cfg_solvef(cmd, root):
22
+ if not isinstance(cmd, str):
23
+ return cmd
24
+
25
+ if cmd.find('SAME')==0:
26
+ zoom = root
27
+ p = cmd[len('SAME'):].strip('()').split('.')
28
+ p = [pi.strip() for pi in p]
29
+ for pi in p:
30
+ try:
31
+ pi = int(pi)
32
+ except:
33
+ pass
34
+
35
+ try:
36
+ zoom = zoom[pi]
37
+ except:
38
+ return cmd
39
+ return cfg_solvef(zoom, root)
40
+
41
+ if cmd.find('SEARCH')==0:
42
+ zoom = root
43
+ p = cmd[len('SEARCH'):].strip('()').split('.')
44
+ p = [pi.strip() for pi in p]
45
+ find = True
46
+ # Depth first search
47
+ for pi in p:
48
+ try:
49
+ pi = int(pi)
50
+ except:
51
+ pass
52
+
53
+ try:
54
+ zoom = zoom[pi]
55
+ except:
56
+ find = False
57
+ break
58
+
59
+ if find:
60
+ return cfg_solvef(zoom, root)
61
+ else:
62
+ if isinstance(root, dict):
63
+ for ri in root:
64
+ rv = cfg_solvef(cmd, root[ri])
65
+ if rv != cmd:
66
+ return rv
67
+ if isinstance(root, list):
68
+ for ri in root:
69
+ rv = cfg_solvef(cmd, ri)
70
+ if rv != cmd:
71
+ return rv
72
+ return cmd
73
+
74
+ if cmd.find('MODEL')==0:
75
+ goto = cmd[len('MODEL'):].strip('()')
76
+ return model_cfg_bank()(goto)
77
+
78
+ if cmd.find('DATASET')==0:
79
+ goto = cmd[len('DATASET'):].strip('()')
80
+ return dataset_cfg_bank()(goto)
81
+
82
+ return cmd
83
+
84
+ def cfg_solve(cfg, cfg_root):
85
+ # The function solve cfg element such that
86
+ # all sorrogate input are settled.
87
+ # (i.e. SAME(***) )
88
+ if isinstance(cfg, list):
89
+ for i in range(len(cfg)):
90
+ if isinstance(cfg[i], (list, dict)):
91
+ cfg[i] = cfg_solve(cfg[i], cfg_root)
92
+ else:
93
+ cfg[i] = cfg_solvef(cfg[i], cfg_root)
94
+ if isinstance(cfg, dict):
95
+ for k in cfg:
96
+ if isinstance(cfg[k], (list, dict)):
97
+ cfg[k] = cfg_solve(cfg[k], cfg_root)
98
+ else:
99
+ cfg[k] = cfg_solvef(cfg[k], cfg_root)
100
+ return cfg
101
+
102
+ class model_cfg_bank(object):
103
+ def __init__(self):
104
+ self.cfg_dir = osp.join('configs', 'model')
105
+ self.cfg_bank = edict()
106
+
107
+ def __call__(self, name):
108
+ if name not in self.cfg_bank:
109
+ cfg_path = self.get_yaml_path(name)
110
+ with open(cfg_path, 'r') as f:
111
+ cfg_new = yaml.load(
112
+ f, Loader=yaml.FullLoader)
113
+ cfg_new = edict(cfg_new)
114
+ self.cfg_bank.update(cfg_new)
115
+
116
+ cfg = self.cfg_bank[name]
117
+ cfg.name = name
118
+ if 'super_cfg' not in cfg:
119
+ cfg = cfg_solve(cfg, cfg)
120
+ self.cfg_bank[name] = cfg
121
+ return copy.deepcopy(cfg)
122
+
123
+ super_cfg = self.__call__(cfg.super_cfg)
124
+ # unlike other field,
125
+ # args will not be replaced but update.
126
+ if 'args' in cfg:
127
+ if 'args' in super_cfg:
128
+ super_cfg.args.update(cfg.args)
129
+ else:
130
+ super_cfg.args = cfg.args
131
+ cfg.pop('args')
132
+
133
+ super_cfg.update(cfg)
134
+ super_cfg.pop('super_cfg')
135
+ cfg = super_cfg
136
+ try:
137
+ delete_args = cfg.pop('delete_args')
138
+ except:
139
+ delete_args = []
140
+
141
+ for dargs in delete_args:
142
+ cfg.args.pop(dargs)
143
+
144
+ cfg = cfg_solve(cfg, cfg)
145
+ self.cfg_bank[name] = cfg
146
+ return copy.deepcopy(cfg)
147
+
148
+ def get_yaml_path(self, name):
149
+ if name.find('openai_unet')==0:
150
+ return osp.join(
151
+ self.cfg_dir, 'openai_unet.yaml')
152
+ elif (name.find('clip')==0) or (name.find('openclip')==0):
153
+ return osp.join(
154
+ self.cfg_dir, 'clip.yaml')
155
+ elif name.find('vd')==0:
156
+ return osp.join(
157
+ self.cfg_dir, 'vd.yaml')
158
+ elif name.find('optimus')==0:
159
+ return osp.join(
160
+ self.cfg_dir, 'optimus.yaml')
161
+ elif name.find('autokl')==0:
162
+ return osp.join(
163
+ self.cfg_dir, 'autokl.yaml')
164
+ else:
165
+ raise ValueError
166
+
167
+ class dataset_cfg_bank(object):
168
+ def __init__(self):
169
+ self.cfg_dir = osp.join('configs', 'dataset')
170
+ self.cfg_bank = edict()
171
+
172
+ def __call__(self, name):
173
+ if name not in self.cfg_bank:
174
+ cfg_path = self.get_yaml_path(name)
175
+ with open(cfg_path, 'r') as f:
176
+ cfg_new = yaml.load(
177
+ f, Loader=yaml.FullLoader)
178
+ cfg_new = edict(cfg_new)
179
+ self.cfg_bank.update(cfg_new)
180
+
181
+ cfg = self.cfg_bank[name]
182
+ cfg.name = name
183
+ if cfg.get('super_cfg', None) is None:
184
+ cfg = cfg_solve(cfg, cfg)
185
+ self.cfg_bank[name] = cfg
186
+ return copy.deepcopy(cfg)
187
+
188
+ super_cfg = self.__call__(cfg.super_cfg)
189
+ super_cfg.update(cfg)
190
+ cfg = super_cfg
191
+ cfg.super_cfg = None
192
+ try:
193
+ delete = cfg.pop('delete')
194
+ except:
195
+ delete = []
196
+
197
+ for dargs in delete:
198
+ cfg.pop(dargs)
199
+
200
+ cfg = cfg_solve(cfg, cfg)
201
+ self.cfg_bank[name] = cfg
202
+ return copy.deepcopy(cfg)
203
+
204
+ def get_yaml_path(self, name):
205
+ if name.find('laion2b')==0:
206
+ return osp.join(
207
+ self.cfg_dir, 'laion2b.yaml')
208
+ else:
209
+ raise ValueError
210
+
211
+ class experiment_cfg_bank(object):
212
+ def __init__(self):
213
+ self.cfg_dir = osp.join('configs', 'experiment')
214
+ self.cfg_bank = edict()
215
+
216
+ def __call__(self, name):
217
+ if name not in self.cfg_bank:
218
+ cfg_path = self.get_yaml_path(name)
219
+ with open(cfg_path, 'r') as f:
220
+ cfg = yaml.load(
221
+ f, Loader=yaml.FullLoader)
222
+ cfg = edict(cfg)
223
+
224
+ cfg = cfg_solve(cfg, cfg)
225
+ cfg = cfg_solve(cfg, cfg)
226
+ # twice for SEARCH
227
+ self.cfg_bank[name] = cfg
228
+ return copy.deepcopy(cfg)
229
+
230
+ def get_yaml_path(self, name):
231
+ return osp.join(
232
+ self.cfg_dir, name+'.yaml')
233
+
234
+ def load_cfg_yaml(path):
235
+ if osp.isfile(path):
236
+ cfg_path = path
237
+ elif osp.isfile(osp.join('configs', 'experiment', path)):
238
+ cfg_path = osp.join('configs', 'experiment', path)
239
+ elif osp.isfile(osp.join('configs', 'experiment', path+'.yaml')):
240
+ cfg_path = osp.join('configs', 'experiment', path+'.yaml')
241
+ else:
242
+ assert False, 'No such config!'
243
+
244
+ with open(cfg_path, 'r') as f:
245
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
246
+ cfg = edict(cfg)
247
+ cfg = cfg_solve(cfg, cfg)
248
+ cfg = cfg_solve(cfg, cfg)
249
+ return cfg
250
+
251
+ ##############
252
+ # cfg_helper #
253
+ ##############
254
+
255
+ def get_experiment_id(ref=None):
256
+ if ref is None:
257
+ time.sleep(0.5)
258
+ return int(time.time()*100)
259
+ else:
260
+ try:
261
+ return int(ref)
262
+ except:
263
+ pass
264
+
265
+ _, ref = osp.split(ref)
266
+ ref = ref.split('_')[0]
267
+ try:
268
+ return int(ref)
269
+ except:
270
+ assert False, 'Invalid experiment ID!'
271
+
272
+ def record_resume_cfg(path):
273
+ cnt = 0
274
+ while True:
275
+ if osp.exists(path+'.{:04d}'.format(cnt)):
276
+ cnt += 1
277
+ continue
278
+ shutil.copyfile(path, path+'.{:04d}'.format(cnt))
279
+ break
280
+
281
+ def get_command_line_args():
282
+ parser = argparse.ArgumentParser()
283
+ parser.add_argument('--debug', action='store_true', default=False)
284
+ parser.add_argument('--config', type=str)
285
+ parser.add_argument('--gpu', nargs='+', type=int)
286
+
287
+ parser.add_argument('--node_rank', type=int)
288
+ parser.add_argument('--node_list', nargs='+', type=str)
289
+ parser.add_argument('--nodes', type=int)
290
+ parser.add_argument('--addr', type=str, default='127.0.0.1')
291
+ parser.add_argument('--port', type=int, default=11233)
292
+
293
+ parser.add_argument('--signature', nargs='+', type=str)
294
+ parser.add_argument('--seed', type=int)
295
+
296
+ parser.add_argument('--eval', type=str)
297
+ parser.add_argument('--eval_subdir', type=str)
298
+ parser.add_argument('--pretrained', type=str)
299
+
300
+ parser.add_argument('--resume_dir', type=str)
301
+ parser.add_argument('--resume_step', type=int)
302
+ parser.add_argument('--resume_weight', type=str)
303
+
304
+ args = parser.parse_args()
305
+
306
+ # Special handling the resume
307
+ if args.resume_dir is not None:
308
+ cfg = edict()
309
+ cfg.env = edict()
310
+ cfg.env.debug = args.debug
311
+ cfg.env.resume = edict()
312
+ cfg.env.resume.dir = args.resume_dir
313
+ cfg.env.resume.step = args.resume_step
314
+ cfg.env.resume.weight = args.resume_weight
315
+ return cfg
316
+
317
+ cfg = load_cfg_yaml(args.config)
318
+ cfg.env.debug = args.debug
319
+ cfg.env.gpu_device = [0] if args.gpu is None else list(args.gpu)
320
+ cfg.env.master_addr = args.addr
321
+ cfg.env.master_port = args.port
322
+ cfg.env.dist_url = 'tcp://{}:{}'.format(args.addr, args.port)
323
+
324
+ if args.node_list is None:
325
+ cfg.env.node_rank = 0 if args.node_rank is None else args.node_rank
326
+ cfg.env.nodes = 1 if args.nodes is None else args.nodes
327
+ else:
328
+ import socket
329
+ hostname = socket.gethostname()
330
+ assert cfg.env.master_addr == args.node_list[0]
331
+ cfg.env.node_rank = args.node_list.index(hostname)
332
+ cfg.env.nodes = len(args.node_list)
333
+ cfg.env.node_list = args.node_list
334
+
335
+ istrain = False if args.eval is not None else True
336
+ isdebug = cfg.env.debug
337
+
338
+ if istrain:
339
+ if isdebug:
340
+ cfg.env.experiment_id = 999999999999
341
+ cfg.train.signature = ['debug']
342
+ else:
343
+ cfg.env.experiment_id = get_experiment_id()
344
+ if args.signature is not None:
345
+ cfg.train.signature = args.signature
346
+ else:
347
+ if 'train' in cfg:
348
+ cfg.pop('train')
349
+ cfg.env.experiment_id = get_experiment_id(args.eval)
350
+ if args.signature is not None:
351
+ cfg.eval.signature = args.signature
352
+
353
+ if isdebug and (args.eval is None):
354
+ cfg.env.experiment_id = 999999999999
355
+ cfg.eval.signature = ['debug']
356
+
357
+ if args.eval_subdir is not None:
358
+ if isdebug:
359
+ cfg.eval.eval_subdir = 'debug'
360
+ else:
361
+ cfg.eval.eval_subdir = args.eval_subdir
362
+ if args.pretrained is not None:
363
+ cfg.eval.pretrained = args.pretrained
364
+ # The override pretrained over the setting in cfg.model
365
+
366
+ if args.seed is not None:
367
+ cfg.env.rnd_seed = args.seed
368
+
369
+ return cfg
370
+
371
+ def cfg_initiates(cfg):
372
+ cfge = cfg.env
373
+ isdebug = cfge.debug
374
+ isresume = 'resume' in cfge
375
+ istrain = 'train' in cfg
376
+ haseval = 'eval' in cfg
377
+ cfgt = cfg.train if istrain else None
378
+ cfgv = cfg.eval if haseval else None
379
+
380
+ ###############################
381
+ # get some environment params #
382
+ ###############################
383
+
384
+ cfge.computer = os.uname()
385
+ cfge.torch_version = str(torch.__version__)
386
+
387
+ ##########
388
+ # resume #
389
+ ##########
390
+
391
+ if isresume:
392
+ resume_cfg_path = osp.join(cfge.resume.dir, 'config.yaml')
393
+ record_resume_cfg(resume_cfg_path)
394
+ with open(resume_cfg_path, 'r') as f:
395
+ cfg_resume = yaml.load(f, Loader=yaml.FullLoader)
396
+ cfg_resume = edict(cfg_resume)
397
+ cfg_resume.env.update(cfge)
398
+ cfg = cfg_resume
399
+ cfge = cfg.env
400
+ log_file = cfg.train.log_file
401
+
402
+ print('')
403
+ print('##########')
404
+ print('# resume #')
405
+ print('##########')
406
+ print('')
407
+ with open(log_file, 'a') as f:
408
+ print('', file=f)
409
+ print('##########', file=f)
410
+ print('# resume #', file=f)
411
+ print('##########', file=f)
412
+ print('', file=f)
413
+
414
+ pprint.pprint(cfg)
415
+ with open(log_file, 'a') as f:
416
+ pprint.pprint(cfg, f)
417
+
418
+ ####################
419
+ # node distributed #
420
+ ####################
421
+
422
+ if cfg.env.master_addr!='127.0.0.1':
423
+ os.environ['MASTER_ADDR'] = cfge.master_addr
424
+ os.environ['MASTER_PORT'] = '{}'.format(cfge.master_port)
425
+ if cfg.env.dist_backend=='nccl':
426
+ os.environ['NCCL_SOCKET_FAMILY'] = 'AF_INET'
427
+ if cfg.env.dist_backend=='gloo':
428
+ os.environ['GLOO_SOCKET_FAMILY'] = 'AF_INET'
429
+
430
+ #######################
431
+ # cuda visible device #
432
+ #######################
433
+
434
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
435
+ [str(gid) for gid in cfge.gpu_device])
436
+
437
+ #####################
438
+ # return resume cfg #
439
+ #####################
440
+
441
+ if isresume:
442
+ return cfg
443
+
444
+ #############################################
445
+ # some misc setting that not need in resume #
446
+ #############################################
447
+
448
+ cfgm = cfg.model
449
+ cfge.gpu_count = len(cfge.gpu_device)
450
+
451
+ ##########################################
452
+ # align batch size and num worker config #
453
+ ##########################################
454
+
455
+ gpu_n = cfge.gpu_count * cfge.nodes
456
+ def align_batch_size(bs, bs_per_gpu):
457
+ assert (bs is not None) or (bs_per_gpu is not None)
458
+ bs = bs_per_gpu * gpu_n if bs is None else bs
459
+ bs_per_gpu = bs // gpu_n if bs_per_gpu is None else bs_per_gpu
460
+ assert (bs == bs_per_gpu * gpu_n)
461
+ return bs, bs_per_gpu
462
+
463
+ if istrain:
464
+ cfgt.batch_size, cfgt.batch_size_per_gpu = \
465
+ align_batch_size(cfgt.batch_size, cfgt.batch_size_per_gpu)
466
+ cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu = \
467
+ align_batch_size(cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu)
468
+ if haseval:
469
+ cfgv.batch_size, cfgv.batch_size_per_gpu = \
470
+ align_batch_size(cfgv.batch_size, cfgv.batch_size_per_gpu)
471
+ cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu = \
472
+ align_batch_size(cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu)
473
+
474
+ ##################
475
+ # create log dir #
476
+ ##################
477
+
478
+ if istrain:
479
+ if not isdebug:
480
+ sig = cfgt.get('signature', [])
481
+ sig = sig + ['s{}'.format(cfge.rnd_seed)]
482
+ else:
483
+ sig = ['debug']
484
+
485
+ log_dir = [
486
+ cfge.log_root_dir,
487
+ '{}_{}'.format(cfgm.symbol, cfgt.dataset.symbol),
488
+ '_'.join([str(cfge.experiment_id)] + sig)
489
+ ]
490
+ log_dir = osp.join(*log_dir)
491
+ log_file = osp.join(log_dir, 'train.log')
492
+ if not osp.exists(log_file):
493
+ os.makedirs(osp.dirname(log_file))
494
+ cfgt.log_dir = log_dir
495
+ cfgt.log_file = log_file
496
+
497
+ if haseval:
498
+ cfgv.log_dir = log_dir
499
+ cfgv.log_file = log_file
500
+ else:
501
+ model_symbol = cfgm.symbol
502
+ if cfgv.get('dataset', None) is None:
503
+ dataset_symbol = 'nodataset'
504
+ else:
505
+ dataset_symbol = cfgv.dataset.symbol
506
+
507
+ log_dir = osp.join(cfge.log_root_dir, '{}_{}'.format(model_symbol, dataset_symbol))
508
+ exp_dir = search_experiment_folder(log_dir, cfge.experiment_id)
509
+ if exp_dir is None:
510
+ if not isdebug:
511
+ sig = cfgv.get('signature', []) + ['evalonly']
512
+ else:
513
+ sig = ['debug']
514
+ exp_dir = '_'.join([str(cfge.experiment_id)] + sig)
515
+
516
+ eval_subdir = cfgv.get('eval_subdir', None)
517
+ # override subdir in debug mode (if eval_subdir is set)
518
+ eval_subdir = 'debug' if (eval_subdir is not None) and isdebug else eval_subdir
519
+
520
+ if eval_subdir is not None:
521
+ log_dir = osp.join(log_dir, exp_dir, eval_subdir)
522
+ else:
523
+ log_dir = osp.join(log_dir, exp_dir)
524
+
525
+ disable_log_override = cfgv.get('disable_log_override', False)
526
+ if osp.isdir(log_dir):
527
+ if disable_log_override:
528
+ assert False, 'Override an exsited log_dir is disabled at [{}]'.format(log_dir)
529
+ else:
530
+ os.makedirs(log_dir)
531
+
532
+ log_file = osp.join(log_dir, 'eval.log')
533
+ cfgv.log_dir = log_dir
534
+ cfgv.log_file = log_file
535
+
536
+ ######################
537
+ # print and save cfg #
538
+ ######################
539
+
540
+ pprint.pprint(cfg)
541
+ if cfge.node_rank==0:
542
+ with open(log_file, 'w') as f:
543
+ pprint.pprint(cfg, f)
544
+ with open(osp.join(log_dir, 'config.yaml'), 'w') as f:
545
+ yaml.dump(edict_2_dict(cfg), f)
546
+ else:
547
+ with open(osp.join(log_dir, 'config.yaml.{}'.format(cfge.node_rank)), 'w') as f:
548
+ yaml.dump(edict_2_dict(cfg), f)
549
+
550
+ #############
551
+ # save code #
552
+ #############
553
+
554
+ save_code = False
555
+ if istrain:
556
+ save_code = cfgt.get('save_code', False)
557
+ elif haseval:
558
+ save_code = cfgv.get('save_code', False)
559
+ save_code = save_code and (cfge.node_rank==0)
560
+
561
+ if save_code:
562
+ codedir = osp.join(log_dir, 'code')
563
+ if osp.exists(codedir):
564
+ shutil.rmtree(codedir)
565
+ for d in ['configs', 'lib']:
566
+ fromcodedir = d
567
+ tocodedir = osp.join(codedir, d)
568
+ shutil.copytree(
569
+ fromcodedir, tocodedir,
570
+ ignore=shutil.ignore_patterns(
571
+ '*__pycache__*', '*build*'))
572
+ for codei in os.listdir('.'):
573
+ if osp.splitext(codei)[1] == 'py':
574
+ shutil.copy(codei, codedir)
575
+
576
+ #######################
577
+ # set matplotlib mode #
578
+ #######################
579
+
580
+ if 'matplotlib_mode' in cfge:
581
+ try:
582
+ matplotlib.use(cfge.matplotlib_mode)
583
+ except:
584
+ print('Warning: matplotlib mode [{}] failed to be set!'.format(cfge.matplotlib_mode))
585
+
586
+ return cfg
587
+
588
+ def edict_2_dict(x):
589
+ if isinstance(x, dict):
590
+ xnew = {}
591
+ for k in x:
592
+ xnew[k] = edict_2_dict(x[k])
593
+ return xnew
594
+ elif isinstance(x, list):
595
+ xnew = []
596
+ for i in range(len(x)):
597
+ xnew.append( edict_2_dict(x[i]) )
598
+ return xnew
599
+ else:
600
+ return x
601
+
602
+ def search_experiment_folder(root, exid):
603
+ target = None
604
+ for fi in os.listdir(root):
605
+ if not osp.isdir(osp.join(root, fi)):
606
+ continue
607
+ if int(fi.split('_')[0]) == exid:
608
+ if target is not None:
609
+ return None # duplicated
610
+ elif target is None:
611
+ target = fi
612
+ return target
lib/cfg_holder.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ def singleton(class_):
4
+ instances = {}
5
+ def getinstance(*args, **kwargs):
6
+ if class_ not in instances:
7
+ instances[class_] = class_(*args, **kwargs)
8
+ return instances[class_]
9
+ return getinstance
10
+
11
+ ##############
12
+ # cfg_holder #
13
+ ##############
14
+
15
+ @singleton
16
+ class cfg_unique_holder(object):
17
+ def __init__(self):
18
+ self.cfg = None
19
+ # this is use to track the main codes.
20
+ self.code = set()
21
+ def save_cfg(self, cfg):
22
+ self.cfg = copy.deepcopy(cfg)
23
+ def add_code(self, code):
24
+ """
25
+ A new main code is reached and
26
+ its name is added.
27
+ """
28
+ self.code.add(code)
lib/log_service.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timeit
2
+ import numpy as np
3
+ import os
4
+ import os.path as osp
5
+ import shutil
6
+ import copy
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.distributed as dist
10
+ from .cfg_holder import cfg_unique_holder as cfguh
11
+ from . import sync
12
+
13
+ print_console_local_rank0_only = True
14
+
15
+ def print_log(*console_info):
16
+ local_rank = sync.get_rank('local')
17
+ if print_console_local_rank0_only and (local_rank!=0):
18
+ return
19
+ console_info = [str(i) for i in console_info]
20
+ console_info = ' '.join(console_info)
21
+ print(console_info)
22
+
23
+ if local_rank!=0:
24
+ return
25
+
26
+ log_file = None
27
+ try:
28
+ log_file = cfguh().cfg.train.log_file
29
+ except:
30
+ try:
31
+ log_file = cfguh().cfg.eval.log_file
32
+ except:
33
+ return
34
+ if log_file is not None:
35
+ with open(log_file, 'a') as f:
36
+ f.write(console_info + '\n')
37
+
38
+ class distributed_log_manager(object):
39
+ def __init__(self):
40
+ self.sum = {}
41
+ self.cnt = {}
42
+ self.time_check = timeit.default_timer()
43
+
44
+ cfgt = cfguh().cfg.train
45
+ use_tensorboard = getattr(cfgt, 'log_tensorboard', False)
46
+
47
+ self.ddp = sync.is_ddp()
48
+ self.rank = sync.get_rank('local')
49
+ self.world_size = sync.get_world_size('local')
50
+
51
+ self.tb = None
52
+ if use_tensorboard and (self.rank==0):
53
+ import tensorboardX
54
+ monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard')
55
+ self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir))
56
+
57
+ def accumulate(self, n, **data):
58
+ if n < 0:
59
+ raise ValueError
60
+
61
+ for itemn, di in data.items():
62
+ if itemn in self.sum:
63
+ self.sum[itemn] += di * n
64
+ self.cnt[itemn] += n
65
+ else:
66
+ self.sum[itemn] = di * n
67
+ self.cnt[itemn] = n
68
+
69
+ def get_mean_value_dict(self):
70
+ value_gather = [
71
+ self.sum[itemn]/self.cnt[itemn] \
72
+ for itemn in sorted(self.sum.keys()) ]
73
+
74
+ value_gather_tensor = torch.FloatTensor(value_gather).to(self.rank)
75
+ if self.ddp:
76
+ dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM)
77
+ value_gather_tensor /= self.world_size
78
+
79
+ mean = {}
80
+ for idx, itemn in enumerate(sorted(self.sum.keys())):
81
+ mean[itemn] = value_gather_tensor[idx].item()
82
+ return mean
83
+
84
+ def tensorboard_log(self, step, data, mode='train', **extra):
85
+ if self.tb is None:
86
+ return
87
+ if mode == 'train':
88
+ self.tb.add_scalar('other/epochn', extra['epochn'], step)
89
+ if 'lr' in extra:
90
+ self.tb.add_scalar('other/lr', extra['lr'], step)
91
+ for itemn, di in data.items():
92
+ if itemn.find('loss') == 0:
93
+ self.tb.add_scalar('loss/'+itemn, di, step)
94
+ elif itemn == 'Loss':
95
+ self.tb.add_scalar('Loss', di, step)
96
+ else:
97
+ self.tb.add_scalar('other/'+itemn, di, step)
98
+ elif mode == 'eval':
99
+ if isinstance(data, dict):
100
+ for itemn, di in data.items():
101
+ self.tb.add_scalar('eval/'+itemn, di, step)
102
+ else:
103
+ self.tb.add_scalar('eval', data, step)
104
+ return
105
+
106
+ def train_summary(self, itern, epochn, samplen, lr, tbstep=None):
107
+ console_info = [
108
+ 'Iter:{}'.format(itern),
109
+ 'Epoch:{}'.format(epochn),
110
+ 'Sample:{}'.format(samplen),]
111
+
112
+ if lr is not None:
113
+ console_info += ['LR:{:.4E}'.format(lr)]
114
+
115
+ mean = self.get_mean_value_dict()
116
+
117
+ tbstep = itern if tbstep is None else tbstep
118
+ self.tensorboard_log(
119
+ tbstep, mean, mode='train',
120
+ itern=itern, epochn=epochn, lr=lr)
121
+
122
+ loss = mean.pop('Loss')
123
+ mean_info = ['Loss:{:.4f}'.format(loss)] + [
124
+ '{}:{:.4f}'.format(itemn, mean[itemn]) \
125
+ for itemn in sorted(mean.keys()) \
126
+ if itemn.find('loss') == 0
127
+ ]
128
+ console_info += mean_info
129
+ console_info.append('Time:{:.2f}s'.format(
130
+ timeit.default_timer() - self.time_check))
131
+ return ' , '.join(console_info)
132
+
133
+ def clear(self):
134
+ self.sum = {}
135
+ self.cnt = {}
136
+ self.time_check = timeit.default_timer()
137
+
138
+ def tensorboard_close(self):
139
+ if self.tb is not None:
140
+ self.tb.close()
141
+
142
+ # ----- also include some small utils -----
143
+
144
+ def torch_to_numpy(*argv):
145
+ if len(argv) > 1:
146
+ data = list(argv)
147
+ else:
148
+ data = argv[0]
149
+
150
+ if isinstance(data, torch.Tensor):
151
+ return data.to('cpu').detach().numpy()
152
+
153
+ elif isinstance(data, (list, tuple)):
154
+ out = []
155
+ for di in data:
156
+ out.append(torch_to_numpy(di))
157
+ return out
158
+
159
+ elif isinstance(data, dict):
160
+ out = {}
161
+ for ni, di in data.items():
162
+ out[ni] = torch_to_numpy(di)
163
+ return out
164
+
165
+ else:
166
+ return data
lib/model_zoo/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .common.get_model import get_model
2
+ from .common.get_optimizer import get_optimizer
3
+ from .common.get_scheduler import get_scheduler
4
+ from .common.utils import get_unit