Sequence Parallelism
Sequence Parallelism
Sequence parallelism is a technique that splits sequences across multiple GPUs, allowing you to train with very long sequences that wouldn’t fit on a single GPU. Each GPU processes a different portion of the sequence, and the results are aggregated through a ring communication pattern.
When to Use Sequence Parallelism
Use sequence parallelism when:
- You need to train with sequence lengths that don’t fit into a single GPU’s memory
- You have multiple GPUs available
- You’re experiencing OOM (Out Of Memory) errors with long sequences
Configuration
To enable sequence parallelism, add the following to your configuration file:
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
The sequence_parallel_degree
should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
Implementation Details
When sequence parallelism is enabled:
- Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
- The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
- Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
- The trainer uses special ring communication patterns for attention operations
Requirements
To use sequence parallelism, you need:
- Multiple GPUs (at least 2)
- The
ring-flash-attn
package. Install with:pip install axolotl[ring-flash-attn]
(preferred)pip install ring-flash-attn>=0.1.4
Limitations
- Flash attention must be enabled for this to work (
flash_attention: true
in config YAML) - May have a small performance overhead due to communication between GPUs
Example
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
...
This will train the Llama 3 8B model with 8K context length, with each sequence split into 2 subsequences of length 4096 across 2 GPUs.
Sample Packing with Sequence Parallelism
Sequence parallelism is compatible with Axolotl’s sample packing functionality. When using both features together:
- Samples are first packed together
- The packed sequences are then divided across GPUs in the sequence parallel group
- Position IDs are automatically adjusted to maintain proper relative positions
Effect on Batch Size
When using sequence parallelism, your effective global batch size is divided by the sequence_parallel_degree
. This happens because:
- Each group of
sequence_parallel_degree
GPUs works on the same batch (just different parts of each sequence) - The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and sequence_parallel_degree=4
: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU micro_batch_size
is 2, the global batch size decreases from 16 to 4