Optimizations Guide
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
Speed Optimizations
These optimizations focus on increasing training throughput and reducing total training time.
Sample Packing
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the attention implementations below.
- Config:
sample_packing: true
- Learn more: Sample Packing
Attention Implementations
Using an optimized attention implementation is critical for training speed.
- Flash Attention 2:
flash_attention: true
. (Recommended) The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check AMD Support. - Flex Attention:
flex_attention: true
. - SDP Attention:
sdp_attention: true
. PyTorch’s native implementation. - Xformers:
xformers_attention: true
. Works with FP16.
Note: You should only enable one attention backend.
LoRA Optimizations
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
- Learn more: LoRA Optimizations Documentation
Memory Optimizations
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
Parameter Efficient Finetuning (LoRA & QLoRA)
Drastically reduces memory by training a small set of “adapter” parameters instead of the full model. This is the most common and effective memory-saving technique.
- Examples: Find configs with
lora
orqlora
in the examples directory. - Config Reference: See
adapter
,load_in_4bit
, andload_in_8bit
in the Configuration Reference.
Gradient Checkpointing & Activation Offloading
These techniques save VRAM by changing how activations are handled.
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
- Learn more: Gradient Checkpointing and Offloading Docs
Cut Cross Entropy (CCE)
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
- Learn more: Custom Integrations - CCE
Liger Kernels
Provides efficient Triton kernels to improve training speed and reduce memory usage.
- Learn more: Custom Integrations - Liger Kernels
Long Context Models
Techniques to train models on sequences longer than their original context window.
RoPE Scaling
Extends a model’s context window by interpolating its Rotary Position Embeddings.
- Config: Pass the
rope_scaling
config under theoverrides_of_model_config:
. To learn how to set RoPE, check the respective model config.
Sequence Parallelism
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
- Learn more: Sequence Parallelism Documentation
Artic Long Sequence Training (ALST)
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
TiledMLP to reduce memory usage in MLP layers.
Tiled Loss functions (like CCE.
Activation Offloading to CPU.
Example: ALST Example Configuration
Large Models (Distributed Training)
To train models that don’t fit on a single GPU, you’ll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
- Learn more: Multi-GPU Guide
- Learn more: Multi-Node Guide
N-D Parallelism (Beta)
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
- Learn more: N-D Parallelism Guide
Quantization
Techniques to reduce the precision of model weights for memory savings.
4-bit Training (QLoRA)
The recommended approach for quantization-based training. It loads the base model in 4-bit using bitsandbytes
and then trains QLoRA adapters. See Adapter Finetuning for details.
FP8 Training
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
- Example: Llama 3 FP8 FSDP Example
Quantization Aware Training (QAT)
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
- Learn more: QAT Documentation
GPTQ
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
- Example: GPTQ LoRA Example