Final_Assignment / test_improved_classification.py
tonthatthienvu's picture
Clean repository without binary files
37cadfb
raw
history blame
5.47 kB
#!/usr/bin/env python3
"""
Test for improved question classification and tool selection
Focuses on YouTube URL detection and appropriate tool selection
"""
import os
import sys
import re
from pathlib import Path
from question_classifier import QuestionClassifier
from main import GAIASolver
def test_youtube_classification():
"""Test enhanced YouTube URL detection and classification"""
print("πŸ§ͺ Testing improved YouTube classification")
print("=" * 50)
# Create classifier
classifier = QuestionClassifier()
# Test cases with various YouTube URL formats
test_cases = [
{
"id": "standard_youtube",
"question": "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species visible?",
"expected_type": "multimedia",
"expected_tool": "analyze_youtube_video"
},
{
"id": "shortened_youtube",
"question": "Check this YouTube video https://youtu.be/L1vXCYZAYYM and count the birds",
"expected_type": "multimedia",
"expected_tool": "analyze_youtube_video"
},
{
"id": "youtube_without_protocol",
"question": "How many people appear in the YouTube video at youtube.com/watch?v=dQw4w9WgXcQ",
"expected_type": "multimedia",
"expected_tool": "analyze_youtube_video"
},
{
"id": "youtube_embedded",
"question": "Count the number of times 'hello' is said in youtube.com/embed/dQw4w9WgXcQ",
"expected_type": "multimedia",
"expected_tool": "analyze_youtube_video"
},
{
"id": "youtube_without_direct_url",
"question": "There's a YouTube video about bird watching. How many species can you see?",
"expected_type": "multimedia", # Should detect this as likely multimedia
"expected_tool": None # May not specifically use analyze_youtube_video without URL
},
{
"id": "non_youtube_video",
"question": "Analyze the video file and tell me how many people appear in it.",
"expected_type": "multimedia",
"expected_tool": None # Should NOT be analyze_youtube_video
}
]
# Run tests
for case in test_cases:
print(f"\nπŸ“ Testing case: {case['id']}")
print(f"Question: {case['question']}")
# Classify
classification = classifier.classify_question(case['question'])
# Check primary agent type
agent_type = classification['primary_agent']
print(f"🎯 Classified as: {agent_type}")
# Check if expected type matches
if agent_type == case['expected_type']:
print(f"βœ… PASS: Correctly classified as {case['expected_type']}")
else:
print(f"❌ FAIL: Expected {case['expected_type']} but got {agent_type}")
# Check for specific tool
tools = classification.get('tools_needed', [])
print(f"πŸ”§ Tools selected: {tools}")
if case['expected_tool'] is not None:
if case['expected_tool'] in tools:
print(f"βœ… PASS: Correctly included {case['expected_tool']} tool")
else:
print(f"❌ FAIL: Expected {case['expected_tool']} tool but not found")
elif case['expected_tool'] is None and "analyze_youtube_video" in tools and "youtube" not in case['question'].lower():
print(f"❌ FAIL: Incorrectly included analyze_youtube_video tool for non-YouTube question")
# Print full classification data
print(f"πŸ“‹ Classification data:")
for key, value in classification.items():
if key not in ['question_summary']: # Skip lengthy fields
print(f" - {key}: {value}")
print("-" * 50)
def test_solver_tool_selection():
"""Test if the improved GAIASolver selects correct tools"""
print("\n\nπŸ§ͺ Testing GAIASolver tool selection")
print("=" * 50)
# Create solver
try:
solver = GAIASolver()
# Test question with YouTube URL
test_question = {
"task_id": "youtube_test",
"question": "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species visible?",
}
print(f"\nπŸ“ Testing solver with YouTube question")
print(f"Question: {test_question['question']}")
# We don't need to run the full solve_question method
# Instead, just check that classification and tool selection are correct
classification = solver.classifier.classify_question(test_question['question'])
print(f"🎯 Classified as: {classification['primary_agent']}")
print(f"πŸ”§ Tools selected: {classification['tools_needed']}")
if "analyze_youtube_video" in classification['tools_needed']:
print("βœ… PASS: Correctly selected analyze_youtube_video tool")
else:
print("❌ FAIL: Did not select analyze_youtube_video tool for YouTube question")
except Exception as e:
print(f"❌ Error initializing solver: {e}")
print("Skipping solver tests")
if __name__ == "__main__":
test_youtube_classification()
test_solver_tool_selection()