monkeypatch.fsdp2_qlora
monkeypatch.fsdp2_qlora
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2 and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization metadata through the FSDP2 shard/unshard cycle.
Functions
| Name | Description |
|---|---|
| apply_init_dtype_attrs_patch | Prevent FSDP2 mixed precision from casting non-float quantized params. |
| apply_init_sharded_param_patch | Apply patch to FSDPParam._init_sharded_param to support Params4bit. |
| apply_init_unsharded_param_patch | Apply patch to FSDPParam.init_unsharded_param to support Params4bit. |
| apply_linear8bitlt_save_patch | Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params. |
apply_init_dtype_attrs_patch
monkeypatch.fsdp2_qlora.apply_init_dtype_attrs_patch()Prevent FSDP2 mixed precision from casting non-float quantized params.
When mixed precision is enabled (e.g., bf16), FSDP2’s init_dtype_attrs sets param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts the sharded param to param_dtype. For non-float params (uint8 packed 4-bit, int8 quantized) without FSDP2 extensions, this destroys the quantized data.
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our parametrize-based expert quantization uses plain nn.Parameter(uint8/int8) without extensions.
apply_init_sharded_param_patch
monkeypatch.fsdp2_qlora.apply_init_sharded_param_patch()Apply patch to FSDPParam._init_sharded_param to support Params4bit.
apply_init_unsharded_param_patch
monkeypatch.fsdp2_qlora.apply_init_unsharded_param_patch()Apply patch to FSDPParam.init_unsharded_param to support Params4bit.
apply_linear8bitlt_save_patch
monkeypatch.fsdp2_qlora.apply_linear8bitlt_save_patch()Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params. BnB’s _save_to_state_dict accesses self.weight.SCB directly, but DTensor doesn’t proxy custom attribute access to its _local_tensor. This patch temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.