Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
from typing import Dict, List, Optional | |
from bs4 import BeautifulSoup | |
import yt_dlp | |
import pandas as pd | |
import requests | |
import torch | |
from langchain_community.document_loaders import YoutubeLoader | |
from langchain_community.retrievers import BM25Retriever | |
from langchain_community.tools import BearlyInterpreterTool | |
from langchain.docstore.document import Document | |
from smolagents import ( | |
DuckDuckGoSearchTool, | |
SpeechToTextTool, | |
Tool, | |
VisitWebpageTool, | |
WikipediaSearchTool, | |
) | |
from transformers import AutoProcessor, AutoModelForImageTextToText | |
class RelevantInfoRetrieverTool(Tool): | |
name = "relevant_info_retriever" | |
description = "Retrieves relevant to the query information." | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query for which to retrieve information.", | |
}, | |
"docs": { | |
"type": "string", | |
"description": "The source documents from which to choose in order to retrieve relevant information", | |
}, | |
} | |
output_type = "string" | |
def forward(self, query: str, docs: List[Document]): | |
self.retriever = BM25Retriever.from_documents(docs) | |
results = self.retriever.get_relevant_documents(query) | |
if results: | |
return "\n\n".join([doc.page_content for doc in results]) | |
else: | |
return "No relevant information found." | |
class YoutubeTranscriptTool(Tool): | |
name = "youtube_transcript" | |
description = "Fetches youtube video's transcript." | |
inputs = { | |
"youtube_url": { | |
"type": "string", | |
"description": "The youtube video url", | |
}, | |
"source_langs": { | |
"type": "array", | |
"description": "A list of language codes in a descending priority for the video trascript.", | |
"items": {"type": "string"}, | |
"default": ["en"], | |
"required": False, | |
"nullable": True, | |
}, | |
"target_lang": { | |
"type": "string", | |
"description": "The language to which the transcript will be translated.", | |
"default": "en", | |
"required": False, | |
"nullable": True, | |
}, | |
} | |
output_type = "string" | |
def forward( | |
self, | |
youtube_url: str, | |
source_langs: Optional[List[str]] = ["en"], | |
target_lang: Optional[str] = "en", | |
): | |
try: | |
loader = YoutubeLoader.from_youtube_url( | |
youtube_url, | |
add_video_info=True, | |
language=source_langs, | |
translation=target_lang, | |
# transcript_format=TranscriptFormat.CHUNKS, | |
# chunk_size_seconds=30, | |
) | |
transcript_docs = loader.load() | |
return transcript_docs | |
except Exception as e: | |
return f"Error fetching video's transcript: {e}" | |
class ReverseStringTool(Tool): | |
name = "reverse_string" | |
description = "Reverses the input string." | |
inputs = { | |
"string": { | |
"type": "string", | |
"description": "The string that needs to be reversed.", | |
} | |
} | |
output_type = "string" | |
def forward(self, string: str): | |
try: | |
return string[-1::-1] | |
except Exception as e: | |
return f"Error reversing string: {e}" | |
class SmolVLM2: | |
"""The parent class for visual analyzer tools (using SmolVLM2-500M-Video model)""" | |
def __init__(self): | |
"""Initializations for the analyzer tool""" | |
model_path = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" | |
device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu" | |
self.processor = AutoProcessor.from_pretrained(model_path) | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
# _attn_implementation="flash_attention_2", | |
).to(device) | |
class ImagesAnalyzerTool(Tool, SmolVLM2): | |
name = "image_analyzer" | |
description = "Analyzes each input image according to the query" | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query according to which the image will be analyzed.", | |
}, | |
"images_urls": { | |
"type": "array", | |
"description": "A list of strings containing the images' urls", | |
"items": {"type": "string"}, | |
}, | |
} | |
output_type = "string" | |
def __init__(self): | |
Tool.__init__(self) | |
SmolVLM2.__init__(self) | |
def forward(self, query: str, images_urls: List[str]): | |
try: | |
# Image message entities for the different images' urls | |
image_message_ents = [{"type": "image", "url": iu} for iu in images_urls] | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": query, | |
}, | |
] | |
+ image_message_ents, | |
}, | |
] | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(self.model.device, dtype=torch.bfloat16) | |
generated_ids = self.model.generate( | |
**inputs, do_sample=False, max_new_tokens=64 | |
) | |
generated_texts = self.processor.batch_decode( | |
generated_ids, | |
skip_special_tokens=True, | |
) | |
return generated_texts[0] | |
except Exception as e: | |
return f"Error analyzing image(s): {e}" | |
class VideoAnalyzerTool(Tool, SmolVLM2): | |
name = "video_analyzer" | |
description = "Analyzes video at a specified path according to the query" | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query according to which the video will be analyzed.", | |
}, | |
"video_path": { | |
"type": "string", | |
"description": "A string containing the video path", | |
}, | |
} | |
output_type = "string" | |
def __init__(self): | |
Tool.__init__(self) | |
SmolVLM2.__init__(self) | |
def forward(self, query: str, video_path: str) -> str: | |
try: | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "video", "path": video_path}, | |
{"type": "text", "text": query}, | |
], | |
}, | |
] | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(self.model.device, dtype=torch.bfloat16) | |
generated_ids = self.model.generate( | |
**inputs, do_sample=False, max_new_tokens=64 | |
) | |
generated_texts = self.processor.batch_decode( | |
generated_ids, | |
skip_special_tokens=True, | |
) | |
return generated_texts[0] | |
except Exception as e: | |
return f"Error analyzing video: {e}" | |
finally: | |
# Cleanup if needed | |
if video_path and os.path.exists(video_path): | |
os.remove(video_path) | |
class FileDownloaderTool(Tool): | |
name = "file_downloader" | |
description = "Downloads a file returning the name of the temporarily saved file" | |
inputs = { | |
"file_url": { | |
"type": "string", | |
"description": "The url from which the file shall be downloaded.", | |
}, | |
} | |
output_type = "string" | |
def forward(self, file_url: str) -> str: | |
response = requests.get(file_url, stream=True) | |
response.raise_for_status() | |
original_filename = ( | |
response.headers.get("content-disposition", "") | |
.split("=", -1)[-1] | |
.strip('"') | |
) | |
# Even if original_filename is empty or there is no extension, ext will be "" | |
ext = os.path.splitext(original_filename)[-1] | |
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file: | |
for chunk in response.iter_content(chunk_size=8192): | |
tmp_file.write(chunk) | |
return tmp_file.name | |
class YoutubeVideoDownloaderTool(Tool): | |
name = "youtube_video_downloader" | |
description = "Downloads the video from the specified url and returns the path where the video was saved" | |
inputs = { | |
"video_url": { | |
"type": "string", | |
"description": "A string containing the video url", | |
}, | |
} | |
output_type = "string" | |
def forward(self, video_url: str) -> str: | |
try: | |
saved_video_path = "" | |
temp_dir = tempfile.gettempdir() | |
ydl_opts = { | |
"outtmpl": f"{temp_dir}/%(title)s.%(ext)s", # Absolute or relative path | |
"quiet": True, | |
} | |
# Download youtube video as a file in tmp directory | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info = ydl.extract_info(video_url, download=True) | |
saved_video_path = ydl.prepare_filename(info) | |
return saved_video_path | |
except Exception as e: | |
return f"Error downloading video: {e}" | |
class LoadXlsxFileTool(Tool): | |
name = "load_xlsx_file" | |
description = "This tool loads xlsx file into pandas and returns it" | |
inputs = {"file_path": {"type": "string", "description": "File path"}} | |
output_type = "object" | |
def forward(self, file_path: str) -> object: | |
return pd.read_excel(file_path) | |
class LoadTextFileTool(Tool): | |
name = "load_text_file" | |
description = "This tool loads any text file" | |
inputs = {"file_path": {"type": "string", "description": "File path"}} | |
output_type = "string" | |
def forward(self, file_path: str) -> str: | |
with open(file_path, "r", encoding="utf-8") as file: | |
return file.read() | |
class WebpageTablesContextRetrieverTool(Tool): | |
name = "webpage_tables_context_retriever" | |
description = """Retrieves structural context for all tables on a webpage. | |
Returns table indexes with captions, headers, and surrounding text to help identify relevant tables. | |
Use this first to determine which table index to extract.""" | |
inputs = { | |
"url": {"type": "string", "description": "The URL of the webpage to analyze"} | |
} | |
output_type = "object" | |
def forward(self, url: str) -> Dict: | |
"""Retrieve context information for all tables on the page""" | |
try: | |
response = requests.get(url, timeout=15) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.text, "html.parser") | |
tables = soup.find_all("table") | |
if not tables: | |
return { | |
"status": "success", | |
"tables": [], | |
"message": "No tables found on page", | |
"url": url, | |
} | |
results = [] | |
for i, table in enumerate(tables): | |
context = { | |
"index": i, | |
"id": table.get("id", ""), | |
"class": " ".join(table.get("class", [])), | |
"summary": table.get("summary", ""), | |
"caption": self._get_table_caption(table), | |
"preceding_header": self._get_preceding_header(table), | |
"surrounding_text": self._get_surrounding_text(table), | |
} | |
results.append(context) | |
return { | |
"status": "success", | |
"tables": results, | |
"url": url, | |
"message": f"Found {len(results)} tables with context information", | |
"suggestion": "Use html_table_extractor with the most relevant index", | |
} | |
except Exception as e: | |
return { | |
"status": "error", | |
"url": url, | |
"message": f"Failed to retrieve table contexts: {str(e)}", | |
} | |
def _get_table_caption(self, table) -> str: | |
"""Extract table caption text if available""" | |
caption = table.find("caption") | |
return caption.get_text(strip=True) if caption else "" | |
def _get_preceding_header(self, table) -> str: | |
"""Find the nearest preceding heading""" | |
for tag in table.find_all_previous(["h1", "h2", "h3", "h4", "h5", "h6"]): | |
return tag.get_text(strip=True) | |
return "" | |
def _get_surrounding_text(self, table, chars=150) -> str: | |
"""Get relevant text around the table""" | |
prev_text = " ".join( | |
t.strip() | |
for t in table.find_all_previous(string=True, limit=3) | |
if t.strip() | |
) | |
next_text = " ".join( | |
t.strip() for t in table.find_all_next(string=True, limit=3) if t.strip() | |
) | |
return f"...{prev_text[-chars:]} [TABLE] {next_text[:chars]}..." | |
class HtmlTableExtractorTool(Tool): | |
name = "html_table_extractor" | |
description = """Extracts a specific HTML table as structured data. | |
Use after webpage_tables_context_retriever to get the correct table index.""" | |
inputs = { | |
"page_url": { | |
"type": "string", | |
"description": "The webpage URL containing the table", | |
}, | |
"table_index": { | |
"type": "integer", | |
"description": "0-based index of the table to extract (from webpage_tables_context_retriever)", | |
}, | |
} | |
output_type = "object" | |
def forward(self, page_url: str, table_index: int) -> Dict: | |
"""Extract a specific table by index""" | |
try: | |
# First verify the URL is accessible | |
test_request = requests.head(page_url, timeout=5) | |
test_request.raise_for_status() | |
# Read all tables | |
tables = pd.read_html(page_url) | |
if not tables: | |
return { | |
"status": "error", | |
"message": "No tables found at URL", | |
"url": page_url, | |
} | |
# Validate index | |
if table_index < 0 or table_index >= len(tables): | |
return { | |
"status": "error", | |
"message": f"Invalid table index {table_index}. Page has {len(tables)} tables.", | |
"url": page_url, | |
"available_indexes": list(range(len(tables))), | |
} | |
# Convert DataFrame to JSON-serializable format | |
df = tables[table_index] | |
return { | |
"status": "success", | |
"table_index": table_index, | |
"table_data": df, | |
"url": page_url, | |
} | |
except Exception as e: | |
return { | |
"status": "error", | |
"message": f"Table extraction failed: {str(e)}", | |
"url": page_url, | |
"table_index": table_index, | |
} | |