Rewriting the Mamba-2 SSD Layer in Triton
Mamba has been a key figure in the topic of efficiency; the SSD layer is a linear recurrence that scales better than attention at long contexts.
The recurrence itself is sequential, each hidden state depends on the previous one. Unrolling the computation shows a parallel prefix scan, and the implementation recognizes this at the algorithm level and parallelizes across chunks, but inside each chunk it still processes timesteps sequentially.
As such, I rewrote the SSD level in Triton to exploit the within-tile parallelism.
Math
The recurrence is:
where is a scalar decay per head, projects input into state space, and reads out. This is sequential, depends on but unrolling it changes it to
is the cumulative decay from timestep s to t, it controls how much the past input at position s still matters by the time we reach t. If the decay at A is close to 1, the gate stays large and the model has long memory. If A is close to 0, the gate shrinks quickly and recent inputs dominate.
So every is just a weighted sum of all past inputs, where the weights are these gates. This is what makes the recurrence a parallel prefix scan, the gates are associative meaning you can compute them in any order and combine partial results (which is exactly what tl.associative_scan does.)
The Kernel
Both variants launch one Triton program per (batch, head) pair, 128 programs in parallel for our benchmark config. Each program walks through the full sequence in tiles of BLOCK_T timesteps, carrying the hidden state h across tile boundaries.
The sequential variant only saves one h vector per tile boundary. During backward it replays the recurrence from that saved state to reconstruct the intermediate h values. Memory cost: one vector per tile.
The parallel variant uses tl.associative_scan, computing all h values within a tile simultaneously. There's no sequential replay, so every $ h_t $ has to be written to high bandwidth memory (HBM) during forward and read back during backward, memory cost is seqlen * d_state floats.
Since they differ massively in backward structure, each variant has its own backward kernel. The sequential backward recomputes h within each tile from the saved boundary state, then differentiates through the recurrence. The parallel backward loads the full saved h tensor and differentiates through the associative scan, which is basically a reverse scan over the tile.
TL;DR: The sequential backward is cheaper on memory traffic; the parallel backward pays an HBM round-trip for the saved activations.
Results
All benchmarks: batch=16, n_heads=8, d_state=64, H200 SXM. Timing uses 100 warmup + 500 timed runs with torch.cuda.Event.
Forward latency
| seqlen | naive_ms | seq_ms | par_ms | mamba_ms | par_GBps | par/mamba |
|---|---|---|---|---|---|---|
| 512 | 48.03 | 0.178 | 0.075 | 0.636 | 343 | 8.48x |
| 1024 | 76.70 | 0.626 | 0.069 | 0.494 | 750 | 7.20x |
| 2048 | 127.50 | 1.250 | 0.139 | 0.631 | 741 | 4.55x |
| 4096 | 265.70 | 2.506 | 0.258 | 0.632 | 797 | 2.45x |
| 8192 | 766.6* | 4.988 | 0.511 | 0.762 | 805 | 1.49x |
| 16384 | 1619.3* | 10.011 | 1.015 | 1.374 | 810 | 1.48x |
| 32768 | 3162.8* | 20.018 | 2.021 | 2.757 | 813 | 1.48x |
| * single-run wall-clock estimate — Python loop too slow for 500-run average | ||||||
The first thing to notice is mamba_ssm sitting flat ~0.62ms across seqlen 512-2048 despite the data volume increasing by 4x. This isn't due to more work, but rather it's paying the fixed cost of launching 4 separate kernels: cumsum, chunk_state, state_passing, and chunk_scan. At short sequences that overhead dominates, so the "8x" speedup is really just "we have one kernel, they have four."
Bandwidth plateaus at ~810 GB/s from 8K onward, so the kernel is compute-bound, not memory-bound. As such, the scan tree's arithmetic is the ceiling. On Hopper, extra compute is cheap relative to extra memory reads, so this is the right tradeoff.
A question that was bugging me: was the advantage due to just launch overhead though? Or was the kernel more efficient?
Pure Kernel time
To confirm the speedup isn't just our kernel having a cheaper launch overhead, pure kernel time was isolated via torch.profiler:
| seqlen | par_us | mamba_us | speedup |
|---|---|---|---|
| 2048 | 109 | 179 | 1.64x |
| 4096 | 215 | 347 | 1.62x |
| 8192 | 419 | 674 | 1.61x |
| 16384 | 838 | 1331 | 1.59x |
| 32768 | 1685 | 2628 | 1.56x |
Thankfully, the speedup holds up. ssm's four-kernel pipeline means 4 HBM round-trips per forward call; our kernel pays for that in arithmetic instead via the scan tree. As we discussed last section this is the right tradeoff.
That being said, we have one last thing to tweak: BLOCK_T, the number of timesteps each tile processes.
BLOCK_T
My initial choice was BLOCK_T=64, but let's try a few other numbers.
| BLOCK_T | ms | GB/s |
|---|---|---|
| 64 | 0.510 | 805 |
| 128 | 0.422 | 973 |
| 256 | 0.510 | 807 |
| 512 | 1.771 | 232 |
Forward + Backward
The backward pass is where the gap widens. As we have seen before, mamba-ssm's chunked decomposition means every stage of the forward has its own gradient kernel. Our kernel collapses this into one reverse scan. Here's what that costs in practice:
| seqlen | fs_fwd | fs_fwd+bwd | fs_bwd | mb_fwd | mb_fwd+bwd | mb_bwd | speedup |
|---|---|---|---|---|---|---|---|
| 2048 | 0.112 | 2.535 | 2.424 | 0.628 | 10.789 | 10.161 | 4.26x |
| 4096 | 0.215 | 2.542 | 2.327 | 0.643 | 10.897 | 10.255 | 4.29x |
| 8192 | 0.421 | 2.531 | 2.110 | 0.682 | 11.181 | 10.498 | 4.42x |
| 16384 | 0.838 | 4.215 | 3.377 | 1.332 | 11.231 | 9.899 | 2.66x |
| 32768 | 1.684 | 8.386 | 6.702 | 2.691 | 20.267 | 17.576 | 2.42x |
These results were pretty hard to interpret. The 2K–8K entries for our kernel are flat at ~2.53ms despite seqlen doubling twice. If the kernel were actually doing proportional work at those sizes, the time would scale, but it doesn't. The culprit was PyTorch's autograd overhead: before any GPU work happens, Python has to traverse the computation graph built during the forward pass, figure out which backward kernels to call and in what order, and schedule them. At small sequence lengths this bookkeeping cost dominates the actual kernel execution time. The same thing happened in the forward pass at seqlen=512 where event recording overhead caused a 32% timing discrepancy. As such, the speedup (2.42x) that's actually worth noting is at 32K because both kernels are well above their floors.
Additionally at 32K, our kernel's backward is 4x its forward in terms of time; mamba-ssm's is ~6.5x. This is because we pay one HBM round trip for the saved activations, while mamba-ssm pays once per decomposed stage (4).
Conclusion
Our kernel shows the reference mamba-ssm implementation leaves measurable performance on the table, 1.56x in pure kernel time on H200. What I am interested in is Nemotron-3-Super since it uses a hybrid Mamba + Transformer backbone at production scale. So would we be able to find similar gaps in NVIDIA's own inference stack?