import torch
import spaces
import os
import diffusers
import PIL
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
import gradio as gr
from accelerate import dispatch_model, infer_auto_device_map
from PIL import Image
from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel
import gc
# Corrected and optimized FluxControlNet implementation
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
device = "cuda"
torch_dtype = torch.bfloat16
MAX_SEED = 1000000

def self_attention_slicing(module, slice_size=3):
    """Modified from Diffusers' original for Flux compatibility"""
    def sliced_attention(*args, **kwargs):
        if "dim" in kwargs:
            dim = kwargs["dim"]
        else:
            dim = 1
        
        if slice_size == "auto":
            # Automatic slicing based on Flux architecture
            return module(*args, **kwargs)
            
        output = torch.cat([
            module(
                *[arg[:, :, i:i+slice_size] if i == dim else arg 
                for arg in args],
                **{k: v[:, :, i:i+slice_size] if k == dim else v 
                   for k,v in kwargs.items()}
            )
            for i in range(0, args[0].shape[dim], slice_size)
        ], dim=dim)
        
        return output
    return sliced_attention
    
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
text_encoder_2_8bit = T5EncoderModel.from_pretrained(
    "LPX55/FLUX.1-merged_uncensored",
    subfolder="text_encoder_2",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
    token=huggingface_token
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
transformer_8bit = FluxTransformer2DModel.from_pretrained(
    "LPX55/FLUX.1-merged_uncensored",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
    token=huggingface_token
)
good_vae = AutoencoderKL.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="vae",
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    device_map=None,  # Disable automatic mapping
    token=huggingface_token
).to(device)
# 2. Main Pipeline Initialization WITH VAE SCOPE
pipe = FluxControlNetPipeline.from_pretrained(
    "LPX55/FLUX.1-merged_uncensored",
    controlnet=FluxControlNetModel.from_pretrained(
        "jasperai/Flux.1-dev-Controlnet-Upscaler",
        torch_dtype=torch.bfloat16
    ),
    vae=good_vae,  # Now defined in scope
    transformer=transformer_8bit,
    text_encoder_2=text_encoder_2_8bit,
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    device_map=None,
    token=huggingface_token  # Note corrected env var name
)
pipe.to(device)
# 3. Strict Order for Optimization Steps
# A. Apply CPU Offloading FIRST
#### pipe.enable_sequential_cpu_offload()  # No arguments for new API
# 2. Then apply custom VAE slicing
if getattr(pipe, "vae", None) is not None:
    # Method 1: Use official implementation if available
    try:
        pipe.vae.enable_slicing()
    except AttributeError:
        # Method 2: Apply manual slicing for Flux compatibility [source_id]pipeline_flux_controlnet.py
        print("Falling back to manual attention slicing.")
        pipe.vae.decode = self_attention_slicing(pipe.vae.decode, 2) 

pipe.enable_attention_slicing(1)
# B. Enable Memory Optimizations
# pipe.enable_vae_tiling()
# pipe.enable_xformers_memory_efficient_attention()

# C. Unified Precision Handling
# for comp in [pipe.unet, pipe.vae, pipe.controlnet]:
#     comp.to(dtype=torch.bfloat16)

print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
@spaces.GPU
def generate_image(prompt, scale, steps, seed, control_image, controlnet_conditioning_scale, guidance_scale, guidance_start, guidance_end):
    print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")

    # Load control image
    control_image = load_image(control_image)
    w, h = control_image.size
    w = w - w % 8
    h = h - h % 8
    control_image = control_image.resize((int(w * scale), int(h * scale)))
    print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1]))
    generator = torch.Generator().manual_seed(seed)

    image = pipe(
        prompt=prompt,
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=steps,
        guidance_scale=guidance_scale,
        height=h,
        width=w,
        control_guidance_start=guidance_start,
        control_guidance_end=guidance_end,
        generator=generator
    ).images[0]
    return image
# Create Gradio interface
iface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
        gr.Slider(1, 3, value=1, label="Scale"),
        gr.Slider(2, 20, value=8, label="Steps"),
        gr.Slider(0, MAX_SEED, value=42, label="Seed"),
        gr.Image(type="pil", label="Control Image"),
        gr.Slider(0, 1, value=0.6, label="ControlNet Scale"),
        gr.Slider(1, 20, value=3.5, label="Guidance Scale"),
        gr.Slider(0, 1, value=0.0, label="Control Guidance Start"),
        gr.Slider(0, 1, value=1.0, label="Control Guidance End"),
    ],
    outputs=[
        gr.Image(type="pil", label="Generated Image", format="png"),
    ],
    title="FLUX ControlNet Image Generation",
    description="Generate images using the FluxControlNetPipeline. Upload a control image and enter a prompt to create an image.",
)
print(f"Memory Usage: {torch.cuda.memory_summary(device=None, abbreviated=False)}")
gc.enable()
gc.collect()
# Launch the app
iface.launch(show_error=True, share=True)