kernels.gemma4_fused_rope
kernels.gemma4_fused_rope
Fused RMSNorm + (partial) RoPE Triton kernel for Gemma 4 / Qwen3 Q/K paths.
Functions
| Name | Description |
|---|---|
| fused_rms_norm_noscale | RMSNorm without a learned scale (used for v_norm). |
| fused_rms_norm_rope | Apply fused RMSNorm + (partial) RoPE. |
fused_rms_norm_noscale
kernels.gemma4_fused_rope.fused_rms_norm_noscale(x, eps=1e-06)RMSNorm without a learned scale (used for v_norm).
fused_rms_norm_rope
kernels.gemma4_fused_rope.fused_rms_norm_rope(
x,
weight,
cos,
sin,
eps=1e-06,
unit_offset=False,
)Apply fused RMSNorm + (partial) RoPE.
Shapes
x: (B, S, H, D) — post-projection
weight: (D,) — required; use fused_rms_norm_noscale for the no-weight variant
cos: (B, S, n_rot) — n_rot must be even and <= D; trailing
D - n_rot columns are RMSNorm-only (partial rotary).
sin: (B, S, n_rot)
unit_offset=True scales by (weight + 1.0) (Gemma-style).