Gemma 3 fine tuning max token length

#22
by mukhayy - opened

Looking to fine tune google/gemma-3-12b-it with my dataset of around 10k examples. But my dataset outputs are quiet lengthy (some of them may reach 125k and average being around 60k tokens) so I thought I may take adventage of max_position_embeddings = 131072 of this model. But I haven't seen anywhere in examples for fine tuning setting max_seq_length of trl.SFTTrainer as 131072.
Is it smth doable? Or does 131072 only applies for inference? How people should/are approach(ing) fine tuning for lengthy outputs in dataset? Can you also tell what hardware with number of gpus is a best option from your experience?
Thank you

Hi @mukhayy ,

Welcome to Google Gemma family of open source models, The max_position_embeddings parameter in a model google/gemma-3-12b-it defines the maximum sequence length the model was pre-trained to handle. This means its positional encodings and attention mechanisms are designed to work up to that length. Please consider few things If you would like to fine the model with such large lengths of tokens before you proceed with the fine tuning:

  1. Fine tuning the model with such large data requires very powerful hard and large amount of memory.
  2. Consider packing multiple examples into a single max_seq_length input to make more efficient use of GPU memory. trl.SFTTrainer can handle this with packing=True. This is crucial if many of your inputs are shorter, even if the outputs are long.
  3. Ensure your dataset adheres to the chat template or instruction format that google/gemma-3-12b-it was instruction-tuned on. This is usually something like:

messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},
{"type": "text", "text": "What animal is on the candy?"}
]
}
]

  1. Fine-tuning (QLoRA/LoRA): This is critical for fine-tuning large models like google/gemma-3-12b-it on consumer-grade or even many professional GPUs. Full fine-tuning will be prohibitively expensive in terms of VRAM and computation. QLoRA (Quantized LoRA) is even more memory-efficient in such scenarios. We can achieve this by passing a LoraConfig to your SFTTrainer.

Please consider the above mentioned suggestion before you are doing the fine tuning with such large dataset.

Hardware Requirement: Multiple GPU hardware is recommended when you are dealing with large model like 12B or 27B, along with tuning of large data.

  1. 4x NVIDIA A100 (80GB): This would be an ideal setup. You could potentially use a slightly larger effective batch size and train more efficiently.
  2. 8x NVIDIA A6000 (48GB): Similar to the 4x A100, this provides ample VRAM option.

Please find the attached gist file where you can find the parameter configuration to deal with large corpus of data, please note it's not a complete code as I don't have the actual dataset.

Thanks.

Sign up or log in to comment