core.trainers.ebft.rewards
core.trainers.ebft.rewards
Feature-matching reward utilities for Energy-Based Fine-Tuning (EBFT).
Ported from: feature-002/ebft_openrlhf/openrlhf/utils/embedding_utils.py Paper: “Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models” (Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
Functions
| Name | Description |
|---|---|
| apply_embed_method | Pool per-token hidden states into per-sequence embeddings. |
| extract_hidden_states | Forward pass through model, extracting and concatenating hidden states |
| get_alignment_rewards | Compute alignment reward as cosine similarity between generated |
| get_diversity_rewards | Compute diversity penalty as mean pairwise dot-product similarity |
| whiten_embeddings_batched | Whiten generated embeddings using SVD, then apply same transform to ground-truth. |
apply_embed_method
core.trainers.ebft.rewards.apply_embed_method(
hidden_states,
method,
attention_mask=None,
prompt_lengths=None,
)Pool per-token hidden states into per-sequence embeddings.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| hidden_states | torch.Tensor | (B, S, D) concatenated hidden states | required |
| method | str | One of “last_token”, “mean_pooling”, “completion_mean”, “concat” | required |
| attention_mask | torch.Tensor | None | (B, S) mask for mean pooling | None |
| prompt_lengths | torch.Tensor | None | (B,) number of prompt tokens per sample (for completion_mean) | None |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Sequence embeddings: (B, D) for last_token/mean_pooling/completion_mean, (B, 3*D) for concat |
get_alignment_rewards
core.trainers.ebft.rewards.get_alignment_rewards(gen_embedding, gt_embedding)Compute alignment reward as cosine similarity between generated and ground-truth feature embeddings.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| gen_embedding | torch.Tensor | (B, D) generated sequence embeddings | required |
| gt_embedding | torch.Tensor | (B, D) ground-truth sequence embeddings If num_generations > 1, gt_embedding should be repeated to match gen_embedding’s batch dim. | required |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Alignment rewards: (B,) cosine similarities in [-1, 1] |
get_diversity_rewards
core.trainers.ebft.rewards.get_diversity_rewards(gen_embedding, num_generations)Compute diversity penalty as mean pairwise dot-product similarity between samples from the same prompt (excluding self-similarity).
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| gen_embedding | torch.Tensor | (B, D) generated embeddings where B = num_prompts * num_generations | required |
| num_generations | int | Number of generations per prompt | required |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Diversity penalties: (B,) mean similarity to other samples from same prompt |
whiten_embeddings_batched
core.trainers.ebft.rewards.whiten_embeddings_batched(
phi,
phi_gt,
whiten_tol=1e-05,
normalize=False,
)Whiten generated embeddings using SVD, then apply same transform to ground-truth.
Whitening decorrelates feature dimensions so no single direction dominates the feature-matching loss. Uses pseudo-inverse for rank-deficient cases.
Note: Singular values scale with sqrt(B), so reward magnitudes are batch-size dependent. This is acceptable because B = n_samples_per_prompt which is fixed during training (typically 2-4).
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| phi | torch.Tensor | (B, D) generated embeddings (used to estimate covariance) | required |
| phi_gt | torch.Tensor | (B, D) ground-truth embeddings | required |
| whiten_tol | float | Tolerance for singular value cutoff | 1e-05 |
| normalize | bool | If True, L2-normalize after whitening | False |
Returns
| Name | Type | Description |
|---|---|---|
| tuple[torch.Tensor, torch.Tensor] | Whitened (phi, phi_gt) tuple, each (B, D) |