Training process: SFT > DPO
The SFT dataset has been greatly expanded from previous models. 31mil tokens, 25mil trainable. Uses rslora and trains all modules, including lm_head & embed_tokens (at a lower LR).
SFT dataset consists of RP/ERP, Stories, in character assistant data, anime & vtuber AMA's and Nitral's Reddit NSFW writing prompts (slightly modified).
DPO focused on reducing repetition, misgendered characters, parroting and general logic issues. Chosen responses are high quality ERP / RP that are self edited, rejected are MS3.2 outputs, instructed to make mistakes / ignore instructions.
>
Axolotl configs
Not optimized for cost / performance efficiency, YMMV.
SFT 4*H200
base_model: Darkhn/Magistral-Small-2509-Text-Only
tokenizer_use_mistral_common: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_bit: false
deepspeed: deepspeed_configs/zero1.json
datasets:
- path: ./data/train_datasetv2.jsonl
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./Magi-SFT-v2-3
adapter: lora
peft_use_rslora: true
lora_model_dir:
sequence_len: 10280
sample_packing: true
lora_r: 256
lora_alpha: 16
lora_dropout: 0.075
lora_target_linear: true
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project: Magi-SFT-24B
wandb_name: Magi-SFT-v2-3
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 2
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
lr_groups:
- name: lm_head_embed
modules:
- lm_head.weight
- model.embed_tokens.weight
lr: 4e-6
learning_rate: 3e-5
weight_decay: 0.01
max_grad_norm: 1.0
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.05
evals_per_epoch: 4
saves_per_epoch: 2
DPO 2*H200
# ====================
# MODEL CONFIGURATION
# ====================
base_model: ApocalypseParty/Magi-SFT-v2-3
model_type: MistralForCausalLM
tokenizer_type: AutoTokenizer
chat_template: mistral_v7_tekken
# ====================
# RL/DPO CONFIGURATION
# ====================
rl: dpo
rl_beta: 0.075
# ====================
# DATASET CONFIGURATION
# ====================
datasets:
- path: ./data/dpo_ms32_handcrafted_dataset.jsonl
type: chat_template.default
field_messages: messages
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system: ["system"]
user: ["user"]
assistant: ["assistant"]
dataset_prepared_path: ./dpo_data
train_on_inputs: false # Only train on assistant responses
# ====================
# QLORA CONFIGURATION
# ====================
adapter: lora
load_in_8bit: false
lora_r: 32
lora_alpha: 32
lora_dropout: 0.1
lora_target_linear: true
# lora_modules_to_save: # Uncomment only if you added NEW tokens
# ====================
# TRAINING PARAMETERS
# ====================
num_epochs: 1
micro_batch_size: 2
gradient_accumulation_steps: 6
learning_rate: 5e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_ratio: 0.05
weight_decay: 0.01
max_grad_norm: 0.5
# ====================
# SEQUENCE CONFIGURATION
# ====================
sequence_len: 10280
pad_to_sequence_len: true
# ====================
# HARDWARE OPTIMIZATIONS
# ====================
bf16: auto
tf32: false
flash_attention: true
gradient_checkpointing: offload
plugins:
- axolotl.integrations.liger.LigerPlugin
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
liger_rope: true
liger_rms_norm: true
liger_layer_norm: true
liger_glu_activation: true
liger_cross_entropy: false # Cut Cross Entropy overrides this
liger_fused_linear_cross_entropy: false # Cut Cross Entropy overrides this
deepspeed: deepspeed_configs/zero1.json
# ====================
# CHECKPOINTING
# ====================
evals_per_epoch: 1
saves_per_epoch: 2
load_best_model_at_end: true
metric_for_best_model: eval_loss
greater_is_better: false
# ====================
# LOGGING & OUTPUT
# ====================
output_dir: ./Magi-SFT-v2-3-DPO-3
logging_steps: 1
save_safetensors: true
# ====================
# WANDB TRACKING
# ====================
wandb_project: Magi-24B-DPO
wandb_name: Magi-SFT-v2-3-DPO-3