|
from typing import Literal, Annotated |
|
from typing_extensions import TypedDict |
|
import json |
|
import tempfile |
|
import os |
|
|
|
from langchain_core.runnables import RunnableLambda, RunnableConfig |
|
|
|
from langgraph.graph import END |
|
from langgraph.types import Command |
|
from langgraph.prebuilt import InjectedState |
|
|
|
from langchain_community.utilities import BingSearchAPIWrapper |
|
from langchain_community.tools.bing_search.tool import BingSearchResults |
|
from langchain_community.document_loaders import JSONLoader |
|
|
|
from langchain.agents import tool |
|
|
|
from ea4all.src.shared.configuration import ( |
|
BaseConfiguration |
|
) |
|
|
|
from ea4all.src.shared.state import ( |
|
State |
|
) |
|
|
|
from ea4all.src.shared.utils import ( |
|
get_llm_client, |
|
format_docs, |
|
) |
|
|
|
def make_supervisor_node(config: RunnableConfig, members: list[str]) -> RunnableLambda: |
|
options = ["FINISH"] + members |
|
system_prompt = ( |
|
"You are a supervisor tasked with managing a conversation between the" |
|
f" following workers: {members}. Given the following user request," |
|
" respond with the worker to act next. Each worker will perform a" |
|
" task and respond with their results and status. When finished," |
|
" respond with FINISH." |
|
) |
|
|
|
configuration = BaseConfiguration.from_runnable_config(config) |
|
model = get_llm_client( |
|
configuration.supervisor_model, |
|
api_base_url="", |
|
) |
|
|
|
class Router(TypedDict): |
|
"""Worker to route to next. If no workers needed, route to FINISH.""" |
|
|
|
next: Literal[*options] |
|
|
|
def supervisor_node(state: State) -> Command[Literal[*members, "__end__"]]: |
|
"""An LLM-based router.""" |
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
] + state["messages"] |
|
response = model.with_structured_output(Router).invoke(messages) |
|
goto = response["next"] |
|
if goto == "FINISH": |
|
goto = END |
|
|
|
return Command(goto=goto, update={"next": goto}) |
|
|
|
return RunnableLambda(supervisor_node) |
|
|
|
async def websearch(state: dict[str, dict | str]) -> dict[str,dict[str,str]]: |
|
""" |
|
Web search based on the re-phrased question. |
|
|
|
Args: |
|
state (dict): The current graph state |
|
config (RunnableConfig): Configuration with the model used for query analysis. |
|
|
|
Returns: |
|
state (dict): Updates documents key with appended web results |
|
""" |
|
|
|
|
|
bing_subscription_key = os.environ.get("BING_SUBSCRIPTION_KEY", "") |
|
bing_search_url = os.environ.get("BING_SEARCH_URL", "https://api.bing.microsoft.com/v7.0/search") |
|
search = BingSearchAPIWrapper( |
|
bing_subscription_key=bing_subscription_key, |
|
bing_search_url=bing_search_url |
|
) |
|
|
|
question = getattr(state,'messages')[-1].content if getattr(state,'messages', False) else getattr(state,'question') |
|
|
|
|
|
web_results = BingSearchResults( |
|
api_wrapper=search, |
|
handle_tool_error=True, |
|
args_schema={"k":"5"}, |
|
) |
|
|
|
result = await web_results.ainvoke({"query": question}) |
|
|
|
fixed_string = result.replace("'", "\"") |
|
result_json = json.loads(fixed_string) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file: |
|
|
|
json.dump(result_json, temp_file) |
|
temp_file.flush() |
|
|
|
|
|
loader = JSONLoader(file_path=temp_file.name, jq_schema=".[]", text_content=False) |
|
docs = loader.load() |
|
|
|
return {"messages": {"role":"assistant", "content":format_docs(docs)}} |
|
|