DarianT commited on
Commit
c9c864b
·
1 Parent(s): 12dc184

Update demo to use LoRA weights

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import random
4
  import spaces # Uncomment if using ZeroGPU
5
 
6
- from diffusers import DiffusionPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -11,19 +11,28 @@ model_repo_id = "stabilityai/stable-diffusion-2-1-base"
11
 
12
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
 
14
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
16
 
17
  backgrounds_list = ["forest", "city street", "beach", "office", "bus", "laboratory", "factory", "construction site", "hospital", "night club", ""]
18
  poses_list = ["portrait", "side-portrait"]
19
- id_list = ["ID_0", "ID_1", "ID_2", "ID_3", "ID_4", "ID_5"]
20
 
21
- gender_dict = {"ID_0": "male"}
22
  MAX_SEED = 10000
23
  image_size = 512
24
 
25
  @spaces.GPU # Uncomment if using ZeroGPU
26
  def infer(
 
27
  background,
28
  pose,
29
  negative_prompt,
@@ -34,16 +43,20 @@ def infer(
34
  progress=gr.Progress(track_tqdm=True),
35
  num_images=1
36
  ):
 
 
 
 
37
  if randomize_seed:
38
  seed = random.randint(0, MAX_SEED)
39
 
40
  generator = torch.Generator().manual_seed(seed)
41
 
42
  id = "ID_0"
43
- gender = gender_dict[id]
44
 
45
  # Construct prompt from dropdown selections
46
- prompt = f"face {pose.lower()} photo of {gender} {id} person, {background.lower()} background"
47
 
48
  print(prompt)
49
  print(negative_prompt)
@@ -144,6 +157,7 @@ with gr.Blocks(css=css) as demo:
144
  triggers=[run_button.click],
145
  fn=infer,
146
  inputs=[
 
147
  background,
148
  pose,
149
  negative_prompt,
 
3
  import random
4
  import spaces # Uncomment if using ZeroGPU
5
 
6
+ from diffusers import StableDiffusionPipeline, DDPMScheduler
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
11
 
12
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
 
14
+ # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
+ # pipe = pipe.to(device)
16
+
17
+ pipe = StableDiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16).to(device)
18
+ pipe.scheduler = DDPMScheduler.from_pretrained(model_repo_id, subfolder="scheduler")
19
+
20
+ folder_of_lora_weights = "ID_Booth_LoRA_weights"
21
+ which_checkpoint = "checkpoint-31-6400"
22
+ lora_name = "pytorch_lora_weights.safetensors"
23
+
24
 
25
  backgrounds_list = ["forest", "city street", "beach", "office", "bus", "laboratory", "factory", "construction site", "hospital", "night club", ""]
26
  poses_list = ["portrait", "side-portrait"]
27
+ id_list = ["ID_0", "ID_1", "ID_2"]
28
 
29
+ gender_dict = {"ID_0": "male", "ID_1": "male", "ID_2": "female", "ID_2": "male"}
30
  MAX_SEED = 10000
31
  image_size = 512
32
 
33
  @spaces.GPU # Uncomment if using ZeroGPU
34
  def infer(
35
+ which_id,
36
  background,
37
  pose,
38
  negative_prompt,
 
43
  progress=gr.Progress(track_tqdm=True),
44
  num_images=1
45
  ):
46
+
47
+ full_lora_weights_path = f"{folder_of_lora_weights}/{which_id}/{which_checkpoint}/{lora_name}"
48
+ pipe.load_lora_weights(full_lora_weights_path)
49
+
50
  if randomize_seed:
51
  seed = random.randint(0, MAX_SEED)
52
 
53
  generator = torch.Generator().manual_seed(seed)
54
 
55
  id = "ID_0"
56
+ gender = gender_dict[which_id]
57
 
58
  # Construct prompt from dropdown selections
59
+ prompt = f"face {pose.lower()} photo of {gender} sks person, {background.lower()} background"
60
 
61
  print(prompt)
62
  print(negative_prompt)
 
157
  triggers=[run_button.click],
158
  fn=infer,
159
  inputs=[
160
+ which_id,
161
  background,
162
  pose,
163
  negative_prompt,