integrations.kd.kernels.liger

integrations.kd.kernels.liger

Liger Kernels for Chunked Top-K Log-Prob Distillation

Classes

Name Description
LigerFusedLinearKLTopKLogprobFunction Chunked kl-div loss for top-k logprobs
LigerFusedLinearKLTopKLogprobLoss wrapper for chunked top-k logprob kl-d

LigerFusedLinearKLTopKLogprobFunction

integrations.kd.kernels.liger.LigerFusedLinearKLTopKLogprobFunction()

Chunked kl-div loss for top-k logprobs

Methods

Name Description
distillation_loss_fn Compute Top-K KL divergence loss for a chunk.
distillation_loss_fn
integrations.kd.kernels.liger.LigerFusedLinearKLTopKLogprobFunction.distillation_loss_fn(
    student_logits_temp_scaled,
    target_token_ids_chunk,
    target_logprobs_chunk,
    target_mask_chunk,
    beta=0.0,
    normalize_topk=True,
)

Compute Top-K KL divergence loss for a chunk. Args: student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V). target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K). target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K). target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K). beta: Controls the type of KL divergence. 0.0 for Forward KL (P_teacher || P_student). 1.0 for Reverse KL (P_student || P_teacher). 0.5 for Symmetric KL (average of Forward and Reverse). normalize_topk: Whether to normalize the log probabilities Returns: Sum of KL divergence losses for the chunk.

LigerFusedLinearKLTopKLogprobLoss

integrations.kd.kernels.liger.LigerFusedLinearKLTopKLogprobLoss(
    weight_hard_loss=0.5,
    weight_soft_loss=0.5,
    temperature=1.0,
    beta=1.0,
    ignore_index=-100,
    compiled=True,
    chunk_size=1024,
    compute_ce_loss=True,
    normalize_topk=True,
)

wrapper for chunked top-k logprob kl-d