Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
LLM-based Question Classifier for Multi-Agent GAIA Solver | |
Routes questions to appropriate specialist agents based on content analysis | |
""" | |
import os | |
import json | |
import re | |
from typing import Dict, List, Optional, Tuple | |
from enum import Enum | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Import LLM (using same setup as main solver) | |
try: | |
from smolagents import InferenceClientModel | |
except ImportError: | |
# Fallback for newer smolagents versions | |
try: | |
from smolagents.models import InferenceClientModel | |
except ImportError: | |
# If all imports fail, we'll handle this in the class | |
InferenceClientModel = None | |
class AgentType(Enum): | |
"""Available specialist agent types""" | |
MULTIMEDIA = "multimedia" # Video, audio, image analysis | |
RESEARCH = "research" # Web search, Wikipedia, academic papers | |
LOGIC_MATH = "logic_math" # Puzzles, calculations, pattern recognition | |
FILE_PROCESSING = "file_processing" # Excel, Python code, document analysis | |
GENERAL = "general" # Fallback for unclear cases | |
# Regular expression patterns for better content type detection | |
YOUTUBE_URL_PATTERN = r'(https?://)?(www\.)?(youtube\.com|youtu\.?be)/.+?(?=\s|$)' | |
# Enhanced YouTube URL pattern with more variations (shortened links, IDs, watch URLs, etc) | |
ENHANCED_YOUTUBE_URL_PATTERN = r'(https?://)?(www\.)?(youtube\.com|youtu\.?be)/(?:watch\?v=|embed/|v/|shorts/|playlist\?list=|channel/|user/|[^/\s]+/?)?([^\s&?/]+)' | |
VIDEO_PATTERNS = [r'youtube\.(com|be)', r'video', r'watch\?v='] | |
AUDIO_PATTERNS = [r'\.mp3\b', r'\.wav\b', r'audio', r'sound', r'listen', r'music', r'podcast'] | |
IMAGE_PATTERNS = [r'\.jpg\b', r'\.jpeg\b', r'\.png\b', r'\.gif\b', r'image', r'picture', r'photo'] | |
class QuestionClassifier: | |
"""LLM-powered question classifier for agent routing""" | |
def __init__(self): | |
self.hf_token = os.getenv("HUGGINGFACE_TOKEN") | |
if not self.hf_token: | |
raise ValueError("HUGGINGFACE_TOKEN environment variable is required") | |
# Initialize lightweight model for classification | |
if InferenceClientModel is not None: | |
self.classifier_model = InferenceClientModel( | |
model_id="Qwen/Qwen2.5-7B-Instruct", # Smaller, faster model for classification | |
token=self.hf_token | |
) | |
else: | |
# Fallback: Use a simple rule-based classifier | |
self.classifier_model = None | |
print("β οΈ Using fallback rule-based classification (InferenceClientModel not available)") | |
def classify_question(self, question: str, file_name: str = "") -> Dict: | |
""" | |
Classify a GAIA question and determine the best agent routing | |
Args: | |
question: The question text | |
file_name: Associated file name (if any) | |
Returns: | |
Dict with classification results and routing information | |
""" | |
# First, check for direct YouTube URL pattern as a fast path (enhanced detection) | |
if re.search(ENHANCED_YOUTUBE_URL_PATTERN, question): | |
return self._create_youtube_video_classification(question, file_name) | |
# Secondary check for YouTube keywords plus URL-like text | |
question_lower = question.lower() | |
if "youtube" in question_lower and any(term in question_lower for term in ["video", "watch", "channel"]): | |
# Possible YouTube question, check more carefully | |
if re.search(r'(youtube\.com|youtu\.be)', question): | |
return self._create_youtube_video_classification(question, file_name) | |
# Continue with regular classification | |
# Create classification prompt | |
classification_prompt = f""" | |
Analyze this GAIA benchmark question and classify it for routing to specialist agents. | |
Question: {question} | |
Associated file: {file_name if file_name else "None"} | |
Classify this question into ONE primary category and optionally secondary categories: | |
AGENT CATEGORIES: | |
1. MULTIMEDIA - Questions involving video analysis, audio transcription, image analysis | |
Examples: YouTube videos, MP3 files, PNG images, visual content analysis | |
2. RESEARCH - Questions requiring web search, Wikipedia lookup, or factual data retrieval | |
Examples: Factual lookups, biographical info, historical data, citations, sports statistics, company information, academic papers | |
Note: If a question requires looking up data first (even for later calculations), classify as RESEARCH | |
3. LOGIC_MATH - Questions involving pure mathematical calculations or logical reasoning with given data | |
Examples: Mathematical puzzles with provided numbers, algebraic equations, geometric calculations, logical deduction puzzles | |
Note: Use this ONLY when all data is provided and no external lookup is needed | |
4. FILE_PROCESSING - Questions requiring file analysis (Excel, Python code, documents) | |
Examples: Spreadsheet analysis, code execution, document parsing | |
5. GENERAL - Simple questions or unclear classification | |
ANALYSIS REQUIRED: | |
1. Primary agent type (required) | |
2. Secondary agent types (if question needs multiple specialists) | |
3. Complexity level (1-5, where 5 is most complex) | |
4. Tools needed (list specific tools that would be useful) | |
5. Reasoning (explain your classification choice) | |
Respond in JSON format: | |
{{ | |
"primary_agent": "AGENT_TYPE", | |
"secondary_agents": ["AGENT_TYPE2", "AGENT_TYPE3"], | |
"complexity": 3, | |
"confidence": 0.95, | |
"tools_needed": ["tool1", "tool2"], | |
"reasoning": "explanation of classification", | |
"requires_multimodal": false, | |
"estimated_steps": 5 | |
}} | |
""" | |
try: | |
# Get classification from LLM or fallback | |
if self.classifier_model is not None: | |
messages = [{"role": "user", "content": classification_prompt}] | |
response = self.classifier_model(messages) | |
else: | |
# Fallback to rule-based classification | |
return self._fallback_classification(question, file_name) | |
# Parse JSON response | |
classification_text = response.content.strip() | |
# Extract JSON if wrapped in code blocks | |
if "```json" in classification_text: | |
json_start = classification_text.find("```json") + 7 | |
json_end = classification_text.find("```", json_start) | |
classification_text = classification_text[json_start:json_end].strip() | |
elif "```" in classification_text: | |
json_start = classification_text.find("```") + 3 | |
json_end = classification_text.find("```", json_start) | |
classification_text = classification_text[json_start:json_end].strip() | |
classification = json.loads(classification_text) | |
# Validate and normalize the response | |
return self._validate_classification(classification, question, file_name) | |
except Exception as e: | |
print(f"Classification error: {e}") | |
# Fallback classification | |
return self._fallback_classification(question, file_name) | |
def _create_youtube_video_classification(self, question: str, file_name: str = "") -> Dict: | |
"""Create a specialized classification for YouTube video questions""" | |
# Use enhanced pattern for more robust URL detection | |
youtube_url_match = re.search(ENHANCED_YOUTUBE_URL_PATTERN, question) | |
if not youtube_url_match: | |
# Fall back to original pattern | |
youtube_url_match = re.search(YOUTUBE_URL_PATTERN, question) | |
# Extract the URL | |
if youtube_url_match: | |
youtube_url = youtube_url_match.group(0) | |
else: | |
# If we can't extract a URL but it looks like a YouTube question | |
question_lower = question.lower() | |
if "youtube" in question_lower: | |
# Try to find any URL-like pattern | |
url_match = re.search(r'https?://\S+', question) | |
youtube_url = url_match.group(0) if url_match else "unknown_youtube_url" | |
else: | |
youtube_url = "unknown_youtube_url" | |
# Determine complexity based on question | |
question_lower = question.lower() | |
complexity = 3 # Default | |
confidence = 0.98 # High default confidence for YouTube questions | |
# Analyze the task more specifically | |
if any(term in question_lower for term in ['count', 'how many', 'highest number']): | |
complexity = 2 # Counting tasks | |
task_type = "counting" | |
elif any(term in question_lower for term in ['relationship', 'compare', 'difference']): | |
complexity = 4 # Comparative analysis | |
task_type = "comparison" | |
elif any(term in question_lower for term in ['say', 'speech', 'dialogue', 'talk', 'speak']): | |
complexity = 3 # Speech analysis | |
task_type = "speech_analysis" | |
elif any(term in question_lower for term in ['scene', 'visual', 'appear', 'shown']): | |
complexity = 3 # Visual analysis | |
task_type = "visual_analysis" | |
else: | |
task_type = "general_video_analysis" | |
# Always use analyze_youtube_video as the primary tool | |
tools_needed = ["analyze_youtube_video"] | |
# Set highest priority for analyze_youtube_video in case other tools are suggested | |
# This ensures it always appears first in the tools list | |
primary_tool = "analyze_youtube_video" | |
# Add secondary tools if the task might need them | |
if "audio" in question_lower or any(term in question_lower for term in ['say', 'speech', 'dialogue']): | |
tools_needed.append("analyze_audio_file") # Add as fallback | |
return { | |
"primary_agent": "multimedia", | |
"secondary_agents": [], | |
"complexity": complexity, | |
"confidence": confidence, | |
"tools_needed": tools_needed, | |
"reasoning": f"Question contains a YouTube URL and requires {task_type}", | |
"requires_multimodal": True, | |
"estimated_steps": 3, | |
"question_summary": question[:100] + "..." if len(question) > 100 else question, | |
"has_file": bool(file_name), | |
"media_type": "youtube_video", | |
"media_url": youtube_url, | |
"task_type": task_type # Add task type for more specific handling | |
} | |
def _validate_classification(self, classification: Dict, question: str, file_name: str) -> Dict: | |
"""Validate and normalize classification response""" | |
# Ensure primary agent is valid | |
primary_agent = classification.get("primary_agent", "GENERAL") | |
if primary_agent not in [agent.value.upper() for agent in AgentType]: | |
primary_agent = "GENERAL" | |
# Validate secondary agents | |
secondary_agents = classification.get("secondary_agents", []) | |
valid_secondary = [ | |
agent for agent in secondary_agents | |
if agent.upper() in [a.value.upper() for a in AgentType] | |
] | |
# Ensure confidence is between 0 and 1 | |
confidence = max(0.0, min(1.0, classification.get("confidence", 0.5))) | |
# Ensure complexity is between 1 and 5 | |
complexity = max(1, min(5, classification.get("complexity", 3))) | |
return { | |
"primary_agent": primary_agent.lower(), | |
"secondary_agents": [agent.lower() for agent in valid_secondary], | |
"complexity": complexity, | |
"confidence": confidence, | |
"tools_needed": classification.get("tools_needed", []), | |
"reasoning": classification.get("reasoning", "Automated classification"), | |
"requires_multimodal": classification.get("requires_multimodal", False), | |
"estimated_steps": classification.get("estimated_steps", 5), | |
"question_summary": question[:100] + "..." if len(question) > 100 else question, | |
"has_file": bool(file_name) | |
} | |
def _fallback_classification(self, question: str, file_name: str = "") -> Dict: | |
"""Fallback classification when LLM fails""" | |
# Simple heuristic-based fallback | |
question_lower = question.lower() | |
# Check for YouTube URL first (most specific case) - use enhanced pattern | |
youtube_match = re.search(ENHANCED_YOUTUBE_URL_PATTERN, question) | |
if youtube_match: | |
# Use the dedicated method for YouTube classification to ensure consistency | |
return self._create_youtube_video_classification(question, file_name) | |
# Secondary check for YouTube references (may not have a valid URL format) | |
if "youtube" in question_lower and any(keyword in question_lower for keyword in | |
["video", "watch", "link", "url", "channel"]): | |
# Likely a YouTube question even without a perfect URL match | |
# Create a custom classification with high confidence | |
return { | |
"primary_agent": "multimedia", | |
"secondary_agents": [], | |
"complexity": 3, | |
"confidence": 0.85, | |
"tools_needed": ["analyze_youtube_video"], | |
"reasoning": "Fallback detected YouTube reference without complete URL", | |
"requires_multimodal": True, | |
"estimated_steps": 3, | |
"question_summary": question[:100] + "..." if len(question) > 100 else question, | |
"has_file": bool(file_name), | |
"media_type": "youtube_video", | |
"media_url": "youtube_reference_detected" # Placeholder | |
} | |
# Check other multimedia patterns | |
# Video patterns (beyond YouTube) | |
elif any(re.search(pattern, question_lower) for pattern in VIDEO_PATTERNS): | |
return { | |
"primary_agent": "multimedia", | |
"secondary_agents": [], | |
"complexity": 3, | |
"confidence": 0.8, | |
"tools_needed": ["analyze_video_frames"], | |
"reasoning": "Fallback detected video-related content", | |
"requires_multimodal": True, | |
"estimated_steps": 4, | |
"question_summary": question[:100] + "..." if len(question) > 100 else question, | |
"has_file": bool(file_name), | |
"media_type": "video" | |
} | |
# Audio patterns | |
elif any(re.search(pattern, question_lower) for pattern in AUDIO_PATTERNS): | |
return { | |
"primary_agent": "multimedia", | |
"secondary_agents": [], | |
"complexity": 3, | |
"confidence": 0.8, | |
"tools_needed": ["analyze_audio_file"], | |
"reasoning": "Fallback detected audio-related content", | |
"requires_multimodal": True, | |
"estimated_steps": 3, | |
"question_summary": question[:100] + "..." if len(question) > 100 else question, | |
"has_file": bool(file_name), | |
"media_type": "audio" | |
} | |
# Image patterns | |
elif any(re.search(pattern, question_lower) for pattern in IMAGE_PATTERNS): | |
return { | |
"primary_agent": "multimedia", | |
"secondary_agents": [], | |
"complexity": 2, | |
"confidence": 0.8, | |
"tools_needed": ["analyze_image_with_gemini"], | |
"reasoning": "Fallback detected image-related content", | |
"requires_multimodal": True, | |
"estimated_steps": 2, | |
"question_summary": question[:100] + "..." if len(question) > 100 else question, | |
"has_file": bool(file_name), | |
"media_type": "image" | |
} | |
# General multimedia keywords | |
elif any(keyword in question_lower for keyword in ["multimedia", "visual", "picture", "screenshot"]): | |
primary_agent = "multimedia" | |
tools_needed = ["analyze_image_with_gemini"] | |
# Research patterns | |
elif any(keyword in question_lower for keyword in ["wikipedia", "search", "find", "who", "what", "when", "where"]): | |
primary_agent = "research" | |
tools_needed = ["research_with_comprehensive_fallback"] | |
# Math/Logic patterns | |
elif any(keyword in question_lower for keyword in ["calculate", "number", "count", "math", "opposite", "pattern"]): | |
primary_agent = "logic_math" | |
tools_needed = ["advanced_calculator"] | |
# File processing | |
elif file_name and any(ext in file_name.lower() for ext in [".xlsx", ".py", ".csv", ".pdf"]): | |
primary_agent = "file_processing" | |
if ".xlsx" in file_name.lower(): | |
tools_needed = ["analyze_excel_file"] | |
elif ".py" in file_name.lower(): | |
tools_needed = ["analyze_python_code"] | |
else: | |
tools_needed = ["analyze_text_file"] | |
# Default | |
else: | |
primary_agent = "general" | |
tools_needed = [] | |
return { | |
"primary_agent": primary_agent, | |
"secondary_agents": [], | |
"complexity": 3, | |
"confidence": 0.6, | |
"tools_needed": tools_needed, | |
"reasoning": "Fallback heuristic classification", | |
"requires_multimodal": bool(file_name), | |
"estimated_steps": 5, | |
"question_summary": question[:100] + "..." if len(question) > 100 else question, | |
"has_file": bool(file_name) | |
} | |
def batch_classify(self, questions: List[Dict]) -> List[Dict]: | |
"""Classify multiple questions in batch""" | |
results = [] | |
for q in questions: | |
question_text = q.get("question", "") | |
file_name = q.get("file_name", "") | |
task_id = q.get("task_id", "") | |
classification = self.classify_question(question_text, file_name) | |
classification["task_id"] = task_id | |
results.append(classification) | |
return results | |
def get_routing_recommendation(self, classification: Dict) -> Dict: | |
"""Get specific routing recommendations based on classification""" | |
primary_agent = classification["primary_agent"] | |
complexity = classification["complexity"] | |
routing = { | |
"primary_route": primary_agent, | |
"requires_coordination": len(classification["secondary_agents"]) > 0, | |
"parallel_execution": False, | |
"estimated_duration": "medium", | |
"special_requirements": [] | |
} | |
# Add special requirements based on agent type | |
if primary_agent == "multimedia": | |
routing["special_requirements"].extend([ | |
"Requires yt-dlp and ffmpeg for video processing", | |
"Needs Gemini Vision API for image analysis", | |
"May need large temp storage for video files" | |
]) | |
elif primary_agent == "research": | |
routing["special_requirements"].extend([ | |
"Requires web search and Wikipedia API access", | |
"May need academic database access", | |
"Benefits from citation tracking tools" | |
]) | |
elif primary_agent == "file_processing": | |
routing["special_requirements"].extend([ | |
"Requires file processing libraries (pandas, openpyxl)", | |
"May need sandboxed code execution environment", | |
"Needs secure file handling" | |
]) | |
# Adjust duration estimate based on complexity | |
if complexity >= 4: | |
routing["estimated_duration"] = "long" | |
elif complexity <= 2: | |
routing["estimated_duration"] = "short" | |
# Suggest parallel execution for multi-agent scenarios | |
if len(classification["secondary_agents"]) >= 2: | |
routing["parallel_execution"] = True | |
return routing | |
def test_classifier(): | |
"""Test the classifier with sample GAIA questions""" | |
# Sample questions from our GAIA set | |
test_questions = [ | |
{ | |
"task_id": "video_test", | |
"question": "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?", | |
"file_name": "" | |
}, | |
{ | |
"task_id": "youtube_short_test", | |
"question": "Check this YouTube video https://youtu.be/L1vXCYZAYYM and count the birds", | |
"file_name": "" | |
}, | |
{ | |
"task_id": "video_url_variation", | |
"question": "How many people appear in the YouTube video at youtube.com/watch?v=dQw4w9WgXcQ", | |
"file_name": "" | |
}, | |
{ | |
"task_id": "research_test", | |
"question": "How many studio albums were published by Mercedes Sosa between 2000 and 2009?", | |
"file_name": "" | |
}, | |
{ | |
"task_id": "logic_test", | |
"question": ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI", | |
"file_name": "" | |
}, | |
{ | |
"task_id": "file_test", | |
"question": "What is the final numeric output from the attached Python code?", | |
"file_name": "script.py" | |
} | |
] | |
classifier = QuestionClassifier() | |
print("π§ Testing Question Classifier") | |
print("=" * 50) | |
for question in test_questions: | |
print(f"\nπ Question: {question['question'][:80]}...") | |
classification = classifier.classify_question( | |
question["question"], | |
question["file_name"] | |
) | |
print(f"π― Primary Agent: {classification['primary_agent']}") | |
print(f"π§ Tools Needed: {classification['tools_needed']}") | |
print(f"π Complexity: {classification['complexity']}/5") | |
print(f"π² Confidence: {classification['confidence']:.2f}") | |
print(f"π Reasoning: {classification['reasoning']}") | |
routing = classifier.get_routing_recommendation(classification) | |
print(f"π Routing: {routing['primary_route']} ({'coordination needed' if routing['requires_coordination'] else 'single agent'})") | |
if __name__ == "__main__": | |
test_classifier() |