Spaces:
Sleeping
Sleeping
File size: 5,244 Bytes
70d06c8 19491ad 70d06c8 19491ad 70d06c8 19491ad 70d06c8 19491ad 70d06c8 19491ad 70d06c8 19491ad 70d06c8 19491ad 70d06c8 19491ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os
import getpass
import html
from typing import Annotated, Union
from typing_extensions import TypedDict
from langchain_community.graphs import Neo4jGraph
# Remove ChatGroq import
# from langchain_groq import ChatGroq
# Add ChatGoogleGenerativeAI import
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint import base
from langgraph.graph import add_messages
memory = MemorySaver()
def format_df(df):
"""
Used to display the generated plan in a nice format
Returns html code in a string
"""
def format_cell(cell):
if isinstance(cell, str):
# Encode special characters, but preserve line breaks
return html.escape(cell).replace('\n', '<br>')
return cell
# Convert the DataFrame to HTML with custom CSS
formatted_df = df.map(format_cell)
html_table = formatted_df.to_html(escape=False, index=False)
# Add custom CSS to allow multiple lines and scrolling in cells
css = """
<style>
table {
border-collapse: collapse;
width: 100%;
}
th, td {
border: 1px solid black;
padding: 8px;
text-align: left;
vertical-align: top;
white-space: pre-wrap;
max-width: 300px;
max-height: 100px;
overflow-y: auto;
}
th {
background-color: #f2f2f2;
}
</style>
"""
return css + html_table
def format_doc(doc: dict) -> str :
formatted_string = ""
for key in doc:
formatted_string += f"**{key}**: {doc[key]}\n"
return formatted_string
def _set_env(var: str, value: str = None):
if not os.environ.get(var):
if value:
os.environ[var] = value
else:
os.environ[var] = getpass.getpass(f"{var}: ")
# Remove groq_key parameter
def init_app(openai_key : str = None, langsmith_key : str = None):
"""
Initialize app with user api keys and sets up proxy settings
"""
# Remove setting GROQ_API_KEY
# _set_env("GROQ_API_KEY", value=os.getenv("groq_api_key"))
_set_env("LANGSMITH_API_KEY", value=os.getenv("langsmith_api_key"))
_set_env("OPENAI_API_KEY", value=os.getenv("openai_api_key"))
# Make sure GEMINI_API_KEY is set if needed elsewhere, though ChatGoogleGenerativeAI reads it automatically
_set_env("GEMINI_API_KEY", value=os.getenv("gemini_api_key"))
os.environ["LANGSMITH_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
def clear_memory(memory, thread_id: str = "") -> None:
"""
Clears checkpointer state for a given thread_id, broken for now
TODO : fix this
"""
memory = MemorySaver()
#checkpoint = base.empty_checkpoint()
#memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
# Update get_model to use ChatGoogleGenerativeAI
def get_model(model : str = "gemini-2.0-flash"):
"""
Wrapper to return the correct llm object depending on the 'model' param
"""
if model == "gpt-4o":
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
# Check for gemini models
elif model.startswith("gemini"):
# Pass the API key explicitly, although it often reads from env var by default
llm = ChatGoogleGenerativeAI(model=model, google_api_key=os.getenv("gemini_api_key"))
else:
# Fallback or handle other models if necessary, maybe raise an error
# For now, defaulting to Gemini if model name doesn't match others
print(f"Warning: Model '{model}' not explicitly handled. Defaulting to Gemini.")
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.getenv("gemini_api_key"))
return llm
class ConfigSchema(TypedDict):
graph: Neo4jGraph
plan_method: str
use_detailed_query: bool
class State(TypedDict):
messages : Annotated[list, add_messages]
store_plan : list[str]
current_plan_step : int
valid_docs : list[str]
class DocRetrieverState(TypedDict):
messages: Annotated[list, add_messages]
query: str
docs: list[dict]
cyphers: list[str]
current_plan_step : int
valid_docs: list[Union[str, dict]]
class HumanValidationState(TypedDict):
human_validated : bool
process_steps : list[str]
def update_doc_history(left : list | None, right : list | None) -> list:
"""
Reducer for the 'docs_in_processing' field.
Doesn't work currently because of bad handlinf of duplicates
TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state)
"""
if not left:
# This shouldn't happen
left = [[]]
if not right:
right = []
for i in range(len(right)):
left[i].append(right[i])
return left
class DocProcessorState(TypedDict):
valid_docs : list[Union[str, dict]]
docs_in_processing : list
process_steps : list[Union[str,dict]]
current_process_step : int |