prithivMLmods commited on
Commit
3f7b37c
·
verified ·
1 Parent(s): 27bb58e

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -383
app.py DELETED
@@ -1,383 +0,0 @@
1
- import gradio as gr
2
- import numpy as np
3
- import spaces
4
- import torch
5
- import random
6
- import json
7
- import os
8
- from PIL import Image
9
- from diffusers import FluxKontextPipeline
10
- from diffusers.utils import load_image, peft_utils
11
- from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
12
- from safetensors.torch import load_file
13
- import requests
14
- import re
15
-
16
- # Load the base model
17
- MAX_SEED = np.iinfo(np.int32).max
18
-
19
- pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
20
-
21
- try: # Temporary workaround for diffusers LoRA loading issue
22
- from diffusers.utils.peft_utils import _derive_exclude_modules
23
-
24
- def new_derive_exclude_modules(*args, **kwargs):
25
- exclude_modules = _derive_exclude_modules(*args, **kwargs)
26
- if exclude_modules is not None:
27
- exclude_modules = [n for n in exclude_modules if "proj_out" not in n]
28
- return exclude_modules
29
- peft_utils._derive_exclude_modules = new_derive_exclude_modules
30
- except:
31
- pass
32
-
33
- # Load LoRA configurations from JSON
34
- with open("lora_configs.json", "r") as file:
35
- data = json.load(file)
36
- lora_configs = [
37
- {
38
- "image": item["image"],
39
- "title": item["title"],
40
- "repo": item["repo"],
41
- "trigger_word": item.get("trigger_word", ""),
42
- "trigger_position": item.get("trigger_position", "prepend"),
43
- "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
44
- }
45
- for item in data
46
- ]
47
- print(f"Loaded {len(lora_configs)} LoRAs from JSON")
48
-
49
- # Global variables for adapter management
50
- active_lora_adapter = None
51
- lora_cache = {}
52
-
53
- def load_lora_weights(repo_id, weights_filename):
54
- """Load adapter weights from HuggingFace"""
55
- try:
56
- if repo_id not in lora_cache:
57
- lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
58
- lora_cache[repo_id] = lora_path
59
- return lora_cache[repo_id]
60
- except Exception as e:
61
- print(f"Error loading adapter from {repo_id}: {e}")
62
- return None
63
-
64
- def on_lora_select(selected_state: gr.SelectData, lora_configs):
65
- """Update UI when an adapter is selected"""
66
- if selected_state.index >= len(lora_configs):
67
- return "### No adapter selected", gr.update(), None
68
-
69
- lora_repo = lora_configs[selected_state.index]["repo"]
70
- trigger_word = lora_configs[selected_state.index]["trigger_word"]
71
-
72
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
73
- new_placeholder = f"optional description, e.g. 'a man with glasses and a beard'"
74
-
75
- return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
76
-
77
- def fetch_lora_from_hf(link):
78
- """Retrieve adapter from HuggingFace link"""
79
- split_link = link.split("/")
80
- if len(split_link) == 2:
81
- try:
82
- model_card = ModelCard.load(link)
83
- trigger_word = model_card.data.get("instance_prompt", "")
84
-
85
- fs = HfFileSystem()
86
- list_of_files = fs.ls(link, detail=False)
87
- safetensors_file = None
88
-
89
- for file in list_of_files:
90
- if file.endswith(".safetensors") and "lora" in file.lower():
91
- safetensors_file = file.split("/")[-1]
92
- break
93
-
94
- if not safetensors_file:
95
- safetensors_file = "pytorch_lora_weights.safetensors"
96
-
97
- return split_link[1], safetensors_file, trigger_word
98
- except Exception as e:
99
- raise Exception(f"Error loading adapter: {e}")
100
- else:
101
- raise Exception("Invalid HuggingFace repository format")
102
-
103
- def load_user_lora(link):
104
- """Load a user-provided adapter"""
105
- if not link:
106
- return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on an adapter in the gallery to select it", None
107
-
108
- try:
109
- repo_name, weights_file, trigger_word = fetch_lora_from_hf(link)
110
-
111
- card = f'''
112
- <div style="border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 10px 0;">
113
- <span><strong>Loaded custom adapter:</strong></span>
114
- <div style="margin-top: 8px;">
115
- <h4>{repo_name}</h4>
116
- <small>{"Using: <code><b>"+trigger_word+"</b></code> as trigger word" if trigger_word else "No trigger word found"}</small>
117
- </div>
118
- </div>
119
- '''
120
-
121
- user_lora_data = {
122
- "repo": link,
123
- "weights": weights_file,
124
- "trigger_word": trigger_word
125
- }
126
-
127
- return gr.update(visible=True), card, gr.update(visible=True), user_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}", None
128
-
129
- except Exception as e:
130
- return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on an adapter in the gallery to select it", None
131
-
132
- def unload_user_lora():
133
- """Remove the user-provided adapter"""
134
- return "", gr.update(visible=False), gr.update(visible=False), None, None
135
-
136
- def sort_lora_gallery(lora_configs):
137
- """Sort the adapter gallery by likes"""
138
- sorted_gallery = sorted(lora_configs, key=lambda x: x.get("likes", 0), reverse=True)
139
- return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
140
-
141
- def generate_image_wrapper(input_image, prompt, selected_index, user_lora, seed=42, randomize_seed=False, steps=28, guidance_scale=2.5, lora_scale=1.75, width=960, height=1280, lora_configs=None, progress=gr.Progress(track_tqdm=True)):
142
- """Wrapper for image generation to handle state"""
143
- return generate_image(input_image, prompt, selected_index, user_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, lora_configs, progress)
144
-
145
- @spaces.GPU
146
- def generate_image(input_image, prompt, selected_index, user_lora, seed=42, randomize_seed=False, steps=28, guidance_scale=2.5, lora_scale=1.0, width=960, height=1280, lora_configs=None, progress=gr.Progress(track_tqdm=True)):
147
- """Generate an image using the selected adapter"""
148
- global active_lora_adapter, pipe
149
-
150
- if randomize_seed:
151
- seed = random.randint(0, MAX_SEED)
152
-
153
- # Select the adapter to use
154
- lora_to_use = None
155
- if user_lora:
156
- lora_to_use = user_lora
157
- elif selected_index is not None and lora_configs and selected_index < len(lora_configs):
158
- lora_to_use = lora_configs[selected_index]
159
- print(f"Loaded {len(lora_configs)} adapters from JSON")
160
-
161
- # Load the adapter if necessary
162
- if lora_to_use and lora_to_use != active_lora_adapter:
163
- try:
164
- if active_lora_adapter:
165
- pipe.unload_lora_weights()
166
-
167
- lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
168
- if lora_path:
169
- pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
170
- pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
171
- print(f"loaded: {lora_path} with scale {lora_scale}")
172
- active_lora_adapter = lora_to_use
173
-
174
- except Exception as e:
175
- print(f"Error loading adapter: {e}")
176
- else:
177
- print(f"using already loaded adapter: {lora_to_use}")
178
-
179
- input_image = input_image.convert("RGB")
180
- # Modify prompt based on trigger word
181
- trigger_word = lora_to_use["trigger_word"]
182
- if trigger_word == ", How2Draw":
183
- prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
184
- elif trigger_word == "__ ":
185
- prompt = f" {prompt}. Accurately render the toolimpact logo and any tool impact iconography. The toolimpact logo begins with a two-line-tall drop-cap capital letter T with a dot in the center of its top bar."
186
- else:
187
- prompt = f" {prompt}. convert the style of this photo or image to {trigger_word}. Maintain the facial identity of any persons and the general features of the image!"
188
-
189
- try:
190
- image = pipe(
191
- image=input_image,
192
- prompt=prompt,
193
- guidance_scale=guidance_scale,
194
- num_inference_steps=steps,
195
- generator=torch.Generator().manual_seed(seed),
196
- width=width,
197
- height=height,
198
- max_area=width * height
199
- ).images[0]
200
-
201
- return image, seed, gr.update(visible=True)
202
-
203
- except Exception as e:
204
- print(f"Error during generation: {e}")
205
- return None, seed, gr.update(visible=False)
206
-
207
- # CSS styling
208
- css = """
209
- #app_container {
210
- display: flex;
211
- gap: 20px;
212
- }
213
- #left_panel {
214
- min-width: 400px;
215
- }
216
- #lora_info {
217
- color: #2563eb;
218
- font-weight: bold;
219
- }
220
- #edit_prompt {
221
- flex-grow: 1;
222
- }
223
- #generate_button {
224
- background: linear-gradient(45deg, #2563eb, #3b82f6);
225
- color: white;
226
- border: none;
227
- padding: 8px 16px;
228
- border-radius: 6px;
229
- font-weight: bold;
230
- }
231
- .user_lora_card {
232
- background: #f8fafc;
233
- border: 1px solid #e2e8f0;
234
- border-radius: 8px;
235
- padding: 12px;
236
- margin: 8px 0;
237
- }
238
- #lora_gallery{
239
- overflow: scroll !important
240
- }
241
- """
242
-
243
- # Build the Gradio interface
244
- with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as demo:
245
- gr_lora_configs = gr.State(value=lora_configs)
246
-
247
- title = gr.HTML(
248
- """<h1>Flux Kontext DLC😍</h1>""",
249
- )
250
-
251
- selected_state = gr.State(value=None)
252
- user_lora = gr.State(value=None)
253
-
254
- with gr.Row(elem_id="app_container"):
255
- with gr.Column(scale=4, elem_id="left_panel"):
256
- with gr.Group(elem_id="lora_selection"):
257
- input_image = gr.Image(label="Upload a picture", type="pil", height=300)
258
-
259
- gallery = gr.Gallery(
260
- label="Pick an Adapter",
261
- allow_preview=False,
262
- columns=3,
263
- elem_id="lora_gallery",
264
- show_share_button=False,
265
- height=400
266
- )
267
-
268
- user_lora_input = gr.Textbox(
269
- label="Or enter a custom HuggingFace adapter",
270
- placeholder="e.g., username/adapter-name",
271
- visible=True
272
- )
273
- user_lora_card = gr.HTML(visible=False)
274
- unload_user_lora_button = gr.Button("Remove custom adapter", visible=True)
275
-
276
- with gr.Column(scale=5):
277
- with gr.Row():
278
- prompt = gr.Textbox(
279
- label="Editing Prompt",
280
- show_label=False,
281
- lines=1,
282
- max_lines=1,
283
- placeholder="optional description, e.g. 'colorize and stylize, leave all else as is'",
284
- elem_id="edit_prompt"
285
- )
286
- run_button = gr.Button("Generate", elem_id="generate_button")
287
-
288
- result = gr.Image(label="Generated Image", interactive=False)
289
- reuse_button = gr.Button("Reuse this image", visible=False)
290
-
291
- with gr.Accordion("Advanced Settings", open=True):
292
- lora_scale = gr.Slider(
293
- label="Adapter Scale",
294
- minimum=0,
295
- maximum=2,
296
- step=0.1,
297
- value=1.5,
298
- info="Controls the strength of the adapter effect"
299
- )
300
- seed = gr.Slider(
301
- label="Seed",
302
- minimum=0,
303
- maximum=MAX_SEED,
304
- step=1,
305
- value=0,
306
- )
307
- steps = gr.Slider(
308
- label="Steps",
309
- minimum=1,
310
- maximum=40,
311
- value=10,
312
- step=1
313
- )
314
- width = gr.Slider(
315
- label="Width",
316
- minimum=128,
317
- maximum=2560,
318
- step=1,
319
- value=960,
320
- )
321
- height = gr.Slider(
322
- label="Height",
323
- minimum=128,
324
- maximum=2560,
325
- step=1,
326
- value=1280,
327
- )
328
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
329
- guidance_scale = gr.Slider(
330
- label="Guidance Scale",
331
- minimum=1,
332
- maximum=10,
333
- step=0.1,
334
- value=2.8,
335
- )
336
-
337
- prompt_title = gr.Markdown(
338
- value="### Click on an adapter in the gallery to select it",
339
- visible=True,
340
- elem_id="lora_info",
341
- )
342
-
343
- # Event handlers
344
- user_lora_input.input(
345
- fn=load_user_lora,
346
- inputs=[user_lora_input],
347
- outputs=[user_lora_card, user_lora_card, unload_user_lora_button, user_lora, gallery, prompt_title, selected_state],
348
- )
349
-
350
- unload_user_lora_button.click(
351
- fn=unload_user_lora,
352
- outputs=[user_lora_input, unload_user_lora_button, user_lora_card, user_lora, selected_state]
353
- )
354
-
355
- gallery.select(
356
- fn=on_lora_select,
357
- inputs=[gr_lora_configs],
358
- outputs=[prompt_title, prompt, selected_state],
359
- show_progress=False
360
- )
361
-
362
- gr.on(
363
- triggers=[run_button.click, prompt.submit],
364
- fn=generate_image_wrapper,
365
- inputs=[input_image, prompt, selected_state, user_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, gr_lora_configs],
366
- outputs=[result, seed, reuse_button]
367
- )
368
-
369
- reuse_button.click(
370
- fn=lambda image: image,
371
- inputs=[result],
372
- outputs=[input_image]
373
- )
374
-
375
- # Initialize the gallery
376
- demo.load(
377
- fn=sort_lora_gallery,
378
- inputs=[gr_lora_configs],
379
- outputs=[gallery, gr_lora_configs]
380
- )
381
-
382
- demo.queue(default_concurrency_limit=None)
383
- demo.launch()