Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| import gradio as gr | |
| import torch | |
| import safetensors | |
| from huggingface_hub import hf_hub_download | |
| from diffusers.utils import load_image, check_min_version | |
| from controlnet_flux import FluxControlNetModel | |
| from transformer_flux import FluxTransformer2DModel | |
| from pipeline_flux_cnet import FluxControlNetInpaintingPipeline | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| import subprocess | |
| from transformers import T5EncoderModel | |
| from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig | |
| from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
| subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Ensure that the minimal version of diffusers is installed | |
| check_min_version("0.30.2") | |
| quant_config = TransformersBitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| text_encoder_2_4bit = T5EncoderModel.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| subfolder="text_encoder_2", | |
| quantization_config=quant_config, | |
| torch_dtype=torch.bfloat16, | |
| token=HF_TOKEN | |
| ) | |
| # quant_config = DiffusersBitsAndBytesConfig( | |
| # load_in_4bit=True, | |
| # bnb_4bit_use_double_quant=True, | |
| # ) | |
| transformerx = FluxTransformer2DModel.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16, | |
| token=HF_TOKEN | |
| ) | |
| # text_encoder_8bit = T5EncoderModel.from_pretrained( | |
| # "black-forest-labs/FLUX.1-dev", | |
| # subfolder="text_encoder_2", | |
| # quantization_config=quant_config, | |
| # torch_dtype=torch.bfloat16, | |
| # use_safetensors=True, | |
| # token=HF_TOKEN | |
| # ) | |
| # Build pipeline | |
| controlnet = FluxControlNetModel.from_pretrained( | |
| "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", | |
| # subfolder="controlnet", | |
| torch_dtype=torch.bfloat16, | |
| token=HF_TOKEN | |
| ) | |
| pipe = FluxControlNetInpaintingPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| controlnet=controlnet, | |
| # text_encoder_2=text_encoder_8bit, | |
| transformer=transformerx, | |
| torch_dtype=torch.bfloat16, | |
| # device_map="balanced", | |
| token=HF_TOKEN | |
| ) | |
| # pipe.text_encoder_2 = text_encoder_2_4bit | |
| # pipe.transformer = transformer_4bit | |
| pipe.transformer.to(torch.bfloat16) | |
| pipe.controlnet.to(torch.bfloat16) | |
| pipe.to("cuda") | |
| pipe.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name="turbo") | |
| pipe.set_adapters(["turbo"], adapter_weights=[0.95]) | |
| pipe.fuse_lora(lora_scale=1) | |
| pipe.unload_lora_weights() | |
| # We can utilize the enable_group_offload method for Diffusers model implementations | |
| # pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) | |
| # For any other model implementations, the apply_group_offloading function can be used | |
| # pipe.push_to_hub("FLUX.1-Inpainting-8step_uncensored", private=True, token=HF_TOKEN) | |
| # pipe.enable_vae_tiling() | |
| # pipe.enable_model_cpu_offload() | |
| print(pipe.hf_device_map) | |
| def create_mask_from_editor(editor_value): | |
| """ | |
| Create a mask from the ImageEditor value. | |
| Args: | |
| editor_value: Dictionary from EditorValue with 'background', 'layers', and 'composite' | |
| Returns: | |
| PIL Image with white mask | |
| """ | |
| # The 'composite' key contains the final image with all layers applied | |
| composite_image = editor_value['composite'] | |
| # Convert to numpy array | |
| composite_array = np.array(composite_image) | |
| # Create mask where the composite image is white | |
| mask_array = np.all(composite_array == (255, 255, 255), axis=-1).astype(np.uint8) * 255 | |
| mask_image = Image.fromarray(mask_array) | |
| return mask_image | |
| def create_mask_on_image(image, xyxy): | |
| """ | |
| Create a white mask on the image given xyxy coordinates. | |
| Args: | |
| image: PIL Image | |
| xyxy: List of [x1, y1, x2, y2] coordinates | |
| Returns: | |
| PIL Image with white mask | |
| """ | |
| # Convert to numpy array | |
| img_array = np.array(image) | |
| # Create mask | |
| mask = Image.new('RGB', image.size, (0, 0, 0)) | |
| draw = ImageDraw.Draw(mask) | |
| # Draw white rectangle | |
| draw.rectangle(xyxy, fill=(255, 255, 255)) | |
| # Convert mask to array | |
| mask_array = np.array(mask) | |
| # Apply mask to image | |
| masked_array = np.where(mask_array == 255, 255, img_array) | |
| return Image.fromarray(mask_array), Image.fromarray(masked_array) | |
| def create_diptych_image(image): | |
| # Create a diptych image with original on left and black on right | |
| width, height = image.size | |
| diptych = Image.new('RGB', (width * 2, height), 'black') | |
| diptych.paste(image, (0, 0)) | |
| return diptych | |
| def inpaint_image(image, prompt, subject, editor_value): | |
| # Load image and mask | |
| size = (1536, 768) | |
| image = load_image(image).convert("RGB").resize((768, 768)) | |
| diptych_image = create_diptych_image(image) | |
| # mask = load_image(mask_path).convert("RGB").resize(size) | |
| # mask, mask_image = create_mask_on_image(image, [250, 275, 500, 400]) | |
| mask, mask_image = create_mask_on_image(diptych_image, [768, 0, 1536, 768]) | |
| generator = torch.Generator(device="cuda").manual_seed(24) | |
| # Load and preprocess image | |
| # Calculate attention scale mask | |
| attn_scale_factor = 1.5 | |
| # Create a tensor of ones with same size as diptych image | |
| H, W = size[1]//16, size[0]//16 | |
| attn_scale_mask = torch.zeros(size[1], size[0]) | |
| attn_scale_mask[:, 768:] = 1.0 # height, width | |
| attn_scale_mask = torch.nn.functional.interpolate(attn_scale_mask[None, None, :, :], (H, W), mode='nearest-exact').flatten() | |
| attn_scale_mask = attn_scale_mask[None, None, :, None].repeat(1, 24, 1, H*W) | |
| # Get inverted attention mask by subtracting from 1.0 | |
| transposed_inverted_attn_scale_mask = (1.0 - attn_scale_mask).transpose(-1, -2) | |
| cross_attn_region = torch.logical_and(attn_scale_mask, transposed_inverted_attn_scale_mask) | |
| cross_attn_region = cross_attn_region * attn_scale_factor | |
| cross_attn_region[cross_attn_region < 1.0] = 1.0 | |
| full_attn_scale_mask = torch.ones(1, 24, 512+H*W, 512+H*W) | |
| full_attn_scale_mask[:, :, 512:, 512:] = cross_attn_region | |
| # Convert to bfloat16 to match model dtype | |
| full_attn_scale_mask = full_attn_scale_mask.to(device=pipe.transformer.device, dtype=torch.bfloat16) | |
| subject_name=subject | |
| target_text_prompt=prompt | |
| prompt_final=f'A two side-by-side image of {subject_name}. LEFT: a photo of {subject_name}; RIGHT: a photo of {subject_name} {target_text_prompt}.' | |
| # Convert attention mask to PIL image format | |
| # Take first head's mask after prompt tokens (shape is now H*W x H*W) | |
| attn_vis = full_attn_scale_mask[0, 0] | |
| attn_vis[attn_vis <= 1.0] = 0 | |
| attn_vis[attn_vis > 1.0] = 255 | |
| attn_vis = attn_vis.cpu().float().numpy().astype(np.uint8) | |
| # # Convert to PIL Image | |
| attn_vis_img = Image.fromarray(attn_vis) | |
| attn_vis_img.save('attention_mask_vis.png') | |
| with torch.inference_mode(): | |
| result = pipe( | |
| prompt=prompt_final, | |
| height=size[1], | |
| width=size[0], | |
| control_image=diptych_image, | |
| control_mask=mask, | |
| num_inference_steps=12, | |
| generator=generator, | |
| controlnet_conditioning_scale=0.7, | |
| guidance_scale=1, | |
| negative_prompt="", | |
| true_guidance_scale=1.0, | |
| attn_scale_mask=full_attn_scale_mask, | |
| ).images[0] | |
| return result, attn_vis_img | |
| # Create Gradio interface with structured layout | |
| with gr.Blocks() as iface: | |
| gr.Markdown("## FLUX Inpainting with Diptych Prompting") | |
| gr.Markdown("Upload an image, specify a prompt, and draw a mask on the image. The app will automatically generate the inpainted image.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Accordion(): | |
| input_image = gr.Image(type="filepath", label="Upload Image") | |
| with gr.Row(): | |
| prompt_preview = gr.Textbox(value="A two side-by-side image of 'subject_name'. LEFT: a photo of 'subject_name'; RIGHT: a photo of 'subject_name' 'target_text_prompt'", interactive=False) | |
| subject = gr.Textbox(lines=1, placeholder="Enter your subject", label="Subject") | |
| prompt = gr.Textbox(lines=2, placeholder="Enter your prompt here (e.g., 'wearing a christmas hat, in a busy street')", label="Prompt") | |
| with gr.Column(): | |
| editor_value = gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", visible=False) | |
| inpainted_image = gr.Image(type="pil", label="Inpainted Image") | |
| attn_vis_img = gr.Image(type="pil", label="Attn Vis Image") | |
| with gr.Row(): | |
| inpaint_button = gr.Button("Inpaint") | |
| inpaint_button.click(fn=inpaint_image, inputs=[input_image, prompt, subject, editor_value], outputs=[inpainted_image, attn_vis_img]) | |
| # Launch the app | |
| iface.launch() |