import streamlit as st
import streamlit_analytics

import torch
import torchvision.transforms as transforms
from transformers import ViTModel, ViTConfig
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io

streamlit_analytics.start_tracking()

# Set page config for custom theme
st.set_page_config(page_title="ViewViz", layout="wide")

# Custom color scheme for Streamlit
st.markdown("""
    <style>
    .stApp {
        background-color: #2b3d4f;
        color: #ffffff;
    }
    .stButton>button {
        color: #2b3d4f;
        background-color: #4fd1c5;
        border-radius: 5px;
    }
    .stSlider>div>div>div>div {
        background-color: #4fd1c5;
    }
    </style>
    """, unsafe_allow_html=True)

# Set device preference
USE_GPU = False  # Set to True to use GPU, False to use CPU
device = torch.device('cuda' if USE_GPU and torch.cuda.is_available() else 'cpu')

# Available color schemes
COLOR_SCHEMES = {
    'Plasma': plt.cm.plasma,
    'Viridis': plt.cm.viridis,
    'Magma': plt.cm.magma,
    'Inferno': plt.cm.inferno,
    'Cividis': plt.cm.cividis,
    'Spectral': plt.cm.Spectral,
    'Coolwarm': plt.cm.coolwarm
}

# Load the pre-trained Vision Transformer model
@st.cache_resource
def load_model():
    model_name = 'google/vit-base-patch16-384'
    config = ViTConfig.from_pretrained(model_name, output_attentions=True, attn_implementation="eager")
    model = ViTModel.from_pretrained(model_name, config=config)
    model.eval()
    return model.to(device)

model = load_model()

# Image preprocessing
preprocess = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def get_attention_map(img):
    # Preprocess the image
    input_tensor = preprocess(img).unsqueeze(0).to(device)
    
    # Get model output
    with torch.no_grad():
        outputs = model(input_tensor, output_attentions=True)
    
    # Process attention maps
    att_mat = torch.stack(outputs.attentions).squeeze(1)
    att_mat = torch.mean(att_mat, dim=1)

    # Add residual connections
    residual_att = torch.eye(att_mat.size(-1)).unsqueeze(0).to(device)
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size()).to(device)
    joint_attentions[0] = aug_att_mat[0]
    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

    # Get final attention map
    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().cpu().numpy()
    
    return mask

def overlay_attention_map(image, attention_map, overlay_strength, color_scheme):
    # Resize attention map to match image size
    attention_map = Image.fromarray(attention_map).resize(image.size, Image.BICUBIC)
    attention_map = np.array(attention_map)
    
    # Normalize attention map
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    
    # Apply selected color map
    attention_map_color = color_scheme(attention_map)
    
    # Convert image to RGBA
    image_rgba = image.convert("RGBA")
    image_array = np.array(image_rgba) / 255.0
    
    # Overlay attention map on image with adjustable strength
    overlayed_image = image_array * (1 - overlay_strength) + attention_map_color * overlay_strength
    
    return Image.fromarray((overlayed_image * 255).astype(np.uint8))

st.title("ViewViz")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    image = Image.open(uploaded_file).convert('RGB')
    
    st.success("Starting Prediction Process...")
    attention_map = get_attention_map(image)
    
    col1, col2 = st.columns(2)
    
    with col1:
        overlay_strength = st.slider("Heatmap Overlay Percentage", 0, 100, 50) / 100.0
    
    with col2:
        color_scheme_name = st.selectbox("Choose Heatmap Color Scheme", list(COLOR_SCHEMES.keys()))
    
    color_scheme = COLOR_SCHEMES[color_scheme_name]
    
    overlayed_image = overlay_attention_map(image, attention_map, overlay_strength, color_scheme)
    
    st.image(overlayed_image, caption='Image with Heatmap Overlay', use_column_width=True)
    
    # Option to download the overlayed image
    buf = io.BytesIO()
    overlayed_image.save(buf, format="PNG")
    btn = st.download_button(
        label="Download Image with Attention Map",
        data=buf.getvalue(),
        file_name="attention_map_overlay.png",
        mime="image/png"
    )

streamlit_analytics.stop_tracking()