Attention

Supported attention modules in Axolotl

Axolotl routes attention via a single config field:

attn_implementation: <backend>

attn_implementation is passed through to transformers verbatim (via model.config._attn_implementation). Accepted values are the HF-native backends, axolotl-registered backends, or a hub-kernel path.

Backends

attn_implementation Description
eager Plain PyTorch attention. No packing support.
sdpa PyTorch scaled_dot_product_attention. No packing support.
flash_attention_2 Dao-AILab Flash Attention 2.
flash_attention_3 Dao-AILab Flash Attention 3 (Hopper+).
flex_attention Torch Flex Attention (requires torch ≥ 2.6).
xformers xFormers memory-efficient attention.
sage SageAttention (QK int8 / PV fp16).
s2 Shifted-Sparse Attention (LLaMA only, FA2 under the hood).
fp8 torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load.
kernels-community/flash-attn3 HF hub FA3 kernel.
kernels-community/sage-attention HF hub SageAttention kernel.
Other <org>/<name> path Any hub-kernel path supported by transformers.

Short-form aliases (flash, fa2, flex, sdp, etc.) are not accepted — set the canonical name above.

Capability flags

Axolotl derives three boolean capability flags from attn_implementation and exposes them on the validated config:

  • cfg.attn_supports_packing — backend supports varlen sample packing via position_ids. Gates multipack patches and sample_packing_drop_attention_mask.
  • cfg.attn_uses_flash_lib — backend needs the flash_attn (Dao-AILab) monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
  • cfg.attn_needs_dtype_cast — backend requires fp16/bf16 embeddings (everything except eager and sdpa).

These are computed — they cannot be overridden from YAML.

Per-backend notes

SDPA

Default PyTorch attention. See PyTorch docs.

attn_implementation: sdpa

Flash Attention

Axolotl supports FA2, FA3, and FA4. The best available version is used automatically based on your installed packages and GPU.

attn_implementation: flash_attention_2  # or flash_attention_3

Flash Attention 2

Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)

pip install flash-attn --no-build-isolation
Tip

If you get undefined symbol while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.

Flash Attention 3

Requirements: Hopper only and CUDA 12.8 (recommended)

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install

Flash Attention 4

Requirements: Hopper or Blackwell GPUs. Auto-applied when attn_uses_flash_lib is true and FA4 is importable.

FA4 is still a pre-release on PyPI, so --pre is required:

pip install --pre flash-attn-4

Or from source:

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .

# FA2's flash_attn package includes a cute/ stub that shadows FA4.
# Remove it so Python can find the real FA4 module:
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
Note

Hopper (SM90) users: The backward kernel is not yet included in the pip package. To use FA4 for training on Hopper, install from source using the instructions above.

Warning

FA4 only supports head dimensions up to 128 (d ≤ 128). The DeepSeek shape (192, 128) is also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions and falls back to FA2/3.

AMD

Requirements: ROCm 6.0 and above. See Flash Attention AMD docs.

Flex Attention

attn_implementation: flex_attention
torch_compile: true  # recommended

Requires torch ≥ 2.6. See PyTorch docs.

SageAttention

Requirements: Ampere, Ada, or Hopper GPUs.

attn_implementation: sage
pip install sageattention==2.2.0 --no-build-isolation
Warning

Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See GitHub Issue.

For more details: Sage Attention.

xFormers

attn_implementation: xformers
Tip

Recommended for Turing GPUs or below (e.g. Colab T4).

Shifted Sparse Attention

Warning

Planned for deprecation. Prefer one of the backends above.

Requirements: LLaMA model architecture. Loaded as FA2 under the hood and patched to implement shifted-sparse attention. Does not support sample packing.

attn_implementation: s2

FP8

torchao low-precision attention. Loaded as SDPA and patched post-load.

Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17, flash-attn with FA3. KV caching must be disabled.

attn_implementation: fp8

Hub kernels

attn_implementation: kernels-community/flash-attn3

Passed through to transformers; axolotl does not install the kernel itself. For recognized hub paths the capability flags are set automatically; for arbitrary paths axolotl uses conservative defaults (attn_supports_packing=False, attn_uses_flash_lib=False).

Migrating from legacy boolean flags

The following legacy config fields are deprecated and will be removed in a future release. Each emits a DeprecationWarning when set and is stripped from the validated config.

Legacy Canonical
flash_attention: true attn_implementation: flash_attention_2
sdp_attention: true attn_implementation: sdpa
xformers_attention: true attn_implementation: xformers
flex_attention: true attn_implementation: flex_attention
sage_attention: true attn_implementation: sage
s2_attention: true attn_implementation: s2
eager_attention: true attn_implementation: eager

Combining attn_implementation with a legacy flag (e.g. attn_implementation: flash_attention_2 and flash_attention: true) raises — pick one.

Note

Existing example configs under examples/ still use the legacy flags. They continue to work with a deprecation warning; they will be migrated in a follow-up pass.