MoE Expert Quantization
Transformers v5 changed MoE expert layers from nn.Linear to fused nn.Parameter (3D+ tensors).
This means bitsandbytes can no longer quantize them during model loading, resulting in all expert
weights being loaded in full bf16 precision and causing massive VRAM usage.
quantize_moe_experts solves this by quantizing expert weights during model loading.
It intercepts the weight loading process, quantizes each expert tensor on the fly, and
immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.
For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.
Usage
Enable expert quantization in your Axolotl config:
quantize_moe_experts: trueThis works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.
Expert LoRA targeting
You can optionally apply LoRA adapters directly to expert weights using lora_target_parameters:
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # routerlora_dropout must be 0 when using lora_target_parameters.
Requirements
- Requires (
adapter: loraandload_in_8bit: true) or (adapter: qloraandload_in_4bit: true) - CUDA GPUs only (not tested with ROCm or other backends)
- FSDP2 compatible for distributed training
Limitations
cpu_ram_efficient_loadinghangs / takes long time with FSDP2 + QLoRA.- Total model parameter count may display incorrectly (trainable param count is correct).
- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.
- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.
- Model loading takes longer due to on-demand quantization, even on consecutive runs.
- DeepSpeed has not been tested.
Implementation details
The quantization is applied by patching transformers to intercept weight loading. When a 3D+ CUDA tensor with “expert” in its name is detected:
- 4-bit mode: Uses bitsandbytes NF4 parametrization (configurable via
bnb_4bit_quant_type). - 8-bit mode: Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.
The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.
For full implementation details, see PR #3439.