core.trainers.ebft.rewards

core.trainers.ebft.rewards

Feature-matching reward utilities for Energy-Based Fine-Tuning (EBFT).

Ported from: feature-002/ebft_openrlhf/openrlhf/utils/embedding_utils.py Paper: “Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models” (Jelassi et al., 2026) https://arxiv.org/abs/2603.12248

Functions

Name Description
apply_embed_method Pool per-token hidden states into per-sequence embeddings.
extract_hidden_states Forward pass through model, extracting and concatenating hidden states
get_alignment_rewards Compute alignment reward as cosine similarity between generated
get_diversity_rewards Compute diversity penalty as mean pairwise dot-product similarity
whiten_embeddings_batched Whiten generated embeddings using SVD, then apply same transform to ground-truth.

apply_embed_method

core.trainers.ebft.rewards.apply_embed_method(
    hidden_states,
    method,
    attention_mask=None,
    prompt_lengths=None,
)

Pool per-token hidden states into per-sequence embeddings.

Parameters

Name Type Description Default
hidden_states torch.Tensor (B, S, D) concatenated hidden states required
method str One of “last_token”, “mean_pooling”, “completion_mean”, “concat” required
attention_mask torch.Tensor | None (B, S) mask for mean pooling None
prompt_lengths torch.Tensor | None (B,) number of prompt tokens per sample (for completion_mean) None

Returns

Name Type Description
torch.Tensor Sequence embeddings: (B, D) for last_token/mean_pooling/completion_mean, (B, 3*D) for concat

extract_hidden_states

core.trainers.ebft.rewards.extract_hidden_states(
    model,
    input_ids,
    attention_mask,
    layer_indices,
    batch_size=None,
)

Forward pass through model, extracting and concatenating hidden states at specified layer indices.

Parameters

Name Type Description Default
model nn.Module The frozen feature network required
input_ids torch.Tensor (B, S) token ids required
attention_mask torch.Tensor (B, S) attention mask required
layer_indices list[int] List of layer indices to extract (e.g., [8, 16, 24] for 32-layer model) required
batch_size int | None If set, process in chunks to reduce peak memory None

Returns

Name Type Description
torch.Tensor Concatenated hidden states: (B, S, num_layers * H)

get_alignment_rewards

core.trainers.ebft.rewards.get_alignment_rewards(gen_embedding, gt_embedding)

Compute alignment reward as cosine similarity between generated and ground-truth feature embeddings.

Parameters

Name Type Description Default
gen_embedding torch.Tensor (B, D) generated sequence embeddings required
gt_embedding torch.Tensor (B, D) ground-truth sequence embeddings If num_generations > 1, gt_embedding should be repeated to match gen_embedding’s batch dim. required

Returns

Name Type Description
torch.Tensor Alignment rewards: (B,) cosine similarities in [-1, 1]

get_diversity_rewards

core.trainers.ebft.rewards.get_diversity_rewards(gen_embedding, num_generations)

Compute diversity penalty as mean pairwise dot-product similarity between samples from the same prompt (excluding self-similarity).

Parameters

Name Type Description Default
gen_embedding torch.Tensor (B, D) generated embeddings where B = num_prompts * num_generations required
num_generations int Number of generations per prompt required

Returns

Name Type Description
torch.Tensor Diversity penalties: (B,) mean similarity to other samples from same prompt

whiten_embeddings_batched

core.trainers.ebft.rewards.whiten_embeddings_batched(
    phi,
    phi_gt,
    whiten_tol=1e-05,
    normalize=False,
)

Whiten generated embeddings using SVD, then apply same transform to ground-truth.

Whitening decorrelates feature dimensions so no single direction dominates the feature-matching loss. Uses pseudo-inverse for rank-deficient cases.

Note: Singular values scale with sqrt(B), so reward magnitudes are batch-size dependent. This is acceptable because B = n_samples_per_prompt which is fixed during training (typically 2-4).

Parameters

Name Type Description Default
phi torch.Tensor (B, D) generated embeddings (used to estimate covariance) required
phi_gt torch.Tensor (B, D) ground-truth embeddings required
whiten_tol float Tolerance for singular value cutoff 1e-05
normalize bool If True, L2-normalize after whitening False

Returns

Name Type Description
tuple[torch.Tensor, torch.Tensor] Whitened (phi, phi_gt) tuple, each (B, D)