integrations.kd.topk_logprob.forward_kl

integrations.kd.topk_logprob.forward_kl

loss for top_k KL divergence

Classes

Name Description
ChunkedTopKKDLoss A wrapper that chunks (splits) the student and teacher outputs along the time dimension

ChunkedTopKKDLoss

integrations.kd.topk_logprob.forward_kl.ChunkedTopKKDLoss(
    num_output_chunks=8,
    kd_temperature=1.0,
)

A wrapper that chunks (splits) the student and teacher outputs along the time dimension to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.

Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.

Functions

Name Description
loss A KD loss function that is TorchScript-friendly.

loss

integrations.kd.topk_logprob.forward_kl.loss(
    student_logits,
    target_token_ids,
    target_logprobs,
    target_mask,
    num_items_in_batch=-1,
    kd_temperature=1.0,
)

A KD loss function that is TorchScript-friendly.

Parameters

Name Type Description Default
student_logits torch.Tensor The logits of the student model. Shape: [B, student_seq_len, vocab_size] required
target_token_ids torch.Tensor The top-k teacher/target token IDs Shape: [B, teacher_seq_len, top_k] required
target_logprobs torch.Tensor The top-k teacher/target logprobs, these should already be re-normalized. Shape: [B, teacher_seq_len, top_k] required
target_mask torch.Tensor The mask for valid tokens. Shape: [B, teacher_seq_len, top_k] required
num_items_in_batch int The number of items in the batch. -1
kd_temperature float The temperature for KD. Default: 1.0 1.0