import torch
from torchvision import transforms
from PIL import Image
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def process_image(image, birefnet, device="cuda"):
    """Processes the input image to remove the background.

    Args:
        image (PIL.Image.Image): The image to process.
        birefnet (torch.nn.Module): The BiRefNet model.
        device (str): The device to run the model on (default: "cuda").

    Returns:
        PIL.Image.Image: The processed image with background removed.
    """
    try:
        image_size = image.size
        input_images = transform_image(image).unsqueeze(0).to(device)

        # Prediction
        with torch.no_grad():
            preds = birefnet(input_images)[-1].sigmoid().cpu()
        pred = preds[0].squeeze()
        pred_pil = transforms.ToPILImage()(pred)
        mask = pred_pil.resize(image_size)
        image.putalpha(mask)
        logging.info("Image processed successfully.")
        return image
    except Exception as e:
        logging.error(f"Error processing image: {e}")
        raise Exception(f"Error processing image: {e}")