Training Stability & Debugging

Guide to monitoring, debugging, and stabilizing training runs in axolotl

This guide covers practical techniques for monitoring training health, diagnosing instability, and resolving common failures in both supervised fine-tuning (SFT) and reinforcement learning (GRPO/EBFT) workflows.

Monitoring Training

Key Metrics for SFT

Every SFT run should be monitored through at least these four metrics:

Metric What It Tells You Healthy Range
train/loss How well the model fits training data Decreasing; typically 0.5–2.0 for chat fine-tuning
eval/loss Generalization performance Tracks train loss with small gap; divergence signals overfitting
grad_norm Gradient magnitude 0.1–10.0; spikes above 100 indicate instability
learning_rate Current LR from scheduler Should follow expected schedule (warmup then decay)
TipSet Up Logging Early

Enable W&B or TensorBoard from the start. Debugging a failed run without metrics is guesswork.

wandb_project: my-project
wandb_run_id:   # optional, for resuming
logging_steps: 1

Key Metrics for RL (GRPO)

GRPO training logs a richer set of metrics. These are the critical ones:

Metric Healthy Range Red Flag
rewards/<name>/mean > 0.15 within 20 steps Stays at 0 – reward function is broken or task is too hard
reward_std > 0 on most steps Always 0 – no learning signal (all completions get the same reward)
frac_reward_zero_std < 0.8 1.0 on every step – zero-advantage skip fires constantly, no gradient updates
grad_norm 0.001–1.0 0.0 is acceptable occasionally (zero-adv skip); > 10.0 is unstable
entropy 0.05–0.5 < 0.01 suggests mode collapse; > 1.0 suggests the model is not converging
kl 0.0–0.5 > 2.0 suggests policy has diverged too far from reference
sampling/sampling_logp_difference/mean < 0.1 > 1.0 means policy has diverged far from vLLM server weights
sampling/importance_sampling_ratio/min > 0.1 Near 0 indicates stale off-policy data; increase vllm_sync_interval
clip_ratio/region_mean < 0.1 > 0.3 means PPO clipping is too aggressive
completions/mean_length Task-dependent Monotonically increasing to max length suggests reward hacking
completions/clipped_ratio < 0.3 > 0.8 means most completions hit max_completion_length – increase it
NoteEBFT-Specific Metrics

For EBFT training, also monitor ebft/alignment (should trend upward, healthy 0.3–0.9), ebft/diversity (healthy 0.01–0.1; > 1.0 indicates mode collapse), and ebft/cfm_loss (should trend downward, < 10).

SFT Stability

Loss Plateau

Symptom: Loss stops decreasing early in training, well above expected values.

Causes and fixes:

  • Learning rate too low: Increase by 2–5x. Typical ranges: full fine-tune 1e-5 to 5e-5, LoRA 1e-4 to 3e-4.
  • Insufficient warmup: Set warmup_steps to 5–10% of total steps. Too-aggressive learning at the start can push the model into a flat region.
  • Data quality: Check that labels are correctly masked. Use axolotl preprocess and inspect tokenized samples to confirm only the target tokens are trainable.
  • Weight decay too high: Default 0.01 is usually fine. Values above 0.1 can suppress learning in LoRA.

Loss Spikes

Symptom: Loss suddenly jumps by 2–10x then (possibly) recovers.

Causes and fixes:

  • Bad data samples: A single malformed or extremely long example can cause a spike. Enable sample_packing: false temporarily and check if spikes correlate with specific batches.
  • Learning rate too high: Reduce by 2–5x, or increase warmup.
  • Gradient accumulation mismatch: Effective batch size = micro_batch_size * gradient_accumulation_steps * num_gpus. Very large effective batch sizes amplify gradient noise.
  • Mixed precision issues: With bf16: true, some operations can lose precision. If spikes are severe, try fp32 for diagnosis.

Overfitting

Symptom: Train loss keeps decreasing but eval loss starts increasing.

Fixes:

  • Increase val_set_size (e.g., 0.05) and monitor eval/loss.
  • Reduce num_epochs or max_steps.
  • Increase weight_decay (try 0.01–0.1).
  • Use a smaller LoRA rank (lora_r). Typical values: 8–32.
  • Increase dropout: lora_dropout: 0.05.

RL/GRPO Stability

Reward Never Increases

If rewards/*/mean stays at 0 for more than 20 steps:

  1. Test reward function standalone: Run it outside training with known inputs to verify it returns nonzero values.

    cd experiments && python -c "import my_rewards; print(my_rewards.accuracy_reward(...))"
  2. Check dataset columns: The reward function receives **kwargs containing dataset columns. Verify the columns it needs (e.g., answer) are not removed by the dataset transform.

  3. Check completion content: Enable log_completions: true in the trl: config and inspect logged completions in W&B. If completions are empty or incoherent, the model may be too weak for the task.

  4. Verify vLLM is serving the right model: Hit the vLLM health endpoint and confirm the model name matches your config.

Entropy Collapse (Mode Collapse)

Symptom: entropy drops below 0.01; all completions become nearly identical.

Fixes:

  • Increase temperature in generation kwargs (try 0.8–1.0).
  • Reduce learning rate.
  • Add a KL penalty term (beta parameter in GRPO config).
  • Check that num_generations is sufficient (16+ gives better advantage estimates).

IS Ratio Divergence

Symptom: sampling/importance_sampling_ratio/min drops near 0, or sampling/sampling_logp_difference/mean exceeds 1.0.

This means the policy has diverged significantly from the weights used by vLLM for generation. The importance sampling correction becomes unreliable.

Fixes:

  • Decrease vllm_sync_interval (sync weights more often).
  • Enable off_policy_mask_threshold (e.g., 0.5) to mask stale off-policy samples.
  • Use importance_sampling_level: token for finer-grained correction.

Gradient Norm Instability

Symptom: grad_norm oscillates wildly or exceeds 10.0 regularly.

Fixes:

  • Enable gradient clipping: max_grad_norm: 1.0 (default in most configs).
  • Reduce learning rate.
  • Increase gradient_accumulation_steps to smooth out noisy batches.
  • Check for NaN issues (see next section).

NaN and Inf Handling

Common Causes

Cause Where It Manifests Detection
FP8 zero-scale division Forward pass logits grad_norm: nan, loss becomes NaN immediately
Gradient explosion Backward pass grad_norm spikes to inf, then loss goes NaN
Bad data (empty sequences) Logprob computation NaN in specific batches only
Numerical overflow in log-softmax Loss computation Large negative logprobs cause exp() overflow

FP8-Specific NaN Issues

FP8 quantization (fp8: true) can produce NaN when the activation quantization kernel divides by max(abs(x)) / 448. If the input tensor is all zeros (e.g., padding positions), the scale becomes 0, causing division by zero.

Fixes applied in axolotl:

  • The act_quant_kernel has a zero-guard: s = tl.where(s == 0, 1.0, s).
  • A safety net nan_to_num(logits, nan=0.0) is applied in _get_per_token_logps_and_entropies.
  • Embedding padding is zero-padded for FP8 compatibility.
ImportantAfter Modifying Triton Kernels

If you patch any Triton JIT kernel (e.g., the FP8 quantization kernels in transformers), you must clear the Triton cache for changes to take effect:

rm -rf ~/.triton/cache

General NaN Debugging Steps

  1. Enable anomaly detection (slow, but pinpoints the source):

    torch.autograd.set_detect_anomaly(True)
  2. Check grad_norm: If it goes to NaN, the backward pass is the problem. If loss is NaN but grad_norm was fine on the previous step, the forward pass is the problem.

  3. Reduce to single GPU, single batch: Eliminate distributed training variables.

  4. Inspect data: Print the batch that triggers NaN. Look for empty sequences, extreme token IDs, or unexpected padding patterns.

OOM Debugging

Out-of-memory errors are the most common training failure. Use this systematic approach, from least to most disruptive:

Step 1: Reduce Batch Size

The single highest-impact change. VRAM scales roughly linearly with batch size.

micro_batch_size: 1              # Start here
gradient_accumulation_steps: 16  # Increase to maintain effective batch size

For GRPO specifically, the logits tensor for policy logprob computation can be very large. batch_size * num_generations * seq_len * vocab_size in bf16. For example, with num_generations: 16 and micro_batch_size: 8, the logits tensor alone is:

8 * 16 * 2048 * 151936 * 2 bytes = ~75 GB  (way too large)

Reduce micro_batch_size to 2–4 for GRPO.

Step 2: Enable Gradient Checkpointing

Trades compute for memory by recomputing activations during the backward pass instead of storing them.

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false     # Recommended default
WarningReentrant Checkpointing Exceptions

Some configurations require use_reentrant: true:

  • DeepSpeed ZeRO-3 (non-reentrant causes CheckpointError)
  • EBFT strided mode with flex_attention

Step 3: Use Quantization

Load the base model in reduced precision:

# 4-bit QLoRA
adapter: qlora
load_in_4bit: true

# 8-bit
load_in_8bit: true

# FP8 (saves ~50% model VRAM, same compute speed as bf16)
fp8: true

Step 4: Reduce Sequence Length

sequence_len: 1024     # Down from 2048 or 4096

For GRPO, also reduce max_completion_length. Memory scales quadratically with sequence length when using standard attention.

Step 5: Use Flash Attention

Reduces attention memory from O(n^2) to O(n):

flash_attention: true

Step 6: Offload with DeepSpeed

For extreme cases, offload optimizer states or parameters to CPU:

deepspeed: deepspeed_configs/zero3_bf16.json

Diagnosing the Specific Culprit

Use the profiler_steps config option to capture GPU memory snapshots:

profiler_steps: [1, 2]

This generates PyTorch profiler traces you can inspect to see exactly which tensor allocation caused the OOM.

Common Errors

Error Message Likely Cause Fix
exitcode: -9 System RAM exhaustion Reduce dataset size, dataset_num_proc, or number of data workers
exitcode: -7 (DeepSpeed) DeepSpeed version issue pip install -U deepspeed
CUDA out of memory GPU VRAM exhaustion Follow OOM debugging steps above
RuntimeError: NCCL communicator was aborted GPU communication failure See NCCL docs; check NCCL_DEBUG=INFO output
ValueError: Asking to pad but the tokenizer does not have a padding token Missing pad token Add special_tokens: { pad_token: "<\|endoftext\|>" } to config
'DummyOptim' object has no attribute 'step' DeepSpeed on single GPU Remove deepspeed: section from config
unable to load strategy X then None is not callable Reward module not importable Run cd experiments && python -c "import my_rewards" to check
generation_batch_size not divisible by num_generations micro_batch_size too small Set micro_batch_size >= num_generations and make it divisible
'weight' must be 2-D FSDP1 flattened parameters Use fsdp_version: 2 or skip unwrap_model when FSDP is enabled
CheckpointError (tensor count mismatch) Non-reentrant checkpointing + ZeRO-3 or flex_attention Set use_reentrant: true in gradient_checkpointing_kwargs
BFloat16 TypeError during weight sync NumPy does not support bf16 Fixed in axolotl’s weight_serde.py (auto bf16 to fp16 conversion)
Content end boundary is before start boundary Chat template parsing issue Check eos_token matches template; file a GitHub issue if persistent
CAS service error during data processing HuggingFace XET issue Set export HF_HUB_DISABLE_XET=1
Training hangs (multi-GPU) FSDP + async prefetch deadlock Set async_prefetch: false with FSDP

Profiling

PyTorch Profiler

Axolotl supports PyTorch profiler integration via the config:

profiler_steps: [1, 2, 3]

This captures profiler traces for the specified steps. View them in TensorBoard:

tensorboard --logdir output_dir/runs

Or open the .json trace file in chrome://tracing.

CUDA Memory Snapshots

For detailed memory analysis, use PyTorch’s memory snapshot API. Add this to your training script or use it interactively:

import torch

# Enable memory history tracking
torch.cuda.memory._record_memory_history()

# ... run your training step ...

# Save snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

Visualize with PyTorch’s memory visualizer:

python -m torch.cuda.memory._viz memory_snapshot.pickle

Quick GPU Memory Check

During training, monitor GPU utilization in a separate terminal:

watch -n 1 nvidia-smi

For programmatic access within axolotl, the logged metrics memory/max_alloc and memory/max_reserved come from torch.cuda.max_memory_allocated() and torch.cuda.max_memory_reserved(). Note these report PyTorch’s view of memory, which may differ from nvidia-smi (see FAQ).

W&B and Logging

Enabling Logging

wandb_project: my-project
wandb_entity: my-team          # optional
wandb_run_id: run-123          # optional, for resuming
wandb_name: experiment-name    # optional
logging_steps: 1               # log every step (recommended for RL)

Debug Logging

For detailed axolotl-internal debug output:

AXOLOTL_LOG_LEVEL=DEBUG axolotl train config.yaml 2>&1 | tee /tmp/training.log
TipAlways Log to a File

Pipe training output to a log file so you can inspect it after the run:

axolotl train config.yaml 2>&1 | tee /tmp/my_run.log

What Axolotl Logs

SFT metrics (logged every logging_steps):

  • train/loss, eval/loss – training and validation loss
  • train/grad_norm – gradient L2 norm (before clipping)
  • train/learning_rate – current learning rate
  • memory/max_alloc, memory/max_reserved – peak GPU memory

GRPO/RL metrics (logged every step):

  • rewards/<name>/mean, rewards/<name>/std – per-reward-function statistics
  • reward, reward_std – aggregated reward across all reward functions
  • frac_reward_zero_std – fraction of prompt groups where all completions got the same reward
  • completions/mean_length, completions/min_length, completions/max_length – completion token lengths
  • completions/clipped_ratio – fraction of completions that hit the max length
  • completions/mean_terminated_length, completions/min_terminated_length, completions/max_terminated_length – lengths of naturally terminated completions
  • kl – KL divergence between policy and reference
  • entropy – policy entropy (measure of output diversity)
  • clip_ratio/region_mean, clip_ratio/low_mean, clip_ratio/high_mean – PPO clipping statistics
  • sampling/sampling_logp_difference/mean, sampling/sampling_logp_difference/max – log-probability difference between policy and sampling distribution
  • sampling/importance_sampling_ratio/min, sampling/importance_sampling_ratio/mean, sampling/importance_sampling_ratio/max – IS ratio statistics for off-policy correction
  • num_tokens – total tokens processed

Reading W&B Charts

For a healthy GRPO run, expect to see:

  1. reward/mean: Gradual upward trend. May start near 0 and reach 0.3–0.8 depending on task difficulty. Not monotonic – fluctuations are normal.
  2. entropy: Gradual decrease from initial values (often 0.3–0.6) as the model becomes more confident. Should not collapse to near-zero.
  3. grad_norm: Mostly in the 0.001–1.0 range. Occasional 0.0 values are fine (zero-advantage skip). Persistent values above 10.0 need investigation.
  4. kl: Starts near 0 and grows slowly. If it shoots up rapidly, the policy is diverging from the reference.
  5. completions/mean_length: Should reflect the task’s natural answer length. If it steadily increases to max_completion_length, the model may be reward-hacking by generating longer outputs.