GRPO Training
Overview
Group Relative Policy Optimization (GRPO) is a reinforcement learning method that improves language models by generating multiple completions per prompt, scoring them with reward functions, and using the relative ranking within each group to compute advantage estimates. Unlike DPO, which requires pre-collected preference pairs, GRPO generates its own training data online and can work with any programmatic reward signal (math correctness, format compliance, code execution results, etc.).
Use GRPO when you have a task with a verifiable reward signal and want the model to discover solution strategies on its own. Use DPO when you already have human preference data. Use SFT when you have gold-standard completions to imitate directly.
Axolotl’s GRPO implementation builds on TRL and adds async generation, streaming scoring, importance sampling correction, replay buffers, and multi-GPU scaling via FSDP and DeepSpeed.
Architecture
GRPO training uses a two-process architecture: a vLLM server for fast generation and a trainer process for scoring and gradient updates.
Terminal 1 (GPU 0) Terminal 2 (GPU 1)
┌──────────────────────┐ ┌──────────────────────────────────┐
│ vLLM Server │ │ Trainer │
│ │ HTTP │ │
│ Serves base model │◄────────────►│ Background thread: │
│ + LoRA adapter │ /generate │ Send prompts to vLLM │
│ │ /set_lora │ Pad & collate completions │
│ Punica kernels for │ │ │
│ LoRA inference │ │ Main thread: │
│ │ │ Score completions (rewards) │
└──────────────────────┘ │ Compute policy log-probs │
│ Calculate advantages │
│ PPO-clip gradient update │
│ Sync LoRA weights to vLLM │
└──────────────────────────────────┘
Data flow for each training step:
- The background thread sends prompts to vLLM, which generates
num_generationscompletions per prompt. - The main thread scores completions using your reward functions.
- Advantages are computed within each prompt group (group-relative normalization).
- Policy log-probabilities are computed by running a forward pass on the training model.
- The PPO-clip loss is computed and gradients are applied.
- Periodically, LoRA adapter weights are synced back to vLLM so future generations reflect the updated policy.
With async prefetch enabled, step 1 for the next batch runs concurrently with steps 2-6 for the current batch.
Quick Start
A GRPO training run requires three components: a YAML config, a reward module (Python file), and a running vLLM server.
1. Write a reward module
Create a file called rewards.py in your working directory:
# rewards.py
import re
def accuracy_reward(completions, answer, **kwargs) -> list[float]:
"""Check if the completion contains the correct numerical answer."""
rewards = []
for completion, correct in zip(completions, answer):
text = completion[0]["content"]
# Extract the last number from the completion
numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
predicted = numbers[-1] if numbers else ""
rewards.append(1.0 if predicted == str(correct) else 0.0)
return rewards
def format_reward(completions, **kwargs) -> list[float]:
"""Reward completions that use a structured thinking format."""
rewards = []
for completion in completions:
text = completion[0]["content"]
has_think = "<think>" in text and "</think>" in text
has_answer = "<answer>" in text and "</answer>" in text
rewards.append(1.0 if has_think and has_answer else 0.0)
return rewards
def prompt_transform(cfg, *args, **kwargs):
"""Convert GSM8K dataset rows into chat prompts."""
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [
{"role": "system", "content": "Solve the math problem. Show your reasoning in <think> tags and your final numerical answer in <answer> tags."},
{"role": "user", "content": example["question"]},
],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}2. Write the config
Create config.yaml:
base_model: Qwen/Qwen2.5-1.5B-Instruct
rl: grpo
chat_template: tokenizer_default
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.85
dtype: auto
max_model_len: 2048
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
trl:
use_vllm: true
use_data_producer: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_server_timeout: 300
vllm_lora_sync: true
num_generations: 8
max_completion_length: 512
temperature: 0.7
reward_funcs:
- rewards.accuracy_reward
- rewards.format_reward
reward_weights:
- 1.0
- 0.5
datasets:
- path: openai/gsm8k
name: main
type: rewards.prompt_transform
split: train
skip_prepare_dataset: true
val_set_size: 0.0
sequence_len: 512
micro_batch_size: 2
gradient_accumulation_steps: 4
max_steps: 200
learning_rate: 5.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 10
bf16: true
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|endoftext|>"
output_dir: ./grpo-output
logging_steps: 13. Start vLLM and train
# Terminal 1: Start vLLM server on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Wait 30-90 seconds for model loading and CUDA graph capture
# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yamlUse tmux or separate terminal sessions to manage the two processes. The vLLM server must remain running for the entire training duration.
Custom Reward Functions
Function signature
TRL calls reward functions with this signature:
def my_reward(completions, **kwargs) -> list[float]:completionsis a list of single-element lists, where each element is a dict{"role": "assistant", "content": "..."}. Socompletions[i][0]["content"]gives you the text of the i-th completion.**kwargscontains all dataset columns that were not removed by the dataset transform. This is how you pass ground truth answers, metadata, or any other information to your reward function.- Return a
list[float]with the same length ascompletions. You may returnNonefor individual elements to exclude them from aggregation.
Example: accuracy reward with answer extraction
def accuracy_reward(completions, answer, **kwargs) -> list[float]:
rewards = []
for completion, correct_answer in zip(completions, answer):
text = completion[0]["content"]
# Extract answer from <answer>...</answer> tags
match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
predicted = match.group(1).strip() if match else ""
rewards.append(1.0 if predicted == str(correct_answer) else 0.0)
return rewardsExample: length penalty
def length_penalty(completions, **kwargs) -> list[float]:
"""Penalize very short or very long completions."""
rewards = []
for completion in completions:
length = len(completion[0]["content"])
if length < 50:
rewards.append(-0.5)
elif length > 2000:
rewards.append(-0.2)
else:
rewards.append(0.0)
return rewardsMultiple rewards and weighting
You can combine multiple reward functions with different weights:
trl:
reward_funcs:
- rewards.accuracy_reward
- rewards.format_reward
- rewards.length_penalty
reward_weights:
- 1.0 # accuracy is most important
- 0.5 # format compliance
- 0.1 # mild length preferenceRewards are combined by the multi_objective_aggregation strategy:
sum_then_normalize(default): weights and sums all rewards first, then normalizes across the group.normalize_then_sum(GDPO): normalizes each reward independently, then sums. This prevents one reward from dominating and is recommended when using multiple reward functions with different scales.
trl:
multi_objective_aggregation: normalize_then_sumDataset transforms
The dataset transform converts raw HuggingFace dataset rows into chat-format prompts:
def prompt_transform(cfg, *args, **kwargs):
def map_fn(example, tokenizer=None):
return {
"prompt": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": example["question"]},
],
# Keep 'answer' column for the reward function
"answer": example["answer"],
}
# Remove columns consumed by the transform; keep columns needed by rewards
return map_fn, {"remove_columns": ["question"]}The transform returns a tuple of (map_function, kwargs_dict). The remove_columns in the kwargs dict removes columns that are no longer needed. Columns that your reward functions reference via **kwargs (like answer) must not be removed.
The reward module must be importable from the directory where you run axolotl train. If your reward file is rewards.py, the import path is rewards.accuracy_reward. If it is inside a package my_rewards/scoring.py, use my_rewards.scoring.accuracy_reward.
Reward models (neural network rewards)
Instead of a Python function, you can pass a HuggingFace model path as a reward function. TRL will load it as a reward model and use its scalar output as the reward:
trl:
reward_funcs:
- OpenAssistant/reward-model-deberta-v3-large-v2
- rewards.format_reward
reward_weights:
- 1.0
- 0.3Using math_verify
The math_verify library provides robust mathematical answer verification but uses signal.alarm() internally, which only works in the main thread. If you use math_verify in a reward function, set reward_num_workers to use subprocess workers:
trl:
reward_num_workers: 4Each worker runs in its own subprocess with its own main thread, so signal.alarm() works correctly.
vLLM Setup
GRPO requires a running vLLM server for generation. For a complete guide on server modes, LoRA sync, weight synchronization, and restart procedures, see vLLM Serving.
The minimal setup:
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.85
trl:
use_vllm: true
vllm_lora_sync: true # Recommended with LoRA — faster sync, no NCCL contention
vllm_sync_interval: 5 # Sync weights every 5 stepsCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml # GPU 0: vLLM
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml # GPU 1: trainingvLLM must be restarted between experiments — stale weight syncs corrupt server state. See Restart Requirements.
Async Training Features
Async GRPO overlaps generation and training to reduce wall-clock time. While the model trains on the current batch, the next batch is already being generated by vLLM.
Enabling async prefetch
trl:
use_data_producer: true
async_prefetch: true
prefetch_depth: 1
vllm_sync_interval: 2use_data_producer: trueenables the data producer protocol (required for all async features).async_prefetch: trueruns generation in a background thread.prefetch_depthcontrols how many batches to prefetch ahead (1 is usually sufficient).vllm_sync_intervalcontrols how often LoRA weights are synced to vLLM (every N optimizer steps). Lower values mean fresher generations but more sync overhead.
Because the background thread generates with slightly stale model weights, async mode benefits from importance sampling correction (see next section). Enable vllm_importance_sampling_correction: true when using async_prefetch: true.
Streaming partial batch
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This reduces peak memory during scoring and enables finer-grained zero-advantage skipping.
trl:
streaming_partial_batch: true
streaming_min_groups: 1streaming_min_groups controls the minimum number of prompt groups scored per chunk. Setting it to 1 gives maximum granularity.
Zero-advantage batch skipping
When all advantages in a micro-batch are zero (every completion in the group got the same reward), there is no learning signal. This feature skips the forward/backward pass entirely for such micro-batches.
trl:
skip_zero_advantage_batches: true # defaultThis is enabled by default and logged as skipped_zero_adv_batches in training metrics. It is a safety net, not a major optimization – it only saves significant time when the model cannot solve any prompts in the batch.
Replay buffer
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and replaces zero-signal groups in later batches. This improves data utilization when many prompts yield no reward variance.
trl:
replay_buffer_size: 100
replay_recompute_logps: trueWhen replay_recompute_logps: false, replayed data uses stale log-probabilities which creates an IS mismatch. Keep the default true unless you have a specific reason to disable it.
Deferred re-rolling
Prompts where the model gets zero reward for all generations are buffered and re-injected into later batches, when the model may have improved enough to produce useful completions.
trl:
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
reroll_max_groups: 1 # Max groups to replace per batchSet reroll_start_fraction: 1.0 to disable. This is most useful for tasks where the model starts weak but steadily improves.
Parallel reward workers
Reward functions that use signal.alarm() (like math_verify) only work in the main thread. Parallel reward workers run each function in its own subprocess:
trl:
reward_num_workers: 4Work is sharded across workers by prompt group. For simple reward functions, a single worker is usually sufficient – the overhead of IPC can exceed the computation time.
Importance Sampling and Off-Policy Correction
When using async prefetch, completions are generated from a slightly older policy. IS correction adjusts the gradient to account for this mismatch.
trl:
vllm_importance_sampling_correction: true
importance_sampling_level: token # 'token' recommended (especially with Liger kernel)
off_policy_mask_threshold: 0.5 # KL threshold — masks sequences that are too off-policyUse token level IS. Sequence-level has numerical issues with Liger’s chunked computation. The off_policy_mask_threshold (OPSM) is a safety net that drops sequences where KL divergence exceeds the threshold — 0.5 is a reasonable starting point.
For detailed coverage of IS modes (token_mask, token_truncate, etc.), capping, and bias-corrected KL, see vLLM Serving — IS Correction.
Scaling
FP8 training
FP8 quantization halves model VRAM usage with minimal impact on training quality. It does not significantly speed up computation for small models but allows larger models to fit in memory.
fp8: true
torch_compile: trueFP8 requires patching for zero-padding edge cases. The act_quant_kernel can produce NaN when input is all zeros (padding positions). If you see NaN in grad norms, check whether your padding token embedding is non-zero.
FSDP (Fully Sharded Data Parallel)
FSDP distributes model parameters across multiple GPUs for training while vLLM runs on a separate GPU:
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
use_reentrant: falseLaunch with:
# GPU 0: vLLM
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# GPUs 0,1: Training (FSDP will use both visible GPUs)
CUDA_VISIBLE_DEVICES=0,1 axolotl train config.yamlasync_prefetch: true can deadlock with FSDP because background threads perform unsynchronized FSDP collectives across ranks. With multi-GPU FSDP, only rank 0 generates in the background thread and results are broadcast to all ranks. If you still see hangs, set async_prefetch: false.
DeepSpeed ZeRO-3
deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
use_reentrant: true # Required -- non-reentrant causes CheckpointError with ZeRO-3DeepSpeed ZeRO-3 requires use_reentrant: true for gradient checkpointing. This is the opposite of the FSDP recommendation. Non-reentrant checkpointing causes tensor metadata mismatches during recomputation with ZeRO-3’s parameter partitioning.
Multi-GPU considerations
| Concern | Recommendation |
|---|---|
| vLLM GPU allocation | Dedicate one or more GPUs to vLLM; do not share with trainer GPUs |
| Weight sync contention | Use vllm_lora_sync: true to avoid NCCL contention between training and vLLM |
| FSDP + async | Use async_prefetch: false or rely on rank-0-only background generation |
| DeepSpeed + gradient checkpoint | Must use use_reentrant: true |
| OOM during scoring | Reduce micro_batch_size or num_generations. The logits tensor scales with batch_size * vocab_size |
Monitoring and Debugging
For detailed metric ranges, failure diagnosis, and OOM debugging, see Training Stability & Debugging.
Quick health checks during GRPO training:
rewards/*/meanshould be > 0.15 within 20 steps — if it stays at 0, test your reward function standalonereward_stdshould be > 0 on most steps — all-zero means no learning signalentropyin 0.05-0.5 — below 0.01 suggests mode collapsegrad_normin 0.001-1.0 — > 10 is unstable, 0.0 is expected when zero-advantage skip fires
Pipe training output to a log file: axolotl train config.yaml 2>&1 | tee /tmp/training.log
Configuration Reference
All GRPO-specific options live under the trl: key in your config. Standard training options (learning_rate, micro_batch_size, etc.) are set at the top level as usual.
Core GRPO
| Option | Type | Default | Description |
|---|---|---|---|
use_vllm |
bool | false |
Enable vLLM for generation |
vllm_mode |
"server" or "colocate" |
null |
vLLM deployment mode |
vllm_server_host |
str | "0.0.0.0" |
vLLM server hostname |
vllm_server_port |
int | 8000 |
vLLM server port |
vllm_server_timeout |
int | null |
Timeout (seconds) for vLLM responses |
num_generations |
int | null |
Completions generated per prompt |
generation_batch_size |
int | null |
Number of unique prompts per generation step |
max_completion_length |
int | null |
Maximum tokens per completion |
beta |
float | null |
KL penalty coefficient |
num_iterations |
int | null |
Iterations per batch (mu in the GRPO paper) |
epsilon |
float | null |
PPO clipping lower bound |
epsilon_high |
float | null |
PPO clipping upper bound |
loss_type |
str | null |
Loss formulation: grpo, bnpo, or dr_grpo |
scale_rewards |
bool | true |
Normalize rewards by standard deviation |
mask_truncated_completions |
bool | false |
Exclude truncated completions from loss |
Reward functions
| Option | Type | Default | Description |
|---|---|---|---|
reward_funcs |
list[str] | null |
Import paths to reward functions or HF model IDs |
reward_weights |
list[float] | null |
Relative weights for each reward function |
multi_objective_aggregation |
str | null |
"sum_then_normalize" (GRPO) or "normalize_then_sum" (GDPO) |
rollout_func |
str | null |
Import path to custom rollout function for OpenEnv-style tasks |
Generation parameters
| Option | Type | Default | Description |
|---|---|---|---|
temperature |
float | null |
Sampling temperature |
top_p |
float | null |
Nucleus sampling probability |
top_k |
int | null |
Top-k sampling |
min_p |
float | null |
Minimum probability threshold |
repetition_penalty |
float | null |
Penalty for repeated tokens |
generation_kwargs |
dict | null |
Additional vLLM SamplingParams (e.g., stop_token_ids) |
chat_template_kwargs |
dict | null |
Chat template kwargs (e.g., {enable_thinking: false}) |
vllm_guided_decoding_regex |
str | null |
Regex constraint for guided decoding |
Async pipeline
| Option | Type | Default | Description |
|---|---|---|---|
use_data_producer |
bool | false |
Enable data producer protocol (required for async features) |
async_prefetch |
bool | false |
Generate next batch in background thread |
prefetch_depth |
int | null |
Number of batches to prefetch ahead |
vllm_sync_interval |
int | null |
Sync LoRA weights to vLLM every N steps |
vllm_lora_sync |
bool | false |
Use filesystem LoRA sync instead of NCCL merge |
streaming_partial_batch |
bool | null |
Score prompt groups incrementally |
streaming_min_groups |
int | null |
Minimum groups per streaming chunk |
skip_zero_advantage_batches |
bool | true |
Skip micro-batches with zero learning signal |
reward_num_workers |
int | 1 |
Subprocess workers for reward computation |
vllm_enable_sleep_mode |
bool | null |
Offload vLLM weights when idle (colocate mode) |
Importance sampling
| Option | Type | Default | Description |
|---|---|---|---|
vllm_importance_sampling_correction |
bool | null |
Enable IS correction for async distribution shift |
importance_sampling_level |
"token" or "sequence" |
null |
Granularity of IS ratios. Use token with Liger |
vllm_importance_sampling_mode |
str | null |
token_mask, token_truncate, sequence_mask, or sequence_truncate |
vllm_importance_sampling_cap |
float | null |
Cap C for IS ratio clipping/masking |
off_policy_mask_threshold |
float | null |
KL threshold for off-policy sequence masking (OPSM) |
use_bias_correction_kl |
bool | null |
Apply IS correction to KL divergence term |
Replay and re-roll
| Option | Type | Default | Description |
|---|---|---|---|
replay_buffer_size |
int | 0 |
Max cached high-signal groups. 0 = disabled |
replay_recompute_logps |
bool | true |
Recompute log-probs for replayed data with current model |
reroll_start_fraction |
float | 1.0 |
Start re-rolling failed prompts after this fraction of training. 1.0 = disabled |
reroll_max_groups |
int | 1 |
Max prompt groups to replace with re-rolls per batch |
Reference model
| Option | Type | Default | Description |
|---|---|---|---|
sync_ref_model |
bool | false |
Periodically sync reference model with training model |
ref_model_mixup_alpha |
float | 0.9 |
EMA coefficient for reference model sync |
ref_model_sync_steps |
int | 64 |
Sync reference model every N steps |
Logging
| Option | Type | Default | Description |
|---|---|---|---|
log_completions |
bool | false |
Log sample completions to W&B |
num_completions_to_print |
int | null |
Number of completions to print per step |
use_liger_loss |
bool | null |
Use Liger fused kernel for GRPO loss (reduces VRAM) |