--- language: - en license: gemma library_name: transformers pipeline_tag: text-generation base_model: google/gemma-3-270m tags: - gemma3 - trl - grpo - rlhf - sft - math - reasoning - chain-of-thought - experimental - colab - kaggle --- # Gemma-3-270M GRPO (math + CoT)
⚠️ Experimental training run. This is a personal experiment and not production-ready. The current checkpoint is from ~1200 GRPO steps; accuracy is unstable and may be low. Prompt format and weights may change. Please evaluate carefully before any use.
Small reasoning-tuned variant of **Google’s Gemma-3-270M**. Two-stage recipe: **SFT** on math prompts with hidden `` reasoning, then **GRPO** to reinforce *correct final answers* while *discouraging overly long hidden reasoning*. **Note:** there is **no `` tag** in this project. The final answer is emitted as `\boxed{...}` after the `` block. > The model can emit `` tokens. Examples below **strip** this by default. --- ## ✨ What’s inside - **Base**: `google/gemma-3-270m` - **Objective**: - **SFT**: learn prompt format + produce a boxed final answer. - **GRPO**: reward `= 1.0` if the boxed answer matches ground truth (numeric or `sympy`-equivalent), else `0.0`, **minus** a small penalty proportional to tokens inside ``. KL regularization to the SFT reference. --- ## 🧠 Prompt & output format Training/eval wrapper (no `` tag): ```text …(internal scratch work)… \boxed{FINAL_ANSWER} ``` - SFT builder resembled: `format_sft_example(question, reasoning, final_answer)` - RL prompts use only: `\n\n` and expect the model to write reasoning + `\boxed{...}`. --- ## 🚀 Quickstart ```python from transformers import AutoTokenizer, AutoModelForCausalLM import torch, re MODEL_ID = "nirav-madhani/gemma3-270m-grpo-math" # change if different tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) if tok.pad_token is None: tok.pad_token = tok.eos_token device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, low_cpu_mem_usage=True, attn_implementation="eager", ).to(device).eval() BOX_RE = re.compile(r"\\boxed\\{([^{}]+)\\}") def generate(question, max_new_tokens=160, temperature=0.2, top_p=0.95, return_boxed=True, show_think=False): prompt = f"\n{question}\n\n\n" inputs = tok(prompt, return_tensors="pt").to(device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=temperature > 0, temperature=temperature, top_p=top_p, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id, use_cache=True, ) text = tok.decode(out[0], skip_special_tokens=True) if not show_think: text = re.sub(r".*?\s*", "", text, flags=re.S) if return_boxed: m = BOX_RE.search(text) return m.group(1).strip() if m else text.strip() return text print(generate("If 3x + 5 = 17, what is x?")) ``` --- ## 🏗️ Training recipe (summary) **Stage 1 — SFT** - `transformers==4.55.x`, `trl==0.21.0` - Tokenizer: Gemma-3 fast; `pad_token = eos_token` - `Trainer` + `DataCollatorForLanguageModeling` - `max_seq_length = prompt_len + completion_len`; small per-device batch with grad accumulation; BF16/FP16 on Ampere, else FP32. **Stage 2 — RL (GRPO)** - TRL **GRPO** (`trl==0.21.0`) - Policy & reference initialized from **SFT** - **Reward**: - `+1.0` if `\boxed{...}` equals ground truth (float tolerance or `sympy` equivalence) - `−λ * (#tokens inside )` - Knobs (fit ~15 GB VRAM; tune per GPU): - `num_generations (K)`: 2–8 - `per_device_train_batch_size`: 1–2 - `gradient_accumulation_steps`: 4–8 (*ensure* `batch * accum * world_size` is **divisible by K**) - `max_prompt_length`: ~160–256 - `max_completion_length`: ~128–192 - `beta` (KL): ~0.02 - `attn_implementation="eager"`; enable `use_cache=True` if you have headroom **Checkpointing** - Checkpoints saved every *N* steps; keep last 3; persisted to Google Drive or `/kaggle/working`. - Inference loader grabs **newest `checkpoint-XXXX/`**, else RL root → SFT → base. --- ## 📊 Evaluation *Early-stage.* Evaluate on your math split or GSM-style test by extracting `\boxed{…}` and checking numeric or `sympy`-equivalence. Track: - reward mean/std, exact-match of final answers, - KL vs. reference, - output length and “think” token counts. --- ## ⚖️ License & usage - **Base license**: Gemma models are under the **Gemma license**. This derivative remains subject to those terms. - Repo metadata sets `license: gemma`. Review Gemma’s terms before commercial use/redistribution. --- ## 🔒 Limitations & risks - 270M params is **very small**; expect brittleness outside narrow math tasks. - Hidden reasoning can be wrong; we **hide** it by default. - No built-in safety filtering. --- ## 🧩 Repro notes - Colab/Kaggle friendly; use `attn_implementation="eager"` to avoid FA mismatches. - GRPO progress bar’s “Training Loss” can be `0.0` — monitor **reward/KL/length**. - Env tips: - set `TRANSFORMERS_NO_TORCHVISION=1`, - ensure compatible `numpy`/`scikit-learn` on Kaggle if `transformers.generation` pulls sklearn. --- ## 🙌 Acknowledgements - Base weights: **Google** `google/gemma-3-270m` - RL training: **TRL** (`trl==0.21.0`) --- ## 📣 Citation ```bibtex @software{nirav_gemma3_270m_grpo_math_2025, title = {Gemma-3-270M GRPO (math + CoT)}, author = {Nirav Madhani}, year = {2025}, url = {https://huggingface.co/nirav-madhani/gemma3-270m-grpo-math} } ```