Spaces:
Running
Running
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import re | |
from dataclasses import dataclass | |
from typing import Any, Dict, Optional | |
from .local_python_executor import ( | |
BASE_BUILTIN_MODULES, | |
BASE_PYTHON_TOOLS, | |
evaluate_python_code, | |
) | |
from .tools import PipelineTool, Tool | |
class PreTool: | |
name: str | |
inputs: Dict[str, str] | |
output_type: type | |
task: str | |
description: str | |
repo_id: str | |
class PythonInterpreterTool(Tool): | |
name = "python_interpreter" | |
description = "This is a tool that evaluates python code. It can be used to perform calculations." | |
inputs = { | |
"code": { | |
"type": "string", | |
"description": "The python code to run in interpreter", | |
} | |
} | |
output_type = "string" | |
def __init__(self, *args, authorized_imports=None, **kwargs): | |
if authorized_imports is None: | |
self.authorized_imports = list(set(BASE_BUILTIN_MODULES)) | |
else: | |
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(authorized_imports)) | |
self.inputs = { | |
"code": { | |
"type": "string", | |
"description": ( | |
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " | |
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}." | |
), | |
} | |
} | |
self.base_python_tools = BASE_PYTHON_TOOLS | |
self.python_evaluator = evaluate_python_code | |
super().__init__(*args, **kwargs) | |
def forward(self, code: str) -> str: | |
state = {} | |
output = str( | |
self.python_evaluator( | |
code, | |
state=state, | |
static_tools=self.base_python_tools, | |
authorized_imports=self.authorized_imports, | |
)[0] # The second element is boolean is_final_answer | |
) | |
return f"Stdout:\n{str(state['_print_outputs'])}\nOutput: {output}" | |
class FinalAnswerTool(Tool): | |
name = "final_answer" | |
description = "Provides a final answer to the given problem." | |
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} | |
output_type = "any" | |
def forward(self, answer: Any) -> Any: | |
return answer | |
class UserInputTool(Tool): | |
name = "user_input" | |
description = "Asks for user's input on a specific question" | |
inputs = {"question": {"type": "string", "description": "The question to ask the user"}} | |
output_type = "string" | |
def forward(self, question): | |
user_input = input(f"{question} => Type your answer here:") | |
return user_input | |
class DuckDuckGoSearchTool(Tool): | |
name = "web_search" | |
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.""" | |
inputs = {"query": {"type": "string", "description": "The search query to perform."}} | |
output_type = "string" | |
def __init__(self, max_results=10, **kwargs): | |
super().__init__() | |
self.max_results = max_results | |
try: | |
from duckduckgo_search import DDGS | |
except ImportError as e: | |
raise ImportError( | |
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`." | |
) from e | |
self.ddgs = DDGS(**kwargs) | |
def forward(self, query: str) -> str: | |
results = self.ddgs.text(query, max_results=self.max_results) | |
if len(results) == 0: | |
raise Exception("No results found! Try a less restrictive/shorter query.") | |
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results] | |
return "## Search Results\n\n" + "\n\n".join(postprocessed_results) | |
class GoogleSearchTool(Tool): | |
name = "web_search" | |
description = """Performs a google web search for your query then returns a string of the top search results.""" | |
inputs = { | |
"query": {"type": "string", "description": "The search query to perform."}, | |
"filter_year": { | |
"type": "integer", | |
"description": "Optionally restrict results to a certain year", | |
"nullable": True, | |
}, | |
} | |
output_type = "string" | |
def __init__(self, provider: str = "serpapi"): | |
super().__init__() | |
import os | |
self.provider = provider | |
if provider == "serpapi": | |
self.organic_key = "organic_results" | |
api_key_env_name = "SERPAPI_API_KEY" | |
else: | |
self.organic_key = "organic" | |
api_key_env_name = "SERPER_API_KEY" | |
self.api_key = os.getenv(api_key_env_name) | |
if self.api_key is None: | |
raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.") | |
def forward(self, query: str, filter_year: Optional[int] = None) -> str: | |
import requests | |
if self.provider == "serpapi": | |
params = { | |
"q": query, | |
"api_key": self.api_key, | |
"engine": "google", | |
"google_domain": "google.com", | |
} | |
base_url = "https://serpapi.com/search.json" | |
else: | |
params = { | |
"q": query, | |
"api_key": self.api_key, | |
} | |
base_url = "https://google.serper.dev/search" | |
if filter_year is not None: | |
params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}" | |
response = requests.get(base_url, params=params) | |
if response.status_code == 200: | |
results = response.json() | |
else: | |
raise ValueError(response.json()) | |
if self.organic_key not in results.keys(): | |
if filter_year is not None: | |
raise Exception( | |
f"No results found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year." | |
) | |
else: | |
raise Exception(f"No results found for query: '{query}'. Use a less restrictive query.") | |
if len(results[self.organic_key]) == 0: | |
year_filter_message = f" with filter year={filter_year}" if filter_year is not None else "" | |
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter." | |
web_snippets = [] | |
if self.organic_key in results: | |
for idx, page in enumerate(results[self.organic_key]): | |
date_published = "" | |
if "date" in page: | |
date_published = "\nDate published: " + page["date"] | |
source = "" | |
if "source" in page: | |
source = "\nSource: " + page["source"] | |
snippet = "" | |
if "snippet" in page: | |
snippet = "\n" + page["snippet"] | |
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}" | |
web_snippets.append(redacted_version) | |
return "## Search Results\n" + "\n\n".join(web_snippets) | |
class VisitWebpageTool(Tool): | |
name = "visit_webpage" | |
description = ( | |
"Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages." | |
) | |
inputs = { | |
"url": { | |
"type": "string", | |
"description": "The url of the webpage to visit.", | |
} | |
} | |
output_type = "string" | |
def forward(self, url: str) -> str: | |
try: | |
import requests | |
from markdownify import markdownify | |
from requests.exceptions import RequestException | |
from smolagents.utils import truncate_content | |
except ImportError as e: | |
raise ImportError( | |
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`." | |
) from e | |
try: | |
# Send a GET request to the URL with a 20-second timeout | |
response = requests.get(url, timeout=20) | |
response.raise_for_status() # Raise an exception for bad status codes | |
# Convert the HTML content to Markdown | |
markdown_content = markdownify(response.text).strip() | |
# Remove multiple line breaks | |
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) | |
return truncate_content(markdown_content, 10000) | |
except requests.exceptions.Timeout: | |
return "The request timed out. Please try again later or check the URL." | |
except RequestException as e: | |
return f"Error fetching the webpage: {str(e)}" | |
except Exception as e: | |
return f"An unexpected error occurred: {str(e)}" | |
class SpeechToTextTool(PipelineTool): | |
default_checkpoint = "openai/whisper-large-v3-turbo" | |
description = "This is a tool that transcribes an audio into text. It returns the transcribed text." | |
name = "transcriber" | |
inputs = { | |
"audio": { | |
"type": "audio", | |
"description": "The audio to transcribe. Can be a local path, an url, or a tensor.", | |
} | |
} | |
output_type = "string" | |
def __new__(cls, *args, **kwargs): | |
from transformers.models.whisper import ( | |
WhisperForConditionalGeneration, | |
WhisperProcessor, | |
) | |
cls.pre_processor_class = WhisperProcessor | |
cls.model_class = WhisperForConditionalGeneration | |
return super().__new__(cls, *args, **kwargs) | |
def encode(self, audio): | |
from .agent_types import AgentAudio | |
audio = AgentAudio(audio).to_raw() | |
return self.pre_processor(audio, return_tensors="pt") | |
def forward(self, inputs): | |
return self.model.generate(inputs["input_features"]) | |
def decode(self, outputs): | |
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
TOOL_MAPPING = { | |
tool_class.name: tool_class | |
for tool_class in [ | |
PythonInterpreterTool, | |
DuckDuckGoSearchTool, | |
VisitWebpageTool, | |
] | |
} | |
__all__ = [ | |
"PythonInterpreterTool", | |
"FinalAnswerTool", | |
"UserInputTool", | |
"DuckDuckGoSearchTool", | |
"GoogleSearchTool", | |
"VisitWebpageTool", | |
"SpeechToTextTool", | |
] | |