monkeypatch.accelerate.fsdp2

monkeypatch.accelerate.fsdp2

monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts

Functions

Name Description
fsdp2_load_full_state_dict Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
fsdp2_prepare_model Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
get_state_dict Returns the state dictionary of a model sent through [Accelerator.prepare] potentially without full
patch_initialize_missing_keys_for_fsdp Patch _initialize_missing_keys to skip re-initialization on FSDP non-rank-0.
patch_peft_param_wrapper_for_fsdp2 Patch PEFT’s _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
patch_tied_keys_for_meta_device Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.

fsdp2_load_full_state_dict

monkeypatch.accelerate.fsdp2.fsdp2_load_full_state_dict(
    _accelerator,
    model,
    full_sd,
    offload_to_cpu=False,
)

Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the parameters from rank 0 to all other ranks. This function modifies the model in-place. Args: accelerator (Accelerator): The accelerator instance model (torch.nn.Module): The model to load the state dict into, expected to be on meta device or a VRAM spike can occur full_sd (dict): The full state dict to load, can only be on rank 0

fsdp2_prepare_model

monkeypatch.accelerate.fsdp2.fsdp2_prepare_model(accelerator, model)

Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.

Parameters

Name Type Description Default
accelerator Accelerator The accelerator instance required
model torch.nn.Module The model to prepare required

Returns

Name Type Description
torch.nn.Module torch.nn.Module: Prepared model

get_state_dict

monkeypatch.accelerate.fsdp2.get_state_dict(self, model, unwrap=True)

Returns the state dictionary of a model sent through [Accelerator.prepare] potentially without full precision.

Parameters

Name Type Description Default
model torch.nn.Module A PyTorch model sent through [Accelerator.prepare] required
unwrap bool, optional, defaults to True Whether to return the original underlying state_dict of model or to return the wrapped state_dict True

Returns

Name Type Description
dict: The state dictionary of the model potentially without full precision.

Example:

>>> import torch
>>> from accelerate import Accelerator

>>> accelerator = Accelerator()
>>> net = torch.nn.Linear(2, 2)
>>> net = accelerator.prepare(net)
>>> state_dict = accelerator.get_state_dict(net)

patch_initialize_missing_keys_for_fsdp

monkeypatch.accelerate.fsdp2.patch_initialize_missing_keys_for_fsdp()

Patch _initialize_missing_keys to skip re-initialization on FSDP non-rank-0.

When using cpu_ram_efficient_loading, non-rank-0 processes load weights on meta device and move them to CPU as empty tensors. Without this patch, initialize_weights() re-initializes ALL parameters (via guarded init functions), which is slow and uses extra RAM per process.

The fix marks all params/buffers with is_hf_initialized=True before calling the original method, so guarded init functions (init.normal, init.zeros_, etc.) become no-ops on non-rank-0 processes. The real weights arrive later via FSDP broadcast from rank 0.

Upstream fix: https://github.com/huggingface/transformers/pull/44473 Remove this patch once transformers includes the fix in a stable release.

patch_peft_param_wrapper_for_fsdp2

monkeypatch.accelerate.fsdp2.patch_peft_param_wrapper_for_fsdp2()

Patch PEFT’s _LoraParameterProxy.forward for FSDP2 DTensor compatibility.

PEFT’s ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds delta_weight to the base weight W inside _LoraParameterProxy.forward(). Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a regular Tensor (or vice versa), causing a RuntimeError on mixed types.

This patch promotes the non-DTensor operand to match the DTensor’s spec using DTensor.from_local(), which is free for Replicate placement (just metadata wrapping, no communication).

patch_tied_keys_for_meta_device

monkeypatch.accelerate.fsdp2.patch_tied_keys_for_meta_device()

Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.

Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly grouped as “tied”. Skipping them is safe since they have no real storage.