monkeypatch.accelerate.fsdp2
monkeypatch.accelerate.fsdp2
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
Functions
| Name | Description |
|---|---|
| fsdp2_load_full_state_dict | Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the |
| fsdp2_prepare_model | Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model. |
| get_state_dict | Returns the state dictionary of a model sent through [Accelerator.prepare] potentially without full |
| patch_initialize_missing_keys_for_fsdp | Patch _initialize_missing_keys to skip re-initialization on FSDP non-rank-0. |
| patch_peft_param_wrapper_for_fsdp2 | Patch PEFT’s _LoraParameterProxy.forward for FSDP2 DTensor compatibility. |
| patch_tied_keys_for_meta_device | Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors. |
fsdp2_load_full_state_dict
monkeypatch.accelerate.fsdp2.fsdp2_load_full_state_dict(
_accelerator,
model,
full_sd,
offload_to_cpu=False,
)Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args:
accelerator (Accelerator): The accelerator instance
model (torch.nn.Module):
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
full_sd (dict): The full state dict to load, can only be on rank 0
fsdp2_prepare_model
monkeypatch.accelerate.fsdp2.fsdp2_prepare_model(accelerator, model)Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| accelerator | Accelerator |
The accelerator instance | required |
| model | torch.nn.Module |
The model to prepare | required |
Returns
| Name | Type | Description |
|---|---|---|
| torch.nn.Module | torch.nn.Module: Prepared model |
get_state_dict
monkeypatch.accelerate.fsdp2.get_state_dict(self, model, unwrap=True)Returns the state dictionary of a model sent through [Accelerator.prepare] potentially without full
precision.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| model | torch.nn.Module |
A PyTorch model sent through [Accelerator.prepare] |
required |
| unwrap | bool, optional, defaults to True |
Whether to return the original underlying state_dict of model or to return the wrapped state_dict |
True |
Returns
| Name | Type | Description |
|---|---|---|
dict: The state dictionary of the model potentially without full precision. |
Example:
>>> import torch
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> net = torch.nn.Linear(2, 2)
>>> net = accelerator.prepare(net)
>>> state_dict = accelerator.get_state_dict(net)patch_initialize_missing_keys_for_fsdp
monkeypatch.accelerate.fsdp2.patch_initialize_missing_keys_for_fsdp()Patch _initialize_missing_keys to skip re-initialization on FSDP non-rank-0.
When using cpu_ram_efficient_loading, non-rank-0 processes load weights on meta device and move them to CPU as empty tensors. Without this patch, initialize_weights() re-initializes ALL parameters (via guarded init functions), which is slow and uses extra RAM per process.
The fix marks all params/buffers with is_hf_initialized=True before calling the original method, so guarded init functions (init.normal, init.zeros_, etc.) become no-ops on non-rank-0 processes. The real weights arrive later via FSDP broadcast from rank 0.
Upstream fix: https://github.com/huggingface/transformers/pull/44473 Remove this patch once transformers includes the fix in a stable release.
patch_peft_param_wrapper_for_fsdp2
monkeypatch.accelerate.fsdp2.patch_peft_param_wrapper_for_fsdp2()Patch PEFT’s _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
PEFT’s ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds delta_weight to the base weight W inside _LoraParameterProxy.forward(). Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a regular Tensor (or vice versa), causing a RuntimeError on mixed types.
This patch promotes the non-DTensor operand to match the DTensor’s spec using DTensor.from_local(), which is free for Replicate placement (just metadata wrapping, no communication).
patch_tied_keys_for_meta_device
monkeypatch.accelerate.fsdp2.patch_tied_keys_for_meta_device()Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly grouped as “tied”. Skipping them is safe since they have no real storage.