Markit_v2 / tests /test_implementation_structure.py
AnseMin's picture
Add advanced retrieval strategies and update dependencies for RAG implementation
21c909d
#!/usr/bin/env python3
"""
Test script to verify the Phase 1 implementation structure is correct.
This test checks imports, method signatures, and class structure without requiring API keys.
"""
import os
import sys
from pathlib import Path
# Add src to path
sys.path.append(str(Path(__file__).parent / "src"))
def test_imports():
"""Test that all new imports work correctly."""
print("πŸ”§ Testing Imports and Structure")
print("=" * 40)
try:
# Test vector store imports
from src.rag.vector_store import VectorStoreManager, vector_store_manager
print("βœ… VectorStoreManager imports successfully")
# Test chat service imports
from src.rag.chat_service import RAGChatService, rag_chat_service
print("βœ… RAGChatService imports successfully")
# Test LangChain community imports
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
print("βœ… BM25Retriever and EnsembleRetriever import successfully")
return True
except Exception as e:
print(f"❌ Import test failed: {e}")
return False
def test_method_signatures():
"""Test that all new methods have correct signatures."""
print("\nπŸ” Testing Method Signatures")
print("=" * 40)
try:
from src.rag.vector_store import VectorStoreManager
from src.rag.chat_service import RAGChatService
# Test VectorStoreManager methods
vm = VectorStoreManager()
# Check method exists
assert hasattr(vm, 'get_bm25_retriever'), "get_bm25_retriever method missing"
assert hasattr(vm, 'get_hybrid_retriever'), "get_hybrid_retriever method missing"
print("βœ… VectorStoreManager has new methods")
# Test RAGChatService methods
cs = RAGChatService()
assert hasattr(cs, 'chat_with_retrieval'), "chat_with_retrieval method missing"
assert hasattr(cs, 'chat_stream_with_retrieval'), "chat_stream_with_retrieval method missing"
assert hasattr(cs, 'set_default_retrieval_method'), "set_default_retrieval_method method missing"
print("βœ… RAGChatService has new methods")
# Test method parameters (basic signature check)
import inspect
# Check get_hybrid_retriever signature
sig = inspect.signature(vm.get_hybrid_retriever)
expected_params = ['k', 'semantic_weight', 'keyword_weight', 'search_type', 'search_kwargs']
actual_params = list(sig.parameters.keys())
for param in expected_params:
assert param in actual_params, f"Parameter {param} missing from get_hybrid_retriever"
print("βœ… get_hybrid_retriever has correct parameters")
# Check chat_with_retrieval signature
sig = inspect.signature(cs.chat_with_retrieval)
expected_params = ['user_message', 'retrieval_method', 'retrieval_config']
actual_params = list(sig.parameters.keys())
for param in expected_params:
assert param in actual_params, f"Parameter {param} missing from chat_with_retrieval"
print("βœ… chat_with_retrieval has correct parameters")
return True
except Exception as e:
print(f"❌ Method signature test failed: {e}")
return False
def test_class_attributes():
"""Test that classes have the required new attributes."""
print("\nπŸ“‹ Testing Class Attributes")
print("=" * 40)
try:
from src.rag.vector_store import VectorStoreManager
from src.rag.chat_service import RAGChatService
# Test VectorStoreManager attributes
vm = VectorStoreManager()
assert hasattr(vm, '_documents_cache'), "_documents_cache attribute missing"
assert hasattr(vm, '_bm25_retriever'), "_bm25_retriever attribute missing"
print("βœ… VectorStoreManager has new attributes")
# Test RAGChatService attributes
cs = RAGChatService()
assert hasattr(cs, '_current_retrieval_method'), "_current_retrieval_method attribute missing"
assert hasattr(cs, '_default_retrieval_method'), "_default_retrieval_method attribute missing"
assert hasattr(cs, '_default_retrieval_config'), "_default_retrieval_config attribute missing"
print("βœ… RAGChatService has new attributes")
return True
except Exception as e:
print(f"❌ Class attributes test failed: {e}")
return False
def test_configuration_options():
"""Test that different configuration options can be set."""
print("\nβš™οΈ Testing Configuration Options")
print("=" * 40)
try:
from src.rag.chat_service import rag_chat_service
# Test setting different retrieval methods
configs = [
("similarity", {"k": 4}),
("mmr", {"k": 3, "fetch_k": 10, "lambda_mult": 0.5}),
("hybrid", {"k": 4, "semantic_weight": 0.7, "keyword_weight": 0.3})
]
for method, config in configs:
try:
rag_chat_service.set_default_retrieval_method(method, config)
assert rag_chat_service._default_retrieval_method == method
assert rag_chat_service._default_retrieval_config == config
print(f"βœ… {method} configuration works")
except Exception as e:
print(f"❌ {method} configuration failed: {e}")
return False
return True
except Exception as e:
print(f"❌ Configuration test failed: {e}")
return False
def test_requirements_updated():
"""Test that requirements.txt has the new dependencies."""
print("\nπŸ“¦ Testing Requirements Update")
print("=" * 40)
try:
requirements_path = Path(__file__).parent / "requirements.txt"
if requirements_path.exists():
with open(requirements_path, 'r') as f:
content = f.read()
required_packages = [
"langchain-community",
"rank-bm25"
]
for package in required_packages:
if package in content:
print(f"βœ… {package} found in requirements.txt")
else:
print(f"❌ {package} missing from requirements.txt")
return False
return True
else:
print("❌ requirements.txt not found")
return False
except Exception as e:
print(f"❌ Requirements test failed: {e}")
return False
def main():
"""Run all structure tests."""
print("πŸš€ Phase 1 Implementation Structure Tests")
print("Testing code structure without requiring API keys")
print("=" * 60)
tests = [
("Imports", test_imports),
("Method Signatures", test_method_signatures),
("Class Attributes", test_class_attributes),
("Configuration Options", test_configuration_options),
("Requirements Update", test_requirements_updated)
]
results = {}
for test_name, test_func in tests:
try:
results[test_name] = test_func()
except Exception as e:
print(f"❌ {test_name} test crashed: {e}")
results[test_name] = False
# Summary
print("\nπŸ“‹ Structure Test Summary")
print("=" * 40)
passed_count = sum(1 for passed in results.values() if passed)
total_count = len(results)
for test_name, passed in results.items():
status = "βœ… PASSED" if passed else "❌ FAILED"
print(f"{test_name}: {status}")
print(f"\nOverall: {passed_count}/{total_count} tests passed")
if passed_count == total_count:
print("\nπŸŽ‰ Phase 1 Implementation Structure is PERFECT!")
print("βœ… All imports work correctly")
print("βœ… All method signatures are correct")
print("βœ… All class attributes are present")
print("βœ… Configuration system works")
print("βœ… Requirements are updated")
print("\nπŸ’‘ The implementation is ready for use once API keys are configured!")
return 0
else:
print(f"\n❌ {total_count - passed_count} structure issues found")
return 1
if __name__ == "__main__":
exit(main())