cli.merge_sharded_fsdp_weights
cli.merge_sharded_fsdp_weights
CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.
Classes
| Name | Description |
|---|---|
| BFloat16CastPlanner | A custom planner to cast tensors to bfloat16 on the fly during loading. |
BFloat16CastPlanner
cli.merge_sharded_fsdp_weights.BFloat16CastPlanner()A custom planner to cast tensors to bfloat16 on the fly during loading.
Functions
| Name | Description |
|---|---|
| do_cli | Parses axolotl config, CLI args, and calls merge_fsdp_weights. |
| merge_fsdp_weights | Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if |
do_cli
cli.merge_sharded_fsdp_weights.do_cli(config=Path('examples/'), **kwargs)Parses axolotl config, CLI args, and calls merge_fsdp_weights.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| config | Union[Path, str] | Path to axolotl config YAML file. |
Path('examples/') |
| kwargs | Additional keyword arguments to override config file values. | {} |
merge_fsdp_weights
cli.merge_sharded_fsdp_weights.merge_fsdp_weights(
checkpoint_dir,
output_path,
safe_serialization=False,
remove_checkpoint_dir=False,
)Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
SHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors if
safe_serialization else pytorch_model.bin.
Note: this is a CPU-bound process.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| checkpoint_dir | str |
The directory containing the FSDP checkpoints (can be either the model or optimizer). | required |
| output_path | str |
The path to save the merged checkpoint. | required |
| safe_serialization | bool, optional, defaults to True |
Whether to save the merged weights with safetensors (recommended). | False |
| remove_checkpoint_dir | bool, optional, defaults to False |
Whether to remove the checkpoint directory after merging. | False |
Raises
| Name | Type | Description |
|---|---|---|
| ValueError | If torch version < 2.3.0, or if checkpoint_dir does not exist. |