Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
from src_inference.pipeline import FluxPipeline | |
from src_inference.lora_helper import set_single_lora, clear_cache | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Running on:", device) | |
# Download and load model | |
base_path = hf_hub_download(repo_id="showlab/OmniConsistency", filename="OmniConsistency.safetensors", local_dir="./Model") | |
lora_path = hf_hub_download( | |
repo_id="showlab/OmniConsistency", | |
filename="LoRAs/Ghibli_rank128_bf16.safetensors", | |
local_dir="./LoRAs" | |
) | |
lora_path = hf_hub_download( | |
repo_id="showlab/OmniConsistency", | |
filename="LoRAs/American_Cartoon_rank128_bf16.safetensors", | |
local_dir="./LoRAs" | |
) | |
lora_path = hf_hub_download( | |
repo_id="showlab/OmniConsistency", | |
filename="LoRAs/Chinese_Ink_rank128_bf16.safetensors", | |
local_dir="./LoRAs" | |
) | |
lora_path = hf_hub_download( | |
repo_id="showlab/OmniConsistency", | |
filename="LoRAs/Jojo_rank128_bf16.safetensors", | |
local_dir="./LoRAs" | |
) | |
lora_path = hf_hub_download( | |
repo_id="showlab/OmniConsistency", | |
filename="LoRAs/Line_rank128_bf16.safetensors", | |
local_dir="./LoRAs" | |
) | |
lora_path = hf_hub_download( | |
repo_id="showlab/OmniConsistency", | |
filename="LoRAs/Rick_Morty_rank128_bf16.safetensors", | |
local_dir="./LoRAs" | |
) | |
lora_path = hf_hub_download( | |
repo_id="showlab/OmniConsistency", | |
filename="LoRAs/Vector_rank128_bf16.safetensors", | |
local_dir="./LoRAs" | |
) | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
set_single_lora(pipe.transformer, base_path, lora_weights=[1], cond_size=512) | |
pipe.unload_lora_weights() | |
pipe.load_lora_weights("./LoRAs", weight_name="Ghibli_rank128_bf16.safetensors") | |
pipe.load_lora_weights("./LoRAs", weight_name="American_Cartoon_rank128_bf16.safetensors") | |
pipe.load_lora_weights("./LoRAs", weight_name="Chinese_Ink_rank128_bf16.safetensors") | |
pipe.load_lora_weights("./LoRAs", weight_name="Jojo_rank128_bf16.safetensors") | |
pipe.load_lora_weights("./LoRAs", weight_name="Line_rank128_bf16.safetensors") | |
pipe.load_lora_weights("./LoRAs", weight_name="Rick_Morty_rank128_bf16.safetensors") | |
pipe.load_lora_weights("./LoRAs", weight_name="Vector_rank128_bf16.safetensors") | |
def generate_manga(input_image, prompt): | |
spatial_image = [input_image.convert("RGB")] | |
image = pipe( | |
prompt, | |
height=1024, | |
width=1024, | |
guidance_scale=3.5, | |
num_inference_steps=25, | |
max_sequence_length=512, | |
spatial_images=spatial_image, | |
subject_images=[], | |
cond_size=512, | |
).images[0] | |
clear_cache(pipe.transformer) | |
return image | |
demo = gr.Interface( | |
fn=generate_manga, | |
inputs=[ | |
gr.Image(type="pil", label="Input Character"), | |
gr.Textbox(label="Scene Prompt") | |
], | |
outputs=gr.Image(label="Generated Manga Frame"), | |
title="OmniConsistency Manga Generator" | |
) | |
demo.launch() | |