MangaNinja-demo / run_gradio.py
fffiloni's picture
Update run_gradio.py
643f878 verified
import spaces
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import cv2
import gradio as gr
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
import numpy as np
import os
import re
from PIL import Image, ImageDraw
import cv2
#
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
import torch.nn as nn
from inference.manganinjia_pipeline import MangaNinjiaPipeline
from diffusers import (
ControlNetModel,
DiffusionPipeline,
DDIMScheduler,
AutoencoderKL,
)
from src.models.mutual_self_attention_multi_scale import ReferenceAttentionControl
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.refunet_2d_condition import RefUNet2DConditionModel
from src.point_network import PointNet
from src.annotator.lineart import BatchLineartDetector
val_configs = OmegaConf.load('./configs/inference.yaml')
# download the checkpoints
from huggingface_hub import snapshot_download, hf_hub_download
os.makedirs("checkpoints", exist_ok=True)
# List of subdirectories to create inside "checkpoints"
subfolders = [
"StableDiffusion",
"models",
"MangaNinjia"
]
# Create each subdirectory
for subfolder in subfolders:
os.makedirs(os.path.join("checkpoints", subfolder), exist_ok=True)
# List of subdirectories to create inside "models"
models_subfolders = [
"clip-vit-large-patch14",
"control_v11p_sd15_lineart",
"Annotators"
]
# Create each subdirectory
for subfolder in models_subfolders:
os.makedirs(os.path.join("checkpoints/models", subfolder), exist_ok=True)
snapshot_download(
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5",
local_dir = "./checkpoints/StableDiffusion"
)
snapshot_download(
repo_id = "openai/clip-vit-large-patch14",
local_dir = "./checkpoints/models/clip-vit-large-patch14"
)
snapshot_download(
repo_id = "lllyasviel/control_v11p_sd15_lineart",
local_dir = "./checkpoints/models/control_v11p_sd15_lineart"
)
hf_hub_download(
repo_id = "lllyasviel/Annotators",
filename = "sk_model.pth",
local_dir = "./checkpoints/models/Annotators"
)
snapshot_download(
repo_id = "Johanan0528/MangaNinjia",
local_dir = "./checkpoints/MangaNinjia"
)
# === load the checkpoint ===
pretrained_model_name_or_path = val_configs.model_path.pretrained_model_name_or_path
refnet_clip_vision_encoder_path = val_configs.model_path.clip_vision_encoder_path
controlnet_clip_vision_encoder_path = val_configs.model_path.clip_vision_encoder_path
controlnet_model_name_or_path = val_configs.model_path.controlnet_model_name
annotator_ckpts_path = val_configs.model_path.annotator_ckpts_path
output_root = val_configs.inference_config.output_path
device = val_configs.inference_config.device
preprocessor = BatchLineartDetector(annotator_ckpts_path)
in_channels_reference_unet = 4
in_channels_denoising_unet = 4
in_channels_controlnet = 4
noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path,subfolder='scheduler')
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path,
subfolder='vae'
)
denoising_unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path,subfolder="unet",
in_channels=in_channels_denoising_unet,
low_cpu_mem_usage=False,
ignore_mismatched_sizes=True
)
reference_unet = RefUNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path,subfolder="unet",
in_channels=in_channels_reference_unet,
low_cpu_mem_usage=False,
ignore_mismatched_sizes=True
)
refnet_tokenizer = CLIPTokenizer.from_pretrained(refnet_clip_vision_encoder_path)
refnet_text_encoder = CLIPTextModel.from_pretrained(refnet_clip_vision_encoder_path)
refnet_image_enc = CLIPVisionModelWithProjection.from_pretrained(refnet_clip_vision_encoder_path)
controlnet = ControlNetModel.from_pretrained(
controlnet_model_name_or_path,
in_channels=in_channels_controlnet,
low_cpu_mem_usage=False,
ignore_mismatched_sizes=True
)
controlnet_tokenizer = CLIPTokenizer.from_pretrained(controlnet_clip_vision_encoder_path)
controlnet_text_encoder = CLIPTextModel.from_pretrained(controlnet_clip_vision_encoder_path)
controlnet_image_enc = CLIPVisionModelWithProjection.from_pretrained(controlnet_clip_vision_encoder_path)
point_net=PointNet()
reference_control_writer = ReferenceAttentionControl(
reference_unet,
do_classifier_free_guidance=False,
mode="write",
fusion_blocks="full",
)
reference_control_reader = ReferenceAttentionControl(
denoising_unet,
do_classifier_free_guidance=False,
mode="read",
fusion_blocks="full",
)
controlnet.load_state_dict(
torch.load(val_configs.model_path.manga_control_model_path, map_location="cpu"),
strict=False,
)
point_net.load_state_dict(
torch.load(val_configs.model_path.point_net_path, map_location="cpu"),
strict=False,
)
reference_unet.load_state_dict(
torch.load(val_configs.model_path.manga_reference_model_path, map_location="cpu"),
strict=False,
)
denoising_unet.load_state_dict(
torch.load(val_configs.model_path.manga_main_model_path, map_location="cpu"),
strict=False,
)
pipe = MangaNinjiaPipeline(
reference_unet=reference_unet,
controlnet=controlnet,
denoising_unet=denoising_unet,
vae=vae,
refnet_tokenizer=refnet_tokenizer,
refnet_text_encoder=refnet_text_encoder,
refnet_image_encoder=refnet_image_enc,
controlnet_tokenizer=controlnet_tokenizer,
controlnet_text_encoder=controlnet_text_encoder,
controlnet_image_encoder=controlnet_image_enc,
scheduler=noise_scheduler,
point_net=point_net
)
pipe = pipe.to(torch.device(device))
def string_to_np_array(coord_string):
coord_string = coord_string.strip('[]')
coords = re.findall(r'\d+', coord_string)
coords = list(map(int, coords))
coord_array = np.array(coords).reshape(-1, 2)
return coord_array
def infer_single(is_lineart, ref_image, target_image, output_coords_ref, output_coords_base, seed = -1, num_inference_steps=20, guidance_scale_ref = 9, guidance_scale_point =15 ):
"""
mask: 0/1 1-channel np.array
image: rgb np.array
"""
generator = torch.cuda.manual_seed(seed)
matrix1 = np.zeros((512, 512), dtype=np.uint8)
matrix2 = np.zeros((512, 512), dtype=np.uint8)
output_coords_ref = string_to_np_array(output_coords_ref)
output_coords_base = string_to_np_array(output_coords_base)
for index, (coords_ref,coords_base) in enumerate(zip(output_coords_ref,output_coords_base)):
y1, x1 = coords_ref
y2, x2 = coords_base
matrix1[y1, x1] = index + 1
matrix2[y2, x2] = index + 1
point_ref = torch.from_numpy(matrix1).unsqueeze(0).unsqueeze(0)
point_main = torch.from_numpy(matrix2).unsqueeze(0).unsqueeze(0)
preprocessor.to(device,dtype=torch.float32)
pipe_out = pipe(
is_lineart,
ref_image,
target_image,
target_image,
denosing_steps=num_inference_steps,
processing_res=512,
match_input_res=True,
batch_size=1,
show_progress_bar=True,
guidance_scale_ref=guidance_scale_ref,
guidance_scale_point=guidance_scale_point,
preprocessor=preprocessor,
generator=generator,
point_ref=point_ref,
point_main=point_main,
)
return pipe_out
def inference_single_image(ref_image,
tar_image,
ddim_steps,
scale_ref,
scale_point,
seed,
output_coords1,
output_coords2,
is_lineart
):
if seed == -1:
seed = np.random.randint(10000)
pipe_out = infer_single(is_lineart, ref_image, tar_image, output_coords_ref=output_coords1, output_coords_base=output_coords2,seed=seed ,num_inference_steps=ddim_steps, guidance_scale_ref = scale_ref, guidance_scale_point = scale_point
)
return pipe_out
clicked_points_img1 = []
clicked_points_img2 = []
current_img_idx = 0
max_clicks = 14
point_size = 8
colors = [(255, 0, 0), (0, 255, 0)]
# Process images: resizing them to 512x512
def process_image(ref, base):
ref_resized = cv2.resize(ref, (512, 512)) # Note OpenCV resize order is (width, height)
base_resized = cv2.resize(base, (512, 512))
return ref_resized, base_resized
# Convert string to numpy array of coordinates
def string_to_np_array(coord_string):
coord_string = coord_string.strip('[]')
coords = re.findall(r'\d+', coord_string)
coords = list(map(int, coords))
coord_array = np.array(coords).reshape(-1, 2)
return coord_array
# Function to handle click events
def get_select_coords(img1, img2, evt: gr.SelectData):
global clicked_points_img1, clicked_points_img2, current_img_idx
click_coords = (evt.index[1], evt.index[0])
if current_img_idx == 0:
clicked_points_img1.append(click_coords)
if len(clicked_points_img1) > max_clicks:
clicked_points_img1 = []
current_img = img1
clicked_points = clicked_points_img1
else:
clicked_points_img2.append(click_coords)
if len(clicked_points_img2) > max_clicks:
clicked_points_img2 = []
current_img = img2
clicked_points = clicked_points_img2
current_img_idx = 1 - current_img_idx
img_pil = Image.fromarray(current_img.astype('uint8'))
draw = ImageDraw.Draw(img_pil)
for idx, point in enumerate(clicked_points):
x, y = point
color = colors[current_img_idx]
for dx in range(-point_size, point_size + 1):
for dy in range(-point_size, point_size + 1):
if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]:
draw.point((y+dy, x+dx), fill=color)
img_out = np.array(img_pil)
coord_array = np.array([(x, y) for x, y in clicked_points])
return img_out, coord_array
# Function to clear the clicked points
def undo_last_point(ref, base):
global clicked_points_img1, clicked_points_img2, current_img_idx
current_img_idx=1-current_img_idx
if current_img_idx == 0 and clicked_points_img1:
clicked_points_img1.pop() # Undo last point in ref
elif current_img_idx == 1 and clicked_points_img2:
clicked_points_img2.pop() # Undo last point in base
# After removing the last point, redraw the image without it
if current_img_idx == 0:
current_img = ref
current_img_other = base
clicked_points = clicked_points_img1
clicked_points_other = clicked_points_img2
else:
current_img = base
current_img_other = ref
clicked_points = clicked_points_img2
clicked_points_other = clicked_points_img1
# Redraw the image without the last point
img_pil = Image.fromarray(current_img.astype('uint8'))
draw = ImageDraw.Draw(img_pil)
for idx, point in enumerate(clicked_points):
x, y = point
color = colors[current_img_idx]
for dx in range(-point_size, point_size + 1):
for dy in range(-point_size, point_size + 1):
if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]:
draw.point((y+dy, x+dx), fill=color)
img_out = np.array(img_pil)
img_pil_other = Image.fromarray(current_img_other.astype('uint8'),)
draw_other = ImageDraw.Draw(img_pil_other)
for idx, point in enumerate(clicked_points_other):
x, y = point
color = colors[1-current_img_idx]
for dx in range(-point_size, point_size + 1):
for dy in range(-point_size, point_size + 1):
if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]:
draw_other.point((y+dy, x+dx), fill=color)
img_out_other = np.array(img_pil_other)
coord_array = np.array([(x, y) for x, y in clicked_points])
# Return the updated image and coordinates as text
updated_coords = str(coord_array.tolist())
# If current_img_idx is 0, it means we are working with ref, so return for ref
if current_img_idx == 0:
coord_array2 = np.array([(x, y) for x, y in clicked_points_img2])
updated_coords2 = str(coord_array2.tolist())
return img_out, updated_coords, img_out_other, updated_coords2 # for ref image
else:
coord_array1 = np.array([(x, y) for x, y in clicked_points_img1])
updated_coords1 = str(coord_array1.tolist())
return img_out_other, updated_coords1, img_out, updated_coords # for base image
# Main function to run the image processing
@spaces.GPU
def run_local(ref, base, *args, progress=gr.Progress(track_tqdm=True)):
image = Image.fromarray(base)
ref_image = Image.fromarray(ref)
pipe_out = inference_single_image(ref_image.copy(), image.copy(), *args)
to_save_dict = pipe_out.to_save_dict
to_save_dict['edit2'] = pipe_out.img_pil
return [to_save_dict['edit2'], to_save_dict['edge2_black']]
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# MangaNinja: Line Art Colorization with Precise Reference Following")
with gr.Row():
baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=1, height=768)
with gr.Accordion("Advanced Option", open=True):
num_samples = 1
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
scale_ref = gr.Slider(label="Guidance of ref", minimum=0, maximum=30.0, value=9, step=0.1)
scale_point = gr.Slider(label="Guidance of points", minimum=0, maximum=30.0, value=15, step=0.1)
is_lineart = gr.Checkbox(label="Input is lineart", value=False)
seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1)
gr.Markdown("### Tutorial")
gr.Markdown("1. Upload the reference image and target image. Note that for the target image, there are two modes: you can upload an RGB image, and the model will automatically extract the line art; or you can directly upload the line art by checking the 'input is lineart' option.")
gr.Markdown("2. Click 'Process Images' to resize the images to 512*512 resolution.")
gr.Markdown("3. (Optional) **Starting from the reference image**, **alternately** click on the reference and target images in sequence to define matching points. Use 'Undo' to revert the last action.")
gr.Markdown("4. Click 'Generate' to produce the result.")
gr.Markdown("# Upload the reference image and target image")
with gr.Row():
ref = gr.Image(label="Reference Image",)
base = gr.Image(label="Target Image",)
gr.Button("Process Images").click(process_image, inputs=[ref, base], outputs=[ref, base])
with gr.Row():
output_img1 = gr.Image(label="Reference Output")
output_coords1 = gr.Textbox(lines=2, label="Clicked Coordinates Image 1 (npy format)")
output_img2 = gr.Image(label="Base Output")
output_coords2 = gr.Textbox(lines=2, label="Clicked Coordinates Image 2 (npy format)")
# Image click select functions
ref.select(get_select_coords, [ref, base], [output_img1, output_coords1])
base.select(get_select_coords, [ref, base], [output_img2, output_coords2])
# Undo button
undo_button = gr.Button("Undo")
undo_button.click(undo_last_point, inputs=[ref, base], outputs=[output_img1, output_coords1, output_img2, output_coords2])
run_local_button = gr.Button("Generate")
with gr.Row():
gr.Examples(
examples=[
['test_cases/hz0.png', 'test_cases/manga_target_examples/target_1.jpg'],
['test_cases/more_cases/az0.png', 'test_cases/manga_target_examples/target_2.jpg'],
['test_cases/more_cases/hi0.png', 'test_cases/manga_target_examples/target_3.jpg'],
['test_cases/more_cases/kn0.jpg', 'test_cases/manga_target_examples/target_4.jpg'],
['test_cases/more_cases/rk0.jpg', 'test_cases/manga_target_examples/target_5.jpg'],
],
inputs=[ref, base],
cache_examples=False,
examples_per_page=100
)
run_local_button.click(fn=run_local,
inputs=[ref,
base,
ddim_steps,
scale_ref,
scale_point,
seed,
output_coords1,
output_coords2,
is_lineart
],
outputs=[baseline_gallery]
)
demo.launch(show_api=False, show_error=True, ssr_mode=True)