integrations.nemo_gym.multi_turn
integrations.nemo_gym.multi_turn
Multi-turn rollout function for NeMo Gym environments.
Delegates multi-turn orchestration to NeMo Gym’s agent servers via the /run endpoint. The agent handles generation (by calling our vLLM server), tool execution, session management, and reward computation.
This follows the same pattern as TRL’s reference implementation at examples/scripts/nemo_gym/train_multi_environment.py.
Architecture
rollout_func(prompts, trainer) -> expand prompts by num_generations -> async POST /run to agent servers (one per sample) -> parse response: prompt_ids, completion_ids, logprobs, env_mask, reward -> return to TRL for GRPO training
Functions
| Name | Description |
|---|---|
| create_nemo_gym_rollout_func | Create a TRL-compatible rollout_func that delegates to NeMo Gym agents. |
create_nemo_gym_rollout_func
integrations.nemo_gym.multi_turn.create_nemo_gym_rollout_func(
agent_servers,
dataset_lookup,
request_timeout=10800,
)Create a TRL-compatible rollout_func that delegates to NeMo Gym agents.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| agent_servers | dict[str, str] | Mapping of agent_name → agent URL (e.g., {“simple_agent”: “http://host:port”}). | required |
| dataset_lookup | dict[int, dict] | Mapping of dataset index → full JSONL row dict. | required |
| request_timeout | float | HTTP timeout for /run requests. | 10800 |
Returns
| Name | Type | Description |
|---|---|---|
| A rollout_func with signature (prompts: list[str], trainer) -> dict. |