import streamlit as st
import pickle

import io
from typing import List, Optional

import markdown
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
from plotly import express as px
from plotly.subplots import make_subplots
from tqdm import trange

import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification

@st.cache(allow_output_mutation=True)
# @st.cache_resource
def load_dataset(data_index):
    with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
        dataset = pickle.load(file)
    return dataset

@st.cache(allow_output_mutation=True)
# @st.cache_resource
def load_dataset_dict():
    dataset_dict = {}
    progress_empty = st.empty()
    text_empty = st.empty()
    text_empty.write("Loading datasets...")
    progress_bar = progress_empty.progress(0.0)
    for data_index in trange(5):
        dataset_dict[data_index] = load_dataset(data_index)
        progress_bar.progress((data_index+1)/5)
    progress_empty.empty()
    text_empty.empty()
    return dataset_dict


# @st.cache_data
@st.cache(allow_output_mutation=True)
def load_image(image_id):
    dataset = load_dataset(image_id//10000)
    image = dataset[image_id%10000]
    return image

# @st.cache_data
@st.cache(allow_output_mutation=True)
def load_images(image_ids):
    images = []
    for image_id in image_ids:
        image = load_image(image_id)
        images.append(image)
    return images


@st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
# @st.cache_resource
def load_model(model_name):
    with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
        if model_name == 'ResNet':
            model_file_path = 'microsoft/resnet-50'
            feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
            model = AutoModelForImageClassification.from_pretrained(model_file_path)
            model.eval()
        elif model_name == 'ConvNeXt':
            model_file_path = 'facebook/convnext-tiny-224'
            feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
            model = AutoModelForImageClassification.from_pretrained(model_file_path)
            model.eval()
        else:
            model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
            model.eval()
            feature_extractor = None
    return model, feature_extractor


def make_grid(cols=None,rows=None):
    grid = [0]*rows
    for i in range(rows):
        with st.container():
            grid[i] = st.columns(cols)
    return grid


def use_container_width_percentage(percentage_width:int = 75):
    max_width_str = f"max-width: {percentage_width}%;"
    st.markdown(f""" 
                <style> 
                .reportview-container .main .block-container{{{max_width_str}}}
                </style>    
                """, 
                unsafe_allow_html=True,
    )

matplotlib.use("Agg")
COLOR = "#31333f"
BACKGROUND_COLOR = "#ffffff"


def grid_demo():
    """Main function. Run this to run the app"""
    st.sidebar.title("Layout and Style Experiments")
    st.sidebar.header("Settings")
    st.markdown(
        """
# Layout and Style Experiments

The basic question is: Can we create a multi-column dashboard with plots, numbers and text using
the [CSS Grid](https://gridbyexample.com/examples)?

Can we do it with a nice api?
Can have a dark theme?
"""
    )

    select_block_container_style()
    add_resources_section()

    # My preliminary idea of an API for generating a grid
    with Grid("1 1 1", color=COLOR, background_color=BACKGROUND_COLOR) as grid:
        grid.cell(
            class_="a",
            grid_column_start=2,
            grid_column_end=3,
            grid_row_start=1,
            grid_row_end=2,
        ).markdown("# This is A Markdown Cell")
        grid.cell("b", 2, 3, 2, 3).text("The cell to the left is a dataframe")
        grid.cell("c", 3, 4, 2, 3).plotly_chart(get_plotly_fig())
        grid.cell("d", 1, 2, 1, 3).dataframe(get_dataframe())
        grid.cell("e", 3, 4, 1, 2).markdown(
            "Try changing the **block container style** in the sidebar!"
        )
        grid.cell("f", 1, 3, 3, 4).text(
            "The cell to the right is a matplotlib svg image"
        )
        grid.cell("g", 3, 4, 3, 4).pyplot(get_matplotlib_plt())


def add_resources_section():
    """Adds a resources section to the sidebar"""
    st.sidebar.header("Add_resources_section")
    st.sidebar.markdown(
        """
- [gridbyexample.com] (https://gridbyexample.com/examples/)
"""
    )


class Cell:
    """A Cell can hold text, markdown, plots etc."""

    def __init__(
        self,
        class_: str = None,
        grid_column_start: Optional[int] = None,
        grid_column_end: Optional[int] = None,
        grid_row_start: Optional[int] = None,
        grid_row_end: Optional[int] = None,
    ):
        self.class_ = class_
        self.grid_column_start = grid_column_start
        self.grid_column_end = grid_column_end
        self.grid_row_start = grid_row_start
        self.grid_row_end = grid_row_end
        self.inner_html = ""

    def _to_style(self) -> str:
        return f"""
.{self.class_} {{
    grid-column-start: {self.grid_column_start};
    grid-column-end: {self.grid_column_end};
    grid-row-start: {self.grid_row_start};
    grid-row-end: {self.grid_row_end};
}}
"""

    def text(self, text: str = ""):
        self.inner_html = text

    def markdown(self, text):
        self.inner_html = markdown.markdown(text)

    def dataframe(self, dataframe: pd.DataFrame):
        self.inner_html = dataframe.to_html()

    def plotly_chart(self, fig):
        self.inner_html = f"""
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<body>
    <p>This should have been a plotly plot.
    But since *script* tags are removed when inserting MarkDown/ HTML i cannot get it to workto work.
    But I could potentially save to svg and insert that.</p>
    <div id='divPlotly'></div>
    <script>
        var plotly_data = {fig.to_json()}
        Plotly.react('divPlotly', plotly_data.data, plotly_data.layout);
    </script>
</body>
"""

    def pyplot(self, fig=None, **kwargs):
        string_io = io.StringIO()
        plt.savefig(string_io, format="svg", fig=(2, 2))
        svg = string_io.getvalue()[215:]
        plt.close(fig)
        self.inner_html = '<div height="200px">' + svg + "</div>"

    def _to_html(self):
        return f"""<div class="box {self.class_}">{self.inner_html}</div>"""


class Grid:
    """A (CSS) Grid"""

    def __init__(
        self,
        template_columns="1 1 1",
        gap="10px",
        background_color=COLOR,
        color=BACKGROUND_COLOR,
    ):
        self.template_columns = template_columns
        self.gap = gap
        self.background_color = background_color
        self.color = color
        self.cells: List[Cell] = []

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        st.markdown(self._get_grid_style(), unsafe_allow_html=True)
        st.markdown(self._get_cells_style(), unsafe_allow_html=True)
        st.markdown(self._get_cells_html(), unsafe_allow_html=True)

    def _get_grid_style(self):
        return f"""
<style>
    .wrapper {{
    display: grid;
    grid-template-columns: {self.template_columns};
    grid-gap: {self.gap};
    background-color: {self.color};
    color: {self.background_color};
    }}
    .box {{
    background-color: {self.color};
    color: {self.background_color};
    border-radius: 0px;
    padding: 0px;
    font-size: 100%;
    text-align: center;
    }}
    table {{
        color: {self.color}
    }}
</style>
"""

    def _get_cells_style(self):
        return (
            "<style>"
            + "\n".join([cell._to_style() for cell in self.cells])
            + "</style>"
        )

    def _get_cells_html(self):
        return (
            '<div class="wrapper">'
            + "\n".join([cell._to_html() for cell in self.cells])
            + "</div>"
        )

    def cell(
        self,
        class_: str = None,
        grid_column_start: Optional[int] = None,
        grid_column_end: Optional[int] = None,
        grid_row_start: Optional[int] = None,
        grid_row_end: Optional[int] = None,
    ):
        cell = Cell(
            class_=class_,
            grid_column_start=grid_column_start,
            grid_column_end=grid_column_end,
            grid_row_start=grid_row_start,
            grid_row_end=grid_row_end,
        )
        self.cells.append(cell)
        return cell


def select_block_container_style():
    """Add selection section for setting setting the max-width and padding
    of the main block container"""
    st.sidebar.header("Block Container Style")
    max_width_100_percent = st.sidebar.checkbox("Max-width: 100%?", False)
    if not max_width_100_percent:
        max_width = st.sidebar.slider("Select max-width in px", 100, 2000, 1200, 100)
    else:
        max_width = 1200
    dark_theme = st.sidebar.checkbox("Dark Theme?", False)
    padding_top = st.sidebar.number_input("Select padding top in rem", 0, 200, 5, 1)
    padding_right = st.sidebar.number_input("Select padding right in rem", 0, 200, 1, 1)
    padding_left = st.sidebar.number_input("Select padding left in rem", 0, 200, 1, 1)
    padding_bottom = st.sidebar.number_input(
        "Select padding bottom in rem", 0, 200, 10, 1
    )
    if dark_theme:
        global COLOR
        global BACKGROUND_COLOR
        BACKGROUND_COLOR = "rgb(17,17,17)"
        COLOR = "#fff"

    _set_block_container_style(
        max_width,
        max_width_100_percent,
        padding_top,
        padding_right,
        padding_left,
        padding_bottom,
    )


def _set_block_container_style(
    max_width: int = 1200,
    max_width_100_percent: bool = False,
    padding_top: int = 5,
    padding_right: int = 1,
    padding_left: int = 1,
    padding_bottom: int = 10,
):
    if max_width_100_percent:
        max_width_str = f"max-width: 100%;"
    else:
        max_width_str = f"max-width: {max_width}px;"
    st.markdown(
        f"""
<style>
    .reportview-container .main .block-container{{
        {max_width_str}
        padding-top: {padding_top}rem;
        padding-right: {padding_right}rem;
        padding-left: {padding_left}rem;
        padding-bottom: {padding_bottom}rem;
    }}
    .reportview-container .main {{
        color: {COLOR};
        background-color: {BACKGROUND_COLOR};
    }}
</style>
""",
        unsafe_allow_html=True,
    )


# @st.cache
# def get_dataframe() -> pd.DataFrame():
#     """Dummy DataFrame"""
#     data = [
#         {"quantity": 1, "price": 2},
#         {"quantity": 3, "price": 5},
#         {"quantity": 4, "price": 8},
#     ]
#     return pd.DataFrame(data)


# def get_plotly_fig():
#     """Dummy Plotly Plot"""
#     return px.line(data_frame=get_dataframe(), x="quantity", y="price")


# def get_matplotlib_plt():
#     get_dataframe().plot(kind="line", x="quantity", y="price", figsize=(5, 3))