import json

from Levenshtein import distance
import streamlit as st
import numpy as np
import plotly.express as px
from sklearn.decomposition import PCA


def load_data():

    embeddings = np.load("data/simplesegmentT5_embeddings.npy")
    words = json.load(open("data/words.json", "r"))

    return embeddings, words


def project_embeddings(embeddings):
    pca = PCA(n_components=3)
    proj = pca.fit_transform(embeddings)

    return proj


def filter_words(words, remove_capitalized, length):
    idx = []
    for i, w in enumerate(words):

        if remove_capitalized and w.lower() != w:
            continue

        if len(w) < length[0] or len(w) > length[1]:
            continue

        idx.append(i)

    return idx


def color_length(words):
    return [len(w) for w in words]


def color_first_letter(words):
    return [min(1, max(0, (ord(w.lower()[0]) - 97) / 26)) for w in words]


def color_levenshtein(words):
    return [distance(w, words[4]) for w in words]


def plot_scatter(words, embeddings, remove_capitalized, length, color_select):

    idx = filter_words(words, remove_capitalized, length)

    filtered_embeddings = embeddings[idx]
    filtered_words = [words[i] for i in idx]

    proj = project_embeddings(filtered_embeddings)

    if color_select == "Word length":
        color = color_length(filtered_words)
    else:
        color = color_levenshtein(filtered_words)

    fig = px.scatter_3d(
        x=proj[:, 0],
        y=proj[:, 1],
        z=proj[:, 2],
        width=800,
        height=600,
        color=color,
        color_continuous_scale=px.colors.sequential.Viridis,
        hover_name=filtered_words,
        title="SimpleSegmentT5 Embeddings",
    )

    fig.update_traces(
        marker={"size": 6, "line": {"width": 2}},
        selector={"mode": "markers"},
    )

    return fig


def main():
    embeddings, words = load_data()

    proj = project_embeddings(embeddings)

    fig = px.scatter_3d(
        x=proj[:, 0],
        y=proj[:, 1],
        z=proj[:, 2],
        color=[len(w) for w in words],
        hover_name=words,
        title="SimpleSegmentT5 Embeddings",
    )

    st.sidebar.title("Settings")

    remove_checkbox = st.sidebar.checkbox(
        "Remove capitalized words",
        value=True,
        key="include_capitalized",
    )

    length_slider = st.sidebar.slider("Word length", 3, 9, (3, 9))
    color_select = st.sidebar.radio("Color by", ["Word length", "Levenshtein distance to random word"])

    scatter = st.plotly_chart(plot_scatter(words, embeddings, remove_checkbox, length_slider, color_select))


if __name__ == "__main__":

    main()