Multi-GPU
This guide covers advanced training configurations for multi-GPU setups using Axolotl.
1 Overview
Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA
2 DeepSpeed
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
2.1 Configuration
Add to your YAML config:
deepspeed: deepspeed_configs/zero1.json
2.2 Usage
# Fetch deepspeed configs (if not already present)
axolotl fetch deepspeed_configs
# Passing arg via config
axolotl train config.yml
# Passing arg via cli
axolotl train config.yml --deepspeed deepspeed_configs/zero1.json
2.3 ZeRO Stages
We provide default configurations for:
- ZeRO Stage 1 (
zero1.json
) - ZeRO Stage 1 with torch compile (
zero1_torch_compile.json
) - ZeRO Stage 2 (
zero2.json
) - ZeRO Stage 3 (
zero3.json
) - ZeRO Stage 3 with bf16 (
zero3_bf16.json
) - ZeRO Stage 3 with bf16 and CPU offload params(
zero3_bf16_cpuoffload_params.json
) - ZeRO Stage 3 with bf16 and CPU offload params and optimizer (
zero3_bf16_cpuoffload_all.json
)
Choose the configuration that offloads the least amount to memory while still being able to fit on VRAM for best performance.
Start from Stage 1 -> Stage 2 -> Stage 3.
3 FSDP
3.1 Basic FSDP Configuration
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
4 Sequence parallelism
We support sequence parallelism (SP) via the ring-flash-attention project. This allows one to split up sequences across GPUs, which is useful in the event that a single sequence causes OOM errors during model training.
First, install ring-flash-attn
, recommended via pip install axolotl[ring-flash-attn]
,
or from source with pip install .[ring-flash-attn]
.
Your Axolotl YAML config should contain the following lines:
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
See our dedicated guide for more details.
4.1 FSDP + QLoRA
For combining FSDP with QLoRA, see our dedicated guide.
5 Performance Optimization
5.1 Liger Kernel Integration
Please see docs for more info.
6 Troubleshooting
6.1 NCCL Issues
For NCCL-related problems, see our NCCL troubleshooting guide.
6.2 Common Problems
- Reduce
micro_batch_size
- Reduce
eval_batch_size
- Adjust
gradient_accumulation_steps
- Consider using a higher ZeRO stage
- Start with DeepSpeed ZeRO-2
- Monitor loss values
- Check learning rates
For more detailed troubleshooting, see our debugging guide.