core.trainers.ebft.strided

core.trainers.ebft.strided

Strided block-parallel EBFT trainer for unstructured text data.

This trainer implements the full EBFT algorithm from the paper, including strided block-parallel generation where multiple short rollouts are generated at different anchor points within a single document. This is essential for training on raw text data (code, prose, etc.) without prompt/completion splits.

Uses torch flex_attention with a compiled block mask for efficient strided attention patterns. Falls back to eager attention with dense 4D masks when flex_attention is not available.

“Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models”

(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248

Classes

Name Description
AxolotlStridedEBFTTrainer Strided block-parallel EBFT trainer for unstructured text data.

AxolotlStridedEBFTTrainer

core.trainers.ebft.strided.AxolotlStridedEBFTTrainer(
    model,
    args,
    train_dataset,
    **kwargs,
)

Strided block-parallel EBFT trainer for unstructured text data.

Takes full text documents (no prompt/completion split needed), generates short rollouts at multiple anchor points via strided attention, and trains with feature-matching rewards.

When flex_attention is available (torch >= 2.5), uses compiled block masks for efficient fused attention kernels. Otherwise falls back to eager attention with dense 4D masks.

Methods

Name Description
compute_loss Full strided EBFT training step.
compute_loss
core.trainers.ebft.strided.AxolotlStridedEBFTTrainer.compute_loss(
    model,
    inputs,
    return_outputs=False,
    num_items_in_batch=None,
)

Full strided EBFT training step.

  1. Take tokenized documents from inputs
  2. Generate n_samples short rollouts at strided anchor points
  3. Extract features from frozen network for both generated and GT blocks
  4. Compute alignment/diversity rewards per block
  5. Compute RLOO advantages
  6. Policy gradient loss on the strided forward pass

Supports both unstructured text (no prompt/completion split) and structured data (prompt + completion with labels masking). For structured data, anchors are placed only within the completion span.

Functions

Name Description
build_strided_dense_mask_and_positions Build dense 4D attention mask (eager fallback) + position IDs.
build_strided_position_ids Build position IDs for strided generation (shared between flex and eager modes).
create_strided_block_mask Create a BlockMask for flex_attention using the strided EBFT pattern.
override_attn_implementation Temporarily override a model’s attention implementation.

build_strided_dense_mask_and_positions

core.trainers.ebft.strided.build_strided_dense_mask_and_positions(
    full_sequence_length,
    prompt_length,
    context_length,
    generation_step,
    max_generation_length,
    stride,
    num_blocks,
    device,
    batch_size=1,
    dtype=torch.bfloat16,
)

Build dense 4D attention mask (eager fallback) + position IDs.

build_strided_position_ids

core.trainers.ebft.strided.build_strided_position_ids(
    full_sequence_length,
    prompt_length,
    context_length,
    generation_step,
    stride,
    num_blocks,
    device,
    batch_size=1,
)

Build position IDs for strided generation (shared between flex and eager modes).

create_strided_block_mask

core.trainers.ebft.strided.create_strided_block_mask(
    prompt_length,
    context_length,
    max_generation_length,
    stride,
    num_blocks,
    full_sequence_length,
    batch_size,
    num_heads,
    device,
)

Create a BlockMask for flex_attention using the strided EBFT pattern.

Returns a BlockMask that can be passed directly to model.forward() when using attn_implementation=“flex_attention”.

Parameters that vary across training steps (context_length, num_blocks) are captured as tensors so torch.compile/dynamo treats them as dynamic values rather than guarding on literal int values (which causes recompiles).

override_attn_implementation

core.trainers.ebft.strided.override_attn_implementation(model, implementation)

Temporarily override a model’s attention implementation.

Useful for forcing eager attention during generation (where sequence lengths change each step, causing dynamo recompiles) while keeping flex_attention for the fixed-size training forward pass.

Usage::

with override_attn_implementation(model, "eager"):
    output = model(input_ids, attention_mask=dense_4d_mask, ...)