File size: 5,244 Bytes
70d06c8
 
 
 
 
 
 
 
 
19491ad
 
 
 
70d06c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19491ad
 
70d06c8
 
 
19491ad
 
70d06c8
 
19491ad
 
70d06c8
 
 
 
 
 
 
 
 
 
 
 
 
 
19491ad
 
70d06c8
 
 
 
 
19491ad
 
 
 
70d06c8
19491ad
 
 
 
70d06c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19491ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import getpass
import html


from typing import Annotated, Union
from typing_extensions import TypedDict

from langchain_community.graphs import Neo4jGraph
# Remove ChatGroq import
# from langchain_groq import ChatGroq 
# Add ChatGoogleGenerativeAI import
from langchain_google_genai import ChatGoogleGenerativeAI 
from langchain_openai import ChatOpenAI

from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint import base
from langgraph.graph import add_messages

memory = MemorySaver()

def format_df(df):
    """
    Used to display the generated plan in a nice format
    Returns html code in a string
    """
    def format_cell(cell):
        if isinstance(cell, str):
            # Encode special characters, but preserve line breaks
            return html.escape(cell).replace('\n', '<br>')
        return cell
    # Convert the DataFrame to HTML with custom CSS
    formatted_df = df.map(format_cell)
    html_table = formatted_df.to_html(escape=False, index=False)
    
    # Add custom CSS to allow multiple lines and scrolling in cells
    css = """
    <style>
        table {
            border-collapse: collapse;
            width: 100%;
        }
        th, td {
            border: 1px solid black;
            padding: 8px;
            text-align: left;
            vertical-align: top;
            white-space: pre-wrap;
            max-width: 300px;
            max-height: 100px;
            overflow-y: auto;
        }
        th {
            background-color: #f2f2f2;
        }
    </style>
    """
    
    return css + html_table

def format_doc(doc: dict) -> str :
    formatted_string = ""
    for key in doc:
        formatted_string += f"**{key}**: {doc[key]}\n"
    return formatted_string



def _set_env(var: str, value: str = None):
    if not os.environ.get(var):
        if value:
            os.environ[var] = value
        else:
            os.environ[var] = getpass.getpass(f"{var}: ")


# Remove groq_key parameter
def init_app(openai_key : str = None, langsmith_key : str = None): 
    """
    Initialize app with user api keys and sets up proxy settings
    """
    # Remove setting GROQ_API_KEY
    # _set_env("GROQ_API_KEY", value=os.getenv("groq_api_key")) 
    _set_env("LANGSMITH_API_KEY", value=os.getenv("langsmith_api_key"))
    _set_env("OPENAI_API_KEY", value=os.getenv("openai_api_key"))
    # Make sure GEMINI_API_KEY is set if needed elsewhere, though ChatGoogleGenerativeAI reads it automatically
    _set_env("GEMINI_API_KEY", value=os.getenv("gemini_api_key"))
    os.environ["LANGSMITH_TRACING_V2"] = "true"
    os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
    

def clear_memory(memory, thread_id: str = "") -> None:
    """
    Clears checkpointer state for a given thread_id, broken for now
    TODO : fix this
    """
    memory = MemorySaver()

    #checkpoint = base.empty_checkpoint()
    #memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})  

# Update get_model to use ChatGoogleGenerativeAI
def get_model(model : str = "gemini-2.0-flash"): 
    """
    Wrapper to return the correct llm object depending on the 'model' param
    """
    if model == "gpt-4o":
        llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
    # Check for gemini models
    elif model.startswith("gemini"):
        # Pass the API key explicitly, although it often reads from env var by default
        llm = ChatGoogleGenerativeAI(model=model, google_api_key=os.getenv("gemini_api_key")) 
    else:
        # Fallback or handle other models if necessary, maybe raise an error
        # For now, defaulting to Gemini if model name doesn't match others
        print(f"Warning: Model '{model}' not explicitly handled. Defaulting to Gemini.")
        llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.getenv("gemini_api_key"))
    return llm


class ConfigSchema(TypedDict):
    graph: Neo4jGraph
    plan_method: str
    use_detailed_query: bool

class State(TypedDict):
    messages : Annotated[list, add_messages]
    store_plan : list[str]
    current_plan_step : int
    valid_docs : list[str]

class DocRetrieverState(TypedDict):
    messages: Annotated[list, add_messages]
    query: str
    docs: list[dict]
    cyphers: list[str]
    current_plan_step : int
    valid_docs: list[Union[str, dict]]

class HumanValidationState(TypedDict):
    human_validated : bool
    process_steps : list[str]

def update_doc_history(left : list | None, right : list | None) -> list:
    """
    Reducer for the 'docs_in_processing' field.
    Doesn't work currently because of bad handlinf of duplicates
    TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state)
    """
    if not left:
        # This shouldn't happen
        left = [[]]
    if not right:
        right = []

    for i in range(len(right)):
        left[i].append(right[i])
    return left


class DocProcessorState(TypedDict):
    valid_docs : list[Union[str, dict]]
    docs_in_processing : list
    process_steps : list[Union[str,dict]]
    current_process_step : int