GRPO Training

Group Relative Policy Optimization — a reinforcement learning method for training language models with verifiable reward functions.

Overview

Group Relative Policy Optimization (GRPO) is a reinforcement learning method that improves language models by generating multiple completions per prompt, scoring them with reward functions, and using the relative ranking within each group to compute advantage estimates. Unlike DPO, which requires pre-collected preference pairs, GRPO generates its own training data online and can work with any programmatic reward signal (math correctness, format compliance, code execution results, etc.).

Use GRPO when you have a task with a verifiable reward signal and want the model to discover solution strategies on its own. Use DPO when you already have human preference data. Use SFT when you have gold-standard completions to imitate directly.

Axolotl’s GRPO implementation builds on TRL and adds async generation, streaming scoring, importance sampling correction, replay buffers, and multi-GPU scaling via FSDP and DeepSpeed.

Architecture

GRPO training uses a two-process architecture: a vLLM server for fast generation and a trainer process for scoring and gradient updates.

Terminal 1 (GPU 0)                    Terminal 2 (GPU 1)
┌──────────────────────┐              ┌──────────────────────────────────┐
│  vLLM Server         │              │  Trainer                         │
│                      │   HTTP       │                                  │
│  Serves base model   │◄────────────►│  Background thread:              │
│  + LoRA adapter      │  /generate   │    Send prompts to vLLM          │
│                      │  /set_lora   │    Pad & collate completions     │
│  Punica kernels for  │              │                                  │
│  LoRA inference      │              │  Main thread:                    │
│                      │              │    Score completions (rewards)   │
└──────────────────────┘              │    Compute policy log-probs      │
                                      │    Calculate advantages          │
                                      │    PPO-clip gradient update      │
                                      │    Sync LoRA weights to vLLM     │
                                      └──────────────────────────────────┘

Data flow for each training step:

  1. The background thread sends prompts to vLLM, which generates num_generations completions per prompt.
  2. The main thread scores completions using your reward functions.
  3. Advantages are computed within each prompt group (group-relative normalization).
  4. Policy log-probabilities are computed by running a forward pass on the training model.
  5. The PPO-clip loss is computed and gradients are applied.
  6. Periodically, LoRA adapter weights are synced back to vLLM so future generations reflect the updated policy.

With async prefetch enabled, step 1 for the next batch runs concurrently with steps 2-6 for the current batch.

Quick Start

A GRPO training run requires three components: a YAML config, a reward module (Python file), and a running vLLM server.

1. Write a reward module

Create a file called rewards.py in your working directory:

# rewards.py
import re


def accuracy_reward(completions, answer, **kwargs) -> list[float]:
    """Check if the completion contains the correct numerical answer."""
    rewards = []
    for completion, correct in zip(completions, answer):
        text = completion[0]["content"]
        # Extract the last number from the completion
        numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
        predicted = numbers[-1] if numbers else ""
        rewards.append(1.0 if predicted == str(correct) else 0.0)
    return rewards


def format_reward(completions, **kwargs) -> list[float]:
    """Reward completions that use a structured thinking format."""
    rewards = []
    for completion in completions:
        text = completion[0]["content"]
        has_think = "<think>" in text and "</think>" in text
        has_answer = "<answer>" in text and "</answer>" in text
        rewards.append(1.0 if has_think and has_answer else 0.0)
    return rewards


def prompt_transform(cfg, *args, **kwargs):
    """Convert GSM8K dataset rows into chat prompts."""
    def transform_fn(example, tokenizer=None):
        label = example["answer"].split("####")[-1].strip().replace(",", "")
        return {
            "prompt": [
                {"role": "system", "content": "Solve the math problem. Show your reasoning in <think> tags and your final numerical answer in <answer> tags."},
                {"role": "user", "content": example["question"]},
            ],
            "answer": label,
        }
    return transform_fn, {"remove_columns": ["question"]}

2. Write the config

Create config.yaml:

base_model: Qwen/Qwen2.5-1.5B-Instruct

rl: grpo
chat_template: tokenizer_default

vllm:
  host: 0.0.0.0
  port: 8000
  gpu_memory_utilization: 0.85
  dtype: auto
  max_model_len: 2048

adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true

trl:
  use_vllm: true
  use_data_producer: true
  vllm_server_host: 0.0.0.0
  vllm_server_port: 8000
  vllm_server_timeout: 300
  vllm_lora_sync: true
  num_generations: 8
  max_completion_length: 512
  temperature: 0.7
  reward_funcs:
    - rewards.accuracy_reward
    - rewards.format_reward
  reward_weights:
    - 1.0
    - 0.5

datasets:
  - path: openai/gsm8k
    name: main
    type: rewards.prompt_transform
    split: train

skip_prepare_dataset: true
val_set_size: 0.0
sequence_len: 512
micro_batch_size: 2
gradient_accumulation_steps: 4
max_steps: 200
learning_rate: 5.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 10

bf16: true
flash_attention: true
gradient_checkpointing: true

special_tokens:
  pad_token: "<|endoftext|>"

output_dir: ./grpo-output
logging_steps: 1

3. Start vLLM and train

# Terminal 1: Start vLLM server on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml

# Wait 30-90 seconds for model loading and CUDA graph capture

# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
Tip

Use tmux or separate terminal sessions to manage the two processes. The vLLM server must remain running for the entire training duration.

Custom Reward Functions

Function signature

TRL calls reward functions with this signature:

def my_reward(completions, **kwargs) -> list[float]:
  • completions is a list of single-element lists, where each element is a dict {"role": "assistant", "content": "..."}. So completions[i][0]["content"] gives you the text of the i-th completion.
  • **kwargs contains all dataset columns that were not removed by the dataset transform. This is how you pass ground truth answers, metadata, or any other information to your reward function.
  • Return a list[float] with the same length as completions. You may return None for individual elements to exclude them from aggregation.

Example: accuracy reward with answer extraction

def accuracy_reward(completions, answer, **kwargs) -> list[float]:
    rewards = []
    for completion, correct_answer in zip(completions, answer):
        text = completion[0]["content"]
        # Extract answer from <answer>...</answer> tags
        match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
        predicted = match.group(1).strip() if match else ""
        rewards.append(1.0 if predicted == str(correct_answer) else 0.0)
    return rewards

Example: length penalty

def length_penalty(completions, **kwargs) -> list[float]:
    """Penalize very short or very long completions."""
    rewards = []
    for completion in completions:
        length = len(completion[0]["content"])
        if length < 50:
            rewards.append(-0.5)
        elif length > 2000:
            rewards.append(-0.2)
        else:
            rewards.append(0.0)
    return rewards

Multiple rewards and weighting

You can combine multiple reward functions with different weights:

trl:
  reward_funcs:
    - rewards.accuracy_reward
    - rewards.format_reward
    - rewards.length_penalty
  reward_weights:
    - 1.0    # accuracy is most important
    - 0.5    # format compliance
    - 0.1    # mild length preference

Rewards are combined by the multi_objective_aggregation strategy:

  • sum_then_normalize (default): weights and sums all rewards first, then normalizes across the group.
  • normalize_then_sum (GDPO): normalizes each reward independently, then sums. This prevents one reward from dominating and is recommended when using multiple reward functions with different scales.
trl:
  multi_objective_aggregation: normalize_then_sum

Dataset transforms

The dataset transform converts raw HuggingFace dataset rows into chat-format prompts:

def prompt_transform(cfg, *args, **kwargs):
    def map_fn(example, tokenizer=None):
        return {
            "prompt": [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": example["question"]},
            ],
            # Keep 'answer' column for the reward function
            "answer": example["answer"],
        }
    # Remove columns consumed by the transform; keep columns needed by rewards
    return map_fn, {"remove_columns": ["question"]}

The transform returns a tuple of (map_function, kwargs_dict). The remove_columns in the kwargs dict removes columns that are no longer needed. Columns that your reward functions reference via **kwargs (like answer) must not be removed.

Warning

The reward module must be importable from the directory where you run axolotl train. If your reward file is rewards.py, the import path is rewards.accuracy_reward. If it is inside a package my_rewards/scoring.py, use my_rewards.scoring.accuracy_reward.

Reward models (neural network rewards)

Instead of a Python function, you can pass a HuggingFace model path as a reward function. TRL will load it as a reward model and use its scalar output as the reward:

trl:
  reward_funcs:
    - OpenAssistant/reward-model-deberta-v3-large-v2
    - rewards.format_reward
  reward_weights:
    - 1.0
    - 0.3

Using math_verify

The math_verify library provides robust mathematical answer verification but uses signal.alarm() internally, which only works in the main thread. If you use math_verify in a reward function, set reward_num_workers to use subprocess workers:

trl:
  reward_num_workers: 4

Each worker runs in its own subprocess with its own main thread, so signal.alarm() works correctly.

vLLM Setup

GRPO requires a running vLLM server for generation. For a complete guide on server modes, LoRA sync, weight synchronization, and restart procedures, see vLLM Serving.

The minimal setup:

vllm:
  host: 0.0.0.0
  port: 8000
  gpu_memory_utilization: 0.85

trl:
  use_vllm: true
  vllm_lora_sync: true         # Recommended with LoRA — faster sync, no NCCL contention
  vllm_sync_interval: 5        # Sync weights every 5 steps
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml   # GPU 0: vLLM
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml         # GPU 1: training
Warning

vLLM must be restarted between experiments — stale weight syncs corrupt server state. See Restart Requirements.

Async Training Features

Async GRPO overlaps generation and training to reduce wall-clock time. While the model trains on the current batch, the next batch is already being generated by vLLM.

Enabling async prefetch

trl:
  use_data_producer: true
  async_prefetch: true
  prefetch_depth: 1
  vllm_sync_interval: 2
  • use_data_producer: true enables the data producer protocol (required for all async features).
  • async_prefetch: true runs generation in a background thread.
  • prefetch_depth controls how many batches to prefetch ahead (1 is usually sufficient).
  • vllm_sync_interval controls how often LoRA weights are synced to vLLM (every N optimizer steps). Lower values mean fresher generations but more sync overhead.
Tip

Because the background thread generates with slightly stale model weights, async mode benefits from importance sampling correction (see next section). Enable vllm_importance_sampling_correction: true when using async_prefetch: true.

Streaming partial batch

Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This reduces peak memory during scoring and enables finer-grained zero-advantage skipping.

trl:
  streaming_partial_batch: true
  streaming_min_groups: 1

streaming_min_groups controls the minimum number of prompt groups scored per chunk. Setting it to 1 gives maximum granularity.

Zero-advantage batch skipping

When all advantages in a micro-batch are zero (every completion in the group got the same reward), there is no learning signal. This feature skips the forward/backward pass entirely for such micro-batches.

trl:
  skip_zero_advantage_batches: true   # default

This is enabled by default and logged as skipped_zero_adv_batches in training metrics. It is a safety net, not a major optimization – it only saves significant time when the model cannot solve any prompts in the batch.

Replay buffer

The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and replaces zero-signal groups in later batches. This improves data utilization when many prompts yield no reward variance.

trl:
  replay_buffer_size: 100
  replay_recompute_logps: true
Warning

When replay_recompute_logps: false, replayed data uses stale log-probabilities which creates an IS mismatch. Keep the default true unless you have a specific reason to disable it.

Deferred re-rolling

Prompts where the model gets zero reward for all generations are buffered and re-injected into later batches, when the model may have improved enough to produce useful completions.

trl:
  reroll_start_fraction: 0.5   # Start re-rolling after 50% of training
  reroll_max_groups: 1          # Max groups to replace per batch

Set reroll_start_fraction: 1.0 to disable. This is most useful for tasks where the model starts weak but steadily improves.

Parallel reward workers

Reward functions that use signal.alarm() (like math_verify) only work in the main thread. Parallel reward workers run each function in its own subprocess:

trl:
  reward_num_workers: 4

Work is sharded across workers by prompt group. For simple reward functions, a single worker is usually sufficient – the overhead of IPC can exceed the computation time.

Importance Sampling and Off-Policy Correction

When using async prefetch, completions are generated from a slightly older policy. IS correction adjusts the gradient to account for this mismatch.

trl:
  vllm_importance_sampling_correction: true
  importance_sampling_level: token     # 'token' recommended (especially with Liger kernel)
  off_policy_mask_threshold: 0.5       # KL threshold — masks sequences that are too off-policy

Use token level IS. Sequence-level has numerical issues with Liger’s chunked computation. The off_policy_mask_threshold (OPSM) is a safety net that drops sequences where KL divergence exceeds the threshold — 0.5 is a reasonable starting point.

For detailed coverage of IS modes (token_mask, token_truncate, etc.), capping, and bias-corrected KL, see vLLM Serving — IS Correction.

Scaling

FP8 training

FP8 quantization halves model VRAM usage with minimal impact on training quality. It does not significantly speed up computation for small models but allows larger models to fit in memory.

fp8: true
torch_compile: true
Warning

FP8 requires patching for zero-padding edge cases. The act_quant_kernel can produce NaN when input is all zeros (padding positions). If you see NaN in grad norms, check whether your padding token embedding is non-zero.

FSDP (Fully Sharded Data Parallel)

FSDP distributes model parameters across multiple GPUs for training while vLLM runs on a separate GPU:

fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
  use_reentrant: false

Launch with:

# GPU 0: vLLM
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml

# GPUs 0,1: Training (FSDP will use both visible GPUs)
CUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml
Warning

async_prefetch: true can deadlock with FSDP because background threads perform unsynchronized FSDP collectives across ranks. With multi-GPU FSDP, only rank 0 generates in the background thread and results are broadcast to all ranks. If you still see hangs, set async_prefetch: false.

DeepSpeed ZeRO-3

deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
  use_reentrant: true   # Required -- non-reentrant causes CheckpointError with ZeRO-3
Note

DeepSpeed ZeRO-3 requires use_reentrant: true for gradient checkpointing. This is the opposite of the FSDP recommendation. Non-reentrant checkpointing causes tensor metadata mismatches during recomputation with ZeRO-3’s parameter partitioning.

Multi-GPU considerations

Concern Recommendation
vLLM GPU allocation Dedicate one or more GPUs to vLLM; do not share with trainer GPUs
Weight sync contention Use vllm_lora_sync: true to avoid NCCL contention between training and vLLM
FSDP + async Use async_prefetch: false or rely on rank-0-only background generation
DeepSpeed + gradient checkpoint Must use use_reentrant: true
OOM during scoring Reduce micro_batch_size or num_generations. The logits tensor scales with batch_size * vocab_size

Monitoring and Debugging

For detailed metric ranges, failure diagnosis, and OOM debugging, see Training Stability & Debugging.

Quick health checks during GRPO training:

  • rewards/*/mean should be > 0.15 within 20 steps — if it stays at 0, test your reward function standalone
  • reward_std should be > 0 on most steps — all-zero means no learning signal
  • entropy in 0.05-0.5 — below 0.01 suggests mode collapse
  • grad_norm in 0.001-1.0 — > 10 is unstable, 0.0 is expected when zero-advantage skip fires
Tip

Pipe training output to a log file: axolotl train config.yaml 2>&1 | tee /tmp/training.log

Configuration Reference

All GRPO-specific options live under the trl: key in your config. Standard training options (learning_rate, micro_batch_size, etc.) are set at the top level as usual.

Core GRPO

Option Type Default Description
use_vllm bool false Enable vLLM for generation
vllm_mode "server" or "colocate" null vLLM deployment mode
vllm_server_host str "0.0.0.0" vLLM server hostname
vllm_server_port int 8000 vLLM server port
vllm_server_timeout int null Timeout (seconds) for vLLM responses
num_generations int null Completions generated per prompt
generation_batch_size int null Number of unique prompts per generation step
max_completion_length int null Maximum tokens per completion
beta float null KL penalty coefficient
num_iterations int null Iterations per batch (mu in the GRPO paper)
epsilon float null PPO clipping lower bound
epsilon_high float null PPO clipping upper bound
loss_type str null Loss formulation: grpo, bnpo, or dr_grpo
scale_rewards bool true Normalize rewards by standard deviation
mask_truncated_completions bool false Exclude truncated completions from loss

Reward functions

Option Type Default Description
reward_funcs list[str] null Import paths to reward functions or HF model IDs
reward_weights list[float] null Relative weights for each reward function
multi_objective_aggregation str null "sum_then_normalize" (GRPO) or "normalize_then_sum" (GDPO)
rollout_func str null Import path to custom rollout function for OpenEnv-style tasks

Generation parameters

Option Type Default Description
temperature float null Sampling temperature
top_p float null Nucleus sampling probability
top_k int null Top-k sampling
min_p float null Minimum probability threshold
repetition_penalty float null Penalty for repeated tokens
generation_kwargs dict null Additional vLLM SamplingParams (e.g., stop_token_ids)
chat_template_kwargs dict null Chat template kwargs (e.g., {enable_thinking: false})
vllm_guided_decoding_regex str null Regex constraint for guided decoding

Async pipeline

Option Type Default Description
use_data_producer bool false Enable data producer protocol (required for async features)
async_prefetch bool false Generate next batch in background thread
prefetch_depth int null Number of batches to prefetch ahead
vllm_sync_interval int null Sync LoRA weights to vLLM every N steps
vllm_lora_sync bool false Use filesystem LoRA sync instead of NCCL merge
streaming_partial_batch bool null Score prompt groups incrementally
streaming_min_groups int null Minimum groups per streaming chunk
skip_zero_advantage_batches bool true Skip micro-batches with zero learning signal
reward_num_workers int 1 Subprocess workers for reward computation
vllm_enable_sleep_mode bool null Offload vLLM weights when idle (colocate mode)

Importance sampling

Option Type Default Description
vllm_importance_sampling_correction bool null Enable IS correction for async distribution shift
importance_sampling_level "token" or "sequence" null Granularity of IS ratios. Use token with Liger
vllm_importance_sampling_mode str null token_mask, token_truncate, sequence_mask, or sequence_truncate
vllm_importance_sampling_cap float null Cap C for IS ratio clipping/masking
off_policy_mask_threshold float null KL threshold for off-policy sequence masking (OPSM)
use_bias_correction_kl bool null Apply IS correction to KL divergence term

Replay and re-roll

Option Type Default Description
replay_buffer_size int 0 Max cached high-signal groups. 0 = disabled
replay_recompute_logps bool true Recompute log-probs for replayed data with current model
reroll_start_fraction float 1.0 Start re-rolling failed prompts after this fraction of training. 1.0 = disabled
reroll_max_groups int 1 Max prompt groups to replace with re-rolls per batch

Reference model

Option Type Default Description
sync_ref_model bool false Periodically sync reference model with training model
ref_model_mixup_alpha float 0.9 EMA coefficient for reference model sync
ref_model_sync_steps int 64 Sync reference model every N steps

Logging

Option Type Default Description
log_completions bool false Log sample completions to W&B
num_completions_to_print int null Number of completions to print per step
use_liger_loss bool null Use Liger fused kernel for GRPO loss (reduces VRAM)