File size: 8,224 Bytes
f8ac349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import streamlit as st
from langchain_community.graphs import Neo4jGraph
import pandas as pd
import json
import time

from ki_gen.planner import build_planner_graph
from ki_gen.utils import init_app, memory
from ki_gen.prompts import get_initial_prompt

from neo4j import GraphDatabase

# Set page config
st.set_page_config(page_title="Key Issue Generator", layout="wide")

# Neo4j Database Configuration
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = os.getenv("neo4j_password")

# API Keys for LLM services
OPENAI_API_KEY = os.getenv("openai_api_key")
GROQ_API_KEY = os.getenv("groq_api_key")
LANGSMITH_API_KEY = os.getenv("langsmith_api_key")

def verify_neo4j_connectivity():
    """Verify connection to Neo4j database"""
    try:
        with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
            return driver.verify_connectivity()
    except Exception as e:
        return f"Error: {str(e)}"

def load_config():
    """Load configuration with custom parameters"""
    # Custom configuration based on provided parameters
    custom_config = {
        "main_llm": "deepseek-r1-distill-llama-70b",
        "plan_method": "generation",
        "use_detailed_query": False,
        "cypher_gen_method": "guided",
        "validate_cypher": False,
        "summarize_model": "deepseek-r1-distill-llama-70b",
        "eval_method": "binary",
        "eval_threshold": 0.7,
        "max_docs": 15,
        "compression_method": "llm_lingua",
        "compress_rate": 0.33,
        "force_tokens": ["."],  # Converting to list format as expected by the application
        "eval_model": "deepseek-r1-distill-llama-70b",
        "thread_id": "3"
    }
    
    # Add Neo4j graph object to config
    try:
        neo_graph = Neo4jGraph(
            url=NEO4J_URI, 
            username=NEO4J_USERNAME, 
            password=NEO4J_PASSWORD
        )
        custom_config["graph"] = neo_graph
    except Exception as e:
        st.error(f"Error connecting to Neo4j: {e}")
        return None
        
    return {"configurable": custom_config}

def generate_key_issues(user_query):
    """Main function to generate key issues from Neo4j data"""
    # Initialize application with API keys
    init_app(
        openai_key=OPENAI_API_KEY,
        groq_key=GROQ_API_KEY,
        langsmith_key=LANGSMITH_API_KEY
    )
    
    # Load configuration with custom parameters
    config = load_config()
    if not config:
        return None
    
    # Create status containers
    plan_status = st.empty()
    plan_display = st.empty()
    retrieval_status = st.empty()
    processing_status = st.empty()
    
    # Build planner graph
    plan_status.info("Building planner graph...")
    graph = build_planner_graph(memory, config["configurable"])
    
    # Execute initial prompt generation
    plan_status.info(f"Generating plan for query: {user_query}")
    
    messages_content = []
    for event in graph.stream(get_initial_prompt(config, user_query), config, stream_mode="values"):
        if "messages" in event:
            event["messages"][-1].pretty_print()
            messages_content.append(event["messages"][-1].content)
    
    # Get the state with the generated plan
    state = graph.get_state(config)
    steps = [i for i in range(1, len(state.values['store_plan'])+1)]
    plan_df = pd.DataFrame({'Plan steps': steps, 'Description': state.values['store_plan']})
    
    # Display the plan
    plan_status.success("Plan generation complete!")
    plan_display.dataframe(plan_df, use_container_width=True)
    
    # Continue with plan execution for document retrieval
    retrieval_status.info("Retrieving documents...")
    for event in graph.stream(None, config, stream_mode="values"):
        if "messages" in event:
            event["messages"][-1].pretty_print()
            messages_content.append(event["messages"][-1].content)
    
    # Get updated state after document retrieval
    snapshot = graph.get_state(config)
    doc_count = len(snapshot.values.get('valid_docs', []))
    retrieval_status.success(f"Retrieved {doc_count} documents")
    
    # Proceed to document processing
    processing_status.info("Processing documents...")
    process_steps = ["summarize"]  # Using summarize as default processing step
    
    # Update state to indicate human validation is complete and specify processing steps
    graph.update_state(config, {'human_validated': True, 'process_steps': process_steps}, as_node="human_validation")
    
    # Continue execution with document processing
    for event in graph.stream(None, config, stream_mode="values"):
        if "messages" in event:
            event["messages"][-1].pretty_print()
            messages_content.append(event["messages"][-1].content)
    
    # Get final state after processing
    final_snapshot = graph.get_state(config)
    processing_status.success("Document processing complete!")
    
    if "messages" in final_snapshot.values:
        final_result = final_snapshot.values["messages"][-1].content
        return final_result, final_snapshot.values.get('valid_docs', [])
    
    return None, []

# App header
st.title("Key Issue Generator")
st.write("Generate key issues from a Neo4j knowledge graph using advanced language models.")

# Check database connectivity
connectivity_status = verify_neo4j_connectivity()
st.sidebar.header("Database Status")
if "Error" not in str(connectivity_status):
    st.sidebar.success("Connected to Neo4j database")
else:
    st.sidebar.error(f"Database connection issue: {connectivity_status}")

# User input section
st.header("Enter Your Query")
user_query = st.text_area("What would you like to explore?", 
                         "What are the main challenges in AI adoption for healthcare systems?", 
                         height=100)

# Process button
if st.button("Generate Key Issues", type="primary"):
    if not OPENAI_API_KEY or not GROQ_API_KEY or not LANGSMITH_API_KEY or not NEO4J_PASSWORD:
        st.error("Required API keys or database credentials are missing. Please check your environment variables.")
    else:
        with st.spinner("Processing your query..."):
            start_time = time.time()
            final_result, valid_docs = generate_key_issues(user_query)
            end_time = time.time()
            
            if final_result:
                # Display execution time
                st.sidebar.info(f"Total execution time: {round(end_time - start_time, 2)} seconds")
                
                # Display final result
                st.header("Generated Key Issues")
                st.markdown(final_result)
                
                # Option to download results
                st.download_button(
                    label="Download Results",
                    data=final_result,
                    file_name="key_issues_results.txt",
                    mime="text/plain"
                )
                
                # Display retrieved documents in expandable section
                if valid_docs:
                    with st.expander("View Retrieved Documents"):
                        for i, doc in enumerate(valid_docs):
                            st.markdown(f"### Document {i+1}")
                            for key in doc:
                                st.markdown(f"**{key}**: {doc[key]}")
                            st.divider()
            else:
                st.error("An error occurred during processing. Please check the logs for details.")

# Help information in sidebar
with st.sidebar:
    st.header("About")
    st.info("""

    This application uses advanced language models to analyze a Neo4j knowledge graph and generate key issues 

    based on your query. The process involves:

    

    1. Creating a plan based on your query

    2. Retrieving relevant documents from the database

    3. Processing and summarizing the information

    4. Generating a comprehensive response

    """)