Spaces:
Sleeping
Sleeping
import os | |
import getpass | |
import html | |
from typing import Annotated, Union | |
from typing_extensions import TypedDict | |
from langchain_community.graphs import Neo4jGraph | |
# Remove ChatGroq import | |
# from langchain_groq import ChatGroq | |
# Add ChatGoogleGenerativeAI import | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_openai import ChatOpenAI | |
from langgraph.checkpoint.sqlite import SqliteSaver | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.checkpoint import base | |
from langgraph.graph import add_messages | |
memory = MemorySaver() | |
def format_df(df): | |
""" | |
Used to display the generated plan in a nice format | |
Returns html code in a string | |
""" | |
def format_cell(cell): | |
if isinstance(cell, str): | |
# Encode special characters, but preserve line breaks | |
return html.escape(cell).replace('\n', '<br>') | |
return cell | |
# Convert the DataFrame to HTML with custom CSS | |
formatted_df = df.map(format_cell) | |
html_table = formatted_df.to_html(escape=False, index=False) | |
# Add custom CSS to allow multiple lines and scrolling in cells | |
css = """ | |
<style> | |
table { | |
border-collapse: collapse; | |
width: 100%; | |
} | |
th, td { | |
border: 1px solid black; | |
padding: 8px; | |
text-align: left; | |
vertical-align: top; | |
white-space: pre-wrap; | |
max-width: 300px; | |
max-height: 100px; | |
overflow-y: auto; | |
} | |
th { | |
background-color: #f2f2f2; | |
} | |
</style> | |
""" | |
return css + html_table | |
def format_doc(doc: dict) -> str : | |
formatted_string = "" | |
for key in doc: | |
formatted_string += f"**{key}**: {doc[key]}\n" | |
return formatted_string | |
def _set_env(var: str, value: str = None): | |
if not os.environ.get(var): | |
if value: | |
os.environ[var] = value | |
else: | |
os.environ[var] = getpass.getpass(f"{var}: ") | |
# Remove groq_key parameter | |
def init_app(openai_key : str = None, langsmith_key : str = None): | |
""" | |
Initialize app with user api keys and sets up proxy settings | |
""" | |
# Remove setting GROQ_API_KEY | |
# _set_env("GROQ_API_KEY", value=os.getenv("groq_api_key")) | |
_set_env("LANGSMITH_API_KEY", value=os.getenv("langsmith_api_key")) | |
_set_env("OPENAI_API_KEY", value=os.getenv("openai_api_key")) | |
# Make sure GEMINI_API_KEY is set if needed elsewhere, though ChatGoogleGenerativeAI reads it automatically | |
_set_env("GEMINI_API_KEY", value=os.getenv("gemini_api_key")) | |
os.environ["LANGSMITH_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_PROJECT"] = "3GPP Test" | |
def clear_memory(memory, thread_id: str = "") -> None: | |
""" | |
Clears checkpointer state for a given thread_id, broken for now | |
TODO : fix this | |
""" | |
memory = MemorySaver() | |
#checkpoint = base.empty_checkpoint() | |
#memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={}) | |
# Update get_model to use ChatGoogleGenerativeAI | |
def get_model(model : str = "gemini-2.0-flash"): | |
""" | |
Wrapper to return the correct llm object depending on the 'model' param | |
""" | |
if model == "gpt-4o": | |
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/") | |
# Check for gemini models | |
elif model.startswith("gemini"): | |
# Pass the API key explicitly, although it often reads from env var by default | |
llm = ChatGoogleGenerativeAI(model=model, google_api_key=os.getenv("gemini_api_key")) | |
else: | |
# Fallback or handle other models if necessary, maybe raise an error | |
# For now, defaulting to Gemini if model name doesn't match others | |
print(f"Warning: Model '{model}' not explicitly handled. Defaulting to Gemini.") | |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.getenv("gemini_api_key")) | |
return llm | |
class ConfigSchema(TypedDict): | |
graph: Neo4jGraph | |
plan_method: str | |
use_detailed_query: bool | |
class State(TypedDict): | |
messages : Annotated[list, add_messages] | |
store_plan : list[str] | |
current_plan_step : int | |
valid_docs : list[str] | |
class DocRetrieverState(TypedDict): | |
messages: Annotated[list, add_messages] | |
query: str | |
docs: list[dict] | |
cyphers: list[str] | |
current_plan_step : int | |
valid_docs: list[Union[str, dict]] | |
class HumanValidationState(TypedDict): | |
human_validated : bool | |
process_steps : list[str] | |
def update_doc_history(left : list | None, right : list | None) -> list: | |
""" | |
Reducer for the 'docs_in_processing' field. | |
Doesn't work currently because of bad handlinf of duplicates | |
TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state) | |
""" | |
if not left: | |
# This shouldn't happen | |
left = [[]] | |
if not right: | |
right = [] | |
for i in range(len(right)): | |
left[i].append(right[i]) | |
return left | |
class DocProcessorState(TypedDict): | |
valid_docs : list[Union[str, dict]] | |
docs_in_processing : list | |
process_steps : list[Union[str,dict]] | |
current_process_step : int |