integrations.kd.topk_logprob.forward_kl
integrations.kd.topk_logprob.forward_kl
loss for top_k KL divergence
Classes
| Name | Description |
|---|---|
| ChunkedTopKKDLoss | A wrapper that chunks (splits) the student and teacher outputs along the time dimension |
ChunkedTopKKDLoss
integrations.kd.topk_logprob.forward_kl.ChunkedTopKKDLoss(
num_output_chunks=8,
kd_temperature=1.0,
)A wrapper that chunks (splits) the student and teacher outputs along the time dimension to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
Functions
| Name | Description |
|---|---|
| loss | A KD loss function that is TorchScript-friendly. |
loss
integrations.kd.topk_logprob.forward_kl.loss(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch=-1,
kd_temperature=1.0,
)A KD loss function that is TorchScript-friendly.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| student_logits | torch.Tensor | The logits of the student model. Shape: [B, student_seq_len, vocab_size] | required |
| target_token_ids | torch.Tensor | The top-k teacher/target token IDs Shape: [B, teacher_seq_len, top_k] | required |
| target_logprobs | torch.Tensor | The top-k teacher/target logprobs, these should already be re-normalized. Shape: [B, teacher_seq_len, top_k] | required |
| target_mask | torch.Tensor | The mask for valid tokens. Shape: [B, teacher_seq_len, top_k] | required |
| num_items_in_batch | int | The number of items in the batch. | -1 |
| kd_temperature | float | The temperature for KD. Default: 1.0 | 1.0 |