monkeypatch.ring_attn.adapters.batch
monkeypatch.ring_attn.adapters.batch
HuggingFace flash attention adapter for basic ring attention (batch API).
Inspired by https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py. Our implementation closely follows the structure of that module, but we’ve minified it somewhat to support only the latest versions of transformers.
Functions
| Name | Description |
|---|---|
| create_flash_attn_forward_varlen_llama3 | Create a ring flash attention forward function compatible with HuggingFace’s |
| substitute_hf_flash_attn | Substitute HuggingFace’s flash attention implementation with ring-based implementation. |
create_flash_attn_forward_varlen_llama3
monkeypatch.ring_attn.adapters.batch.create_flash_attn_forward_varlen_llama3(
process_group,
ring_attn_func,
)Create a ring flash attention forward function compatible with HuggingFace’s interface.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| process_group | dist.ProcessGroup | A PyTorch distributed process group. | required |
| ring_attn_func | RingAttnFunc | Function from ring_flash_attention to replace HF flash attention with. |
required |
Returns
| Name | Type | Description |
|---|---|---|
| Callable | A function that implements the ring flash attention forward pass with the signature expected by HuggingFace Transformers. |
substitute_hf_flash_attn
monkeypatch.ring_attn.adapters.batch.substitute_hf_flash_attn(
process_group,
ring_attn_func,
)Substitute HuggingFace’s flash attention implementation with ring-based implementation.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| process_group | dist.ProcessGroup | PyTorch distributed process group for communication. | required |
| ring_attn_func | RingAttnFunc | Function from ring_flash_attention to replace HF flash attention with. |
required |