monkeypatch.scaled_softmax_attn
monkeypatch.scaled_softmax_attn
Scaled Softmax (SSMax) attention patch using FlexAttention. SSMax: softmax(scores * s * log(n) + b) where n is the position index Ref: https://arxiv.org/abs/2501.19399
Functions
| Name | Description |
|---|---|
| patch_scaled_softmax_attention | Patch attention to apply SSMax via FlexAttention score_mod. |
| ssmax_flex_attention_forward | FlexAttention forward with SSMax: score * (s * log(n) + b). |
| unpatch_scaled_softmax_attention | Restore the original FlexAttention function. |
patch_scaled_softmax_attention
monkeypatch.scaled_softmax_attn.patch_scaled_softmax_attention(
scaling_factor_init=0.43,
bias=0.0,
model=None,
)Patch attention to apply SSMax via FlexAttention score_mod.
ssmax_flex_attention_forward
monkeypatch.scaled_softmax_attn.ssmax_flex_attention_forward(
module,
query,
key,
value,
attention_mask,
scaling=None,
softcap=None,
**kwargs,
)FlexAttention forward with SSMax: score * (s * log(n) + b).
unpatch_scaled_softmax_attention
monkeypatch.scaled_softmax_attn.unpatch_scaled_softmax_attention()Restore the original FlexAttention function.