Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import torch | |
import gradio as gr | |
import numpy as np | |
import spaces | |
from diffusers import DiffusionPipeline | |
from PIL import Image | |
# --- [Optional Patch] --------------------------------------------------------- | |
# This patch fixes potential JSON schema parsing issues in Gradio/Gradio-Client. | |
import gradio_client.utils | |
original_json_schema = gradio_client.utils._json_schema_to_python_type | |
from PIL import ImageOps, ExifTags | |
def preprocess_image(image): | |
# EXIF 정보에 따라 이미지 회전 조정 | |
try: | |
image = ImageOps.exif_transpose(image) | |
except Exception as e: | |
print(f"EXIF 변환 오류: {e}") | |
# 이미지 크기 조정 (너무 크면 모델이 제대로 처리하지 못할 수 있음) | |
if max(image.width, image.height) > 1024: | |
image.thumbnail((1024, 1024), Image.LANCZOS) | |
# 이미지 모드 확인 및 변환 | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
return image | |
# DELETE THIS LINE COMPLETELY | |
def patched_json_schema(schema, defs=None): | |
# Handle boolean schema directly | |
if isinstance(schema, bool): | |
return "bool" | |
# If 'additionalProperties' is a boolean, replace it with a generic type | |
try: | |
if "additionalProperties" in schema and isinstance(schema["additionalProperties"], bool): | |
schema["additionalProperties"] = {"type": "any"} | |
except (TypeError, KeyError): | |
pass | |
# Attempt to parse normally; fallback to "any" on error | |
try: | |
return original_json_schema(schema, defs) | |
except Exception: | |
return "any" | |
gradio_client.utils._json_schema_to_python_type = patched_json_schema | |
# ----------------------------------------------------------------------------- | |
# ----------------------------- Model Loading ---------------------------------- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
repo_id = "black-forest-labs/FLUX.1-dev" | |
adapter_id = "openfree/flux-chatgpt-ghibli-lora" | |
def load_model_with_retry(max_retries=5): | |
for attempt in range(max_retries): | |
try: | |
print(f"Loading model attempt {attempt+1}/{max_retries}...") | |
pipeline = DiffusionPipeline.from_pretrained( | |
repo_id, | |
torch_dtype=torch.bfloat16, | |
use_safetensors=True, | |
resume_download=True | |
) | |
print("Base model loaded successfully, now loading LoRA weights...") | |
pipeline.load_lora_weights(adapter_id) | |
pipeline = pipeline.to(device) | |
print("Pipeline is ready!") | |
return pipeline | |
except Exception as e: | |
if attempt < max_retries - 1: | |
wait_time = 10 * (attempt + 1) | |
print(f"Error loading model: {e}. Retrying in {wait_time} seconds...") | |
import time | |
time.sleep(wait_time) | |
else: | |
raise Exception(f"Failed to load model after {max_retries} attempts: {e}") | |
pipeline = load_model_with_retry() | |
# ----------------------------- Inference Function ----------------------------- | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
def inference( | |
prompt: str, | |
seed: int, | |
randomize_seed: bool, | |
width: int, | |
height: int, | |
guidance_scale: float, | |
num_inference_steps: int, | |
lora_scale: float, | |
): | |
# If "randomize_seed" is selected, choose a random seed | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
print(f"Running inference with prompt: {prompt}") | |
try: | |
image = pipeline( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
joint_attention_kwargs={"scale": lora_scale}, | |
).images[0] | |
return image, seed | |
except Exception as e: | |
print(f"Error during inference: {e}") | |
# Return a red error image of the specified size and the used seed | |
error_img = Image.new('RGB', (width, height), color='red') | |
return error_img, seed | |
# ----------------------------- Florence-2 Captioner --------------------------- | |
import subprocess | |
try: | |
subprocess.run( | |
'pip install flash-attn --no-build-isolation', | |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
shell=True | |
) | |
except Exception as e: | |
print(f"Warning: Could not install flash-attn: {e}") | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
# Function to safely load models | |
def load_caption_model(model_name): | |
try: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, trust_remote_code=True | |
).eval() | |
processor = AutoProcessor.from_pretrained( | |
model_name, trust_remote_code=True | |
) | |
return model, processor | |
except Exception as e: | |
print(f"Error loading caption model {model_name}: {e}") | |
return None, None | |
# Pre-load models and processors | |
print("Loading captioning models...") | |
default_caption_model = 'microsoft/Florence-2-large' | |
models = {} | |
processors = {} | |
# Try to load the default model | |
default_model, default_processor = load_caption_model(default_caption_model) | |
if default_model is not None and default_processor is not None: | |
models[default_caption_model] = default_model | |
processors[default_caption_model] = default_processor | |
print(f"Successfully loaded default caption model: {default_caption_model}") | |
else: | |
# Fallback to simpler model | |
fallback_model = 'gokaygokay/Florence-2-Flux' | |
fallback_model_obj, fallback_processor = load_caption_model(fallback_model) | |
if fallback_model_obj is not None and fallback_processor is not None: | |
models[fallback_model] = fallback_model_obj | |
processors[fallback_model] = fallback_processor | |
default_caption_model = fallback_model | |
print(f"Loaded fallback caption model: {fallback_model}") | |
else: | |
print("WARNING: Failed to load any caption model!") | |
def caption_image(image, model_name=default_caption_model): | |
""" | |
Runs the selected Florence-2 model to generate a detailed caption. | |
""" | |
from PIL import Image as PILImage | |
import numpy as np | |
print(f"Starting caption generation with model: {model_name}") | |
# Handle case where image is already a PIL image | |
if isinstance(image, PILImage.Image): | |
pil_image = image | |
else: | |
# Convert numpy array to PIL | |
if isinstance(image, np.ndarray): | |
pil_image = PILImage.fromarray(image) | |
else: | |
print(f"Unexpected image type: {type(image)}") | |
return "Error: Unsupported image type" | |
# Convert input to RGB if needed | |
if pil_image.mode != "RGB": | |
pil_image = pil_image.convert("RGB") | |
# Check if model is available | |
if model_name not in models or model_name not in processors: | |
available_models = list(models.keys()) | |
if available_models: | |
model_name = available_models[0] | |
print(f"Requested model not available, using: {model_name}") | |
else: | |
return "Error: No caption models available" | |
model = models[model_name] | |
processor = processors[model_name] | |
task_prompt = "<DESCRIPTION>" | |
user_prompt = task_prompt + "Describe this image in great detail." | |
try: | |
inputs = processor(text=user_prompt, images=pil_image, return_tensors="pt") | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
num_beams=3, | |
repetition_penalty=1.10, | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
generated_text, task=task_prompt, image_size=(pil_image.width, pil_image.height) | |
) | |
# Extract the caption | |
caption = parsed_answer.get("<DESCRIPTION>", "") | |
print(f"Generated caption: {caption}") | |
return caption | |
except Exception as e: | |
print(f"Error during captioning: {e}") | |
return f"Error generating caption: {str(e)}" | |
# --------- Process uploaded image and generate Ghibli style image --------- | |
def process_uploaded_image( | |
image, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
lora_scale | |
): | |
if image is None: | |
print("No image provided") | |
return None, None, "No image provided", "No image provided" | |
print("Starting image processing workflow") | |
# Step 1: Generate caption from the uploaded image | |
try: | |
caption = caption_image(image) | |
if caption.startswith("Error:"): | |
print(f"Captioning failed: {caption}") | |
# Use a default caption as fallback | |
caption = "A beautiful scene" | |
except Exception as e: | |
print(f"Exception during captioning: {e}") | |
caption = "A beautiful scene" | |
# Step 2: Append "ghibli style" to the caption | |
ghibli_prompt = f"{caption}, ghibli style" | |
print(f"Final prompt for Ghibli generation: {ghibli_prompt}") | |
# Step 3: Generate Ghibli-style image based on the caption | |
try: | |
generated_image, used_seed = inference( | |
prompt=ghibli_prompt, | |
seed=seed, | |
randomize_seed=randomize_seed, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
lora_scale=lora_scale | |
) | |
print(f"Image generation complete with seed: {used_seed}") | |
return generated_image, used_seed, caption, ghibli_prompt | |
except Exception as e: | |
print(f"Error generating image: {e}") | |
error_img = Image.new('RGB', (width, height), color='red') | |
return error_img, seed, caption, ghibli_prompt | |
# Define Ghibli Studio Theme | |
ghibli_theme = gr.themes.Soft( | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
font=[gr.themes.GoogleFont("Nunito"), "ui-sans-serif", "sans-serif"], | |
radius_size=gr.themes.sizes.radius_sm, | |
).set( | |
body_background_fill="#f0f9ff", | |
body_background_fill_dark="#0f172a", | |
button_primary_background_fill="#6366f1", | |
button_primary_background_fill_hover="#4f46e5", | |
button_primary_text_color="#ffffff", | |
block_title_text_weight="600", | |
block_border_width="1px", | |
block_shadow="0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1)", | |
) | |
# Custom CSS for enhanced visuals | |
custom_css = """ | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.main-header { | |
text-align: center; | |
margin-bottom: 1rem; | |
font-weight: 800; | |
font-size: 2.5rem; | |
background: linear-gradient(90deg, #4338ca, #3b82f6); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
padding: 0.5rem; | |
} | |
.tagline { | |
text-align: center; | |
font-size: 1.2rem; | |
margin-bottom: 2rem; | |
color: #4b5563; | |
} | |
.image-preview { | |
border-radius: 12px; | |
overflow: hidden; | |
box-shadow: 0 10px 15px -3px rgb(0 0 0 / 0.1), 0 4px 6px -4px rgb(0 0 0 / 0.1); | |
} | |
.panel-box { | |
border-radius: 12px; | |
background-color: rgba(255, 255, 255, 0.8); | |
padding: 1rem; | |
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1); | |
} | |
.control-panel { | |
padding: 1rem; | |
border-radius: 12px; | |
background-color: rgba(255, 255, 255, 0.9); | |
margin-bottom: 1rem; | |
border: 1px solid #e2e8f0; | |
} | |
.section-header { | |
font-weight: 600; | |
font-size: 1.1rem; | |
margin-bottom: 0.5rem; | |
color: #4338ca; | |
} | |
.transform-button { | |
font-weight: 600 !important; | |
margin-top: 1rem !important; | |
} | |
.footer { | |
text-align: center; | |
color: #6b7280; | |
margin-top: 2rem; | |
font-size: 0.9rem; | |
} | |
.output-panel { | |
background: linear-gradient(135deg, #f0f9ff, #e0f2fe); | |
border-radius: 12px; | |
padding: 1rem; | |
border: 1px solid #bfdbfe; | |
} | |
""" | |
# ----------------------------- Gradio UI -------------------------------------- | |
with gr.Blocks(analytics_enabled=False, theme=ghibli_theme, css=custom_css) as demo: | |
gr.HTML( | |
""" | |
<div class="main-header">Open Ghibli Studio</div> | |
<div class="tagline">Transform your photos into magical Ghibli-inspired artwork</div> | |
""" | |
) | |
# Background image for the app | |
gr.HTML( | |
""" | |
<style> | |
body { | |
background-image: url('https://i.imgur.com/LxPQPR1.jpg'); | |
background-size: cover; | |
background-position: center; | |
background-attachment: fixed; | |
background-repeat: no-repeat; | |
background-color: #f0f9ff; | |
} | |
@media (max-width: 768px) { | |
body { | |
background-size: contain; | |
} | |
} | |
</style> | |
""" | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
with gr.Group(elem_classes="panel-box"): | |
gr.HTML('<div class="section-header">Upload Image</div>') | |
upload_img = gr.Image( | |
label="Drop your image here", | |
type="pil", | |
elem_classes="image-preview", | |
height=400 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Group(elem_classes="control-panel"): | |
gr.HTML('<div class="section-header">Generation Controls</div>') | |
with gr.Row(): | |
img2img_seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42, | |
info="Set a specific seed for reproducible results" | |
) | |
img2img_randomize_seed = gr.Checkbox( | |
label="Randomize Seed", | |
value=True, | |
info="Enable to get different results each time" | |
) | |
with gr.Group(): | |
gr.HTML('<div class="section-header">Image Size</div>') | |
with gr.Row(): | |
img2img_width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
info="Image width in pixels" | |
) | |
img2img_height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
info="Image height in pixels" | |
) | |
with gr.Group(): | |
gr.HTML('<div class="section-header">Generation Parameters</div>') | |
with gr.Row(): | |
img2img_guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=3.5, | |
info="Higher values follow the prompt more closely" | |
) | |
img2img_steps = gr.Slider( | |
label="Steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=30, | |
info="More steps = more detailed but slower generation" | |
) | |
img2img_lora_scale = gr.Slider( | |
label="Ghibli Style Strength", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=1.0, | |
info="Controls the intensity of the Ghibli style" | |
) | |
transform_button = gr.Button("Transform to Ghibli Style", variant="primary", elem_classes="transform-button") | |
with gr.Column(scale=1): | |
with gr.Group(elem_classes="output-panel"): | |
gr.HTML('<div class="section-header">Ghibli Magic Result</div>') | |
ghibli_output_image = gr.Image( | |
label="Generated Ghibli Style Image", | |
elem_classes="image-preview", | |
height=400 | |
) | |
ghibli_output_seed = gr.Number(label="Seed Used", interactive=False) | |
# Debug elements | |
with gr.Accordion("Image Details", open=False): | |
extracted_caption = gr.Textbox( | |
label="Detected Image Content", | |
placeholder="The AI will analyze your image and describe it here...", | |
info="AI-generated description of your uploaded image" | |
) | |
ghibli_prompt = gr.Textbox( | |
label="Generation Prompt", | |
placeholder="The prompt used to create your Ghibli image will appear here...", | |
info="Final prompt used for the Ghibli transformation" | |
) | |
gr.HTML( | |
""" | |
<div class="footer"> | |
<p>Open Ghibli Studio uses AI to transform your images into Ghibli-inspired artwork.</p> | |
<p>Powered by FLUX.1 and Florence-2 models.</p> | |
</div> | |
""" | |
) | |
# Auto-process when image is uploaded | |
upload_img.upload( | |
process_uploaded_image, | |
inputs=[ | |
upload_img, | |
img2img_seed, | |
img2img_randomize_seed, | |
img2img_width, | |
img2img_height, | |
img2img_guidance_scale, | |
img2img_steps, | |
img2img_lora_scale, | |
], | |
outputs=[ | |
ghibli_output_image, | |
ghibli_output_seed, | |
extracted_caption, | |
ghibli_prompt, | |
] | |
) | |
# Manual process button | |
transform_button.click( | |
process_uploaded_image, | |
inputs=[ | |
upload_img, | |
img2img_seed, | |
img2img_randomize_seed, | |
img2img_width, | |
img2img_height, | |
img2img_guidance_scale, | |
img2img_steps, | |
img2img_lora_scale, | |
], | |
outputs=[ | |
ghibli_output_image, | |
ghibli_output_seed, | |
extracted_caption, | |
ghibli_prompt, | |
] | |
) | |
demo.launch(debug=True) |