kernels.dora

kernels.dora

Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).

Fuses the weight norm computation and magnitude scaling to avoid materializing the full [out_features, in_features] combined weight matrix. The B@A product is computed row-by-row inside the kernel.

Functions

Name Description
triton_dora_scale Compute DoRA mag_norm_scale using fused Triton kernel.

triton_dora_scale

kernels.dora.triton_dora_scale(W, W_quant, A, B, s, magnitude, dtype)

Compute DoRA mag_norm_scale using fused Triton kernel.

Computes B@A row-by-row inside the kernel, avoiding the full [out_features, in_features] materialization.

Parameters

Name Type Description Default
W torch.Tensor base weight [out, in] (possibly quantized) required
W_quant quantization state required
A torch.Tensor LoRA A [rank, in] required
B torch.Tensor LoRA B [out, rank] required
s float LoRA scaling factor required
magnitude torch.Tensor learned magnitude [out] required
dtype torch.dtype compute dtype required

Returns

Name Type Description
mag_norm_scale torch.Tensor [out] tensor = magnitude / ||W + s * B @ A||_2