import os
import getpass
import html
from typing import Annotated, Union
from typing_extensions import TypedDict
from langchain_community.graphs import Neo4jGraph
# Remove ChatGroq import
# from langchain_groq import ChatGroq
# Add ChatGoogleGenerativeAI import
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint import base
from langgraph.graph import add_messages
memory = MemorySaver()
def format_df(df):
"""
Used to display the generated plan in a nice format
Returns html code in a string
"""
def format_cell(cell):
if isinstance(cell, str):
# Encode special characters, but preserve line breaks
return html.escape(cell).replace('\n', '
')
return cell
# Convert the DataFrame to HTML with custom CSS
formatted_df = df.map(format_cell)
html_table = formatted_df.to_html(escape=False, index=False)
# Add custom CSS to allow multiple lines and scrolling in cells
css = """
"""
return css + html_table
def format_doc(doc: dict) -> str :
formatted_string = ""
for key in doc:
formatted_string += f"**{key}**: {doc[key]}\n"
return formatted_string
def _set_env(var: str, value: str = None):
if not os.environ.get(var):
if value:
os.environ[var] = value
else:
os.environ[var] = getpass.getpass(f"{var}: ")
# Remove groq_key parameter
def init_app(openai_key : str = None, langsmith_key : str = None):
"""
Initialize app with user api keys and sets up proxy settings
"""
# Remove setting GROQ_API_KEY
# _set_env("GROQ_API_KEY", value=os.getenv("groq_api_key"))
_set_env("LANGSMITH_API_KEY", value=os.getenv("langsmith_api_key"))
_set_env("OPENAI_API_KEY", value=os.getenv("openai_api_key"))
# Make sure GEMINI_API_KEY is set if needed elsewhere, though ChatGoogleGenerativeAI reads it automatically
_set_env("GEMINI_API_KEY", value=os.getenv("gemini_api_key"))
os.environ["LANGSMITH_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
def clear_memory(memory, thread_id: str = "") -> None:
"""
Clears checkpointer state for a given thread_id, broken for now
TODO : fix this
"""
memory = MemorySaver()
#checkpoint = base.empty_checkpoint()
#memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
# Update get_model to use ChatGoogleGenerativeAI
def get_model(model : str = "gemini-2.0-flash"):
"""
Wrapper to return the correct llm object depending on the 'model' param
"""
if model == "gpt-4o":
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
# Check for gemini models
elif model.startswith("gemini"):
# Pass the API key explicitly, although it often reads from env var by default
llm = ChatGoogleGenerativeAI(model=model, google_api_key=os.getenv("gemini_api_key"))
else:
# Fallback or handle other models if necessary, maybe raise an error
# For now, defaulting to Gemini if model name doesn't match others
print(f"Warning: Model '{model}' not explicitly handled. Defaulting to Gemini.")
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.getenv("gemini_api_key"))
return llm
class ConfigSchema(TypedDict):
graph: Neo4jGraph
plan_method: str
use_detailed_query: bool
class State(TypedDict):
messages : Annotated[list, add_messages]
store_plan : list[str]
current_plan_step : int
valid_docs : list[str]
class DocRetrieverState(TypedDict):
messages: Annotated[list, add_messages]
query: str
docs: list[dict]
cyphers: list[str]
current_plan_step : int
valid_docs: list[Union[str, dict]]
class HumanValidationState(TypedDict):
human_validated : bool
process_steps : list[str]
def update_doc_history(left : list | None, right : list | None) -> list:
"""
Reducer for the 'docs_in_processing' field.
Doesn't work currently because of bad handlinf of duplicates
TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state)
"""
if not left:
# This shouldn't happen
left = [[]]
if not right:
right = []
for i in range(len(right)):
left[i].append(right[i])
return left
class DocProcessorState(TypedDict):
valid_docs : list[Union[str, dict]]
docs_in_processing : list
process_steps : list[Union[str,dict]]
current_process_step : int