integrations.expert_parallel.shard

integrations.expert_parallel.shard

Generic expert-weight sharding for @use_experts_implementation modules.

After this runs (in post_model_build, before FSDP wraps), each rank’s Experts modules hold only their local slice of the experts dim. The registered deep_ep_* forward function then handles dispatch -> local compute -> combine.

Functions

Name Description
shard_expert_weights Slice expert weights along dim 0 per the EP rank.

shard_expert_weights

integrations.expert_parallel.shard.shard_expert_weights(model, ep_group)

Slice expert weights along dim 0 per the EP rank.

Parameters

Name Type Description Default
model A built (but not yet FSDP-wrapped) HuggingFace model. required
ep_group torch.distributed.ProcessGroup for EP, or None for single-rank (no-op). required

Returns

Name Type Description
int Number of Experts modules sharded (0 if EP disabled or none found).

Raises

Name Type Description
ValueError if any Experts module’s num_experts is not divisible by the EP world size.

DDP composition: the sharded params hold DIFFERENT content per rank, so we add their fully-qualified names to model._ddp_params_and_buffers_to_ignore to prevent the startup broadcast from copying rank 0’s slice everywhere. FSDP composition is handled in ExpertParallelPlugin.fully_shard_experts.