monkeypatch.models.mamba_utils
monkeypatch.models.mamba_utils
Shared utilities for Mamba2 SSM sample-packing and context-parallelism patches.
Used by: nemotron_h, falcon_h1, granite_moe_hybrid
Functions
| Name | Description |
|---|---|
| ensure_mamba_kernels_loaded | Eagerly resolve mamba-ssm and causal-conv1d globals on target_module. |
| get_seq_idx | Convert position_ids [B, T] → seq_idx [B, T] int32 for mamba-ssm kernels. |
| is_cp_active | Return True if context parallelism (ring attention) is active on this rank. |
| mamba2_cp_correction | Apply CP correction to SSM output using the received state from rank-1. |
| ring_shift_ssm_state | P2P ring: send h_final to rank+1, receive from rank-1 within CP group. |
| wrap_mamba_scan_for_cp | Wrap mamba_chunk_scan_combined in target_module to apply CP correction. |
ensure_mamba_kernels_loaded
monkeypatch.models.mamba_utils.ensure_mamba_kernels_loaded(target_module)Eagerly resolve mamba-ssm and causal-conv1d globals on target_module.
Transformers >= 5.5 lazily loads these inside Mixer.__init__ via
lazy_load_kernel. Our monkeypatches run before model instantiation,
so the module globals are still None. This helper triggers the kernel
resolution early so the patched cuda_kernels_forward (and
wrap_mamba_scan_for_cp) can reference them.
get_seq_idx
monkeypatch.models.mamba_utils.get_seq_idx(position_ids)Convert position_ids [B, T] → seq_idx [B, T] int32 for mamba-ssm kernels.
Example: position_ids [[0,1,2,3,0,1,2]] → seq_idx [[0,0,0,0,1,1,1]]
Under context parallelism a rank may receive a chunk that begins mid-sample (position_ids[0] != 0), so the raw cumsum starts at 0 and subtracting 1 would yield -1 — an invalid value for the Mamba kernels. Subtracting the first element of the cumsum instead normalises every chunk to start at 0 while still correctly incrementing at every intra-chunk sample boundary.
Example (CP rank 1, chunk starts mid-sample): position_ids [[3,4,5,0,1,2]] → seq_idx [[0,0,0,1,1,1]]
is_cp_active
monkeypatch.models.mamba_utils.is_cp_active()Return True if context parallelism (ring attention) is active on this rank.
Zero-cost when CP is not configured: the import guard ensures we only touch the distributed group if ring_flash_attn is installed.
mamba2_cp_correction
monkeypatch.models.mamba_utils.mamba2_cp_correction(
out,
h_final,
C,
cum_A,
h_prev,
num_heads,
head_dim,
seq_idx=None,
)Apply CP correction to SSM output using the received state from rank-1.
SSM output is linear in the initial hidden state, so the contribution of h_prev can be added analytically without a second forward pass.
For each timestep t in the local chunk
propagated_state_t = cumA_t * h_prev [B, H, d, n] Δy_t = sum_over_n( C_t * propagated_state_t ) [B, H, d]
The corrected final state for this rank is
h_final_corrected = h_final + cumA_T * h_prev
Sample packing correctness (seq_idx): When sample packing is active, a CP rank may hold multiple packed sequences. Only the first sequence (seq_idx == 0) is a continuation of the previous rank’s chunk — subsequent sequences are brand-new and should receive zero correction from h_prev.
Passing seq_idx masks delta_y to zero for all tokens where
seq_idx > 0, preventing h_prev state from leaking into unrelated
packed sequences.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| out | torch.Tensor | SSM scan output from this rank, shape [B, T, D] where D = H*d. | required |
| h_final | torch.Tensor | Final SSM state from this rank, shape [B, H, d, n]. | required |
| C | torch.Tensor | Output projection matrices, shape [B, T, n_groups, n]. | required |
| cum_A | torch.Tensor | Cumulative log-transition factors, shape [B, T, H]. These are the log-space cumulative sums of A, so exp(cum_A_t) gives the transition matrix from step 0 to t. | required |
| h_prev | torch.Tensor | SSM state received from rank-1 (zeros on rank 0). Shape [B, H, d, n]. | required |
| num_heads | int | Number of SSM heads (H). | required |
| head_dim | int | Dimension per head (d). | required |
| seq_idx | torch.Tensor | None | Optional sequence index tensor, shape [B, T] int32. When provided, correction is zeroed for tokens where seq_idx > 0 (i.e. sequences that start fresh on this rank). | None |
Returns
| Name | Type | Description |
|---|---|---|
| corrected_out | torch.Tensor | out + Δy, shape [B, T, D]. |
| corrected_h_final | torch.Tensor | h_final + cumA_T * h_prev, shape [B, H, d, n]. |
ring_shift_ssm_state
monkeypatch.models.mamba_utils.ring_shift_ssm_state(h_final)P2P ring: send h_final to rank+1, receive from rank-1 within CP group.
Uses synchronous send/recv on the ring attention process group. Rank 0 in the CP group receives zeros (no previous chunk).
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| h_final | torch.Tensor | Final SSM state from this rank’s forward pass. Shape is architecture-dependent, typically [B, H, d, n]. | required |
Returns
| Name | Type | Description |
|---|---|---|
| h_prev | torch.Tensor | SSM state received from rank-1, same shape/dtype as h_final. Zero tensor on the first rank in the CP group. |
wrap_mamba_scan_for_cp
monkeypatch.models.mamba_utils.wrap_mamba_scan_for_cp(target_module)Wrap mamba_chunk_scan_combined in target_module to apply CP correction.
After the scan, if CP is active the wrapper:
1. Sends the final SSM state to the next rank via ring_shift_ssm_state.
2. Computes cumA from the scan’s A / dt / dt_bias / dt_softplus args.
3. Calls mamba2_cp_correction to add the contribution of h_prev.
This is installed per-module so it only affects the architecture whose
modeling file imports mamba_chunk_scan_combined.
The approach follows Tri Dao’s Mamba-2 systems blog: each GPU computes its local output and final states, states are passed via P2P, then outputs are corrected — no ring attention needed for SSM layers.