top2vec / app /pages /02_Document_Explorer_πŸ“–.py
derek-thomas
Updating topic_word in AgGrid
176bc83
from logging import getLogger
from pathlib import Path
import pandas as pd
import plotly.express as px
import streamlit as st
from st_aggrid import AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder
from streamlit_plotly_events import plotly_events
from utilities import initialization
initialization()
# @st.cache(show_spinner=False)
# def initialize_state():
# with st.spinner("Loading app..."):
# if 'model' not in st.session_state:
# model = Top2Vec.load('models/model.pkl')
# model._check_model_status()
# model.hierarchical_topic_reduction(num_topics=20)
#
# st.session_state.model = model
# st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
# logger.info("loading data...")
#
# if 'data' not in st.session_state:
# logger.info("loading data...")
# data = pd.read_csv(proj_dir / 'data' / 'data.csv')
# data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
# st.session_state.data = data
# st.session_state.selected_data = data
# st.session_state.all_topics = list(data.topic_id.unique())
#
# if 'topics' not in st.session_state:
# logger.info("loading topics...")
# topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
# topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
# st.session_state.topics = topics
def reset():
logger.info("Resetting...")
st.session_state.selected_data = st.session_state.data
st.session_state.selected_points = []
def filter_df():
if st.session_state.selected_points:
points_df = pd.DataFrame(st.session_state.selected_points).loc[:, ['x', 'y']]
st.session_state.selected_data = st.session_state.data.merge(points_df, on=['x', 'y'])
logger.info(f"Updates selected_data: {len(st.session_state.selected_data)}")
else:
logger.info(f"Lame")
def reset():
st.session_state.selected_data = st.session_state.data
st.session_state.selected_points = []
def main():
st.write("""
# Topic Modeling
This shows a 2d representation of documents embeded in a semantic space. Each dot is a document
and the dots close represent documents that are close in meaning.
Zoom in and explore a topic of your choice. You can see the documents you select with the `lasso` or `box`
tool below in the corresponding tabs."""
)
st.button("Reset", help="Will Reset the selected points and the selected topics", on_click=reset)
data_to_model = st.session_state.data.sort_values(by='topic_id',
ascending=True) # to make legend sorted https://bioinformatics.stackexchange.com/a/18847
data_to_model['topic_id'].replace(st.session_state.topic_str_to_word, inplace=True)
fig = px.scatter(data_to_model, x='x', y='y', color='topic_id', template='plotly_dark',
hover_data=['id', 'topic_id', 'x', 'y'])
st.session_state.selected_points = plotly_events(fig, select_event=True, click_event=False)
filter_df()
tab1, tab2 = st.tabs(["Docs", "Topics"])
with tab1:
if st.session_state.selected_points:
filter_df()
cols = ['id', 'topic_id', 'documents']
data = st.session_state.selected_data[cols]
data['topic_word'] = data.topic_id.replace(st.session_state.topic_str_to_word)
ordered_cols = ['id', 'topic_id', 'topic_word', 'documents']
builder = GridOptionsBuilder.from_dataframe(data[ordered_cols])
builder.configure_pagination()
go = builder.build()
AgGrid(data[ordered_cols], theme='streamlit', gridOptions=go,
columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS)
else:
st.markdown('Select points in the graph with the `lasso` or `box` select tools to populate this table.')
def get_topics_counts() -> pd.DataFrame:
topic_counts = st.session_state.selected_data["topic_id"].value_counts().to_frame()
merged = topic_counts.merge(st.session_state.topics, left_index=True, right_on='topic_id')
cleaned = merged.drop(['topic_id_y'], axis=1).rename({'topic_id_x': 'topic_count'}, axis=1)
cols = ['topic_id'] + [col for col in cleaned.columns if col != 'topic_id']
return cleaned[cols]
with tab2:
if st.session_state.selected_points:
filter_df()
cols = ['topic_id', 'topic_count', 'topic_0']
topic_counts = get_topics_counts()
# st.write(topic_counts.columns)
builder = GridOptionsBuilder.from_dataframe(topic_counts[cols])
builder.configure_pagination()
builder.configure_column('topic_0', header_name='Topic Word', wrap_text=True)
go = builder.build()
AgGrid(topic_counts.loc[:, cols], theme='streamlit', gridOptions=go,
columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW)
else:
st.markdown('Select points in the graph with the `lasso` or `box` select tools to populate this table.')
if __name__ == "__main__":
# Setting up Logger and proj_dir
logger = getLogger(__name__)
proj_dir = Path(__file__).parents[2]
# For max width tables
pd.set_option('display.max_colwidth', 0)
# Streamlit settings
# st.set_page_config(layout="wide")
md_title = "# Document Explorer πŸ“–"
st.markdown(md_title)
st.sidebar.markdown(md_title)
# initialize_state()
main()