File size: 2,617 Bytes
74ce942
 
 
 
 
 
 
d5f15cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74ce942
ea72d75
74ce942
 
 
 
 
 
 
356174d
74ce942
ea72d75
 
 
74ce942
356174d
ea72d75
 
 
74ce942
 
 
ea72d75
74ce942
 
 
 
 
 
 
 
 
d5f15cb
74ce942
 
 
 
d5f15cb
ea72d75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from logging import getLogger
from pathlib import Path

import pandas as pd
import plotly.graph_objects as go
import streamlit as st

from utilities import initialization

initialization()


# @st.cache(show_spinner=False)
# def initialize_state():
#     with st.spinner("Loading app..."):
#         if 'model' not in st.session_state:
#             model = Top2Vec.load('models/model.pkl')
#             model._check_model_status()
#             model.hierarchical_topic_reduction(num_topics=20)
#
#             st.session_state.model = model
#             st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
#             logger.info("loading data...")
#
#         if 'data' not in st.session_state:
#             logger.info("loading data...")
#             data = pd.read_csv(proj_dir / 'data' / 'data.csv')
#             data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
#             st.session_state.data = data
#             st.session_state.selected_data = data
#             st.session_state.all_topics = list(data.topic_id.unique())
#
#         if 'topics' not in st.session_state:
#             logger.info("loading topics...")
#             topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
#             topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
#             st.session_state.topics = topics


def main():
    st.write(""" 
    A way to dive into each topic. Use the slider on the left to choose the topic.

    The `y` axis shows which words are closest to a topic centroid. The `x` axis shows how correlated they are.""")

    topic_num = st.sidebar.slider("Topic Number", 0, 19, value=0)
    topic_num_str = f"{topic_num:02}"
    fig = go.Figure(go.Bar(
            x=st.session_state.model.topic_word_scores_reduced[topic_num][::-1],
            y=st.session_state.model.topic_words_reduced[topic_num][::-1],
            orientation='h'))
    fig.update_layout(
            title=f'Words for Topic {topic_num_str}: {st.session_state.topic_str_to_word[topic_num_str]}',
            yaxis_title='Top 20 topic words',
            xaxis_title='Distance to topic centroid'
            )

    st.plotly_chart(fig, True)


if __name__ == "__main__":
    # Setting up Logger and proj_dir
    logger = getLogger(__name__)
    proj_dir = Path(__file__).parents[2]

    # For max width tables
    pd.set_option('display.max_colwidth', 0)

    # Streamlit settings
    # st.set_page_config(layout="wide")
    md_title = "# Topic Explorer πŸ“š"
    st.markdown(md_title)
    st.sidebar.markdown(md_title)

    # initialize_state()
    main()