theaniketgiri commited on
Commit
373e5ff
·
1 Parent(s): 3f61e65
Files changed (4) hide show
  1. README.md +63 -25
  2. app.py +129 -32
  3. pytest.ini +10 -0
  4. test_backend.py +209 -0
README.md CHANGED
@@ -7,42 +7,80 @@ sdk: docker
7
  app_port: 7860
8
  ---
9
 
10
- # Synthex Medical Text Generator
11
 
12
- A synthetic medical text generator built with FastAPI and Hugging Face Transformers.
13
 
14
- ## Features
15
 
16
- - Generate synthetic medical text data
17
- - Multiple record types:
18
- - Clinical Notes
19
- - Discharge Summaries
20
- - Lab Reports
21
- - Prescriptions
22
- - HIPAA-compliant fictional data
23
- - RESTful API endpoints
 
 
 
 
 
24
 
25
- ## API Endpoints
26
 
27
- - `GET /`: Get API information
28
- - `POST /generate`: Generate medical records
29
- - `GET /health`: Health check endpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- ## Example Usage
32
 
 
33
  ```bash
34
- # Generate a clinical note
35
- curl -X POST "https://theaniketgiri-synthex.hf.space/generate" \
36
- -H "Content-Type: application/json" \
37
- -d '{"record_type": "clinical_note", "count": 1}'
38
  ```
39
 
40
- ## Technical Details
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- - Built with FastAPI
43
- - Uses Bio_ClinicalBERT model from Hugging Face
44
- - Docker container with Python 3.9
45
- - Exposed on port 7860
46
 
47
  ## License
48
 
 
7
  app_port: 7860
8
  ---
9
 
10
+ # Synthex Backend
11
 
12
+ FastAPI backend for the Synthex medical text generation service.
13
 
14
+ ## Project Structure
15
 
16
+ ```
17
+ backend/
18
+ ├── app/ # Main application code
19
+ │ ├── api/ # API endpoints
20
+ │ ├── core/ # Core functionality
21
+ │ ├── models/ # Database models
22
+ │ └── services/ # Business logic
23
+ ├── tests/ # Test files
24
+ ├── Dockerfile # Docker configuration
25
+ ├── requirements.txt # Production dependencies
26
+ ├── requirements-dev.txt # Development dependencies
27
+ └── README.md # This file
28
+ ```
29
 
30
+ ## Setup
31
 
32
+ 1. Create a virtual environment:
33
+ ```bash
34
+ python -m venv venv
35
+ source venv/bin/activate # On Windows: venv\Scripts\activate
36
+ ```
37
+
38
+ 2. Install dependencies:
39
+ ```bash
40
+ pip install -r requirements.txt
41
+ pip install -r requirements-dev.txt # For development
42
+ ```
43
+
44
+ 3. Run the application:
45
+ ```bash
46
+ uvicorn app.main:app --reload
47
+ ```
48
 
49
+ ## Testing
50
 
51
+ 1. Run all tests:
52
  ```bash
53
+ python run_tests.py
 
 
 
54
  ```
55
 
56
+ 2. Run backend API tests:
57
+ ```bash
58
+ python test_backend.py
59
+ ```
60
+
61
+ 3. Run linters:
62
+ ```bash
63
+ python run_linters.py
64
+ ```
65
+
66
+ ## API Endpoints
67
+
68
+ - `GET /health`: Health check endpoint
69
+ - `POST /generate`: Generate medical records
70
+ - Parameters:
71
+ - `record_type`: Type of record to generate
72
+ - `count`: Number of records to generate
73
+
74
+ ## Development
75
+
76
+ - Use `requirements-dev.txt` for development dependencies
77
+ - Run linters before committing
78
+ - Write tests for new features
79
+ - Follow PEP 8 style guide
80
+
81
+ ## Deployment
82
 
83
+ The backend is deployed on Hugging Face Spaces. The Dockerfile is configured for this deployment.
 
 
 
84
 
85
  ## License
86
 
app.py CHANGED
@@ -1,11 +1,19 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- from typing import List, Optional
5
  from datetime import datetime
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import torch
8
  import json
 
 
 
 
 
 
 
 
9
 
10
  app = FastAPI(
11
  title="Synthex Medical Text Generator",
@@ -25,18 +33,78 @@ app.add_middleware(
25
  # Initialize model and tokenizer
26
  model = None
27
  tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def load_model():
30
- global model, tokenizer
31
- if model is None or tokenizer is None:
32
- model_name = "emilyalsentzer/Bio_ClinicalBERT"
33
- tokenizer = AutoTokenizer.from_pretrained(model_name)
34
- model = AutoModelForCausalLM.from_pretrained(model_name)
35
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class GenerateRequest(BaseModel):
38
- record_type: str
39
- count: int = 1
 
 
 
 
 
 
 
 
40
 
41
  class MedicalRecord(BaseModel):
42
  type: str
@@ -58,32 +126,61 @@ async def generate_records(request: GenerateRequest):
58
 
59
  records = []
60
  for i in range(request.count):
61
- # Generate text using the model
62
- input_text = f"Generate a {request.record_type}:"
63
- inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
64
- outputs = model.generate(
65
- inputs["input_ids"],
66
- max_length=200,
67
- num_return_sequences=1,
68
- temperature=0.7,
69
- top_p=0.9,
70
- do_sample=True
71
- )
72
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
-
74
- # Create record
75
- record = MedicalRecord(
76
- type=request.record_type,
77
- content=generated_text,
78
- generated_at=datetime.now().isoformat()
79
- )
80
- records.append(record)
 
 
 
 
 
 
 
81
 
82
  return records
83
 
84
  except Exception as e:
85
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
86
 
87
  @app.get("/health")
88
  def health_check():
89
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel, Field, validator
4
+ from typing import List, Optional, Literal
5
  from datetime import datetime
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import torch
8
  import json
9
+ import os
10
+ import logging
11
+ import time
12
+ from huggingface_hub import snapshot_download
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
 
18
  app = FastAPI(
19
  title="Synthex Medical Text Generator",
 
33
  # Initialize model and tokenizer
34
  model = None
35
  tokenizer = None
36
+ MODEL_LOADED = False
37
+ MODEL_LOADING = False
38
+
39
+ def download_model_with_retry(model_name: str, max_retries: int = 3, retry_delay: int = 60):
40
+ """Download model with retry logic"""
41
+ for attempt in range(max_retries):
42
+ try:
43
+ logger.info(f"Downloading model (attempt {attempt + 1}/{max_retries})...")
44
+ # Download model files first
45
+ snapshot_download(
46
+ repo_id=model_name,
47
+ local_files_only=False,
48
+ resume_download=True
49
+ )
50
+ return True
51
+ except Exception as e:
52
+ logger.error(f"Download attempt {attempt + 1} failed: {str(e)}")
53
+ if attempt < max_retries - 1:
54
+ logger.info(f"Waiting {retry_delay} seconds before retrying...")
55
+ time.sleep(retry_delay)
56
+ else:
57
+ raise
58
 
59
  def load_model():
60
+ global model, tokenizer, MODEL_LOADED, MODEL_LOADING
61
+ try:
62
+ if not MODEL_LOADED and not MODEL_LOADING:
63
+ MODEL_LOADING = True
64
+ logger.info("Loading model and tokenizer...")
65
+ model_name = "emilyalsentzer/Bio_ClinicalBERT"
66
+
67
+ # Set environment variable to disable symlinks warning
68
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
69
+
70
+ # Download model first
71
+ download_model_with_retry(model_name)
72
+
73
+ # Load tokenizer
74
+ tokenizer = AutoTokenizer.from_pretrained(
75
+ model_name,
76
+ local_files_only=True
77
+ )
78
+ logger.info("Tokenizer loaded successfully")
79
+
80
+ # Load model
81
+ model = AutoModelForCausalLM.from_pretrained(
82
+ model_name,
83
+ local_files_only=True
84
+ )
85
+ logger.info("Model loaded successfully")
86
+ MODEL_LOADED = True
87
+ MODEL_LOADING = False
88
+ return model, tokenizer
89
+ except Exception as e:
90
+ MODEL_LOADING = False
91
+ logger.error(f"Error loading model: {str(e)}")
92
+ raise HTTPException(
93
+ status_code=503,
94
+ detail="Model loading failed. Please try again later."
95
+ )
96
 
97
  class GenerateRequest(BaseModel):
98
+ record_type: Literal["clinical_note", "discharge_summary", "lab_report", "prescription"]
99
+ count: int = Field(gt=0, le=10, default=1)
100
+
101
+ @validator('count')
102
+ def validate_count(cls, v):
103
+ if v <= 0:
104
+ raise ValueError("Count must be greater than 0")
105
+ if v > 10:
106
+ raise ValueError("Count cannot exceed 10")
107
+ return v
108
 
109
  class MedicalRecord(BaseModel):
110
  type: str
 
126
 
127
  records = []
128
  for i in range(request.count):
129
+ try:
130
+ # Generate text using the model
131
+ input_text = f"Generate a {request.record_type}:"
132
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
133
+ outputs = model.generate(
134
+ inputs["input_ids"],
135
+ max_length=200,
136
+ num_return_sequences=1,
137
+ temperature=0.7,
138
+ top_p=0.9,
139
+ do_sample=True
140
+ )
141
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
142
+
143
+ # Create record
144
+ record = MedicalRecord(
145
+ type=request.record_type,
146
+ content=generated_text,
147
+ generated_at=datetime.now().isoformat()
148
+ )
149
+ records.append(record)
150
+ except Exception as e:
151
+ logger.error(f"Error generating record {i+1}: {str(e)}")
152
+ raise HTTPException(
153
+ status_code=500,
154
+ detail=f"Error generating record: {str(e)}"
155
+ )
156
 
157
  return records
158
 
159
  except Exception as e:
160
+ logger.error(f"Error in generate_records: {str(e)}")
161
+ raise HTTPException(
162
+ status_code=500,
163
+ detail=f"Error generating records: {str(e)}"
164
+ )
165
 
166
  @app.get("/health")
167
  def health_check():
168
+ try:
169
+ # Try to load model if not loaded
170
+ if not MODEL_LOADED and not MODEL_LOADING:
171
+ load_model()
172
+ return {
173
+ "status": "healthy" if MODEL_LOADED else "loading",
174
+ "timestamp": datetime.now().isoformat(),
175
+ "model_loaded": MODEL_LOADED,
176
+ "model_loading": MODEL_LOADING
177
+ }
178
+ except Exception as e:
179
+ logger.error(f"Health check failed: {str(e)}")
180
+ return {
181
+ "status": "unhealthy",
182
+ "timestamp": datetime.now().isoformat(),
183
+ "model_loaded": MODEL_LOADED,
184
+ "model_loading": MODEL_LOADING,
185
+ "error": str(e)
186
+ }
pytest.ini ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [pytest]
2
+ testpaths = .
3
+ python_files = test_*.py
4
+ python_classes = Test*
5
+ python_functions = test_*
6
+ addopts = -v --tb=short
7
+ log_cli = true
8
+ log_cli_level = INFO
9
+ log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)
10
+ log_cli_date_format = %Y-%m-%d %H:%M:%S
test_backend.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import pytest
4
+ from typing import Dict, List
5
+ import os
6
+ from datetime import datetime
7
+ import time
8
+
9
+ # Base URLs for different environments
10
+ LOCAL_URL = "http://127.0.0.1:8000"
11
+ PROD_URL = "https://theaniketgiri-synthex.hf.space"
12
+
13
+ # Use environment variable to determine which URL to use
14
+ BASE_URL = os.getenv("API_URL", LOCAL_URL)
15
+
16
+ def wait_for_model_loading(max_retries=10, delay=30):
17
+ """Wait for model to load before running tests"""
18
+ for i in range(max_retries):
19
+ try:
20
+ response = requests.get(f"{BASE_URL}/health")
21
+ data = response.json()
22
+ print(f"\nHealth check response: {json.dumps(data, indent=2)}")
23
+
24
+ if data.get("model_loaded", False):
25
+ return True
26
+ elif data.get("model_loading", False):
27
+ print(f"Model is still loading, attempt {i+1}/{max_retries}")
28
+ else:
29
+ print(f"Model not loaded yet, attempt {i+1}/{max_retries}")
30
+ time.sleep(delay)
31
+ except Exception as e:
32
+ print(f"Error checking health: {str(e)}")
33
+ time.sleep(delay)
34
+ return False
35
+
36
+ class TestBackendAPI:
37
+ @classmethod
38
+ def setup_class(cls):
39
+ """Setup before running tests"""
40
+ if not wait_for_model_loading():
41
+ pytest.skip("Model failed to load within timeout")
42
+
43
+ def test_health(self):
44
+ """Test the health check endpoint"""
45
+ response = requests.get(f"{BASE_URL}/health")
46
+ assert response.status_code == 200
47
+ data = response.json()
48
+ assert "status" in data
49
+ assert data["status"] in ["healthy", "unhealthy"]
50
+ assert "timestamp" in data
51
+ assert "model_loaded" in data
52
+ print(f"\n=== Health Check ===")
53
+ print(f"Status: {data['status']}")
54
+ print(f"Model Loaded: {data['model_loaded']}")
55
+ print(f"Timestamp: {data['timestamp']}")
56
+
57
+ @pytest.mark.parametrize("record_type", [
58
+ "clinical_note",
59
+ "discharge_summary",
60
+ "lab_report",
61
+ "prescription"
62
+ ])
63
+ def test_generate_single_record(self, record_type: str):
64
+ """Test generating a single record of each type"""
65
+ url = f"{BASE_URL}/generate"
66
+ payload = {
67
+ "record_type": record_type,
68
+ "count": 1
69
+ }
70
+
71
+ print(f"\n=== Generating {record_type} ===")
72
+ response = requests.post(url, json=payload)
73
+
74
+ if response.status_code == 503:
75
+ pytest.skip("Model not loaded")
76
+ elif response.status_code == 500:
77
+ error = response.json()
78
+ pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
79
+
80
+ assert response.status_code == 200
81
+
82
+ data = response.json()
83
+ assert isinstance(data, list)
84
+ assert len(data) == 1
85
+
86
+ record = data[0]
87
+ print(f"Generated Record:")
88
+ print(json.dumps(record, indent=2))
89
+
90
+ # Validate record structure
91
+ assert "type" in record
92
+ assert record["type"] == record_type
93
+ assert "content" in record
94
+ assert "generated_at" in record
95
+
96
+ def test_generate_multiple_records(self):
97
+ """Test generating multiple records"""
98
+ url = f"{BASE_URL}/generate"
99
+ payload = {
100
+ "record_type": "clinical_note",
101
+ "count": 3
102
+ }
103
+
104
+ print("\n=== Generating Multiple Records ===")
105
+ response = requests.post(url, json=payload)
106
+
107
+ if response.status_code == 503:
108
+ pytest.skip("Model not loaded")
109
+ elif response.status_code == 500:
110
+ error = response.json()
111
+ pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
112
+
113
+ assert response.status_code == 200
114
+
115
+ data = response.json()
116
+ assert isinstance(data, list)
117
+ assert len(data) == 3
118
+
119
+ print(f"Generated {len(data)} records")
120
+ for i, record in enumerate(data, 1):
121
+ print(f"\nRecord {i}:")
122
+ print(json.dumps(record, indent=2))
123
+
124
+ def test_invalid_record_type(self):
125
+ """Test error handling for invalid record type"""
126
+ url = f"{BASE_URL}/generate"
127
+ payload = {
128
+ "record_type": "invalid_type",
129
+ "count": 1
130
+ }
131
+
132
+ print("\n=== Testing Invalid Record Type ===")
133
+ response = requests.post(url, json=payload)
134
+ assert response.status_code == 422 # FastAPI validation error
135
+
136
+ error = response.json()
137
+ assert "detail" in error
138
+ print(f"Error: {error['detail']}")
139
+
140
+ def test_invalid_count(self):
141
+ """Test error handling for invalid count"""
142
+ url = f"{BASE_URL}/generate"
143
+ payload = {
144
+ "record_type": "clinical_note",
145
+ "count": 0
146
+ }
147
+
148
+ print("\n=== Testing Invalid Count ===")
149
+ response = requests.post(url, json=payload)
150
+ assert response.status_code == 422 # FastAPI validation error
151
+
152
+ error = response.json()
153
+ assert "detail" in error
154
+ print(f"Error: {error['detail']}")
155
+
156
+ def test_record_content_quality(self):
157
+ """Test the quality of generated record content"""
158
+ url = f"{BASE_URL}/generate"
159
+ payload = {
160
+ "record_type": "clinical_note",
161
+ "count": 1
162
+ }
163
+
164
+ print("\n=== Testing Record Content Quality ===")
165
+ response = requests.post(url, json=payload)
166
+
167
+ if response.status_code == 503:
168
+ pytest.skip("Model not loaded")
169
+ elif response.status_code == 500:
170
+ error = response.json()
171
+ pytest.fail(f"Generation failed: {error.get('detail', 'Unknown error')}")
172
+
173
+ assert response.status_code == 200
174
+
175
+ data = response.json()
176
+ record = data[0]
177
+
178
+ # Check content length
179
+ assert len(record["content"]) > 100, "Content too short"
180
+
181
+ # Check for common medical terms
182
+ medical_terms = ["patient", "diagnosis", "treatment", "symptoms"]
183
+ content_lower = record["content"].lower()
184
+ assert any(term in content_lower for term in medical_terms), "Missing medical terminology"
185
+
186
+ print("Content Quality Checks Passed")
187
+ print(f"Content Length: {len(record['content'])} characters")
188
+
189
+ def main():
190
+ """Run all tests"""
191
+ print("Starting API Tests...")
192
+ print(f"Testing against: {BASE_URL}")
193
+ print("=" * 50)
194
+
195
+ test_suite = TestBackendAPI()
196
+
197
+ # Run all tests
198
+ test_suite.test_health()
199
+ test_suite.test_generate_single_record("clinical_note")
200
+ test_suite.test_generate_multiple_records()
201
+ test_suite.test_invalid_record_type()
202
+ test_suite.test_invalid_count()
203
+ test_suite.test_record_content_quality()
204
+
205
+ print("\nAll tests completed successfully!")
206
+ print("=" * 50)
207
+
208
+ if __name__ == "__main__":
209
+ main()