MedGemma Fine-tuned for Medical Diagnosis
Model Description
This is a fine-tuned version of MedGemma (based on Google Gemma 2 2B) specialized for medical diagnosis classification based on patient symptoms.
Important Note: While MedGemma is originally a multimodal model built on Gemini architecture capable of processing text, images, and other modalities, this fine-tuned version is optimized as a causal language model specifically for text-based symptom-to-diagnosis classification.
The model was fine-tuned using LoRA (Low-Rank Adaptation) on a dataset of 10,000 medical cases with symptom-diagnosis pairs, achieving 98.9% training accuracy at checkpoint-1100.
Key Features
- π― High Accuracy: 98.9% training accuracy
- π¬ Medical Focus: Fine-tuned from multimodal MedGemma foundation
- β‘ Efficient: Uses LoRA adapters (~35M parameters) on Gemma-2-2B base
- πΎ Compact: Smallest model in the ensemble at 2B parameters
- π Specialized: Text-only causal modeling for diagnostic classification
Model Background
Multimodal Origins
- Original Architecture: MedGemma (multimodal medical model)
- Base Model: Google Gemma 2 2B
- Gemini Heritage: Built on Gemini 3 multimodal processing architecture
- Capabilities: Originally supports text, images, and multimodal inputs
This Fine-Tuned Version
- Task: Text-based symptom classification only
- Mode: Causal language modeling (text-to-text)
- Optimization: Focused on diagnostic prediction from text symptoms
- Modality: Text input β Text output (diagnosis)
Training Details
Dataset
- Size: 10,000 symptom-diagnosis pairs
- Format: Text-based patient symptoms β Medical diagnosis
- Modality: Text-only (single modality from multimodal base)
- Train/Validation Split: Standard split with held-out validation
Training Configuration
LoRA Hyperparameters:
- LoRA Rank (r): 16
- LoRA Alpha: 32
- LoRA Dropout: 0.05
- Target Modules: q_proj, v_proj
Training Hyperparameters:
- Learning Rate: 2e-4
- Batch Size: 4 per device
- Gradient Accumulation Steps: 4
- Effective Batch Size: 16
- Number of Epochs: 3
- Warmup Steps: 100
- Optimizer: AdamW (8-bit)
- Weight Decay: 0.01
- Max Gradient Norm: 1.0
- LR Scheduler: Linear with warmup
Training Environment:
- GPU: NVIDIA A100 40GB
- Precision: Mixed FP16
- Quantization: 4-bit NF4 with double quantization
- Framework: Hugging Face Transformers 4.45.0 + PEFT 0.12.0
Training Results
Final Checkpoint (checkpoint-1100):
- Training Loss: ~0.05
- Training Accuracy: 98.9%
- Total Training Steps: 1100
- Training Time: ~3 hours
Model Architecture
- Base Model: google/gemma-2-2b (~2B parameters)
- Original Capability: Multimodal (text + images)
- Fine-Tuned Mode: Causal language model (text-only)
- LoRA Adapters: ~35M trainable parameters
- Model Size: ~35 MB (LoRA adapters only)
- Architecture Type: Gemini-based decoder with medical specialization
Intended Use
Primary Use Cases
β
Medical diagnosis prediction from text symptom descriptions
β
Clinical decision support systems (with medical oversight)
β
Medical education and training
β
Healthcare AI research
β
Ensemble medical diagnosis systems
β
Lightweight deployment scenarios (smallest model in ensemble)
Out of Scope
β Multimodal medical imaging analysis (requires base model capabilities)
β Direct patient care without medical professional oversight
β Emergency medical decisions
β Replacement for professional medical judgment
Usage
Installation
pip install transformers peft torch bitsandbytes accelerate
Loading the Model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch
# Quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
# Load base model (Gemma 2 2B)
base_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b",
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
# Load fine-tuned adapter
model = PeftModel.from_pretrained(
base_model,
"Sugandha-Chauhan/MedGemma-SymptomDiagnosis"
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Inference Example
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
# Patient symptoms (text-only input)
symptoms = "fever, cough, fatigue, body aches, headache"
# Format prompt
prompt = f"""Below is a patient case with symptoms. Provide ONLY the most likely diagnosis.
### Symptoms:
{symptoms}
### Diagnosis:
"""
# Generate
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=0.1,
do_sample=False
)
# Decode
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
diagnosis = result[len(prompt):].strip().split('\n')[0]
print(f"Diagnosis: {diagnosis}")
Performance Metrics
| Metric | Value |
|---|---|
| Training Loss | ~0.05 |
| Training Accuracy | 98.9% |
| Training Steps | 1100 |
| Checkpoint | checkpoint-1100 |
| Base Model Size | 2B parameters |
| Adapter Size | 35 MB |
Model Comparison
Part of a 4-model ensemble for medical diagnosis:
| Model | Base | Size | Adapter Size | Training Acc | Checkpoint | Specialty |
|---|---|---|---|---|---|---|
| BioMistral-7B | BioMistral | 7B | 52 MB | 99.1% | 700 | Medical text |
| MedAlpaca-7B | MedAlpaca | 7B | 64 MB | 99.0% | 600 | Medical LLM |
| MedGemma | Gemma 2 | 2B | 35 MB | 98.9% | 1100 | Multimodal base |
| BioGPT | BioGPT | 1.5B | 12 MB | TBD | 1100 | Biomedical |
Key Advantages:
- β Smallest base model (2B parameters) - fastest inference
- β Built on multimodal architecture (future expansion possible)
- β Google Gemini heritage - advanced language understanding
- β Compact deployment footprint
Limitations
- Limited to diagnostic categories in 10K training samples
- Text-only fine-tuning: Multimodal capabilities not utilized in this version
- Performance depends on accurate symptom description
- Does not consider patient history, labs, or imaging
- May not perform well on rare conditions
- English language only
- Requires medical professional interpretation
- Smaller base model may have less domain knowledge than 7B models
Future Enhancements
Given the multimodal base architecture, potential future versions could:
- π¬ Process medical images alongside text symptoms
- π Analyze lab results and diagnostic images
- π©Ί Integrate multiple data modalities for diagnosis
- π Leverage full Gemini multimodal capabilities
Ethical Considerations
β οΈ Medical AI Ethics:
- Should never replace professional medical judgment
- Requires appropriate medical oversight in clinical settings
- Users must understand model limitations
- Clinical validation needed for real-world deployment
β οΈ Bias Considerations:
- Training data may reflect diagnostic biases
- Performance may vary across demographics
- Regular monitoring recommended for production use
- Smaller model size may amplify biases
Citation
If you use this model, please cite:
@misc{chauhan2024medgemma,
author = {Sugandha Chauhan},
title = {MedGemma Fine-tuned for Medical Diagnosis: Text-based Classification from Multimodal Foundation},
year = {2024},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/Sugandha-Chauhan/MedGemma-SymptomDiagnosis}},
note = {Fine-tuned causal model from multimodal MedGemma base}
}
Acknowledgments
- Base Model: Google Gemma 2 team for the efficient 2B architecture
- MedGemma: Multimodal medical foundation model
- Gemini Architecture: Advanced multimodal processing capabilities
- Training Framework: Hugging Face Transformers and PEFT libraries
Disclaimer
β οΈ IMPORTANT MEDICAL DISCLAIMER
This AI model is for educational and research purposes only. It is NOT:
- A substitute for professional medical advice, diagnosis, or treatment
- Approved for clinical use without medical oversight
- Intended for emergency medical situations
- A replacement for qualified healthcare providers
Always consult a physician or qualified healthcare provider for medical decisions.
License
Released under the same license as Gemma 2. See base model for details.
Contact
Author: Sugandha Chauhan
Repository: https://huggingface.co/Sugandha-Chauhan/MedGemma-SymptomDiagnosis
Model Portfolio: 4-model ensemble for medical diagnosis
Model Version: 1.0 (Checkpoint-1100)
Last Updated: November 2024
Architecture: Multimodal foundation β Text-only fine-tuning
Part of: Multi-model Medical Diagnosis System
- Downloads last month
- 30
Model tree for Sugandha-Chauhan/MedGemma-SymptomDiagnosis
Base model
google/gemma-2-2b