tperm

Rewriting the Mamba-2 SSD Layer in Triton

Mamba has been a key figure in the topic of efficiency; the SSD layer is a linear recurrence that scales better than attention at long contexts.

The recurrence itself is sequential, each hidden state depends on the previous one. Unrolling the computation shows a parallel prefix scan, and the implementation recognizes this at the algorithm level and parallelizes across chunks, but inside each chunk it still processes timesteps sequentially.

As such, I rewrote the SSD level in Triton to exploit the within-tile parallelism.

Math

The recurrence is:

ht=Atht1+Btxt
yt=Ctht

where At is a scalar decay per head, Bt projects input into state space, and Ct reads out. This is sequential, ht depends on ht1, but unrolling it changes it to

ht=gateth0+s=1tgatet,sBsxs
gatet,s=k=s+1tAk

gatet,s is the cumulative decay from timestep s to t, it controls how much the past input at position s still matters by the time we reach t. If the decay at A is close to 1, the gate stays large and the model has long memory. If A is close to 0, the gate shrinks quickly and recent inputs dominate.

So every ht is just a weighted sum of all past inputs, where the weights are these gates. This is what makes the recurrence a parallel prefix scan, the gates are associative meaning you can compute them in any order and combine partial results (which is exactly what tl.associative_scan does.)

The Kernel

Both variants launch one Triton program per (batch, head) pair, 128 programs in parallel for our benchmark config. Each program walks through the full sequence in tiles of BLOCK_T timesteps, carrying the hidden state h across tile boundaries.

The sequential variant only saves one h vector per tile boundary. During backward it replays the recurrence from that saved state to reconstruct the intermediate h values. Memory cost: one vector per tile.

The parallel variant uses tl.associative_scan, computing all h values within a tile simultaneously. There's no sequential replay, so every $ h_t $ has to be written to high bandwidth memory (HBM) during forward and read back during backward, memory cost is seqlen * d_state floats.

Since they differ massively in backward structure, each variant has its own backward kernel. The sequential backward recomputes h within each tile from the saved boundary state, then differentiates through the recurrence. The parallel backward loads the full saved h tensor and differentiates through the associative scan, which is basically a reverse scan over the tile.

TL;DR: The sequential backward is cheaper on memory traffic; the parallel backward pays an HBM round-trip for the saved activations.

Results

All benchmarks: batch=16, n_heads=8, d_state=64, H200 SXM. Timing uses 100 warmup + 500 timed runs with torch.cuda.Event.

Forward latency

seqlen naive_ms seq_ms par_ms mamba_ms par_GBps par/mamba
51248.030.1780.0750.6363438.48x
102476.700.6260.0690.4947507.20x
2048127.501.2500.1390.6317414.55x
4096265.702.5060.2580.6327972.45x
8192766.6*4.9880.5110.7628051.49x
163841619.3*10.0111.0151.3748101.48x
327683162.8*20.0182.0212.7578131.48x
* single-run wall-clock estimate — Python loop too slow for 500-run average

The first thing to notice is mamba_ssm sitting flat ~0.62ms across seqlen 512-2048 despite the data volume increasing by 4x. This isn't due to more work, but rather it's paying the fixed cost of launching 4 separate kernels: cumsum, chunk_state, state_passing, and chunk_scan. At short sequences that overhead dominates, so the "8x" speedup is really just "we have one kernel, they have four."

Bandwidth plateaus at ~810 GB/s from 8K onward, so the kernel is compute-bound, not memory-bound. As such, the scan tree's arithmetic is the ceiling. On Hopper, extra compute is cheap relative to extra memory reads, so this is the right tradeoff.

A question that was bugging me: was the advantage due to just launch overhead though? Or was the kernel more efficient?

Pure Kernel time

To confirm the speedup isn't just our kernel having a cheaper launch overhead, pure kernel time was isolated via torch.profiler:

seqlen par_us mamba_us speedup
20481091791.64x
40962153471.62x
81924196741.61x
1638483813311.59x
32768168526281.56x

Thankfully, the speedup holds up. ssm's four-kernel pipeline means 4 HBM round-trips per forward call; our kernel pays for that in arithmetic instead via the scan tree. As we discussed last section this is the right tradeoff.

That being said, we have one last thing to tweak: BLOCK_T, the number of timesteps each tile processes.

BLOCK_T

My initial choice was BLOCK_T=64, but let's try a few other numbers.

BLOCK_T ms GB/s
640.510805
1280.422973
2560.510807
5121.771232
BLOCK_T=128 is the clear winner. My intuition as to why is that below 128, tiles are too small and the fixed setup cost per tile is a larger fraction of total time. Going from 64 to 128 doubles the work per tile without changing the scan tree depth, so you get better that cost spread over more computation for free. Above 128, the register file becomes the bottleneck as the scan needs to keep `h` live across all BLOCK_T timesteps, which is `BLOCK_T * d_state` floats. At 512, it spills because `512 * 64 * 4 = 128KB` per tensor, which is more than an SM can hold. (This reminded me a lot of amortized analysis.)

Forward + Backward

The backward pass is where the gap widens. As we have seen before, mamba-ssm's chunked decomposition means every stage of the forward has its own gradient kernel. Our kernel collapses this into one reverse scan. Here's what that costs in practice:

seqlen fs_fwd fs_fwd+bwd fs_bwd mb_fwd mb_fwd+bwd mb_bwd speedup
20480.1122.5352.4240.62810.78910.1614.26x
40960.2152.5422.3270.64310.89710.2554.29x
81920.4212.5312.1100.68211.18110.4984.42x
163840.8384.2153.3771.33211.2319.8992.66x
327681.6848.3866.7022.69120.26717.5762.42x

These results were pretty hard to interpret. The 2K–8K entries for our kernel are flat at ~2.53ms despite seqlen doubling twice. If the kernel were actually doing proportional work at those sizes, the time would scale, but it doesn't. The culprit was PyTorch's autograd overhead: before any GPU work happens, Python has to traverse the computation graph built during the forward pass, figure out which backward kernels to call and in what order, and schedule them. At small sequence lengths this bookkeeping cost dominates the actual kernel execution time. The same thing happened in the forward pass at seqlen=512 where event recording overhead caused a 32% timing discrepancy. As such, the speedup (2.42x) that's actually worth noting is at 32K because both kernels are well above their floors.

Additionally at 32K, our kernel's backward is 4x its forward in terms of time; mamba-ssm's is ~6.5x. This is because we pay one HBM round trip for the saved activations, while mamba-ssm pays once per decomposed stage (4).

Conclusion

Our kernel shows the reference mamba-ssm implementation leaves measurable performance on the table, 1.56x in pure kernel time on H200. What I am interested in is Nemotron-3-Super since it uses a hybrid Mamba + Transformer backbone at production scale. So would we be able to find similar gaps in NVIDIA's own inference stack?

References & Useful stuff

Mamba Blog by Tri Dao and Albert Gu

GitHub