core.trainers.grpo.async_trainer

core.trainers.grpo.async_trainer

Async GRPO training with streaming scoring and IS correction.

Works on stock TRL v0.29.0 and transformers v5.3.0 — no custom branches needed.

Features

  • Async prefetch: background thread generates completions via vLLM while the main thread trains on the previous rollout.
  • Deferred scoring: rewards, advantages, and policy logprobs computed on the main thread (thread-safe with GPU forward passes).
  • Streaming group scoring: scores prompt groups incrementally so that reward computation overlaps with the next group’s logprob computation.
  • Importance sampling (IS) correction: corrects for stale vLLM weights.
  • Off-Policy Sequence Mask (OPSM): drops sequences with high KL + negative advantage.
  • Configurable vLLM weight sync interval.

Classes exported

  • AsyncGRPOConfig: GRPOConfig extended with async/streaming/IS fields
  • AsyncGRPOTrainer: GRPOTrainer with async prefetch and IS correction
  • ProducerConfig, DataProducer, BaseDataProducer, AsyncDataProducer: data producer protocol

Classes

Name Description
AsyncDataProducer Wraps a synchronous :class:DataProducer for background-thread data generation.
AsyncGRPOConfig GRPOConfig extended with async prefetch, streaming scoring, and IS correction fields.
AsyncGRPOTrainer GRPOTrainer with async prefetch, streaming scoring, and IS correction.
BaseDataProducer Convenience base class with a default :class:ProducerConfig and lifecycle hooks.
DataProducer Abstract base class for online data producers.
DataProducerCallback Marker class: if a DataProducer also inherits from this, the Trainer will
GRPODataProducer Produces GRPO training rollouts using the trainer’s generation pipeline.
ProducerConfig Configuration for a :class:DataProducer.
RolloutDataset A Dataset wrapping the output dict from _generate_and_score_completions.

AsyncDataProducer

core.trainers.grpo.async_trainer.AsyncDataProducer(
    inner,
    background_produce_kwargs=None,
)

Wraps a synchronous :class:DataProducer for background-thread data generation.

While the Trainer trains on the current rollout, this wrapper produces upcoming datasets in a background thread.

FSDP compatibility: Background threads must NOT call cross-rank collectives (gather_object, broadcast_object_list, FSDP all-gather) because the main thread may be doing FSDP forward/backward concurrently, causing deadlocks. When num_processes > 1, only rank 0 runs BG generation; results are broadcast to other ranks on the main thread when produce() is next called.

Methods

Name Description
produce Return the next dataset, blocking if the prefetch hasn’t finished.
shutdown Shut down the background thread pool and cancel pending futures.
produce
core.trainers.grpo.async_trainer.AsyncDataProducer.produce(
    model,
    global_step,
    **kwargs,
)

Return the next dataset, blocking if the prefetch hasn’t finished.

shutdown
core.trainers.grpo.async_trainer.AsyncDataProducer.shutdown()

Shut down the background thread pool and cancel pending futures.

AsyncGRPOConfig

core.trainers.grpo.async_trainer.AsyncGRPOConfig(
    use_data_producer=False,
    async_prefetch=False,
    prefetch_depth=1,
    vllm_sync_interval=1,
    batch_flattening=False,
    streaming_partial_batch=False,
    streaming_min_groups=1,
    vllm_importance_sampling_correction=True,
    vllm_importance_sampling_mode='token_truncate',
    vllm_importance_sampling_cap=3.0,
    off_policy_mask_threshold=None,
    use_bias_correction_kl=False,
)

GRPOConfig extended with async prefetch, streaming scoring, and IS correction fields.

Fields already present in stock GRPOConfig (e.g. importance_sampling_level, multi_objective_aggregation) are listed here for safety: if the stock version does not define them, the defaults below ensure everything works.

AsyncGRPOTrainer

core.trainers.grpo.async_trainer.AsyncGRPOTrainer(*args, **kwargs)

GRPOTrainer with async prefetch, streaming scoring, and IS correction.

Drop-in replacement: pass AsyncGRPOConfig as args and use this trainer instead of GRPOTrainer.

Methods

Name Description
get_off_policy_mask OPSM from DeepSeek-V3.2: drop sequences with negative advantage + high KL.
get_off_policy_mask
core.trainers.grpo.async_trainer.AsyncGRPOTrainer.get_off_policy_mask(
    advantages,
    per_token_logps,
    sampling_per_token_logps,
    mask,
    off_policy_threshold,
)

OPSM from DeepSeek-V3.2: drop sequences with negative advantage + high KL.

BaseDataProducer

core.trainers.grpo.async_trainer.BaseDataProducer(config=None)

Convenience base class with a default :class:ProducerConfig and lifecycle hooks.

Methods

Name Description
on_rollout_begin Called before each produce() invocation.
on_rollout_end Called after each produce() invocation with the produced dataset.
on_rollout_begin
core.trainers.grpo.async_trainer.BaseDataProducer.on_rollout_begin(global_step)

Called before each produce() invocation.

on_rollout_end
core.trainers.grpo.async_trainer.BaseDataProducer.on_rollout_end(
    dataset,
    global_step,
)

Called after each produce() invocation with the produced dataset.

DataProducer

core.trainers.grpo.async_trainer.DataProducer()

Abstract base class for online data producers.

Subclass this and implement :meth:produce to supply fresh training data each rollout round.

Methods

Name Description
produce Generate a fresh training dataset.
produce
core.trainers.grpo.async_trainer.DataProducer.produce(
    model,
    global_step,
    *,
    processing_class=None,
    accelerator=None,
    args=None,
    **kwargs,
)

Generate a fresh training dataset.

DataProducerCallback

core.trainers.grpo.async_trainer.DataProducerCallback()

Marker class: if a DataProducer also inherits from this, the Trainer will automatically register it as a callback.

GRPODataProducer

core.trainers.grpo.async_trainer.GRPODataProducer(
    config,
    prompt_dataset,
    *,
    num_generations,
    generation_batch_size,
    train_batch_size,
    steps_per_generation,
    shuffle_dataset,
    seed,
)

Produces GRPO training rollouts using the trainer’s generation pipeline.

Created before Trainer.__init__ completes; the trainer reference is injected later via set_trainer().

Methods

Name Description
produce Generate a fresh GRPO training rollout.
set_trainer Inject the live trainer reference and create the prompt DataLoader.
produce
core.trainers.grpo.async_trainer.GRPODataProducer.produce(
    model,
    global_step,
    *,
    skip_policy_logps=False,
    processing_class=None,
    accelerator=None,
    args=None,
    _rank0_only=False,
    **kwargs,
)

Generate a fresh GRPO training rollout.

set_trainer
core.trainers.grpo.async_trainer.GRPODataProducer.set_trainer(trainer)

Inject the live trainer reference and create the prompt DataLoader.

ProducerConfig

core.trainers.grpo.async_trainer.ProducerConfig(
    mini_epochs=1,
    max_rollouts=None,
    steps_per_generation=None,
    num_iterations=1,
    async_prefetch=False,
    prefetch_depth=1,
    sync_warmup_rollouts=0,
    eval_during_produce=True,
    empty_cache_before_produce=False,
    empty_cache_after_produce=False,
)

Configuration for a :class:DataProducer.

Parameters

Name Type Description Default
mini_epochs int Number of training passes over each produced dataset. 1
max_rollouts int | None Maximum number of produce-then-train rounds (None = unlimited). None
steps_per_generation int | None Optimisation steps per produced dataset before regenerating. None
num_iterations int Number of times to reuse each generation across optimisation steps. 1
async_prefetch bool Produce the next dataset in a background thread. False
prefetch_depth int How many rollouts to queue ahead when async. 1
sync_warmup_rollouts int Initial on-policy rollouts before switching to async. 0
eval_during_produce bool Switch model to eval() during produce(). True
empty_cache_before_produce bool torch.cuda.empty_cache() before produce(). False
empty_cache_after_produce bool torch.cuda.empty_cache() after produce(). False

RolloutDataset

core.trainers.grpo.async_trainer.RolloutDataset(data)

A Dataset wrapping the output dict from _generate_and_score_completions.

Per-sample tensors are sliced by index; shared metadata is passed through.

Functions

Name Description
make_rollout_collator Return a collator that stacks per-sample tensors and passes shared keys through.

make_rollout_collator

core.trainers.grpo.async_trainer.make_rollout_collator(shared_keys)

Return a collator that stacks per-sample tensors and passes shared keys through.