Sequence Parallelism

Train with long sequences split across multiple GPUs.

Sequence Parallelism

Sequence parallelism is a technique that splits sequences across multiple GPUs, allowing you to train with very long sequences that wouldn’t fit on a single GPU. Each GPU processes a different portion of the sequence, and the results are aggregated through a ring communication pattern.

When to Use Sequence Parallelism

Use sequence parallelism when:

  • You need to train with sequence lengths that don’t fit into a single GPU’s memory
  • You have multiple GPUs available
  • You’re experiencing OOM (Out Of Memory) errors with long sequences

Configuration

To enable sequence parallelism, add the following to your configuration file:

# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4  # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1

The sequence_parallel_degree should be a divisor of the total number of GPUs. For example:

  • With 8 GPUs, valid values would be 2, 4, or 8
  • With 4 GPUs, valid values would be 2 or 4

Implementation Details

When sequence parallelism is enabled:

  1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
  2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
  3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
  4. The trainer uses special ring communication patterns for attention operations

Requirements

To use sequence parallelism, you need:

  • Multiple GPUs (at least 2)
  • The ring-flash-attn package. Install with:
    • pip install axolotl[ring-flash-attn] (preferred)
    • pip install ring-flash-attn>=0.1.4

Limitations

  • Flash attention must be enabled for this to work (flash_attention: true in config YAML)
  • May have a small performance overhead due to communication between GPUs

Example

base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192

...

sequence_parallel_degree: 4  # Split each sequence into 4 parts, one per GPU
flash_attention: true  # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1

...

This will train the Llama 3 8B model with 8K context length, with each sequence split into 2 subsequences of length 4096 across 2 GPUs.

Sample Packing with Sequence Parallelism

Sequence parallelism is compatible with Axolotl’s sample packing functionality. When using both features together:

  1. Samples are first packed together
  2. The packed sequences are then divided across GPUs in the sequence parallel group
  3. Position IDs are automatically adjusted to maintain proper relative positions

Effect on Batch Size

When using sequence parallelism, your effective global batch size is divided by the sequence_parallel_degree. This happens because:

  • Each group of sequence_parallel_degree GPUs works on the same batch (just different parts of each sequence)
  • The number of batches processed per step decreases

For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and sequence_parallel_degree=4: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4