import streamlit as st
import pandas as pd
import numpy as np
import os
import json
import gzip
import re
from urllib.parse import quote, unquote
# Updated CSS styles to use default background
CUSTOM_CSS = """
"""
def get_hierarchy_files():
hierarchy_dir = 'hierarchies'
if not os.path.exists(hierarchy_dir):
return []
files = [f for f in os.listdir(hierarchy_dir) if f.endswith('.json')]
print(f"Found files: {files}")
return files
def parse_filename(filename):
"""Parse hierarchy filename to extract metadata using improved patterns."""
filename = filename.replace('.json', '')
parts = filename.split('_')
# Basic fields that should be consistent
if len(parts) < 6:
return {
'date': 'Unknown',
'embedder': 'Unknown',
'summarizer': 'Unknown',
'clustermethod': 'Unknown',
'contribution_type': 'Unknown',
'building_method': 'Unknown',
'clusterlevel': 'Unknown',
'clusterlevel_array': [],
'level_count': 0,
'random_seed': 'Unknown'
}
# These are consistent across formats
date_str = parts[1]
embedder = parts[2]
summarizer = parts[3]
clustermethod = parts[4]
# parts[5] is typically "emb" placeholder
contribution_type = parts[6]
# Special handling for building methods
# Check for compound building methods
building_method = None
clusterlevel_str = None
seed = None
# Handle different cases for building method and what follows
if len(parts) > 7:
if parts[7] == "bidirectional":
building_method = "bidirectional"
if len(parts) > 8:
# The cluster level is next
clusterlevel_str = parts[8]
if len(parts) > 9:
seed = parts[9]
elif parts[7] == "top" and len(parts) > 8 and parts[8] == "down":
building_method = "top_down"
if len(parts) > 9:
clusterlevel_str = parts[9]
if len(parts) > 10:
seed = parts[10]
elif parts[7] == "bottom" and len(parts) > 8 and parts[8] == "up":
building_method = "bottom_up"
if len(parts) > 9:
clusterlevel_str = parts[9]
if len(parts) > 10:
seed = parts[10]
# Default case - building method is not compound
else:
building_method = parts[7]
if len(parts) > 8:
clusterlevel_str = parts[8]
if len(parts) > 9:
seed = parts[9]
# Format date with slashes for better readability
formatted_date = f"{date_str[:4]}/{date_str[4:6]}/{date_str[6:]}" if len(date_str) == 8 else date_str
# Process cluster levels
clusterlevel_array = clusterlevel_str.split('-') if clusterlevel_str else []
level_count = len(clusterlevel_array)
return {
'date': formatted_date,
'embedder': embedder,
'summarizer': summarizer,
'clustermethod': clustermethod,
'contribution_type': contribution_type,
'building_method': building_method or 'Unknown',
'clusterlevel': clusterlevel_str or 'Unknown',
'clusterlevel_array': clusterlevel_array,
'level_count': level_count,
'random_seed': seed or 'Unknown'
}
def format_hierarchy_option(filename):
info = parse_filename(filename)
levels_str = "×".join(info['clusterlevel_array'])
return f"{info['date']} - {info['clustermethod']} ({info['embedder']}/{info['summarizer']}, {info['contribution_type']}, {info['building_method']}, {info['level_count']} levels: {levels_str}, seed: {info['random_seed']})"
@st.cache_data
def load_hierarchy_data(filename):
"""Load hierarchy data with support for compressed files"""
filepath = os.path.join('hierarchies', filename)
# 检查是否存在未压缩版本
if os.path.exists(filepath):
with open(filepath, 'r') as f:
return json.load(f)
# 检查是否存在 gzip 压缩版本
gzip_filepath = filepath + '.gz'
if os.path.exists(gzip_filepath):
try:
with gzip.open(gzip_filepath, 'rt') as f:
return json.load(f)
except Exception as e:
st.error(f"Error loading compressed file {gzip_filepath}: {str(e)}")
return {"clusters": []}
st.error(f"Could not find hierarchy file: {filepath} or {gzip_filepath}")
return {"clusters": []}
def get_cluster_statistics(clusters):
"""获取集群统计信息,包括悬停提示"""
def count_papers(node):
if "children" not in node:
return 0
children = node["children"]
if not children:
return 0
if "paper_id" in children[0]:
return len(children)
return sum(count_papers(child) for child in children)
cluster_count = len(clusters)
paper_counts = []
for cluster, _ in clusters:
paper_count = count_papers(cluster)
paper_counts.append(paper_count)
if paper_counts:
total_papers = sum(paper_counts)
average_papers = total_papers / cluster_count if cluster_count > 0 else 0
return {
'Total Clusters': {'value': cluster_count, 'tooltip': 'Total number of clusters at this level'},
'Total Papers': {'value': total_papers, 'tooltip': 'Total number of papers across all clusters at this level'},
'Average Papers per Cluster': {'value': round(average_papers, 2), 'tooltip': 'Average number of papers per cluster'},
'Median Papers': {'value': round(np.median(paper_counts), 2), 'tooltip': 'Median number of papers per cluster'},
'Standard Deviation': {'value': round(np.std(paper_counts), 2), 'tooltip': 'Standard deviation of paper counts across clusters'},
'Max Papers in Cluster': {'value': max(paper_counts), 'tooltip': 'Maximum number of papers in any single cluster'},
'Min Papers in Cluster': {'value': min(paper_counts), 'tooltip': 'Minimum number of papers in any single cluster'}
}
return {
'Total Clusters': {'value': cluster_count, 'tooltip': 'Total number of clusters at this level'},
'Total Papers': {'value': 0, 'tooltip': 'Total number of papers across all clusters at this level'},
'Average Papers per Cluster': {'value': 0, 'tooltip': 'Average number of papers per cluster'},
'Median Papers': {'value': 0, 'tooltip': 'Median number of papers per cluster'},
'Standard Deviation': {'value': 0, 'tooltip': 'Standard deviation of paper counts across clusters'},
'Max Papers in Cluster': {'value': 0, 'tooltip': 'Maximum number of papers in any single cluster'},
'Min Papers in Cluster': {'value': 0, 'tooltip': 'Minimum number of papers in any single cluster'}
}
def calculate_citation_metrics(node):
"""Calculate total, average, and maximum citation and influential citation counts for a cluster."""
total_citations = 0
total_influential_citations = 0
paper_count = 0
citation_values = [] # 存储每篇论文的引用数
influential_citation_values = [] # 存储每篇论文的有影响力引用数
def process_node(n):
nonlocal total_citations, total_influential_citations, paper_count
if "children" not in n or n["children"] is None:
return
children = n["children"]
if not children:
return
# If this node contains papers directly
if children and len(children) > 0 and isinstance(children[0], dict) and "paper_id" in children[0]:
for paper in children:
if not isinstance(paper, dict):
continue
semantic_scholar = paper.get('semantic_scholar', {}) or {}
citations = semantic_scholar.get('citationCount', 0)
influential_citations = semantic_scholar.get('influentialCitationCount', 0)
total_citations += citations
total_influential_citations += influential_citations
paper_count += 1
citation_values.append(citations)
influential_citation_values.append(influential_citations)
else:
# Recursively process child clusters
for child in children:
if isinstance(child, dict):
process_node(child)
process_node(node)
# 计算平均值和最大值
avg_citations = round(total_citations / paper_count, 2) if paper_count > 0 else 0
avg_influential_citations = round(total_influential_citations / paper_count, 2) if paper_count > 0 else 0
max_citations = max(citation_values) if citation_values else 0
max_influential_citations = max(influential_citation_values) if influential_citation_values else 0
return {
'total_citations': total_citations,
'avg_citations': avg_citations,
'max_citations': max_citations,
'total_influential_citations': total_influential_citations,
'avg_influential_citations': avg_influential_citations,
'max_influential_citations': max_influential_citations,
'paper_count': paper_count
}
def find_clusters_in_path(data, path):
"""Find clusters or papers at the given path in the hierarchy."""
if not data or "clusters" not in data:
return []
clusters = data["clusters"]
current_clusters = []
if not path:
return [(cluster, []) for cluster in clusters]
current = clusters
for i, p in enumerate(path):
found = False
for cluster in current:
if cluster.get("cluster_id") == p:
if "children" not in cluster or not cluster["children"]:
# No children found, return empty list
return []
current = cluster["children"]
found = True
if i == len(path) - 1:
# We're at the target level
if current and len(current) > 0 and isinstance(current[0], dict) and "paper_id" in current[0]:
# This level contains papers
return [(paper, path) for paper in current]
else:
# This level contains subclusters
current_clusters = []
for c in current:
if isinstance(c, dict):
cluster_id = c.get("cluster_id")
if cluster_id is not None:
current_clusters.append((c, path + [cluster_id]))
return current_clusters
break
if not found:
# Path segment not found
return []
return current_clusters
def parse_json_abstract(abstract_text):
"""Parse JSON formatted abstract string into a beautifully formatted HTML string"""
try:
abstract_json = json.loads(abstract_text)
# Create a formatted display for the structured abstract
if "Problem" in abstract_json:
problem = abstract_json["Problem"]
return f"""
Problem
Domain:
{problem.get('overarching problem domain', 'N/A')}
Challenges:
{problem.get('challenges/difficulties', 'N/A')}
Goal:
{problem.get('research question/goal', 'N/A')}
"""
return abstract_text
except (json.JSONDecodeError, ValueError, TypeError):
# If not valid JSON, return the original text
return abstract_text
def display_path_details(path, data, level_count):
if not path:
return
st.markdown("### Path Details")
current = data["clusters"]
# Dynamically generate level labels and containers
for i, cluster_id in enumerate(path):
# 修改这里:使用 i + 1 作为层级编号
level_number = i + 1 # 从1开始计算层级,顶层是Level 1
indent = i * 32 # Indent 32 pixels per level
for c in current:
if c["cluster_id"] == cluster_id:
# Create a container with proper indentation
st.markdown(f"""
""", unsafe_allow_html=True)
# Add extra spacing at the bottom
st.markdown("", unsafe_allow_html=True)
# Create a row with cluster name and level button
col1, col2 = st.columns([0.85, 0.15])
with col1:
st.markdown(f"""
Cluster {c["cluster_id"]}: {c["title"]}
""", unsafe_allow_html=True)
with col2:
button_clicked = st.button(f'Level {level_number}', key=f'level_btn_{i}_{c["cluster_id"]}')
if button_clicked:
st.session_state.path = path[:i]
new_params = {}
new_params['hierarchy'] = st.query_params['hierarchy']
if st.session_state.path:
new_params['path'] = st.session_state.path
st.query_params.clear()
for key, value in new_params.items():
if isinstance(value, list):
for v in value:
st.query_params[key] = v
else:
st.query_params[key] = value
st.rerun()
# Calculate left margin for expander content to align with the header
# Use an extra container with margin to create the indentation
with st.container():
st.markdown(f"""
""", unsafe_allow_html=True)
# Remove the key parameter that was causing the error
with st.expander("📄 Show Cluster Details", expanded=False):
# Parse abstract if it's in JSON format
abstract_content = parse_json_abstract(c["abstract"])
st.markdown(f"""
{abstract_content}
""", unsafe_allow_html=True)
current = c["children"]
break
def display_paper(item):
"""Display detailed paper information including problem, solution, and results with semantic scholar info"""
# Check for semantic scholar data with proper fallbacks
semantic_scholar = item.get('semantic_scholar', {}) or {}
url = semantic_scholar.get('url', '')
citation_count = semantic_scholar.get('citationCount', 0)
influential_citation_count = semantic_scholar.get('influentialCitationCount', 0)
fields_of_study = semantic_scholar.get('fieldsOfStudy', []) or []
# Generate field badges HTML
field_badges_html = ""
for field in fields_of_study:
field_badges_html += f"{field} "
# Basic information section with URL link and citation counts - Always visible
st.markdown(f"""
{item.get('title', 'Untitled Paper')}
🔗
⭐ Citations: {citation_count}
🔥 Influential Citations: {influential_citation_count}
""", unsafe_allow_html=True)
# One main expander for all detailed information - Default collapsed
with st.expander("📑 Show Detailed Information", expanded=False):
# Abstract section
st.markdown("""
📄 Abstract
""", unsafe_allow_html=True)
abstract_text = item.get('abstract', 'No abstract available')
st.markdown(f"
{abstract_text}
", unsafe_allow_html=True)
# Problem section
if 'problem' in item and item['problem']:
st.markdown("""
🔍 Problem Details
""", unsafe_allow_html=True)
problem = item['problem']
cols = st.columns([1, 2])
with cols[0]:
st.markdown("""
Problem Domain
""", unsafe_allow_html=True)
st.markdown("""
Challenges/Difficulties
""", unsafe_allow_html=True)
st.markdown("""
Research Question/Goal
""", unsafe_allow_html=True)
with cols[1]:
st.markdown(f"""
{problem.get('overarching problem domain', 'Not specified')}
""", unsafe_allow_html=True)
st.markdown(f"""
{problem.get('challenges/difficulties', 'Not specified')}
""", unsafe_allow_html=True)
st.markdown(f"""
{problem.get('research question/goal', 'Not specified')}
""", unsafe_allow_html=True)
# Solution section
if 'solution' in item and item['solution']:
st.markdown("""
💡 Solution Details
""", unsafe_allow_html=True)
solution = item['solution']
cols = st.columns([1, 2])
with cols[0]:
st.markdown("""
Solution Domain
""", unsafe_allow_html=True)
st.markdown("""
Solution Approach
""", unsafe_allow_html=True)
st.markdown("""
Novelty of Solution
""", unsafe_allow_html=True)
with cols[1]:
st.markdown(f"""
{solution.get('overarching solution domain', 'Not specified')}
""", unsafe_allow_html=True)
st.markdown(f"""
{solution.get('solution approach', 'Not specified')}
""", unsafe_allow_html=True)
st.markdown(f"""
{solution.get('novelty of the solution', 'Not specified')}
""", unsafe_allow_html=True)
# Results section
if 'results' in item and item['results']:
st.markdown("""
📊 Results Details
""", unsafe_allow_html=True)
results = item['results']
cols = st.columns([1, 2])
with cols[0]:
st.markdown("""
Findings/Results
""", unsafe_allow_html=True)
st.markdown("""
Potential Impact
""", unsafe_allow_html=True)
with cols[1]:
st.markdown(f"""
{results.get('findings/results', 'Not specified')}
""", unsafe_allow_html=True)
st.markdown(f"""
{results.get('potential impact of the results', 'Not specified')}
""", unsafe_allow_html=True)
# Author information
if 'semantic_scholar' in item and item['semantic_scholar'] and 'authors' in item['semantic_scholar'] and item['semantic_scholar']['authors']:
st.markdown("""
👥 Authors
""", unsafe_allow_html=True)
authors = item['semantic_scholar']['authors'] or []
for author in authors:
if not isinstance(author, dict):
continue
st.markdown(f"""
{author.get('name', 'Unknown')}
Author ID: {author.get('authorId', 'N/A')}
Papers
{author.get('paperCount', 0)}
Citations
{author.get('citationCount', 0)}
h-index
{author.get('hIndex', 0)}
""", unsafe_allow_html=True)
# Close paper-card div
st.markdown("
", unsafe_allow_html=True)
def display_cluster(item, path):
"""Display a collapsible cluster with citation metrics integrated into the header, including abstract expander and buttons"""
# Generate a unique ID for this cluster for the expander functionality
cluster_id = item['cluster_id']
unique_id = f"{cluster_id}_{'-'.join(map(str, path))}"
# Calculate citation metrics using the updated function
citation_metrics = calculate_citation_metrics(item)
# Parse the abstract
abstract_content = parse_json_abstract(item['abstract'])
# 根据是否包含子项来设置按钮文本和行为
has_children = "children" in item and item["children"]
if has_children:
count = citation_metrics['paper_count'] if "paper_id" in item["children"][0] else len(item["children"])
next_level_items = item["children"]
is_next_level_papers = len(next_level_items) > 0 and "paper_id" in next_level_items[0]
btn_text = f'View Papers ({count})' if is_next_level_papers else f'View Sub-clusters ({count})'
# 标题和论文数量显示 - 确保它们在同一水平线上
st.markdown(f"""
{item['title']}
📑{citation_metrics['paper_count']} papers
""", unsafe_allow_html=True)
# 使用两列布局
cols = st.columns([8, 2])
with cols[0]: # 统计数据区域
# 引用统计格式:使用管道符号分隔
st.markdown(f"""
⭐ Citations:
Total {citation_metrics['total_citations']} |
Avg {citation_metrics['avg_citations']} |
Max {citation_metrics['max_citations']}
🔥 Influential Citations:
Total {citation_metrics['total_influential_citations']} |
Avg {citation_metrics['avg_influential_citations']} |
Max {citation_metrics['max_influential_citations']}
""", unsafe_allow_html=True)
# 创建摘要展开器 - 修改文本为"Cluster Summary"
with st.expander("📄 Cluster Summary", expanded=False):
st.markdown(f"""
{abstract_content}
""", unsafe_allow_html=True)
with cols[1]: # 查看按钮
# 如果有子集群或论文,添加查看按钮
if has_children:
# 使用动态生成的按钮文本,而不是固定的"View Sub-Cluster"
if st.button(btn_text, key=f"btn_{unique_id}"):
st.session_state.path.append(item['cluster_id'])
st.rerun()
# 创建一个分隔线
st.markdown("
", unsafe_allow_html=True)
def main():
st.set_page_config(
layout="wide",
page_title="Paper Clusters Explorer",
initial_sidebar_state="expanded",
menu_items=None
)
# 设置浅色主题
st.markdown("""
""", unsafe_allow_html=True)
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
hierarchy_files = get_hierarchy_files()
if not hierarchy_files:
st.error("No hierarchy files found in /hierarchies directory")
return
# Manage file selection via query params
current_url = st.query_params.get('hierarchy', None)
current_file = unquote(current_url) + '.json' if current_url else None
hierarchy_options = {format_hierarchy_option(f): f for f in hierarchy_files}
selected_option = st.selectbox(
'Select Hierarchy',
options=list(hierarchy_options.keys()),
index=list(hierarchy_options.values()).index(current_file) if current_file else 0
)
selected_file = hierarchy_options[selected_option]
# Save selected file in query params
if selected_file != current_file:
st.query_params['hierarchy'] = quote(selected_file.replace('.json', ''))
data = load_hierarchy_data(selected_file)
info = parse_filename(selected_file)
# Hierarchy metadata and navigation state
with st.expander("📋 Hierarchy Metadata", expanded=False):
# Create a grid layout for metadata
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"""
Clustering Method
{info['clustermethod']}
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
Embedder / Summarizer
{info['embedder']} / {info['summarizer']}
Contribution Type
{info['contribution_type']}
""", unsafe_allow_html=True)
with col3:
st.markdown(f"""
Building Method
{info['building_method']}
Cluster Levels
{info['clusterlevel']} (Total: {info['level_count']})
""", unsafe_allow_html=True)
if 'path' not in st.session_state:
path_params = st.query_params.get_all('path')
st.session_state.path = [p for p in path_params if p]
current_clusters = find_clusters_in_path(data, st.session_state.path)
current_level = len(st.session_state.path)
total_levels = info['level_count']
level_name = f'Level {current_level + 1}' if current_level < total_levels else 'Papers'
is_paper_level = current_level >= total_levels or (current_clusters and "paper_id" in current_clusters[0][0])
if not is_paper_level and current_clusters:
with st.expander("📊 Cluster Statistics", expanded=False):
stats = get_cluster_statistics(current_clusters)
# Create a 3x2 grid for six small metric cards
row1_col1, row1_col2, row1_col3 = st.columns(3)
row2_col1, row2_col2, row2_col3 = st.columns(3)
# Row 1 - First 3 metrics
with row1_col1:
st.markdown(f"""
Total Clusters
{stats['Total Clusters']['value']}
""", unsafe_allow_html=True)
with row1_col2:
st.markdown(f"""
Total Papers
{stats['Total Papers']['value']}
""", unsafe_allow_html=True)
with row1_col3:
st.markdown(f"""
Avg Papers/Cluster
{stats['Average Papers per Cluster']['value']}
""", unsafe_allow_html=True)
# Row 2 - Next 3 metrics
with row2_col1:
st.markdown(f"""
Median Papers
{stats['Median Papers']['value']}
""", unsafe_allow_html=True)
with row2_col2:
st.markdown(f"""
Max Papers in Cluster
{stats['Max Papers in Cluster']['value']}
""", unsafe_allow_html=True)
with row2_col3:
st.markdown(f"""
Min Papers in Cluster
{stats['Min Papers in Cluster']['value']}
""", unsafe_allow_html=True)
# Back navigation button
if st.session_state.path:
if st.button('← Back', key='back_button'):
st.session_state.path.pop()
st.rerun()
# Current path display
if st.session_state.path:
# 获取路径上每个聚类的标题
path_info = []
current = data["clusters"]
# 构建路径中每个聚类的标题和层级信息
for i, cid in enumerate(st.session_state.path):
level_num = i + 1 # 从1开始的层级编号
for c in current:
if c["cluster_id"] == cid:
path_info.append((level_num, c["title"], c["cluster_id"]))
current = c["children"]
break
# 在Streamlit中创建路径导航
with st.container():
st.markdown("🗂️ Current Path
", unsafe_allow_html=True)
# 🔝 添加 Root 入口
col1, col2 = st.columns([0.3, 0.7])
with col1:
st.markdown(f"Root:
", unsafe_allow_html=True)
with col2:
if st.button("All Papers", key="root_button"):
st.session_state.path = []
st.rerun()
# 使用缩进显示路径层次结构
for i, (level_num, title, cluster_id) in enumerate(path_info):
col1, col2 = st.columns([0.3, 0.7])
with col1:
st.markdown(f"Level {level_num}:
", unsafe_allow_html=True)
with col2:
# 创建用于返回到该级别的按钮
if st.button(f"{title}", key=f"lvl_{i}_{cluster_id}"):
# 当按钮被点击时,将路径截断到该级别
st.session_state.path = st.session_state.path[:i+1]
st.rerun()
# 内容展示标题
st.markdown(f"""
{'📑 Papers' if is_paper_level else '📂 ' + level_name}
""", unsafe_allow_html=True)
for item, full_path in current_clusters:
if is_paper_level:
display_paper(item)
else:
display_cluster(item, full_path)
if __name__ == '__main__':
main()