monkeypatch.models.mamba_utils

monkeypatch.models.mamba_utils

Shared utilities for Mamba2 SSM sample-packing and context-parallelism patches.

Used by: nemotron_h, falcon_h1, granite_moe_hybrid

Functions

Name Description
ensure_mamba_kernels_loaded Eagerly resolve mamba-ssm and causal-conv1d globals on target_module.
get_seq_idx Convert position_ids [B, T] → seq_idx [B, T] int32 for mamba-ssm kernels.
is_cp_active Return True if context parallelism (ring attention) is active on this rank.
mamba2_cp_correction Apply CP correction to SSM output using the received state from rank-1.
ring_shift_ssm_state P2P ring: send h_final to rank+1, receive from rank-1 within CP group.
wrap_mamba_scan_for_cp Wrap mamba_chunk_scan_combined in target_module to apply CP correction.

ensure_mamba_kernels_loaded

monkeypatch.models.mamba_utils.ensure_mamba_kernels_loaded(target_module)

Eagerly resolve mamba-ssm and causal-conv1d globals on target_module.

Transformers >= 5.5 lazily loads these inside Mixer.__init__ via lazy_load_kernel. Our monkeypatches run before model instantiation, so the module globals are still None. This helper triggers the kernel resolution early so the patched cuda_kernels_forward (and wrap_mamba_scan_for_cp) can reference them.

get_seq_idx

monkeypatch.models.mamba_utils.get_seq_idx(position_ids)

Convert position_ids [B, T] → seq_idx [B, T] int32 for mamba-ssm kernels.

Example: position_ids [[0,1,2,3,0,1,2]] → seq_idx [[0,0,0,0,1,1,1]]

Under context parallelism a rank may receive a chunk that begins mid-sample (position_ids[0] != 0), so the raw cumsum starts at 0 and subtracting 1 would yield -1 — an invalid value for the Mamba kernels. Subtracting the first element of the cumsum instead normalises every chunk to start at 0 while still correctly incrementing at every intra-chunk sample boundary.

Example (CP rank 1, chunk starts mid-sample): position_ids [[3,4,5,0,1,2]] → seq_idx [[0,0,0,1,1,1]]

is_cp_active

monkeypatch.models.mamba_utils.is_cp_active()

Return True if context parallelism (ring attention) is active on this rank.

Zero-cost when CP is not configured: the import guard ensures we only touch the distributed group if ring_flash_attn is installed.

mamba2_cp_correction

monkeypatch.models.mamba_utils.mamba2_cp_correction(
    out,
    h_final,
    C,
    cum_A,
    h_prev,
    num_heads,
    head_dim,
    seq_idx=None,
)

Apply CP correction to SSM output using the received state from rank-1.

SSM output is linear in the initial hidden state, so the contribution of h_prev can be added analytically without a second forward pass.

For each timestep t in the local chunk

propagated_state_t = cumA_t * h_prev [B, H, d, n] Δy_t = sum_over_n( C_t * propagated_state_t ) [B, H, d]

The corrected final state for this rank is

h_final_corrected = h_final + cumA_T * h_prev

Sample packing correctness (seq_idx): When sample packing is active, a CP rank may hold multiple packed sequences. Only the first sequence (seq_idx == 0) is a continuation of the previous rank’s chunk — subsequent sequences are brand-new and should receive zero correction from h_prev.

Passing seq_idx masks delta_y to zero for all tokens where
seq_idx > 0, preventing h_prev state from leaking into unrelated
packed sequences.

Parameters

Name Type Description Default
out torch.Tensor SSM scan output from this rank, shape [B, T, D] where D = H*d. required
h_final torch.Tensor Final SSM state from this rank, shape [B, H, d, n]. required
C torch.Tensor Output projection matrices, shape [B, T, n_groups, n]. required
cum_A torch.Tensor Cumulative log-transition factors, shape [B, T, H]. These are the log-space cumulative sums of A, so exp(cum_A_t) gives the transition matrix from step 0 to t. required
h_prev torch.Tensor SSM state received from rank-1 (zeros on rank 0). Shape [B, H, d, n]. required
num_heads int Number of SSM heads (H). required
head_dim int Dimension per head (d). required
seq_idx torch.Tensor | None Optional sequence index tensor, shape [B, T] int32. When provided, correction is zeroed for tokens where seq_idx > 0 (i.e. sequences that start fresh on this rank). None

Returns

Name Type Description
corrected_out torch.Tensor out + Δy, shape [B, T, D].
corrected_h_final torch.Tensor h_final + cumA_T * h_prev, shape [B, H, d, n].

ring_shift_ssm_state

monkeypatch.models.mamba_utils.ring_shift_ssm_state(h_final)

P2P ring: send h_final to rank+1, receive from rank-1 within CP group.

Uses synchronous send/recv on the ring attention process group. Rank 0 in the CP group receives zeros (no previous chunk).

Parameters

Name Type Description Default
h_final torch.Tensor Final SSM state from this rank’s forward pass. Shape is architecture-dependent, typically [B, H, d, n]. required

Returns

Name Type Description
h_prev torch.Tensor SSM state received from rank-1, same shape/dtype as h_final. Zero tensor on the first rank in the CP group.

wrap_mamba_scan_for_cp

monkeypatch.models.mamba_utils.wrap_mamba_scan_for_cp(target_module)

Wrap mamba_chunk_scan_combined in target_module to apply CP correction.

After the scan, if CP is active the wrapper: 1. Sends the final SSM state to the next rank via ring_shift_ssm_state. 2. Computes cumA from the scan’s A / dt / dt_bias / dt_softplus args. 3. Calls mamba2_cp_correction to add the contribution of h_prev.

This is installed per-module so it only affects the architecture whose modeling file imports mamba_chunk_scan_combined.

The approach follows Tri Dao’s Mamba-2 systems blog: each GPU computes its local output and final states, states are passed via P2P, then outputs are corrected — no ring attention needed for SSM layers.