MedGemma Fine-tuned for Medical Diagnosis

Model Description

This is a fine-tuned version of MedGemma (based on Google Gemma 2 2B) specialized for medical diagnosis classification based on patient symptoms.

Important Note: While MedGemma is originally a multimodal model built on Gemini architecture capable of processing text, images, and other modalities, this fine-tuned version is optimized as a causal language model specifically for text-based symptom-to-diagnosis classification.

The model was fine-tuned using LoRA (Low-Rank Adaptation) on a dataset of 10,000 medical cases with symptom-diagnosis pairs, achieving 98.9% training accuracy at checkpoint-1100.

Key Features

  • 🎯 High Accuracy: 98.9% training accuracy
  • πŸ”¬ Medical Focus: Fine-tuned from multimodal MedGemma foundation
  • ⚑ Efficient: Uses LoRA adapters (~35M parameters) on Gemma-2-2B base
  • πŸ’Ύ Compact: Smallest model in the ensemble at 2B parameters
  • πŸ“ Specialized: Text-only causal modeling for diagnostic classification

Model Background

Multimodal Origins

  • Original Architecture: MedGemma (multimodal medical model)
  • Base Model: Google Gemma 2 2B
  • Gemini Heritage: Built on Gemini 3 multimodal processing architecture
  • Capabilities: Originally supports text, images, and multimodal inputs

This Fine-Tuned Version

  • Task: Text-based symptom classification only
  • Mode: Causal language modeling (text-to-text)
  • Optimization: Focused on diagnostic prediction from text symptoms
  • Modality: Text input β†’ Text output (diagnosis)

Training Details

Dataset

  • Size: 10,000 symptom-diagnosis pairs
  • Format: Text-based patient symptoms β†’ Medical diagnosis
  • Modality: Text-only (single modality from multimodal base)
  • Train/Validation Split: Standard split with held-out validation

Training Configuration

LoRA Hyperparameters:

  • LoRA Rank (r): 16
  • LoRA Alpha: 32
  • LoRA Dropout: 0.05
  • Target Modules: q_proj, v_proj

Training Hyperparameters:

  • Learning Rate: 2e-4
  • Batch Size: 4 per device
  • Gradient Accumulation Steps: 4
  • Effective Batch Size: 16
  • Number of Epochs: 3
  • Warmup Steps: 100
  • Optimizer: AdamW (8-bit)
  • Weight Decay: 0.01
  • Max Gradient Norm: 1.0
  • LR Scheduler: Linear with warmup

Training Environment:

  • GPU: NVIDIA A100 40GB
  • Precision: Mixed FP16
  • Quantization: 4-bit NF4 with double quantization
  • Framework: Hugging Face Transformers 4.45.0 + PEFT 0.12.0

Training Results

Final Checkpoint (checkpoint-1100):

  • Training Loss: ~0.05
  • Training Accuracy: 98.9%
  • Total Training Steps: 1100
  • Training Time: ~3 hours

Model Architecture

  • Base Model: google/gemma-2-2b (~2B parameters)
  • Original Capability: Multimodal (text + images)
  • Fine-Tuned Mode: Causal language model (text-only)
  • LoRA Adapters: ~35M trainable parameters
  • Model Size: ~35 MB (LoRA adapters only)
  • Architecture Type: Gemini-based decoder with medical specialization

Intended Use

Primary Use Cases

βœ… Medical diagnosis prediction from text symptom descriptions
βœ… Clinical decision support systems (with medical oversight)
βœ… Medical education and training
βœ… Healthcare AI research
βœ… Ensemble medical diagnosis systems
βœ… Lightweight deployment scenarios (smallest model in ensemble)

Out of Scope

❌ Multimodal medical imaging analysis (requires base model capabilities)
❌ Direct patient care without medical professional oversight
❌ Emergency medical decisions
❌ Replacement for professional medical judgment

Usage

Installation

pip install transformers peft torch bitsandbytes accelerate

Loading the Model

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch

# Quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

# Load base model (Gemma 2 2B)
base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

# Load fine-tuned adapter
model = PeftModel.from_pretrained(
    base_model,
    "Sugandha-Chauhan/MedGemma-SymptomDiagnosis"
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

Inference Example

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'

# Patient symptoms (text-only input)
symptoms = "fever, cough, fatigue, body aches, headache"

# Format prompt
prompt = f"""Below is a patient case with symptoms. Provide ONLY the most likely diagnosis.

### Symptoms:
{symptoms}

### Diagnosis:
"""

# Generate
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        temperature=0.1,
        do_sample=False
    )

# Decode
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
diagnosis = result[len(prompt):].strip().split('\n')[0]

print(f"Diagnosis: {diagnosis}")

Performance Metrics

Metric Value
Training Loss ~0.05
Training Accuracy 98.9%
Training Steps 1100
Checkpoint checkpoint-1100
Base Model Size 2B parameters
Adapter Size 35 MB

Model Comparison

Part of a 4-model ensemble for medical diagnosis:

Model Base Size Adapter Size Training Acc Checkpoint Specialty
BioMistral-7B BioMistral 7B 52 MB 99.1% 700 Medical text
MedAlpaca-7B MedAlpaca 7B 64 MB 99.0% 600 Medical LLM
MedGemma Gemma 2 2B 35 MB 98.9% 1100 Multimodal base
BioGPT BioGPT 1.5B 12 MB TBD 1100 Biomedical

Key Advantages:

  • βœ… Smallest base model (2B parameters) - fastest inference
  • βœ… Built on multimodal architecture (future expansion possible)
  • βœ… Google Gemini heritage - advanced language understanding
  • βœ… Compact deployment footprint

Limitations

  • Limited to diagnostic categories in 10K training samples
  • Text-only fine-tuning: Multimodal capabilities not utilized in this version
  • Performance depends on accurate symptom description
  • Does not consider patient history, labs, or imaging
  • May not perform well on rare conditions
  • English language only
  • Requires medical professional interpretation
  • Smaller base model may have less domain knowledge than 7B models

Future Enhancements

Given the multimodal base architecture, potential future versions could:

  • πŸ”¬ Process medical images alongside text symptoms
  • πŸ“Š Analyze lab results and diagnostic images
  • 🩺 Integrate multiple data modalities for diagnosis
  • πŸ“ˆ Leverage full Gemini multimodal capabilities

Ethical Considerations

⚠️ Medical AI Ethics:

  • Should never replace professional medical judgment
  • Requires appropriate medical oversight in clinical settings
  • Users must understand model limitations
  • Clinical validation needed for real-world deployment

⚠️ Bias Considerations:

  • Training data may reflect diagnostic biases
  • Performance may vary across demographics
  • Regular monitoring recommended for production use
  • Smaller model size may amplify biases

Citation

If you use this model, please cite:

@misc{chauhan2024medgemma,
  author = {Sugandha Chauhan},
  title = {MedGemma Fine-tuned for Medical Diagnosis: Text-based Classification from Multimodal Foundation},
  year = {2024},
  publisher = {Hugging Face},
  howpublished = {\url{https://huggingface.co/Sugandha-Chauhan/MedGemma-SymptomDiagnosis}},
  note = {Fine-tuned causal model from multimodal MedGemma base}
}

Acknowledgments

  • Base Model: Google Gemma 2 team for the efficient 2B architecture
  • MedGemma: Multimodal medical foundation model
  • Gemini Architecture: Advanced multimodal processing capabilities
  • Training Framework: Hugging Face Transformers and PEFT libraries

Disclaimer

⚠️ IMPORTANT MEDICAL DISCLAIMER

This AI model is for educational and research purposes only. It is NOT:

  • A substitute for professional medical advice, diagnosis, or treatment
  • Approved for clinical use without medical oversight
  • Intended for emergency medical situations
  • A replacement for qualified healthcare providers

Always consult a physician or qualified healthcare provider for medical decisions.

License

Released under the same license as Gemma 2. See base model for details.

Contact

Author: Sugandha Chauhan
Repository: https://huggingface.co/Sugandha-Chauhan/MedGemma-SymptomDiagnosis
Model Portfolio: 4-model ensemble for medical diagnosis


Model Version: 1.0 (Checkpoint-1100)
Last Updated: November 2024
Architecture: Multimodal foundation β†’ Text-only fine-tuning
Part of: Multi-model Medical Diagnosis System

Downloads last month
30
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for Sugandha-Chauhan/MedGemma-SymptomDiagnosis

Base model

google/gemma-2-2b
Adapter
(140)
this model