AI & Machine Learning

Beyond Self-Attention: What Comes After Transformers

The transformer architecture has been running AI for eight years now. Every major language model, most image generation systems, and an increasing number of audio and video models are built on the self-attention mechanism introduced in the 'Attention Is All You Need' paper. But self-attention has a fundamental problem: its compute and memory cost scales quadratically with sequence length. Double the input length, quadruple the cost.

For short sequences this doesn't matter. For the 128K-token context windows we're pushing toward — and the million-token windows people want — it's a serious bottleneck. A wave of research is exploring alternatives: attention residuals that reuse computation across layers, linear attention variants that drop the quadratic cost, and hybrid architectures that mix attention with cheaper mechanisms. The transformer isn't going away, but it's being reshaped.

Why Self-Attention Is Expensive

To understand the alternatives, you need to understand what self-attention actually computes. Given a sequence of N tokens, self-attention calculates a relevance score between every pair of tokens. Token 1 vs token 2, token 1 vs token 3, ..., token 1 vs token N, then token 2 vs every other token, and so on. That's N² pairs.

import torch
import torch.nn.functional as F
def self_attention(Q, K, V):
"""
Standard self-attention.
Q, K, V: (batch, seq_len, d_model)
The attention matrix is seq_len × seq_len.
For seq_len = 1024:   ~1M entries   (manageable)
For seq_len = 32768:  ~1B entries   (expensive)
For seq_len = 131072: ~17B entries  (very expensive)
"""
d_k = Q.size(-1)
# This matmul creates the N×N attention matrix
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V)

At 4K tokens, the attention matrix has 16 million entries — no problem for modern GPUs. At 128K tokens, it has 16 billion entries. At 1 million tokens, it's over a trillion. Even with Flash Attention (which doesn't reduce computation but dramatically improves memory access patterns), the quadratic scaling eventually wins.

This is why early transformers were limited to 512 or 1024 tokens. Each generation of hardware and optimization has pushed the ceiling higher, but we're fighting a mathematical wall. Linear scaling (O(N)) would be fundamentally better than quadratic scaling (O(N²)), and that's what most alternative architectures pursue.

Attention Residuals: Reusing What You've Already Computed

One of the most pragmatic approaches to reducing attention cost doesn't replace attention — it makes each attention layer cheaper by reusing computation from previous layers.

The observation: in a deep transformer (say, 32 layers), the attention patterns in adjacent layers are often remarkably similar. Layer 15 and layer 16 tend to attend to similar positions, with small adjustments. Computing the full N² attention matrix from scratch at every layer is redundant — much of the work was already done one layer ago.

Attention residuals exploit this by computing a 'residual' attention pattern: the difference between what this layer wants to attend to and what the previous layer computed. If the difference is small (which it usually is in middle layers), the computation is cheaper. The full attention pattern is the sum of the previous layer's pattern plus the current layer's residual.

This is analogous to how video compression works: rather than storing each frame independently, you store a keyframe and then a series of differences (residuals) from that keyframe. The differences are usually much smaller than the full frame, so compression is dramatically better.

In practice, attention residuals reduce the compute cost of attention by 30-50% in the middle layers of deep models with minimal quality impact. The first and last few layers still need full attention computation (their patterns are more distinct), but the middle layers — which are the majority — get significant speedups.

Linear Attention: Dropping the Quadratic Cost

Linear attention variants attempt to reformulate attention so it scales as O(N) instead of O(N²). The general approach: instead of computing the N×N attention matrix explicitly, find a way to compute the same (or approximately the same) output using linear operations.

The mathematical trick relies on the kernel decomposition of softmax. Standard attention computes softmax(QK^T)V. If you replace softmax with a different kernel function that can be decomposed as φ(Q) · φ(K)^T, you can rearrange the computation order: instead of (φ(Q) · φ(K)^T) · V (which has an N×N intermediate), compute φ(Q) · (φ(K)^T · V) (which has a d×d intermediate, where d is the model dimension). Since d << N for long sequences, this is dramatically cheaper.

def linear_attention(Q, K, V, feature_map=None):
"""
Linear attention via kernel feature maps.
Cost: O(N * d^2) instead of O(N^2 * d)
"""
if feature_map is None:
# ELU+1 is a common choice (from Katharopoulos et al.)
feature_map = lambda x: F.elu(x) + 1
Q = feature_map(Q)  # (batch, seq_len, d)
K = feature_map(K)  # (batch, seq_len, d)
# Key insight: compute K^T @ V first (d × d matrix)
# instead of Q @ K^T first (N × N matrix)
KV = torch.einsum('bnd,bnm->bdm', K, V)  # (batch, d, d)
# Then multiply by Q
output = torch.einsum('bnd,bdm->bnm', Q, KV)  # (batch, N, d)
# Normalize
Z = torch.einsum('bnd,bd->bn', Q, K.sum(dim=1))  # normalization
output = output / Z.unsqueeze(-1)
return output

The catch: replacing softmax with a different kernel function changes the attention distribution, and models trained with softmax attention don't necessarily transfer well to linear attention. The quality gap has narrowed significantly — recent linear attention variants achieve 95-98% of softmax attention quality — but the gap persists, especially on tasks that require precise long-range retrieval.

State Space Models: A Different Paradigm

State space models (SSMs) like Mamba take a fundamentally different approach. Instead of computing pairwise token relationships, they process the sequence through a recurrence — maintaining a fixed-size hidden state that gets updated at each token. This is inherently O(N): processing twice as many tokens takes twice as long, not four times.

The innovation in modern SSMs is making the recurrence parameters depend on the input (selective state spaces). This gives the model a form of content-based attention — it can 'choose' which information to remember and which to forget — without the quadratic cost. Mamba-style models match transformer quality on many benchmarks while being significantly faster for long sequences.

The trade-off: SSMs process tokens sequentially, which makes them harder to parallelize during training compared to transformers (which can process all tokens simultaneously). Training efficiency matters — a model that's 2x faster at inference but 3x slower to train isn't necessarily a win, since most of the total compute goes to training.

Hybrid Architectures: The Pragmatic Path

The current trend in production models is hybrid architectures that combine different attention mechanisms. The reasoning is simple: different parts of a model benefit from different types of computation.

  • Full attention for global reasoning. Some layers need to attend across the entire sequence — finding relevant context thousands of tokens away. These layers get standard (possibly Flash-optimized) self-attention.
  • Local attention for nearby context. Many layers primarily attend to nearby tokens (sliding window attention). Using a fixed window of 256-1024 tokens reduces cost to O(N·W) where W is the window size.
  • Linear attention for broad context. Some layers need to aggregate information across the sequence but don't need precise attention weights. Linear attention provides this at O(N) cost.
  • SSM layers for sequential processing. Mamba-style layers can efficiently process sequential dependencies without any attention computation.

Models like Jamba (AI21) and various research architectures alternate between these mechanisms based on the layer's role. Early layers use local attention (processing syntax and local patterns). Middle layers use linear attention or SSMs (building broader representations). A few strategic layers use full attention (global reasoning and retrieval). This gives near-linear overall scaling while preserving the model quality that requires some full attention.

What Developers Should Watch

If you're building applications on top of language models, the architectural changes happening underneath affect your work in concrete ways.

  • Context windows will keep growing. As attention costs drop, context windows expand. This changes application architecture: instead of building complex RAG pipelines to fit relevant context into a 4K window, you might just stuff everything into a 1M-token prompt. The simplicity is appealing, but latency and cost implications differ between architectures.
  • Latency profiles change. Transformers have relatively flat latency up to a point, then it increases quadratically. Linear-attention and SSM models have more gradual, linear latency increases. For applications where response time matters, understanding your model's scaling behavior matters.
  • Quality differences are task-dependent. Linear attention models may slightly underperform on tasks requiring precise retrieval from specific positions in long contexts ('what was the third item in the list on page 47?'). They perform equally well on tasks requiring general understanding. Know your use case.
  • Inference optimization matters more. As models get more architecturally complex (mixing different attention types), inference engines need to handle heterogeneous computation efficiently. vLLM, TensorRT-LLM, and similar frameworks are adapting, but custom architectures may not be immediately supported.

The transformer isn't being replaced — it's being evolved. Self-attention remains the most expressive mechanism we have for modeling relationships between tokens. But it doesn't need to be used everywhere, for every layer, at full N² cost. The models of the next few years will use attention surgically — full precision where it matters most, cheaper alternatives everywhere else. The result will be models that are faster, handle longer contexts, and cost less to run, while matching or exceeding current quality. That's worth paying attention to.