ret45's picture
Update app.py
2463cdb verified
import torch
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from src_inference.pipeline import FluxPipeline
from src_inference.lora_helper import set_single_lora, clear_cache
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)
# Download and load model
base_path = hf_hub_download(repo_id="showlab/OmniConsistency", filename="OmniConsistency.safetensors", local_dir="./Model")
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Ghibli_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/American_Cartoon_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Chinese_Ink_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Jojo_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Line_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Rick_Morty_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Vector_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).to(device)
set_single_lora(pipe.transformer, base_path, lora_weights=[1], cond_size=512)
pipe.unload_lora_weights()
pipe.load_lora_weights("./LoRAs", weight_name="Ghibli_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="American_Cartoon_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Chinese_Ink_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Jojo_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Line_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Rick_Morty_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Vector_rank128_bf16.safetensors")
def generate_manga(input_image, prompt):
spatial_image = [input_image.convert("RGB")]
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=25,
max_sequence_length=512,
spatial_images=spatial_image,
subject_images=[],
cond_size=512,
).images[0]
clear_cache(pipe.transformer)
return image
demo = gr.Interface(
fn=generate_manga,
inputs=[
gr.Image(type="pil", label="Input Character"),
gr.Textbox(label="Scene Prompt")
],
outputs=gr.Image(label="Generated Manga Frame"),
title="OmniConsistency Manga Generator"
)
demo.launch()