N-D Parallelism (Beta)

Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:

or combinations of the above!

Core Concepts

Parallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorch’s DeviceMesh is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid.

Data Parallelism

Data Parallelism focuses on splitting the global data batch across GPUs.

  • Distributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states.

  • Fully Sharded Data Parallel (FSDP): A highly memory-efficient form of data parallelism (inspired by DeepSpeed’s ZeRO). Instead of replicating the model, FSDP shards the model’s parameters, gradients, and optimizer states across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an all_gather operation just before they are used, and they can be discarded immediately after (reshard-after-forward).

    • FSDP maps to ZeRO stages:
      • ZeRO-2 (reshard_after_forward=False): Shards gradients and optimizer states. Model weights are replicated on each GPU.
      • ZeRO-3 (reshard_after_forward=True): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes).

[Experimental] Tensor Parallelism (TP)

Also known as “horizontal model parallelism,” as described in the Megatron-LM paper. Instead of splitting the batch, TP splits the model’s layers themselves across GPUs.

  • How it works: For a linear layer Y = XA, the weight matrix A is split column-wise (A = [A_1, A_2]). The computation becomes Y_1 = XA_1 and Y_2 = XA_2, which can happen in parallel on different GPUs. The final output Y is simply the concatenation of Y_1 and Y_2. Check this comment for more detailed info.
  • Requirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes.

Context Parallelism (CP)

Context Parallelism, also called Sequence Parallelism, addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs.

  • How it works: If you have a sequence of 8192 tokens and a context_parallel_size of 4, each GPU will only handle a chunk of 2048 tokens.
  • The Challenge: Attention is not local; every token needs to “attend to” every other token. Splitting the sequence breaks this.
  • The Solution (ring-flash-attention): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a “ring.” After N-1 steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized flash-attention kernel at each step.

Hybrid Sharding Data Parallel (HSDP)

HSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training.

  • Intra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the all_gather operations for sharded parameters fast.
  • Inter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDP’s parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband).
  • Example: With 2 nodes of 8 GPUs each (16 total), you could have dp_shard_size=8 (FSDP within each node) and dp_replicate_size=2 (DDP across the two nodes).

Usage

# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp
fsdp_version: 2
fsdp_config:
  # ...

# The number of GPUs to shard the model parameters across (FSDP dimension).
dp_shard_size: 4

# The number of times to replicate the sharded model (DDP dimension).
dp_replicate_size: 2

# Number of GPUs for Tensor Parallelism.
tensor_parallel_size: 1  # (default is 1, no TP)

# Number of GPUs for Context/Sequence Parallelism.
context_parallel_size: 1 # (default is 1, no CP)

Note: We recommend FSDP. DeepSpeed is only compatible with tensor_parallel_size.

Examples

Tip

See our example configs here.

  1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
    • You want FSDP within each node and DDP across nodes.
    • Set dp_shard_size: 4 and dp_replicate_size: 2.
  2. FSDP + TP on a single 8-GPU node:
    • You want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP.
    • Set dp_shard_size: 4 and tensor_parallel_size: 2.
  3. FSDP + CP on a single 8-GPU node for long context:
    • You want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs.
    • Set dp_shard_size: 8 and context_parallel_size: 8. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group.

Support Matrix

This matrix describes how different parallelism methods can be combined in Axolotl.

Combination dp_replicate_size dp_shard_size tp_size cp_size Status & Notes
FSDP (ZeRO-3) 1 >1 1 1 ✅ Fully supported. Shards model across all GPUs.
HSDP >1 >1 1 1 ✅ Fully supported. FSDP intra-node, DDP inter-node.
FSDP + TP 1 >1 >1 1 2D Parallelism. Shards the model across a dp_shard group, and TP-splits layers within the tp group.
HSDP + TP >1 >1 >1 1 3D Parallelism. A powerful but complex combination.
FSDP + CP 1 >1 1 >1 2D Parallelism. Combines FSDP with context parallelism.
FSDP + TP + CP 1 >1 >1 >1 3D Parallelism. Another advanced combination.
DDP + TP/CP >1 1 >1 >1 Not Supported. The ParallelismConfig explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (dp_shard_size > 1).
Just TP / CP 1 1 >1 >1 ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long.
  • tp_size refers to tensor_parallel_size
  • cp_size refers to context_parallel_size