core.trainers.ebft.kernels
core.trainers.ebft.kernels
Fused Triton kernels for strided EBFT.
These kernels eliminate intermediate tensor materializations that dominate the elementwise/fill category (~40% of CUDA time in profiling).
Kernels
- fused_log_softmax_gather: log_softmax + gather in one pass (no full vocab materialization)
- fused_masked_reinforce_loss: -logp * advantage * mask, reduced to scalar
- fused_cosine_similarity: batched cosine similarity without intermediate tensors
Functions
| Name | Description |
|---|---|
| fused_cosine_similarity | Compute cosine similarity along the last dimension. |
| fused_diversity_penalty | Compute mean pairwise dot product (excluding self) per sample. |
| fused_log_softmax_gather | Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output. |
| fused_reinforce_loss | Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum(). |
fused_cosine_similarity
core.trainers.ebft.kernels.fused_cosine_similarity(a, b)Compute cosine similarity along the last dimension.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| a, b | (…, D) tensors of the same shape | required |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | (…,) tensor of cosine similarities |
fused_diversity_penalty
core.trainers.ebft.kernels.fused_diversity_penalty(embeddings)Compute mean pairwise dot product (excluding self) per sample.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| embeddings | torch.Tensor | (B, N, D) tensor where N is n_samples | required |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | (B, N) tensor of diversity penalties |
fused_log_softmax_gather
core.trainers.ebft.kernels.fused_log_softmax_gather(logits, labels)Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| logits | torch.Tensor | (B, S, V) or (B*S, V) float tensor (bf16 or fp32) | required |
| labels | torch.Tensor | (B, S) or (B*S,) int64 tensor of target indices | required |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | (B, S) or (B*S,) float32 tensor of selected log probabilities |
fused_reinforce_loss
core.trainers.ebft.kernels.fused_reinforce_loss(
per_token_logps,
advantages,
action_mask,
)Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum().
All inputs should be flat or will be flattened. Returns scalar loss tensor.