Attention

Supported attention modules in Axolotl

SDP Attention

This is the default built-in attention in PyTorch.

sdp_attention: true

For more details: PyTorch docs

Flash Attention 2

Uses efficient kernels to compute attention.

flash_attention: true

For more details: Flash Attention

Nvidia

Requirements: Ampere, Ada, or Hopper GPUs

Note: For Turing GPUs or lower, please use other attention methods.

pip install flash-attn --no-build-isolation
Tip

If you get undefined symbol while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.

Flash Attention 3

Requirements: Hopper only and CUDA 12.8 (recommended)

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper

python setup.py install

AMD

Requirements: ROCm 6.0 and above.

See Flash Attention AMD docs.

Flex Attention

A flexible PyTorch API for attention used in combination with torch.compile.

flex_attention: true

# recommended
torch_compile: true
Note

We recommend using latest stable version of PyTorch for best performance.

For more details: PyTorch docs

SageAttention

Attention kernels with QK Int8 and PV FP16 accumulator.

sage_attention: true

Requirements: Ampere, Ada, or Hopper GPUs

pip install sageattention==2.2.0 --no-build-isolation
Warning

Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See GitHub Issue.

For more details: Sage Attention

Note

We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.

xFormers

xformers_attention: true
Tip

We recommend using with Turing GPUs or below (such as on Colab).

For more details: xFormers

Shifted Sparse Attention

Warning

We plan to deprecate this! If you use this feature, we recommend switching to methods above.

Requirements: LLaMA model architecture

flash_attention: true
s2_attention: true
Tip

No sample packing support!