core.trainers.ebft.strided
core.trainers.ebft.strided
Strided block-parallel EBFT trainer for unstructured text data.
This trainer implements the full EBFT algorithm from the paper, including strided block-parallel generation where multiple short rollouts are generated at different anchor points within a single document. This is essential for training on raw text data (code, prose, etc.) without prompt/completion splits.
Uses torch flex_attention with a compiled block mask for efficient strided attention patterns. Falls back to eager attention with dense 4D masks when flex_attention is not available.
“Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models”
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
Classes
| Name | Description |
|---|---|
| AxolotlStridedEBFTTrainer | Strided block-parallel EBFT trainer for unstructured text data. |
AxolotlStridedEBFTTrainer
core.trainers.ebft.strided.AxolotlStridedEBFTTrainer(
model,
args,
train_dataset,
**kwargs,
)Strided block-parallel EBFT trainer for unstructured text data.
Takes full text documents (no prompt/completion split needed), generates short rollouts at multiple anchor points via strided attention, and trains with feature-matching rewards.
When flex_attention is available (torch >= 2.5), uses compiled block masks for efficient fused attention kernels. Otherwise falls back to eager attention with dense 4D masks.
Methods
| Name | Description |
|---|---|
| compute_loss | Full strided EBFT training step. |
compute_loss
core.trainers.ebft.strided.AxolotlStridedEBFTTrainer.compute_loss(
model,
inputs,
return_outputs=False,
num_items_in_batch=None,
)Full strided EBFT training step.
- Take tokenized documents from inputs
- Generate n_samples short rollouts at strided anchor points
- Extract features from frozen network for both generated and GT blocks
- Compute alignment/diversity rewards per block
- Compute RLOO advantages
- Policy gradient loss on the strided forward pass
Supports both unstructured text (no prompt/completion split) and structured data (prompt + completion with labels masking). For structured data, anchors are placed only within the completion span.
Functions
| Name | Description |
|---|---|
| build_strided_dense_mask_and_positions | Build dense 4D attention mask (eager fallback) + position IDs. |
| build_strided_position_ids | Build position IDs for strided generation (shared between flex and eager modes). |
| create_strided_block_mask | Create a BlockMask for flex_attention using the strided EBFT pattern. |
| override_attn_implementation | Temporarily override a model’s attention implementation. |
build_strided_dense_mask_and_positions
core.trainers.ebft.strided.build_strided_dense_mask_and_positions(
full_sequence_length,
prompt_length,
context_length,
generation_step,
max_generation_length,
stride,
num_blocks,
device,
batch_size=1,
dtype=torch.bfloat16,
)Build dense 4D attention mask (eager fallback) + position IDs.
build_strided_position_ids
core.trainers.ebft.strided.build_strided_position_ids(
full_sequence_length,
prompt_length,
context_length,
generation_step,
stride,
num_blocks,
device,
batch_size=1,
)Build position IDs for strided generation (shared between flex and eager modes).
create_strided_block_mask
core.trainers.ebft.strided.create_strided_block_mask(
prompt_length,
context_length,
max_generation_length,
stride,
num_blocks,
full_sequence_length,
batch_size,
num_heads,
device,
)Create a BlockMask for flex_attention using the strided EBFT pattern.
Returns a BlockMask that can be passed directly to model.forward() when using attn_implementation=“flex_attention”.
Parameters that vary across training steps (context_length, num_blocks) are captured as tensors so torch.compile/dynamo treats them as dynamic values rather than guarding on literal int values (which causes recompiles).
override_attn_implementation
core.trainers.ebft.strided.override_attn_implementation(model, implementation)Temporarily override a model’s attention implementation.
Useful for forcing eager attention during generation (where sequence lengths change each step, causing dynamo recompiles) while keeping flex_attention for the fixed-size training forward pass.
Usage::
with override_attn_implementation(model, "eager"):
output = model(input_ids, attention_mask=dense_4d_mask, ...)