ChefEase commited on
Commit
4aea63c
Β·
verified Β·
1 Parent(s): 35befad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -81
app.py CHANGED
@@ -1,98 +1,210 @@
 
 
 
1
  import torch
2
- from PIL import Image
3
  import gradio as gr
4
- from huggingface_hub import hf_hub_download
5
- from huggingface_hub import HfFolder
6
-
7
  from src_inference.pipeline import FluxPipeline
8
- from src_inference.lora_helper import set_single_lora, clear_cache
9
- import os
10
 
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- print("Running on:", device)
 
 
 
13
 
14
- # Download and load model
15
- base_path = hf_hub_download(repo_id="showlab/OmniConsistency", filename="OmniConsistency.safetensors", local_dir="./Model")
16
- lora_path = hf_hub_download(
17
- repo_id="showlab/OmniConsistency",
18
- filename="LoRAs/Ghibli_rank128_bf16.safetensors",
19
- local_dir="./LoRAs"
20
- )
21
- lora_path = hf_hub_download(
22
- repo_id="showlab/OmniConsistency",
23
- filename="LoRAs/American_Cartoon_rank128_bf16.safetensors",
24
- local_dir="./LoRAs"
25
- )
26
- lora_path = hf_hub_download(
27
- repo_id="showlab/OmniConsistency",
28
- filename="LoRAs/Chinese_Ink_rank128_bf16.safetensors",
29
- local_dir="./LoRAs"
30
- )
31
- lora_path = hf_hub_download(
32
- repo_id="showlab/OmniConsistency",
33
- filename="LoRAs/Jojo_rank128_bf16.safetensors",
34
- local_dir="./LoRAs"
35
- )
36
- lora_path = hf_hub_download(
37
- repo_id="showlab/OmniConsistency",
38
- filename="LoRAs/Line_rank128_bf16.safetensors",
39
- local_dir="./LoRAs"
40
- )
41
- lora_path = hf_hub_download(
42
- repo_id="showlab/OmniConsistency",
43
- filename="LoRAs/Rick_Morty_rank128_bf16.safetensors",
44
- local_dir="./LoRAs"
45
- )
46
- lora_path = hf_hub_download(
47
  repo_id="showlab/OmniConsistency",
48
- filename="LoRAs/Vector_rank128_bf16.safetensors",
49
- local_dir="./LoRAs"
50
  )
51
 
52
- token = os.environ.get("HF_TOKEN")
53
  pipe = FluxPipeline.from_pretrained(
54
- "black-forest-labs/FLUX.1-dev",
55
- use_auth_token=token,
56
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
57
- ).to(device)
58
-
59
- set_single_lora(pipe.transformer, base_path, lora_weights=[1], cond_size=512)
60
-
61
- pipe.unload_lora_weights()
62
- pipe.load_lora_weights("./LoRAs", weight_name="Ghibli_rank128_bf16.safetensors")
63
- pipe.load_lora_weights("./LoRAs", weight_name="American_Cartoon_rank128_bf16.safetensors")
64
- pipe.load_lora_weights("./LoRAs", weight_name="Chinese_Ink_rank128_bf16.safetensors")
65
- pipe.load_lora_weights("./LoRAs", weight_name="Jojo_rank128_bf16.safetensors")
66
- pipe.load_lora_weights("./LoRAs", weight_name="Line_rank128_bf16.safetensors")
67
- pipe.load_lora_weights("./LoRAs", weight_name="Rick_Morty_rank128_bf16.safetensors")
68
- pipe.load_lora_weights("./LoRAs", weight_name="Vector_rank128_bf16.safetensors")
69
-
70
-
71
- def generate_manga(input_image, prompt):
72
- spatial_image = [input_image.convert("RGB")]
73
- image = pipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  prompt,
75
- height=1024,
76
- width=1024,
77
- guidance_scale=3.5,
78
- num_inference_steps=25,
79
  max_sequence_length=512,
 
80
  spatial_images=spatial_image,
81
- subject_images=[],
82
  cond_size=512,
83
  ).images[0]
 
84
 
85
  clear_cache(pipe.transformer)
86
- return image
87
-
88
- demo = gr.Interface(
89
- fn=generate_manga,
90
- inputs=[
91
- gr.Image(type="pil", label="Input Character"),
92
- gr.Textbox(label="Scene Prompt")
93
- ],
94
- outputs=gr.Image(label="Generated Manga Frame"),
95
- title="OmniConsistency Manga Generator"
96
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- demo.launch()
 
 
 
1
+ import spaces
2
+ import os
3
+ import time
4
  import torch
 
5
  import gradio as gr
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download, list_repo_files
 
8
  from src_inference.pipeline import FluxPipeline
9
+ from src_inference.lora_helper import set_single_lora
 
10
 
11
+ BASE_PATH = "black-forest-labs/FLUX.1-dev"
12
+ LOCAL_LORA_DIR = "./LoRAs"
13
+ CUSTOM_LORA_DIR = "./Custom_LoRAs"
14
+ os.makedirs(LOCAL_LORA_DIR, exist_ok=True)
15
+ os.makedirs(CUSTOM_LORA_DIR, exist_ok=True)
16
 
17
+ print("downloading OmniConsistency base LoRA …")
18
+ omni_consistency_path = hf_hub_download(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  repo_id="showlab/OmniConsistency",
20
+ filename="OmniConsistency.safetensors",
21
+ local_dir="./Model"
22
  )
23
 
24
+ print("loading base pipeline …")
25
  pipe = FluxPipeline.from_pretrained(
26
+ BASE_PATH, torch_dtype=torch.bfloat16
27
+ ).to("cuda")
28
+ set_single_lora(pipe.transformer, omni_consistency_path,
29
+ lora_weights=[1], cond_size=512)
30
+
31
+ def download_all_loras():
32
+ lora_names = [
33
+ "3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
34
+ "Fabric", "Ghibli", "Irasutoya", "Jojo", "LEGO", "Line",
35
+ "Macaron", "Oil_Painting", "Origami", "Paper_Cutting",
36
+ "Picasso", "Pixel", "Poly", "Pop_Art", "Rick_Morty",
37
+ "Snoopy", "Van_Gogh", "Vector"
38
+ ]
39
+ for name in lora_names:
40
+ hf_hub_download(
41
+ repo_id="showlab/OmniConsistency",
42
+ filename=f"LoRAs/{name}_rank128_bf16.safetensors",
43
+ local_dir=LOCAL_LORA_DIR,
44
+ )
45
+ download_all_loras()
46
+
47
+ def clear_cache(transformer):
48
+ for _, attn_processor in transformer.attn_processors.items():
49
+ attn_processor.bank_kv.clear()
50
+
51
+ @spaces.GPU()
52
+ def generate_image(
53
+ lora_name,
54
+ custom_repo_id,
55
+ prompt,
56
+ uploaded_image,
57
+ width, height,
58
+ guidance_scale,
59
+ num_inference_steps,
60
+ seed
61
+ ):
62
+ width, height = int(width), int(height)
63
+ generator = torch.Generator("cpu").manual_seed(seed)
64
+
65
+ if custom_repo_id and custom_repo_id.strip():
66
+ repo_id = custom_repo_id.strip()
67
+ try:
68
+ files = list_repo_files(repo_id)
69
+ print("using custom LoRA from:", repo_id)
70
+ safetensors_files = [f for f in files if f.endswith(".safetensors")]
71
+ print("found safetensors files:", safetensors_files)
72
+ if not safetensors_files:
73
+ raise ValueError("No .safetensors files were found in this repo")
74
+ fname = safetensors_files[0]
75
+ lora_path = hf_hub_download(
76
+ repo_id=repo_id,
77
+ filename=fname,
78
+ local_dir=CUSTOM_LORA_DIR,
79
+ )
80
+ except Exception as e:
81
+ raise gr.Error(f"Load custom LoRA failed: {e}")
82
+ else:
83
+ lora_path = os.path.join(
84
+ f"{LOCAL_LORA_DIR}/LoRAs", f"{lora_name}_rank128_bf16.safetensors"
85
+ )
86
+
87
+ pipe.unload_lora_weights()
88
+ try:
89
+ pipe.load_lora_weights(
90
+ os.path.dirname(lora_path),
91
+ weight_name=os.path.basename(lora_path)
92
+ )
93
+ except Exception as e:
94
+ raise gr.Error(f"Load LoRA failed: {e}")
95
+
96
+ spatial_image = [uploaded_image.convert("RGB")]
97
+ subject_images = []
98
+ start = time.time()
99
+ out_img = pipe(
100
  prompt,
101
+ height=(height // 8) * 8,
102
+ width=(width // 8) * 8,
103
+ guidance_scale=guidance_scale,
104
+ num_inference_steps=num_inference_steps,
105
  max_sequence_length=512,
106
+ generator=generator,
107
  spatial_images=spatial_image,
108
+ subject_images=subject_images,
109
  cond_size=512,
110
  ).images[0]
111
+ print(f"inference time: {time.time()-start:.2f}s")
112
 
113
  clear_cache(pipe.transformer)
114
+ return uploaded_image, out_img
115
+
116
+ # =============== Gradio UI ===============
117
+ def create_interface():
118
+ demo_lora_names = [
119
+ "3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
120
+ "Fabric", "Ghibli", "Irasutoya", "Jojo", "LEGO", "Line",
121
+ "Macaron", "Oil_Painting", "Origami", "Paper_Cutting",
122
+ "Picasso", "Pixel", "Poly", "Pop_Art", "Rick_Morty",
123
+ "Snoopy", "Van_Gogh", "Vector"
124
+ ]
125
+
126
+ def update_trigger_word(lora_name, prompt):
127
+ for name in demo_lora_names:
128
+ trigger = " ".join(name.split("_")) + " style,"
129
+ prompt = prompt.replace(trigger, "")
130
+ new_trigger = " ".join(lora_name.split("_"))+ " style,"
131
+ return new_trigger + prompt
132
+
133
+ # Example data
134
+ examples = [
135
+ ["3D_Chibi", "", "3D Chibi style, Two smiling colleagues enthusiastically high-five in front of a whiteboard filled with technical notes about multimodal learning, reflecting a moment of success and collaboration at OpenAI.",
136
+ Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
137
+ ["Clay_Toy", "", "Clay Toy style, Three team members from OpenAI are gathered around a laptop in a cozy, festive setting, with holiday decorations in the background; one waves cheerfully while the others engage in light conversation, reflecting a relaxed and collaborative atmosphere.",
138
+ Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42],
139
+ ["American_Cartoon", "", "American Cartoon style, In a dramatic and comedic moment from a classic Chinese film, an intense elder with a white beard and red hat grips a younger man, declaring something with fervor, while the subtitle at the bottom reads 'I want them all' β€” capturing both tension and humor.",
140
+ Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42],
141
+ ["Origami", "", "Origami style, A thrilled fan wearing a Portugal football kit poses energetically with a smiling Cristiano Ronaldo, who gives a thumbs-up, as they stand side by side in a casual, cheerful momentβ€”capturing the excitement of meeting a football legend.",
142
+ Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42],
143
+ ["Vector", "", "Vector style, A man glances admiringly at a passing woman, while his girlfriend looks at him in disbelief, perfectly capturing the theme of shifting attention and misplaced priorities in a humorous, relatable way.",
144
+ Image.open("./test_imgs/04.png"), 512, 1024, 3.5, 24, 42]
145
+ ]
146
+
147
+ header = """
148
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
149
+ <a href="https://arxiv.org/abs/2505.18445"><img src="https://img.shields.io/badge/ariXv-2505.18445-A42C25.svg" alt="arXiv"></a>
150
+ <a href="https://huggingface.co/showlab/OmniConsistency"><img src="https://img.shields.io/badge/πŸ€—_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
151
+ <a href="https://github.com/showlab/OmniConsistency"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
152
+ </div>
153
+ """
154
+
155
+ with gr.Blocks() as demo:
156
+ gr.Markdown("# OmniConsistency LoRA Image Generation")
157
+ gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency.")
158
+ gr.HTML(header)
159
+
160
+ with gr.Row():
161
+ with gr.Column(scale=1):
162
+ image_input = gr.Image(type="pil", label="Upload Image")
163
+ prompt_box = gr.Textbox(label="Prompt",
164
+ value="3D Chibi style,",
165
+ info="Remember to include the necessary trigger words if you're using a custom LoRA."
166
+ )
167
+ lora_dropdown = gr.Dropdown(
168
+ demo_lora_names, label="Select built-in LoRA")
169
+ custom_repo_box = gr.Textbox(
170
+ label="Enter Custom LoRA",
171
+ placeholder="LoRA Hugging Face path (e.g., 'username/repo_name')",
172
+ info="If you want to use a custom LoRA, enter its Hugging Face repo ID here and built-in LoRA will be Overridden. Leave empty to use built-in LoRAs. [Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)"
173
+ )
174
+ gen_btn = gr.Button("Generate")
175
+ with gr.Column(scale=1):
176
+ output_image = gr.ImageSlider(label="Generated Image")
177
+ with gr.Accordion("Advanced Options", open=False):
178
+ height_box = gr.Textbox(value="1024", label="Height")
179
+ width_box = gr.Textbox(value="1024", label="Width")
180
+ guidance_slider = gr.Slider(
181
+ 0.1, 20, value=3.5, step=0.1, label="Guidance Scale")
182
+ steps_slider = gr.Slider(
183
+ 1, 50, value=25, step=1, label="Inference Steps")
184
+ seed_slider = gr.Slider(
185
+ 1, 2_147_483_647, value=42, step=1, label="Seed")
186
+
187
+ lora_dropdown.select(fn=update_trigger_word, inputs=[lora_dropdown,prompt_box],
188
+ outputs=prompt_box)
189
+
190
+ gr.Examples(
191
+ examples=examples,
192
+ inputs=[lora_dropdown, custom_repo_box, prompt_box, image_input,
193
+ height_box, width_box, guidance_slider, steps_slider, seed_slider],
194
+ outputs=output_image,
195
+ fn=generate_image,
196
+ cache_examples=False,
197
+ label="Examples"
198
+ )
199
+
200
+ gen_btn.click(
201
+ fn=generate_image,
202
+ inputs=[lora_dropdown, custom_repo_box, prompt_box, image_input,
203
+ width_box, height_box, guidance_slider, steps_slider, seed_slider],
204
+ outputs=output_image
205
+ )
206
+ return demo
207
 
208
+ if __name__ == "__main__":
209
+ demo = create_interface()
210
+ demo.launch(ssr_mode=False)