monkeypatch.llama_attn_hijack_flash

monkeypatch.llama_attn_hijack_flash

Flash attention monkey patch for llama model

Functions

Name Description
flashattn_forward_with_s2attn Input shape: Batch x Time x Channel

flashattn_forward_with_s2attn

monkeypatch.llama_attn_hijack_flash.flashattn_forward_with_s2attn(
    self,
    hidden_states,
    attention_mask=None,
    position_ids=None,
    past_key_value=None,
    output_attentions=False,
    use_cache=False,
    padding_mask=None,
    cu_seqlens=None,
    max_seqlen=None,
)

Input shape: Batch x Time x Channel

From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py

attention_mask: [bsz, q_len]

cu_seqlens will be ignored if provided max_seqlen will be ignored if provided