monkeypatch.scaled_softmax_attn

monkeypatch.scaled_softmax_attn

Scaled Softmax (SSMax) attention patch using FlexAttention. SSMax: softmax(scores * s * log(n) + b) where n is the position index Ref: https://arxiv.org/abs/2501.19399

Functions

Name Description
patch_scaled_softmax_attention Patch attention to apply SSMax via FlexAttention score_mod.
ssmax_flex_attention_forward FlexAttention forward with SSMax: score * (s * log(n) + b).
unpatch_scaled_softmax_attention Restore the original FlexAttention function.

patch_scaled_softmax_attention

monkeypatch.scaled_softmax_attn.patch_scaled_softmax_attention(
    scaling_factor_init=0.43,
    bias=0.0,
    model=None,
)

Patch attention to apply SSMax via FlexAttention score_mod.

ssmax_flex_attention_forward

monkeypatch.scaled_softmax_attn.ssmax_flex_attention_forward(
    module,
    query,
    key,
    value,
    attention_mask,
    scaling=None,
    softcap=None,
    **kwargs,
)

FlexAttention forward with SSMax: score * (s * log(n) + b).

unpatch_scaled_softmax_attention

monkeypatch.scaled_softmax_attn.unpatch_scaled_softmax_attention()

Restore the original FlexAttention function.