blasisd's picture
Initial commit
e12c221
import os
import tempfile
from typing import Dict, List, Optional
from bs4 import BeautifulSoup
import yt_dlp
import pandas as pd
import requests
import torch
from langchain_community.document_loaders import YoutubeLoader
from langchain_community.retrievers import BM25Retriever
from langchain_community.tools import BearlyInterpreterTool
from langchain.docstore.document import Document
from smolagents import (
DuckDuckGoSearchTool,
SpeechToTextTool,
Tool,
VisitWebpageTool,
WikipediaSearchTool,
)
from transformers import AutoProcessor, AutoModelForImageTextToText
class RelevantInfoRetrieverTool(Tool):
name = "relevant_info_retriever"
description = "Retrieves relevant to the query information."
inputs = {
"query": {
"type": "string",
"description": "The query for which to retrieve information.",
},
"docs": {
"type": "string",
"description": "The source documents from which to choose in order to retrieve relevant information",
},
}
output_type = "string"
def forward(self, query: str, docs: List[Document]):
self.retriever = BM25Retriever.from_documents(docs)
results = self.retriever.get_relevant_documents(query)
if results:
return "\n\n".join([doc.page_content for doc in results])
else:
return "No relevant information found."
class YoutubeTranscriptTool(Tool):
name = "youtube_transcript"
description = "Fetches youtube video's transcript."
inputs = {
"youtube_url": {
"type": "string",
"description": "The youtube video url",
},
"source_langs": {
"type": "array",
"description": "A list of language codes in a descending priority for the video trascript.",
"items": {"type": "string"},
"default": ["en"],
"required": False,
"nullable": True,
},
"target_lang": {
"type": "string",
"description": "The language to which the transcript will be translated.",
"default": "en",
"required": False,
"nullable": True,
},
}
output_type = "string"
def forward(
self,
youtube_url: str,
source_langs: Optional[List[str]] = ["en"],
target_lang: Optional[str] = "en",
):
try:
loader = YoutubeLoader.from_youtube_url(
youtube_url,
add_video_info=True,
language=source_langs,
translation=target_lang,
# transcript_format=TranscriptFormat.CHUNKS,
# chunk_size_seconds=30,
)
transcript_docs = loader.load()
return transcript_docs
except Exception as e:
return f"Error fetching video's transcript: {e}"
class ReverseStringTool(Tool):
name = "reverse_string"
description = "Reverses the input string."
inputs = {
"string": {
"type": "string",
"description": "The string that needs to be reversed.",
}
}
output_type = "string"
def forward(self, string: str):
try:
return string[-1::-1]
except Exception as e:
return f"Error reversing string: {e}"
class SmolVLM2:
"""The parent class for visual analyzer tools (using SmolVLM2-500M-Video model)"""
def __init__(self):
"""Initializations for the analyzer tool"""
model_path = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu"
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = AutoModelForImageTextToText.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
# _attn_implementation="flash_attention_2",
).to(device)
class ImagesAnalyzerTool(Tool, SmolVLM2):
name = "image_analyzer"
description = "Analyzes each input image according to the query"
inputs = {
"query": {
"type": "string",
"description": "The query according to which the image will be analyzed.",
},
"images_urls": {
"type": "array",
"description": "A list of strings containing the images' urls",
"items": {"type": "string"},
},
}
output_type = "string"
def __init__(self):
Tool.__init__(self)
SmolVLM2.__init__(self)
def forward(self, query: str, images_urls: List[str]):
try:
# Image message entities for the different images' urls
image_message_ents = [{"type": "image", "url": iu} for iu in images_urls]
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": query,
},
]
+ image_message_ents,
},
]
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device, dtype=torch.bfloat16)
generated_ids = self.model.generate(
**inputs, do_sample=False, max_new_tokens=64
)
generated_texts = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
)
return generated_texts[0]
except Exception as e:
return f"Error analyzing image(s): {e}"
class VideoAnalyzerTool(Tool, SmolVLM2):
name = "video_analyzer"
description = "Analyzes video at a specified path according to the query"
inputs = {
"query": {
"type": "string",
"description": "The query according to which the video will be analyzed.",
},
"video_path": {
"type": "string",
"description": "A string containing the video path",
},
}
output_type = "string"
def __init__(self):
Tool.__init__(self)
SmolVLM2.__init__(self)
def forward(self, query: str, video_path: str) -> str:
try:
messages = [
{
"role": "user",
"content": [
{"type": "video", "path": video_path},
{"type": "text", "text": query},
],
},
]
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device, dtype=torch.bfloat16)
generated_ids = self.model.generate(
**inputs, do_sample=False, max_new_tokens=64
)
generated_texts = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
)
return generated_texts[0]
except Exception as e:
return f"Error analyzing video: {e}"
finally:
# Cleanup if needed
if video_path and os.path.exists(video_path):
os.remove(video_path)
class FileDownloaderTool(Tool):
name = "file_downloader"
description = "Downloads a file returning the name of the temporarily saved file"
inputs = {
"file_url": {
"type": "string",
"description": "The url from which the file shall be downloaded.",
},
}
output_type = "string"
def forward(self, file_url: str) -> str:
response = requests.get(file_url, stream=True)
response.raise_for_status()
original_filename = (
response.headers.get("content-disposition", "")
.split("=", -1)[-1]
.strip('"')
)
# Even if original_filename is empty or there is no extension, ext will be ""
ext = os.path.splitext(original_filename)[-1]
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file:
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
return tmp_file.name
class YoutubeVideoDownloaderTool(Tool):
name = "youtube_video_downloader"
description = "Downloads the video from the specified url and returns the path where the video was saved"
inputs = {
"video_url": {
"type": "string",
"description": "A string containing the video url",
},
}
output_type = "string"
def forward(self, video_url: str) -> str:
try:
saved_video_path = ""
temp_dir = tempfile.gettempdir()
ydl_opts = {
"outtmpl": f"{temp_dir}/%(title)s.%(ext)s", # Absolute or relative path
"quiet": True,
}
# Download youtube video as a file in tmp directory
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(video_url, download=True)
saved_video_path = ydl.prepare_filename(info)
return saved_video_path
except Exception as e:
return f"Error downloading video: {e}"
class LoadXlsxFileTool(Tool):
name = "load_xlsx_file"
description = "This tool loads xlsx file into pandas and returns it"
inputs = {"file_path": {"type": "string", "description": "File path"}}
output_type = "object"
def forward(self, file_path: str) -> object:
return pd.read_excel(file_path)
class LoadTextFileTool(Tool):
name = "load_text_file"
description = "This tool loads any text file"
inputs = {"file_path": {"type": "string", "description": "File path"}}
output_type = "string"
def forward(self, file_path: str) -> str:
with open(file_path, "r", encoding="utf-8") as file:
return file.read()
class WebpageTablesContextRetrieverTool(Tool):
name = "webpage_tables_context_retriever"
description = """Retrieves structural context for all tables on a webpage.
Returns table indexes with captions, headers, and surrounding text to help identify relevant tables.
Use this first to determine which table index to extract."""
inputs = {
"url": {"type": "string", "description": "The URL of the webpage to analyze"}
}
output_type = "object"
def forward(self, url: str) -> Dict:
"""Retrieve context information for all tables on the page"""
try:
response = requests.get(url, timeout=15)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
tables = soup.find_all("table")
if not tables:
return {
"status": "success",
"tables": [],
"message": "No tables found on page",
"url": url,
}
results = []
for i, table in enumerate(tables):
context = {
"index": i,
"id": table.get("id", ""),
"class": " ".join(table.get("class", [])),
"summary": table.get("summary", ""),
"caption": self._get_table_caption(table),
"preceding_header": self._get_preceding_header(table),
"surrounding_text": self._get_surrounding_text(table),
}
results.append(context)
return {
"status": "success",
"tables": results,
"url": url,
"message": f"Found {len(results)} tables with context information",
"suggestion": "Use html_table_extractor with the most relevant index",
}
except Exception as e:
return {
"status": "error",
"url": url,
"message": f"Failed to retrieve table contexts: {str(e)}",
}
def _get_table_caption(self, table) -> str:
"""Extract table caption text if available"""
caption = table.find("caption")
return caption.get_text(strip=True) if caption else ""
def _get_preceding_header(self, table) -> str:
"""Find the nearest preceding heading"""
for tag in table.find_all_previous(["h1", "h2", "h3", "h4", "h5", "h6"]):
return tag.get_text(strip=True)
return ""
def _get_surrounding_text(self, table, chars=150) -> str:
"""Get relevant text around the table"""
prev_text = " ".join(
t.strip()
for t in table.find_all_previous(string=True, limit=3)
if t.strip()
)
next_text = " ".join(
t.strip() for t in table.find_all_next(string=True, limit=3) if t.strip()
)
return f"...{prev_text[-chars:]} [TABLE] {next_text[:chars]}..."
class HtmlTableExtractorTool(Tool):
name = "html_table_extractor"
description = """Extracts a specific HTML table as structured data.
Use after webpage_tables_context_retriever to get the correct table index."""
inputs = {
"page_url": {
"type": "string",
"description": "The webpage URL containing the table",
},
"table_index": {
"type": "integer",
"description": "0-based index of the table to extract (from webpage_tables_context_retriever)",
},
}
output_type = "object"
def forward(self, page_url: str, table_index: int) -> Dict:
"""Extract a specific table by index"""
try:
# First verify the URL is accessible
test_request = requests.head(page_url, timeout=5)
test_request.raise_for_status()
# Read all tables
tables = pd.read_html(page_url)
if not tables:
return {
"status": "error",
"message": "No tables found at URL",
"url": page_url,
}
# Validate index
if table_index < 0 or table_index >= len(tables):
return {
"status": "error",
"message": f"Invalid table index {table_index}. Page has {len(tables)} tables.",
"url": page_url,
"available_indexes": list(range(len(tables))),
}
# Convert DataFrame to JSON-serializable format
df = tables[table_index]
return {
"status": "success",
"table_index": table_index,
"table_data": df,
"url": page_url,
}
except Exception as e:
return {
"status": "error",
"message": f"Table extraction failed: {str(e)}",
"url": page_url,
"table_index": table_index,
}