kernels.rms_norm_gated

kernels.rms_norm_gated

Fused RMSNorm + SiLU Gate Triton kernel.

Computes: Y = (W + offset) * RMSNorm(X) * silu(G) where RMSNorm(X) = X / sqrt(mean(X^2) + eps) and silu(G) = G * sigmoid(G)

Used by Qwen3.5’s GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).

Classes

Name Description
FusedRMSNormGated Fused RMSNorm + SiLU Gate.

FusedRMSNormGated

kernels.rms_norm_gated.FusedRMSNormGated(
    hidden_size,
    eps=1e-06,
    offset=0.0,
    **kwargs,
)

Fused RMSNorm + SiLU Gate.

Computes: Y = W * RMSNorm(X) * silu(G)

Drop-in replacement for Qwen3_5RMSNormGated with matching init signature: init(hidden_size, eps=1e-6, **kwargs) and forward signature: forward(hidden_states, gate=None)