EBFT Training

Energy-Based Fine-Tuning uses feature-matching rewards from internal representations to train language models without external reward functions.

Overview

Energy-Based Fine-Tuning (EBFT) is a training method that optimizes language models by matching the internal feature representations of generated text to those of ground-truth completions. Instead of relying on external reward models or hand-crafted reward functions, EBFT extracts hidden states from intermediate layers of a frozen copy of the model and uses cosine similarity between generated and reference features as the reward signal.

Paper: “Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models” (Jelassi et al., 2026)

How EBFT Differs from Other RL Methods

Method Reward Signal Requires Best For
GRPO External reward function(s) Custom reward code or reward model Tasks with verifiable answers (math, code)
DPO Preference pairs (chosen vs rejected) Paired preference data Alignment with human preferences
EBFT Feature similarity to ground truth Ground-truth completions Any task with reference outputs

EBFT’s key advantage is that it needs only ground-truth completions – no reward engineering, no preference annotation, and no reward model training. The model’s own internal representations serve as the reward signal. This makes it particularly effective for:

  • Code generation (match features of known-good solutions)
  • Instruction following with reference outputs
  • Continual pretraining on unstructured text (strided mode)
  • Multi-turn dialogue with reference conversations

Reward Formulation

The EBFT reward for each generated completion is:

reward = alignment_coef * cosine_similarity(gen_features, gt_features)
       - diversity_coef * mean_pairwise_similarity(gen_features)
  • Alignment: How closely the generated output’s internal representations match the ground truth. Higher is better.
  • Diversity: Penalizes generated samples that are too similar to each other (prevents mode collapse). Lower is better.
  • CFM loss (Cross-Feature Matching): Tracks ||mean(gen_features) - gt_features||^2 as a diagnostic. This is the quantity that EBFT ultimately minimizes.

Modes

EBFT supports three operational modes, each suited to different use cases.

Structured Mode (Sync)

Uses vLLM on a separate GPU for generation, with sequential generate-score-train steps. This is the simplest mode and recommended for getting started.

GPU 0: vLLM Server (generates completions, receives weight syncs)
GPU 1: Trainer (feature extraction, reward computation, GRPO training)

When to use: Standard instruction-following or QA datasets where you have prompt/completion pairs. Requires 2 GPUs.

Structured Mode (Async)

Same architecture as sync, but overlaps generation of the next batch with training on the current batch. Faster throughput at the cost of slightly stale weights during generation.

When to use: Same data as sync mode, but when you want faster training and can tolerate weight staleness (controlled by vllm_sync_interval).

Strided Mode

Runs entirely on a single GPU with no vLLM dependency. Places anchor points throughout a document and generates short rollouts at each anchor using block-parallel attention patterns.

Single GPU: Base model + LoRA adapter
  - Strided block-parallel generation (flex_attention)
  - Feature extraction via disable_adapter()
  - No vLLM needed

When to use: Unstructured text data (raw code, prose, documents) where there is no natural prompt/completion split. Also works with structured data that includes prompt boundaries. Requires only 1 GPU.

Quick Start

Structured Mode

This minimal example fine-tunes Qwen2-0.5B on code data using EBFT with vLLM generation.

Step 1: Create a config file ebft_quickstart.yaml:

base_model: Qwen/Qwen2-0.5B-Instruct

rl: ebft

ebft:
  feature_layers: [0.25, 0.5, 0.75]
  embed_method: last_token
  alignment_coef: 1.0
  diversity_coef: 1.0

trl:
  num_generations: 4
  max_completion_length: 256
  temperature: 0.7
  use_vllm: true
  vllm_server_host: 0.0.0.0
  vllm_server_port: 8000
  vllm_lora_sync: true
  vllm_sync_interval: 3
  use_data_producer: true
  async_prefetch: false
  scale_rewards: true
  loss_type: grpo

vllm:
  gpu_memory_utilization: 0.5
  max_model_len: 1024

datasets:
  - path: nvidia/OpenCodeInstruct
    type: ebft_opencode.transform
    split: train[:500]

# Standard training settings (see getting-started.qmd for details)
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_linear: true
sequence_len: 1024
micro_batch_size: 2
gradient_accumulation_steps: 4
max_steps: 20
learning_rate: 5.0e-6
bf16: auto
flash_attention: true
gradient_checkpointing: true
output_dir: ./outputs/ebft-quickstart

Step 2: Start vLLM on GPU 0:

CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve ebft_quickstart.yaml

Step 3: Wait approximately 30 seconds for vLLM to initialize, then start training on GPU 1:

CUDA_VISIBLE_DEVICES=1 axolotl train ebft_quickstart.yaml
Important

The micro_batch_size must be divisible by num_generations. For example, with num_generations: 4, valid values are 4, 8, 12, etc.

Dataset Format

Structured mode datasets must produce two fields after the transform:

  • prompt: Either a string or a list of chat messages ([{"role": "user", "content": "..."}])
  • ground_truth: A string containing the reference completion

Example raw dataset row:

{
  "input": "Write a function to compute fibonacci numbers.",
  "output": "def fibonacci(n):\n    if n <= 1:\n        return n\n    return fibonacci(n-1) + fibonacci(n-2)"
}

The ebft_opencode.transform converts this to the required {prompt, ground_truth} format automatically.

Feature Extraction

EBFT extracts hidden states from intermediate transformer layers and pools them into per-sequence embeddings. These embeddings are compared between generated and ground-truth completions to compute rewards.

Feature Layers

The feature_layers parameter specifies which layers to extract, as fractions of total model depth:

ebft:
  feature_layers: [0.25, 0.5, 0.75]  # Quarter, middle, three-quarter depth

For a 32-layer model, this extracts layers 8, 16, and 24. The hidden states from all selected layers are concatenated along the feature dimension, producing embeddings of size num_layers * hidden_dim.

Tip

Using multiple layers captures both low-level syntactic features (early layers) and high-level semantic features (later layers). The default [0.25, 0.5, 0.75] works well across model sizes.

Embed Methods

The embed_method controls how per-token hidden states are pooled into a single vector per sequence:

Method Description Output Shape Notes
last_token Hidden state at the last non-padding token (B, D) Default. Good for autoregressive models where the last token summarizes the sequence.
mean_pooling Mean of all non-padding token states (B, D) Considers the entire sequence equally.
completion_mean Mean over completion tokens only (excludes prompt) (B, D) Focuses reward signal on generated content. Requires prompt length information.
concat Concatenation of states at 25%, 50%, 75% positions (B, 3*D) Captures positional structure. Higher dimensional.
ebft:
  embed_method: completion_mean  # Focus on completion features

SVD Whitening

Whitening decorrelates the feature dimensions so that no single direction dominates the feature-matching loss. This is computed via SVD on the generated embeddings, with the same transform applied to the ground-truth embeddings.

ebft:
  use_whitening: true

When whitening is enabled, the reward computation applies a whitening matrix W = U @ diag(1/S) @ U^T derived from the SVD of generated embeddings. This ensures all feature dimensions contribute equally to the alignment reward.

Note

Singular values scale with sqrt(batch_size), so reward magnitudes are batch-size dependent. This is acceptable because the number of samples per prompt (n_samples_per_prompt or num_generations) is fixed during training.

Alignment and Diversity Coefficients

The two reward components are weighted by coefficients:

ebft:
  alignment_coef: 1.0   # Weight for cosine similarity with ground truth
  diversity_coef: 1.0   # Weight for pairwise similarity penalty

Both values are scaled by 2 internally (per paper equation 7). The final reward per sample is:

reward_j = 2 * alignment_coef * cos(gen_j, gt)
         - 2 * diversity_coef * (1/(n-1)) * sum_{j' != j} dot(gen_j, gen_j')

Setting diversity_coef: 0.0 disables the diversity penalty entirely, which may be appropriate when num_generations is small (e.g., 2).

Strided Mode

Strided mode is designed for training on unstructured text data where there is no natural prompt/completion boundary. Instead of generating full completions with vLLM, it places anchor points at regular intervals throughout each document and generates short rollouts at each anchor using block-parallel attention.

How Block-Parallel Generation Works

Given a document of length S tokens:

  1. Anchor placement: Starting at position anchor_offset, place anchors every stride tokens. Each anchor defines a block.
  2. Context window: Each block sees context_length tokens of preceding context from the original document.
  3. Generation: At each anchor, generate generate_max_len tokens autoregressively, conditioned only on the context window.
  4. Parallelism: All blocks are processed in a single forward pass using a specialized attention mask that prevents information leakage between blocks.
Document:   [tok0, tok1, ..., tok_S]
                    |         |         |
                 anchor_0   anchor_1  anchor_2
                    |         |         |
             [ctx][gen]  [ctx][gen]  [ctx][gen]

The attention mask ensures:

  • Prompt tokens use standard causal attention
  • Each generated block attends to its own context window and its own preceding generated tokens
  • Blocks do not attend to each other’s generated tokens

When flex_attention is available (PyTorch >= 2.5), the mask is compiled into efficient fused kernels. Otherwise, a dense 4D attention mask is used as a fallback.

Strided Mode Configuration

base_model: meta-llama/Llama-3.2-1B
rl: ebft

ebft:
  mode: strided
  stride: 8                    # Tokens between anchor points
  context_length: 8            # Context window per block
  generate_max_len: 8          # Tokens to generate per block
  n_samples_per_prompt: 4      # Independent rollouts per document
  temperature: 0.6
  feature_layers: [0.25, 0.5, 0.75]
  embed_method: last_token
  use_whitening: true
  alignment_coef: 1.0
  diversity_coef: 1.0
  rl_coef: 1.0                # RL policy gradient loss weight
  ce_coef: 0.03               # Cross-entropy loss on GT tokens
  advantage_estimator: rloo    # rloo, group_norm, or reinforce
  min_completion_prefix: 8     # Skip anchors in prompt region

datasets:
  - path: nvidia/OpenCodeInstruct
    type: ebft_strided_structured.transform
    split: train[:1%]

sequence_len: 2048
micro_batch_size: 1
gradient_accumulation_steps: 2

adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_linear: true

bf16: auto
flex_attention: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true          # Required with flex_attention

Run with a single command (no vLLM needed):

CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml

Advantage Estimators

Strided mode supports three advantage estimation methods:

Estimator Formula Requirements
rloo Leave-one-out baseline: reward_j - mean(rewards_{-j}) n_samples_per_prompt >= 2
group_norm Group normalization: (reward_j - mean) / std n_samples_per_prompt >= 2
reinforce Raw reward as advantage (no baseline) Works with n_samples_per_prompt = 1
Warning

When n_samples_per_prompt: 1, the trainer automatically falls back to reinforce and disables the diversity penalty (which requires multiple samples).

Strided Mode Constraints

  • flex_attention: true is strongly recommended. Without it, dense 4D masks consume significantly more memory.
  • torch_compile: true must NOT be set. flex_attention compiles its own kernels internally; adding torch_compile causes conflicts and OOM.
  • Gradient checkpointing must use use_reentrant: true. Non-reentrant checkpointing causes CheckpointError with flex_attention block masks.
  • activation_offloading is incompatible with flex_attention.

Cross-Entropy Loss

Strided mode supports an optional cross-entropy loss term on ground-truth tokens. This acts as a regularizer to prevent the model from drifting too far from the original distribution:

ebft:
  ce_coef: 0.03    # Small CE coefficient
  rl_coef: 1.0     # RL loss coefficient

The total loss is rl_coef * rl_loss + ce_coef * ce_loss. For structured mode, ce_coef is typically 0.0 since vLLM generation provides sufficient learning signal.

Dataset Formats

EBFT provides several built-in dataset transforms in src/axolotl/prompt_strategies/ebft/.

Built-In Transforms

Transform Input Format Output Fields Use Case
ebft_opencode.transform {input, output} {prompt, ground_truth} OpenCodeInstruct, structured QA
ebft_strided_structured.transform {input, output} {input_ids, labels, prompt_length} Strided mode with structured data
ebft_strided_chat.transform {messages: [...]} {input_ids, labels, prompt_length} Strided mode with chat data
ebft_chat_multiturn.transform {messages: [...]} {prompt, ground_truth, remaining_turns} Multi-turn: first-turn target
ebft_chat_multiturn.transform_last_turn {messages: [...]} {prompt, ground_truth} Multi-turn: last-turn target
ebft_chat_multiturn.transform_all_turns {messages: [...]} {prompt[], ground_truth[]} Multi-turn: one example per turn
ebft_reasoning.transform {messages: [...]} (with <think>) {prompt, ground_truth} Reasoning/thinking datasets

Structured Mode Datasets

For structured (sync/async) mode, the transform must produce prompt and ground_truth fields:

datasets:
  - path: nvidia/OpenCodeInstruct
    type: ebft_opencode.transform
    split: train[:500]

Multi-Turn Datasets

Multi-turn transforms extract conversation data for sequential rollout. The transform variant targets the first assistant turn, while transform_last_turn targets the final turn:

datasets:
  - path: your/multiturn-dataset
    type: ebft_chat_multiturn.transform

When remaining_turns is present in the dataset output, the trainer performs sequential rollouts: it generates the first assistant turn with vLLM, then continues generating subsequent turns by building up the conversation history.

Strided Mode Datasets

Strided transforms tokenize the full document and produce input_ids, labels, and prompt_length:

datasets:
  - path: nvidia/OpenCodeInstruct
    type: ebft_strided_structured.transform
    split: train[:1%]

Custom Transforms

To use your own dataset format, write a transform function:

def transform(cfg, **kwargs):
    def transform_fn(example, tokenizer=None):
        return {
            "prompt": [{"role": "user", "content": example["question"]}],
            "ground_truth": example["answer"],
        }
    return transform_fn, {"remove_columns": "__all__"}

The "__all__" sentinel removes all original dataset columns after the mapping step. Reference this transform in your config:

datasets:
  - path: your/dataset
    type: your_module.transform

Configuration Reference

Common Parameters (All Modes)

These parameters are set under the ebft: key in the YAML config.

Parameter Type Default Description
mode "structured" or "strided" "structured" EBFT operating mode
feature_layers list[float] [0.25, 0.5, 0.75] Fractional layer depths for feature extraction
embed_method string "last_token" Pooling method: last_token, mean_pooling, completion_mean, or concat
use_whitening bool false Apply SVD whitening to feature embeddings before reward computation
alignment_coef float 1.0 Weight for alignment reward (cosine similarity with ground truth)
diversity_coef float 1.0 Weight for diversity penalty (pairwise dot product between samples)
ce_coef float 0.0 Cross-entropy loss coefficient on ground-truth tokens
adaptive_max_tokens bool true Dynamically set vLLM max_tokens based on ground-truth length (structured mode)
gt_length_multiplier float 1.5 Multiplier for ground-truth token count when computing adaptive max tokens (min 0.1)

Strided Mode Parameters

These additional parameters apply only when mode: strided.

Parameter Type Default Description
stride int 8 Number of tokens between anchor points (must be >= 1)
context_length int 8 Context window size for each generated block (must be >= 1)
generate_max_len int 8 Number of tokens to generate per block (must be >= 1)
n_samples_per_prompt int 4 Number of independent rollouts per document (must be >= 1)
temperature float 0.6 Sampling temperature for strided generation
top_p float 1.0 Top-p nucleus sampling threshold
rl_coef float 1.0 RL policy gradient loss coefficient
advantage_estimator string "rloo" Advantage estimation method: rloo, group_norm, or reinforce
min_completion_prefix int 0 Minimum tokens into the completion span before placing anchors

Structured Mode TRL Parameters

These are set under the trl: key and control the GRPO training loop.

Parameter Type Default Description
num_generations int Number of completions generated per prompt
max_completion_length int Maximum tokens per generated completion
temperature float 0.7 Sampling temperature for vLLM generation
use_vllm bool Enable vLLM generation backend
vllm_lora_sync bool false Sync LoRA adapters via filesystem (recommended)
vllm_sync_interval int 1 Steps between weight syncs to vLLM
use_data_producer bool Required for sync mode with LoRA sync
async_prefetch bool false Enable async generation (overlaps with training)
streaming_partial_batch bool false Score groups incrementally (async mode)
skip_zero_advantage_batches bool false Skip micro-batches where all advantages are zero
scale_rewards bool Normalize rewards within each prompt group
loss_type string "grpo" Loss type for policy optimization
epsilon float 0.2 Clipping parameter for importance sampling

Stop Tokens

vLLM needs explicit stop token IDs for generation. Common configurations:

trl:
  generation_kwargs:
    stop_token_ids: [151645, 151643]   # Qwen: <|im_end|>, <|endoftext|>

Multi-Turn Chat Settings

For multi-turn conversations with Qwen3.5, disable thinking mode to prevent <think> tags in completions:

trl:
  chat_template_kwargs:
    enable_thinking: false

Monitoring

Key Metrics

EBFT logs several custom metrics to wandb and the training console. Here is what to watch for:

Metric Healthy Range Interpretation
ebft/alignment 0.3 – 0.9, trending upward Cosine similarity between generated and ground-truth features. Higher means the model is learning to produce representations that match the reference.
ebft/diversity 0.01 – 0.1 Mean pairwise similarity between different generations for the same prompt. Values above 1.0 indicate mode collapse.
ebft/cfm_loss Below 10, trending downward Cross-Feature Matching loss. This is the core quantity being minimized. Consistently above 100 indicates instability.
ebft/reward Trending upward (may start negative) Combined reward signal. If stuck at -1.0, the diversity penalty is dominating alignment.
grad_norm 0.1 – 3.0 Gradient magnitude. Values of 0.0 indicate zero-advantage skip (normal). Values above 10 suggest instability.
entropy 0.05 – 0.5 Policy entropy. Values below 0.01 suggest mode collapse.
IS ratio min Above 0.1 Importance sampling ratio minimum. Near-zero values mean the policy is too far off-policy; increase vllm_sync_interval.

Console Log Example

During training, you will see periodic EBFT reward logs:

ebft reward | align +0.412 ^ | divers +0.023 v | cfm 4.231 v | reward +0.389 ^

The arrows indicate the desired direction: alignment and reward should trend upward, while diversity and CFM loss should trend downward.

Troubleshooting

Symptom Likely Cause Fix
alignment stays below 0.1 Feature layers not capturing useful information Try different feature_layers or embed_method
diversity exceeds 1.0 Mode collapse – generations are too similar Increase diversity_coef or temperature
reward stuck at -1.0 Diversity penalty dominates alignment Reduce diversity_coef or increase alignment_coef
grad_norm consistently 0.0 All micro-batches have zero advantage Increase num_generations or check data quality
CheckpointError in strided mode Incompatible gradient checkpointing settings Set use_reentrant: true in gradient_checkpointing_kwargs
OOM during training Logits tensor too large Reduce sequence_len or micro_batch_size; strided mode uses chunked lm_head to mitigate this
vLLM 500 errors truncate_prompt_tokens not supported Ensure you are using axolotl vllm-serve (not trl vllm-serve)

Feature Network Memory

In PEFT (LoRA) mode, the feature network shares base weights with the actor model by using the disable_adapter() context manager. This saves an entire model copy in VRAM (approximately 1–16 GB depending on model size). For non-PEFT training, a separate frozen deepcopy is created.

Note

The disable_adapter() approach relies on an invariant: merge_adapter() is never called on the base weights. All weight sync paths (LoRA sync, HTTP, NCCL) compute merged weights as new tensors or save the adapter to the filesystem, leaving base weights unmodified.

Examples

Complete example configurations are available in examples/ebft/:

Config Model Mode Description
llama-1b-ebft-strided-structured.yaml Llama 3.2 1B Strided Single-GPU strided training on code data
qwen3-4b-ebft-structured.yaml Qwen3 4B Structured (sync) Two-GPU structured training
qwen3-4b-ebft-structured-async.yaml Qwen3 4B Structured (async) Two-GPU async training with prefetch
qwen3-8b-ebft-structured.yaml Qwen3 8B Structured (sync) Two-GPU structured training for larger model
qwen35-4b-ebft-structured.yaml Qwen3.5 4B Structured (sync) Two-GPU with Qwen3.5
qwen35-4b-ebft-structured-async.yaml Qwen3.5 4B Structured (async) Two-GPU async with Qwen3.5
qwen35-9b-ebft-structured.yaml Qwen3.5 9B Structured (sync) Two-GPU structured for 9B model