Mamba part 3 - Details of Mamba and Structured State Space
Based on West Coast Machine Learning's video on YouTube. If you like this content, support the original creators by watching, liking and subscribing to their content.
S4 discretizes continuous-time state space dynamics into stepwise hidden-state updates, enabling both sequential inference and unrolled parallel training.
Briefing
Mamba’s core pitch is that sequence modeling can be made both fast and selective without attention’s quadratic cost. The approach builds on state space models (SSMs)—specifically the S4 line of work—then adds “selectivity” so the system’s internal transition behavior changes with the input. That combination aims to keep computation efficient on GPUs while still letting the model treat different tokens differently, which matters for language where some words carry more context than others.
The explanation starts with S4’s state space framing. Inputs (denoted U) are blended with a hidden state (X) that acts like a memory of prior context. In continuous time, the hidden state evolves according to a differential equation; discretization turns that into a step-by-step recurrence. A key practical advantage comes from two equivalent views: one that updates the hidden state sequentially (good for inference), and another that “unrolls” the recurrence into a form that enables parallel gradient computation across long sequences (good for training). This is why S4 is described as fast at inference with one-step updates, and fast at training by parallelizing across the whole sequence.
However, S4 struggled when moved from modalities like signals (the example given is echocardiograms) to language. The intuition offered is that language doesn’t behave like a repeating physical cycle: some tokens are crucial while others are effectively background, and the model needs context-dependent transitions rather than time-position-only dynamics. That’s where Mamba’s selective state space (S6 / “Mamba part 3” discussion) enters.
The transcript describes improvements associated with the S6 direction: a technique inspired by “Hunger Hippos” (H3) to collapse large matrix structure into a more efficient diagonal form, reducing computational burden. It also introduces token-dependent parameters—each token can get its own effective matrices (A, B, C, D variants), and the system performs an input-driven selection mechanism (described as “attention-free” selection rather than standard attention). The result is that the model can modulate how strongly it “forgets” old state versus “remembers” new input depending on what token arrives.
A major engineering theme then follows: making the recurrence run efficiently on GPUs. The discussion emphasizes that performance bottlenecks often come from memory movement, not raw arithmetic. Mamba’s implementation uses custom CUDA kernels, parallel scan strategies (referencing S5), and kernel fusion to avoid slow round-trips to main GPU memory and to compute recurrences in a way that fits GPU hardware well. The outcome is linear scaling with sequence length during training and constant-time-per-step behavior during autoregressive inference, without maintaining a key-value cache like Transformers.
Finally, the transcript addresses why this doesn’t collapse into a vanilla RNN. The state space formulation constrains the allowed dynamics through the continuous-time-to-discrete derivation, producing stable “forget/remember” behavior rather than arbitrary recurrent weights. The discussion also touches on vanishing gradients: while long-range influence still decays, the model’s well-conditioned transitions and gating-like structure aim to keep training numerically stable, and residual connections across layers help optimization. Overall, Mamba is presented as a middle path between Transformers’ full-history attention and RNN-style recurrence—trading quadratic attention for selective, hardware-optimized state space computation that can still capture long context.
Cornell Notes
Mamba builds on state space models (SSMs) from the S4 line, where a hidden state evolves via a discretized differential equation. S4’s recurrence can be computed sequentially for inference and unrolled for parallel training, but it underperformed on language because transitions weren’t context-dependent enough. Mamba’s selective state space adds input-driven parameterization (token-dependent transition behavior) so the model can decide how much to forget and how much to incorporate new information. The system is then engineered for GPU efficiency using parallel scan, kernel fusion, and memory-aware CUDA kernels to avoid slow memory traffic. The goal is linear-time training, constant-time-per-step inference, and attention-free selectivity that still matches language needs.
How does S4 turn a continuous-time state space equation into something usable for token-by-token language modeling?
Why is S4 described as fast for training but not ideal for inference (or vice versa)?
What problem shows up when S4-style dynamics move from signals (e.g., echocardiograms) to language?
What does “selectivity” add in Mamba, and how is it different from standard attention?
Why does Mamba’s GPU implementation matter as much as its math?
How does Mamba claim to avoid Transformer-style key-value caching at inference?
Review Questions
- What computational forms of S4 enable parallel training, and why does that differ from inference-time computation?
- Explain how input-dependent selectivity changes the model’s effective “forget vs remember” behavior compared with fixed SSM dynamics.
- What GPU-level techniques (parallel scan, kernel fusion, memory-aware CUDA kernels) are used to reduce wall-clock time, and what bottleneck do they target?
Key Points
- 1
S4 discretizes continuous-time state space dynamics into stepwise hidden-state updates, enabling both sequential inference and unrolled parallel training.
- 2
Language needs context-dependent transitions; fixed time-position dynamics that work for repeating signals can underperform on text.
- 3
Mamba’s selective state space adds input-driven parameterization so the model can modulate how much prior state to retain versus overwrite when new tokens arrive.
- 4
“Hunger Hippos” (H3) motivates collapsing matrix structure (e.g., toward diagonal-like efficiency) to reduce compute cost.
- 5
Mamba’s speed relies heavily on GPU-aware engineering: parallel scan, kernel fusion, and keeping intermediates in fast memory to avoid bandwidth bottlenecks.
- 6
Mamba targets linear-time training in sequence length and constant-time-per-step autoregressive inference without key-value caching.
- 7
The state space formulation constrains recurrent dynamics to be stable and gating-like, which helps optimization even over long sequences.