Spaces:
Running
Running
import os | |
import requests | |
from dotenv import load_dotenv | |
from langgraph.graph import StateGraph, MessagesState, START | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langfuse.langchain import CallbackHandler | |
from tools.web_search import web_search | |
from tools.math import add_numbers_in_list, check_commutativity | |
from tools.extraction import extract_data_from_excel, extract_transcript_from_youtube, extract_transcript_from_audio | |
from rate_limiters import safe_invoke_with_retry_gemini | |
load_dotenv(override=True) | |
PROVIDER="google" | |
langfuse_handler = CallbackHandler() | |
tools = [ | |
add_numbers_in_list, | |
web_search, | |
check_commutativity, | |
extract_data_from_excel, | |
extract_transcript_from_youtube, | |
extract_transcript_from_audio | |
] | |
# --------------- Define the agent structure ---------------- # | |
def build_agent(provider: str = "hf"): | |
USE_RATE_LIMITER = os.getenv("USE_RATE_LIMITER", "false").lower() == "true" | |
print(f"Building agent with provider: {provider}") | |
if provider == "hf": | |
llm = HuggingFaceEndpoint( | |
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
task="text-generation", | |
temperature=0.0, | |
provider="hf-inference" | |
) | |
llm = ChatHuggingFace(llm=llm) | |
elif provider == "google": | |
# Google Gemini | |
llm = ChatGoogleGenerativeAI( | |
# model="gemini-2.0-flash", | |
model="gemini-2.5-flash-preview-05-20", | |
# model="gemini-2.5-flash-lite-preview-06-17", | |
max_tokens=2048, | |
max_retries=2, | |
) | |
elif provider == "openai": | |
llm = ChatOpenAI( | |
model="gpt-3.5-turbo", # or "gpt-3.5-turbo" | |
temperature=0, | |
api_key=os.getenv("OPENAI_API_KEY"), | |
max_tokens=512 | |
) | |
else: | |
raise ValueError(f"Unsupported provider: {provider}") | |
# Bind the tools to the LLM | |
llm_with_tools = llm.bind_tools(tools) | |
# load the system prompt from the file | |
with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
system_prompt = f.read() | |
# Create system message with the system prompt | |
sys_msg = SystemMessage(content=system_prompt) | |
# --------------- Define nodes ---------------- # | |
def assistant(state: MessagesState): | |
"""Node for the assistant to respond to user input.""" | |
if USE_RATE_LIMITER: | |
if provider == "google": | |
response = safe_invoke_with_retry_gemini( | |
llm_with_tools, | |
[sys_msg] + state["messages"], | |
max_retries=2, | |
wait_seconds=60 | |
) | |
else: | |
raise ValueError(f"Rate limiting is not implemented for provider {provider}.") | |
else: | |
response = llm_with_tools.invoke([sys_msg] + state["messages"]) | |
return {"messages": [response]} | |
tool_node = ToolNode(tools=tools) | |
# --------------- Build the state graph ---------------- # | |
graph_builder = StateGraph(MessagesState) | |
graph_builder.add_node("assistant", assistant) | |
graph_builder.add_node("tools", tool_node) | |
graph_builder.add_conditional_edges( | |
"assistant", | |
tools_condition, | |
) | |
graph_builder.add_edge("tools", "assistant") | |
graph_builder.add_edge(START, "assistant") | |
return graph_builder.compile() | |
# --------------- For manual testing ---------------- # | |
if __name__ == "__main__": | |
print("\n" + "-"*30 + " Agent Starting " + "-"*30) | |
# Print run variables in a table format | |
print(f"Provider: {PROVIDER}") | |
print(f"Search engine used: {'DDGS' if os.getenv('USE_DDGS').lower() == 'true' else 'Tavily'}") | |
agent = build_agent(provider=PROVIDER) # Change to "hf" for HuggingFace | |
print("Agent built successfully.") | |
print("-"*70) | |
# Get questions | |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
api_url = DEFAULT_API_URL | |
questions_url = f"{api_url}/questions" | |
files_url = f"{api_url}/files/" # Needs task_id | |
# 2. Fetch Questions | |
print(f"Fetching questions from: {questions_url}") | |
try: | |
response = requests.get(questions_url, timeout=15) | |
response.raise_for_status() | |
questions_data = response.json() | |
if not questions_data: | |
print("Fetched questions list is empty.") | |
print(f"Fetched {len(questions_data)} questions.") | |
except Exception as e: | |
print(f"An unexpected error occurred fetching questions: {e}") | |
# 3. Get specific question by task_id | |
# task_id = "8e867cd7-cff9-4e6c-867a-ff5ddc2550be" # Sosa albums | |
# task_id = "2d83110e-a098-4ebb-9987-066c06fa42d0" # Reverse text example | |
# task_id = "cca530fc-4052-43b2-b130-b30968d8aa44" # Chess image | |
# task_id = "4fc2f1ae-8625-45b5-ab34-ad4433bc21f8" # Dinosaur ? | |
# task_id = "6f37996b-2ac7-44b0-8e68-6d28256631b4" # Commutativity check | |
# task_id = "9d191bce-651d-4746-be2d-7ef8ecadb9c2" # Youtube video | |
# task_id = "cabe07ed-9eca-40ea-8ead-410ef5e83f91" # Louvrier ? | |
# task_id = "f918266a-b3e0-4914-865d-4faa564f1aef" # Code example | |
# task_id = "3f57289b-8c60-48be-bd80-01f8099ca449" # at bats ? | |
task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733" # Excel file | |
# task_id = "5a0c1adf-205e-4841-a666-7c3ef95def9d" # Malko competition | |
# task_id = "305ac316-eef6-4446-960a-92d80d542f82" # Poland film | |
# task_id = "bda648d7-d618-4883-88f4-3466eabd860e" # Vietnamese | |
# task_id = "cf106601-ab4f-4af9-b045-5295fe67b37d" # Olympics | |
# task_id = "a0c07678-e491-4bbc-8f0b-07405144218f" # pitchers | |
# task_id = "3cef3a44-215e-4aed-8e3b-b1e3f08063b7" # grocery list | |
# task_id = "840bfca7-4f7b-481a-8794-c560c340185d" # Carolyn Collins Petersen | |
# task_id = "1f975693-876d-457b-a649-393859e79bf3" # Audio (pages) | |
# task_id = "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3" # Audio (recipe) | |
# get question with task_id | |
q_data = next((item for item in questions_data if item["task_id"] == task_id), None) | |
content = [ | |
{"type": "text", "text": q_data["question"]} | |
] | |
if q_data["file_name"] != "": | |
file_url = f"{files_url}{task_id}" | |
if q_data["file_name"].endswith((".png", ".jpg", ".jpeg")): | |
content.append({"type": "image_url", "image_url": {"url": file_url}}) | |
elif q_data["file_name"].endswith((".py")): | |
# For code files, we can just send the text content | |
try: | |
response = requests.get(file_url, timeout=15) | |
response.raise_for_status() | |
code_content = response.text | |
content.append({"type": "text", "text": code_content}) | |
except Exception as e: | |
print(f"Error fetching code file: {e}") | |
elif q_data["file_name"].endswith((".xlsx", ".xls")): | |
content.append({"type": "text", "text": "Excel file url: " + file_url}) | |
elif q_data["file_name"].endswith((".mp3", ".wav")): | |
content.append({"type": "text", "text": "Audio file url: " + file_url}) | |
else: | |
content.append({"type": "text", "text": f"File URL: {file_url} (file type not supported)"}) | |
human_msg = HumanMessage(content=content) | |
human_msg.pretty_print() | |
try: | |
result = agent.invoke( | |
{"messages": [human_msg]}, | |
config={"callbacks": [langfuse_handler]} | |
) | |
for message in result["messages"]: | |
message.pretty_print() | |
# Result already printed inside assistant() node | |
except Exception as e: | |
print(f"Error: {e}") |