monkeypatch.ring_attn.patch
monkeypatch.ring_attn.patch
Ring attention group registration and flash attention patching.
Make use of the ring-flash-attn (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the hf_adapter.substitute_hf_flash_attn function to patch in
their sequence parallel version of Flash Attention 2.
We also provide some patches for accelerate functions to prepare the dataloader for sequence parallelism training.
Functions
| Name | Description |
|---|---|
| get_ring_attn_group | Getter for ring attention group on this rank. |
| register_ring_attn_from_device_mesh | Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn. |
| set_ring_attn_group | Setter for ring attention group on this rank. |
| update_ring_attn_params | Calculate the cumulative sequence lengths for the current forward pass and pass the |
get_ring_attn_group
monkeypatch.ring_attn.patch.get_ring_attn_group()Getter for ring attention group on this rank.
register_ring_attn_from_device_mesh
monkeypatch.ring_attn.patch.register_ring_attn_from_device_mesh(
device_mesh,
context_parallel_dim,
heads_k_stride,
ring_attn_func,
)Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| device_mesh | DeviceMesh | DeviceMesh object containing the parallelism topology. | required |
| context_parallel_dim | tuple[str, …] | Name of the sequence parallel dimension in the device mesh. | required |
| heads_k_stride | int | None | Sequence parallelism K head stride size. Passed through to varlen_llama3 ring_flash_attn implementation. |
required |
| ring_attn_func | RingAttnFunc | None | ring_flash_attn ring attention implemention. If sample packing is enabled, it must be a varlen function; otherwise, it must be a batch function. |
required |
set_ring_attn_group
monkeypatch.ring_attn.patch.set_ring_attn_group(ring_attn_group)Setter for ring attention group on this rank.
update_ring_attn_params
monkeypatch.ring_attn.patch.update_ring_attn_params(position_ids)Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted ring_flash_attn.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| position_ids | torch.Tensor | None | Optional tensor of position IDs (for sample packed data). | required |