core.trainers.grpo.sampler
core.trainers.grpo.sampler
Repeat random sampler (similar to the one implemented in https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds sequence parallelism functionality; i.e., duplicating data across ranks in the same sequence parallel group.
Classes
Name | Description |
---|---|
SequenceParallelRepeatRandomSampler | Sampler for GRPO training with sequence parallelism. |
SequenceParallelRepeatRandomSampler
core.trainers.grpo.sampler.SequenceParallelRepeatRandomSampler(
dataset,
mini_repeat_count,
world_size,
rank,=1,
batch_size=1,
repeat_count=1,
sequence_parallel_degree=True,
shuffle=0,
seed=False,
drop_last )
Sampler for GRPO training with sequence parallelism.
This sampler ensures: - Ranks in the same sequence parallel (SP) group receive identical data. - Each index is repeated multiple times for sampling different completions. - Entire batches are repeated for reuse in multiple updates. - Data is properly distributed across SP groups.
In the table below, the values represent dataset indices. Each SP group has
sequence_parallel_degree = 2
GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with world_size = 4
total GPUs.
Sequence Parallel Groups
| SP0 | SP1 |
| GPU 0 | GPU 1 | GPU 2 | GPU 3 |
global_step step <---> mini_repeat_count=3
<----------> batch_size=2 per SP group
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data ▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU | | 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation
2 4 [4 4 4 5 5 5] [6 6 6 7 7 7] <- New batch of data indices
2 5 [4 4 4 5 5 5] [6 6 6 7 7 7]
...
Parameters
Name | Type | Description | Default |
---|---|---|---|
dataset | Sized | Dataset to sample from. | required |
mini_repeat_count | int | How many times to repeat each sample immediately. | required |
world_size | int | Total number of processes. | required |
rank | int | Rank of current process. | required |
batch_size | int | Number of samples per batch. | 1 |
repeat_count | int | How many times to repeat the full sampling process. | 1 |
sequence_parallel_degree | int | Number of ranks in a sequence parallel group. | 1 |
shuffle | bool | Whether to shuffle the dataset. | True |
seed | int | Random seed for shuffling. | 0 |
drop_last | bool | Whether to drop the last incomplete batch. | False |
Methods
Name | Description |
---|---|
set_epoch | Sets the epoch for this sampler. |
set_epoch
core.trainers.grpo.sampler.SequenceParallelRepeatRandomSampler.set_epoch(epoch)
Sets the epoch for this sampler.
Parameters
Name | Type | Description | Default |
---|---|---|---|
epoch | int | Epoch number to use for shuffling. | required |