Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,212 +1,310 @@
|
|
1 |
-
import os
|
2 |
-
import streamlit as st
|
3 |
-
from langchain_community.graphs import Neo4jGraph
|
4 |
-
import pandas as pd
|
5 |
-
import json
|
6 |
-
import time
|
7 |
-
|
8 |
-
from ki_gen.planner import build_planner_graph
|
9 |
-
|
10 |
-
from ki_gen.
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
"
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
"
|
49 |
-
"
|
50 |
-
"
|
51 |
-
"
|
52 |
-
"
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
#
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
#
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
""")
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from langchain_community.graphs import Neo4jGraph
|
4 |
+
import pandas as pd
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
|
8 |
+
from ki_gen.planner import build_planner_graph
|
9 |
+
# Update import path if init_app moved or args changed
|
10 |
+
from ki_gen.utils import init_app, memory, ConfigSchema, State # Import necessary types
|
11 |
+
from ki_gen.prompts import get_initial_prompt
|
12 |
+
|
13 |
+
from neo4j import GraphDatabase
|
14 |
+
|
15 |
+
# Set page config
|
16 |
+
st.set_page_config(page_title="Key Issue Generator", layout="wide")
|
17 |
+
|
18 |
+
# Neo4j Database Configuration
|
19 |
+
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
|
20 |
+
NEO4J_USERNAME = "neo4j"
|
21 |
+
NEO4J_PASSWORD = os.getenv("neo4j_password")
|
22 |
+
|
23 |
+
# API Keys for LLM services
|
24 |
+
OPENAI_API_KEY = os.getenv("openai_api_key")
|
25 |
+
# GROQ_API_KEY is removed as we switch to Gemini
|
26 |
+
# GROQ_API_KEY = os.getenv("groq_api_key")
|
27 |
+
# Ensure Gemini API key is available in the environment
|
28 |
+
GEMINI_API_KEY = os.getenv("gemini_api_key")
|
29 |
+
LANGSMITH_API_KEY = os.getenv("langsmith_api_key")
|
30 |
+
|
31 |
+
def verify_neo4j_connectivity():
|
32 |
+
"""Verify connection to Neo4j database"""
|
33 |
+
try:
|
34 |
+
# Ensure driver closes properly
|
35 |
+
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
36 |
+
driver.verify_connectivity()
|
37 |
+
driver.close() # Explicitly close the driver
|
38 |
+
return True # Return simple boolean
|
39 |
+
except Exception as e:
|
40 |
+
return f"Error: {str(e)}"
|
41 |
+
|
42 |
+
# Update load_config defaults
|
43 |
+
def load_config() -> ConfigSchema: # Add type hint
|
44 |
+
"""Load configuration with custom parameters"""
|
45 |
+
# Custom configuration based on provided parameters
|
46 |
+
# Update default models to gemini-2.0-flash
|
47 |
+
custom_config = {
|
48 |
+
"main_llm": "gemini-2.0-flash",
|
49 |
+
"plan_method": "generation",
|
50 |
+
"use_detailed_query": False,
|
51 |
+
"cypher_gen_method": "guided",
|
52 |
+
"validate_cypher": False,
|
53 |
+
"summarize_model": "gemini-2.0-flash",
|
54 |
+
"eval_method": "binary",
|
55 |
+
"eval_threshold": 0.7,
|
56 |
+
"max_docs": 15,
|
57 |
+
"compression_method": "llm_lingua",
|
58 |
+
"compress_rate": 0.33,
|
59 |
+
"force_tokens": ["."], # Converting to list format as expected by the application
|
60 |
+
"eval_model": "gemini-2.0-flash",
|
61 |
+
"thread_id": "3" # Consider making thread_id dynamic or user-specific
|
62 |
+
}
|
63 |
+
|
64 |
+
# Add Neo4j graph object to config
|
65 |
+
neo_graph = None # Initialize to None
|
66 |
+
try:
|
67 |
+
# Check connectivity before creating graph object potentially?
|
68 |
+
if verify_neo4j_connectivity() is True:
|
69 |
+
neo_graph = Neo4jGraph(
|
70 |
+
url=NEO4J_URI,
|
71 |
+
username=NEO4J_USERNAME,
|
72 |
+
password=NEO4J_PASSWORD
|
73 |
+
)
|
74 |
+
custom_config["graph"] = neo_graph
|
75 |
+
else:
|
76 |
+
st.error(f"Neo4j connection issue: {verify_neo4j_connectivity()}")
|
77 |
+
# Return None or raise error if graph is essential
|
78 |
+
return None
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
st.error(f"Error creating Neo4jGraph object: {e}")
|
82 |
+
return None
|
83 |
+
|
84 |
+
# Return wrapped in 'configurable' key as expected by LangGraph
|
85 |
+
return {"configurable": custom_config}
|
86 |
+
|
87 |
+
|
88 |
+
def generate_key_issues(user_query):
|
89 |
+
"""Main function to generate key issues from Neo4j data"""
|
90 |
+
# Initialize application with API keys (remove groq_key)
|
91 |
+
init_app(
|
92 |
+
openai_key=OPENAI_API_KEY,
|
93 |
+
# groq_key=GROQ_API_KEY, # Remove Groq key
|
94 |
+
langsmith_key=LANGSMITH_API_KEY
|
95 |
+
)
|
96 |
+
|
97 |
+
# Load configuration with custom parameters
|
98 |
+
config = load_config()
|
99 |
+
if not config or "configurable" not in config or not config["configurable"].get("graph"):
|
100 |
+
st.error("Failed to load configuration or connect to Neo4j. Cannot proceed.")
|
101 |
+
return None, []
|
102 |
+
|
103 |
+
# Create status containers
|
104 |
+
plan_status = st.empty()
|
105 |
+
plan_display = st.empty()
|
106 |
+
retrieval_status = st.empty()
|
107 |
+
processing_status = st.empty()
|
108 |
+
|
109 |
+
# Build planner graph
|
110 |
+
plan_status.info("Building planner graph...")
|
111 |
+
# Pass the full config dictionary to build_planner_graph
|
112 |
+
graph = build_planner_graph(memory, config)
|
113 |
+
|
114 |
+
# Execute initial prompt generation
|
115 |
+
plan_status.info(f"Generating plan for query: {user_query}")
|
116 |
+
|
117 |
+
messages_content = []
|
118 |
+
initial_prompt_data = get_initial_prompt(config, user_query)
|
119 |
+
|
120 |
+
# Stream initial plan generation
|
121 |
+
try:
|
122 |
+
for event in graph.stream(initial_prompt_data, config, stream_mode="values"):
|
123 |
+
if "messages" in event and event["messages"]:
|
124 |
+
event["messages"][-1].pretty_print()
|
125 |
+
messages_content.append(event["messages"][-1].content)
|
126 |
+
# Add checks for specific nodes if needed for status updates
|
127 |
+
# if "__start__" in event: # Example check
|
128 |
+
# plan_status.info("Starting plan generation...")
|
129 |
+
|
130 |
+
except Exception as e:
|
131 |
+
st.error(f"Error during initial graph stream: {e}")
|
132 |
+
return None, []
|
133 |
+
|
134 |
+
# Get the state with the generated plan (after initial stream/interrupt)
|
135 |
+
try:
|
136 |
+
# Ensure thread_id matches what's used internally if applicable
|
137 |
+
state = graph.get_state(config)
|
138 |
+
# Check if 'store_plan' exists and is a list
|
139 |
+
stored_plan = state.values.get('store_plan', [])
|
140 |
+
if isinstance(stored_plan, list) and stored_plan:
|
141 |
+
steps = [i for i in range(1, len(stored_plan)+1)]
|
142 |
+
plan_df = pd.DataFrame({'Plan steps': steps, 'Description': stored_plan})
|
143 |
+
plan_status.success("Plan generation complete!")
|
144 |
+
plan_display.dataframe(plan_df, use_container_width=True)
|
145 |
+
else:
|
146 |
+
plan_status.warning("Plan not found or empty in graph state after generation.")
|
147 |
+
plan_display.empty() # Clear display if no plan
|
148 |
+
|
149 |
+
except Exception as e:
|
150 |
+
st.error(f"Error getting graph state or displaying plan: {e}")
|
151 |
+
return None, []
|
152 |
+
|
153 |
+
|
154 |
+
# Continue with plan execution for document retrieval
|
155 |
+
# This part assumes the graph will continue after the first interrupt
|
156 |
+
retrieval_status.info("Retrieving documents...")
|
157 |
+
try:
|
158 |
+
# Stream from the current state (None indicates continue)
|
159 |
+
for event in graph.stream(None, config, stream_mode="values"):
|
160 |
+
if "messages" in event and event["messages"]:
|
161 |
+
event["messages"][-1].pretty_print()
|
162 |
+
messages_content.append(event["messages"][-1].content)
|
163 |
+
# Add checks for nodes like 'human_validation' if needed for status
|
164 |
+
|
165 |
+
except Exception as e:
|
166 |
+
st.error(f"Error during document retrieval stream: {e}")
|
167 |
+
return None, []
|
168 |
+
|
169 |
+
|
170 |
+
# Get updated state after document retrieval interrupt
|
171 |
+
try:
|
172 |
+
snapshot = graph.get_state(config)
|
173 |
+
valid_docs_retrieved = snapshot.values.get('valid_docs', [])
|
174 |
+
doc_count = len(valid_docs_retrieved) if isinstance(valid_docs_retrieved, list) else 0
|
175 |
+
retrieval_status.success(f"Retrieved {doc_count} documents")
|
176 |
+
|
177 |
+
# --- Human Validation / Processing Steps ---
|
178 |
+
# This section needs interaction logic if manual validation is desired.
|
179 |
+
# For now, setting default processing steps and marking as validated.
|
180 |
+
processing_status.info("Processing documents...")
|
181 |
+
process_steps = ["summarize"] # Default: just summarize
|
182 |
+
|
183 |
+
# Update state to indicate human validation is complete and specify processing steps
|
184 |
+
# This should happen *before* the next stream call that triggers processing
|
185 |
+
graph.update_state(config, {'human_validated': True, 'process_steps': process_steps})
|
186 |
+
|
187 |
+
except Exception as e:
|
188 |
+
st.error(f"Error getting state after retrieval or setting up processing: {e}")
|
189 |
+
return None, []
|
190 |
+
|
191 |
+
|
192 |
+
# Continue execution with document processing
|
193 |
+
try:
|
194 |
+
for event in graph.stream(None, config, stream_mode="values"):
|
195 |
+
if "messages" in event and event["messages"]:
|
196 |
+
event["messages"][-1].pretty_print()
|
197 |
+
messages_content.append(event["messages"][-1].content)
|
198 |
+
# Check for the end node or final chatbot node if needed
|
199 |
+
|
200 |
+
except Exception as e:
|
201 |
+
st.error(f"Error during document processing stream: {e}")
|
202 |
+
return None, []
|
203 |
+
|
204 |
+
# Get final state after processing
|
205 |
+
try:
|
206 |
+
final_snapshot = graph.get_state(config)
|
207 |
+
processing_status.success("Document processing complete!")
|
208 |
+
|
209 |
+
# Extract final result and documents
|
210 |
+
final_result = None
|
211 |
+
valid_docs_final = []
|
212 |
+
if "messages" in final_snapshot.values and final_snapshot.values["messages"]:
|
213 |
+
# Assume the last message contains the final result
|
214 |
+
final_result = final_snapshot.values["messages"][-1].content
|
215 |
+
|
216 |
+
# Get the final state of valid_docs (might be processed summaries)
|
217 |
+
valid_docs_final = final_snapshot.values.get('valid_docs', [])
|
218 |
+
if not isinstance(valid_docs_final, list): # Ensure it's a list
|
219 |
+
valid_docs_final = []
|
220 |
+
|
221 |
+
return final_result, valid_docs_final
|
222 |
+
|
223 |
+
except Exception as e:
|
224 |
+
st.error(f"Error getting final state or extracting results: {e}")
|
225 |
+
return None, []
|
226 |
+
|
227 |
+
# App header
|
228 |
+
st.title("Key Issue Generator")
|
229 |
+
st.write("Generate key issues from a Neo4j knowledge graph using advanced language models.")
|
230 |
+
|
231 |
+
# Check database connectivity
|
232 |
+
connectivity_status = verify_neo4j_connectivity()
|
233 |
+
st.sidebar.header("Database Status")
|
234 |
+
# Use boolean check
|
235 |
+
if connectivity_status is True:
|
236 |
+
st.sidebar.success("Connected to Neo4j database")
|
237 |
+
else:
|
238 |
+
# Display the error message returned
|
239 |
+
st.sidebar.error(f"Database connection issue: {connectivity_status}")
|
240 |
+
|
241 |
+
# User input section
|
242 |
+
st.header("Enter Your Query")
|
243 |
+
user_query = st.text_area("What would you like to explore?",
|
244 |
+
"What are the main challenges in AI adoption for healthcare systems?",
|
245 |
+
height=100)
|
246 |
+
|
247 |
+
# Process button
|
248 |
+
if st.button("Generate Key Issues", type="primary"):
|
249 |
+
# Update API key check for Gemini
|
250 |
+
if not OPENAI_API_KEY or not GEMINI_API_KEY or not LANGSMITH_API_KEY or not NEO4J_PASSWORD:
|
251 |
+
st.error("Required API keys (OpenAI, Gemini, Langsmith) or database credentials are missing. Please check your environment variables.")
|
252 |
+
elif connectivity_status is not True: # Check DB connection again before starting
|
253 |
+
st.error(f"Cannot start: Neo4j connection issue: {connectivity_status}")
|
254 |
+
else:
|
255 |
+
with st.spinner("Processing your query..."):
|
256 |
+
start_time = time.time()
|
257 |
+
# Call the main generation function
|
258 |
+
final_result, valid_docs = generate_key_issues(user_query)
|
259 |
+
end_time = time.time()
|
260 |
+
|
261 |
+
if final_result is not None: # Check if result is not None (indicating success)
|
262 |
+
# Display execution time
|
263 |
+
st.sidebar.info(f"Total execution time: {round(end_time - start_time, 2)} seconds")
|
264 |
+
|
265 |
+
# Display final result
|
266 |
+
st.header("Generated Key Issues")
|
267 |
+
st.markdown(final_result)
|
268 |
+
|
269 |
+
# Option to download results
|
270 |
+
st.download_button(
|
271 |
+
label="Download Results",
|
272 |
+
data=final_result, # Ensure final_result is string data
|
273 |
+
file_name="key_issues_results.txt",
|
274 |
+
mime="text/plain"
|
275 |
+
)
|
276 |
+
|
277 |
+
# Display retrieved/processed documents in expandable section
|
278 |
+
if valid_docs:
|
279 |
+
with st.expander("View Processed Documents"): # Update title
|
280 |
+
for i, doc in enumerate(valid_docs):
|
281 |
+
st.markdown(f"### Document {i+1}")
|
282 |
+
# Handle doc format (could be string summary or original dict)
|
283 |
+
if isinstance(doc, dict):
|
284 |
+
for key in doc:
|
285 |
+
st.markdown(f"**{key}**: {doc[key]}")
|
286 |
+
elif isinstance(doc, str):
|
287 |
+
st.markdown(doc) # Display string directly if it's a summary
|
288 |
+
else:
|
289 |
+
st.markdown(str(doc)) # Fallback for other types
|
290 |
+
st.divider()
|
291 |
+
else:
|
292 |
+
# Error messages are now shown within generate_key_issues
|
293 |
+
# st.error("An error occurred during processing. Please check the logs or console output for details.")
|
294 |
+
# Adding a placeholder here in case specific errors weren't caught
|
295 |
+
if final_result is None: # Check explicit None return
|
296 |
+
st.error("Processing failed. Please check the console/logs for errors.")
|
297 |
+
|
298 |
+
|
299 |
+
# Help information in sidebar
|
300 |
+
with st.sidebar:
|
301 |
+
st.header("About")
|
302 |
+
st.info("""
|
303 |
+
This application uses advanced language models (like Google Gemini) to analyze a Neo4j knowledge graph
|
304 |
+
and generate key issues based on your query. The process involves:
|
305 |
+
|
306 |
+
1. Creating a plan based on your query
|
307 |
+
2. Retrieving relevant documents from the database
|
308 |
+
3. Processing and summarizing the information
|
309 |
+
4. Generating a comprehensive response
|
310 |
""")
|