MoE Expert Quantization

Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load

Transformers v5 changed MoE expert layers from nn.Linear to fused nn.Parameter (3D+ tensors). This means bitsandbytes can no longer quantize them during model loading, resulting in all expert weights being loaded in full bf16 precision and causing massive VRAM usage.

quantize_moe_experts solves this by quantizing expert weights during model loading. It intercepts the weight loading process, quantizes each expert tensor on the fly, and immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory. For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.

Usage

Enable expert quantization in your Axolotl config:

quantize_moe_experts: true

This works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.

Expert LoRA targeting

You can optionally apply LoRA adapters directly to expert weights using lora_target_parameters:

lora_target_parameters:
  - mlp.experts.gate_up_proj
  - mlp.experts.down_proj
  # - mlp.gate.weight  # router
Note

lora_dropout must be 0 when using lora_target_parameters.

Requirements

  • Requires (adapter: lora and load_in_8bit: true) or (adapter: qlora and load_in_4bit: true)
  • CUDA GPUs only (not tested with ROCm or other backends)
  • FSDP2 compatible for distributed training

Limitations

  • cpu_ram_efficient_loading hangs / takes long time with FSDP2 + QLoRA.
  • Total model parameter count may display incorrectly (trainable param count is correct).
  • FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.
  • FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.
  • Model loading takes longer due to on-demand quantization, even on consecutive runs.
  • DeepSpeed has not been tested.

Implementation details

The quantization is applied by patching transformers to intercept weight loading. When a 3D+ CUDA tensor with “expert” in its name is detected:

  • 4-bit mode: Uses bitsandbytes NF4 parametrization (configurable via bnb_4bit_quant_type).
  • 8-bit mode: Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.

The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.

For full implementation details, see PR #3439.