File size: 3,785 Bytes
37cadfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3
"""
Final test for YouTube question classification and tool selection
"""

from question_classifier import QuestionClassifier

def test_classification():
    """Test that our classification improvements for YouTube questions are working"""
    
    # Initialize classifier
    classifier = QuestionClassifier()
    
    # Test cases
    test_cases = [
        {
            'question': 'In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species?',
            'expected_agent': 'multimedia',
            'expected_tool': 'analyze_youtube_video'
        },
        {
            'question': 'Tell me about the video at youtu.be/dQw4w9WgXcQ',
            'expected_agent': 'multimedia',
            'expected_tool': 'analyze_youtube_video'
        },
        {
            'question': 'What does Teal\'c say in the YouTube video youtube.com/watch?v=XYZ123?',
            'expected_agent': 'multimedia',
            'expected_tool': 'analyze_youtube_video'
        },
        {
            'question': 'How many birds appear in this image?',
            'expected_agent': 'multimedia',
            'expected_tool': 'analyze_image_with_gemini'
        },
        {
            'question': 'When was the first Star Wars movie released?',
            'expected_agent': 'research',
            'expected_tool': None
        }
    ]
    
    print("πŸ§ͺ Testing Question Classification for YouTube Questions")
    print("=" * 70)
    
    passed = 0
    for i, case in enumerate(test_cases):
        print(f"\nTest {i+1}: {case['question'][:80]}...")
        
        # Classify the question
        classification = classifier.classify_question(case['question'])
        
        # Check primary agent type
        agent_correct = classification['primary_agent'] == case['expected_agent']
        
        # Check if expected tool is in tools list
        expected_tool = case['expected_tool']
        if expected_tool:
            tool_correct = expected_tool in classification.get('tools_needed', [])
        else:
            # If no specific tool expected, just make sure analyze_youtube_video isn't 
            # incorrectly selected for non-YouTube questions
            tool_correct = 'analyze_youtube_video' not in classification.get('tools_needed', []) or 'youtube' in case['question'].lower()
        
        # Print results
        print(f"Expected agent: {case['expected_agent']}")
        print(f"Actual agent: {classification['primary_agent']}")
        print(f"Agent match: {'βœ…' if agent_correct else '❌'}")
        
        print(f"Expected tool: {case['expected_tool']}")
        print(f"Selected tools: {classification.get('tools_needed', [])}")
        print(f"Tool match: {'βœ…' if tool_correct else '❌'}")
        
        # Check which tools were selected first
        tools = classification.get('tools_needed', [])
        if tools and 'youtube' in case['question'].lower():
            if tools[0] == 'analyze_youtube_video':
                print("βœ… analyze_youtube_video correctly prioritized for YouTube question")
            else:
                print("❌ analyze_youtube_video not prioritized for YouTube question")
        
        # Print overall result
        if agent_correct and tool_correct:
            passed += 1
            print("βœ… TEST PASSED")
        else:
            print("❌ TEST FAILED")
    
    # Print summary
    print("\n" + "=" * 70)
    print(f"Final result: {passed}/{len(test_cases)} tests passed")
    
    if passed == len(test_cases):
        print("πŸŽ‰ All tests passed! The classifier is working correctly.")
    else:
        print("⚠️ Some tests failed. Further improvements needed.")

if __name__ == "__main__":
    test_classification()