monkeypatch.ring_attn.patch

monkeypatch.ring_attn.patch

Ring attention group registration and flash attention patching.

Make use of the ring-flash-attn (https://github.com/zhuzilin/ring-flash-attention) package, specifically the hf_adapter.substitute_hf_flash_attn function to patch in their sequence parallel version of Flash Attention 2.

We also provide some patches for accelerate functions to prepare the dataloader for sequence parallelism training.

Functions

Name Description
get_ring_attn_group Getter for ring attention group on this rank.
register_ring_attn_from_device_mesh Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn.
set_ring_attn_group Setter for ring attention group on this rank.
update_ring_attn_params Calculate the cumulative sequence lengths for the current forward pass and pass the

get_ring_attn_group

monkeypatch.ring_attn.patch.get_ring_attn_group()

Getter for ring attention group on this rank.

register_ring_attn_from_device_mesh

monkeypatch.ring_attn.patch.register_ring_attn_from_device_mesh(
    device_mesh,
    context_parallel_dim,
    heads_k_stride,
    ring_attn_func,
)

Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn.

Parameters

Name Type Description Default
device_mesh DeviceMesh DeviceMesh object containing the parallelism topology. required
context_parallel_dim tuple[str, …] Name of the sequence parallel dimension in the device mesh. required
heads_k_stride int | None Sequence parallelism K head stride size. Passed through to varlen_llama3 ring_flash_attn implementation. required
ring_attn_func RingAttnFunc | None ring_flash_attn ring attention implemention. If sample packing is enabled, it must be a varlen function; otherwise, it must be a batch function. required

set_ring_attn_group

monkeypatch.ring_attn.patch.set_ring_attn_group(ring_attn_group)

Setter for ring attention group on this rank.

update_ring_attn_params

monkeypatch.ring_attn.patch.update_ring_attn_params(position_ids)

Calculate the cumulative sequence lengths for the current forward pass and pass the value to the substituted ring_flash_attn.

Parameters

Name Type Description Default
position_ids torch.Tensor | None Optional tensor of position IDs (for sample packed data). required