Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tempfile | |
import os | |
import json | |
from typing import List, Dict, Any, Optional, Tuple | |
import traceback | |
# Import our modules | |
from src.document_processor import DocumentProcessor | |
from src.llm_extractor import LLMExtractor | |
from src.graph_builder import GraphBuilder | |
from src.visualizer import GraphVisualizer | |
from config.settings import Config | |
# Page config | |
st.set_page_config( | |
page_title="Knowledge Graph Extraction", | |
page_icon="πΈοΈ", | |
layout="wide" | |
) | |
# Initialize components | |
def initialize_components(): | |
config = Config() | |
doc_processor = DocumentProcessor() | |
llm_extractor = LLMExtractor() | |
graph_builder = GraphBuilder() | |
visualizer = GraphVisualizer() | |
return config, doc_processor, llm_extractor, graph_builder, visualizer | |
config, doc_processor, llm_extractor, graph_builder, visualizer = initialize_components() | |
def process_uploaded_files(uploaded_files, api_key, batch_mode, visualization_type, layout_type, | |
show_labels, show_edge_labels, min_importance, entity_types_filter): | |
"""Process uploaded files and extract knowledge graph.""" | |
try: | |
# Update API key | |
if api_key.strip(): | |
config.OPENROUTER_API_KEY = api_key.strip() | |
llm_extractor.config.OPENROUTER_API_KEY = api_key.strip() | |
llm_extractor.headers["Authorization"] = f"Bearer {api_key.strip()}" | |
if not config.OPENROUTER_API_KEY: | |
st.error("β OpenRouter API key is required") | |
return None | |
if not uploaded_files: | |
st.error("β Please upload at least one file") | |
return None | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
status_text.text("Loading documents...") | |
progress_bar.progress(0.1) | |
# Save uploaded files to temporary location | |
file_paths = [] | |
for uploaded_file in uploaded_files: | |
# Create temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{uploaded_file.name}") as tmp_file: | |
tmp_file.write(uploaded_file.getvalue()) | |
file_paths.append(tmp_file.name) | |
# Process documents | |
doc_results = doc_processor.process_documents(file_paths, batch_mode) | |
# Clean up temporary files | |
for file_path in file_paths: | |
try: | |
os.unlink(file_path) | |
except: | |
pass | |
# Check for errors | |
failed_files = [r for r in doc_results if r['status'] == 'error'] | |
if failed_files: | |
error_msg = "Failed to process files:\n" + "\n".join([f"- {r['file_path']}: {r['error']}" for r in failed_files]) | |
if len(failed_files) == len(doc_results): | |
st.error(f"β {error_msg}") | |
return None | |
status_text.text("Extracting entities and relationships...") | |
progress_bar.progress(0.3) | |
# Extract entities and relationships | |
all_entities = [] | |
all_relationships = [] | |
extraction_errors = [] | |
for doc_result in doc_results: | |
if doc_result['status'] == 'success': | |
extraction_result = llm_extractor.process_chunks(doc_result['chunks']) | |
if extraction_result.get('errors'): | |
extraction_errors.extend(extraction_result['errors']) | |
all_entities.extend(extraction_result.get('entities', [])) | |
all_relationships.extend(extraction_result.get('relationships', [])) | |
if not all_entities: | |
error_msg = "No entities extracted from documents" | |
if extraction_errors: | |
error_msg += f"\nExtraction errors: {'; '.join(extraction_errors[:3])}" | |
st.error(f"β {error_msg}") | |
return None | |
status_text.text("Building knowledge graph...") | |
progress_bar.progress(0.6) | |
# Build graph | |
graph = graph_builder.build_graph(all_entities, all_relationships) | |
if not graph.nodes(): | |
st.error("β No valid knowledge graph could be built") | |
return None | |
status_text.text("Applying filters...") | |
progress_bar.progress(0.7) | |
# Apply filters | |
filtered_graph = graph | |
if entity_types_filter: | |
filtered_graph = graph_builder.filter_graph( | |
entity_types=entity_types_filter, | |
min_importance=min_importance | |
) | |
elif min_importance > 0: | |
filtered_graph = graph_builder.filter_graph(min_importance=min_importance) | |
if not filtered_graph.nodes(): | |
st.error("β No entities remain after applying filters") | |
return None | |
status_text.text("Generating visualizations...") | |
progress_bar.progress(0.8) | |
# Generate graph visualization based on type | |
if visualization_type == "plotly": | |
graph_viz = visualizer.create_plotly_interactive(filtered_graph, layout_type) | |
graph_image_path = None | |
elif visualization_type == "pyvis": | |
graph_image_path = visualizer.create_pyvis_interactive(filtered_graph, layout_type) | |
graph_viz = None | |
elif visualization_type == "vis.js": | |
graph_viz = visualizer.create_interactive_html(filtered_graph) | |
graph_image_path = None | |
else: # matplotlib | |
graph_image_path = visualizer.visualize_graph( | |
filtered_graph, | |
layout_type=layout_type, | |
show_labels=show_labels, | |
show_edge_labels=show_edge_labels | |
) | |
graph_viz = None | |
# Get statistics | |
stats = graph_builder.get_graph_statistics() | |
stats_summary = visualizer.create_statistics_summary(filtered_graph, stats) | |
# Get entity list | |
entity_list = visualizer.create_entity_list(filtered_graph) | |
# Get central nodes | |
central_nodes = graph_builder.get_central_nodes() | |
central_nodes_text = "## Most Central Entities\n\n" | |
for i, (node, score) in enumerate(central_nodes, 1): | |
central_nodes_text += f"{i}. **{node}** (centrality: {score:.3f})\n" | |
status_text.text("Complete!") | |
progress_bar.progress(1.0) | |
# Success message | |
success_msg = f"β Successfully processed {len([r for r in doc_results if r['status'] == 'success'])} document(s)" | |
if failed_files: | |
success_msg += f"\nβ οΈ {len(failed_files)} file(s) failed to process" | |
if extraction_errors: | |
success_msg += f"\nβ οΈ {len(extraction_errors)} extraction error(s) occurred" | |
return { | |
'success_msg': success_msg, | |
'graph_image_path': graph_image_path, | |
'graph_viz': graph_viz, | |
'visualization_type': visualization_type, | |
'stats_summary': stats_summary, | |
'entity_list': entity_list, | |
'central_nodes_text': central_nodes_text, | |
'graph': filtered_graph | |
} | |
except Exception as e: | |
st.error(f"β Error: {str(e)}") | |
st.error(f"Full traceback:\n{traceback.format_exc()}") | |
return None | |
# Main app | |
def main(): | |
st.title("πΈοΈ Knowledge Graph Extraction") | |
st.markdown(""" | |
Upload documents and extract knowledge graphs using LLMs via OpenRouter. | |
Supports PDF, TXT, DOCX, and JSON files. | |
""") | |
# Sidebar for configuration | |
with st.sidebar: | |
st.header("π Document Upload") | |
uploaded_files = st.file_uploader( | |
"Choose files", | |
type=['pdf', 'txt', 'docx', 'json'], | |
accept_multiple_files=True | |
) | |
batch_mode = st.checkbox( | |
"Batch Processing Mode", | |
value=False, | |
help="Process multiple files together" | |
) | |
st.header("π API Configuration") | |
api_key = st.text_input( | |
"OpenRouter API Key", | |
type="password", | |
placeholder="Enter your OpenRouter API key", | |
help="Get your key at openrouter.ai" | |
) | |
st.header("ποΈ Visualization Settings") | |
visualization_type = st.selectbox( | |
"Visualization Type", | |
options=visualizer.get_visualization_options(), | |
index=1, # Default to plotly for interactivity | |
help="Choose visualization method" | |
) | |
layout_type = st.selectbox( | |
"Layout Algorithm", | |
options=visualizer.get_layout_options(), | |
index=0 | |
) | |
show_labels = st.checkbox("Show Node Labels", value=True) | |
show_edge_labels = st.checkbox("Show Edge Labels", value=False) | |
st.header("π Filtering Options") | |
min_importance = st.slider( | |
"Minimum Entity Importance", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.3, | |
step=0.1 | |
) | |
entity_types_filter = st.multiselect( | |
"Entity Types Filter", | |
options=[], | |
help="Filter will be populated after processing" | |
) | |
process_button = st.button("π Extract Knowledge Graph", type="primary") | |
# Main content area | |
if process_button and uploaded_files: | |
with st.spinner("Processing..."): | |
result = process_uploaded_files( | |
uploaded_files, api_key, batch_mode, visualization_type, layout_type, | |
show_labels, show_edge_labels, min_importance, entity_types_filter | |
) | |
if result: | |
# Store results in session state | |
st.session_state['result'] = result | |
# Display success message | |
st.success(result['success_msg']) | |
# Create tabs for results | |
tab1, tab2, tab3, tab4 = st.tabs(["π Graph Visualization", "π Statistics", "π Entities", "π― Central Nodes"]) | |
with tab1: | |
viz_type = result['visualization_type'] | |
if viz_type == "plotly" and result['graph_viz']: | |
st.plotly_chart(result['graph_viz'], use_container_width=True) | |
st.info("π― Interactive Plotly graph: Hover over nodes for details, drag to pan, scroll to zoom") | |
elif viz_type == "pyvis" and result['graph_image_path'] and os.path.exists(result['graph_image_path']): | |
# Read HTML file and display | |
with open(result['graph_image_path'], 'r', encoding='utf-8') as f: | |
html_content = f.read() | |
st.components.v1.html(html_content, height=600, scrolling=True) | |
st.info("π― Interactive Pyvis graph: Drag nodes to rearrange, hover for details") | |
elif viz_type == "vis.js" and result['graph_viz']: | |
st.components.v1.html(result['graph_viz'], height=600, scrolling=True) | |
st.info("π― Interactive vis.js graph: Drag nodes, hover for details, use physics simulation") | |
elif viz_type == "matplotlib" and result['graph_image_path'] and os.path.exists(result['graph_image_path']): | |
st.image(result['graph_image_path'], caption="Knowledge Graph", use_column_width=True) | |
st.info("π Static matplotlib visualization") | |
else: | |
st.error("Failed to generate graph visualization") | |
with tab2: | |
st.markdown(result['stats_summary']) | |
with tab3: | |
st.markdown(result['entity_list']) | |
with tab4: | |
st.markdown(result['central_nodes_text']) | |
# Export options | |
st.header("πΎ Export Options") | |
col1, col2 = st.columns(2) | |
with col1: | |
export_format = st.selectbox( | |
"Export Format", | |
options=["json", "graphml", "gexf"], | |
index=0 | |
) | |
with col2: | |
if st.button("π₯ Export Graph"): | |
try: | |
export_data = graph_builder.export_graph(export_format) | |
st.text_area("Export Data", value=export_data, height=300) | |
# Download button | |
st.download_button( | |
label=f"Download {export_format.upper()} file", | |
data=export_data, | |
file_name=f"knowledge_graph.{export_format}", | |
mime="application/octet-stream" | |
) | |
except Exception as e: | |
st.error(f"Export failed: {str(e)}") | |
elif process_button and not uploaded_files: | |
st.warning("Please upload at least one file before processing.") | |
# Instructions | |
st.header("π Instructions") | |
with st.expander("How to use this app"): | |
st.markdown(""" | |
1. **Upload Documents**: Select one or more files (PDF, TXT, DOCX, JSON) using the file uploader in the sidebar | |
2. **Enter API Key**: Get a free API key from [OpenRouter](https://openrouter.ai) and enter it in the sidebar | |
3. **Configure Settings**: Adjust visualization and filtering options in the sidebar | |
4. **Extract Graph**: Click the "Extract Knowledge Graph" button and wait for processing | |
5. **Explore Results**: View the graph, statistics, and entity details in the tabs | |
6. **Export**: Download the graph data in various formats | |
""") | |
with st.expander("Features"): | |
st.markdown(""" | |
- **Multi-format Support**: PDF, TXT, DOCX, JSON files | |
- **Batch Processing**: Process multiple documents together | |
- **Smart Extraction**: Uses LLM to identify important entities and relationships | |
- **Interactive Filtering**: Filter by entity type and importance | |
- **Multiple Layouts**: Various graph layout algorithms | |
- **Export Options**: JSON, GraphML, GEXF formats | |
- **Free Models**: Uses cost-effective OpenRouter models | |
""") | |
with st.expander("Notes"): | |
st.markdown(""" | |
- File size limit: 10MB per file | |
- Free OpenRouter models are used to minimize costs | |
- Processing time depends on document size and complexity | |
""") | |
if __name__ == "__main__": | |
main() | |