Kaixuanliu lvkaokao commited on
Commit
510b9cd
·
0 Parent(s):

Duplicate from Intel/textual-inversion-training

Browse files

Co-authored-by: lvkaokao <lvkaokao@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Textual Inversion Training
3
+ emoji: 📉
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.14.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: Intel/textual-inversion-training
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from pathlib import Path
4
+ import argparse
5
+ import shutil
6
+ # from train_dreambooth import run_training
7
+ from textual_inversion import run_training
8
+ from convertosd import convert
9
+ from PIL import Image
10
+ from slugify import slugify
11
+ import requests
12
+ import torch
13
+ import zipfile
14
+ import tarfile
15
+ import urllib.parse
16
+ import gc
17
+ from diffusers import StableDiffusionPipeline
18
+ from huggingface_hub import snapshot_download
19
+
20
+
21
+ is_spaces = True if "SPACE_ID" in os.environ else False
22
+ #is_shared_ui = True if "IS_SHARED_UI" in os.environ else False
23
+ if(is_spaces):
24
+ is_shared_ui = True if ("lvkaokao/textual-inversion-training" in os.environ['SPACE_ID'] or "Intel/textual-inversion-training" in os.environ['SPACE_ID']) else False
25
+ else:
26
+ is_shared_ui = False
27
+
28
+ css = '''
29
+ .instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
30
+ .arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
31
+ #component-4, #component-3, #component-10{min-height: 0}
32
+ .duplicate-button img{margin: 0}
33
+ '''
34
+ maximum_concepts = 1
35
+
36
+ #Pre download the files
37
+ '''
38
+ model_v1_4 = snapshot_download(repo_id="CompVis/stable-diffusion-v1-4")
39
+ #model_v1_5 = snapshot_download(repo_id="runwayml/stable-diffusion-v1-5")
40
+ model_v1_5 = snapshot_download(repo_id="stabilityai/stable-diffusion-2")
41
+ model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base", revision="fp16")
42
+ safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
43
+ '''
44
+ model_v1_4 = "CompVis/stable-diffusion-v1-4"
45
+ model_v1_5 = "stabilityai/stable-diffusion-2"
46
+ model_v2_512 = "stabilityai/stable-diffusion-2-base"
47
+
48
+ model_to_load = model_v1_4
49
+
50
+
51
+ with zipfile.ZipFile("mix.zip", 'r') as zip_ref:
52
+ zip_ref.extractall(".")
53
+
54
+ def swap_text(option):
55
+ mandatory_liability = "You must have the right to do so and you are liable for the images you use, example:"
56
+ if(option == "object"):
57
+ instance_prompt_example = "cttoy"
58
+ freeze_for = 30
59
+ return [f"You are going to train `object`(s), upload 5-10 images of each object you are planning on training on from different angles/perspectives. {mandatory_liability}:", '''<img src="file/cat-toy.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for, gr.update(visible=False)]
60
+ elif(option == "person"):
61
+ instance_prompt_example = "julcto"
62
+ freeze_for = 70
63
+ return [f"You are going to train a `person`(s), upload 10-20 images of each person you are planning on training on from different angles/perspectives. {mandatory_liability}:", '''<img src="file/person.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for, gr.update(visible=True)]
64
+ elif(option == "style"):
65
+ instance_prompt_example = "trsldamrl"
66
+ freeze_for = 10
67
+ return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for, gr.update(visible=False)]
68
+
69
+ def swap_base_model(selected_model):
70
+ global model_to_load
71
+ if(selected_model == "v1-4"):
72
+ model_to_load = model_v1_4
73
+ elif(selected_model == "v1-5"):
74
+ model_to_load = model_v1_5
75
+ else:
76
+ model_to_load = model_v2_512
77
+
78
+ def count_files(*inputs):
79
+ file_counter = 0
80
+ concept_counter = 0
81
+ for i, input in enumerate(inputs):
82
+ if(i < maximum_concepts-1):
83
+ files = inputs[i]
84
+ if(files):
85
+ concept_counter+=1
86
+ file_counter+=len(files)
87
+ uses_custom = inputs[-1]
88
+ type_of_thing = inputs[-4]
89
+ if(uses_custom):
90
+ Training_Steps = int(inputs[-3])
91
+ else:
92
+ Training_Steps = file_counter*200
93
+ if(Training_Steps > 2400):
94
+ Training_Steps=2400
95
+ elif(Training_Steps < 1400):
96
+ Training_Steps=1400
97
+ if(is_spaces):
98
+ summary_sentence = f'''The training should take around 24 hours for 1000 steps using the default free CPU.<br><br>'''
99
+ else:
100
+ summary_sentence = f'''You are going to train {concept_counter} {type_of_thing}(s), with {file_counter} images for {Training_Steps} steps.<br><br>'''
101
+
102
+ return([gr.update(visible=True), gr.update(visible=True, value=summary_sentence)])
103
+
104
+ def update_steps(*files_list):
105
+ file_counter = 0
106
+ for i, files in enumerate(files_list):
107
+ if(files):
108
+ file_counter+=len(files)
109
+ return(gr.update(value=file_counter*200))
110
+
111
+ def pad_image(image):
112
+ w, h = image.size
113
+ if w == h:
114
+ return image
115
+ elif w > h:
116
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
117
+ new_image.paste(image, (0, (w - h) // 2))
118
+ return new_image
119
+ else:
120
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
121
+ new_image.paste(image, ((h - w) // 2, 0))
122
+ return new_image
123
+
124
+ def train(*inputs):
125
+ if is_shared_ui:
126
+ raise gr.Error("This Space only works in duplicated instances")
127
+
128
+ torch.cuda.empty_cache()
129
+ if 'pipe' in globals():
130
+ global pipe, pipe_is_set
131
+ del pipe
132
+ pipe_is_set = False
133
+ gc.collect()
134
+
135
+ if os.path.exists("output_model"): shutil.rmtree('output_model')
136
+ if os.path.exists("concept_images"): shutil.rmtree('concept_images')
137
+ if os.path.exists("diffusers_model.tar"): os.remove("diffusers_model.tar")
138
+ if os.path.exists("model.ckpt"): os.remove("model.ckpt")
139
+ if os.path.exists("hastrained.success"): os.remove("hastrained.success")
140
+ file_counter = 0
141
+ print(inputs)
142
+
143
+ os.makedirs('concept_images', exist_ok=True)
144
+ files = inputs[maximum_concepts*3]
145
+ init_word = inputs[maximum_concepts*2]
146
+ prompt = inputs[maximum_concepts]
147
+ if(prompt == "" or prompt == None):
148
+ raise gr.Error("You forgot to define your concept prompt")
149
+
150
+ for j, file_temp in enumerate(files):
151
+ file = Image.open(file_temp.name)
152
+ image = pad_image(file)
153
+ image = image.resize((512, 512))
154
+ extension = file_temp.name.split(".")[1]
155
+ image = image.convert('RGB')
156
+ image.save(f'concept_images/{j+1}.jpg', format="JPEG", quality = 100)
157
+ file_counter += 1
158
+
159
+
160
+ os.makedirs('output_model',exist_ok=True)
161
+ uses_custom = inputs[-1]
162
+ type_of_thing = inputs[-4]
163
+ remove_attribution_after = inputs[-6]
164
+ experimental_face_improvement = inputs[-9]
165
+ which_model = inputs[-10]
166
+ if(uses_custom):
167
+ Training_Steps = int(inputs[-3])
168
+ else:
169
+ Training_Steps = 1000
170
+
171
+ print(os.listdir("concept_images"))
172
+
173
+ args_general = argparse.Namespace(
174
+ pretrained_model_name_or_path = model_to_load,
175
+ train_data_dir="concept_images",
176
+ learnable_property=type_of_thing,
177
+ placeholder_token=prompt,
178
+ initializer_token=init_word,
179
+ resolution=512,
180
+ train_batch_size=1,
181
+ gradient_accumulation_steps=2,
182
+ use_bf16=True,
183
+ max_train_steps=Training_Steps,
184
+ learning_rate=5.0e-4,
185
+ scale_lr=True,
186
+ lr_scheduler="constant",
187
+ lr_warmup_steps=0,
188
+ output_dir="output_model",
189
+ )
190
+ print("Starting single training...")
191
+ lock_file = open("intraining.lock", "w")
192
+ lock_file.close()
193
+ run_training(args_general)
194
+
195
+ gc.collect()
196
+ torch.cuda.empty_cache()
197
+ if(which_model in ["v1-5"]):
198
+ print("Adding Safety Checker to the model...")
199
+ shutil.copytree(f"{safety_checker}/feature_extractor", "output_model/feature_extractor")
200
+ shutil.copytree(f"{safety_checker}/safety_checker", "output_model/safety_checker")
201
+ shutil.copy(f"model_index.json", "output_model/model_index.json")
202
+
203
+ if(not remove_attribution_after):
204
+ print("Archiving model file...")
205
+ with tarfile.open("diffusers_model.tar", "w") as tar:
206
+ tar.add("output_model", arcname=os.path.basename("output_model"))
207
+ if os.path.exists("intraining.lock"): os.remove("intraining.lock")
208
+ trained_file = open("hastrained.success", "w")
209
+ trained_file.close()
210
+ print(os.listdir("output_model"))
211
+ print("Training completed!")
212
+ return [
213
+ gr.update(visible=True, value=["diffusers_model.tar"]), #result
214
+ gr.update(visible=True), #try_your_model
215
+ gr.update(visible=True), #push_to_hub
216
+ gr.update(visible=True), #convert_button
217
+ gr.update(visible=False), #training_ongoing
218
+ gr.update(visible=True) #completed_training
219
+ ]
220
+ else:
221
+ hf_token = inputs[-5]
222
+ model_name = inputs[-7]
223
+ where_to_upload = inputs[-8]
224
+ push(model_name, where_to_upload, hf_token, which_model, True)
225
+ hardware_url = f"https://huggingface.co/spaces/{os.environ['SPACE_ID']}/hardware"
226
+ headers = { "authorization" : f"Bearer {hf_token}"}
227
+ body = {'flavor': 'cpu-basic'}
228
+ requests.post(hardware_url, json = body, headers=headers)
229
+
230
+ import time
231
+ pipe_is_set = False
232
+ def generate(prompt, steps):
233
+
234
+ print("prompt: ", prompt)
235
+ print("steps: ", steps)
236
+
237
+ torch.cuda.empty_cache()
238
+ from diffusers import StableDiffusionPipeline
239
+ global pipe_is_set
240
+ if(not pipe_is_set):
241
+ global pipe
242
+ if torch.cuda.is_available():
243
+ pipe = StableDiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float16)
244
+ pipe = pipe.to("cuda")
245
+ else:
246
+ pipe = StableDiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float)
247
+ pipe_is_set = True
248
+
249
+ start_time = time.time()
250
+ image = pipe(prompt, num_inference_steps=steps, guidance_scale=7.5).images[0]
251
+ print("cost: ", time.time() - start_time)
252
+ return(image)
253
+
254
+ def push(model_name, where_to_upload, hf_token, which_model, comes_from_automated=False):
255
+
256
+ if(not os.path.exists("model.ckpt")):
257
+ convert("output_model", "model.ckpt")
258
+ from huggingface_hub import HfApi, HfFolder, CommitOperationAdd
259
+ from huggingface_hub import create_repo
260
+ model_name_slug = slugify(model_name)
261
+ api = HfApi()
262
+ your_username = api.whoami(token=hf_token)["name"]
263
+ if(where_to_upload == "My personal profile"):
264
+ model_id = f"{your_username}/{model_name_slug}"
265
+ else:
266
+ model_id = f"sd-dreambooth-library/{model_name_slug}"
267
+ headers = {"Authorization" : f"Bearer: {hf_token}", "Content-Type": "application/json"}
268
+ response = requests.post("https://huggingface.co/organizations/sd-dreambooth-library/share/SSeOwppVCscfTEzFGQaqpfcjukVeNrKNHX", headers=headers)
269
+
270
+ images_upload = os.listdir("concept_images")
271
+ image_string = ""
272
+ instance_prompt_list = []
273
+ previous_instance_prompt = ''
274
+ for i, image in enumerate(images_upload):
275
+ instance_prompt = image.split("_")[0]
276
+ if(instance_prompt != previous_instance_prompt):
277
+ title_instance_prompt_string = instance_prompt
278
+ instance_prompt_list.append(instance_prompt)
279
+ else:
280
+ title_instance_prompt_string = ''
281
+ previous_instance_prompt = instance_prompt
282
+ image_string = f'''{title_instance_prompt_string} {"(use that on your prompt)" if title_instance_prompt_string != "" else ""}
283
+ {image_string}![{instance_prompt} {i}](https://huggingface.co/{model_id}/resolve/main/concept_images/{urllib.parse.quote(image)})'''
284
+ readme_text = f'''---
285
+ license: creativeml-openrail-m
286
+ tags:
287
+ - text-to-image
288
+ ---
289
+ ### {model_name} Dreambooth model trained by {api.whoami(token=hf_token)["name"]} with [Hugging Face Dreambooth Training Space](https://huggingface.co/spaces/multimodalart/dreambooth-training) with the {which_model} base model
290
+
291
+ You run your new concept via `diffusers` [Colab Notebook for Inference](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_inference.ipynb). Don't forget to use the concept prompts!
292
+
293
+ Sample pictures of:
294
+ {image_string}
295
+ '''
296
+ #Save the readme to a file
297
+ readme_file = open("model.README.md", "w")
298
+ readme_file.write(readme_text)
299
+ readme_file.close()
300
+ #Save the token identifier to a file
301
+ text_file = open("token_identifier.txt", "w")
302
+ text_file.write(', '.join(instance_prompt_list))
303
+ text_file.close()
304
+ try:
305
+ create_repo(model_id,private=True, token=hf_token)
306
+ except:
307
+ import time
308
+ epoch_time = str(int(time.time()))
309
+ create_repo(f"{model_id}-{epoch_time}", private=True,token=hf_token)
310
+ operations = [
311
+ CommitOperationAdd(path_in_repo="token_identifier.txt", path_or_fileobj="token_identifier.txt"),
312
+ CommitOperationAdd(path_in_repo="README.md", path_or_fileobj="model.README.md"),
313
+ CommitOperationAdd(path_in_repo=f"model.ckpt",path_or_fileobj="model.ckpt")
314
+ ]
315
+ api.create_commit(
316
+ repo_id=model_id,
317
+ operations=operations,
318
+ commit_message=f"Upload the model {model_name}",
319
+ token=hf_token
320
+ )
321
+ api.upload_folder(
322
+ folder_path="output_model",
323
+ repo_id=model_id,
324
+ token=hf_token
325
+ )
326
+ api.upload_folder(
327
+ folder_path="concept_images",
328
+ path_in_repo="concept_images",
329
+ repo_id=model_id,
330
+ token=hf_token
331
+ )
332
+ if is_spaces:
333
+ if(not comes_from_automated):
334
+ extra_message = "Don't forget to remove the GPU attribution after you play with it."
335
+ else:
336
+ extra_message = "The GPU has been removed automatically as requested, and you can try the model via the model page"
337
+ api.create_discussion(repo_id=os.environ['SPACE_ID'], title=f"Your model {model_name} has finished trained from the Dreambooth Train Spaces!", description=f"Your model has been successfully uploaded to: https://huggingface.co/{model_id}. {extra_message}",repo_type="space", token=hf_token)
338
+
339
+ return [gr.update(visible=True, value=f"Successfully uploaded your model. Access it [here](https://huggingface.co/{model_id})"), gr.update(visible=True, value=["diffusers_model.tar", "model.ckpt"])]
340
+
341
+ def convert_to_ckpt():
342
+ convert("output_model", "model.ckpt")
343
+ return gr.update(visible=True, value=["diffusers_model.tar", "model.ckpt"])
344
+
345
+ def check_status(top_description):
346
+ print('=='*20)
347
+ print(os.listdir("./"))
348
+
349
+ if os.path.exists("hastrained.success"):
350
+ if is_spaces:
351
+ update_top_tag = gr.update(value=f'''
352
+ <div class="gr-prose" style="max-width: 80%">
353
+ <h2>Your model has finished training ✅</h2>
354
+ <p>Yay, congratulations on training your model. Scroll down to play with with it, save it (either downloading it or on the Hugging Face Hub). Once you are done, your model is safe, and you don't want to train a new one, go to the <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}">settings page</a> and downgrade your Space to a CPU Basic</p>
355
+ </div>
356
+ ''')
357
+ else:
358
+ update_top_tag = gr.update(value=f'''
359
+ <div class="gr-prose" style="max-width: 80%">
360
+ <h2>Your model has finished training ✅</h2>
361
+ <p>Yay, congratulations on training your model. Scroll down to play with with it, save it (either downloading it or on the Hugging Face Hub).</p>
362
+ </div>
363
+ ''')
364
+ show_outputs = True
365
+ elif os.path.exists("intraining.lock"):
366
+ update_top_tag = gr.update(value='''
367
+ <div class="gr-prose" style="max-width: 80%">
368
+ <h2>Don't worry, your model is still training! ⌛</h2>
369
+ <p>You closed the tab while your model was training, but it's all good! It is still training right now. You can click the "Open logs" button above here to check the training status. Once training is done, reload this tab to interact with your model</p>
370
+ </div>
371
+ ''')
372
+ show_outputs = False
373
+ else:
374
+ update_top_tag = gr.update(value=top_description)
375
+ show_outputs = False
376
+ if os.path.exists("diffusers_model.tar"):
377
+ update_files_tag = gr.update(visible=show_outputs, value=["diffusers_model.tar"])
378
+ else:
379
+ update_files_tag = gr.update(visible=show_outputs)
380
+ return [
381
+ update_top_tag, #top_description
382
+ gr.update(visible=show_outputs), #try_your_model
383
+ gr.update(visible=show_outputs), #push_to_hub
384
+ update_files_tag, #result
385
+ gr.update(visible=show_outputs), #convert_button
386
+ ]
387
+
388
+ def checkbox_swap(checkbox):
389
+ return [gr.update(visible=checkbox), gr.update(visible=checkbox), gr.update(visible=checkbox), gr.update(visible=checkbox)]
390
+
391
+ with gr.Blocks(css=css) as demo:
392
+ with gr.Box():
393
+ if is_shared_ui:
394
+ top_description = gr.HTML(f'''
395
+ <div class="gr-prose" style="max-width: 80%">
396
+ <h2>Attention - This Space doesn't work in this shared UI</h2>
397
+ <p>For it to work, you can either run locally or duplicate the Space and run it on your own profile using the free CPU or a (paid) private T4 GPU for training. CPU training takes a long time while each T4 costs US$0.60/h which should cost < US$1 to train most models using default settings!&nbsp;&nbsp;<a class="duplicate-button" style="display:inline-block" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></p>
398
+ <img class="instruction" src="file/duplicate.png">
399
+ <img class="arrow" src="file/arrow.png" />
400
+ </div>
401
+ ''')
402
+ elif(is_spaces):
403
+ top_description = gr.HTML(f'''
404
+ <div class="gr-prose" style="max-width: 80%">
405
+ <h2>You have successfully duplicated the Textual Inversion Training Space 🎉</h2>
406
+ <p>If you want to use CPU, it will take a long time to run the training below. If you want to use GPU, please get this ready: <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}/settings">attribute a T4 GPU to it (via the Settings tab)</a> and run the training below. You will be billed by the minute from when you activate the GPU until when it is turned it off.</p>
407
+ </div>
408
+ ''')
409
+ else:
410
+ top_description = gr.HTML(f'''
411
+ <div class="gr-prose" style="max-width: 80%">
412
+ <h2>You have successfully cloned the Dreambooth Training Space locally 🎉</h2>
413
+ <p>Do a <code>pip install requirements-local.txt</code></p>
414
+ </div>
415
+ ''')
416
+ gr.Markdown("# Textual Inversion Training UI 💭")
417
+ gr.Markdown("Customize Stable Diffusion by training it on a new concept. This Space is based on [Intel® Neural Compressor](https://github.com/intel/neural-compressor/tree/master/examples/pytorch/diffusion_model/diffusers/textual_inversion) with [🧨 diffusers](https://github.com/huggingface/diffusers)")
418
+
419
+ with gr.Row() as what_are_you_training:
420
+ type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
421
+ base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-4", "v1-5", "v2-512"], value="v1-4", interactive=True)
422
+
423
+ #Very hacky approach to emulate dynamically created Gradio components
424
+ with gr.Row() as upload_your_concept:
425
+ with gr.Column():
426
+ thing_description = gr.Markdown("You are going to train an `object`, please upload 1-5 images of the object to teach new concepts to Stable Diffusion, example")
427
+ thing_experimental = gr.Checkbox(label="Improve faces (prior preservation) - can take longer training but can improve faces", visible=False, value=False)
428
+ thing_image_example = gr.HTML('''<img src="file/dicoo-toy.png" class="aligncenter" height="128" width="128" />''')
429
+ things_naming = gr.Markdown("You should name your concept with a unique made up word that never appears in the model vocab (e.g.: `dicoo*` here). **The meaning of the initial word** is to initialize the concept word embedding which will make training easy (e.g.: `toy` here). Images will be automatically cropped to 512x512.")
430
+
431
+ with gr.Column():
432
+ file_collection = []
433
+ concept_collection = []
434
+ init_collection = []
435
+ buttons_collection = []
436
+ delete_collection = []
437
+ is_visible = []
438
+
439
+ row = [None] * maximum_concepts
440
+ for x in range(maximum_concepts):
441
+ ordinal = lambda n: "%d%s" % (n, "tsnrhtdd"[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4])
442
+ if(x == 0):
443
+ visible = True
444
+ is_visible.append(gr.State(value=True))
445
+ else:
446
+ visible = False
447
+ is_visible.append(gr.State(value=False))
448
+
449
+ file_collection.append(gr.File(label=f'''Upload the images for your {ordinal(x+1) if (x>0) else ""} concept''', file_count="multiple", interactive=True, visible=visible))
450
+ with gr.Column(visible=visible) as row[x]:
451
+ concept_collection.append(gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} concept word - use a unique, made up word to avoid collisions'''))
452
+ init_collection.append(gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} initial word - to init the concept embedding'''))
453
+ with gr.Row():
454
+ if(x < maximum_concepts-1):
455
+ buttons_collection.append(gr.Button(value="Add +1 concept", visible=visible))
456
+ if(x > 0):
457
+ delete_collection.append(gr.Button(value=f"Delete {ordinal(x+1)} concept"))
458
+
459
+ counter_add = 1
460
+ for button in buttons_collection:
461
+ if(counter_add < len(buttons_collection)):
462
+ button.click(lambda:
463
+ [gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), True, None],
464
+ None,
465
+ [row[counter_add], file_collection[counter_add], buttons_collection[counter_add-1], buttons_collection[counter_add], is_visible[counter_add], file_collection[counter_add]], queue=False)
466
+ else:
467
+ button.click(lambda:[gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), True], None, [row[counter_add], file_collection[counter_add], buttons_collection[counter_add-1], is_visible[counter_add]], queue=False)
468
+ counter_add += 1
469
+
470
+ counter_delete = 1
471
+ for delete_button in delete_collection:
472
+ if(counter_delete < len(delete_collection)+1):
473
+ delete_button.click(lambda:[gr.update(visible=False),gr.update(visible=False), gr.update(visible=True), False], None, [file_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], is_visible[counter_delete]], queue=False)
474
+ counter_delete += 1
475
+
476
+ with gr.Accordion("Custom Settings", open=False):
477
+ swap_auto_calculated = gr.Checkbox(label="Use custom settings")
478
+ gr.Markdown("The default steps is 1000. If your results aren't really what you wanted, it may be underfitting and you need more steps.")
479
+ steps = gr.Number(label="How many steps", value=1000)
480
+ # need to remove
481
+ perc_txt_encoder = gr.Number(label="Percentage of the training steps the text-encoder should be trained as well", value=30, visible=False)
482
+ # perc_txt_encoder = 30
483
+
484
+ with gr.Box(visible=False) as training_summary:
485
+ training_summary_text = gr.HTML("", visible=False, label="Training Summary")
486
+ is_advanced_visible = True if is_spaces else False
487
+ training_summary_checkbox = gr.Checkbox(label="Automatically remove paid GPU attribution and upload model to the Hugging Face Hub after training", value=False, visible=is_advanced_visible)
488
+ training_summary_model_name = gr.Textbox(label="Name of your model", visible=False)
489
+ training_summary_where_to_upload = gr.Dropdown(["My personal profile", "Public Library"], label="Upload to", visible=False)
490
+ training_summary_token_message = gr.Markdown("[A Hugging Face write access token](https://huggingface.co/settings/tokens), go to \"New token\" -> Role : Write. A regular read token won't work here.", visible=False)
491
+ training_summary_token = gr.Textbox(label="Hugging Face Write Token", type="password", visible=False)
492
+
493
+ train_btn = gr.Button("Start Training")
494
+
495
+ training_ongoing = gr.Markdown("## Training is ongoing ⌛... You can close this tab if you like or just wait. If you did not check the `Remove GPU After training`, you can come back here to try your model and upload it after training. Don't forget to remove the GPU attribution after you are done. ", visible=False)
496
+
497
+ #Post-training UI
498
+ completed_training = gr.Markdown('''# ✅ Training completed.
499
+ ### Don't forget to remove the GPU attribution after you are done trying and uploading your model''', visible=False)
500
+
501
+ with gr.Row():
502
+ with gr.Box(visible=True) as try_your_model:
503
+ gr.Markdown("## Try your model")
504
+ prompt = gr.Textbox(label="Type your prompt")
505
+ result_image = gr.Image()
506
+ inference_steps = gr.Slider(minimum=1, maximum=150, value=50, step=1)
507
+ generate_button = gr.Button("Generate Image")
508
+
509
+ with gr.Box(visible=False) as push_to_hub:
510
+ gr.Markdown("## Push to Hugging Face Hub")
511
+ model_name = gr.Textbox(label="Name of your model", placeholder="Tarsila do Amaral Style")
512
+ where_to_upload = gr.Dropdown(["My personal profile", "Public Library"], label="Upload to")
513
+ gr.Markdown("[A Hugging Face write access token](https://huggingface.co/settings/tokens), go to \"New token\" -> Role : Write. A regular read token won't work here.")
514
+ hf_token = gr.Textbox(label="Hugging Face Write Token", type="password")
515
+
516
+ push_button = gr.Button("Push to the Hub")
517
+
518
+ result = gr.File(label="Download the uploaded models in the diffusers format", visible=True)
519
+ success_message_upload = gr.Markdown(visible=False)
520
+ convert_button = gr.Button("Convert to CKPT", visible=False)
521
+
522
+ #Swap the examples and the % of text encoder trained depending if it is an object, person or style
523
+ type_of_thing.change(fn=swap_text, inputs=[type_of_thing], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
524
+
525
+ #Swap the base model
526
+ base_model_to_use.change(fn=swap_base_model, inputs=base_model_to_use, outputs=[])
527
+
528
+ #Update the summary box below the UI according to how many images are uploaded and whether users are using custom settings or not
529
+ for file in file_collection:
530
+ #file.change(fn=update_steps,inputs=file_collection, outputs=steps)
531
+ file.change(fn=count_files, inputs=file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[training_summary, training_summary_text], queue=False)
532
+
533
+ steps.change(fn=count_files, inputs=file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[training_summary, training_summary_text], queue=False)
534
+ perc_txt_encoder.change(fn=count_files, inputs=file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[training_summary, training_summary_text], queue=False)
535
+
536
+ #Give more options if the user wants to finish everything after training
537
+ if(is_spaces):
538
+ training_summary_checkbox.change(fn=checkbox_swap, inputs=training_summary_checkbox, outputs=[training_summary_token_message, training_summary_token, training_summary_model_name, training_summary_where_to_upload],queue=False, show_progress=False)
539
+ #Add a message for while it is in training
540
+ train_btn.click(lambda:gr.update(visible=True), inputs=None, outputs=training_ongoing)
541
+
542
+ #The main train function
543
+ train_btn.click(fn=train, inputs=is_visible+concept_collection+init_collection+file_collection+[base_model_to_use]+[thing_experimental]+[training_summary_where_to_upload]+[training_summary_model_name]+[training_summary_checkbox]+[training_summary_token]+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub, convert_button, training_ongoing, completed_training], queue=False)
544
+
545
+ #Button to generate an image from your trained model after training
546
+ print('=='*20)
547
+ print(prompt)
548
+ print(inference_steps)
549
+ generate_button.click(fn=generate, inputs=[prompt, inference_steps], outputs=result_image, queue=False)
550
+
551
+ #Button to push the model to the Hugging Face Hub
552
+ push_button.click(fn=push, inputs=[model_name, where_to_upload, hf_token, base_model_to_use], outputs=[success_message_upload, result], queue=False)
553
+ #Button to convert the model to ckpt format
554
+ convert_button.click(fn=convert_to_ckpt, inputs=[], outputs=result, queue=False)
555
+
556
+ #Checks if the training is running
557
+ demo.load(fn=check_status, inputs=top_description, outputs=[top_description, try_your_model, push_to_hub, result, convert_button], queue=False, show_progress=False)
558
+
559
+ demo.queue(default_enabled=False).launch(debug=True)
arrow.png ADDED
cat-toy-deprec.png ADDED
cat-toy.png ADDED
cattoy.png ADDED
convertosd.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, VAE, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+ # Written by jachiam
5
+
6
+ import argparse
7
+ import os.path as osp
8
+
9
+ import torch
10
+ import gc
11
+
12
+ # =================#
13
+ # UNet Conversion #
14
+ # =================#
15
+
16
+ unet_conversion_map = [
17
+ # (stable-diffusion, HF Diffusers)
18
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
19
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
20
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
21
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
22
+ ("input_blocks.0.0.weight", "conv_in.weight"),
23
+ ("input_blocks.0.0.bias", "conv_in.bias"),
24
+ ("out.0.weight", "conv_norm_out.weight"),
25
+ ("out.0.bias", "conv_norm_out.bias"),
26
+ ("out.2.weight", "conv_out.weight"),
27
+ ("out.2.bias", "conv_out.bias"),
28
+ ]
29
+
30
+ unet_conversion_map_resnet = [
31
+ # (stable-diffusion, HF Diffusers)
32
+ ("in_layers.0", "norm1"),
33
+ ("in_layers.2", "conv1"),
34
+ ("out_layers.0", "norm2"),
35
+ ("out_layers.3", "conv2"),
36
+ ("emb_layers.1", "time_emb_proj"),
37
+ ("skip_connection", "conv_shortcut"),
38
+ ]
39
+
40
+ unet_conversion_map_layer = []
41
+ # hardcoded number of downblocks and resnets/attentions...
42
+ # would need smarter logic for other networks.
43
+ for i in range(4):
44
+ # loop over downblocks/upblocks
45
+
46
+ for j in range(2):
47
+ # loop over resnets/attentions for downblocks
48
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
49
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
50
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
51
+
52
+ if i < 3:
53
+ # no attention layers in down_blocks.3
54
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
55
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
56
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
57
+
58
+ for j in range(3):
59
+ # loop over resnets/attentions for upblocks
60
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
61
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
62
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
63
+
64
+ if i > 0:
65
+ # no attention layers in up_blocks.0
66
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
67
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
68
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
69
+
70
+ if i < 3:
71
+ # no downsample in down_blocks.3
72
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
73
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
74
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
75
+
76
+ # no upsample in up_blocks.3
77
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
78
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
79
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
80
+
81
+ hf_mid_atn_prefix = "mid_block.attentions.0."
82
+ sd_mid_atn_prefix = "middle_block.1."
83
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
84
+
85
+ for j in range(2):
86
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
87
+ sd_mid_res_prefix = f"middle_block.{2*j}."
88
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
89
+
90
+
91
+ def convert_unet_state_dict(unet_state_dict):
92
+ # buyer beware: this is a *brittle* function,
93
+ # and correct output requires that all of these pieces interact in
94
+ # the exact order in which I have arranged them.
95
+ mapping = {k: k for k in unet_state_dict.keys()}
96
+ for sd_name, hf_name in unet_conversion_map:
97
+ mapping[hf_name] = sd_name
98
+ for k, v in mapping.items():
99
+ if "resnets" in k:
100
+ for sd_part, hf_part in unet_conversion_map_resnet:
101
+ v = v.replace(hf_part, sd_part)
102
+ mapping[k] = v
103
+ for k, v in mapping.items():
104
+ for sd_part, hf_part in unet_conversion_map_layer:
105
+ v = v.replace(hf_part, sd_part)
106
+ mapping[k] = v
107
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
108
+ return new_state_dict
109
+
110
+
111
+ # ================#
112
+ # VAE Conversion #
113
+ # ================#
114
+
115
+ vae_conversion_map = [
116
+ # (stable-diffusion, HF Diffusers)
117
+ ("nin_shortcut", "conv_shortcut"),
118
+ ("norm_out", "conv_norm_out"),
119
+ ("mid.attn_1.", "mid_block.attentions.0."),
120
+ ]
121
+
122
+ for i in range(4):
123
+ # down_blocks have two resnets
124
+ for j in range(2):
125
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
126
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
127
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
128
+
129
+ if i < 3:
130
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
131
+ sd_downsample_prefix = f"down.{i}.downsample."
132
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
133
+
134
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
135
+ sd_upsample_prefix = f"up.{3-i}.upsample."
136
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
137
+
138
+ # up_blocks have three resnets
139
+ # also, up blocks in hf are numbered in reverse from sd
140
+ for j in range(3):
141
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
142
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
143
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
144
+
145
+ # this part accounts for mid blocks in both the encoder and the decoder
146
+ for i in range(2):
147
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
148
+ sd_mid_res_prefix = f"mid.block_{i+1}."
149
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
150
+
151
+
152
+ vae_conversion_map_attn = [
153
+ # (stable-diffusion, HF Diffusers)
154
+ ("norm.", "group_norm."),
155
+ ("q.", "query."),
156
+ ("k.", "key."),
157
+ ("v.", "value."),
158
+ ("proj_out.", "proj_attn."),
159
+ ]
160
+
161
+
162
+ def reshape_weight_for_sd(w):
163
+ # convert HF linear weights to SD conv2d weights
164
+ return w.reshape(*w.shape, 1, 1)
165
+
166
+
167
+ def convert_vae_state_dict(vae_state_dict):
168
+ mapping = {k: k for k in vae_state_dict.keys()}
169
+ for k, v in mapping.items():
170
+ for sd_part, hf_part in vae_conversion_map:
171
+ v = v.replace(hf_part, sd_part)
172
+ mapping[k] = v
173
+ for k, v in mapping.items():
174
+ if "attentions" in k:
175
+ for sd_part, hf_part in vae_conversion_map_attn:
176
+ v = v.replace(hf_part, sd_part)
177
+ mapping[k] = v
178
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
179
+ weights_to_convert = ["q", "k", "v", "proj_out"]
180
+ print("Converting to CKPT ...")
181
+ for k, v in new_state_dict.items():
182
+ for weight_name in weights_to_convert:
183
+ if f"mid.attn_1.{weight_name}.weight" in k:
184
+ new_state_dict[k] = reshape_weight_for_sd(v)
185
+ return new_state_dict
186
+
187
+
188
+ # =========================#
189
+ # Text Encoder Conversion #
190
+ # =========================#
191
+ # pretty much a no-op
192
+
193
+
194
+ def convert_text_enc_state_dict(text_enc_dict):
195
+ return text_enc_dict
196
+
197
+
198
+ def convert(model_path, checkpoint_path):
199
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
200
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
201
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
202
+
203
+ # Convert the UNet model
204
+ unet_state_dict = torch.load(unet_path, map_location='cpu')
205
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
206
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
207
+
208
+ # Convert the VAE model
209
+ vae_state_dict = torch.load(vae_path, map_location='cpu')
210
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
211
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
212
+
213
+ # Convert the text encoder model
214
+ text_enc_dict = torch.load(text_enc_path, map_location='cpu')
215
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
216
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
217
+
218
+ # Put together new checkpoint
219
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
220
+
221
+ state_dict = {k:v.half() for k,v in state_dict.items()}
222
+ state_dict = {"state_dict": state_dict}
223
+ torch.save(state_dict, checkpoint_path)
224
+ del state_dict, text_enc_dict, vae_state_dict, unet_state_dict
225
+ torch.cuda.empty_cache()
226
+ gc.collect()
dicoo-toy.png ADDED
duplicate.png ADDED
mix.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09207c4e95fcf5296eb0ff708fdc672da960aeb2864d298810db5094b072a0d4
3
+ size 28022653
model_index.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionPipeline",
3
+ "_diffusers_version": "0.6.0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPFeatureExtractor"
7
+ ],
8
+ "safety_checker": [
9
+ "stable_diffusion",
10
+ "StableDiffusionSafetyChecker"
11
+ ],
12
+ "scheduler": [
13
+ "diffusers",
14
+ "PNDMScheduler"
15
+ ],
16
+ "text_encoder": [
17
+ "transformers",
18
+ "CLIPTextModel"
19
+ ],
20
+ "tokenizer": [
21
+ "transformers",
22
+ "CLIPTokenizer"
23
+ ],
24
+ "unet": [
25
+ "diffusers",
26
+ "UNet2DConditionModel"
27
+ ],
28
+ "vae": [
29
+ "diffusers",
30
+ "AutoencoderKL"
31
+ ]
32
+ }
person.png ADDED
requirements-local.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch==1.12.1+cu113
3
+ torchvision==0.13.1+cu113
4
+ diffusers==0.9.0
5
+ accelerate==0.12.0
6
+ OmegaConf
7
+ wget
8
+ pytorch_lightning
9
+ huggingface_hub
10
+ ftfy
11
+ transformers
12
+ pyfiglet
13
+ triton==2.0.0.dev20220701
14
+ bitsandbytes
15
+ python-slugify
16
+ requests
17
+ tensorboard
18
+ pip install git+https://github.com/facebookresearch/xformers@7e4c02c#egg=xformers
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch==1.12.1+cu113
3
+ torchvision==0.13.1+cu113
4
+ diffusers==0.9.0
5
+ accelerate==0.12.0
6
+ OmegaConf
7
+ wget
8
+ pytorch_lightning
9
+ huggingface_hub
10
+ ftfy
11
+ transformers
12
+ pyfiglet
13
+ triton==2.0.0.dev20220701
14
+ bitsandbytes
15
+ python-slugify
16
+ requests
17
+ tensorboard
18
+ https://github.com/apolinario/xformers/releases/download/0.0.2/xformers-0.0.14.dev0-cp38-cp38-linux_x86_64.whl
textual_inversion.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import math
4
+ import os
5
+ import random
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from torch.utils.data import Dataset
14
+
15
+ import PIL
16
+ from accelerate import Accelerator
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import set_seed
19
+ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
20
+ from diffusers.optimization import get_scheduler
21
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
22
+ from huggingface_hub import HfFolder, Repository, whoami
23
+ from PIL import Image
24
+ from torchvision import transforms
25
+ from tqdm.auto import tqdm
26
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
27
+ import gc
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ def save_progress(text_encoder, placeholder_token_id, accelerator, args):
33
+ logger.info("Saving embeddings")
34
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
35
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
36
+ torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
37
+
38
+
39
+ def parse_args():
40
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
41
+ parser.add_argument(
42
+ "--save_steps",
43
+ type=int,
44
+ default=500,
45
+ help="Save learned_embeds.bin every X updates steps.",
46
+ )
47
+ parser.add_argument(
48
+ "--pretrained_model_name_or_path",
49
+ type=str,
50
+ default=None,
51
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
52
+ )
53
+ parser.add_argument(
54
+ "--tokenizer_name",
55
+ type=str,
56
+ default=None,
57
+ help="Pretrained tokenizer name or path if not the same as model_name",
58
+ )
59
+ parser.add_argument(
60
+ "--train_data_dir", type=str, default=None, help="A folder containing the training data."
61
+ )
62
+ parser.add_argument(
63
+ "--placeholder_token",
64
+ type=str,
65
+ default=None,
66
+ help="A token to use as a placeholder for the concept.",
67
+ )
68
+ parser.add_argument(
69
+ "--initializer_token", type=str, default=None, help="A token to use as initializer word."
70
+ )
71
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
72
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
73
+ parser.add_argument(
74
+ "--output_dir",
75
+ type=str,
76
+ default="text-inversion-model",
77
+ help="The output directory where the model predictions and checkpoints will be written.",
78
+ )
79
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
80
+ parser.add_argument(
81
+ "--resolution",
82
+ type=int,
83
+ default=512,
84
+ help=(
85
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
86
+ " resolution"
87
+ ),
88
+ )
89
+ parser.add_argument(
90
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
91
+ )
92
+ parser.add_argument(
93
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
94
+ )
95
+ parser.add_argument("--num_train_epochs", type=int, default=100)
96
+ parser.add_argument(
97
+ "--max_train_steps",
98
+ type=int,
99
+ default=5000,
100
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
101
+ )
102
+ parser.add_argument(
103
+ "--gradient_accumulation_steps",
104
+ type=int,
105
+ default=1,
106
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
107
+ )
108
+ parser.add_argument(
109
+ "--learning_rate",
110
+ type=float,
111
+ default=1e-4,
112
+ help="Initial learning rate (after the potential warmup period) to use.",
113
+ )
114
+ parser.add_argument(
115
+ "--scale_lr",
116
+ action="store_true",
117
+ default=True,
118
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
119
+ )
120
+ parser.add_argument(
121
+ "--lr_scheduler",
122
+ type=str,
123
+ default="constant",
124
+ help=(
125
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
126
+ ' "constant", "constant_with_warmup"]'
127
+ ),
128
+ )
129
+ parser.add_argument(
130
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
131
+ )
132
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
133
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
134
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
135
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
136
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
137
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
138
+ parser.add_argument(
139
+ "--hub_model_id",
140
+ type=str,
141
+ default=None,
142
+ help="The name of the repository to keep in sync with the local `output_dir`.",
143
+ )
144
+ parser.add_argument(
145
+ "--logging_dir",
146
+ type=str,
147
+ default="logs",
148
+ help=(
149
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
150
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
151
+ ),
152
+ )
153
+ parser.add_argument(
154
+ "--mixed_precision",
155
+ type=str,
156
+ default="no",
157
+ choices=["no", "fp16", "bf16"],
158
+ help=(
159
+ "Whether to use mixed precision. Choose"
160
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
161
+ "and an Nvidia Ampere GPU."
162
+ ),
163
+ )
164
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
165
+
166
+ args = parser.parse_args()
167
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
168
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
169
+ args.local_rank = env_local_rank
170
+
171
+ '''
172
+ if args.train_data_dir is None:
173
+ raise ValueError("You must specify a train data directory.")
174
+ '''
175
+
176
+ return args
177
+
178
+
179
+ imagenet_templates_small = [
180
+ "a photo of a {}",
181
+ "a rendering of a {}",
182
+ "a cropped photo of the {}",
183
+ "the photo of a {}",
184
+ "a photo of a clean {}",
185
+ "a photo of a dirty {}",
186
+ "a dark photo of the {}",
187
+ "a photo of my {}",
188
+ "a photo of the cool {}",
189
+ "a close-up photo of a {}",
190
+ "a bright photo of the {}",
191
+ "a cropped photo of a {}",
192
+ "a photo of the {}",
193
+ "a good photo of the {}",
194
+ "a photo of one {}",
195
+ "a close-up photo of the {}",
196
+ "a rendition of the {}",
197
+ "a photo of the clean {}",
198
+ "a rendition of a {}",
199
+ "a photo of a nice {}",
200
+ "a good photo of a {}",
201
+ "a photo of the nice {}",
202
+ "a photo of the small {}",
203
+ "a photo of the weird {}",
204
+ "a photo of the large {}",
205
+ "a photo of a cool {}",
206
+ "a photo of a small {}",
207
+ ]
208
+
209
+ imagenet_style_templates_small = [
210
+ "a painting in the style of {}",
211
+ "a rendering in the style of {}",
212
+ "a cropped painting in the style of {}",
213
+ "the painting in the style of {}",
214
+ "a clean painting in the style of {}",
215
+ "a dirty painting in the style of {}",
216
+ "a dark painting in the style of {}",
217
+ "a picture in the style of {}",
218
+ "a cool painting in the style of {}",
219
+ "a close-up painting in the style of {}",
220
+ "a bright painting in the style of {}",
221
+ "a cropped painting in the style of {}",
222
+ "a good painting in the style of {}",
223
+ "a close-up painting in the style of {}",
224
+ "a rendition in the style of {}",
225
+ "a nice painting in the style of {}",
226
+ "a small painting in the style of {}",
227
+ "a weird painting in the style of {}",
228
+ "a large painting in the style of {}",
229
+ ]
230
+
231
+
232
+ class TextualInversionDataset(Dataset):
233
+ def __init__(
234
+ self,
235
+ data_root,
236
+ tokenizer,
237
+ learnable_property="object", # [object, style]
238
+ size=512,
239
+ repeats=100,
240
+ interpolation="bicubic",
241
+ flip_p=0.5,
242
+ set="train",
243
+ placeholder_token="*",
244
+ center_crop=False,
245
+ ):
246
+ self.data_root = data_root
247
+ self.tokenizer = tokenizer
248
+ self.learnable_property = learnable_property
249
+ self.size = size
250
+ self.placeholder_token = placeholder_token
251
+ self.center_crop = center_crop
252
+ self.flip_p = flip_p
253
+
254
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
255
+
256
+ self.num_images = len(self.image_paths)
257
+ self._length = self.num_images
258
+
259
+ if set == "train":
260
+ self._length = self.num_images * repeats
261
+
262
+ self.interpolation = {
263
+ "linear": PIL.Image.LINEAR,
264
+ "bilinear": PIL.Image.BILINEAR,
265
+ "bicubic": PIL.Image.BICUBIC,
266
+ "lanczos": PIL.Image.LANCZOS,
267
+ }[interpolation]
268
+
269
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
270
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
271
+
272
+ def __len__(self):
273
+ return self._length
274
+
275
+ def __getitem__(self, i):
276
+ example = {}
277
+ image = Image.open(self.image_paths[i % self.num_images])
278
+
279
+ if not image.mode == "RGB":
280
+ image = image.convert("RGB")
281
+
282
+ placeholder_string = self.placeholder_token
283
+ text = random.choice(self.templates).format(placeholder_string)
284
+
285
+ example["input_ids"] = self.tokenizer(
286
+ text,
287
+ padding="max_length",
288
+ truncation=True,
289
+ max_length=self.tokenizer.model_max_length,
290
+ return_tensors="pt",
291
+ ).input_ids[0]
292
+
293
+ # default to score-sde preprocessing
294
+ img = np.array(image).astype(np.uint8)
295
+
296
+ if self.center_crop:
297
+ crop = min(img.shape[0], img.shape[1])
298
+ h, w, = (
299
+ img.shape[0],
300
+ img.shape[1],
301
+ )
302
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
303
+
304
+ image = Image.fromarray(img)
305
+ image = image.resize((self.size, self.size), resample=self.interpolation)
306
+
307
+ image = self.flip_transform(image)
308
+ image = np.array(image).astype(np.uint8)
309
+ image = (image / 127.5 - 1.0).astype(np.float32)
310
+
311
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
312
+ return example
313
+
314
+
315
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
316
+ if token is None:
317
+ token = HfFolder.get_token()
318
+ if organization is None:
319
+ username = whoami(token)["name"]
320
+ return f"{username}/{model_id}"
321
+ else:
322
+ return f"{organization}/{model_id}"
323
+
324
+
325
+ def freeze_params(params):
326
+ for param in params:
327
+ param.requires_grad = False
328
+
329
+
330
+ def merge_two_dicts(starting_dict: dict, updater_dict: dict) -> dict:
331
+ """
332
+ Starts from base starting dict and then adds the remaining key values from updater replacing the values from
333
+ the first starting/base dict with the second updater dict.
334
+
335
+ For later: how does d = {**d1, **d2} replace collision?
336
+
337
+ :param starting_dict:
338
+ :param updater_dict:
339
+ :return:
340
+ """
341
+ new_dict: dict = starting_dict.copy() # start with keys and values of starting_dict
342
+ new_dict.update(updater_dict) # modifies starting_dict with keys and values of updater_dict
343
+ return new_dict
344
+
345
+ def merge_args(args1: argparse.Namespace, args2: argparse.Namespace) -> argparse.Namespace:
346
+ """
347
+
348
+ ref: https://stackoverflow.com/questions/56136549/how-can-i-merge-two-argparse-namespaces-in-python-2-x
349
+ :param args1:
350
+ :param args2:
351
+ :return:
352
+ """
353
+ # - the merged args
354
+ # The vars() function returns the __dict__ attribute to values of the given object e.g {field:value}.
355
+ merged_key_values_for_namespace: dict = merge_two_dicts(vars(args1), vars(args2))
356
+ args = argparse.Namespace(**merged_key_values_for_namespace)
357
+ return args
358
+
359
+ def run_training(args_imported):
360
+ args_default = parse_args()
361
+ args = merge_args(args_default, args_imported)
362
+
363
+ print(args)
364
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
365
+
366
+ accelerator = Accelerator(
367
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
368
+ mixed_precision=args.mixed_precision,
369
+ log_with="tensorboard",
370
+ logging_dir=logging_dir,
371
+ )
372
+
373
+ # If passed along, set the training seed now.
374
+ if args.seed is not None:
375
+ set_seed(args.seed)
376
+
377
+ # Handle the repository creation
378
+ if accelerator.is_main_process:
379
+ if args.push_to_hub:
380
+ if args.hub_model_id is None:
381
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
382
+ else:
383
+ repo_name = args.hub_model_id
384
+ repo = Repository(args.output_dir, clone_from=repo_name)
385
+
386
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
387
+ if "step_*" not in gitignore:
388
+ gitignore.write("step_*\n")
389
+ if "epoch_*" not in gitignore:
390
+ gitignore.write("epoch_*\n")
391
+ elif args.output_dir is not None:
392
+ os.makedirs(args.output_dir, exist_ok=True)
393
+
394
+ # Load the tokenizer and add the placeholder token as a additional special token
395
+ if args.tokenizer_name:
396
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
397
+ elif args.pretrained_model_name_or_path:
398
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
399
+
400
+ # Add the placeholder token in tokenizer
401
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
402
+ if num_added_tokens == 0:
403
+ raise ValueError(
404
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
405
+ " `placeholder_token` that is not already in the tokenizer."
406
+ )
407
+
408
+ # Convert the initializer_token, placeholder_token to ids
409
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
410
+ # Check if initializer_token is a single token or a sequence of tokens
411
+ if len(token_ids) > 1:
412
+ raise ValueError("The initializer token must be a single token.")
413
+
414
+ initializer_token_id = token_ids[0]
415
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
416
+
417
+ # Load models and create wrapper for stable diffusion
418
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
419
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
420
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
421
+
422
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
423
+ text_encoder.resize_token_embeddings(len(tokenizer))
424
+
425
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
426
+ token_embeds = text_encoder.get_input_embeddings().weight.data
427
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
428
+
429
+ # Freeze vae and unet
430
+ freeze_params(vae.parameters())
431
+ freeze_params(unet.parameters())
432
+ # Freeze all parameters except for the token embeddings in text encoder
433
+ params_to_freeze = itertools.chain(
434
+ text_encoder.text_model.encoder.parameters(),
435
+ text_encoder.text_model.final_layer_norm.parameters(),
436
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
437
+ )
438
+ freeze_params(params_to_freeze)
439
+
440
+ if args.scale_lr:
441
+ args.learning_rate = (
442
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
443
+ )
444
+
445
+ # Initialize the optimizer
446
+ optimizer = torch.optim.AdamW(
447
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
448
+ lr=args.learning_rate,
449
+ betas=(args.adam_beta1, args.adam_beta2),
450
+ weight_decay=args.adam_weight_decay,
451
+ eps=args.adam_epsilon,
452
+ )
453
+
454
+ # TODO (patil-suraj): load scheduler using args
455
+ noise_scheduler = DDPMScheduler(
456
+ beta_start=0.00085,
457
+ beta_end=0.012,
458
+ beta_schedule="scaled_linear",
459
+ num_train_timesteps=1000,
460
+ )
461
+
462
+ train_dataset = TextualInversionDataset(
463
+ data_root=args.train_data_dir,
464
+ tokenizer=tokenizer,
465
+ size=args.resolution,
466
+ placeholder_token=args.placeholder_token,
467
+ repeats=args.repeats,
468
+ learnable_property=args.learnable_property,
469
+ center_crop=args.center_crop,
470
+ set="train",
471
+ )
472
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
473
+
474
+ # Scheduler and math around the number of training steps.
475
+ overrode_max_train_steps = False
476
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
477
+ if args.max_train_steps is None:
478
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
479
+ overrode_max_train_steps = True
480
+
481
+ lr_scheduler = get_scheduler(
482
+ args.lr_scheduler,
483
+ optimizer=optimizer,
484
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
485
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
486
+ )
487
+
488
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
489
+ text_encoder, optimizer, train_dataloader, lr_scheduler
490
+ )
491
+
492
+ # Move vae and unet to device
493
+ vae.to(accelerator.device)
494
+ unet.to(accelerator.device)
495
+
496
+ # Keep vae and unet in eval model as we don't train these
497
+ vae.eval()
498
+ unet.eval()
499
+
500
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
501
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
502
+ if overrode_max_train_steps:
503
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
504
+ # Afterwards we recalculate our number of training epochs
505
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
506
+
507
+ # We need to initialize the trackers we use, and also store our configuration.
508
+ # The trackers initializes automatically on the main process.
509
+ if accelerator.is_main_process:
510
+ accelerator.init_trackers("textual_inversion", config=vars(args))
511
+
512
+ # Train!
513
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
514
+
515
+ logger.info("***** Running training *****")
516
+ logger.info(f" Num examples = {len(train_dataset)}")
517
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
518
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
519
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
520
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
521
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
522
+ # Only show the progress bar once on each machine.
523
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
524
+ progress_bar.set_description("Steps")
525
+ global_step = 0
526
+
527
+ for epoch in range(args.num_train_epochs):
528
+ text_encoder.train()
529
+ for step, batch in enumerate(train_dataloader):
530
+ with accelerator.accumulate(text_encoder):
531
+ # Convert images to latent space
532
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
533
+ latents = latents * 0.18215
534
+
535
+ # Sample noise that we'll add to the latents
536
+ noise = torch.randn(latents.shape).to(latents.device)
537
+ bsz = latents.shape[0]
538
+ # Sample a random timestep for each image
539
+ timesteps = torch.randint(
540
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
541
+ ).long()
542
+
543
+ # Add noise to the latents according to the noise magnitude at each timestep
544
+ # (this is the forward diffusion process)
545
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
546
+
547
+ # Get the text embedding for conditioning
548
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
549
+
550
+ # Predict the noise residual
551
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
552
+
553
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
554
+ accelerator.backward(loss)
555
+
556
+ # Zero out the gradients for all token embeddings except the newly added
557
+ # embeddings for the concept, as we only want to optimize the concept embeddings
558
+ if accelerator.num_processes > 1:
559
+ grads = text_encoder.module.get_input_embeddings().weight.grad
560
+ else:
561
+ grads = text_encoder.get_input_embeddings().weight.grad
562
+ # Get the index for tokens that we want to zero the grads for
563
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
564
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
565
+
566
+ optimizer.step()
567
+ lr_scheduler.step()
568
+ optimizer.zero_grad()
569
+
570
+ # Checks if the accelerator has performed an optimization step behind the scenes
571
+ if accelerator.sync_gradients:
572
+ progress_bar.update(1)
573
+ global_step += 1
574
+ if global_step % args.save_steps == 0:
575
+ save_progress(text_encoder, placeholder_token_id, accelerator, args)
576
+
577
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
578
+ progress_bar.set_postfix(**logs)
579
+ accelerator.log(logs, step=global_step)
580
+
581
+ if global_step >= args.max_train_steps:
582
+ break
583
+
584
+ accelerator.wait_for_everyone()
585
+
586
+ # Create the pipeline using using the trained modules and save it.
587
+ if accelerator.is_main_process:
588
+ pipeline = StableDiffusionPipeline(
589
+ text_encoder=accelerator.unwrap_model(text_encoder),
590
+ vae=vae,
591
+ unet=unet,
592
+ tokenizer=tokenizer,
593
+ scheduler=PNDMScheduler(
594
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
595
+ ),
596
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
597
+ feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
598
+ )
599
+ pipeline.save_pretrained(args.output_dir)
600
+ # Also save the newly trained embeddings
601
+ save_progress(text_encoder, placeholder_token_id, accelerator, args)
602
+
603
+ if args.push_to_hub:
604
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
605
+
606
+ accelerator.end_training()
607
+ torch.cuda.empty_cache()
608
+ gc.collect()
609
+
610
+
611
+ if __name__ == "__main__":
612
+ main()
trsl_style.png ADDED