Attention
Axolotl routes attention via a single config field:
attn_implementation: <backend>attn_implementation is passed through to transformers verbatim (via
model.config._attn_implementation). Accepted values are the HF-native
backends, axolotl-registered backends, or a hub-kernel path.
Backends
attn_implementation |
Description |
|---|---|
eager |
Plain PyTorch attention. No packing support. |
sdpa |
PyTorch scaled_dot_product_attention. No packing support. |
flash_attention_2 |
Dao-AILab Flash Attention 2. |
flash_attention_3 |
Dao-AILab Flash Attention 3 (Hopper+). |
flex_attention |
Torch Flex Attention (requires torch ≥ 2.6). |
xformers |
xFormers memory-efficient attention. |
sage |
SageAttention (QK int8 / PV fp16). |
s2 |
Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
fp8 |
torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
kernels-community/flash-attn3 |
HF hub FA3 kernel. |
kernels-community/sage-attention |
HF hub SageAttention kernel. |
Other <org>/<name> path |
Any hub-kernel path supported by transformers. |
Short-form aliases (flash, fa2, flex, sdp, etc.) are not accepted —
set the canonical name above.
Capability flags
Axolotl derives three boolean capability flags from attn_implementation and
exposes them on the validated config:
cfg.attn_supports_packing— backend supports varlen sample packing viaposition_ids. Gates multipack patches andsample_packing_drop_attention_mask.cfg.attn_uses_flash_lib— backend needs theflash_attn(Dao-AILab) monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).cfg.attn_needs_dtype_cast— backend requires fp16/bf16 embeddings (everything excepteagerandsdpa).
These are computed — they cannot be overridden from YAML.
Per-backend notes
SDPA
Default PyTorch attention. See PyTorch docs.
attn_implementation: sdpaFlash Attention
Axolotl supports FA2, FA3, and FA4. The best available version is used automatically based on your installed packages and GPU.
attn_implementation: flash_attention_2 # or flash_attention_3Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
pip install flash-attn --no-build-isolationIf you get undefined symbol while training, ensure you installed PyTorch prior to Axolotl.
Alternatively, try reinstall or downgrade a version.
Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py installFlash Attention 4
Requirements: Hopper or Blackwell GPUs. Auto-applied when attn_uses_flash_lib
is true and FA4 is importable.
FA4 is still a pre-release on PyPI, so --pre is required:
pip install --pre flash-attn-4Or from source:
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
# Remove it so Python can find the real FA4 module:
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cuteHopper (SM90) users: The backward kernel is not yet included in the pip package. To use FA4 for training on Hopper, install from source using the instructions above.
FA4 only supports head dimensions up to 128 (d ≤ 128). The DeepSeek shape (192, 128) is
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
and falls back to FA2/3.
AMD
Requirements: ROCm 6.0 and above. See Flash Attention AMD docs.
Flex Attention
attn_implementation: flex_attention
torch_compile: true # recommendedRequires torch ≥ 2.6. See PyTorch docs.
SageAttention
Requirements: Ampere, Ada, or Hopper GPUs.
attn_implementation: sagepip install sageattention==2.2.0 --no-build-isolationOnly LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See GitHub Issue.
For more details: Sage Attention.
xFormers
attn_implementation: xformersRecommended for Turing GPUs or below (e.g. Colab T4).
Shifted Sparse Attention
Planned for deprecation. Prefer one of the backends above.
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and patched to implement shifted-sparse attention. Does not support sample packing.
attn_implementation: s2FP8
torchao low-precision attention. Loaded as SDPA and patched post-load.
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17, flash-attn with FA3. KV caching must be disabled.
attn_implementation: fp8Hub kernels
attn_implementation: kernels-community/flash-attn3Passed through to transformers; axolotl does not install the kernel itself.
For recognized hub paths the capability flags are set automatically; for
arbitrary paths axolotl uses conservative defaults (attn_supports_packing=False,
attn_uses_flash_lib=False).
Migrating from legacy boolean flags
The following legacy config fields are deprecated and will be removed in a
future release. Each emits a DeprecationWarning when set and is stripped from
the validated config.
| Legacy | Canonical |
|---|---|
flash_attention: true |
attn_implementation: flash_attention_2 |
sdp_attention: true |
attn_implementation: sdpa |
xformers_attention: true |
attn_implementation: xformers |
flex_attention: true |
attn_implementation: flex_attention |
sage_attention: true |
attn_implementation: sage |
s2_attention: true |
attn_implementation: s2 |
eager_attention: true |
attn_implementation: eager |
Combining attn_implementation with a legacy flag (e.g. attn_implementation: flash_attention_2 and flash_attention: true) raises — pick one.
Existing example configs under examples/ still use the legacy flags. They
continue to work with a deprecation warning; they will be migrated in a
follow-up pass.