Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Test for improved question classification and tool selection | |
Focuses on YouTube URL detection and appropriate tool selection | |
""" | |
import os | |
import sys | |
import re | |
from pathlib import Path | |
from question_classifier import QuestionClassifier | |
from main import GAIASolver | |
def test_youtube_classification(): | |
"""Test enhanced YouTube URL detection and classification""" | |
print("π§ͺ Testing improved YouTube classification") | |
print("=" * 50) | |
# Create classifier | |
classifier = QuestionClassifier() | |
# Test cases with various YouTube URL formats | |
test_cases = [ | |
{ | |
"id": "standard_youtube", | |
"question": "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species visible?", | |
"expected_type": "multimedia", | |
"expected_tool": "analyze_youtube_video" | |
}, | |
{ | |
"id": "shortened_youtube", | |
"question": "Check this YouTube video https://youtu.be/L1vXCYZAYYM and count the birds", | |
"expected_type": "multimedia", | |
"expected_tool": "analyze_youtube_video" | |
}, | |
{ | |
"id": "youtube_without_protocol", | |
"question": "How many people appear in the YouTube video at youtube.com/watch?v=dQw4w9WgXcQ", | |
"expected_type": "multimedia", | |
"expected_tool": "analyze_youtube_video" | |
}, | |
{ | |
"id": "youtube_embedded", | |
"question": "Count the number of times 'hello' is said in youtube.com/embed/dQw4w9WgXcQ", | |
"expected_type": "multimedia", | |
"expected_tool": "analyze_youtube_video" | |
}, | |
{ | |
"id": "youtube_without_direct_url", | |
"question": "There's a YouTube video about bird watching. How many species can you see?", | |
"expected_type": "multimedia", # Should detect this as likely multimedia | |
"expected_tool": None # May not specifically use analyze_youtube_video without URL | |
}, | |
{ | |
"id": "non_youtube_video", | |
"question": "Analyze the video file and tell me how many people appear in it.", | |
"expected_type": "multimedia", | |
"expected_tool": None # Should NOT be analyze_youtube_video | |
} | |
] | |
# Run tests | |
for case in test_cases: | |
print(f"\nπ Testing case: {case['id']}") | |
print(f"Question: {case['question']}") | |
# Classify | |
classification = classifier.classify_question(case['question']) | |
# Check primary agent type | |
agent_type = classification['primary_agent'] | |
print(f"π― Classified as: {agent_type}") | |
# Check if expected type matches | |
if agent_type == case['expected_type']: | |
print(f"β PASS: Correctly classified as {case['expected_type']}") | |
else: | |
print(f"β FAIL: Expected {case['expected_type']} but got {agent_type}") | |
# Check for specific tool | |
tools = classification.get('tools_needed', []) | |
print(f"π§ Tools selected: {tools}") | |
if case['expected_tool'] is not None: | |
if case['expected_tool'] in tools: | |
print(f"β PASS: Correctly included {case['expected_tool']} tool") | |
else: | |
print(f"β FAIL: Expected {case['expected_tool']} tool but not found") | |
elif case['expected_tool'] is None and "analyze_youtube_video" in tools and "youtube" not in case['question'].lower(): | |
print(f"β FAIL: Incorrectly included analyze_youtube_video tool for non-YouTube question") | |
# Print full classification data | |
print(f"π Classification data:") | |
for key, value in classification.items(): | |
if key not in ['question_summary']: # Skip lengthy fields | |
print(f" - {key}: {value}") | |
print("-" * 50) | |
def test_solver_tool_selection(): | |
"""Test if the improved GAIASolver selects correct tools""" | |
print("\n\nπ§ͺ Testing GAIASolver tool selection") | |
print("=" * 50) | |
# Create solver | |
try: | |
solver = GAIASolver() | |
# Test question with YouTube URL | |
test_question = { | |
"task_id": "youtube_test", | |
"question": "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species visible?", | |
} | |
print(f"\nπ Testing solver with YouTube question") | |
print(f"Question: {test_question['question']}") | |
# We don't need to run the full solve_question method | |
# Instead, just check that classification and tool selection are correct | |
classification = solver.classifier.classify_question(test_question['question']) | |
print(f"π― Classified as: {classification['primary_agent']}") | |
print(f"π§ Tools selected: {classification['tools_needed']}") | |
if "analyze_youtube_video" in classification['tools_needed']: | |
print("β PASS: Correctly selected analyze_youtube_video tool") | |
else: | |
print("β FAIL: Did not select analyze_youtube_video tool for YouTube question") | |
except Exception as e: | |
print(f"β Error initializing solver: {e}") | |
print("Skipping solver tests") | |
if __name__ == "__main__": | |
test_youtube_classification() | |
test_solver_tool_selection() | |