EBFT Training
Overview
Energy-Based Fine-Tuning (EBFT) is a training method that optimizes language models by matching the internal feature representations of generated text to those of ground-truth completions. Instead of relying on external reward models or hand-crafted reward functions, EBFT extracts hidden states from intermediate layers of a frozen copy of the model and uses cosine similarity between generated and reference features as the reward signal.
Paper: “Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models” (Jelassi et al., 2026)
How EBFT Differs from Other RL Methods
| Method | Reward Signal | Requires | Best For |
|---|---|---|---|
| GRPO | External reward function(s) | Custom reward code or reward model | Tasks with verifiable answers (math, code) |
| DPO | Preference pairs (chosen vs rejected) | Paired preference data | Alignment with human preferences |
| EBFT | Feature similarity to ground truth | Ground-truth completions | Any task with reference outputs |
EBFT’s key advantage is that it needs only ground-truth completions – no reward engineering, no preference annotation, and no reward model training. The model’s own internal representations serve as the reward signal. This makes it particularly effective for:
- Code generation (match features of known-good solutions)
- Instruction following with reference outputs
- Continual pretraining on unstructured text (strided mode)
- Multi-turn dialogue with reference conversations
Reward Formulation
The EBFT reward for each generated completion is:
reward = alignment_coef * cosine_similarity(gen_features, gt_features)
- diversity_coef * mean_pairwise_similarity(gen_features)
- Alignment: How closely the generated output’s internal representations match the ground truth. Higher is better.
- Diversity: Penalizes generated samples that are too similar to each other (prevents mode collapse). Lower is better.
- CFM loss (Cross-Feature Matching): Tracks
||mean(gen_features) - gt_features||^2as a diagnostic. This is the quantity that EBFT ultimately minimizes.
Modes
EBFT supports three operational modes, each suited to different use cases.
Structured Mode (Sync)
Uses vLLM on a separate GPU for generation, with sequential generate-score-train steps. This is the simplest mode and recommended for getting started.
GPU 0: vLLM Server (generates completions, receives weight syncs)
GPU 1: Trainer (feature extraction, reward computation, GRPO training)
When to use: Standard instruction-following or QA datasets where you have prompt/completion pairs. Requires 2 GPUs.
Structured Mode (Async)
Same architecture as sync, but overlaps generation of the next batch with training on the current batch. Faster throughput at the cost of slightly stale weights during generation.
When to use: Same data as sync mode, but when you want faster training and can tolerate weight staleness (controlled by vllm_sync_interval).
Strided Mode
Runs entirely on a single GPU with no vLLM dependency. Places anchor points throughout a document and generates short rollouts at each anchor using block-parallel attention patterns.
Single GPU: Base model + LoRA adapter
- Strided block-parallel generation (flex_attention)
- Feature extraction via disable_adapter()
- No vLLM needed
When to use: Unstructured text data (raw code, prose, documents) where there is no natural prompt/completion split. Also works with structured data that includes prompt boundaries. Requires only 1 GPU.
Quick Start
Structured Mode
This minimal example fine-tunes Qwen2-0.5B on code data using EBFT with vLLM generation.
Step 1: Create a config file ebft_quickstart.yaml:
base_model: Qwen/Qwen2-0.5B-Instruct
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
alignment_coef: 1.0
diversity_coef: 1.0
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_lora_sync: true
vllm_sync_interval: 3
use_data_producer: true
async_prefetch: false
scale_rewards: true
loss_type: grpo
vllm:
gpu_memory_utilization: 0.5
max_model_len: 1024
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
# Standard training settings (see getting-started.qmd for details)
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_linear: true
sequence_len: 1024
micro_batch_size: 2
gradient_accumulation_steps: 4
max_steps: 20
learning_rate: 5.0e-6
bf16: auto
flash_attention: true
gradient_checkpointing: true
output_dir: ./outputs/ebft-quickstartStep 2: Start vLLM on GPU 0:
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve ebft_quickstart.yamlStep 3: Wait approximately 30 seconds for vLLM to initialize, then start training on GPU 1:
CUDA_VISIBLE_DEVICES=1 axolotl train ebft_quickstart.yamlThe micro_batch_size must be divisible by num_generations. For example, with num_generations: 4, valid values are 4, 8, 12, etc.
Dataset Format
Structured mode datasets must produce two fields after the transform:
prompt: Either a string or a list of chat messages ([{"role": "user", "content": "..."}])ground_truth: A string containing the reference completion
Example raw dataset row:
{
"input": "Write a function to compute fibonacci numbers.",
"output": "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)"
}The ebft_opencode.transform converts this to the required {prompt, ground_truth} format automatically.
Feature Extraction
EBFT extracts hidden states from intermediate transformer layers and pools them into per-sequence embeddings. These embeddings are compared between generated and ground-truth completions to compute rewards.
Feature Layers
The feature_layers parameter specifies which layers to extract, as fractions of total model depth:
ebft:
feature_layers: [0.25, 0.5, 0.75] # Quarter, middle, three-quarter depthFor a 32-layer model, this extracts layers 8, 16, and 24. The hidden states from all selected layers are concatenated along the feature dimension, producing embeddings of size num_layers * hidden_dim.
Using multiple layers captures both low-level syntactic features (early layers) and high-level semantic features (later layers). The default [0.25, 0.5, 0.75] works well across model sizes.
Embed Methods
The embed_method controls how per-token hidden states are pooled into a single vector per sequence:
| Method | Description | Output Shape | Notes |
|---|---|---|---|
last_token |
Hidden state at the last non-padding token | (B, D) |
Default. Good for autoregressive models where the last token summarizes the sequence. |
mean_pooling |
Mean of all non-padding token states | (B, D) |
Considers the entire sequence equally. |
completion_mean |
Mean over completion tokens only (excludes prompt) | (B, D) |
Focuses reward signal on generated content. Requires prompt length information. |
concat |
Concatenation of states at 25%, 50%, 75% positions | (B, 3*D) |
Captures positional structure. Higher dimensional. |
ebft:
embed_method: completion_mean # Focus on completion featuresSVD Whitening
Whitening decorrelates the feature dimensions so that no single direction dominates the feature-matching loss. This is computed via SVD on the generated embeddings, with the same transform applied to the ground-truth embeddings.
ebft:
use_whitening: trueWhen whitening is enabled, the reward computation applies a whitening matrix W = U @ diag(1/S) @ U^T derived from the SVD of generated embeddings. This ensures all feature dimensions contribute equally to the alignment reward.
Singular values scale with sqrt(batch_size), so reward magnitudes are batch-size dependent. This is acceptable because the number of samples per prompt (n_samples_per_prompt or num_generations) is fixed during training.
Alignment and Diversity Coefficients
The two reward components are weighted by coefficients:
ebft:
alignment_coef: 1.0 # Weight for cosine similarity with ground truth
diversity_coef: 1.0 # Weight for pairwise similarity penaltyBoth values are scaled by 2 internally (per paper equation 7). The final reward per sample is:
reward_j = 2 * alignment_coef * cos(gen_j, gt)
- 2 * diversity_coef * (1/(n-1)) * sum_{j' != j} dot(gen_j, gen_j')
Setting diversity_coef: 0.0 disables the diversity penalty entirely, which may be appropriate when num_generations is small (e.g., 2).
Strided Mode
Strided mode is designed for training on unstructured text data where there is no natural prompt/completion boundary. Instead of generating full completions with vLLM, it places anchor points at regular intervals throughout each document and generates short rollouts at each anchor using block-parallel attention.
How Block-Parallel Generation Works
Given a document of length S tokens:
- Anchor placement: Starting at position
anchor_offset, place anchors everystridetokens. Each anchor defines a block. - Context window: Each block sees
context_lengthtokens of preceding context from the original document. - Generation: At each anchor, generate
generate_max_lentokens autoregressively, conditioned only on the context window. - Parallelism: All blocks are processed in a single forward pass using a specialized attention mask that prevents information leakage between blocks.
Document: [tok0, tok1, ..., tok_S]
| | |
anchor_0 anchor_1 anchor_2
| | |
[ctx][gen] [ctx][gen] [ctx][gen]
The attention mask ensures:
- Prompt tokens use standard causal attention
- Each generated block attends to its own context window and its own preceding generated tokens
- Blocks do not attend to each other’s generated tokens
When flex_attention is available (PyTorch >= 2.5), the mask is compiled into efficient fused kernels. Otherwise, a dense 4D attention mask is used as a fallback.
Strided Mode Configuration
base_model: meta-llama/Llama-3.2-1B
rl: ebft
ebft:
mode: strided
stride: 8 # Tokens between anchor points
context_length: 8 # Context window per block
generate_max_len: 8 # Tokens to generate per block
n_samples_per_prompt: 4 # Independent rollouts per document
temperature: 0.6
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0 # RL policy gradient loss weight
ce_coef: 0.03 # Cross-entropy loss on GT tokens
advantage_estimator: rloo # rloo, group_norm, or reinforce
min_completion_prefix: 8 # Skip anchors in prompt region
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_strided_structured.transform
split: train[:1%]
sequence_len: 2048
micro_batch_size: 1
gradient_accumulation_steps: 2
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_linear: true
bf16: auto
flex_attention: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required with flex_attentionRun with a single command (no vLLM needed):
CUDA_VISIBLE_DEVICES=0 axolotl train config.yamlAdvantage Estimators
Strided mode supports three advantage estimation methods:
| Estimator | Formula | Requirements |
|---|---|---|
rloo |
Leave-one-out baseline: reward_j - mean(rewards_{-j}) |
n_samples_per_prompt >= 2 |
group_norm |
Group normalization: (reward_j - mean) / std |
n_samples_per_prompt >= 2 |
reinforce |
Raw reward as advantage (no baseline) | Works with n_samples_per_prompt = 1 |
When n_samples_per_prompt: 1, the trainer automatically falls back to reinforce and disables the diversity penalty (which requires multiple samples).
Strided Mode Constraints
flex_attention: trueis strongly recommended. Without it, dense 4D masks consume significantly more memory.torch_compile: truemust NOT be set.flex_attentioncompiles its own kernels internally; addingtorch_compilecauses conflicts and OOM.- Gradient checkpointing must use
use_reentrant: true. Non-reentrant checkpointing causesCheckpointErrorwithflex_attentionblock masks. activation_offloadingis incompatible withflex_attention.
Cross-Entropy Loss
Strided mode supports an optional cross-entropy loss term on ground-truth tokens. This acts as a regularizer to prevent the model from drifting too far from the original distribution:
ebft:
ce_coef: 0.03 # Small CE coefficient
rl_coef: 1.0 # RL loss coefficientThe total loss is rl_coef * rl_loss + ce_coef * ce_loss. For structured mode, ce_coef is typically 0.0 since vLLM generation provides sufficient learning signal.
Dataset Formats
EBFT provides several built-in dataset transforms in src/axolotl/prompt_strategies/ebft/.
Built-In Transforms
| Transform | Input Format | Output Fields | Use Case |
|---|---|---|---|
ebft_opencode.transform |
{input, output} |
{prompt, ground_truth} |
OpenCodeInstruct, structured QA |
ebft_strided_structured.transform |
{input, output} |
{input_ids, labels, prompt_length} |
Strided mode with structured data |
ebft_strided_chat.transform |
{messages: [...]} |
{input_ids, labels, prompt_length} |
Strided mode with chat data |
ebft_chat_multiturn.transform |
{messages: [...]} |
{prompt, ground_truth, remaining_turns} |
Multi-turn: first-turn target |
ebft_chat_multiturn.transform_last_turn |
{messages: [...]} |
{prompt, ground_truth} |
Multi-turn: last-turn target |
ebft_chat_multiturn.transform_all_turns |
{messages: [...]} |
{prompt[], ground_truth[]} |
Multi-turn: one example per turn |
ebft_reasoning.transform |
{messages: [...]} (with <think>) |
{prompt, ground_truth} |
Reasoning/thinking datasets |
Structured Mode Datasets
For structured (sync/async) mode, the transform must produce prompt and ground_truth fields:
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]Multi-Turn Datasets
Multi-turn transforms extract conversation data for sequential rollout. The transform variant targets the first assistant turn, while transform_last_turn targets the final turn:
datasets:
- path: your/multiturn-dataset
type: ebft_chat_multiturn.transformWhen remaining_turns is present in the dataset output, the trainer performs sequential rollouts: it generates the first assistant turn with vLLM, then continues generating subsequent turns by building up the conversation history.
Strided Mode Datasets
Strided transforms tokenize the full document and produce input_ids, labels, and prompt_length:
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_strided_structured.transform
split: train[:1%]Custom Transforms
To use your own dataset format, write a transform function:
def transform(cfg, **kwargs):
def transform_fn(example, tokenizer=None):
return {
"prompt": [{"role": "user", "content": example["question"]}],
"ground_truth": example["answer"],
}
return transform_fn, {"remove_columns": "__all__"}The "__all__" sentinel removes all original dataset columns after the mapping step. Reference this transform in your config:
datasets:
- path: your/dataset
type: your_module.transformConfiguration Reference
Common Parameters (All Modes)
These parameters are set under the ebft: key in the YAML config.
| Parameter | Type | Default | Description |
|---|---|---|---|
mode |
"structured" or "strided" |
"structured" |
EBFT operating mode |
feature_layers |
list[float] |
[0.25, 0.5, 0.75] |
Fractional layer depths for feature extraction |
embed_method |
string |
"last_token" |
Pooling method: last_token, mean_pooling, completion_mean, or concat |
use_whitening |
bool |
false |
Apply SVD whitening to feature embeddings before reward computation |
alignment_coef |
float |
1.0 |
Weight for alignment reward (cosine similarity with ground truth) |
diversity_coef |
float |
1.0 |
Weight for diversity penalty (pairwise dot product between samples) |
ce_coef |
float |
0.0 |
Cross-entropy loss coefficient on ground-truth tokens |
adaptive_max_tokens |
bool |
true |
Dynamically set vLLM max_tokens based on ground-truth length (structured mode) |
gt_length_multiplier |
float |
1.5 |
Multiplier for ground-truth token count when computing adaptive max tokens (min 0.1) |
Strided Mode Parameters
These additional parameters apply only when mode: strided.
| Parameter | Type | Default | Description |
|---|---|---|---|
stride |
int |
8 |
Number of tokens between anchor points (must be >= 1) |
context_length |
int |
8 |
Context window size for each generated block (must be >= 1) |
generate_max_len |
int |
8 |
Number of tokens to generate per block (must be >= 1) |
n_samples_per_prompt |
int |
4 |
Number of independent rollouts per document (must be >= 1) |
temperature |
float |
0.6 |
Sampling temperature for strided generation |
top_p |
float |
1.0 |
Top-p nucleus sampling threshold |
rl_coef |
float |
1.0 |
RL policy gradient loss coefficient |
advantage_estimator |
string |
"rloo" |
Advantage estimation method: rloo, group_norm, or reinforce |
min_completion_prefix |
int |
0 |
Minimum tokens into the completion span before placing anchors |
Structured Mode TRL Parameters
These are set under the trl: key and control the GRPO training loop.
| Parameter | Type | Default | Description |
|---|---|---|---|
num_generations |
int |
– | Number of completions generated per prompt |
max_completion_length |
int |
– | Maximum tokens per generated completion |
temperature |
float |
0.7 |
Sampling temperature for vLLM generation |
use_vllm |
bool |
– | Enable vLLM generation backend |
vllm_lora_sync |
bool |
false |
Sync LoRA adapters via filesystem (recommended) |
vllm_sync_interval |
int |
1 |
Steps between weight syncs to vLLM |
use_data_producer |
bool |
– | Required for sync mode with LoRA sync |
async_prefetch |
bool |
false |
Enable async generation (overlaps with training) |
streaming_partial_batch |
bool |
false |
Score groups incrementally (async mode) |
skip_zero_advantage_batches |
bool |
false |
Skip micro-batches where all advantages are zero |
scale_rewards |
bool |
– | Normalize rewards within each prompt group |
loss_type |
string |
"grpo" |
Loss type for policy optimization |
epsilon |
float |
0.2 |
Clipping parameter for importance sampling |
Stop Tokens
vLLM needs explicit stop token IDs for generation. Common configurations:
trl:
generation_kwargs:
stop_token_ids: [151645, 151643] # Qwen: <|im_end|>, <|endoftext|>Multi-Turn Chat Settings
For multi-turn conversations with Qwen3.5, disable thinking mode to prevent <think> tags in completions:
trl:
chat_template_kwargs:
enable_thinking: falseMonitoring
Key Metrics
EBFT logs several custom metrics to wandb and the training console. Here is what to watch for:
| Metric | Healthy Range | Interpretation |
|---|---|---|
ebft/alignment |
0.3 – 0.9, trending upward | Cosine similarity between generated and ground-truth features. Higher means the model is learning to produce representations that match the reference. |
ebft/diversity |
0.01 – 0.1 | Mean pairwise similarity between different generations for the same prompt. Values above 1.0 indicate mode collapse. |
ebft/cfm_loss |
Below 10, trending downward | Cross-Feature Matching loss. This is the core quantity being minimized. Consistently above 100 indicates instability. |
ebft/reward |
Trending upward (may start negative) | Combined reward signal. If stuck at -1.0, the diversity penalty is dominating alignment. |
grad_norm |
0.1 – 3.0 | Gradient magnitude. Values of 0.0 indicate zero-advantage skip (normal). Values above 10 suggest instability. |
entropy |
0.05 – 0.5 | Policy entropy. Values below 0.01 suggest mode collapse. |
IS ratio min |
Above 0.1 | Importance sampling ratio minimum. Near-zero values mean the policy is too far off-policy; increase vllm_sync_interval. |
Console Log Example
During training, you will see periodic EBFT reward logs:
ebft reward | align +0.412 ^ | divers +0.023 v | cfm 4.231 v | reward +0.389 ^
The arrows indicate the desired direction: alignment and reward should trend upward, while diversity and CFM loss should trend downward.
Troubleshooting
| Symptom | Likely Cause | Fix |
|---|---|---|
alignment stays below 0.1 |
Feature layers not capturing useful information | Try different feature_layers or embed_method |
diversity exceeds 1.0 |
Mode collapse – generations are too similar | Increase diversity_coef or temperature |
reward stuck at -1.0 |
Diversity penalty dominates alignment | Reduce diversity_coef or increase alignment_coef |
grad_norm consistently 0.0 |
All micro-batches have zero advantage | Increase num_generations or check data quality |
CheckpointError in strided mode |
Incompatible gradient checkpointing settings | Set use_reentrant: true in gradient_checkpointing_kwargs |
| OOM during training | Logits tensor too large | Reduce sequence_len or micro_batch_size; strided mode uses chunked lm_head to mitigate this |
| vLLM 500 errors | truncate_prompt_tokens not supported |
Ensure you are using axolotl vllm-serve (not trl vllm-serve) |
Feature Network Memory
In PEFT (LoRA) mode, the feature network shares base weights with the actor model by using the disable_adapter() context manager. This saves an entire model copy in VRAM (approximately 1–16 GB depending on model size). For non-PEFT training, a separate frozen deepcopy is created.
The disable_adapter() approach relies on an invariant: merge_adapter() is never called on the base weights. All weight sync paths (LoRA sync, HTTP, NCCL) compute merged weights as new tensors or save the adapter to the filesystem, leaving base weights unmodified.
Examples
Complete example configurations are available in examples/ebft/:
| Config | Model | Mode | Description |
|---|---|---|---|
llama-1b-ebft-strided-structured.yaml |
Llama 3.2 1B | Strided | Single-GPU strided training on code data |
qwen3-4b-ebft-structured.yaml |
Qwen3 4B | Structured (sync) | Two-GPU structured training |
qwen3-4b-ebft-structured-async.yaml |
Qwen3 4B | Structured (async) | Two-GPU async training with prefetch |
qwen3-8b-ebft-structured.yaml |
Qwen3 8B | Structured (sync) | Two-GPU structured training for larger model |
qwen35-4b-ebft-structured.yaml |
Qwen3.5 4B | Structured (sync) | Two-GPU with Qwen3.5 |
qwen35-4b-ebft-structured-async.yaml |
Qwen3.5 4B | Structured (async) | Two-GPU async with Qwen3.5 |
qwen35-9b-ebft-structured.yaml |
Qwen3.5 9B | Structured (sync) | Two-GPU structured for 9B model |