monkeypatch.gemma4_hybrid_mask

monkeypatch.gemma4_hybrid_mask

Hybrid attention mask fix for Gemma 4 (standard and unified).

Gemma 4 has full-attention (global) layers with head_dim=512 which exceeds flash-attention-2’s supported size. Axolotl’s hybrid-attention patch in patch_manager._apply_gemma_hybrid_attention works around this by forcing _attn_implementation="sdpa" on each global layer’s self_attn.config, leaving sliding-window layers on FA2.

The per-layer config override alone is insufficient, however: Gemma4TextModel.forward builds a single causal_mask_mapping dict using the model-level config and passes the mapped mask to each decoder layer. With FA2 still set at the model level, the full_attention entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask. The global layers then fail with::

RuntimeError: The expanded size of the tensor (S) must match the existing
size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor
sizes: [B, S]

…when the sequence length grows past roughly 7k tokens.

This module fixes the symptom by monkey-patching create_causal_mask in the model’s module namespace — NOT the original in masking_utils. The wrapper forces _attn_implementation="sdpa" on a shallow-copied config before calling through, so the full_attention mask built inside the text backbone’s forward is always 4D/SDPA-compatible. create_sliding_window_causal_mask is left alone, so sliding-window layers continue to receive FA2-format masks.

gemma4_unified reproduces the same mixed sliding/global architecture (global_head_dim=512) in its own modeling_gemma4_unified namespace, so both namespaces are patched when present.

The patch is idempotent. Install once per process, before any Gemma 4 forward pass runs.

Functions

Name Description
patch_gemma4_hybrid_mask Install the Gemma 4 hybrid-attention mask fix across all variants.
unpatch_gemma4_hybrid_mask Restore the original create_causal_mask in every namespace. Tests.

patch_gemma4_hybrid_mask

monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask()

Install the Gemma 4 hybrid-attention mask fix across all variants.

Returns True if at least one namespace was patched, False if none of the target modules could be imported (e.g. transformers version predates Gemma 4) — in which case nothing is done and the caller can continue unaffected.

unpatch_gemma4_hybrid_mask

monkeypatch.gemma4_hybrid_mask.unpatch_gemma4_hybrid_mask()

Restore the original create_causal_mask in every namespace. Tests.