---
language:
- en
license: gemma
library_name: transformers
pipeline_tag: text-generation
base_model: google/gemma-3-270m
tags:
- gemma3
- trl
- grpo
- rlhf
- sft
- math
- reasoning
- chain-of-thought
- experimental
- colab
- kaggle
---
# Gemma-3-270M GRPO (math + CoT)
⚠️ Experimental training run. This is a personal experiment and not production-ready. The current checkpoint is from ~1200 GRPO steps; accuracy is unstable and may be low. Prompt format and weights may change. Please evaluate carefully before any use.
Small reasoning-tuned variant of **Google’s Gemma-3-270M**.
Two-stage recipe: **SFT** on math prompts with hidden `…` reasoning, then **GRPO** to reinforce *correct final answers* while *discouraging overly long hidden reasoning*.
**Note:** there is **no `` tag** in this project. The final answer is emitted as `\boxed{...}` after the `` block.
> The model can emit `…` tokens. Examples below **strip** this by default.
---
## ✨ What’s inside
- **Base**: `google/gemma-3-270m`
- **Objective**:
- **SFT**: learn prompt format + produce a boxed final answer.
- **GRPO**: reward `= 1.0` if the boxed answer matches ground truth (numeric or `sympy`-equivalent), else `0.0`, **minus** a small penalty proportional to tokens inside `…`. KL regularization to the SFT reference.
---
## 🧠 Prompt & output format
Training/eval wrapper (no `` tag):
```text
…(internal scratch work)…
\boxed{FINAL_ANSWER}
```
- SFT builder resembled: `format_sft_example(question, reasoning, final_answer)`
- RL prompts use only: `…\n\n` and expect the model to write reasoning + `\boxed{...}`.
---
## 🚀 Quickstart
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch, re
MODEL_ID = "nirav-madhani/gemma3-270m-grpo-math" # change if different
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(device).eval()
BOX_RE = re.compile(r"\\boxed\\{([^{}]+)\\}")
def generate(question, max_new_tokens=160, temperature=0.2, top_p=0.95, return_boxed=True, show_think=False):
prompt = f"\n{question}\n\n\n"
inputs = tok(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
temperature=temperature,
top_p=top_p,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
use_cache=True,
)
text = tok.decode(out[0], skip_special_tokens=True)
if not show_think:
text = re.sub(r".*?\s*", "", text, flags=re.S)
if return_boxed:
m = BOX_RE.search(text)
return m.group(1).strip() if m else text.strip()
return text
print(generate("If 3x + 5 = 17, what is x?"))
```
---
## 🏗️ Training recipe (summary)
**Stage 1 — SFT**
- `transformers==4.55.x`, `trl==0.21.0`
- Tokenizer: Gemma-3 fast; `pad_token = eos_token`
- `Trainer` + `DataCollatorForLanguageModeling`
- `max_seq_length = prompt_len + completion_len`; small per-device batch with grad accumulation; BF16/FP16 on Ampere, else FP32.
**Stage 2 — RL (GRPO)**
- TRL **GRPO** (`trl==0.21.0`)
- Policy & reference initialized from **SFT**
- **Reward**:
- `+1.0` if `\boxed{...}` equals ground truth (float tolerance or `sympy` equivalence)
- `−λ * (#tokens inside …)`
- Knobs (fit ~15 GB VRAM; tune per GPU):
- `num_generations (K)`: 2–8
- `per_device_train_batch_size`: 1–2
- `gradient_accumulation_steps`: 4–8
(*ensure* `batch * accum * world_size` is **divisible by K**)
- `max_prompt_length`: ~160–256
- `max_completion_length`: ~128–192
- `beta` (KL): ~0.02
- `attn_implementation="eager"`; enable `use_cache=True` if you have headroom
**Checkpointing**
- Checkpoints saved every *N* steps; keep last 3; persisted to Google Drive or `/kaggle/working`.
- Inference loader grabs **newest `checkpoint-XXXX/`**, else RL root → SFT → base.
---
## 📊 Evaluation
*Early-stage.* Evaluate on your math split or GSM-style test by extracting `\boxed{…}` and checking numeric or `sympy`-equivalence. Track:
- reward mean/std, exact-match of final answers,
- KL vs. reference,
- output length and “think” token counts.
---
## ⚖️ License & usage
- **Base license**: Gemma models are under the **Gemma license**. This derivative remains subject to those terms.
- Repo metadata sets `license: gemma`. Review Gemma’s terms before commercial use/redistribution.
---
## 🔒 Limitations & risks
- 270M params is **very small**; expect brittleness outside narrow math tasks.
- Hidden reasoning can be wrong; we **hide** it by default.
- No built-in safety filtering.
---
## 🧩 Repro notes
- Colab/Kaggle friendly; use `attn_implementation="eager"` to avoid FA mismatches.
- GRPO progress bar’s “Training Loss” can be `0.0` — monitor **reward/KL/length**.
- Env tips:
- set `TRANSFORMERS_NO_TORCHVISION=1`,
- ensure compatible `numpy`/`scikit-learn` on Kaggle if `transformers.generation` pulls sklearn.
---
## 🙌 Acknowledgements
- Base weights: **Google** `google/gemma-3-270m`
- RL training: **TRL** (`trl==0.21.0`)
---
## 📣 Citation
```bibtex
@software{nirav_gemma3_270m_grpo_math_2025,
title = {Gemma-3-270M GRPO (math + CoT)},
author = {Nirav Madhani},
year = {2025},
url = {https://huggingface.co/nirav-madhani/gemma3-270m-grpo-math}
}
```