Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Automated API Testing Script for Stock Monitoring API | |
Tests authentication, endpoints, and security features. | |
Updated for new API architecture: | |
- Removed /data/download endpoint (now uses only /data/download-all) | |
- Added force_refresh and force_indicators parameters | |
- Updated bulk download strategy testing with 3-month period | |
- Added technical indicators validation (SMA 10, 20, 50) | |
- Updated response validation for new SMA fields | |
""" | |
import requests | |
import json | |
import time | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables from parent directory | |
load_dotenv(dotenv_path="../.env") | |
PORT = os.getenv("PORT", "7860") | |
print(f"Using PORT: {PORT}") | |
# Configuration | |
BASE_URL = f"http://localhost:{PORT}" | |
API_KEY = os.getenv("API_KEY") | |
INVALID_API_KEY = "invalid_key_for_testing" | |
# Headers | |
HEADERS_NO_AUTH = {"Content-Type": "application/json"} | |
HEADERS_VALID_AUTH = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {API_KEY}" | |
} | |
HEADERS_INVALID_AUTH = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {INVALID_API_KEY}" | |
} | |
def print_test_header(test_name): | |
"""Print formatted test header.""" | |
print(f"\n{'='*60}") | |
print(f"π§ͺ {test_name}") | |
print(f"{'='*60}") | |
def print_result(endpoint, method, expected_status, actual_status, passed): | |
"""Print test result.""" | |
status_icon = "β " if passed else "β" | |
print(f"{status_icon} {method} {endpoint}") | |
print(f" Expected: {expected_status}, Got: {actual_status}") | |
if not passed: | |
print(f" β TEST FAILED") | |
return passed | |
def test_health_check(): | |
"""Test the health check endpoint (should be public).""" | |
print_test_header("Health Check (Public Endpoint)") | |
try: | |
response = requests.get(f"{BASE_URL}/", headers=HEADERS_NO_AUTH, timeout=10) | |
passed = response.status_code == 200 | |
print_result("/", "GET", 200, response.status_code, passed) | |
if passed: | |
data = response.json() | |
print(f" π Status: {data.get('status')}") | |
print(f" π Timestamp: {data.get('timestamp')}") | |
print(f" πΎ DB Connected: {data.get('database', {}).get('connected')}") | |
return passed | |
except Exception as e: | |
print(f"β Health check failed: {e}") | |
return False | |
def test_public_endpoints(): | |
"""Test public endpoints that should work without authentication.""" | |
print_test_header("Public Endpoints (No Auth Required)") | |
all_passed = True | |
# Test GET /tickers | |
try: | |
response = requests.get(f"{BASE_URL}/tickers?limit=5", headers=HEADERS_NO_AUTH, timeout=10) | |
passed = response.status_code == 200 | |
all_passed &= print_result("/tickers", "GET", 200, response.status_code, passed) | |
if passed: | |
data = response.json() | |
print(f" π Returned {len(data)} tickers") | |
except Exception as e: | |
print(f"β GET /tickers failed: {e}") | |
all_passed = False | |
return all_passed | |
def test_protected_endpoints_no_auth(): | |
"""Test protected endpoints without authentication (should fail).""" | |
print_test_header("Protected Endpoints - No Auth (Should Fail)") | |
all_passed = True | |
protected_endpoints = [ | |
("POST", "/tickers/update", {"force_refresh": False}), | |
("POST", "/tickers/update-async", {"force_refresh": False}), | |
("POST", "/data/download-all", {"force_refresh": False, "force_indicators": False}), | |
("GET", "/tasks", None), | |
("DELETE", "/tasks/old", None) | |
] | |
for method, endpoint, payload in protected_endpoints: | |
try: | |
if method == "GET": | |
response = requests.get(f"{BASE_URL}{endpoint}", headers=HEADERS_NO_AUTH, timeout=10) | |
elif method == "POST": | |
response = requests.post(f"{BASE_URL}{endpoint}", headers=HEADERS_NO_AUTH, json=payload, timeout=10) | |
elif method == "DELETE": | |
response = requests.delete(f"{BASE_URL}{endpoint}", headers=HEADERS_NO_AUTH, timeout=10) | |
# Should return 403 (Forbidden) or 401 (Unauthorized) | |
passed = response.status_code in [401, 403] | |
all_passed &= print_result(endpoint, method, "401/403", response.status_code, passed) | |
except Exception as e: | |
print(f"β {method} {endpoint} failed: {e}") | |
all_passed = False | |
return all_passed | |
def test_protected_endpoints_invalid_auth(): | |
"""Test protected endpoints with invalid authentication (should fail).""" | |
print_test_header("Protected Endpoints - Invalid Auth (Should Fail)") | |
all_passed = True | |
protected_endpoints = [ | |
("POST", "/tickers/update", {"force_refresh": False}), | |
("POST", "/data/download-all", {"force_refresh": False, "force_indicators": False}), | |
("GET", "/tasks", None), | |
] | |
for method, endpoint, payload in protected_endpoints: | |
try: | |
if method == "GET": | |
response = requests.get(f"{BASE_URL}{endpoint}", headers=HEADERS_INVALID_AUTH, timeout=10) | |
elif method == "POST": | |
response = requests.post(f"{BASE_URL}{endpoint}", headers=HEADERS_INVALID_AUTH, json=payload, timeout=10) | |
# Should return 401 (Unauthorized) | |
passed = response.status_code == 401 | |
all_passed &= print_result(endpoint, method, "401", response.status_code, passed) | |
except Exception as e: | |
print(f"β {method} {endpoint} failed: {e}") | |
all_passed = False | |
return all_passed | |
def test_protected_endpoints_valid_auth(): | |
"""Test protected endpoints with valid authentication (should succeed).""" | |
print_test_header("Protected Endpoints - Valid Auth (Should Succeed)") | |
all_passed = True | |
# Test GET /tasks | |
try: | |
response = requests.get(f"{BASE_URL}/tasks", headers=HEADERS_VALID_AUTH, timeout=10) | |
passed = response.status_code == 200 | |
all_passed &= print_result("/tasks", "GET", 200, response.status_code, passed) | |
if passed: | |
data = response.json() | |
print(f" π Found {len(data)} tasks") | |
except Exception as e: | |
print(f"β GET /tasks failed: {e}") | |
all_passed = False | |
# Test POST /tickers/update-async (safer than sync version) | |
try: | |
response = requests.post( | |
f"{BASE_URL}/tickers/update-async", | |
headers=HEADERS_VALID_AUTH, | |
json={"force_refresh": False}, | |
timeout=15 | |
) | |
passed = response.status_code == 200 | |
all_passed &= print_result("/tickers/update-async", "POST", 200, response.status_code, passed) | |
if passed: | |
data = response.json() | |
task_id = data.get("task_id") | |
print(f" π Task started: {task_id}") | |
# Test GET /tasks/{task_id} | |
if task_id: | |
time.sleep(1) # Give task a moment to start | |
response = requests.get(f"{BASE_URL}/tasks/{task_id}", headers=HEADERS_VALID_AUTH, timeout=10) | |
passed = response.status_code == 200 | |
all_passed &= print_result(f"/tasks/{task_id}", "GET", 200, response.status_code, passed) | |
if passed: | |
task_data = response.json() | |
print(f" π Task status: {task_data.get('status')}") | |
except Exception as e: | |
print(f"β POST /tickers/update-async failed: {e}") | |
all_passed = False | |
# Test DELETE /tasks/old | |
try: | |
response = requests.delete(f"{BASE_URL}/tasks/old", headers=HEADERS_VALID_AUTH, timeout=10) | |
passed = response.status_code == 200 | |
all_passed &= print_result("/tasks/old", "DELETE", 200, response.status_code, passed) | |
if passed: | |
data = response.json() | |
print(f" ποΈ Deleted {data.get('deleted', 0)} old tasks") | |
except Exception as e: | |
print(f"β DELETE /tasks/old failed: {e}") | |
all_passed = False | |
return all_passed | |
def test_data_endpoints(): | |
"""Test data download and query endpoints.""" | |
print_test_header("Data Endpoints - Valid Auth (Should Succeed)") | |
all_passed = True | |
# Test POST /data/download-all (bulk download with automatic freshness check) | |
# Note: This endpoint now automatically checks if data is <24h old and skips update if fresh | |
# First run will download all data, subsequent runs may return "data is fresh" message | |
# Now supports force_refresh and force_indicators parameters | |
try: | |
response = requests.post( | |
f"{BASE_URL}/data/download-all", | |
headers=HEADERS_VALID_AUTH, | |
json={"force_refresh": False, "force_indicators": False}, | |
timeout=120 # Bulk download might take longer with 3mo data and technical indicators | |
) | |
passed = response.status_code == 200 | |
all_passed &= print_result("/data/download-all", "POST", 200, response.status_code, passed) | |
if passed: | |
data = response.json() | |
print(f" π Processed {data.get('tickers_processed', 0)} tickers") | |
print(f" π Created {data.get('records_created', 0)} records") | |
print(f" π Updated {data.get('records_updated', 0)} records") | |
print(f" π Date range: {data.get('date_range', {}).get('start_date')} to {data.get('date_range', {}).get('end_date')}") | |
print(f" π¬ Message: {data.get('message', 'N/A')}") | |
except Exception as e: | |
print(f"β POST /data/download-all failed: {e}") | |
all_passed = False | |
# Test GET /data/tickers/{ticker} (public endpoint) | |
try: | |
response = requests.get(f"{BASE_URL}/data/tickers/AAPL?days=5", headers=HEADERS_NO_AUTH, timeout=10) | |
passed = response.status_code == 200 | |
all_passed &= print_result("/data/tickers/AAPL", "GET", 200, response.status_code, passed) | |
if passed: | |
data = response.json() | |
print(f" π Retrieved {len(data)} days of AAPL data") | |
if data: | |
latest = data[0] | |
print(f" π° Latest close: ${latest.get('close', 0):.2f}") | |
# Check for technical indicators | |
sma_fast = latest.get('sma_fast') | |
sma_med = latest.get('sma_med') | |
sma_slow = latest.get('sma_slow') | |
if sma_fast is not None: | |
print(f" π SMA Fast (10): ${sma_fast:.2f}") | |
if sma_med is not None: | |
print(f" π SMA Med (20): ${sma_med:.2f}") | |
if sma_slow is not None: | |
print(f" π SMA Slow (50): ${sma_slow:.2f}") | |
except Exception as e: | |
print(f"β GET /data/tickers/AAPL failed: {e}") | |
all_passed = False | |
return all_passed | |
def test_sql_injection_safety(): | |
"""Test that SQL injection attempts are safely handled.""" | |
print_test_header("SQL Injection Safety Tests") | |
all_passed = True | |
# Test various SQL injection attempts in query parameters | |
injection_attempts = [ | |
"'; DROP TABLE tickers; --", | |
"' OR '1'='1", | |
"1' UNION SELECT * FROM tasks --", | |
"'; DELETE FROM tasks; --" | |
] | |
for injection in injection_attempts: | |
try: | |
# Test in ticker endpoint (should be safely parameterized) | |
response = requests.get( | |
f"{BASE_URL}/tickers", | |
params={"limit": injection}, | |
headers=HEADERS_NO_AUTH, | |
timeout=10 | |
) | |
# Should either return 422 (validation error) or 200 with safe handling | |
passed = response.status_code in [200, 422] | |
print_result(f"/tickers?limit={injection[:20]}...", "GET", "200/422", response.status_code, passed) | |
all_passed &= passed | |
except Exception as e: | |
print(f"β SQL injection test failed: {e}") | |
all_passed = False | |
print(" π‘οΈ SQL injection tests completed") | |
return all_passed | |
def main(): | |
"""Run all tests.""" | |
print("π§ͺ Starting Stock Monitoring API Tests") | |
print(f"π Base URL: {BASE_URL}") | |
print(f"π API Key: {API_KEY[:10]}...") | |
all_tests_passed = True | |
# Run test suites | |
all_tests_passed &= test_health_check() | |
all_tests_passed &= test_public_endpoints() | |
all_tests_passed &= test_protected_endpoints_no_auth() | |
all_tests_passed &= test_protected_endpoints_invalid_auth() | |
all_tests_passed &= test_protected_endpoints_valid_auth() | |
all_tests_passed &= test_data_endpoints() | |
all_tests_passed &= test_technical_indicators() | |
all_tests_passed &= test_sql_injection_safety() | |
# Final results | |
print(f"\n{'='*60}") | |
if all_tests_passed: | |
print("π ALL TESTS PASSED! β ") | |
print("β API Key authentication is working") | |
print("β Protected endpoints are secure") | |
print("β SQL injection protection is active") | |
print("β Public endpoints are accessible") | |
print("β Bulk data download with freshness check is working") | |
print("β Technical indicators (SMA 10, 20, 50) are working") | |
print("β 3-month data period and force_indicators flag functional") | |
print("β New optimized API architecture is functional") | |
else: | |
print("β SOME TESTS FAILED!") | |
print("β οΈ Please check the API implementation") | |
print(f"{'='*60}") | |
return 0 if all_tests_passed else 1 | |
def test_technical_indicators(): | |
"""Test technical indicators functionality.""" | |
print_test_header("Technical Indicators Tests") | |
all_passed = True | |
# Test POST /data/download-all with force_indicators=True | |
try: | |
response = requests.post( | |
f"{BASE_URL}/data/download-all", | |
headers=HEADERS_VALID_AUTH, | |
json={"force_refresh": False, "force_indicators": True}, | |
timeout=120 | |
) | |
passed = response.status_code == 200 | |
all_passed &= print_result("/data/download-all (force_indicators)", "POST", 200, response.status_code, passed) | |
if passed: | |
data = response.json() | |
print(f" π Processed {data.get('tickers_processed', 0)} tickers") | |
print(f" π¬ Message: {data.get('message', 'N/A')}") | |
except Exception as e: | |
print(f"β POST /data/download-all (force_indicators) failed: {e}") | |
all_passed = False | |
# Test that ticker data now includes technical indicators | |
try: | |
response = requests.get(f"{BASE_URL}/data/tickers/AAPL?days=60", headers=HEADERS_NO_AUTH, timeout=10) | |
passed = response.status_code == 200 | |
all_passed &= print_result("/data/tickers/AAPL (indicators validation)", "GET", 200, response.status_code, passed) | |
if passed: | |
data = response.json() | |
print(f" π Retrieved {len(data)} days of AAPL data for indicators test") | |
# Check that we have enough data and some records have indicators | |
indicators_found = 0 | |
for record in data: | |
if (record.get('sma_fast') is not None or | |
record.get('sma_med') is not None or | |
record.get('sma_slow') is not None): | |
indicators_found += 1 | |
print(f" π Records with indicators: {indicators_found}/{len(data)}") | |
# Validate that recent records have indicators (should have SMA after 50+ days) | |
if len(data) >= 50: | |
recent_records = data[:10] # Check most recent 10 records | |
sma_slow_count = sum(1 for r in recent_records if r.get('sma_slow') is not None) | |
print(f" π Recent records with SMA Slow (50): {sma_slow_count}/10") | |
except Exception as e: | |
print(f"β Technical indicators validation failed: {e}") | |
all_passed = False | |
return all_passed | |
if __name__ == "__main__": | |
exit(main()) |