Spaces:
Runtime error
Runtime error
File size: 3,050 Bytes
2a6b9ff 2463cdb 2a6b9ff 2463cdb 2a6b9ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import torch
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from src_inference.pipeline import FluxPipeline
from src_inference.lora_helper import set_single_lora, clear_cache
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)
# Download and load model
base_path = hf_hub_download(repo_id="showlab/OmniConsistency", filename="OmniConsistency.safetensors", local_dir="./Model")
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Ghibli_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/American_Cartoon_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Chinese_Ink_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Jojo_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Line_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Rick_Morty_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
lora_path = hf_hub_download(
repo_id="showlab/OmniConsistency",
filename="LoRAs/Vector_rank128_bf16.safetensors",
local_dir="./LoRAs"
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).to(device)
set_single_lora(pipe.transformer, base_path, lora_weights=[1], cond_size=512)
pipe.unload_lora_weights()
pipe.load_lora_weights("./LoRAs", weight_name="Ghibli_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="American_Cartoon_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Chinese_Ink_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Jojo_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Line_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Rick_Morty_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Vector_rank128_bf16.safetensors")
def generate_manga(input_image, prompt):
spatial_image = [input_image.convert("RGB")]
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=25,
max_sequence_length=512,
spatial_images=spatial_image,
subject_images=[],
cond_size=512,
).images[0]
clear_cache(pipe.transformer)
return image
demo = gr.Interface(
fn=generate_manga,
inputs=[
gr.Image(type="pil", label="Input Character"),
gr.Textbox(label="Scene Prompt")
],
outputs=gr.Image(label="Generated Manga Frame"),
title="OmniConsistency Manga Generator"
)
demo.launch()
|