core.trainers.grpo.async_trainer
core.trainers.grpo.async_trainer
Async GRPO training with streaming scoring and IS correction.
Works on stock TRL v0.29.0 and transformers v5.3.0 — no custom branches needed.
Features
- Async prefetch: background thread generates completions via vLLM while the main thread trains on the previous rollout.
- Deferred scoring: rewards, advantages, and policy logprobs computed on the main thread (thread-safe with GPU forward passes).
- Streaming group scoring: scores prompt groups incrementally so that reward computation overlaps with the next group’s logprob computation.
- Importance sampling (IS) correction: corrects for stale vLLM weights.
- Off-Policy Sequence Mask (OPSM): drops sequences with high KL + negative advantage.
- Configurable vLLM weight sync interval.
Classes exported
- AsyncGRPOConfig: GRPOConfig extended with async/streaming/IS fields
- AsyncGRPOTrainer: GRPOTrainer with async prefetch and IS correction
- ProducerConfig, DataProducer, BaseDataProducer, AsyncDataProducer: data producer protocol
Classes
| Name | Description |
|---|---|
| AsyncDataProducer | Wraps a synchronous :class:DataProducer for background-thread data generation. |
| AsyncGRPOConfig | GRPOConfig extended with async prefetch, streaming scoring, and IS correction fields. |
| AsyncGRPOTrainer | GRPOTrainer with async prefetch, streaming scoring, and IS correction. |
| BaseDataProducer | Convenience base class with a default :class:ProducerConfig and lifecycle hooks. |
| DataProducer | Abstract base class for online data producers. |
| DataProducerCallback | Marker class: if a DataProducer also inherits from this, the Trainer will |
| GRPODataProducer | Produces GRPO training rollouts using the trainer’s generation pipeline. |
| ProducerConfig | Configuration for a :class:DataProducer. |
| RolloutDataset | A Dataset wrapping the output dict from _generate_and_score_completions. |
AsyncDataProducer
core.trainers.grpo.async_trainer.AsyncDataProducer(
inner,
background_produce_kwargs=None,
)Wraps a synchronous :class:DataProducer for background-thread data generation.
While the Trainer trains on the current rollout, this wrapper produces upcoming datasets in a background thread.
FSDP compatibility: Background threads must NOT call cross-rank collectives
(gather_object, broadcast_object_list, FSDP all-gather) because the main thread
may be doing FSDP forward/backward concurrently, causing deadlocks. When
num_processes > 1, only rank 0 runs BG generation; results are broadcast
to other ranks on the main thread when produce() is next called.
Methods
| Name | Description |
|---|---|
| produce | Return the next dataset, blocking if the prefetch hasn’t finished. |
| shutdown | Shut down the background thread pool and cancel pending futures. |
produce
core.trainers.grpo.async_trainer.AsyncDataProducer.produce(
model,
global_step,
**kwargs,
)Return the next dataset, blocking if the prefetch hasn’t finished.
shutdown
core.trainers.grpo.async_trainer.AsyncDataProducer.shutdown()Shut down the background thread pool and cancel pending futures.
AsyncGRPOConfig
core.trainers.grpo.async_trainer.AsyncGRPOConfig(
use_data_producer=False,
async_prefetch=False,
prefetch_depth=1,
vllm_sync_interval=1,
batch_flattening=False,
streaming_partial_batch=False,
streaming_min_groups=1,
vllm_importance_sampling_correction=True,
vllm_importance_sampling_mode='token_truncate',
vllm_importance_sampling_cap=3.0,
off_policy_mask_threshold=None,
use_bias_correction_kl=False,
)GRPOConfig extended with async prefetch, streaming scoring, and IS correction fields.
Fields already present in stock GRPOConfig (e.g. importance_sampling_level,
multi_objective_aggregation) are listed here for safety: if the stock version
does not define them, the defaults below ensure everything works.
AsyncGRPOTrainer
core.trainers.grpo.async_trainer.AsyncGRPOTrainer(*args, **kwargs)GRPOTrainer with async prefetch, streaming scoring, and IS correction.
Drop-in replacement: pass AsyncGRPOConfig as args and use this trainer
instead of GRPOTrainer.
Methods
| Name | Description |
|---|---|
| get_off_policy_mask | OPSM from DeepSeek-V3.2: drop sequences with negative advantage + high KL. |
get_off_policy_mask
core.trainers.grpo.async_trainer.AsyncGRPOTrainer.get_off_policy_mask(
advantages,
per_token_logps,
sampling_per_token_logps,
mask,
off_policy_threshold,
)OPSM from DeepSeek-V3.2: drop sequences with negative advantage + high KL.
BaseDataProducer
core.trainers.grpo.async_trainer.BaseDataProducer(config=None)Convenience base class with a default :class:ProducerConfig and lifecycle hooks.
Methods
| Name | Description |
|---|---|
| on_rollout_begin | Called before each produce() invocation. |
| on_rollout_end | Called after each produce() invocation with the produced dataset. |
on_rollout_begin
core.trainers.grpo.async_trainer.BaseDataProducer.on_rollout_begin(global_step)Called before each produce() invocation.
on_rollout_end
core.trainers.grpo.async_trainer.BaseDataProducer.on_rollout_end(
dataset,
global_step,
)Called after each produce() invocation with the produced dataset.
DataProducer
core.trainers.grpo.async_trainer.DataProducer()Abstract base class for online data producers.
Subclass this and implement :meth:produce to supply fresh training data
each rollout round.
Methods
| Name | Description |
|---|---|
| produce | Generate a fresh training dataset. |
produce
core.trainers.grpo.async_trainer.DataProducer.produce(
model,
global_step,
*,
processing_class=None,
accelerator=None,
args=None,
**kwargs,
)Generate a fresh training dataset.
DataProducerCallback
core.trainers.grpo.async_trainer.DataProducerCallback()Marker class: if a DataProducer also inherits from this, the Trainer will automatically register it as a callback.
GRPODataProducer
core.trainers.grpo.async_trainer.GRPODataProducer(
config,
prompt_dataset,
*,
num_generations,
generation_batch_size,
train_batch_size,
steps_per_generation,
shuffle_dataset,
seed,
)Produces GRPO training rollouts using the trainer’s generation pipeline.
Created before Trainer.__init__ completes; the trainer reference is injected later via set_trainer().
Methods
| Name | Description |
|---|---|
| produce | Generate a fresh GRPO training rollout. |
| set_trainer | Inject the live trainer reference and create the prompt DataLoader. |
produce
core.trainers.grpo.async_trainer.GRPODataProducer.produce(
model,
global_step,
*,
skip_policy_logps=False,
processing_class=None,
accelerator=None,
args=None,
_rank0_only=False,
**kwargs,
)Generate a fresh GRPO training rollout.
set_trainer
core.trainers.grpo.async_trainer.GRPODataProducer.set_trainer(trainer)Inject the live trainer reference and create the prompt DataLoader.
ProducerConfig
core.trainers.grpo.async_trainer.ProducerConfig(
mini_epochs=1,
max_rollouts=None,
steps_per_generation=None,
num_iterations=1,
async_prefetch=False,
prefetch_depth=1,
sync_warmup_rollouts=0,
eval_during_produce=True,
empty_cache_before_produce=False,
empty_cache_after_produce=False,
)Configuration for a :class:DataProducer.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| mini_epochs | int | Number of training passes over each produced dataset. | 1 |
| max_rollouts | int | None | Maximum number of produce-then-train rounds (None = unlimited). | None |
| steps_per_generation | int | None | Optimisation steps per produced dataset before regenerating. | None |
| num_iterations | int | Number of times to reuse each generation across optimisation steps. | 1 |
| async_prefetch | bool | Produce the next dataset in a background thread. | False |
| prefetch_depth | int | How many rollouts to queue ahead when async. | 1 |
| sync_warmup_rollouts | int | Initial on-policy rollouts before switching to async. | 0 |
| eval_during_produce | bool | Switch model to eval() during produce(). | True |
| empty_cache_before_produce | bool | torch.cuda.empty_cache() before produce(). | False |
| empty_cache_after_produce | bool | torch.cuda.empty_cache() after produce(). | False |
RolloutDataset
core.trainers.grpo.async_trainer.RolloutDataset(data)A Dataset wrapping the output dict from _generate_and_score_completions.
Per-sample tensors are sliced by index; shared metadata is passed through.
Functions
| Name | Description |
|---|---|
| make_rollout_collator | Return a collator that stacks per-sample tensors and passes shared keys through. |
make_rollout_collator
core.trainers.grpo.async_trainer.make_rollout_collator(shared_keys)Return a collator that stacks per-sample tensors and passes shared keys through.