Spaces:
Sleeping
Sleeping
File size: 13,171 Bytes
32d54de f8ac349 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
import os
import streamlit as st
from langchain_community.graphs import Neo4jGraph
import pandas as pd
import json
import time
from ki_gen.planner import build_planner_graph
# Update import path if init_app moved or args changed
from ki_gen.utils import init_app, memory, ConfigSchema, State # Import necessary types
from ki_gen.prompts import get_initial_prompt
from neo4j import GraphDatabase
# Set page config
st.set_page_config(page_title="Key Issue Generator", layout="wide")
# Neo4j Database Configuration
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = os.getenv("neo4j_password")
# API Keys for LLM services
OPENAI_API_KEY = os.getenv("openai_api_key")
# GROQ_API_KEY is removed as we switch to Gemini
# GROQ_API_KEY = os.getenv("groq_api_key")
# Ensure Gemini API key is available in the environment
GEMINI_API_KEY = os.getenv("gemini_api_key")
LANGSMITH_API_KEY = os.getenv("langsmith_api_key")
def verify_neo4j_connectivity():
"""Verify connection to Neo4j database"""
try:
# Ensure driver closes properly
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
driver.verify_connectivity()
driver.close() # Explicitly close the driver
return True # Return simple boolean
except Exception as e:
return f"Error: {str(e)}"
# Update load_config defaults
def load_config() -> ConfigSchema: # Add type hint
"""Load configuration with custom parameters"""
# Custom configuration based on provided parameters
# Update default models to gemini-2.0-flash
custom_config = {
"main_llm": "gemini-2.0-flash",
"plan_method": "generation",
"use_detailed_query": False,
"cypher_gen_method": "guided",
"validate_cypher": False,
"summarize_model": "gemini-2.0-flash",
"eval_method": "binary",
"eval_threshold": 0.7,
"max_docs": 15,
"compression_method": "llm_lingua",
"compress_rate": 0.33,
"force_tokens": ["."], # Converting to list format as expected by the application
"eval_model": "gemini-2.0-flash",
"thread_id": "3" # Consider making thread_id dynamic or user-specific
}
# Add Neo4j graph object to config
neo_graph = None # Initialize to None
try:
# Check connectivity before creating graph object potentially?
if verify_neo4j_connectivity() is True:
neo_graph = Neo4jGraph(
url=NEO4J_URI,
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD
)
custom_config["graph"] = neo_graph
else:
st.error(f"Neo4j connection issue: {verify_neo4j_connectivity()}")
# Return None or raise error if graph is essential
return None
except Exception as e:
st.error(f"Error creating Neo4jGraph object: {e}")
return None
# Return wrapped in 'configurable' key as expected by LangGraph
return {"configurable": custom_config}
def generate_key_issues(user_query):
"""Main function to generate key issues from Neo4j data"""
# Initialize application with API keys (remove groq_key)
init_app(
openai_key=OPENAI_API_KEY,
# groq_key=GROQ_API_KEY, # Remove Groq key
langsmith_key=LANGSMITH_API_KEY
)
# Load configuration with custom parameters
config = load_config()
if not config or "configurable" not in config or not config["configurable"].get("graph"):
st.error("Failed to load configuration or connect to Neo4j. Cannot proceed.")
return None, []
# Create status containers
plan_status = st.empty()
plan_display = st.empty()
retrieval_status = st.empty()
processing_status = st.empty()
# Build planner graph
plan_status.info("Building planner graph...")
# Pass the full config dictionary to build_planner_graph
graph = build_planner_graph(memory, config)
# Execute initial prompt generation
plan_status.info(f"Generating plan for query: {user_query}")
messages_content = []
initial_prompt_data = get_initial_prompt(config, user_query)
# Stream initial plan generation
try:
for event in graph.stream(initial_prompt_data, config, stream_mode="values"):
if "messages" in event and event["messages"]:
event["messages"][-1].pretty_print()
messages_content.append(event["messages"][-1].content)
# Add checks for specific nodes if needed for status updates
# if "__start__" in event: # Example check
# plan_status.info("Starting plan generation...")
except Exception as e:
st.error(f"Error during initial graph stream: {e}")
return None, []
# Get the state with the generated plan (after initial stream/interrupt)
try:
# Ensure thread_id matches what's used internally if applicable
state = graph.get_state(config)
# Check if 'store_plan' exists and is a list
stored_plan = state.values.get('store_plan', [])
if isinstance(stored_plan, list) and stored_plan:
steps = [i for i in range(1, len(stored_plan)+1)]
plan_df = pd.DataFrame({'Plan steps': steps, 'Description': stored_plan})
plan_status.success("Plan generation complete!")
plan_display.dataframe(plan_df, use_container_width=True)
else:
plan_status.warning("Plan not found or empty in graph state after generation.")
plan_display.empty() # Clear display if no plan
except Exception as e:
st.error(f"Error getting graph state or displaying plan: {e}")
return None, []
# Continue with plan execution for document retrieval
# This part assumes the graph will continue after the first interrupt
retrieval_status.info("Retrieving documents...")
try:
# Stream from the current state (None indicates continue)
for event in graph.stream(None, config, stream_mode="values"):
if "messages" in event and event["messages"]:
event["messages"][-1].pretty_print()
messages_content.append(event["messages"][-1].content)
# Add checks for nodes like 'human_validation' if needed for status
except Exception as e:
st.error(f"Error during document retrieval stream: {e}")
return None, []
# Get updated state after document retrieval interrupt
try:
snapshot = graph.get_state(config)
valid_docs_retrieved = snapshot.values.get('valid_docs', [])
doc_count = len(valid_docs_retrieved) if isinstance(valid_docs_retrieved, list) else 0
retrieval_status.success(f"Retrieved {doc_count} documents")
# --- Human Validation / Processing Steps ---
# This section needs interaction logic if manual validation is desired.
# For now, setting default processing steps and marking as validated.
processing_status.info("Processing documents...")
process_steps = ["summarize"] # Default: just summarize
# Update state to indicate human validation is complete and specify processing steps
# This should happen *before* the next stream call that triggers processing
graph.update_state(config, {'human_validated': True, 'process_steps': process_steps})
except Exception as e:
st.error(f"Error getting state after retrieval or setting up processing: {e}")
return None, []
# Continue execution with document processing
try:
for event in graph.stream(None, config, stream_mode="values"):
if "messages" in event and event["messages"]:
event["messages"][-1].pretty_print()
messages_content.append(event["messages"][-1].content)
# Check for the end node or final chatbot node if needed
except Exception as e:
st.error(f"Error during document processing stream: {e}")
return None, []
# Get final state after processing
try:
final_snapshot = graph.get_state(config)
processing_status.success("Document processing complete!")
# Extract final result and documents
final_result = None
valid_docs_final = []
if "messages" in final_snapshot.values and final_snapshot.values["messages"]:
# Assume the last message contains the final result
final_result = final_snapshot.values["messages"][-1].content
# Get the final state of valid_docs (might be processed summaries)
valid_docs_final = final_snapshot.values.get('valid_docs', [])
if not isinstance(valid_docs_final, list): # Ensure it's a list
valid_docs_final = []
return final_result, valid_docs_final
except Exception as e:
st.error(f"Error getting final state or extracting results: {e}")
return None, []
# App header
st.title("Key Issue Generator")
st.write("Generate key issues from a Neo4j knowledge graph using advanced language models.")
# Check database connectivity
connectivity_status = verify_neo4j_connectivity()
st.sidebar.header("Database Status")
# Use boolean check
if connectivity_status is True:
st.sidebar.success("Connected to Neo4j database")
else:
# Display the error message returned
st.sidebar.error(f"Database connection issue: {connectivity_status}")
# User input section
st.header("Enter Your Query")
user_query = st.text_area("What would you like to explore?",
"What are the main challenges in AI adoption for healthcare systems?",
height=100)
# Process button
if st.button("Generate Key Issues", type="primary"):
# Update API key check for Gemini
if not OPENAI_API_KEY or not GEMINI_API_KEY or not LANGSMITH_API_KEY or not NEO4J_PASSWORD:
st.error("Required API keys (OpenAI, Gemini, Langsmith) or database credentials are missing. Please check your environment variables.")
elif connectivity_status is not True: # Check DB connection again before starting
st.error(f"Cannot start: Neo4j connection issue: {connectivity_status}")
else:
with st.spinner("Processing your query..."):
start_time = time.time()
# Call the main generation function
final_result, valid_docs = generate_key_issues(user_query)
end_time = time.time()
if final_result is not None: # Check if result is not None (indicating success)
# Display execution time
st.sidebar.info(f"Total execution time: {round(end_time - start_time, 2)} seconds")
# Display final result
st.header("Generated Key Issues")
st.markdown(final_result)
# Option to download results
st.download_button(
label="Download Results",
data=final_result, # Ensure final_result is string data
file_name="key_issues_results.txt",
mime="text/plain"
)
# Display retrieved/processed documents in expandable section
if valid_docs:
with st.expander("View Processed Documents"): # Update title
for i, doc in enumerate(valid_docs):
st.markdown(f"### Document {i+1}")
# Handle doc format (could be string summary or original dict)
if isinstance(doc, dict):
for key in doc:
st.markdown(f"**{key}**: {doc[key]}")
elif isinstance(doc, str):
st.markdown(doc) # Display string directly if it's a summary
else:
st.markdown(str(doc)) # Fallback for other types
st.divider()
else:
# Error messages are now shown within generate_key_issues
# st.error("An error occurred during processing. Please check the logs or console output for details.")
# Adding a placeholder here in case specific errors weren't caught
if final_result is None: # Check explicit None return
st.error("Processing failed. Please check the console/logs for errors.")
# Help information in sidebar
with st.sidebar:
st.header("About")
st.info("""
This application uses advanced language models (like Google Gemini) to analyze a Neo4j knowledge graph
and generate key issues based on your query. The process involves:
1. Creating a plan based on your query
2. Retrieving relevant documents from the database
3. Processing and summarizing the information
4. Generating a comprehensive response
""") |