import numpy as np
import torch
import cv2
from PIL import Image
from transformers import pipeline
import gradio as gr

# ===== Device Setup =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_index = 0 if torch.cuda.is_available() else -1

# ===== MiDaS Depth Estimation Setup =====
# Load MiDaS model and transforms
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
midas.to(device).eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform

# ===== Segmentation Setup =====
segmenter = pipeline(
    "image-segmentation",
    model="nvidia/segformer-b0-finetuned-ade-512-512",
    device=device_index,
    torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
)

# ===== Utility Functions =====
def resize_image(img: Image.Image, max_size: int = 512) -> Image.Image:
    width, height = img.size
    if max(width, height) > max_size:
        ratio = max_size / max(width, height)
        new_size = (int(width * ratio), int(height * ratio))
        return img.resize(new_size, Image.LANCZOS)
    return img

# ===== Depth Prediction =====
def predict_depth(image: Image.Image) -> Image.Image:
    # Ensure input is PIL Image
    img = image.convert('RGB') if not isinstance(image, Image.Image) else image
    img_np = np.array(img)

    # Convert to the format expected by MiDaS
    input_tensor = transform(img_np).to(device)
    input_batch = input_tensor.unsqueeze(0) if input_tensor.ndim == 3 else input_tensor

    # Predict depth
    with torch.no_grad():
        prediction = midas(input_batch)
        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=img_np.shape[:2],
            mode="bicubic",
            align_corners=False
        ).squeeze()

    # Normalize to 0-255
    depth_map = prediction.cpu().numpy()
    depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
    depth_map = (depth_map * 255).astype(np.uint8)
    return Image.fromarray(depth_map)

# ===== Segmentation =====
def segment_image(img: Image.Image) -> Image.Image:
    img = img.convert('RGB')
    img_resized = resize_image(img)
    results = segmenter(img_resized)

    overlay = np.array(img_resized, dtype=np.uint8)
    for res in results:
        mask = np.array(res["mask"], dtype=bool)
        color = np.random.randint(50, 255, 3, dtype=np.uint8)
        overlay[mask] = (overlay[mask] * 0.6 + color * 0.4).astype(np.uint8)

    return Image.fromarray(overlay)

# ===== Gradio App =====
def predict_fn(input_img: Image.Image) -> Image.Image:
    # 1. Compute depth map
    depth_img = predict_depth(input_img)
    # 2. Segment the depth map
    seg_img = segment_image(depth_img)
    return seg_img

iface = gr.Interface(
    fn=predict_fn,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Image(type="pil", label="Segmented Depth Overlay"),
    title="Depth-then-Segmentation Pipeline",
    description="Upload an image. First computes a depth map via MiDaS, then applies SegFormer segmentation on the depth map."
)

if __name__ == "__main__":
    iface.launch()