core.trainers.ebft.kernels

core.trainers.ebft.kernels

Fused Triton kernels for strided EBFT.

These kernels eliminate intermediate tensor materializations that dominate the elementwise/fill category (~40% of CUDA time in profiling).

Kernels

  1. fused_log_softmax_gather: log_softmax + gather in one pass (no full vocab materialization)
  2. fused_masked_reinforce_loss: -logp * advantage * mask, reduced to scalar
  3. fused_cosine_similarity: batched cosine similarity without intermediate tensors

Functions

Name Description
fused_cosine_similarity Compute cosine similarity along the last dimension.
fused_diversity_penalty Compute mean pairwise dot product (excluding self) per sample.
fused_log_softmax_gather Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output.
fused_reinforce_loss Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum().

fused_cosine_similarity

core.trainers.ebft.kernels.fused_cosine_similarity(a, b)

Compute cosine similarity along the last dimension.

Parameters

Name Type Description Default
a, b (…, D) tensors of the same shape required

Returns

Name Type Description
torch.Tensor (…,) tensor of cosine similarities

fused_diversity_penalty

core.trainers.ebft.kernels.fused_diversity_penalty(embeddings)

Compute mean pairwise dot product (excluding self) per sample.

Parameters

Name Type Description Default
embeddings torch.Tensor (B, N, D) tensor where N is n_samples required

Returns

Name Type Description
torch.Tensor (B, N) tensor of diversity penalties

fused_log_softmax_gather

core.trainers.ebft.kernels.fused_log_softmax_gather(logits, labels)

Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output.

Parameters

Name Type Description Default
logits torch.Tensor (B, S, V) or (B*S, V) float tensor (bf16 or fp32) required
labels torch.Tensor (B, S) or (B*S,) int64 tensor of target indices required

Returns

Name Type Description
torch.Tensor (B, S) or (B*S,) float32 tensor of selected log probabilities

fused_reinforce_loss

core.trainers.ebft.kernels.fused_reinforce_loss(
    per_token_logps,
    advantages,
    action_mask,
)

Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum().

All inputs should be flat or will be flattened. Returns scalar loss tensor.