Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import argparse | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from config import PipelineConfig | |
| from src.pipeline import FashionPipeline, PipelineOutput | |
| config = PipelineConfig() | |
| device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' | |
| fashion_pipeline = FashionPipeline(config, device=device) | |
| def process( | |
| input_image: np.ndarray, | |
| prompt: str, | |
| negative_prompt: str, | |
| generate_from_mask: bool, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| conditioning_scale: float, | |
| target_image_size: int, | |
| max_image_size: int, | |
| seed: int, | |
| ): | |
| output: PipelineOutput = fashion_pipeline( | |
| control_image=input_image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| generate_from_mask=generate_from_mask, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=float(guidance_scale), | |
| conditioning_scale=float(conditioning_scale), | |
| target_image_size=target_image_size, | |
| max_image_size=max_image_size, | |
| seed=seed, | |
| ) | |
| return [ | |
| output.generated_image, | |
| output.control_mask, | |
| ] | |
| def read_content(file_path: str) -> str: | |
| """Read the content of target file.""" | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return content | |
| image_dir = 'examples/images' | |
| image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)] | |
| with open('examples/prompts.json', 'r') as f: | |
| prompts_list = json.load(f).values() | |
| examples = [[image, prompt[0], prompt[1]] for image, prompt in zip(image_list, prompts_list)] | |
| block = gr.Blocks().queue() | |
| with block: | |
| with gr.Row(): | |
| gr.HTML(read_content('header.html')) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type='numpy') | |
| prompt = gr.Textbox(label='Prompt') | |
| negative_prompt = gr.Textbox(label='Negative Prompt') | |
| with gr.Row(): | |
| generate_from_mask = gr.Checkbox(label='Input image is already a control mask', value=False) | |
| run_button = gr.Button(value='Run') | |
| with gr.Accordion('Advanced options', open=False): | |
| target_image_size = gr.Slider( | |
| label='Image target size:', | |
| minimum=512, | |
| maximum=2048, | |
| value=768, | |
| step=64, | |
| ) | |
| max_image_size = gr.Slider( | |
| label='Image max size:', | |
| minimum=512, | |
| maximum=2048, | |
| value=1024, | |
| step=64, | |
| ) | |
| num_inference_steps = gr.Slider(label='Number of steps', minimum=1, maximum=100, value=20, step=1) | |
| guidance_scale = gr.Slider(label='Guidance scale', minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
| conditioning_scale = gr.Slider(label='Conditioning scale', minimum=0.0, maximum=5.0, value=1.0, step=0.1) | |
| seed = gr.Slider(label='Seed', minimum=0, maximum=config.max_seed, step=1, value=0) | |
| gr.Examples(examples=examples, inputs=[input_image, prompt, negative_prompt], label='Examples - Input Images', examples_per_page=12) | |
| gr.HTML( | |
| """ | |
| <div class="footer"> | |
| <p> This repo based on Unet from <a style="text-decoration: underline;" href="https://huggingface.co/spaces/wildoctopus/cloth-segmentation">cloth-segmentation</a> | |
| It's uses pre-trained U2NET to extract Upper body(red), Lower body(green), Full body(blue) masks, and then | |
| run StableDiffusionXLControlNetPipeline with trained controlnet_baseline to generate image conditioned on this masks. | |
| </p> | |
| """) | |
| with gr.Column(): | |
| generated_output = gr.Image(label='Generated', type='numpy', elem_id='generated') | |
| mask_output = gr.Image(label='Mask', type='numpy', elem_id='mask') | |
| ips = [input_image, prompt, negative_prompt, generate_from_mask, num_inference_steps, guidance_scale, conditioning_scale, target_image_size, max_image_size, seed] | |
| run_button.click(fn=process, inputs=ips, outputs=[generated_output, mask_output]) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| '--share', | |
| '-s', | |
| action="store_true", | |
| default=False, | |
| help='Create public link for the app.' | |
| ) | |
| args = parser.parse_args() | |
| block.launch(share=args.share) | |