ginipick commited on
Commit
9fe68b1
·
verified ·
1 Parent(s): f1565b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +550 -13
app.py CHANGED
@@ -13,21 +13,558 @@ import gradio_client.utils
13
  original_json_schema = gradio_client.utils._json_schema_to_python_type
14
 
15
  from PIL import ImageOps, ExifTags
16
- import sys
17
 
18
- # --- [Optional Patch] ---------------------------------------------------------
19
- # This patch fixes potential JSON schema parsing issues in Gradio/Gradio-Client.
20
- import gradio_client.utils
21
- original_json_schema = gradio_client.utils._json_schema_to_python_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- import ast #추가 삽입, requirements: albumentations 추가
24
- script_repr = os.getenv("APP")
25
- if script_repr is None:
26
- print("Error: Environment variable 'APP' not set.")
27
- sys.exit(1)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  try:
30
- exec(script_repr)
 
 
 
 
31
  except Exception as e:
32
- print(f"Error executing script: {e}")
33
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  original_json_schema = gradio_client.utils._json_schema_to_python_type
14
 
15
  from PIL import ImageOps, ExifTags
 
16
 
17
+ def preprocess_image(image):
18
+ # EXIF 정보에 따라 이미지 회전 조정
19
+ try:
20
+ image = ImageOps.exif_transpose(image)
21
+ except Exception as e:
22
+ print(f"EXIF 변환 오류: {e}")
23
+
24
+ # 이미지 크기 조정 (너무 크면 모델이 제대로 처리하지 못할 수 있음)
25
+ if max(image.width, image.height) > 1024:
26
+ image.thumbnail((1024, 1024), Image.LANCZOS)
27
+
28
+ # 이미지 모드 확인 및 변환
29
+ if image.mode != "RGB":
30
+ image = image.convert("RGB")
31
+
32
+ return image
33
+
34
+ # caption_image 함수 내에서 사용
35
+ pil_image = preprocess_image(pil_image)
36
+
37
+ def patched_json_schema(schema, defs=None):
38
+ # Handle boolean schema directly
39
+ if isinstance(schema, bool):
40
+ return "bool"
41
+
42
+ # If 'additionalProperties' is a boolean, replace it with a generic type
43
+ try:
44
+ if "additionalProperties" in schema and isinstance(schema["additionalProperties"], bool):
45
+ schema["additionalProperties"] = {"type": "any"}
46
+ except (TypeError, KeyError):
47
+ pass
48
+
49
+ # Attempt to parse normally; fallback to "any" on error
50
+ try:
51
+ return original_json_schema(schema, defs)
52
+ except Exception:
53
+ return "any"
54
+
55
+ gradio_client.utils._json_schema_to_python_type = patched_json_schema
56
+ # -----------------------------------------------------------------------------
57
+
58
+ # ----------------------------- Model Loading ----------------------------------
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ repo_id = "black-forest-labs/FLUX.1-dev"
61
+ adapter_id = "openfree/flux-chatgpt-ghibli-lora"
62
+
63
+ def load_model_with_retry(max_retries=5):
64
+ for attempt in range(max_retries):
65
+ try:
66
+ print(f"Loading model attempt {attempt+1}/{max_retries}...")
67
+ pipeline = DiffusionPipeline.from_pretrained(
68
+ repo_id,
69
+ torch_dtype=torch.bfloat16,
70
+ use_safetensors=True,
71
+ resume_download=True
72
+ )
73
+ print("Base model loaded successfully, now loading LoRA weights...")
74
+ pipeline.load_lora_weights(adapter_id)
75
+ pipeline = pipeline.to(device)
76
+ print("Pipeline is ready!")
77
+ return pipeline
78
+ except Exception as e:
79
+ if attempt < max_retries - 1:
80
+ wait_time = 10 * (attempt + 1)
81
+ print(f"Error loading model: {e}. Retrying in {wait_time} seconds...")
82
+ import time
83
+ time.sleep(wait_time)
84
+ else:
85
+ raise Exception(f"Failed to load model after {max_retries} attempts: {e}")
86
 
87
+ pipeline = load_model_with_retry()
 
 
 
 
88
 
89
+ # ----------------------------- Inference Function -----------------------------
90
+ MAX_SEED = np.iinfo(np.int32).max
91
+ MAX_IMAGE_SIZE = 1024
92
+
93
+ @spaces.GPU(duration=120)
94
+ def inference(
95
+ prompt: str,
96
+ seed: int,
97
+ randomize_seed: bool,
98
+ width: int,
99
+ height: int,
100
+ guidance_scale: float,
101
+ num_inference_steps: int,
102
+ lora_scale: float,
103
+ ):
104
+ # If "randomize_seed" is selected, choose a random seed
105
+ if randomize_seed:
106
+ seed = random.randint(0, MAX_SEED)
107
+ generator = torch.Generator(device=device).manual_seed(seed)
108
+
109
+ print(f"Running inference with prompt: {prompt}")
110
+
111
+ try:
112
+ image = pipeline(
113
+ prompt=prompt,
114
+ guidance_scale=guidance_scale,
115
+ num_inference_steps=num_inference_steps,
116
+ width=width,
117
+ height=height,
118
+ generator=generator,
119
+ joint_attention_kwargs={"scale": lora_scale},
120
+ ).images[0]
121
+ return image, seed
122
+ except Exception as e:
123
+ print(f"Error during inference: {e}")
124
+ # Return a red error image of the specified size and the used seed
125
+ error_img = Image.new('RGB', (width, height), color='red')
126
+ return error_img, seed
127
+
128
+ # ----------------------------- Florence-2 Captioner ---------------------------
129
+ import subprocess
130
  try:
131
+ subprocess.run(
132
+ 'pip install flash-attn --no-build-isolation',
133
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
134
+ shell=True
135
+ )
136
  except Exception as e:
137
+ print(f"Warning: Could not install flash-attn: {e}")
138
+
139
+ from transformers import AutoProcessor, AutoModelForCausalLM
140
+
141
+ # Function to safely load models
142
+ def load_caption_model(model_name):
143
+ try:
144
+ model = AutoModelForCausalLM.from_pretrained(
145
+ model_name, trust_remote_code=True
146
+ ).eval()
147
+ processor = AutoProcessor.from_pretrained(
148
+ model_name, trust_remote_code=True
149
+ )
150
+ return model, processor
151
+ except Exception as e:
152
+ print(f"Error loading caption model {model_name}: {e}")
153
+ return None, None
154
+
155
+ # Pre-load models and processors
156
+ print("Loading captioning models...")
157
+ default_caption_model = 'microsoft/Florence-2-large'
158
+ models = {}
159
+ processors = {}
160
+
161
+ # Try to load the default model
162
+ default_model, default_processor = load_caption_model(default_caption_model)
163
+ if default_model is not None and default_processor is not None:
164
+ models[default_caption_model] = default_model
165
+ processors[default_caption_model] = default_processor
166
+ print(f"Successfully loaded default caption model: {default_caption_model}")
167
+ else:
168
+ # Fallback to simpler model
169
+ fallback_model = 'gokaygokay/Florence-2-Flux'
170
+ fallback_model_obj, fallback_processor = load_caption_model(fallback_model)
171
+ if fallback_model_obj is not None and fallback_processor is not None:
172
+ models[fallback_model] = fallback_model_obj
173
+ processors[fallback_model] = fallback_processor
174
+ default_caption_model = fallback_model
175
+ print(f"Loaded fallback caption model: {fallback_model}")
176
+ else:
177
+ print("WARNING: Failed to load any caption model!")
178
+
179
+ @spaces.GPU
180
+ def caption_image(image, model_name=default_caption_model):
181
+ """
182
+ Runs the selected Florence-2 model to generate a detailed caption.
183
+ """
184
+ from PIL import Image as PILImage
185
+ import numpy as np
186
+
187
+ print(f"Starting caption generation with model: {model_name}")
188
+
189
+ # Handle case where image is already a PIL image
190
+ if isinstance(image, PILImage.Image):
191
+ pil_image = image
192
+ else:
193
+ # Convert numpy array to PIL
194
+ if isinstance(image, np.ndarray):
195
+ pil_image = PILImage.fromarray(image)
196
+ else:
197
+ print(f"Unexpected image type: {type(image)}")
198
+ return "Error: Unsupported image type"
199
+
200
+ # Convert input to RGB if needed
201
+ if pil_image.mode != "RGB":
202
+ pil_image = pil_image.convert("RGB")
203
+
204
+ # Check if model is available
205
+ if model_name not in models or model_name not in processors:
206
+ available_models = list(models.keys())
207
+ if available_models:
208
+ model_name = available_models[0]
209
+ print(f"Requested model not available, using: {model_name}")
210
+ else:
211
+ return "Error: No caption models available"
212
+
213
+ model = models[model_name]
214
+ processor = processors[model_name]
215
+
216
+ task_prompt = "<DESCRIPTION>"
217
+ user_prompt = task_prompt + "Describe this image in great detail."
218
+
219
+ try:
220
+ inputs = processor(text=user_prompt, images=pil_image, return_tensors="pt")
221
+
222
+ generated_ids = model.generate(
223
+ input_ids=inputs["input_ids"],
224
+ pixel_values=inputs["pixel_values"],
225
+ max_new_tokens=1024,
226
+ num_beams=3,
227
+ repetition_penalty=1.10,
228
+ )
229
+
230
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
231
+ parsed_answer = processor.post_process_generation(
232
+ generated_text, task=task_prompt, image_size=(pil_image.width, pil_image.height)
233
+ )
234
+
235
+ # Extract the caption
236
+ caption = parsed_answer.get("<DESCRIPTION>", "")
237
+ print(f"Generated caption: {caption}")
238
+ return caption
239
+ except Exception as e:
240
+ print(f"Error during captioning: {e}")
241
+ return f"Error generating caption: {str(e)}"
242
+
243
+ # --------- Process uploaded image and generate Ghibli style image ---------
244
+ @spaces.GPU(duration=120)
245
+ def process_uploaded_image(
246
+ image,
247
+ seed,
248
+ randomize_seed,
249
+ width,
250
+ height,
251
+ guidance_scale,
252
+ num_inference_steps,
253
+ lora_scale
254
+ ):
255
+ if image is None:
256
+ print("No image provided")
257
+ return None, None, "No image provided", "No image provided"
258
+
259
+ print("Starting image processing workflow")
260
+
261
+ # Step 1: Generate caption from the uploaded image
262
+ try:
263
+ caption = caption_image(image)
264
+ if caption.startswith("Error:"):
265
+ print(f"Captioning failed: {caption}")
266
+ # Use a default caption as fallback
267
+ caption = "A beautiful scene"
268
+ except Exception as e:
269
+ print(f"Exception during captioning: {e}")
270
+ caption = "A beautiful scene"
271
+
272
+ # Step 2: Append "ghibli style" to the caption
273
+ ghibli_prompt = f"{caption}, ghibli style"
274
+ print(f"Final prompt for Ghibli generation: {ghibli_prompt}")
275
+
276
+ # Step 3: Generate Ghibli-style image based on the caption
277
+ try:
278
+ generated_image, used_seed = inference(
279
+ prompt=ghibli_prompt,
280
+ seed=seed,
281
+ randomize_seed=randomize_seed,
282
+ width=width,
283
+ height=height,
284
+ guidance_scale=guidance_scale,
285
+ num_inference_steps=num_inference_steps,
286
+ lora_scale=lora_scale
287
+ )
288
+
289
+ print(f"Image generation complete with seed: {used_seed}")
290
+ return generated_image, used_seed, caption, ghibli_prompt
291
+ except Exception as e:
292
+ print(f"Error generating image: {e}")
293
+ error_img = Image.new('RGB', (width, height), color='red')
294
+ return error_img, seed, caption, ghibli_prompt
295
+
296
+ # Define Ghibli Studio Theme
297
+ ghibli_theme = gr.themes.Soft(
298
+ primary_hue="indigo",
299
+ secondary_hue="blue",
300
+ neutral_hue="slate",
301
+ font=[gr.themes.GoogleFont("Nunito"), "ui-sans-serif", "sans-serif"],
302
+ radius_size=gr.themes.sizes.radius_sm,
303
+ ).set(
304
+ body_background_fill="#f0f9ff",
305
+ body_background_fill_dark="#0f172a",
306
+ button_primary_background_fill="#6366f1",
307
+ button_primary_background_fill_hover="#4f46e5",
308
+ button_primary_text_color="#ffffff",
309
+ block_title_text_weight="600",
310
+ block_border_width="1px",
311
+ block_shadow="0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1)",
312
+ )
313
+
314
+ # Custom CSS for enhanced visuals
315
+ custom_css = """
316
+ .gradio-container {
317
+ max-width: 1200px !important;
318
+ }
319
+
320
+ .main-header {
321
+ text-align: center;
322
+ margin-bottom: 1rem;
323
+ font-weight: 800;
324
+ font-size: 2.5rem;
325
+ background: linear-gradient(90deg, #4338ca, #3b82f6);
326
+ -webkit-background-clip: text;
327
+ -webkit-text-fill-color: transparent;
328
+ padding: 0.5rem;
329
+ }
330
+
331
+ .tagline {
332
+ text-align: center;
333
+ font-size: 1.2rem;
334
+ margin-bottom: 2rem;
335
+ color: #4b5563;
336
+ }
337
+
338
+ .image-preview {
339
+ border-radius: 12px;
340
+ overflow: hidden;
341
+ box-shadow: 0 10px 15px -3px rgb(0 0 0 / 0.1), 0 4px 6px -4px rgb(0 0 0 / 0.1);
342
+ }
343
+
344
+ .panel-box {
345
+ border-radius: 12px;
346
+ background-color: rgba(255, 255, 255, 0.8);
347
+ padding: 1rem;
348
+ box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
349
+ }
350
+
351
+ .control-panel {
352
+ padding: 1rem;
353
+ border-radius: 12px;
354
+ background-color: rgba(255, 255, 255, 0.9);
355
+ margin-bottom: 1rem;
356
+ border: 1px solid #e2e8f0;
357
+ }
358
+
359
+ .section-header {
360
+ font-weight: 600;
361
+ font-size: 1.1rem;
362
+ margin-bottom: 0.5rem;
363
+ color: #4338ca;
364
+ }
365
+
366
+ .transform-button {
367
+ font-weight: 600 !important;
368
+ margin-top: 1rem !important;
369
+ }
370
+
371
+ .footer {
372
+ text-align: center;
373
+ color: #6b7280;
374
+ margin-top: 2rem;
375
+ font-size: 0.9rem;
376
+ }
377
+
378
+ .output-panel {
379
+ background: linear-gradient(135deg, #f0f9ff, #e0f2fe);
380
+ border-radius: 12px;
381
+ padding: 1rem;
382
+ border: 1px solid #bfdbfe;
383
+ }
384
+ """
385
+
386
+ # ----------------------------- Gradio UI --------------------------------------
387
+ with gr.Blocks(analytics_enabled=False, theme=ghibli_theme, css=custom_css) as demo:
388
+ gr.HTML(
389
+ """
390
+ <div class="main-header">Open Ghibli Studio</div>
391
+ <div class="tagline">Transform your photos into magical Ghibli-inspired artwork</div>
392
+ """
393
+ )
394
+
395
+ # Background image for the app
396
+ gr.HTML(
397
+ """
398
+ <style>
399
+ body {
400
+ background-image: url('https://i.imgur.com/LxPQPR1.jpg');
401
+ background-size: cover;
402
+ background-position: center;
403
+ background-attachment: fixed;
404
+ background-repeat: no-repeat;
405
+ background-color: #f0f9ff;
406
+ }
407
+ @media (max-width: 768px) {
408
+ body {
409
+ background-size: contain;
410
+ }
411
+ }
412
+ </style>
413
+ """
414
+ )
415
+
416
+ with gr.Row(equal_height=True):
417
+ with gr.Column(scale=1):
418
+ with gr.Group(elem_classes="panel-box"):
419
+ gr.HTML('<div class="section-header">Upload Image</div>')
420
+ upload_img = gr.Image(
421
+ label="Drop your image here",
422
+ type="pil",
423
+ elem_classes="image-preview",
424
+ height=400
425
+ )
426
+
427
+ with gr.Accordion("Advanced Settings", open=False):
428
+ with gr.Group(elem_classes="control-panel"):
429
+ gr.HTML('<div class="section-header">Generation Controls</div>')
430
+ with gr.Row():
431
+ img2img_seed = gr.Slider(
432
+ label="Seed",
433
+ minimum=0,
434
+ maximum=MAX_SEED,
435
+ step=1,
436
+ value=42,
437
+ info="Set a specific seed for reproducible results"
438
+ )
439
+ img2img_randomize_seed = gr.Checkbox(
440
+ label="Randomize Seed",
441
+ value=True,
442
+ info="Enable to get different results each time"
443
+ )
444
+
445
+ with gr.Group():
446
+ gr.HTML('<div class="section-header">Image Size</div>')
447
+ with gr.Row():
448
+ img2img_width = gr.Slider(
449
+ label="Width",
450
+ minimum=256,
451
+ maximum=MAX_IMAGE_SIZE,
452
+ step=32,
453
+ value=1024,
454
+ info="Image width in pixels"
455
+ )
456
+ img2img_height = gr.Slider(
457
+ label="Height",
458
+ minimum=256,
459
+ maximum=MAX_IMAGE_SIZE,
460
+ step=32,
461
+ value=1024,
462
+ info="Image height in pixels"
463
+ )
464
+
465
+ with gr.Group():
466
+ gr.HTML('<div class="section-header">Generation Parameters</div>')
467
+ with gr.Row():
468
+ img2img_guidance_scale = gr.Slider(
469
+ label="Guidance Scale",
470
+ minimum=0.0,
471
+ maximum=10.0,
472
+ step=0.1,
473
+ value=3.5,
474
+ info="Higher values follow the prompt more closely"
475
+ )
476
+ img2img_steps = gr.Slider(
477
+ label="Steps",
478
+ minimum=1,
479
+ maximum=50,
480
+ step=1,
481
+ value=30,
482
+ info="More steps = more detailed but slower generation"
483
+ )
484
+
485
+ img2img_lora_scale = gr.Slider(
486
+ label="Ghibli Style Strength",
487
+ minimum=0.0,
488
+ maximum=1.0,
489
+ step=0.05,
490
+ value=1.0,
491
+ info="Controls the intensity of the Ghibli style"
492
+ )
493
+
494
+ transform_button = gr.Button("Transform to Ghibli Style", variant="primary", elem_classes="transform-button")
495
+
496
+ with gr.Column(scale=1):
497
+ with gr.Group(elem_classes="output-panel"):
498
+ gr.HTML('<div class="section-header">Ghibli Magic Result</div>')
499
+ ghibli_output_image = gr.Image(
500
+ label="Generated Ghibli Style Image",
501
+ elem_classes="image-preview",
502
+ height=400
503
+ )
504
+ ghibli_output_seed = gr.Number(label="Seed Used", interactive=False)
505
+
506
+ # Debug elements
507
+ with gr.Accordion("Image Details", open=False):
508
+ extracted_caption = gr.Textbox(
509
+ label="Detected Image Content",
510
+ placeholder="The AI will analyze your image and describe it here...",
511
+ info="AI-generated description of your uploaded image"
512
+ )
513
+ ghibli_prompt = gr.Textbox(
514
+ label="Generation Prompt",
515
+ placeholder="The prompt used to create your Ghibli image will appear here...",
516
+ info="Final prompt used for the Ghibli transformation"
517
+ )
518
+
519
+ gr.HTML(
520
+ """
521
+ <div class="footer">
522
+ <p>Open Ghibli Studio uses AI to transform your images into Ghibli-inspired artwork.</p>
523
+ <p>Powered by FLUX.1 and Florence-2 models.</p>
524
+ </div>
525
+ """
526
+ )
527
+
528
+ # Auto-process when image is uploaded
529
+ upload_img.upload(
530
+ process_uploaded_image,
531
+ inputs=[
532
+ upload_img,
533
+ img2img_seed,
534
+ img2img_randomize_seed,
535
+ img2img_width,
536
+ img2img_height,
537
+ img2img_guidance_scale,
538
+ img2img_steps,
539
+ img2img_lora_scale,
540
+ ],
541
+ outputs=[
542
+ ghibli_output_image,
543
+ ghibli_output_seed,
544
+ extracted_caption,
545
+ ghibli_prompt,
546
+ ]
547
+ )
548
+
549
+ # Manual process button
550
+ transform_button.click(
551
+ process_uploaded_image,
552
+ inputs=[
553
+ upload_img,
554
+ img2img_seed,
555
+ img2img_randomize_seed,
556
+ img2img_width,
557
+ img2img_height,
558
+ img2img_guidance_scale,
559
+ img2img_steps,
560
+ img2img_lora_scale,
561
+ ],
562
+ outputs=[
563
+ ghibli_output_image,
564
+ ghibli_output_seed,
565
+ extracted_caption,
566
+ ghibli_prompt,
567
+ ]
568
+ )
569
+
570
+ demo.launch(debug=True)