File size: 3,050 Bytes
2a6b9ff
 
 
 
 
 
 
 
 
 
 
 
 
2463cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6b9ff
 
 
 
 
 
 
 
2463cdb
 
 
 
 
 
 
 
 
 
2a6b9ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download

from src_inference.pipeline import FluxPipeline
from src_inference.lora_helper import set_single_lora, clear_cache

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)

# Download and load model
base_path = hf_hub_download(repo_id="showlab/OmniConsistency", filename="OmniConsistency.safetensors", local_dir="./Model")
lora_path = hf_hub_download(
    repo_id="showlab/OmniConsistency",
    filename="LoRAs/Ghibli_rank128_bf16.safetensors",
    local_dir="./LoRAs"
)
lora_path = hf_hub_download(
    repo_id="showlab/OmniConsistency",
    filename="LoRAs/American_Cartoon_rank128_bf16.safetensors",
    local_dir="./LoRAs"
)
lora_path = hf_hub_download(
    repo_id="showlab/OmniConsistency",
    filename="LoRAs/Chinese_Ink_rank128_bf16.safetensors",
    local_dir="./LoRAs"
)
lora_path = hf_hub_download(
    repo_id="showlab/OmniConsistency",
    filename="LoRAs/Jojo_rank128_bf16.safetensors",
    local_dir="./LoRAs"
)
lora_path = hf_hub_download(
    repo_id="showlab/OmniConsistency",
    filename="LoRAs/Line_rank128_bf16.safetensors",
    local_dir="./LoRAs"
)
lora_path = hf_hub_download(
    repo_id="showlab/OmniConsistency",
    filename="LoRAs/Rick_Morty_rank128_bf16.safetensors",
    local_dir="./LoRAs"
)
lora_path = hf_hub_download(
    repo_id="showlab/OmniConsistency",
    filename="LoRAs/Vector_rank128_bf16.safetensors",
    local_dir="./LoRAs"
)

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).to(device)

set_single_lora(pipe.transformer, base_path, lora_weights=[1], cond_size=512)

pipe.unload_lora_weights()
pipe.load_lora_weights("./LoRAs", weight_name="Ghibli_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="American_Cartoon_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Chinese_Ink_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Jojo_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Line_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Rick_Morty_rank128_bf16.safetensors")
pipe.load_lora_weights("./LoRAs", weight_name="Vector_rank128_bf16.safetensors")


def generate_manga(input_image, prompt):
    spatial_image = [input_image.convert("RGB")]
    image = pipe(
        prompt,
        height=1024,
        width=1024,
        guidance_scale=3.5,
        num_inference_steps=25,
        max_sequence_length=512,
        spatial_images=spatial_image,
        subject_images=[],
        cond_size=512,
    ).images[0]

    clear_cache(pipe.transformer)
    return image

demo = gr.Interface(
    fn=generate_manga,
    inputs=[
        gr.Image(type="pil", label="Input Character"),
        gr.Textbox(label="Scene Prompt")
    ],
    outputs=gr.Image(label="Generated Manga Frame"),
    title="OmniConsistency Manga Generator"
)

demo.launch()