ginipick's picture
Update app.py
a82e4ad verified
import os
import random
import torch
import gradio as gr
import numpy as np
import spaces
from diffusers import DiffusionPipeline
from PIL import Image
# --- [Optional Patch] ---------------------------------------------------------
# This patch fixes potential JSON schema parsing issues in Gradio/Gradio-Client.
import gradio_client.utils
original_json_schema = gradio_client.utils._json_schema_to_python_type
from PIL import ImageOps, ExifTags
def preprocess_image(image):
# EXIF 정보에 따라 이미지 회전 조정
try:
image = ImageOps.exif_transpose(image)
except Exception as e:
print(f"EXIF 변환 오류: {e}")
# 이미지 크기 조정 (너무 크면 모델이 제대로 처리하지 못할 수 있음)
if max(image.width, image.height) > 1024:
image.thumbnail((1024, 1024), Image.LANCZOS)
# 이미지 모드 확인 및 변환
if image.mode != "RGB":
image = image.convert("RGB")
return image
# DELETE THIS LINE COMPLETELY
def patched_json_schema(schema, defs=None):
# Handle boolean schema directly
if isinstance(schema, bool):
return "bool"
# If 'additionalProperties' is a boolean, replace it with a generic type
try:
if "additionalProperties" in schema and isinstance(schema["additionalProperties"], bool):
schema["additionalProperties"] = {"type": "any"}
except (TypeError, KeyError):
pass
# Attempt to parse normally; fallback to "any" on error
try:
return original_json_schema(schema, defs)
except Exception:
return "any"
gradio_client.utils._json_schema_to_python_type = patched_json_schema
# -----------------------------------------------------------------------------
# ----------------------------- Model Loading ----------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
repo_id = "black-forest-labs/FLUX.1-dev"
adapter_id = "openfree/flux-chatgpt-ghibli-lora"
def load_model_with_retry(max_retries=5):
for attempt in range(max_retries):
try:
print(f"Loading model attempt {attempt+1}/{max_retries}...")
pipeline = DiffusionPipeline.from_pretrained(
repo_id,
torch_dtype=torch.bfloat16,
use_safetensors=True,
resume_download=True
)
print("Base model loaded successfully, now loading LoRA weights...")
pipeline.load_lora_weights(adapter_id)
pipeline = pipeline.to(device)
print("Pipeline is ready!")
return pipeline
except Exception as e:
if attempt < max_retries - 1:
wait_time = 10 * (attempt + 1)
print(f"Error loading model: {e}. Retrying in {wait_time} seconds...")
import time
time.sleep(wait_time)
else:
raise Exception(f"Failed to load model after {max_retries} attempts: {e}")
pipeline = load_model_with_retry()
# ----------------------------- Inference Function -----------------------------
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU(duration=120)
def inference(
prompt: str,
seed: int,
randomize_seed: bool,
width: int,
height: int,
guidance_scale: float,
num_inference_steps: int,
lora_scale: float,
):
# If "randomize_seed" is selected, choose a random seed
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
print(f"Running inference with prompt: {prompt}")
try:
image = pipeline(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
return image, seed
except Exception as e:
print(f"Error during inference: {e}")
# Return a red error image of the specified size and the used seed
error_img = Image.new('RGB', (width, height), color='red')
return error_img, seed
# ----------------------------- Florence-2 Captioner ---------------------------
import subprocess
try:
subprocess.run(
'pip install flash-attn --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True
)
except Exception as e:
print(f"Warning: Could not install flash-attn: {e}")
from transformers import AutoProcessor, AutoModelForCausalLM
# Function to safely load models
def load_caption_model(model_name):
try:
model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
).eval()
processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True
)
return model, processor
except Exception as e:
print(f"Error loading caption model {model_name}: {e}")
return None, None
# Pre-load models and processors
print("Loading captioning models...")
default_caption_model = 'microsoft/Florence-2-large'
models = {}
processors = {}
# Try to load the default model
default_model, default_processor = load_caption_model(default_caption_model)
if default_model is not None and default_processor is not None:
models[default_caption_model] = default_model
processors[default_caption_model] = default_processor
print(f"Successfully loaded default caption model: {default_caption_model}")
else:
# Fallback to simpler model
fallback_model = 'gokaygokay/Florence-2-Flux'
fallback_model_obj, fallback_processor = load_caption_model(fallback_model)
if fallback_model_obj is not None and fallback_processor is not None:
models[fallback_model] = fallback_model_obj
processors[fallback_model] = fallback_processor
default_caption_model = fallback_model
print(f"Loaded fallback caption model: {fallback_model}")
else:
print("WARNING: Failed to load any caption model!")
@spaces.GPU
def caption_image(image, model_name=default_caption_model):
"""
Runs the selected Florence-2 model to generate a detailed caption.
"""
from PIL import Image as PILImage
import numpy as np
print(f"Starting caption generation with model: {model_name}")
# Handle case where image is already a PIL image
if isinstance(image, PILImage.Image):
pil_image = image
else:
# Convert numpy array to PIL
if isinstance(image, np.ndarray):
pil_image = PILImage.fromarray(image)
else:
print(f"Unexpected image type: {type(image)}")
return "Error: Unsupported image type"
# Convert input to RGB if needed
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
# Check if model is available
if model_name not in models or model_name not in processors:
available_models = list(models.keys())
if available_models:
model_name = available_models[0]
print(f"Requested model not available, using: {model_name}")
else:
return "Error: No caption models available"
model = models[model_name]
processor = processors[model_name]
task_prompt = "<DESCRIPTION>"
user_prompt = task_prompt + "Describe this image in great detail."
try:
inputs = processor(text=user_prompt, images=pil_image, return_tensors="pt")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
repetition_penalty=1.10,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text, task=task_prompt, image_size=(pil_image.width, pil_image.height)
)
# Extract the caption
caption = parsed_answer.get("<DESCRIPTION>", "")
print(f"Generated caption: {caption}")
return caption
except Exception as e:
print(f"Error during captioning: {e}")
return f"Error generating caption: {str(e)}"
# --------- Process uploaded image and generate Ghibli style image ---------
@spaces.GPU(duration=120)
def process_uploaded_image(
image,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lora_scale
):
if image is None:
print("No image provided")
return None, None, "No image provided", "No image provided"
print("Starting image processing workflow")
# Step 1: Generate caption from the uploaded image
try:
caption = caption_image(image)
if caption.startswith("Error:"):
print(f"Captioning failed: {caption}")
# Use a default caption as fallback
caption = "A beautiful scene"
except Exception as e:
print(f"Exception during captioning: {e}")
caption = "A beautiful scene"
# Step 2: Append "ghibli style" to the caption
ghibli_prompt = f"{caption}, ghibli style"
print(f"Final prompt for Ghibli generation: {ghibli_prompt}")
# Step 3: Generate Ghibli-style image based on the caption
try:
generated_image, used_seed = inference(
prompt=ghibli_prompt,
seed=seed,
randomize_seed=randomize_seed,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
lora_scale=lora_scale
)
print(f"Image generation complete with seed: {used_seed}")
return generated_image, used_seed, caption, ghibli_prompt
except Exception as e:
print(f"Error generating image: {e}")
error_img = Image.new('RGB', (width, height), color='red')
return error_img, seed, caption, ghibli_prompt
# Define Ghibli Studio Theme
ghibli_theme = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Nunito"), "ui-sans-serif", "sans-serif"],
radius_size=gr.themes.sizes.radius_sm,
).set(
body_background_fill="#f0f9ff",
body_background_fill_dark="#0f172a",
button_primary_background_fill="#6366f1",
button_primary_background_fill_hover="#4f46e5",
button_primary_text_color="#ffffff",
block_title_text_weight="600",
block_border_width="1px",
block_shadow="0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1)",
)
# Custom CSS for enhanced visuals
custom_css = """
.gradio-container {
max-width: 1200px !important;
}
.main-header {
text-align: center;
margin-bottom: 1rem;
font-weight: 800;
font-size: 2.5rem;
background: linear-gradient(90deg, #4338ca, #3b82f6);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
padding: 0.5rem;
}
.tagline {
text-align: center;
font-size: 1.2rem;
margin-bottom: 2rem;
color: #4b5563;
}
.image-preview {
border-radius: 12px;
overflow: hidden;
box-shadow: 0 10px 15px -3px rgb(0 0 0 / 0.1), 0 4px 6px -4px rgb(0 0 0 / 0.1);
}
.panel-box {
border-radius: 12px;
background-color: rgba(255, 255, 255, 0.8);
padding: 1rem;
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
}
.control-panel {
padding: 1rem;
border-radius: 12px;
background-color: rgba(255, 255, 255, 0.9);
margin-bottom: 1rem;
border: 1px solid #e2e8f0;
}
.section-header {
font-weight: 600;
font-size: 1.1rem;
margin-bottom: 0.5rem;
color: #4338ca;
}
.transform-button {
font-weight: 600 !important;
margin-top: 1rem !important;
}
.footer {
text-align: center;
color: #6b7280;
margin-top: 2rem;
font-size: 0.9rem;
}
.output-panel {
background: linear-gradient(135deg, #f0f9ff, #e0f2fe);
border-radius: 12px;
padding: 1rem;
border: 1px solid #bfdbfe;
}
"""
# ----------------------------- Gradio UI --------------------------------------
with gr.Blocks(analytics_enabled=False, theme=ghibli_theme, css=custom_css) as demo:
gr.HTML(
"""
<div class="main-header">Open Ghibli Studio</div>
<div class="tagline">Transform your photos into magical Ghibli-inspired artwork</div>
"""
)
# Background image for the app
gr.HTML(
"""
<style>
body {
background-image: url('https://i.imgur.com/LxPQPR1.jpg');
background-size: cover;
background-position: center;
background-attachment: fixed;
background-repeat: no-repeat;
background-color: #f0f9ff;
}
@media (max-width: 768px) {
body {
background-size: contain;
}
}
</style>
"""
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
with gr.Group(elem_classes="panel-box"):
gr.HTML('<div class="section-header">Upload Image</div>')
upload_img = gr.Image(
label="Drop your image here",
type="pil",
elem_classes="image-preview",
height=400
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Group(elem_classes="control-panel"):
gr.HTML('<div class="section-header">Generation Controls</div>')
with gr.Row():
img2img_seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
info="Set a specific seed for reproducible results"
)
img2img_randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=True,
info="Enable to get different results each time"
)
with gr.Group():
gr.HTML('<div class="section-header">Image Size</div>')
with gr.Row():
img2img_width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
info="Image width in pixels"
)
img2img_height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
info="Image height in pixels"
)
with gr.Group():
gr.HTML('<div class="section-header">Generation Parameters</div>')
with gr.Row():
img2img_guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=3.5,
info="Higher values follow the prompt more closely"
)
img2img_steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50,
step=1,
value=30,
info="More steps = more detailed but slower generation"
)
img2img_lora_scale = gr.Slider(
label="Ghibli Style Strength",
minimum=0.0,
maximum=1.0,
step=0.05,
value=1.0,
info="Controls the intensity of the Ghibli style"
)
transform_button = gr.Button("Transform to Ghibli Style", variant="primary", elem_classes="transform-button")
with gr.Column(scale=1):
with gr.Group(elem_classes="output-panel"):
gr.HTML('<div class="section-header">Ghibli Magic Result</div>')
ghibli_output_image = gr.Image(
label="Generated Ghibli Style Image",
elem_classes="image-preview",
height=400
)
ghibli_output_seed = gr.Number(label="Seed Used", interactive=False)
# Debug elements
with gr.Accordion("Image Details", open=False):
extracted_caption = gr.Textbox(
label="Detected Image Content",
placeholder="The AI will analyze your image and describe it here...",
info="AI-generated description of your uploaded image"
)
ghibli_prompt = gr.Textbox(
label="Generation Prompt",
placeholder="The prompt used to create your Ghibli image will appear here...",
info="Final prompt used for the Ghibli transformation"
)
gr.HTML(
"""
<div class="footer">
<p>Open Ghibli Studio uses AI to transform your images into Ghibli-inspired artwork.</p>
<p>Powered by FLUX.1 and Florence-2 models.</p>
</div>
"""
)
# Auto-process when image is uploaded
upload_img.upload(
process_uploaded_image,
inputs=[
upload_img,
img2img_seed,
img2img_randomize_seed,
img2img_width,
img2img_height,
img2img_guidance_scale,
img2img_steps,
img2img_lora_scale,
],
outputs=[
ghibli_output_image,
ghibli_output_seed,
extracted_caption,
ghibli_prompt,
]
)
# Manual process button
transform_button.click(
process_uploaded_image,
inputs=[
upload_img,
img2img_seed,
img2img_randomize_seed,
img2img_width,
img2img_height,
img2img_guidance_scale,
img2img_steps,
img2img_lora_scale,
],
outputs=[
ghibli_output_image,
ghibli_output_seed,
extracted_caption,
ghibli_prompt,
]
)
demo.launch(debug=True)