utils.ctx_managers.sequence_parallel

utils.ctx_managers.sequence_parallel

Module for Axolotl trainer sequence parallelism manager and utilities

Classes

Name Description
AllGatherWithGrad Custom autograd function for all-gather to preserve gradients.
SequenceParallelContextManager Context manager for sequence parallelism operations.

AllGatherWithGrad

utils.ctx_managers.sequence_parallel.AllGatherWithGrad()

Custom autograd function for all-gather to preserve gradients.

Methods

Name Description
backward Backward pass for all-gather operation.
forward Forward pass of all-gather of data with sequence dimension.
backward
utils.ctx_managers.sequence_parallel.AllGatherWithGrad.backward(
    ctx,
    grad_output,
)

Backward pass for all-gather operation.

Extracts the gradient slice corresponding to this rank’s original input from the full gradient tensor.

Parameters
Name Type Description Default
ctx torch.autograd.function.FunctionCtx torch.autograd function context. required
grad_output torch.Tensor Gradient from subsequent layers with respect to the concatenated output tensor. required
Returns
Name Type Description
tuple[torch.Tensor, None] Tuple containing the gradient slice for this rank’s input tensor and None for the process group parameter which doesn’t require gradients.
forward
utils.ctx_managers.sequence_parallel.AllGatherWithGrad.forward(
    ctx,
    input_tensor,
    group,
)

Forward pass of all-gather of data with sequence dimension.

Parameters
Name Type Description Default
ctx torch.autograd.function.FunctionCtx torch.autograd function context. required
input_tensor torch.Tensor Tensor from model output with sequence dimension. required
group dist.ProcessGroup torch.distributed process group. required
Returns
Name Type Description
torch.Tensor Tensor from gathering the input_tensor from across the process group and concatenating along the sequence dimension.

SequenceParallelContextManager

utils.ctx_managers.sequence_parallel.SequenceParallelContextManager(
    models,
    sequence_parallel_degree,
    gradient_accumulation_steps,
    ring_attn_func,
    heads_k_stride,
)

Context manager for sequence parallelism operations.

This class provides a context that will automatically apply sequence parallelism during model forward passes using a pre-forward hook, and gather outputs from across the sequence parallelism group using a post-forward hook.

Parameters

Name Type Description Default
models list[nn.Module] List of models to apply sequence parallelism to pre- and post- forward hooks. required
sequence_parallel_degree int Number of processes to split sequences over. required
gradient_accumulation_steps int Number of steps to accumulate gradients over. required
ring_attn_func RingAttnFunc Which ring attention function to use. Currently unused. required
heads_k_stride int | None Sequence parallelism K head stride size. Passed through to varlen_llama3 ring_flash_attn implementation. required

Functions

Name Description
apply_sequence_parallelism Apply sequence parallelism slicing to a batch.

apply_sequence_parallelism

utils.ctx_managers.sequence_parallel.apply_sequence_parallelism(
    batch,
    local_rank,
    local_world_size,
    gradient_accumulation_steps,
    ring_attn_func,
)

Apply sequence parallelism slicing to a batch.

Special handling is implemented for integer logits_to_keep, which indicates to only keep the last N tokens in the sequence during generation.

Parameters

Name Type Description Default
batch dict[str, torch.Tensor] Batch dictionary (e.g., input_ids, attention_mask, etc.). required
local_rank int Local rank in the sequence parallel group. required
local_world_size int World size of the sequence parallel group. required
gradient_accumulation_steps int Number of steps to accumulate gradients over. required
ring_attn_func RingAttnFunc Which ring attention function to use. Currently unused, but related to above TODO. required

Returns

Name Type Description
tuple[dict[str, torch.Tensor], int, int] tuple of: - Batch dictionary with sliced tensors. - The original sequence length before padding. - The number of padding tokens added.