monkeypatch.fsdp2_qlora

monkeypatch.fsdp2_qlora

Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2 and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.

This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization metadata through the FSDP2 shard/unshard cycle.

Functions

Name Description
apply_init_dtype_attrs_patch Prevent FSDP2 mixed precision from casting non-float quantized params.
apply_init_sharded_param_patch Apply patch to FSDPParam._init_sharded_param to support Params4bit.
apply_init_unsharded_param_patch Apply patch to FSDPParam.init_unsharded_param to support Params4bit.
apply_linear8bitlt_save_patch Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.

apply_init_dtype_attrs_patch

monkeypatch.fsdp2_qlora.apply_init_dtype_attrs_patch()

Prevent FSDP2 mixed precision from casting non-float quantized params.

When mixed precision is enabled (e.g., bf16), FSDP2’s init_dtype_attrs sets param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts the sharded param to param_dtype. For non-float params (uint8 packed 4-bit, int8 quantized) without FSDP2 extensions, this destroys the quantized data.

Params4bit handles this via fsdp_pre/post_all_gather extensions, but our parametrize-based expert quantization uses plain nn.Parameter(uint8/int8) without extensions.

apply_init_sharded_param_patch

monkeypatch.fsdp2_qlora.apply_init_sharded_param_patch()

Apply patch to FSDPParam._init_sharded_param to support Params4bit.

apply_init_unsharded_param_patch

monkeypatch.fsdp2_qlora.apply_init_unsharded_param_patch()

Apply patch to FSDPParam.init_unsharded_param to support Params4bit.

apply_linear8bitlt_save_patch

monkeypatch.fsdp2_qlora.apply_linear8bitlt_save_patch()

Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.

After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params. BnB’s _save_to_state_dict accesses self.weight.SCB directly, but DTensor doesn’t proxy custom attribute access to its _local_tensor. This patch temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.