adrienbrdne commited on
Commit
32d54de
·
verified ·
1 Parent(s): 77b16df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -211
app.py CHANGED
@@ -1,212 +1,310 @@
1
- import os
2
- import streamlit as st
3
- from langchain_community.graphs import Neo4jGraph
4
- import pandas as pd
5
- import json
6
- import time
7
-
8
- from ki_gen.planner import build_planner_graph
9
- from ki_gen.utils import init_app, memory
10
- from ki_gen.prompts import get_initial_prompt
11
-
12
- from neo4j import GraphDatabase
13
-
14
- # Set page config
15
- st.set_page_config(page_title="Key Issue Generator", layout="wide")
16
-
17
- # Neo4j Database Configuration
18
- NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
19
- NEO4J_USERNAME = "neo4j"
20
- NEO4J_PASSWORD = os.getenv("neo4j_password")
21
-
22
- # API Keys for LLM services
23
- OPENAI_API_KEY = os.getenv("openai_api_key")
24
- GROQ_API_KEY = os.getenv("groq_api_key")
25
- LANGSMITH_API_KEY = os.getenv("langsmith_api_key")
26
-
27
- def verify_neo4j_connectivity():
28
- """Verify connection to Neo4j database"""
29
- try:
30
- with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
31
- return driver.verify_connectivity()
32
- except Exception as e:
33
- return f"Error: {str(e)}"
34
-
35
- def load_config():
36
- """Load configuration with custom parameters"""
37
- # Custom configuration based on provided parameters
38
- custom_config = {
39
- "main_llm": "deepseek-r1-distill-llama-70b",
40
- "plan_method": "generation",
41
- "use_detailed_query": False,
42
- "cypher_gen_method": "guided",
43
- "validate_cypher": False,
44
- "summarize_model": "deepseek-r1-distill-llama-70b",
45
- "eval_method": "binary",
46
- "eval_threshold": 0.7,
47
- "max_docs": 15,
48
- "compression_method": "llm_lingua",
49
- "compress_rate": 0.33,
50
- "force_tokens": ["."], # Converting to list format as expected by the application
51
- "eval_model": "deepseek-r1-distill-llama-70b",
52
- "thread_id": "3"
53
- }
54
-
55
- # Add Neo4j graph object to config
56
- try:
57
- neo_graph = Neo4jGraph(
58
- url=NEO4J_URI,
59
- username=NEO4J_USERNAME,
60
- password=NEO4J_PASSWORD
61
- )
62
- custom_config["graph"] = neo_graph
63
- except Exception as e:
64
- st.error(f"Error connecting to Neo4j: {e}")
65
- return None
66
-
67
- return {"configurable": custom_config}
68
-
69
- def generate_key_issues(user_query):
70
- """Main function to generate key issues from Neo4j data"""
71
- # Initialize application with API keys
72
- init_app(
73
- openai_key=OPENAI_API_KEY,
74
- groq_key=GROQ_API_KEY,
75
- langsmith_key=LANGSMITH_API_KEY
76
- )
77
-
78
- # Load configuration with custom parameters
79
- config = load_config()
80
- if not config:
81
- return None
82
-
83
- # Create status containers
84
- plan_status = st.empty()
85
- plan_display = st.empty()
86
- retrieval_status = st.empty()
87
- processing_status = st.empty()
88
-
89
- # Build planner graph
90
- plan_status.info("Building planner graph...")
91
- graph = build_planner_graph(memory, config["configurable"])
92
-
93
- # Execute initial prompt generation
94
- plan_status.info(f"Generating plan for query: {user_query}")
95
-
96
- messages_content = []
97
- for event in graph.stream(get_initial_prompt(config, user_query), config, stream_mode="values"):
98
- if "messages" in event:
99
- event["messages"][-1].pretty_print()
100
- messages_content.append(event["messages"][-1].content)
101
-
102
- # Get the state with the generated plan
103
- state = graph.get_state(config)
104
- steps = [i for i in range(1, len(state.values['store_plan'])+1)]
105
- plan_df = pd.DataFrame({'Plan steps': steps, 'Description': state.values['store_plan']})
106
-
107
- # Display the plan
108
- plan_status.success("Plan generation complete!")
109
- plan_display.dataframe(plan_df, use_container_width=True)
110
-
111
- # Continue with plan execution for document retrieval
112
- retrieval_status.info("Retrieving documents...")
113
- for event in graph.stream(None, config, stream_mode="values"):
114
- if "messages" in event:
115
- event["messages"][-1].pretty_print()
116
- messages_content.append(event["messages"][-1].content)
117
-
118
- # Get updated state after document retrieval
119
- snapshot = graph.get_state(config)
120
- doc_count = len(snapshot.values.get('valid_docs', []))
121
- retrieval_status.success(f"Retrieved {doc_count} documents")
122
-
123
- # Proceed to document processing
124
- processing_status.info("Processing documents...")
125
- process_steps = ["summarize"] # Using summarize as default processing step
126
-
127
- # Update state to indicate human validation is complete and specify processing steps
128
- graph.update_state(config, {'human_validated': True, 'process_steps': process_steps}, as_node="human_validation")
129
-
130
- # Continue execution with document processing
131
- for event in graph.stream(None, config, stream_mode="values"):
132
- if "messages" in event:
133
- event["messages"][-1].pretty_print()
134
- messages_content.append(event["messages"][-1].content)
135
-
136
- # Get final state after processing
137
- final_snapshot = graph.get_state(config)
138
- processing_status.success("Document processing complete!")
139
-
140
- if "messages" in final_snapshot.values:
141
- final_result = final_snapshot.values["messages"][-1].content
142
- return final_result, final_snapshot.values.get('valid_docs', [])
143
-
144
- return None, []
145
-
146
- # App header
147
- st.title("Key Issue Generator")
148
- st.write("Generate key issues from a Neo4j knowledge graph using advanced language models.")
149
-
150
- # Check database connectivity
151
- connectivity_status = verify_neo4j_connectivity()
152
- st.sidebar.header("Database Status")
153
- if "Error" not in str(connectivity_status):
154
- st.sidebar.success("Connected to Neo4j database")
155
- else:
156
- st.sidebar.error(f"Database connection issue: {connectivity_status}")
157
-
158
- # User input section
159
- st.header("Enter Your Query")
160
- user_query = st.text_area("What would you like to explore?",
161
- "What are the main challenges in AI adoption for healthcare systems?",
162
- height=100)
163
-
164
- # Process button
165
- if st.button("Generate Key Issues", type="primary"):
166
- if not OPENAI_API_KEY or not GROQ_API_KEY or not LANGSMITH_API_KEY or not NEO4J_PASSWORD:
167
- st.error("Required API keys or database credentials are missing. Please check your environment variables.")
168
- else:
169
- with st.spinner("Processing your query..."):
170
- start_time = time.time()
171
- final_result, valid_docs = generate_key_issues(user_query)
172
- end_time = time.time()
173
-
174
- if final_result:
175
- # Display execution time
176
- st.sidebar.info(f"Total execution time: {round(end_time - start_time, 2)} seconds")
177
-
178
- # Display final result
179
- st.header("Generated Key Issues")
180
- st.markdown(final_result)
181
-
182
- # Option to download results
183
- st.download_button(
184
- label="Download Results",
185
- data=final_result,
186
- file_name="key_issues_results.txt",
187
- mime="text/plain"
188
- )
189
-
190
- # Display retrieved documents in expandable section
191
- if valid_docs:
192
- with st.expander("View Retrieved Documents"):
193
- for i, doc in enumerate(valid_docs):
194
- st.markdown(f"### Document {i+1}")
195
- for key in doc:
196
- st.markdown(f"**{key}**: {doc[key]}")
197
- st.divider()
198
- else:
199
- st.error("An error occurred during processing. Please check the logs for details.")
200
-
201
- # Help information in sidebar
202
- with st.sidebar:
203
- st.header("About")
204
- st.info("""
205
- This application uses advanced language models to analyze a Neo4j knowledge graph and generate key issues
206
- based on your query. The process involves:
207
-
208
- 1. Creating a plan based on your query
209
- 2. Retrieving relevant documents from the database
210
- 3. Processing and summarizing the information
211
- 4. Generating a comprehensive response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  """)
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain_community.graphs import Neo4jGraph
4
+ import pandas as pd
5
+ import json
6
+ import time
7
+
8
+ from ki_gen.planner import build_planner_graph
9
+ # Update import path if init_app moved or args changed
10
+ from ki_gen.utils import init_app, memory, ConfigSchema, State # Import necessary types
11
+ from ki_gen.prompts import get_initial_prompt
12
+
13
+ from neo4j import GraphDatabase
14
+
15
+ # Set page config
16
+ st.set_page_config(page_title="Key Issue Generator", layout="wide")
17
+
18
+ # Neo4j Database Configuration
19
+ NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
20
+ NEO4J_USERNAME = "neo4j"
21
+ NEO4J_PASSWORD = os.getenv("neo4j_password")
22
+
23
+ # API Keys for LLM services
24
+ OPENAI_API_KEY = os.getenv("openai_api_key")
25
+ # GROQ_API_KEY is removed as we switch to Gemini
26
+ # GROQ_API_KEY = os.getenv("groq_api_key")
27
+ # Ensure Gemini API key is available in the environment
28
+ GEMINI_API_KEY = os.getenv("gemini_api_key")
29
+ LANGSMITH_API_KEY = os.getenv("langsmith_api_key")
30
+
31
+ def verify_neo4j_connectivity():
32
+ """Verify connection to Neo4j database"""
33
+ try:
34
+ # Ensure driver closes properly
35
+ driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
36
+ driver.verify_connectivity()
37
+ driver.close() # Explicitly close the driver
38
+ return True # Return simple boolean
39
+ except Exception as e:
40
+ return f"Error: {str(e)}"
41
+
42
+ # Update load_config defaults
43
+ def load_config() -> ConfigSchema: # Add type hint
44
+ """Load configuration with custom parameters"""
45
+ # Custom configuration based on provided parameters
46
+ # Update default models to gemini-2.0-flash
47
+ custom_config = {
48
+ "main_llm": "gemini-2.0-flash",
49
+ "plan_method": "generation",
50
+ "use_detailed_query": False,
51
+ "cypher_gen_method": "guided",
52
+ "validate_cypher": False,
53
+ "summarize_model": "gemini-2.0-flash",
54
+ "eval_method": "binary",
55
+ "eval_threshold": 0.7,
56
+ "max_docs": 15,
57
+ "compression_method": "llm_lingua",
58
+ "compress_rate": 0.33,
59
+ "force_tokens": ["."], # Converting to list format as expected by the application
60
+ "eval_model": "gemini-2.0-flash",
61
+ "thread_id": "3" # Consider making thread_id dynamic or user-specific
62
+ }
63
+
64
+ # Add Neo4j graph object to config
65
+ neo_graph = None # Initialize to None
66
+ try:
67
+ # Check connectivity before creating graph object potentially?
68
+ if verify_neo4j_connectivity() is True:
69
+ neo_graph = Neo4jGraph(
70
+ url=NEO4J_URI,
71
+ username=NEO4J_USERNAME,
72
+ password=NEO4J_PASSWORD
73
+ )
74
+ custom_config["graph"] = neo_graph
75
+ else:
76
+ st.error(f"Neo4j connection issue: {verify_neo4j_connectivity()}")
77
+ # Return None or raise error if graph is essential
78
+ return None
79
+
80
+ except Exception as e:
81
+ st.error(f"Error creating Neo4jGraph object: {e}")
82
+ return None
83
+
84
+ # Return wrapped in 'configurable' key as expected by LangGraph
85
+ return {"configurable": custom_config}
86
+
87
+
88
+ def generate_key_issues(user_query):
89
+ """Main function to generate key issues from Neo4j data"""
90
+ # Initialize application with API keys (remove groq_key)
91
+ init_app(
92
+ openai_key=OPENAI_API_KEY,
93
+ # groq_key=GROQ_API_KEY, # Remove Groq key
94
+ langsmith_key=LANGSMITH_API_KEY
95
+ )
96
+
97
+ # Load configuration with custom parameters
98
+ config = load_config()
99
+ if not config or "configurable" not in config or not config["configurable"].get("graph"):
100
+ st.error("Failed to load configuration or connect to Neo4j. Cannot proceed.")
101
+ return None, []
102
+
103
+ # Create status containers
104
+ plan_status = st.empty()
105
+ plan_display = st.empty()
106
+ retrieval_status = st.empty()
107
+ processing_status = st.empty()
108
+
109
+ # Build planner graph
110
+ plan_status.info("Building planner graph...")
111
+ # Pass the full config dictionary to build_planner_graph
112
+ graph = build_planner_graph(memory, config)
113
+
114
+ # Execute initial prompt generation
115
+ plan_status.info(f"Generating plan for query: {user_query}")
116
+
117
+ messages_content = []
118
+ initial_prompt_data = get_initial_prompt(config, user_query)
119
+
120
+ # Stream initial plan generation
121
+ try:
122
+ for event in graph.stream(initial_prompt_data, config, stream_mode="values"):
123
+ if "messages" in event and event["messages"]:
124
+ event["messages"][-1].pretty_print()
125
+ messages_content.append(event["messages"][-1].content)
126
+ # Add checks for specific nodes if needed for status updates
127
+ # if "__start__" in event: # Example check
128
+ # plan_status.info("Starting plan generation...")
129
+
130
+ except Exception as e:
131
+ st.error(f"Error during initial graph stream: {e}")
132
+ return None, []
133
+
134
+ # Get the state with the generated plan (after initial stream/interrupt)
135
+ try:
136
+ # Ensure thread_id matches what's used internally if applicable
137
+ state = graph.get_state(config)
138
+ # Check if 'store_plan' exists and is a list
139
+ stored_plan = state.values.get('store_plan', [])
140
+ if isinstance(stored_plan, list) and stored_plan:
141
+ steps = [i for i in range(1, len(stored_plan)+1)]
142
+ plan_df = pd.DataFrame({'Plan steps': steps, 'Description': stored_plan})
143
+ plan_status.success("Plan generation complete!")
144
+ plan_display.dataframe(plan_df, use_container_width=True)
145
+ else:
146
+ plan_status.warning("Plan not found or empty in graph state after generation.")
147
+ plan_display.empty() # Clear display if no plan
148
+
149
+ except Exception as e:
150
+ st.error(f"Error getting graph state or displaying plan: {e}")
151
+ return None, []
152
+
153
+
154
+ # Continue with plan execution for document retrieval
155
+ # This part assumes the graph will continue after the first interrupt
156
+ retrieval_status.info("Retrieving documents...")
157
+ try:
158
+ # Stream from the current state (None indicates continue)
159
+ for event in graph.stream(None, config, stream_mode="values"):
160
+ if "messages" in event and event["messages"]:
161
+ event["messages"][-1].pretty_print()
162
+ messages_content.append(event["messages"][-1].content)
163
+ # Add checks for nodes like 'human_validation' if needed for status
164
+
165
+ except Exception as e:
166
+ st.error(f"Error during document retrieval stream: {e}")
167
+ return None, []
168
+
169
+
170
+ # Get updated state after document retrieval interrupt
171
+ try:
172
+ snapshot = graph.get_state(config)
173
+ valid_docs_retrieved = snapshot.values.get('valid_docs', [])
174
+ doc_count = len(valid_docs_retrieved) if isinstance(valid_docs_retrieved, list) else 0
175
+ retrieval_status.success(f"Retrieved {doc_count} documents")
176
+
177
+ # --- Human Validation / Processing Steps ---
178
+ # This section needs interaction logic if manual validation is desired.
179
+ # For now, setting default processing steps and marking as validated.
180
+ processing_status.info("Processing documents...")
181
+ process_steps = ["summarize"] # Default: just summarize
182
+
183
+ # Update state to indicate human validation is complete and specify processing steps
184
+ # This should happen *before* the next stream call that triggers processing
185
+ graph.update_state(config, {'human_validated': True, 'process_steps': process_steps})
186
+
187
+ except Exception as e:
188
+ st.error(f"Error getting state after retrieval or setting up processing: {e}")
189
+ return None, []
190
+
191
+
192
+ # Continue execution with document processing
193
+ try:
194
+ for event in graph.stream(None, config, stream_mode="values"):
195
+ if "messages" in event and event["messages"]:
196
+ event["messages"][-1].pretty_print()
197
+ messages_content.append(event["messages"][-1].content)
198
+ # Check for the end node or final chatbot node if needed
199
+
200
+ except Exception as e:
201
+ st.error(f"Error during document processing stream: {e}")
202
+ return None, []
203
+
204
+ # Get final state after processing
205
+ try:
206
+ final_snapshot = graph.get_state(config)
207
+ processing_status.success("Document processing complete!")
208
+
209
+ # Extract final result and documents
210
+ final_result = None
211
+ valid_docs_final = []
212
+ if "messages" in final_snapshot.values and final_snapshot.values["messages"]:
213
+ # Assume the last message contains the final result
214
+ final_result = final_snapshot.values["messages"][-1].content
215
+
216
+ # Get the final state of valid_docs (might be processed summaries)
217
+ valid_docs_final = final_snapshot.values.get('valid_docs', [])
218
+ if not isinstance(valid_docs_final, list): # Ensure it's a list
219
+ valid_docs_final = []
220
+
221
+ return final_result, valid_docs_final
222
+
223
+ except Exception as e:
224
+ st.error(f"Error getting final state or extracting results: {e}")
225
+ return None, []
226
+
227
+ # App header
228
+ st.title("Key Issue Generator")
229
+ st.write("Generate key issues from a Neo4j knowledge graph using advanced language models.")
230
+
231
+ # Check database connectivity
232
+ connectivity_status = verify_neo4j_connectivity()
233
+ st.sidebar.header("Database Status")
234
+ # Use boolean check
235
+ if connectivity_status is True:
236
+ st.sidebar.success("Connected to Neo4j database")
237
+ else:
238
+ # Display the error message returned
239
+ st.sidebar.error(f"Database connection issue: {connectivity_status}")
240
+
241
+ # User input section
242
+ st.header("Enter Your Query")
243
+ user_query = st.text_area("What would you like to explore?",
244
+ "What are the main challenges in AI adoption for healthcare systems?",
245
+ height=100)
246
+
247
+ # Process button
248
+ if st.button("Generate Key Issues", type="primary"):
249
+ # Update API key check for Gemini
250
+ if not OPENAI_API_KEY or not GEMINI_API_KEY or not LANGSMITH_API_KEY or not NEO4J_PASSWORD:
251
+ st.error("Required API keys (OpenAI, Gemini, Langsmith) or database credentials are missing. Please check your environment variables.")
252
+ elif connectivity_status is not True: # Check DB connection again before starting
253
+ st.error(f"Cannot start: Neo4j connection issue: {connectivity_status}")
254
+ else:
255
+ with st.spinner("Processing your query..."):
256
+ start_time = time.time()
257
+ # Call the main generation function
258
+ final_result, valid_docs = generate_key_issues(user_query)
259
+ end_time = time.time()
260
+
261
+ if final_result is not None: # Check if result is not None (indicating success)
262
+ # Display execution time
263
+ st.sidebar.info(f"Total execution time: {round(end_time - start_time, 2)} seconds")
264
+
265
+ # Display final result
266
+ st.header("Generated Key Issues")
267
+ st.markdown(final_result)
268
+
269
+ # Option to download results
270
+ st.download_button(
271
+ label="Download Results",
272
+ data=final_result, # Ensure final_result is string data
273
+ file_name="key_issues_results.txt",
274
+ mime="text/plain"
275
+ )
276
+
277
+ # Display retrieved/processed documents in expandable section
278
+ if valid_docs:
279
+ with st.expander("View Processed Documents"): # Update title
280
+ for i, doc in enumerate(valid_docs):
281
+ st.markdown(f"### Document {i+1}")
282
+ # Handle doc format (could be string summary or original dict)
283
+ if isinstance(doc, dict):
284
+ for key in doc:
285
+ st.markdown(f"**{key}**: {doc[key]}")
286
+ elif isinstance(doc, str):
287
+ st.markdown(doc) # Display string directly if it's a summary
288
+ else:
289
+ st.markdown(str(doc)) # Fallback for other types
290
+ st.divider()
291
+ else:
292
+ # Error messages are now shown within generate_key_issues
293
+ # st.error("An error occurred during processing. Please check the logs or console output for details.")
294
+ # Adding a placeholder here in case specific errors weren't caught
295
+ if final_result is None: # Check explicit None return
296
+ st.error("Processing failed. Please check the console/logs for errors.")
297
+
298
+
299
+ # Help information in sidebar
300
+ with st.sidebar:
301
+ st.header("About")
302
+ st.info("""
303
+ This application uses advanced language models (like Google Gemini) to analyze a Neo4j knowledge graph
304
+ and generate key issues based on your query. The process involves:
305
+
306
+ 1. Creating a plan based on your query
307
+ 2. Retrieving relevant documents from the database
308
+ 3. Processing and summarizing the information
309
+ 4. Generating a comprehensive response
310
  """)