import requests
import json
import pytest
from typing import Dict, List
import os
from datetime import datetime
import time

# Base URLs for different environments
LOCAL_URL = "http://127.0.0.1:8000"
PROD_URL = "https://theaniketgiri-synthex.hf.space"

# Use environment variable to determine which URL to use
BASE_URL = os.getenv("API_URL", LOCAL_URL)

def wait_for_model_loading(max_retries=10, delay=30):
    """Wait for model to load before running tests"""
    for i in range(max_retries):
        try:
            response = requests.get(f"{BASE_URL}/health")
            data = response.json()
            print(f"\nHealth check response: {json.dumps(data, indent=2)}")
            
            if data.get("model_loaded", False):
                return True
            elif data.get("model_loading", False):
                print(f"Model is still loading, attempt {i+1}/{max_retries}")
            else:
                print(f"Model not loaded yet, attempt {i+1}/{max_retries}")
            time.sleep(delay)
        except Exception as e:
            print(f"Error checking health: {str(e)}")
            time.sleep(delay)
    return False

class TestBackendAPI:
    @classmethod
    def setup_class(cls):
        """Setup before running tests"""
        if not wait_for_model_loading():
            pytest.skip("Model failed to load within timeout")

    def test_health(self):
        """Test the health check endpoint"""
        response = requests.get(f"{BASE_URL}/health")
        assert response.status_code == 200
        data = response.json()
        assert "status" in data
        assert data["status"] in ["healthy", "unhealthy"]
        assert "timestamp" in data
        assert "model_loaded" in data
        print(f"\n=== Health Check ===")
        print(f"Status: {data['status']}")
        print(f"Model Loaded: {data['model_loaded']}")
        print(f"Timestamp: {data['timestamp']}")

    @pytest.mark.parametrize("record_type", [
        "clinical_note",
        "discharge_summary",
        "lab_report",
        "prescription"
    ])
    def test_generate_single_record(self, record_type: str):
        """Test generating a single record of each type"""
        url = f"{BASE_URL}/generate"
        payload = {
            "record_type": record_type,
            "count": 1
        }
        
        print(f"\n=== Generating {record_type} ===")
        response = requests.post(url, json=payload)
        
        if response.status_code == 503:
            pytest.skip("Model not loaded")
        elif response.status_code == 500:
            error = response.json()
            pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
            
        assert response.status_code == 200
        
        data = response.json()
        assert isinstance(data, list)
        assert len(data) == 1
        
        record = data[0]
        print(f"Generated Record:")
        print(json.dumps(record, indent=2))
        
        # Validate record structure
        assert "type" in record
        assert record["type"] == record_type
        assert "content" in record
        assert "generated_at" in record

    def test_generate_multiple_records(self):
        """Test generating multiple records"""
        url = f"{BASE_URL}/generate"
        payload = {
            "record_type": "clinical_note",
            "count": 3
        }
        
        print("\n=== Generating Multiple Records ===")
        response = requests.post(url, json=payload)
        
        if response.status_code == 503:
            pytest.skip("Model not loaded")
        elif response.status_code == 500:
            error = response.json()
            pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
            
        assert response.status_code == 200
        
        data = response.json()
        assert isinstance(data, list)
        assert len(data) == 3
        
        print(f"Generated {len(data)} records")
        for i, record in enumerate(data, 1):
            print(f"\nRecord {i}:")
            print(json.dumps(record, indent=2))

    def test_invalid_record_type(self):
        """Test error handling for invalid record type"""
        url = f"{BASE_URL}/generate"
        payload = {
            "record_type": "invalid_type",
            "count": 1
        }
        
        print("\n=== Testing Invalid Record Type ===")
        response = requests.post(url, json=payload)
        assert response.status_code == 422  # FastAPI validation error
        
        error = response.json()
        assert "detail" in error
        print(f"Error: {error['detail']}")

    def test_invalid_count(self):
        """Test error handling for invalid count"""
        url = f"{BASE_URL}/generate"
        payload = {
            "record_type": "clinical_note",
            "count": 0
        }
        
        print("\n=== Testing Invalid Count ===")
        response = requests.post(url, json=payload)
        assert response.status_code == 422  # FastAPI validation error
        
        error = response.json()
        assert "detail" in error
        print(f"Error: {error['detail']}")

    def test_record_content_quality(self):
        """Test the quality of generated record content"""
        url = f"{BASE_URL}/generate"
        payload = {
            "record_type": "clinical_note",
            "count": 1
        }
        
        print("\n=== Testing Record Content Quality ===")
        response = requests.post(url, json=payload)
        
        if response.status_code == 503:
            pytest.skip("Model not loaded")
        elif response.status_code == 500:
            error = response.json()
            pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
            
        assert response.status_code == 200
        
        data = response.json()
        record = data[0]
        
        # Check content length
        assert len(record["content"]) > 100, "Content too short"
        
        # Check for common medical terms
        medical_terms = ["patient", "diagnosis", "treatment", "symptoms"]
        content_lower = record["content"].lower()
        assert any(term in content_lower for term in medical_terms), "Missing medical terminology"
        
        print("Content Quality Checks Passed")
        print(f"Content Length: {len(record['content'])} characters")

def main():
    """Run all tests"""
    print("Starting API Tests...")
    print(f"Testing against: {BASE_URL}")
    print("=" * 50)
    
    test_suite = TestBackendAPI()
    
    # Run all tests
    test_suite.test_health()
    test_suite.test_generate_single_record("clinical_note")
    test_suite.test_generate_multiple_records()
    test_suite.test_invalid_record_type()
    test_suite.test_invalid_count()
    test_suite.test_record_content_quality()
    
    print("\nAll tests completed successfully!")
    print("=" * 50)

if __name__ == "__main__":
    main()