Spaces:
Build error
Build error
| import gradio as gr | |
| import jax | |
| import numpy as np | |
| import jax.numpy as jnp | |
| from flax.jax_utils import replicate | |
| from flax.training.common_utils import shard | |
| from PIL import Image | |
| from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel | |
| import cv2 | |
| # load control net and stable diffusion v1-5 | |
| controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( | |
| "Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd" | |
| ) | |
| pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16 | |
| ) | |
| def create_key(seed=0): | |
| return jax.random.PRNGKey(seed) | |
| def process_mask(image): | |
| mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| mask = cv2.resize(mask,(512,512)) | |
| return mask | |
| def infer(prompts, negative_prompts, image): | |
| params["controlnet"] = controlnet_params | |
| num_samples = 1 #jax.device_count() | |
| rng = create_key(0) | |
| rng = jax.random.split(rng, jax.device_count()) | |
| im = process_mask(image) | |
| mask = Image.fromarray(im) | |
| prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) | |
| negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) | |
| processed_image = pipe.prepare_image_inputs([mask] * num_samples) | |
| p_params = replicate(params) | |
| prompt_ids = shard(prompt_ids) | |
| negative_prompt_ids = shard(negative_prompt_ids) | |
| processed_image = shard(processed_image) | |
| print(processed_image[0].shape) | |
| output = pipe( | |
| prompt_ids=prompt_ids, | |
| image=processed_image, | |
| params=p_params, | |
| prng_seed=rng, | |
| num_inference_steps=50, | |
| neg_prompt_ids=negative_prompt_ids, | |
| jit=True, | |
| ).images | |
| output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) | |
| return output_images | |
| e_images = ['0.png', | |
| '1.png' | |
| '2.png'] | |
| e_prompts = ['a dog in the middle of the road, shadow on the ground,light direction north-east', | |
| 'a skyscraper in the middle of an intersection, shadow on the ground, light direction east', | |
| 'a red rural house, light temperature 5500, shadow on the ground, light direction south-west'] | |
| e_negative_prompts = ['monochromatic, unrealistic, bad looking, full of glitches'*3] | |
| examples = [] | |
| for image, prompt, negative_prompt in zip(e_images, e_prompts, e_negative_prompts): | |
| examples.append([prompt, negative_prompt, image]) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| prompts = gr.Textbox(label='prompts') | |
| negative_prompts = gr.Textbox(label='negative_prompts') | |
| with gr.Row(): | |
| with gr.Column(): | |
| in_image = gr.Image(label="Depth Map Conditioning") | |
| with gr.Column(): | |
| out_image = gr.Gallery(label="Generated Image") | |
| with gr.Row(): | |
| btn = gr.Button("Run") | |
| gr.Examples(examples=examples, | |
| inputs=[prompts,negative_prompts, in_image], | |
| outputs=out_image, | |
| fn=infer, | |
| cache_examples=True) | |
| btn.click(fn=infer, inputs=[prompts,negative_prompts, in_image] , outputs=out_image) | |
| demo.launch() |