kernels.dora
kernels.dora
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
Fuses the weight norm computation and magnitude scaling to avoid materializing the full [out_features, in_features] combined weight matrix. The B@A product is computed row-by-row inside the kernel.
Functions
| Name | Description |
|---|---|
| triton_dora_scale | Compute DoRA mag_norm_scale using fused Triton kernel. |
triton_dora_scale
kernels.dora.triton_dora_scale(W, W_quant, A, B, s, magnitude, dtype)Compute DoRA mag_norm_scale using fused Triton kernel.
Computes B@A row-by-row inside the kernel, avoiding the full [out_features, in_features] materialization.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| W | torch.Tensor | base weight [out, in] (possibly quantized) | required |
| W_quant | quantization state | required | |
| A | torch.Tensor | LoRA A [rank, in] | required |
| B | torch.Tensor | LoRA B [out, rank] | required |
| s | float | LoRA scaling factor | required |
| magnitude | torch.Tensor | learned magnitude [out] | required |
| dtype | torch.dtype | compute dtype | required |
Returns
| Name | Type | Description |
|---|---|---|
| mag_norm_scale | torch.Tensor | [out] tensor = magnitude / ||W + s * B @ A||_2 |