Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Final test for YouTube question classification and tool selection | |
""" | |
from question_classifier import QuestionClassifier | |
def test_classification(): | |
"""Test that our classification improvements for YouTube questions are working""" | |
# Initialize classifier | |
classifier = QuestionClassifier() | |
# Test cases | |
test_cases = [ | |
{ | |
'question': 'In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species?', | |
'expected_agent': 'multimedia', | |
'expected_tool': 'analyze_youtube_video' | |
}, | |
{ | |
'question': 'Tell me about the video at youtu.be/dQw4w9WgXcQ', | |
'expected_agent': 'multimedia', | |
'expected_tool': 'analyze_youtube_video' | |
}, | |
{ | |
'question': 'What does Teal\'c say in the YouTube video youtube.com/watch?v=XYZ123?', | |
'expected_agent': 'multimedia', | |
'expected_tool': 'analyze_youtube_video' | |
}, | |
{ | |
'question': 'How many birds appear in this image?', | |
'expected_agent': 'multimedia', | |
'expected_tool': 'analyze_image_with_gemini' | |
}, | |
{ | |
'question': 'When was the first Star Wars movie released?', | |
'expected_agent': 'research', | |
'expected_tool': None | |
} | |
] | |
print("π§ͺ Testing Question Classification for YouTube Questions") | |
print("=" * 70) | |
passed = 0 | |
for i, case in enumerate(test_cases): | |
print(f"\nTest {i+1}: {case['question'][:80]}...") | |
# Classify the question | |
classification = classifier.classify_question(case['question']) | |
# Check primary agent type | |
agent_correct = classification['primary_agent'] == case['expected_agent'] | |
# Check if expected tool is in tools list | |
expected_tool = case['expected_tool'] | |
if expected_tool: | |
tool_correct = expected_tool in classification.get('tools_needed', []) | |
else: | |
# If no specific tool expected, just make sure analyze_youtube_video isn't | |
# incorrectly selected for non-YouTube questions | |
tool_correct = 'analyze_youtube_video' not in classification.get('tools_needed', []) or 'youtube' in case['question'].lower() | |
# Print results | |
print(f"Expected agent: {case['expected_agent']}") | |
print(f"Actual agent: {classification['primary_agent']}") | |
print(f"Agent match: {'β ' if agent_correct else 'β'}") | |
print(f"Expected tool: {case['expected_tool']}") | |
print(f"Selected tools: {classification.get('tools_needed', [])}") | |
print(f"Tool match: {'β ' if tool_correct else 'β'}") | |
# Check which tools were selected first | |
tools = classification.get('tools_needed', []) | |
if tools and 'youtube' in case['question'].lower(): | |
if tools[0] == 'analyze_youtube_video': | |
print("β analyze_youtube_video correctly prioritized for YouTube question") | |
else: | |
print("β analyze_youtube_video not prioritized for YouTube question") | |
# Print overall result | |
if agent_correct and tool_correct: | |
passed += 1 | |
print("β TEST PASSED") | |
else: | |
print("β TEST FAILED") | |
# Print summary | |
print("\n" + "=" * 70) | |
print(f"Final result: {passed}/{len(test_cases)} tests passed") | |
if passed == len(test_cases): | |
print("π All tests passed! The classifier is working correctly.") | |
else: | |
print("β οΈ Some tests failed. Further improvements needed.") | |
if __name__ == "__main__": | |
test_classification() | |