Training Stability & Debugging
This guide covers practical techniques for monitoring training health, diagnosing instability, and resolving common failures in both supervised fine-tuning (SFT) and reinforcement learning (GRPO/EBFT) workflows.
Monitoring Training
Key Metrics for SFT
Every SFT run should be monitored through at least these four metrics:
| Metric | What It Tells You | Healthy Range |
|---|---|---|
train/loss |
How well the model fits training data | Decreasing; typically 0.5–2.0 for chat fine-tuning |
eval/loss |
Generalization performance | Tracks train loss with small gap; divergence signals overfitting |
grad_norm |
Gradient magnitude | 0.1–10.0; spikes above 100 indicate instability |
learning_rate |
Current LR from scheduler | Should follow expected schedule (warmup then decay) |
Enable W&B or TensorBoard from the start. Debugging a failed run without metrics is guesswork.
wandb_project: my-project
wandb_run_id: # optional, for resuming
logging_steps: 1Key Metrics for RL (GRPO)
GRPO training logs a richer set of metrics. These are the critical ones:
| Metric | Healthy Range | Red Flag |
|---|---|---|
rewards/<name>/mean |
> 0.15 within 20 steps | Stays at 0 – reward function is broken or task is too hard |
reward_std |
> 0 on most steps | Always 0 – no learning signal (all completions get the same reward) |
frac_reward_zero_std |
< 0.8 | 1.0 on every step – zero-advantage skip fires constantly, no gradient updates |
grad_norm |
0.001–1.0 | 0.0 is acceptable occasionally (zero-adv skip); > 10.0 is unstable |
entropy |
0.05–0.5 | < 0.01 suggests mode collapse; > 1.0 suggests the model is not converging |
kl |
0.0–0.5 | > 2.0 suggests policy has diverged too far from reference |
sampling/sampling_logp_difference/mean |
< 0.1 | > 1.0 means policy has diverged far from vLLM server weights |
sampling/importance_sampling_ratio/min |
> 0.1 | Near 0 indicates stale off-policy data; increase vllm_sync_interval |
clip_ratio/region_mean |
< 0.1 | > 0.3 means PPO clipping is too aggressive |
completions/mean_length |
Task-dependent | Monotonically increasing to max length suggests reward hacking |
completions/clipped_ratio |
< 0.3 | > 0.8 means most completions hit max_completion_length – increase it |
For EBFT training, also monitor ebft/alignment (should trend upward, healthy 0.3–0.9), ebft/diversity (healthy 0.01–0.1; > 1.0 indicates mode collapse), and ebft/cfm_loss (should trend downward, < 10).
SFT Stability
Loss Plateau
Symptom: Loss stops decreasing early in training, well above expected values.
Causes and fixes:
- Learning rate too low: Increase by 2–5x. Typical ranges: full fine-tune 1e-5 to 5e-5, LoRA 1e-4 to 3e-4.
- Insufficient warmup: Set
warmup_stepsto 5–10% of total steps. Too-aggressive learning at the start can push the model into a flat region. - Data quality: Check that labels are correctly masked. Use
axolotl preprocessand inspect tokenized samples to confirm only the target tokens are trainable. - Weight decay too high: Default 0.01 is usually fine. Values above 0.1 can suppress learning in LoRA.
Loss Spikes
Symptom: Loss suddenly jumps by 2–10x then (possibly) recovers.
Causes and fixes:
- Bad data samples: A single malformed or extremely long example can cause a spike. Enable
sample_packing: falsetemporarily and check if spikes correlate with specific batches. - Learning rate too high: Reduce by 2–5x, or increase warmup.
- Gradient accumulation mismatch: Effective batch size =
micro_batch_size * gradient_accumulation_steps * num_gpus. Very large effective batch sizes amplify gradient noise. - Mixed precision issues: With
bf16: true, some operations can lose precision. If spikes are severe, tryfp32for diagnosis.
Overfitting
Symptom: Train loss keeps decreasing but eval loss starts increasing.
Fixes:
- Increase
val_set_size(e.g., 0.05) and monitoreval/loss. - Reduce
num_epochsormax_steps. - Increase
weight_decay(try 0.01–0.1). - Use a smaller LoRA rank (
lora_r). Typical values: 8–32. - Increase dropout:
lora_dropout: 0.05.
RL/GRPO Stability
Reward Never Increases
If rewards/*/mean stays at 0 for more than 20 steps:
Test reward function standalone: Run it outside training with known inputs to verify it returns nonzero values.
cd experiments && python -c "import my_rewards; print(my_rewards.accuracy_reward(...))"Check dataset columns: The reward function receives
**kwargscontaining dataset columns. Verify the columns it needs (e.g.,answer) are not removed by the dataset transform.Check completion content: Enable
log_completions: truein thetrl:config and inspect logged completions in W&B. If completions are empty or incoherent, the model may be too weak for the task.Verify vLLM is serving the right model: Hit the vLLM health endpoint and confirm the model name matches your config.
Entropy Collapse (Mode Collapse)
Symptom: entropy drops below 0.01; all completions become nearly identical.
Fixes:
- Increase
temperaturein generation kwargs (try 0.8–1.0). - Reduce learning rate.
- Add a KL penalty term (
betaparameter in GRPO config). - Check that
num_generationsis sufficient (16+ gives better advantage estimates).
IS Ratio Divergence
Symptom: sampling/importance_sampling_ratio/min drops near 0, or sampling/sampling_logp_difference/mean exceeds 1.0.
This means the policy has diverged significantly from the weights used by vLLM for generation. The importance sampling correction becomes unreliable.
Fixes:
- Decrease
vllm_sync_interval(sync weights more often). - Enable
off_policy_mask_threshold(e.g., 0.5) to mask stale off-policy samples. - Use
importance_sampling_level: tokenfor finer-grained correction.
Gradient Norm Instability
Symptom: grad_norm oscillates wildly or exceeds 10.0 regularly.
Fixes:
- Enable gradient clipping:
max_grad_norm: 1.0(default in most configs). - Reduce learning rate.
- Increase
gradient_accumulation_stepsto smooth out noisy batches. - Check for NaN issues (see next section).
NaN and Inf Handling
Common Causes
| Cause | Where It Manifests | Detection |
|---|---|---|
| FP8 zero-scale division | Forward pass logits | grad_norm: nan, loss becomes NaN immediately |
| Gradient explosion | Backward pass | grad_norm spikes to inf, then loss goes NaN |
| Bad data (empty sequences) | Logprob computation | NaN in specific batches only |
| Numerical overflow in log-softmax | Loss computation | Large negative logprobs cause exp() overflow |
FP8-Specific NaN Issues
FP8 quantization (fp8: true) can produce NaN when the activation quantization kernel divides by max(abs(x)) / 448. If the input tensor is all zeros (e.g., padding positions), the scale becomes 0, causing division by zero.
Fixes applied in axolotl:
- The
act_quant_kernelhas a zero-guard:s = tl.where(s == 0, 1.0, s). - A safety net
nan_to_num(logits, nan=0.0)is applied in_get_per_token_logps_and_entropies. - Embedding padding is zero-padded for FP8 compatibility.
If you patch any Triton JIT kernel (e.g., the FP8 quantization kernels in transformers), you must clear the Triton cache for changes to take effect:
rm -rf ~/.triton/cacheGeneral NaN Debugging Steps
Enable anomaly detection (slow, but pinpoints the source):
torch.autograd.set_detect_anomaly(True)Check grad_norm: If it goes to NaN, the backward pass is the problem. If loss is NaN but grad_norm was fine on the previous step, the forward pass is the problem.
Reduce to single GPU, single batch: Eliminate distributed training variables.
Inspect data: Print the batch that triggers NaN. Look for empty sequences, extreme token IDs, or unexpected padding patterns.
OOM Debugging
Out-of-memory errors are the most common training failure. Use this systematic approach, from least to most disruptive:
Step 1: Reduce Batch Size
The single highest-impact change. VRAM scales roughly linearly with batch size.
micro_batch_size: 1 # Start here
gradient_accumulation_steps: 16 # Increase to maintain effective batch sizeFor GRPO specifically, the logits tensor for policy logprob computation can be very large. batch_size * num_generations * seq_len * vocab_size in bf16. For example, with num_generations: 16 and micro_batch_size: 8, the logits tensor alone is:
8 * 16 * 2048 * 151936 * 2 bytes = ~75 GB (way too large)
Reduce micro_batch_size to 2–4 for GRPO.
Step 2: Enable Gradient Checkpointing
Trades compute for memory by recomputing activations during the backward pass instead of storing them.
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false # Recommended defaultSome configurations require use_reentrant: true:
- DeepSpeed ZeRO-3 (non-reentrant causes
CheckpointError) - EBFT strided mode with flex_attention
Step 3: Use Quantization
Load the base model in reduced precision:
# 4-bit QLoRA
adapter: qlora
load_in_4bit: true
# 8-bit
load_in_8bit: true
# FP8 (saves ~50% model VRAM, same compute speed as bf16)
fp8: trueStep 4: Reduce Sequence Length
sequence_len: 1024 # Down from 2048 or 4096For GRPO, also reduce max_completion_length. Memory scales quadratically with sequence length when using standard attention.
Step 5: Use Flash Attention
Reduces attention memory from O(n^2) to O(n):
flash_attention: trueStep 6: Offload with DeepSpeed
For extreme cases, offload optimizer states or parameters to CPU:
deepspeed: deepspeed_configs/zero3_bf16.jsonDiagnosing the Specific Culprit
Use the profiler_steps config option to capture GPU memory snapshots:
profiler_steps: [1, 2]This generates PyTorch profiler traces you can inspect to see exactly which tensor allocation caused the OOM.
Common Errors
| Error Message | Likely Cause | Fix |
|---|---|---|
exitcode: -9 |
System RAM exhaustion | Reduce dataset size, dataset_num_proc, or number of data workers |
exitcode: -7 (DeepSpeed) |
DeepSpeed version issue | pip install -U deepspeed |
CUDA out of memory |
GPU VRAM exhaustion | Follow OOM debugging steps above |
RuntimeError: NCCL communicator was aborted |
GPU communication failure | See NCCL docs; check NCCL_DEBUG=INFO output |
ValueError: Asking to pad but the tokenizer does not have a padding token |
Missing pad token | Add special_tokens: { pad_token: "<\|endoftext\|>" } to config |
'DummyOptim' object has no attribute 'step' |
DeepSpeed on single GPU | Remove deepspeed: section from config |
unable to load strategy X then None is not callable |
Reward module not importable | Run cd experiments && python -c "import my_rewards" to check |
generation_batch_size not divisible by num_generations |
micro_batch_size too small | Set micro_batch_size >= num_generations and make it divisible |
'weight' must be 2-D |
FSDP1 flattened parameters | Use fsdp_version: 2 or skip unwrap_model when FSDP is enabled |
CheckpointError (tensor count mismatch) |
Non-reentrant checkpointing + ZeRO-3 or flex_attention | Set use_reentrant: true in gradient_checkpointing_kwargs |
BFloat16 TypeError during weight sync |
NumPy does not support bf16 | Fixed in axolotl’s weight_serde.py (auto bf16 to fp16 conversion) |
Content end boundary is before start boundary |
Chat template parsing issue | Check eos_token matches template; file a GitHub issue if persistent |
CAS service error during data processing |
HuggingFace XET issue | Set export HF_HUB_DISABLE_XET=1 |
| Training hangs (multi-GPU) | FSDP + async prefetch deadlock | Set async_prefetch: false with FSDP |
Profiling
PyTorch Profiler
Axolotl supports PyTorch profiler integration via the config:
profiler_steps: [1, 2, 3]This captures profiler traces for the specified steps. View them in TensorBoard:
tensorboard --logdir output_dir/runsOr open the .json trace file in chrome://tracing.
CUDA Memory Snapshots
For detailed memory analysis, use PyTorch’s memory snapshot API. Add this to your training script or use it interactively:
import torch
# Enable memory history tracking
torch.cuda.memory._record_memory_history()
# ... run your training step ...
# Save snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")Visualize with PyTorch’s memory visualizer:
python -m torch.cuda.memory._viz memory_snapshot.pickleQuick GPU Memory Check
During training, monitor GPU utilization in a separate terminal:
watch -n 1 nvidia-smiFor programmatic access within axolotl, the logged metrics memory/max_alloc and memory/max_reserved come from torch.cuda.max_memory_allocated() and torch.cuda.max_memory_reserved(). Note these report PyTorch’s view of memory, which may differ from nvidia-smi (see FAQ).
W&B and Logging
Enabling Logging
wandb_project: my-project
wandb_entity: my-team # optional
wandb_run_id: run-123 # optional, for resuming
wandb_name: experiment-name # optional
logging_steps: 1 # log every step (recommended for RL)Debug Logging
For detailed axolotl-internal debug output:
AXOLOTL_LOG_LEVEL=DEBUG axolotl train config.yaml 2>&1 | tee /tmp/training.logPipe training output to a log file so you can inspect it after the run:
axolotl train config.yaml 2>&1 | tee /tmp/my_run.logWhat Axolotl Logs
SFT metrics (logged every logging_steps):
train/loss,eval/loss– training and validation losstrain/grad_norm– gradient L2 norm (before clipping)train/learning_rate– current learning ratememory/max_alloc,memory/max_reserved– peak GPU memory
GRPO/RL metrics (logged every step):
rewards/<name>/mean,rewards/<name>/std– per-reward-function statisticsreward,reward_std– aggregated reward across all reward functionsfrac_reward_zero_std– fraction of prompt groups where all completions got the same rewardcompletions/mean_length,completions/min_length,completions/max_length– completion token lengthscompletions/clipped_ratio– fraction of completions that hit the max lengthcompletions/mean_terminated_length,completions/min_terminated_length,completions/max_terminated_length– lengths of naturally terminated completionskl– KL divergence between policy and referenceentropy– policy entropy (measure of output diversity)clip_ratio/region_mean,clip_ratio/low_mean,clip_ratio/high_mean– PPO clipping statisticssampling/sampling_logp_difference/mean,sampling/sampling_logp_difference/max– log-probability difference between policy and sampling distributionsampling/importance_sampling_ratio/min,sampling/importance_sampling_ratio/mean,sampling/importance_sampling_ratio/max– IS ratio statistics for off-policy correctionnum_tokens– total tokens processed
Reading W&B Charts
For a healthy GRPO run, expect to see:
reward/mean: Gradual upward trend. May start near 0 and reach 0.3–0.8 depending on task difficulty. Not monotonic – fluctuations are normal.entropy: Gradual decrease from initial values (often 0.3–0.6) as the model becomes more confident. Should not collapse to near-zero.grad_norm: Mostly in the 0.001–1.0 range. Occasional 0.0 values are fine (zero-advantage skip). Persistent values above 10.0 need investigation.kl: Starts near 0 and grows slowly. If it shoots up rapidly, the policy is diverging from the reference.completions/mean_length: Should reflect the task’s natural answer length. If it steadily increases tomax_completion_length, the model may be reward-hacking by generating longer outputs.