monkeypatch.models.gemma4_unified.fused_attn

monkeypatch.models.gemma4_unified.fused_attn

Gemma 4 Unified fused attention monkeypatch.

Mirrors :mod:axolotl.monkeypatch.models.gemma4.fused_attn for the encoder-free gemma4_unified text backbone (Gemma4UnifiedTextAttention), replacing the per-layer RMSNorm + RoPE + transpose sequence with fused Triton kernels.

The math is identical to standard Gemma 4 (q_norm/k_norm with scale + RoPE, v_norm without scale, no RoPE), so the same kernels are reused. Like standard Gemma 4 on transformers 5.10.x, the unified attention keys shared_kv_states by layer type string (self.layer_type); the separate module is needed only because the unified backbone redefines its classes in its own namespace (it does not modular-import from modeling_gemma4).

Usage

from axolotl.monkeypatch.models.gemma4_unified.fused_attn import ( patch_gemma4_unified_fused_attn, ) # Pass install_shared_kv_workaround=True when activation checkpointing is enabled. patch_gemma4_unified_fused_attn(install_shared_kv_workaround=True)

Functions

Name Description
patch_gemma4_unified_fused_attn Monkeypatch Gemma4UnifiedTextAttention.forward to use fused RMSNorm+RoPE

patch_gemma4_unified_fused_attn

monkeypatch.models.gemma4_unified.fused_attn.patch_gemma4_unified_fused_attn(
    install_shared_kv_workaround=False,
)

Monkeypatch Gemma4UnifiedTextAttention.forward to use fused RMSNorm+RoPE kernels, and optionally route shared_kv_states via a module-level side channel to avoid a VRAM leak under activation checkpointing (PR #3611).