integrations.kd.kernels.liger
integrations.kd.kernels.liger
Liger Kernels for Chunked Top-K Log-Prob Distillation
Classes
| Name | Description |
|---|---|
| LigerFusedLinearKLTopKLogprobFunction | Chunked kl-div loss for top-k logprobs |
| LigerFusedLinearKLTopKLogprobLoss | wrapper for chunked top-k logprob kl-d |
LigerFusedLinearKLTopKLogprobFunction
integrations.kd.kernels.liger.LigerFusedLinearKLTopKLogprobFunction()Chunked kl-div loss for top-k logprobs
Methods
| Name | Description |
|---|---|
| distillation_loss_fn | Compute Top-K KL divergence loss for a chunk. |
distillation_loss_fn
integrations.kd.kernels.liger.LigerFusedLinearKLTopKLogprobFunction.distillation_loss_fn(
student_logits_temp_scaled,
target_token_ids_chunk,
target_logprobs_chunk,
target_mask_chunk,
beta=0.0,
normalize_topk=True,
)Compute Top-K KL divergence loss for a chunk. Args: student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V). target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K). target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K). target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K). beta: Controls the type of KL divergence. 0.0 for Forward KL (P_teacher || P_student). 1.0 for Reverse KL (P_student || P_teacher). 0.5 for Symmetric KL (average of Forward and Reverse). normalize_topk: Whether to normalize the log probabilities Returns: Sum of KL divergence losses for the chunk.
LigerFusedLinearKLTopKLogprobLoss
integrations.kd.kernels.liger.LigerFusedLinearKLTopKLogprobLoss(
weight_hard_loss=0.5,
weight_soft_loss=0.5,
temperature=1.0,
beta=1.0,
ignore_index=-100,
compiled=True,
chunk_size=1024,
compute_ce_loss=True,
normalize_topk=True,
)wrapper for chunked top-k logprob kl-d