Spaces:
Running
Running
import requests | |
import json | |
import pytest | |
from typing import Dict, List | |
import os | |
from datetime import datetime | |
import time | |
# Base URLs for different environments | |
LOCAL_URL = "http://127.0.0.1:8000" | |
PROD_URL = "https://theaniketgiri-synthex.hf.space" | |
# Use environment variable to determine which URL to use | |
BASE_URL = os.getenv("API_URL", LOCAL_URL) | |
def wait_for_model_loading(max_retries=10, delay=30): | |
"""Wait for model to load before running tests""" | |
for i in range(max_retries): | |
try: | |
response = requests.get(f"{BASE_URL}/health") | |
data = response.json() | |
print(f"\nHealth check response: {json.dumps(data, indent=2)}") | |
if data.get("model_loaded", False): | |
return True | |
elif data.get("model_loading", False): | |
print(f"Model is still loading, attempt {i+1}/{max_retries}") | |
else: | |
print(f"Model not loaded yet, attempt {i+1}/{max_retries}") | |
time.sleep(delay) | |
except Exception as e: | |
print(f"Error checking health: {str(e)}") | |
time.sleep(delay) | |
return False | |
class TestBackendAPI: | |
def setup_class(cls): | |
"""Setup before running tests""" | |
if not wait_for_model_loading(): | |
pytest.skip("Model failed to load within timeout") | |
def test_health(self): | |
"""Test the health check endpoint""" | |
response = requests.get(f"{BASE_URL}/health") | |
assert response.status_code == 200 | |
data = response.json() | |
assert "status" in data | |
assert data["status"] in ["healthy", "unhealthy"] | |
assert "timestamp" in data | |
assert "model_loaded" in data | |
print(f"\n=== Health Check ===") | |
print(f"Status: {data['status']}") | |
print(f"Model Loaded: {data['model_loaded']}") | |
print(f"Timestamp: {data['timestamp']}") | |
def test_generate_single_record(self, record_type: str): | |
"""Test generating a single record of each type""" | |
url = f"{BASE_URL}/generate" | |
payload = { | |
"record_type": record_type, | |
"count": 1 | |
} | |
print(f"\n=== Generating {record_type} ===") | |
response = requests.post(url, json=payload) | |
if response.status_code == 503: | |
pytest.skip("Model not loaded") | |
elif response.status_code == 500: | |
error = response.json() | |
pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}") | |
assert response.status_code == 200 | |
data = response.json() | |
assert isinstance(data, list) | |
assert len(data) == 1 | |
record = data[0] | |
print(f"Generated Record:") | |
print(json.dumps(record, indent=2)) | |
# Validate record structure | |
assert "type" in record | |
assert record["type"] == record_type | |
assert "content" in record | |
assert "generated_at" in record | |
def test_generate_multiple_records(self): | |
"""Test generating multiple records""" | |
url = f"{BASE_URL}/generate" | |
payload = { | |
"record_type": "clinical_note", | |
"count": 3 | |
} | |
print("\n=== Generating Multiple Records ===") | |
response = requests.post(url, json=payload) | |
if response.status_code == 503: | |
pytest.skip("Model not loaded") | |
elif response.status_code == 500: | |
error = response.json() | |
pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}") | |
assert response.status_code == 200 | |
data = response.json() | |
assert isinstance(data, list) | |
assert len(data) == 3 | |
print(f"Generated {len(data)} records") | |
for i, record in enumerate(data, 1): | |
print(f"\nRecord {i}:") | |
print(json.dumps(record, indent=2)) | |
def test_invalid_record_type(self): | |
"""Test error handling for invalid record type""" | |
url = f"{BASE_URL}/generate" | |
payload = { | |
"record_type": "invalid_type", | |
"count": 1 | |
} | |
print("\n=== Testing Invalid Record Type ===") | |
response = requests.post(url, json=payload) | |
assert response.status_code == 422 # FastAPI validation error | |
error = response.json() | |
assert "detail" in error | |
print(f"Error: {error['detail']}") | |
def test_invalid_count(self): | |
"""Test error handling for invalid count""" | |
url = f"{BASE_URL}/generate" | |
payload = { | |
"record_type": "clinical_note", | |
"count": 0 | |
} | |
print("\n=== Testing Invalid Count ===") | |
response = requests.post(url, json=payload) | |
assert response.status_code == 422 # FastAPI validation error | |
error = response.json() | |
assert "detail" in error | |
print(f"Error: {error['detail']}") | |
def test_record_content_quality(self): | |
"""Test the quality of generated record content""" | |
url = f"{BASE_URL}/generate" | |
payload = { | |
"record_type": "clinical_note", | |
"count": 1 | |
} | |
print("\n=== Testing Record Content Quality ===") | |
response = requests.post(url, json=payload) | |
if response.status_code == 503: | |
pytest.skip("Model not loaded") | |
elif response.status_code == 500: | |
error = response.json() | |
pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}") | |
assert response.status_code == 200 | |
data = response.json() | |
record = data[0] | |
# Check content length | |
assert len(record["content"]) > 100, "Content too short" | |
# Check for common medical terms | |
medical_terms = ["patient", "diagnosis", "treatment", "symptoms"] | |
content_lower = record["content"].lower() | |
assert any(term in content_lower for term in medical_terms), "Missing medical terminology" | |
print("Content Quality Checks Passed") | |
print(f"Content Length: {len(record['content'])} characters") | |
def main(): | |
"""Run all tests""" | |
print("Starting API Tests...") | |
print(f"Testing against: {BASE_URL}") | |
print("=" * 50) | |
test_suite = TestBackendAPI() | |
# Run all tests | |
test_suite.test_health() | |
test_suite.test_generate_single_record("clinical_note") | |
test_suite.test_generate_multiple_records() | |
test_suite.test_invalid_record_type() | |
test_suite.test_invalid_count() | |
test_suite.test_record_content_quality() | |
print("\nAll tests completed successfully!") | |
print("=" * 50) | |
if __name__ == "__main__": | |
main() |