monkeypatch.gradient_checkpointing.offload_disk

monkeypatch.gradient_checkpointing.offload_disk

DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching

Classes

Name Description
Disco Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
DiskOffloadManager Manages offloaded tensors and handles prefetching in a separate thread.

Disco

monkeypatch.gradient_checkpointing.offload_disk.Disco()

Disco: DIsk-based Storage and Checkpointing with Optimized prefetching Advanced disk-based gradient checkpointer with prefetching.

Methods

Name Description
backward Backward pass that loads activations from disk with prefetching
forward Forward pass that offloads activations to disk asynchronously
get_instance Get or create the offload manager
backward
monkeypatch.gradient_checkpointing.offload_disk.Disco.backward(
    ctx,
    *grad_outputs,
)

Backward pass that loads activations from disk with prefetching

forward
monkeypatch.gradient_checkpointing.offload_disk.Disco.forward(
    ctx,
    forward_function,
    hidden_states,
    *args,
    prefetch_size=1,
    prefetch_to_gpu=True,
    save_workers=4,
)

Forward pass that offloads activations to disk asynchronously

get_instance
monkeypatch.gradient_checkpointing.offload_disk.Disco.get_instance(
    prefetch_size=1,
    prefetch_to_gpu=True,
    save_workers=4,
)

Get or create the offload manager

DiskOffloadManager

monkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager(
    prefetch_size=3,
    prefetch_to_gpu=True,
    save_workers=4,
)

Manages offloaded tensors and handles prefetching in a separate thread. Includes synchronization to prevent race conditions.

Methods

Name Description
cleanup Clean up all temp files and stop prefetch thread with proper synchronization
cleanup_tensor Clean up a specific tensor file after it’s been used
load_tensor Load tensor from disk or prefetch cache with proper synchronization
save_tensor Save tensor to disk asynchronously and return file path with thread-safe operations
trigger_prefetch Trigger prefetching of the next N tensors with proper synchronization
wait_for_save Wait for a tensor to be saved to disk
cleanup
monkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.cleanup()

Clean up all temp files and stop prefetch thread with proper synchronization

cleanup_tensor
monkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.cleanup_tensor(
    file_path,
)

Clean up a specific tensor file after it’s been used

load_tensor
monkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.load_tensor(
    file_path,
    target_device='cuda',
)

Load tensor from disk or prefetch cache with proper synchronization

save_tensor
monkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.save_tensor(
    tensor,
)

Save tensor to disk asynchronously and return file path with thread-safe operations

trigger_prefetch
monkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.trigger_prefetch(
    n=None,
)

Trigger prefetching of the next N tensors with proper synchronization

wait_for_save
monkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.wait_for_save(
    file_path,
    timeout=None,
)

Wait for a tensor to be saved to disk