monkeypatch.models.gemma4_unified.fused_attn
monkeypatch.models.gemma4_unified.fused_attn
Gemma 4 Unified fused attention monkeypatch.
Mirrors :mod:axolotl.monkeypatch.models.gemma4.fused_attn for the
encoder-free gemma4_unified text backbone (Gemma4UnifiedTextAttention),
replacing the per-layer RMSNorm + RoPE + transpose sequence with fused Triton
kernels.
The math is identical to standard Gemma 4 (q_norm/k_norm with scale + RoPE,
v_norm without scale, no RoPE), so the same kernels are reused. Like standard
Gemma 4 on transformers 5.10.x, the unified attention keys shared_kv_states
by layer type string (self.layer_type); the separate module is needed
only because the unified backbone redefines its classes in its own namespace
(it does not modular-import from modeling_gemma4).
Usage
from axolotl.monkeypatch.models.gemma4_unified.fused_attn import ( patch_gemma4_unified_fused_attn, ) # Pass install_shared_kv_workaround=True when activation checkpointing is enabled. patch_gemma4_unified_fused_attn(install_shared_kv_workaround=True)
Functions
| Name | Description |
|---|---|
| patch_gemma4_unified_fused_attn | Monkeypatch Gemma4UnifiedTextAttention.forward to use fused RMSNorm+RoPE |
patch_gemma4_unified_fused_attn
monkeypatch.models.gemma4_unified.fused_attn.patch_gemma4_unified_fused_attn(
install_shared_kv_workaround=False,
)Monkeypatch Gemma4UnifiedTextAttention.forward to use fused RMSNorm+RoPE
kernels, and optionally route shared_kv_states via a module-level side
channel to avoid a VRAM leak under activation checkpointing (PR #3611).