synthex / test_backend.py
theaniketgiri's picture
backend
373e5ff
import requests
import json
import pytest
from typing import Dict, List
import os
from datetime import datetime
import time
# Base URLs for different environments
LOCAL_URL = "http://127.0.0.1:8000"
PROD_URL = "https://theaniketgiri-synthex.hf.space"
# Use environment variable to determine which URL to use
BASE_URL = os.getenv("API_URL", LOCAL_URL)
def wait_for_model_loading(max_retries=10, delay=30):
"""Wait for model to load before running tests"""
for i in range(max_retries):
try:
response = requests.get(f"{BASE_URL}/health")
data = response.json()
print(f"\nHealth check response: {json.dumps(data, indent=2)}")
if data.get("model_loaded", False):
return True
elif data.get("model_loading", False):
print(f"Model is still loading, attempt {i+1}/{max_retries}")
else:
print(f"Model not loaded yet, attempt {i+1}/{max_retries}")
time.sleep(delay)
except Exception as e:
print(f"Error checking health: {str(e)}")
time.sleep(delay)
return False
class TestBackendAPI:
@classmethod
def setup_class(cls):
"""Setup before running tests"""
if not wait_for_model_loading():
pytest.skip("Model failed to load within timeout")
def test_health(self):
"""Test the health check endpoint"""
response = requests.get(f"{BASE_URL}/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert data["status"] in ["healthy", "unhealthy"]
assert "timestamp" in data
assert "model_loaded" in data
print(f"\n=== Health Check ===")
print(f"Status: {data['status']}")
print(f"Model Loaded: {data['model_loaded']}")
print(f"Timestamp: {data['timestamp']}")
@pytest.mark.parametrize("record_type", [
"clinical_note",
"discharge_summary",
"lab_report",
"prescription"
])
def test_generate_single_record(self, record_type: str):
"""Test generating a single record of each type"""
url = f"{BASE_URL}/generate"
payload = {
"record_type": record_type,
"count": 1
}
print(f"\n=== Generating {record_type} ===")
response = requests.post(url, json=payload)
if response.status_code == 503:
pytest.skip("Model not loaded")
elif response.status_code == 500:
error = response.json()
pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 1
record = data[0]
print(f"Generated Record:")
print(json.dumps(record, indent=2))
# Validate record structure
assert "type" in record
assert record["type"] == record_type
assert "content" in record
assert "generated_at" in record
def test_generate_multiple_records(self):
"""Test generating multiple records"""
url = f"{BASE_URL}/generate"
payload = {
"record_type": "clinical_note",
"count": 3
}
print("\n=== Generating Multiple Records ===")
response = requests.post(url, json=payload)
if response.status_code == 503:
pytest.skip("Model not loaded")
elif response.status_code == 500:
error = response.json()
pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 3
print(f"Generated {len(data)} records")
for i, record in enumerate(data, 1):
print(f"\nRecord {i}:")
print(json.dumps(record, indent=2))
def test_invalid_record_type(self):
"""Test error handling for invalid record type"""
url = f"{BASE_URL}/generate"
payload = {
"record_type": "invalid_type",
"count": 1
}
print("\n=== Testing Invalid Record Type ===")
response = requests.post(url, json=payload)
assert response.status_code == 422 # FastAPI validation error
error = response.json()
assert "detail" in error
print(f"Error: {error['detail']}")
def test_invalid_count(self):
"""Test error handling for invalid count"""
url = f"{BASE_URL}/generate"
payload = {
"record_type": "clinical_note",
"count": 0
}
print("\n=== Testing Invalid Count ===")
response = requests.post(url, json=payload)
assert response.status_code == 422 # FastAPI validation error
error = response.json()
assert "detail" in error
print(f"Error: {error['detail']}")
def test_record_content_quality(self):
"""Test the quality of generated record content"""
url = f"{BASE_URL}/generate"
payload = {
"record_type": "clinical_note",
"count": 1
}
print("\n=== Testing Record Content Quality ===")
response = requests.post(url, json=payload)
if response.status_code == 503:
pytest.skip("Model not loaded")
elif response.status_code == 500:
error = response.json()
pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
assert response.status_code == 200
data = response.json()
record = data[0]
# Check content length
assert len(record["content"]) > 100, "Content too short"
# Check for common medical terms
medical_terms = ["patient", "diagnosis", "treatment", "symptoms"]
content_lower = record["content"].lower()
assert any(term in content_lower for term in medical_terms), "Missing medical terminology"
print("Content Quality Checks Passed")
print(f"Content Length: {len(record['content'])} characters")
def main():
"""Run all tests"""
print("Starting API Tests...")
print(f"Testing against: {BASE_URL}")
print("=" * 50)
test_suite = TestBackendAPI()
# Run all tests
test_suite.test_health()
test_suite.test_generate_single_record("clinical_note")
test_suite.test_generate_multiple_records()
test_suite.test_invalid_record_type()
test_suite.test_invalid_count()
test_suite.test_record_content_quality()
print("\nAll tests completed successfully!")
print("=" * 50)
if __name__ == "__main__":
main()