utils.fp32_norms

utils.fp32_norms

Helpers for keeping selected norm modules in fp32 under FSDP2.

Functions

Name Description
get_fp32_norm_patterns Resolve configured fp32 norm patterns from a config or tagged model.
shard_norms_fp32 Wrap matching norm modules with FSDP2 + fp32 MixedPrecisionPolicy.
tag_model_fp32_norms Attach the resolved fp32 norm patterns to the model for FSDP2 prepare.

get_fp32_norm_patterns

utils.fp32_norms.get_fp32_norm_patterns(source)

Resolve configured fp32 norm patterns from a config or tagged model.

shard_norms_fp32

utils.fp32_norms.shard_norms_fp32(
    model,
    source=None,
    *,
    patterns=None,
    fully_shard_kwargs=None,
)

Wrap matching norm modules with FSDP2 + fp32 MixedPrecisionPolicy.

tag_model_fp32_norms

utils.fp32_norms.tag_model_fp32_norms(model, cfg)

Attach the resolved fp32 norm patterns to the model for FSDP2 prepare.