File size: 22,081 Bytes
37cadfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
#!/usr/bin/env python3
"""
LLM-based Question Classifier for Multi-Agent GAIA Solver
Routes questions to appropriate specialist agents based on content analysis
"""

import os
import json
import re
from typing import Dict, List, Optional, Tuple
from enum import Enum
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Import LLM (using same setup as main solver)
from smolagents import InferenceClientModel


class AgentType(Enum):
    """Available specialist agent types"""
    MULTIMEDIA = "multimedia"           # Video, audio, image analysis
    RESEARCH = "research"              # Web search, Wikipedia, academic papers
    LOGIC_MATH = "logic_math"          # Puzzles, calculations, pattern recognition
    FILE_PROCESSING = "file_processing" # Excel, Python code, document analysis
    GENERAL = "general"                # Fallback for unclear cases


# Regular expression patterns for better content type detection
YOUTUBE_URL_PATTERN = r'(https?://)?(www\.)?(youtube\.com|youtu\.?be)/.+?(?=\s|$)'
# Enhanced YouTube URL pattern with more variations (shortened links, IDs, watch URLs, etc)
ENHANCED_YOUTUBE_URL_PATTERN = r'(https?://)?(www\.)?(youtube\.com|youtu\.?be)/(?:watch\?v=|embed/|v/|shorts/|playlist\?list=|channel/|user/|[^/\s]+/?)?([^\s&?/]+)'
VIDEO_PATTERNS = [r'youtube\.(com|be)', r'video', r'watch\?v=']
AUDIO_PATTERNS = [r'\.mp3\b', r'\.wav\b', r'audio', r'sound', r'listen', r'music', r'podcast']
IMAGE_PATTERNS = [r'\.jpg\b', r'\.jpeg\b', r'\.png\b', r'\.gif\b', r'image', r'picture', r'photo']


class QuestionClassifier:
    """LLM-powered question classifier for agent routing"""
    
    def __init__(self):
        self.hf_token = os.getenv("HUGGINGFACE_TOKEN")
        if not self.hf_token:
            raise ValueError("HUGGINGFACE_TOKEN environment variable is required")
        
        # Initialize lightweight model for classification
        self.classifier_model = InferenceClientModel(
            model_id="Qwen/Qwen2.5-7B-Instruct",  # Smaller, faster model for classification
            token=self.hf_token
        )
        
    def classify_question(self, question: str, file_name: str = "") -> Dict:
        """
        Classify a GAIA question and determine the best agent routing
        
        Args:
            question: The question text
            file_name: Associated file name (if any)
            
        Returns:
            Dict with classification results and routing information
        """
        # First, check for direct YouTube URL pattern as a fast path (enhanced detection)
        if re.search(ENHANCED_YOUTUBE_URL_PATTERN, question):
            return self._create_youtube_video_classification(question, file_name)
        
        # Secondary check for YouTube keywords plus URL-like text
        question_lower = question.lower()
        if "youtube" in question_lower and any(term in question_lower for term in ["video", "watch", "channel"]):
            # Possible YouTube question, check more carefully
            if re.search(r'(youtube\.com|youtu\.be)', question):
                return self._create_youtube_video_classification(question, file_name)
        
        # Continue with regular classification
        # Create classification prompt
        classification_prompt = f"""
Analyze this GAIA benchmark question and classify it for routing to specialist agents.

Question: {question}
Associated file: {file_name if file_name else "None"}

Classify this question into ONE primary category and optionally secondary categories:

AGENT CATEGORIES:
1. MULTIMEDIA - Questions involving video analysis, audio transcription, image analysis
   Examples: YouTube videos, MP3 files, PNG images, visual content analysis
   
2. RESEARCH - Questions requiring web search, Wikipedia lookup, or factual data retrieval
   Examples: Factual lookups, biographical info, historical data, citations, sports statistics, company information, academic papers
   Note: If a question requires looking up data first (even for later calculations), classify as RESEARCH
   
3. LOGIC_MATH - Questions involving pure mathematical calculations or logical reasoning with given data
   Examples: Mathematical puzzles with provided numbers, algebraic equations, geometric calculations, logical deduction puzzles
   Note: Use this ONLY when all data is provided and no external lookup is needed
   
4. FILE_PROCESSING - Questions requiring file analysis (Excel, Python code, documents)
   Examples: Spreadsheet analysis, code execution, document parsing
   
5. GENERAL - Simple questions or unclear classification

ANALYSIS REQUIRED:
1. Primary agent type (required)
2. Secondary agent types (if question needs multiple specialists)
3. Complexity level (1-5, where 5 is most complex)
4. Tools needed (list specific tools that would be useful)
5. Reasoning (explain your classification choice)

Respond in JSON format:
{{
    "primary_agent": "AGENT_TYPE",
    "secondary_agents": ["AGENT_TYPE2", "AGENT_TYPE3"],
    "complexity": 3,
    "confidence": 0.95,
    "tools_needed": ["tool1", "tool2"],
    "reasoning": "explanation of classification",
    "requires_multimodal": false,
    "estimated_steps": 5
}}
"""

        try:
            # Get classification from LLM
            messages = [{"role": "user", "content": classification_prompt}]
            response = self.classifier_model(messages)
            
            # Parse JSON response  
            classification_text = response.content.strip()
            
            # Extract JSON if wrapped in code blocks
            if "```json" in classification_text:
                json_start = classification_text.find("```json") + 7
                json_end = classification_text.find("```", json_start)
                classification_text = classification_text[json_start:json_end].strip()
            elif "```" in classification_text:
                json_start = classification_text.find("```") + 3
                json_end = classification_text.find("```", json_start)
                classification_text = classification_text[json_start:json_end].strip()
            
            classification = json.loads(classification_text)
            
            # Validate and normalize the response
            return self._validate_classification(classification, question, file_name)
            
        except Exception as e:
            print(f"Classification error: {e}")
            # Fallback classification
            return self._fallback_classification(question, file_name)
            
    def _create_youtube_video_classification(self, question: str, file_name: str = "") -> Dict:
        """Create a specialized classification for YouTube video questions"""
        # Use enhanced pattern for more robust URL detection
        youtube_url_match = re.search(ENHANCED_YOUTUBE_URL_PATTERN, question)
        if not youtube_url_match:
            # Fall back to original pattern
            youtube_url_match = re.search(YOUTUBE_URL_PATTERN, question)
            
        # Extract the URL
        if youtube_url_match:
            youtube_url = youtube_url_match.group(0)
        else:
            # If we can't extract a URL but it looks like a YouTube question
            question_lower = question.lower()
            if "youtube" in question_lower:
                # Try to find any URL-like pattern
                url_match = re.search(r'https?://\S+', question)
                youtube_url = url_match.group(0) if url_match else "unknown_youtube_url"
            else:
                youtube_url = "unknown_youtube_url"
        
        # Determine complexity based on question
        question_lower = question.lower()
        complexity = 3  # Default
        confidence = 0.98  # High default confidence for YouTube questions
        
        # Analyze the task more specifically
        if any(term in question_lower for term in ['count', 'how many', 'highest number']):
            complexity = 2  # Counting tasks
            task_type = "counting"
        elif any(term in question_lower for term in ['relationship', 'compare', 'difference']):
            complexity = 4  # Comparative analysis
            task_type = "comparison"
        elif any(term in question_lower for term in ['say', 'speech', 'dialogue', 'talk', 'speak']):
            complexity = 3  # Speech analysis
            task_type = "speech_analysis"
        elif any(term in question_lower for term in ['scene', 'visual', 'appear', 'shown']):
            complexity = 3  # Visual analysis
            task_type = "visual_analysis"
        else:
            task_type = "general_video_analysis"
        
        # Always use analyze_youtube_video as the primary tool
        tools_needed = ["analyze_youtube_video"]
        
        # Set highest priority for analyze_youtube_video in case other tools are suggested
        # This ensures it always appears first in the tools list
        primary_tool = "analyze_youtube_video"
        
        # Add secondary tools if the task might need them
        if "audio" in question_lower or any(term in question_lower for term in ['say', 'speech', 'dialogue']):
            tools_needed.append("analyze_audio_file")  # Add as fallback
            
        return {
            "primary_agent": "multimedia",
            "secondary_agents": [],
            "complexity": complexity,
            "confidence": confidence,
            "tools_needed": tools_needed,
            "reasoning": f"Question contains a YouTube URL and requires {task_type}",
            "requires_multimodal": True,
            "estimated_steps": 3,
            "question_summary": question[:100] + "..." if len(question) > 100 else question,
            "has_file": bool(file_name),
            "media_type": "youtube_video",
            "media_url": youtube_url,
            "task_type": task_type  # Add task type for more specific handling
        }
    
    def _validate_classification(self, classification: Dict, question: str, file_name: str) -> Dict:
        """Validate and normalize classification response"""
        
        # Ensure primary agent is valid
        primary_agent = classification.get("primary_agent", "GENERAL")
        if primary_agent not in [agent.value.upper() for agent in AgentType]:
            primary_agent = "GENERAL"
        
        # Validate secondary agents
        secondary_agents = classification.get("secondary_agents", [])
        valid_secondary = [
            agent for agent in secondary_agents 
            if agent.upper() in [a.value.upper() for a in AgentType]
        ]
        
        # Ensure confidence is between 0 and 1
        confidence = max(0.0, min(1.0, classification.get("confidence", 0.5)))
        
        # Ensure complexity is between 1 and 5
        complexity = max(1, min(5, classification.get("complexity", 3)))
        
        return {
            "primary_agent": primary_agent.lower(),
            "secondary_agents": [agent.lower() for agent in valid_secondary],
            "complexity": complexity,
            "confidence": confidence,
            "tools_needed": classification.get("tools_needed", []),
            "reasoning": classification.get("reasoning", "Automated classification"),
            "requires_multimodal": classification.get("requires_multimodal", False),
            "estimated_steps": classification.get("estimated_steps", 5),
            "question_summary": question[:100] + "..." if len(question) > 100 else question,
            "has_file": bool(file_name)
        }
    
    def _fallback_classification(self, question: str, file_name: str = "") -> Dict:
        """Fallback classification when LLM fails"""
        
        # Simple heuristic-based fallback
        question_lower = question.lower()
        
        # Check for YouTube URL first (most specific case) - use enhanced pattern
        youtube_match = re.search(ENHANCED_YOUTUBE_URL_PATTERN, question)
        if youtube_match:
            # Use the dedicated method for YouTube classification to ensure consistency
            return self._create_youtube_video_classification(question, file_name)
        
        # Secondary check for YouTube references (may not have a valid URL format)
        if "youtube" in question_lower and any(keyword in question_lower for keyword in 
                                              ["video", "watch", "link", "url", "channel"]):
            # Likely a YouTube question even without a perfect URL match
            # Create a custom classification with high confidence
            return {
                "primary_agent": "multimedia",
                "secondary_agents": [],
                "complexity": 3,
                "confidence": 0.85,
                "tools_needed": ["analyze_youtube_video"],
                "reasoning": "Fallback detected YouTube reference without complete URL",
                "requires_multimodal": True,
                "estimated_steps": 3,
                "question_summary": question[:100] + "..." if len(question) > 100 else question,
                "has_file": bool(file_name),
                "media_type": "youtube_video",
                "media_url": "youtube_reference_detected"  # Placeholder
            }
        
        # Check other multimedia patterns
        # Video patterns (beyond YouTube)
        elif any(re.search(pattern, question_lower) for pattern in VIDEO_PATTERNS):
            return {
                "primary_agent": "multimedia", 
                "secondary_agents": [],
                "complexity": 3,
                "confidence": 0.8,
                "tools_needed": ["analyze_video_frames"],
                "reasoning": "Fallback detected video-related content",
                "requires_multimodal": True,
                "estimated_steps": 4,
                "question_summary": question[:100] + "..." if len(question) > 100 else question,
                "has_file": bool(file_name),
                "media_type": "video"
            }
        
        # Audio patterns
        elif any(re.search(pattern, question_lower) for pattern in AUDIO_PATTERNS):
            return {
                "primary_agent": "multimedia",
                "secondary_agents": [],
                "complexity": 3,
                "confidence": 0.8,
                "tools_needed": ["analyze_audio_file"],
                "reasoning": "Fallback detected audio-related content",
                "requires_multimodal": True,
                "estimated_steps": 3,
                "question_summary": question[:100] + "..." if len(question) > 100 else question,
                "has_file": bool(file_name),
                "media_type": "audio"
            }
        
        # Image patterns
        elif any(re.search(pattern, question_lower) for pattern in IMAGE_PATTERNS):
            return {
                "primary_agent": "multimedia",
                "secondary_agents": [],
                "complexity": 2,
                "confidence": 0.8,
                "tools_needed": ["analyze_image_with_gemini"],
                "reasoning": "Fallback detected image-related content",
                "requires_multimodal": True,
                "estimated_steps": 2,
                "question_summary": question[:100] + "..." if len(question) > 100 else question,
                "has_file": bool(file_name),
                "media_type": "image"
            }
        
        # General multimedia keywords
        elif any(keyword in question_lower for keyword in ["multimedia", "visual", "picture", "screenshot"]):
            primary_agent = "multimedia"
            tools_needed = ["analyze_image_with_gemini"]
        
        # Research patterns
        elif any(keyword in question_lower for keyword in ["wikipedia", "search", "find", "who", "what", "when", "where"]):
            primary_agent = "research"
            tools_needed = ["research_with_comprehensive_fallback"]
        
        # Math/Logic patterns  
        elif any(keyword in question_lower for keyword in ["calculate", "number", "count", "math", "opposite", "pattern"]):
            primary_agent = "logic_math"
            tools_needed = ["advanced_calculator"]
        
        # File processing
        elif file_name and any(ext in file_name.lower() for ext in [".xlsx", ".py", ".csv", ".pdf"]):
            primary_agent = "file_processing"
            if ".xlsx" in file_name.lower():
                tools_needed = ["analyze_excel_file"]
            elif ".py" in file_name.lower():
                tools_needed = ["analyze_python_code"]
            else:
                tools_needed = ["analyze_text_file"]
        
        # Default
        else:
            primary_agent = "general"
            tools_needed = []
        
        return {
            "primary_agent": primary_agent,
            "secondary_agents": [],
            "complexity": 3,
            "confidence": 0.6,
            "tools_needed": tools_needed,
            "reasoning": "Fallback heuristic classification",
            "requires_multimodal": bool(file_name),
            "estimated_steps": 5,
            "question_summary": question[:100] + "..." if len(question) > 100 else question,
            "has_file": bool(file_name)
        }
    
    def batch_classify(self, questions: List[Dict]) -> List[Dict]:
        """Classify multiple questions in batch"""
        results = []
        
        for q in questions:
            question_text = q.get("question", "")
            file_name = q.get("file_name", "")
            task_id = q.get("task_id", "")
            
            classification = self.classify_question(question_text, file_name)
            classification["task_id"] = task_id
            
            results.append(classification)
        
        return results
    
    def get_routing_recommendation(self, classification: Dict) -> Dict:
        """Get specific routing recommendations based on classification"""
        
        primary_agent = classification["primary_agent"]
        complexity = classification["complexity"]
        
        routing = {
            "primary_route": primary_agent,
            "requires_coordination": len(classification["secondary_agents"]) > 0,
            "parallel_execution": False,
            "estimated_duration": "medium",
            "special_requirements": []
        }
        
        # Add special requirements based on agent type
        if primary_agent == "multimedia":
            routing["special_requirements"].extend([
                "Requires yt-dlp and ffmpeg for video processing",
                "Needs Gemini Vision API for image analysis",
                "May need large temp storage for video files"
            ])
        elif primary_agent == "research":
            routing["special_requirements"].extend([
                "Requires web search and Wikipedia API access",
                "May need academic database access",
                "Benefits from citation tracking tools"
            ])
        elif primary_agent == "file_processing":
            routing["special_requirements"].extend([
                "Requires file processing libraries (pandas, openpyxl)",
                "May need sandboxed code execution environment",
                "Needs secure file handling"
            ])
        
        # Adjust duration estimate based on complexity
        if complexity >= 4:
            routing["estimated_duration"] = "long"
        elif complexity <= 2:
            routing["estimated_duration"] = "short"
        
        # Suggest parallel execution for multi-agent scenarios
        if len(classification["secondary_agents"]) >= 2:
            routing["parallel_execution"] = True
        
        return routing


def test_classifier():
    """Test the classifier with sample GAIA questions"""
    
    # Sample questions from our GAIA set
    test_questions = [
        {
            "task_id": "video_test",
            "question": "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?",
            "file_name": ""
        },
        {
            "task_id": "youtube_short_test",
            "question": "Check this YouTube video https://youtu.be/L1vXCYZAYYM and count the birds",
            "file_name": ""
        },
        {
            "task_id": "video_url_variation",
            "question": "How many people appear in the YouTube video at youtube.com/watch?v=dQw4w9WgXcQ",
            "file_name": ""
        },
        {
            "task_id": "research_test", 
            "question": "How many studio albums were published by Mercedes Sosa between 2000 and 2009?",
            "file_name": ""
        },
        {
            "task_id": "logic_test",
            "question": ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI",
            "file_name": ""
        },
        {
            "task_id": "file_test",
            "question": "What is the final numeric output from the attached Python code?",
            "file_name": "script.py"
        }
    ]
    
    classifier = QuestionClassifier()
    
    print("🧠 Testing Question Classifier")
    print("=" * 50)
    
    for question in test_questions:
        print(f"\nπŸ“ Question: {question['question'][:80]}...")
        classification = classifier.classify_question(
            question["question"], 
            question["file_name"]
        )
        
        print(f"🎯 Primary Agent: {classification['primary_agent']}")
        print(f"πŸ”§ Tools Needed: {classification['tools_needed']}")
        print(f"πŸ“Š Complexity: {classification['complexity']}/5")
        print(f"🎲 Confidence: {classification['confidence']:.2f}")
        print(f"πŸ’­ Reasoning: {classification['reasoning']}")
        
        routing = classifier.get_routing_recommendation(classification)
        print(f"πŸš€ Routing: {routing['primary_route']} ({'coordination needed' if routing['requires_coordination'] else 'single agent'})")


if __name__ == "__main__":
    test_classifier()