kernels.gemma4_fused_rope

kernels.gemma4_fused_rope

Fused RMSNorm + (partial) RoPE Triton kernel for Gemma 4 / Qwen3 Q/K paths.

Functions

Name Description
fused_rms_norm_noscale RMSNorm without a learned scale (used for v_norm).
fused_rms_norm_rope Apply fused RMSNorm + (partial) RoPE.

fused_rms_norm_noscale

kernels.gemma4_fused_rope.fused_rms_norm_noscale(x, eps=1e-06)

RMSNorm without a learned scale (used for v_norm).

fused_rms_norm_rope

kernels.gemma4_fused_rope.fused_rms_norm_rope(
    x,
    weight,
    cos,
    sin,
    eps=1e-06,
    unit_offset=False,
)

Apply fused RMSNorm + (partial) RoPE.

Shapes

x: (B, S, H, D) — post-projection weight: (D,) — required; use fused_rms_norm_noscale for the no-weight variant cos: (B, S, n_rot) — n_rot must be even and <= D; trailing D - n_rot columns are RMSNorm-only (partial rotary). sin: (B, S, n_rot)

unit_offset=True scales by (weight + 1.0) (Gemma-style).