core.trainers.grpo.sampler

core.trainers.grpo.sampler

Repeat random sampler (similar to the one implemented in https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds sequence parallelism functionality; i.e., duplicating data across ranks in the same sequence parallel group.

Classes

Name Description
SequenceParallelRepeatRandomSampler Sampler for GRPO training with sequence parallelism.

SequenceParallelRepeatRandomSampler

core.trainers.grpo.sampler.SequenceParallelRepeatRandomSampler(
    dataset,
    mini_repeat_count,
    world_size,
    rank,
    batch_size=1,
    repeat_count=1,
    sequence_parallel_degree=1,
    shuffle=True,
    seed=0,
    drop_last=False,
)

Sampler for GRPO training with sequence parallelism.

This sampler ensures: - Ranks in the same sequence parallel (SP) group receive identical data. - Each index is repeated multiple times for sampling different completions. - Entire batches are repeated for reuse in multiple updates. - Data is properly distributed across SP groups.

In the table below, the values represent dataset indices. Each SP group has sequence_parallel_degree = 2 GPUs working together on the same data. There are 2 SP groups (SP0 and SP1), with world_size = 4 total GPUs.

                                       Sequence Parallel Groups
                                |       SP0        |       SP1        |
                                |  GPU 0  |  GPU 1 |  GPU 2  |  GPU 3 |
            global_step  step    <---> mini_repeat_count=3
                                    <----------> batch_size=2 per SP group

grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data ▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU | | 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation

                 2       4         [4 4 4  5 5 5]     [6 6 6  7 7 7]   <- New batch of data indices
                 2       5         [4 4 4  5 5 5]     [6 6 6  7 7 7]
                                    ...

Parameters

Name Type Description Default
dataset Sized Dataset to sample from. required
mini_repeat_count int How many times to repeat each sample immediately. required
world_size int Total number of processes. required
rank int Rank of current process. required
batch_size int Number of samples per batch. 1
repeat_count int How many times to repeat the full sampling process. 1
sequence_parallel_degree int Number of ranks in a sequence parallel group. 1
shuffle bool Whether to shuffle the dataset. True
seed int Random seed for shuffling. 0
drop_last bool Whether to drop the last incomplete batch. False

Methods

Name Description
set_epoch Sets the epoch for this sampler.
set_epoch
core.trainers.grpo.sampler.SequenceParallelRepeatRandomSampler.set_epoch(epoch)

Sets the epoch for this sampler.

Parameters
Name Type Description Default
epoch int Epoch number to use for shuffling. required