Sek2810 commited on
Commit
430f76f
·
verified ·
1 Parent(s): 901fb88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -125
app.py CHANGED
@@ -1,166 +1,174 @@
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
  import torch
5
- import spaces
6
- from PIL import Image
 
 
 
 
 
7
  import os
 
8
 
9
- from models.transformer_sd3 import SD3Transformer2DModel
10
- from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
11
-
12
- from transformers import AutoProcessor, SiglipVisionModel
13
- from huggingface_hub import hf_hub_download
14
 
 
 
 
 
 
 
 
 
15
 
16
- # Constants
17
- MAX_SEED = np.iinfo(np.int32).max
18
- MAX_IMAGE_SIZE = 1024
19
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
-
21
- model_path = 'stabilityai/stable-diffusion-3.5-large'
22
- image_encoder_path = "google/siglip-so400m-patch14-384"
23
- ipadapter_path = hf_hub_download(repo_id="InstantX/SD3.5-Large-IP-Adapter", filename="ip-adapter.bin")
24
 
25
  transformer = SD3Transformer2DModel.from_pretrained(
26
- model_path,
27
- subfolder="transformer",
28
- torch_dtype=torch.bfloat16
29
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- pipe = StableDiffusion3Pipeline.from_pretrained(
32
- model_path,
33
- transformer=transformer,
34
- torch_dtype=torch.bfloat16
35
- ).to("cuda")
36
 
37
- pipe.init_ipadapter(
38
- ip_adapter_path=ipadapter_path,
39
- image_encoder_path=image_encoder_path,
40
- nb_token=64,
41
  )
 
 
 
 
 
 
42
 
43
- def resize_img(image, max_size=1024):
44
- width, height = image.size
45
- scaling_factor = min(max_size / width, max_size / height)
46
- new_width = int(width * scaling_factor)
47
- new_height = int(height * scaling_factor)
48
- return image.resize((new_width, new_height), Image.LANCZOS)
49
 
50
  @spaces.GPU
51
- def process_image(
52
- image,
53
- prompt,
54
- scale,
55
- seed,
56
- randomize_seed,
57
- width,
58
- height,
59
- progress=gr.Progress(track_tqdm=True),
60
- ):
61
- #pipe.to("cuda")
62
  if randomize_seed:
63
  seed = random.randint(0, MAX_SEED)
64
-
65
- if image is None:
66
- return None, seed
67
-
68
- # Convert to PIL Image if needed
69
- if not isinstance(image, Image.Image):
70
- image = Image.fromarray(image)
71
-
72
- # Resize image
73
- image = resize_img(image)
74
-
75
- # Generate the image
76
- result = pipe(
77
- clip_image=image,
78
  prompt=prompt,
79
- ipadapter_scale=scale,
80
- width=width,
81
- height=height,
82
- generator=torch.Generator().manual_seed(seed)
83
  ).images[0]
84
-
85
- return result, seed
86
 
87
- # UI CSS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  css = """
89
  #col-container {
90
  margin: 0 auto;
91
- max-width: 960px;
92
  }
93
  """
94
 
95
- # Create the Gradio interface
96
  with gr.Blocks(css=css) as demo:
97
  with gr.Column(elem_id="col-container"):
98
- gr.Markdown("# InstantX's SD3.5 IP Adapter")
99
-
 
 
 
 
 
100
  with gr.Row():
101
- with gr.Column():
102
- input_image = gr.Image(
103
- label="Input Image",
104
- type="pil"
105
- )
106
- scale = gr.Slider(
107
- label="Image Scale",
108
- minimum=0.0,
109
- maximum=1.0,
110
- step=0.1,
111
- value=0.7,
112
- )
113
- prompt = gr.Text(
114
- label="Prompt",
115
- max_lines=1,
116
- placeholder="Enter your prompt",
117
- )
118
- run_button = gr.Button("Generate", variant="primary")
119
-
120
- with gr.Column():
121
- result = gr.Image(label="Result")
122
-
123
  with gr.Accordion("Advanced Settings", open=False):
124
  seed = gr.Slider(
125
  label="Seed",
126
  minimum=0,
127
  maximum=MAX_SEED,
128
  step=1,
129
- value=42,
130
  )
131
-
132
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
133
-
134
- with gr.Row():
135
- width = gr.Slider(
136
- label="Width",
137
- minimum=256,
138
- maximum=MAX_IMAGE_SIZE,
139
- step=32,
140
- value=1024,
141
- )
142
-
143
- height = gr.Slider(
144
- label="Height",
145
- minimum=256,
146
- maximum=MAX_IMAGE_SIZE,
147
- step=32,
148
- value=1024,
149
- )
150
-
151
- run_button.click(
152
- fn=process_image,
153
- inputs=[
154
- input_image,
155
- prompt,
156
- scale,
157
- seed,
158
- randomize_seed,
159
- width,
160
- height,
161
  ],
162
- outputs=[result, seed],
 
 
 
 
 
163
  )
164
 
165
- if __name__ == "__main__":
166
- demo.launch()
 
1
+ import random
2
+ import spaces
3
+
4
  import gradio as gr
5
  import numpy as np
 
6
  import torch
7
+ from diffusers import (
8
+ StableDiffusion3Pipeline,
9
+ SD3Transformer2DModel,
10
+ FlashFlowMatchEulerDiscreteScheduler,
11
+ AutoencoderTiny,
12
+ )
13
+ from peft import PeftModel
14
  import os
15
+ from huggingface_hub import snapshot_download
16
 
17
+ huggingface_token = os.getenv("HUGGINFACE_TOKEN")
 
 
 
 
18
 
19
+ model_path = snapshot_download(
20
+ repo_id="stabilityai/stable-diffusion-3-medium-diffusers",
21
+ repo_type="model",
22
+ ignore_patterns=["*.md", "*..gitattributes"],
23
+ local_dir="stable-diffusion-3-medium",
24
+ token=huggingface_token, # type a new token-id.
25
+ )
26
+ import spaces
27
 
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
29
 
30
  transformer = SD3Transformer2DModel.from_pretrained(
31
+ model_path,
32
+ subfolder="transformer",
33
+ torch_dtype=torch.float16,
34
  )
35
+ transformer = PeftModel.from_pretrained(transformer, "jasperai/flash-sd3")
36
+
37
+
38
+ if torch.cuda.is_available():
39
+ torch.cuda.max_memory_allocated(device=device)
40
+ pipe = StableDiffusion3Pipeline.from_pretrained(
41
+ model_path,
42
+ transformer=transformer,
43
+ torch_dtype=torch.float16,
44
+ text_encoder_3=None,
45
+ tokenizer_3=None,
46
+ )
47
+ pipe.vae = AutoencoderTiny.from_pretrained(
48
+ "madebyollin/taesd3", torch_dtype=torch.float16
49
+ )
50
+
51
+ # pipe.vae.decoder.layers = torch.compile(
52
+ # pipe.vae.decoder.layers,
53
+ # fullgraph=True,
54
+ # dynamic=False,
55
+ # mode="max-autotune-no-cudagraphs",
56
+ # )
57
+ pipe.vae.config.shift_factor = 0.0
58
+
59
+ pipe = pipe.to(device)
60
+ else:
61
+ pipe = StableDiffusion3Pipeline.from_pretrained(
62
+ model_path,
63
+ transformer=transformer,
64
+ torch_dtype=torch.float16,
65
+ text_encoder_3=None,
66
+ tokenizer_3=None,
67
+ )
68
+ pipe = pipe.to(device)
69
 
 
 
 
 
 
70
 
71
+ pipe.scheduler = FlashFlowMatchEulerDiscreteScheduler.from_pretrained(
72
+ model_path,
73
+ subfolder="scheduler",
 
74
  )
75
+ pipe.set_progress_bar_config(disable=True)
76
+
77
+
78
+ MAX_SEED = np.iinfo(np.int32).max
79
+ MAX_IMAGE_SIZE = 1024
80
+ NUM_INFERENCE_STEPS = 4
81
 
 
 
 
 
 
 
82
 
83
  @spaces.GPU
84
+ def infer(prompt, seed, randomize_seed):
 
 
 
 
 
 
 
 
 
 
85
  if randomize_seed:
86
  seed = random.randint(0, MAX_SEED)
87
+
88
+ generator = torch.Generator().manual_seed(seed)
89
+
90
+ image = pipe(
 
 
 
 
 
 
 
 
 
 
91
  prompt=prompt,
92
+ guidance_scale=0,
93
+ num_inference_steps=NUM_INFERENCE_STEPS,
94
+ generator=generator,
 
95
  ).images[0]
 
 
96
 
97
+ return image
98
+
99
+
100
+ examples = [
101
+ "The image showcases a freshly baked bread, possibly focaccia, with rosemary sprigs and red pepper flakes sprinkled on top. It's sliced and placed on a wire cooling rack, with a bowl of mixed peppercorns beside it.",
102
+ 'a 3D render of a wizard raccoon holding a sign saying "SD3" with a magic wand.',
103
+ "A panda reading a book in a lush forest.",
104
+ "A raccoon trapped inside a glass jar full of colorful candies, the background is steamy with vivid colors",
105
+ "Pirate ship sailing on a sea with the milky way galaxy in the sky and purple glow lights",
106
+ "a cute cartoon fluffy rabbit pilot walking on a military aircraft carrier, 8k, cinematic",
107
+ "A 3d render of a futuristic city with a giant robot in the middle full of neon lights, pink and blue colors",
108
+ "A close up of an old elderly man with green eyes looking straight at the camera",
109
+ "photo of a huge red cat with green eyes sitting on a cloud in the sky, looking at the camera",
110
+ ]
111
+
112
  css = """
113
  #col-container {
114
  margin: 0 auto;
115
+ max-width: 712px;
116
  }
117
  """
118
 
 
119
  with gr.Blocks(css=css) as demo:
120
  with gr.Column(elem_id="col-container"):
121
+ gr.Markdown(
122
+ f"""
123
+ # ⚡ Flash Diffusion: FlashSD3 + TAESD3 ⚡️
124
+ [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) with [Tiny AutoEncoder for Stable Diffusion 3](https://huggingface.co/madebyollin/taesd3)
125
+ """
126
+ )
127
+
128
  with gr.Row():
129
+ prompt = gr.Text(
130
+ label="Prompt",
131
+ show_label=False,
132
+ max_lines=1,
133
+ placeholder="Enter your prompt",
134
+ container=False,
135
+ )
136
+
137
+ run_button = gr.Button("Run", scale=0)
138
+
139
+ result = gr.Image(label="Result", show_label=False)
140
+
 
 
 
 
 
 
 
 
 
 
141
  with gr.Accordion("Advanced Settings", open=False):
142
  seed = gr.Slider(
143
  label="Seed",
144
  minimum=0,
145
  maximum=MAX_SEED,
146
  step=1,
147
+ value=0,
148
  )
149
+
150
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
151
+
152
+ examples = gr.Examples(examples=examples, inputs=[prompt], cache_examples=False)
153
+
154
+ gr.Markdown("**Disclaimer:**")
155
+ gr.Markdown(
156
+ "This demo is only for research purpose. Jasper cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. Jasper provides the tools, but the responsibility for their use lies with the individual user."
157
+ )
158
+ gr.on(
159
+ [
160
+ run_button.click,
161
+ seed.change,
162
+ randomize_seed.change,
163
+ prompt.submit,
164
+ prompt.change,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  ],
166
+ fn=infer,
167
+ inputs=[prompt, seed, randomize_seed],
168
+ outputs=[result],
169
+ show_progress="minimal",
170
+ show_api=True,
171
+ trigger_mode="always_last",
172
  )
173
 
174
+ demo.queue().launch(show_api=True)