monkeypatch.ring_attn.adapters.batch

monkeypatch.ring_attn.adapters.batch

HuggingFace flash attention adapter for basic ring attention (batch API).

Inspired by https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py. Our implementation closely follows the structure of that module, but we’ve minified it somewhat to support only the latest versions of transformers.

Functions

Name Description
create_flash_attn_forward_varlen_llama3 Create a ring flash attention forward function compatible with HuggingFace’s
substitute_hf_flash_attn Substitute HuggingFace’s flash attention implementation with ring-based implementation.

create_flash_attn_forward_varlen_llama3

monkeypatch.ring_attn.adapters.batch.create_flash_attn_forward_varlen_llama3(
    process_group,
    ring_attn_func,
)

Create a ring flash attention forward function compatible with HuggingFace’s interface.

Parameters

Name Type Description Default
process_group dist.ProcessGroup A PyTorch distributed process group. required
ring_attn_func RingAttnFunc Function from ring_flash_attention to replace HF flash attention with. required

Returns

Name Type Description
Callable A function that implements the ring flash attention forward pass with the signature expected by HuggingFace Transformers.

substitute_hf_flash_attn

monkeypatch.ring_attn.adapters.batch.substitute_hf_flash_attn(
    process_group,
    ring_attn_func,
)

Substitute HuggingFace’s flash attention implementation with ring-based implementation.

Parameters

Name Type Description Default
process_group dist.ProcessGroup PyTorch distributed process group for communication. required
ring_attn_func RingAttnFunc Function from ring_flash_attention to replace HF flash attention with. required