monkeypatch.llama_attn_hijack_flash
monkeypatch.llama_attn_hijack_flash
Flash attention monkey patch for llama model
Functions
Name | Description |
---|---|
flashattn_forward_with_s2attn | Input shape: Batch x Time x Channel |
flashattn_forward_with_s2attn
monkeypatch.llama_attn_hijack_flash.flashattn_forward_with_s2attn(self,
hidden_states,=None,
attention_mask=None,
position_ids=None,
past_key_value=False,
output_attentions=False,
use_cache=None,
padding_mask=None,
cu_seqlens=None,
max_seqlen )
Input shape: Batch x Time x Channel
From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
attention_mask: [bsz, q_len]
cu_seqlens
will be ignored if provided
max_seqlen
will be ignored if provided