Spaces:
Running
Running
File size: 7,852 Bytes
448903c dfad45c 3b31030 448903c dfad45c 3b31030 448903c 3b31030 448903c 69fe8b2 3b31030 448903c 3b31030 448903c 99bb959 448903c 3b31030 dfad45c 448903c 3b31030 448903c dfad45c 448903c dfad45c 3b31030 dfad45c 448903c dfad45c 799013a 448903c 3b31030 448903c 799013a 3b31030 448903c 3b31030 448903c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import os
import requests
from dotenv import load_dotenv
from langgraph.graph import StateGraph, MessagesState, START
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_google_genai import ChatGoogleGenerativeAI
from langfuse.langchain import CallbackHandler
from tools.web_search import web_search
from tools.math import add_numbers_in_list, check_commutativity
from tools.extraction import extract_data_from_excel, extract_transcript_from_youtube, extract_transcript_from_audio
from rate_limiters import safe_invoke_with_retry_gemini
load_dotenv(override=True)
PROVIDER="google"
langfuse_handler = CallbackHandler()
tools = [
add_numbers_in_list,
web_search,
check_commutativity,
extract_data_from_excel,
extract_transcript_from_youtube,
extract_transcript_from_audio
]
# --------------- Define the agent structure ---------------- #
def build_agent(provider: str = "hf"):
USE_RATE_LIMITER = os.getenv("USE_RATE_LIMITER", "false").lower() == "true"
print(f"Building agent with provider: {provider}")
if provider == "hf":
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
task="text-generation",
temperature=0.0,
provider="hf-inference"
)
llm = ChatHuggingFace(llm=llm)
elif provider == "google":
# Google Gemini
llm = ChatGoogleGenerativeAI(
# model="gemini-2.0-flash",
model="gemini-2.5-flash-preview-05-20",
# model="gemini-2.5-flash-lite-preview-06-17",
max_tokens=2048,
max_retries=2,
)
elif provider == "openai":
llm = ChatOpenAI(
model="gpt-3.5-turbo", # or "gpt-3.5-turbo"
temperature=0,
api_key=os.getenv("OPENAI_API_KEY"),
max_tokens=512
)
else:
raise ValueError(f"Unsupported provider: {provider}")
# Bind the tools to the LLM
llm_with_tools = llm.bind_tools(tools)
# load the system prompt from the file
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
# Create system message with the system prompt
sys_msg = SystemMessage(content=system_prompt)
# --------------- Define nodes ---------------- #
def assistant(state: MessagesState):
"""Node for the assistant to respond to user input."""
if USE_RATE_LIMITER:
if provider == "google":
response = safe_invoke_with_retry_gemini(
llm_with_tools,
[sys_msg] + state["messages"],
max_retries=2,
wait_seconds=60
)
else:
raise ValueError(f"Rate limiting is not implemented for provider {provider}.")
else:
response = llm_with_tools.invoke([sys_msg] + state["messages"])
return {"messages": [response]}
tool_node = ToolNode(tools=tools)
# --------------- Build the state graph ---------------- #
graph_builder = StateGraph(MessagesState)
graph_builder.add_node("assistant", assistant)
graph_builder.add_node("tools", tool_node)
graph_builder.add_conditional_edges(
"assistant",
tools_condition,
)
graph_builder.add_edge("tools", "assistant")
graph_builder.add_edge(START, "assistant")
return graph_builder.compile()
# --------------- For manual testing ---------------- #
if __name__ == "__main__":
print("\n" + "-"*30 + " Agent Starting " + "-"*30)
# Print run variables in a table format
print(f"Provider: {PROVIDER}")
print(f"Search engine used: {'DDGS' if os.getenv('USE_DDGS').lower() == 'true' else 'Tavily'}")
agent = build_agent(provider=PROVIDER) # Change to "hf" for HuggingFace
print("Agent built successfully.")
print("-"*70)
# Get questions
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
files_url = f"{api_url}/files/" # Needs task_id
# 2. Fetch Questions
print(f"Fetching questions from: {questions_url}")
try:
response = requests.get(questions_url, timeout=15)
response.raise_for_status()
questions_data = response.json()
if not questions_data:
print("Fetched questions list is empty.")
print(f"Fetched {len(questions_data)} questions.")
except Exception as e:
print(f"An unexpected error occurred fetching questions: {e}")
# 3. Get specific question by task_id
# task_id = "8e867cd7-cff9-4e6c-867a-ff5ddc2550be" # Sosa albums
# task_id = "2d83110e-a098-4ebb-9987-066c06fa42d0" # Reverse text example
# task_id = "cca530fc-4052-43b2-b130-b30968d8aa44" # Chess image
# task_id = "4fc2f1ae-8625-45b5-ab34-ad4433bc21f8" # Dinosaur ?
# task_id = "6f37996b-2ac7-44b0-8e68-6d28256631b4" # Commutativity check
# task_id = "9d191bce-651d-4746-be2d-7ef8ecadb9c2" # Youtube video
# task_id = "cabe07ed-9eca-40ea-8ead-410ef5e83f91" # Louvrier ?
# task_id = "f918266a-b3e0-4914-865d-4faa564f1aef" # Code example
# task_id = "3f57289b-8c60-48be-bd80-01f8099ca449" # at bats ?
task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733" # Excel file
# task_id = "5a0c1adf-205e-4841-a666-7c3ef95def9d" # Malko competition
# task_id = "305ac316-eef6-4446-960a-92d80d542f82" # Poland film
# task_id = "bda648d7-d618-4883-88f4-3466eabd860e" # Vietnamese
# task_id = "cf106601-ab4f-4af9-b045-5295fe67b37d" # Olympics
# task_id = "a0c07678-e491-4bbc-8f0b-07405144218f" # pitchers
# task_id = "3cef3a44-215e-4aed-8e3b-b1e3f08063b7" # grocery list
# task_id = "840bfca7-4f7b-481a-8794-c560c340185d" # Carolyn Collins Petersen
# task_id = "1f975693-876d-457b-a649-393859e79bf3" # Audio (pages)
# task_id = "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3" # Audio (recipe)
# get question with task_id
q_data = next((item for item in questions_data if item["task_id"] == task_id), None)
content = [
{"type": "text", "text": q_data["question"]}
]
if q_data["file_name"] != "":
file_url = f"{files_url}{task_id}"
if q_data["file_name"].endswith((".png", ".jpg", ".jpeg")):
content.append({"type": "image_url", "image_url": {"url": file_url}})
elif q_data["file_name"].endswith((".py")):
# For code files, we can just send the text content
try:
response = requests.get(file_url, timeout=15)
response.raise_for_status()
code_content = response.text
content.append({"type": "text", "text": code_content})
except Exception as e:
print(f"Error fetching code file: {e}")
elif q_data["file_name"].endswith((".xlsx", ".xls")):
content.append({"type": "text", "text": "Excel file url: " + file_url})
elif q_data["file_name"].endswith((".mp3", ".wav")):
content.append({"type": "text", "text": "Audio file url: " + file_url})
else:
content.append({"type": "text", "text": f"File URL: {file_url} (file type not supported)"})
human_msg = HumanMessage(content=content)
human_msg.pretty_print()
try:
result = agent.invoke(
{"messages": [human_msg]},
config={"callbacks": [langfuse_handler]}
)
for message in result["messages"]:
message.pretty_print()
# Result already printed inside assistant() node
except Exception as e:
print(f"Error: {e}") |