core.trainers.mixins.activation_checkpointing

core.trainers.mixins.activation_checkpointing

Trainer mixin for activation checkpointing w offloading

Classes

Name Description
ActivationOffloadingMixin Trainer mixin class for activation checkpointing w offloading

ActivationOffloadingMixin

core.trainers.mixins.activation_checkpointing.ActivationOffloadingMixin(
    *args,
    **kwargs,
)

Trainer mixin class for activation checkpointing w offloading

Functions

Name Description
get_lora_act_offloading_ctx_manager Returns the activation offloading context manager for the model. All but the last output Linear in every step will

get_lora_act_offloading_ctx_manager

core.trainers.mixins.activation_checkpointing.get_lora_act_offloading_ctx_manager(
    model,
    use_pin_memory=True,
    use_streams=True,
    min_offload_size=1024,
    max_fwd_stash_size=5,
    warn_if_no_head=True,
)

Returns the activation offloading context manager for the model. All but the last output Linear in every step will be offloaded.

If activation offloading is enabled, we return the OffloadActivations context manager. If activation offloading is disabled, we return a NoOpManager context manager.

Parameters

Name Type Description Default
model nn.Module Model to wrap with the activation offloading context manager. required
use_pin_memory bool, optional, defaults to True Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly but is a limited resource. True
use_streams bool, optional, defaults to True Whether to use streams for performance optimization where the communications get overlapped with the computation. Requires a torch build after torch-2.5.0. True
min_offload_size int, optional, defaults to 1024 Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we do not want to waste bandwidth and resources moving it to CPU and back. 1024
max_fwd_stash_size int, optional, defaults to 5 Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing runtime. 5
warn_if_no_head bool, optional, defaults to True Whether to warn if no output head is detected. If set to False, no warning will be raised if no output head is detected. True

Returns

Name Type Description
OffloadActivations contextlib.ContextDecorator: Activation offloading context manager for the model.