utils.ctx_managers.sequence_parallel
utils.ctx_managers.sequence_parallel
Module for Axolotl trainer sequence parallelism manager and utilities
Classes
Name | Description |
---|---|
AllGatherWithGrad | Custom autograd function for all-gather to preserve gradients. |
SequenceParallelContextManager | Context manager for sequence parallelism operations. |
AllGatherWithGrad
utils.ctx_managers.sequence_parallel.AllGatherWithGrad()
Custom autograd function for all-gather to preserve gradients.
Methods
Name | Description |
---|---|
backward | Backward pass for all-gather operation. |
forward | Forward pass of all-gather of data with sequence dimension. |
backward
utils.ctx_managers.sequence_parallel.AllGatherWithGrad.backward(
ctx,
grad_output, )
Backward pass for all-gather operation.
Extracts the gradient slice corresponding to this rank’s original input from the full gradient tensor.
Parameters
Name | Type | Description | Default |
---|---|---|---|
ctx | torch.autograd.function.FunctionCtx | torch.autograd function context. |
required |
grad_output | torch.Tensor | Gradient from subsequent layers with respect to the concatenated output tensor. | required |
Returns
Name | Type | Description |
---|---|---|
tuple[torch.Tensor, None] | Tuple containing the gradient slice for this rank’s input tensor and None for the process group parameter which doesn’t require gradients. |
forward
utils.ctx_managers.sequence_parallel.AllGatherWithGrad.forward(
ctx,
input_tensor,
group, )
Forward pass of all-gather of data with sequence dimension.
Parameters
Name | Type | Description | Default |
---|---|---|---|
ctx | torch.autograd.function.FunctionCtx | torch.autograd function context. |
required |
input_tensor | torch.Tensor | Tensor from model output with sequence dimension. | required |
group | dist.ProcessGroup | torch.distributed process group. |
required |
Returns
Name | Type | Description |
---|---|---|
torch.Tensor | Tensor from gathering the input_tensor from across the process group and concatenating along the sequence dimension. |
SequenceParallelContextManager
utils.ctx_managers.sequence_parallel.SequenceParallelContextManager(
models,
sequence_parallel_degree,
gradient_accumulation_steps,
ring_attn_func,
heads_k_stride, )
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism during model forward passes using a pre-forward hook, and gather outputs from across the sequence parallelism group using a post-forward hook.
Parameters
Name | Type | Description | Default |
---|---|---|---|
models | list[nn.Module] | List of models to apply sequence parallelism to pre- and post- forward hooks. | required |
sequence_parallel_degree | int | Number of processes to split sequences over. | required |
gradient_accumulation_steps | int | Number of steps to accumulate gradients over. | required |
ring_attn_func | RingAttnFunc | Which ring attention function to use. Currently unused. | required |
heads_k_stride | int | None | Sequence parallelism K head stride size. Passed through to varlen_llama3 ring_flash_attn implementation. |
required |
Functions
Name | Description |
---|---|
apply_sequence_parallelism | Apply sequence parallelism slicing to a batch. |
apply_sequence_parallelism
utils.ctx_managers.sequence_parallel.apply_sequence_parallelism(
batch,
local_rank,
local_world_size,
gradient_accumulation_steps,
ring_attn_func, )
Apply sequence parallelism slicing to a batch.
Special handling is implemented for integer logits_to_keep, which indicates to only keep the last N tokens in the sequence during generation.
Parameters
Name | Type | Description | Default |
---|---|---|---|
batch | dict[str, torch.Tensor] | Batch dictionary (e.g., input_ids, attention_mask, etc.). | required |
local_rank | int | Local rank in the sequence parallel group. | required |
local_world_size | int | World size of the sequence parallel group. | required |
gradient_accumulation_steps | int | Number of steps to accumulate gradients over. | required |
ring_attn_func | RingAttnFunc | Which ring attention function to use. Currently unused, but related to above TODO. | required |
Returns
Name | Type | Description |
---|---|---|
tuple[dict[str, torch.Tensor], int, int] | tuple of: - Batch dictionary with sliced tensors. - The original sequence length before padding. - The number of padding tokens added. |