RLHF (Beta)

Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback.

Overview

Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback. Various methods include, but not limited to:

RLHF using Axolotl

Important

This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.

We rely on the TRL library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.

Tip

You can find what each method supports by going into src/axolotl/prompt_strategies/{method} where {method} is one of our supported methods. The type: can be retrieved from {method}.{function_name}.

DPO

Example config:

rl: dpo
datasets:
  - path: Intel/orca_dpo_pairs
    split: train
    type: chatml.intel
  - path: argilla/ultrafeedback-binarized-preferences
    split: train
    type: chatml

DPO supports the following types with the following dataset format:

chatml.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "chosen_response": "...",
    "rejected_response": "..."
}

chatml.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

chatml.icr

{
    "system": "...", // optional
    "input": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.intel

{
    "system": "...", // optional
    "question": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "chosen_response": "...",
    "rejected_response": "..."
}

llama3.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.icr

{
    "system": "...", // optional
    "input": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.intel

{
    "system": "...", // optional
    "question": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

zephyr.nectar

{
    "prompt": "...",
    "answers": [
        {
            "answer": "...",
            "rank": 1
        },
        {
            "answer": "...",
            "rank": 2
        }
        // ... more answers with ranks
    ]
}

chat_template.default

rl: dpo
datasets:
  - path: ...
    split: train
    type: chat_template.default
    field_messages: "messages"
    field_chosen: "chosen"
    field_rejected: "rejected"
    message_property_mappings:
      role: role
      content: content
    roles:
      user: ["user"]
      assistant: ["assistant"]
      system: ["system"]

Sample input format:

{
    "messages": [
        {
            "role": "system",
            "content": "..."
        },
        {
            "role": "user",
            "content": "..."
        },
        // ... more messages
    ],
    "chosen": {
        "role": "assistant",
        "content": "..."
    },
    "rejected": {
        "role": "assistant",
        "content": "..."
    }
}

user_defined.default

For custom behaviors,

rl: dpo
datasets:
  - path: ...
    split: train
    type: user_defined.default

    field_prompt: "prompt"
    field_system: "system"
    field_chosen: "chosen"
    field_rejected: "rejected"
    prompt_format: "{prompt}"
    chosen_format: "{chosen}"
    rejected_format: "{rejected}"

The input format is a simple JSON input with customizable fields based on the above config.

{
    "system": "...",  // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

IPO

As IPO is just DPO with a different loss function, all supported dataset formats for DPO are also supported for IPO.

rl: ipo

ORPO

Paper: https://arxiv.org/abs/2403.07691

rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false

chat_template: chatml
datasets:
  - path: argilla/ultrafeedback-binarized-preferences-cleaned
    type: chat_template.argilla

ORPO supports the following types with the following dataset format:

chat_template.argilla

{
    "system": "...",  // optional
    "prompt": "...",  // if available, will be taken as user message for single-turn instead of from list below

    // chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

KTO

rl: kto
rl_beta: 0.1  # default
kto_desirable_weight: 1.0  # default
kto_undesirable_weight: 1.0  # default

remove_unused_columns: false

datasets:
  - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
    type: llama3.ultra
    split: train

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true

KTO supports the following types with the following dataset format:

chatml.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "completion": "..."
}

chatml.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."}
    ],
    "completion": [
        {"role": "assistant", "content": "..."}
    ]
}

chatml.intel

{
    "system": "...", // optional
    "question": "...",
    "completion": "..."
}

chatml.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

chatml.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

llama3.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "completion": "..."
}

llama3.argilla_chat

{
    "completion": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.intel

{
    "system": "...", // optional
    "question": "...",
    "completion": "..."
}

llama3.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

llama3.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

user_defined.default

For custom behaviors,

rl: kto
datasets:
  - path: ...
    split: train
    type: user_defined.default

    field_prompt: "prompt"
    field_system: "system"
    field_completion: "completion"
    field_label: "label"
    prompt_format: "{prompt}"
    completion_format: "{completion}"

The input format is a simple JSON input with customizable fields based on the above config.

{
    "system": "...",  // optional
    "prompt": "...",
    "completion": "...",
    "label": "..."
}

GRPO

Tip

Check out our GRPO cookbook.

If you have multiple GPUs available, we reccomend using vLLM with the GRPOTrainer to significantly speedup trajectory generation during training. First, launch a vLLM server using trl vllm-serve - you may use a config file or CLI overrides to configure your vLLM server. In this example, we’re using 4 GPUs - 2 for training, and 2 for vLLM:

Important

Make sure you’ve installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. pip install axolotl[vllm].

base_model: Qwen/Qwen2.5-1.5B-Instruct

vllm:
    host: 0.0.0.0
    port: 8000
    tensor_parallel_size: 2
    gpu_memory_utilization: 0.85
    dtype: auto
    # max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand

rl: grpo
trl:
    use_vllm: true
    vllm_server_host: 0.0.0.0
    vllm_server_port: 8000
    vllm_server_timeout: 300
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yaml

Your vLLM instance will now attempt to spin up, and it’s time to kick off training utilizing our remaining two GPUs. In another terminal, execute:

CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2

Reward functions

GRPO uses custom reward functions and transformations. Please have them ready locally.

For example, to load OpenAI’s GSM8K and use a random reward for completions:

# rewards.py
import random

def rand_reward_func(completions, **kwargs) -> list[float]:
    return [random.uniform(0, 1) for _ in completions]

def oai_gsm8k_transform(cfg, *args, **kwargs):
    def transform_fn(example, tokenizer=None):
        label = example["answer"].split("####")[-1].strip().replace(",", "")
        return {
            "prompt": [{"role": "user", "content": example["question"]},],
            "answer": label,
        }
    return transform_fn, {"remove_columns": ["question"]}
rl: grpo

trl:
    beta: 0.001
    max_completion_length: 256
    use_vllm: True
    num_generations: 4
    reward_funcs: ["rewards.rand_reward_func"]    # format: '{file_name}.{fn_name}'
    reward_weights: [1.0]
datasets:
  - path: openai/gsm8k
    name: main
    type: rewards.oai_gsm8k_transform  # format: '{file_name}.{fn_name}'

To see other examples of custom reward functions, please see TRL GRPO Docs.

To see description of the configs, please see TRLConfig.

SimPO

SimPO uses CPOTrainer but with alternative loss function.

rl: simpo
rl_beta: 0.1  # default in CPOTrainer
cpo_alpha: 1.0  # default in CPOTrainer
simpo_gamma: 0.5  # default in CPOTrainer

This method uses the same dataset format as DPO.

Using local dataset files

datasets:
  - ds_type: json
    data_files:
      - orca_rlhf.jsonl
    split: train
    type: chatml.intel

TRL auto-unwrapping for PEFT

TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:

# load ref model when adapter training.
rl_adapter_ref_model: true