RetNet is a foundation architecture for LLMs that achieves training parallelism, low-cost inference, and good performance simultaneously. The “impossible triangle” of sequence modeling.
The key contribution is the Retention mechanism, which has a dual form of recurrence and parallelism. This means we can train models in parallel (like Transformers) while doing inference recurrently (like RNNs).
Three computation paradigms:
Paper: https://arxiv.org/abs/2307.08621
The three main sequence modeling families each make a different tradeoff:
| Training | Inference memory | Inference cost per step | Positional encoding | |
|---|---|---|---|---|
| Transformer | Parallel | KV cache grows with sequence | per token | RoPE / ALiBi / learned |
| RNN (RWKV, SSM) | Sequential | fixed state | per token | Built into recurrence |
| RetNet | Parallel | fixed state | per token | Built into diagonalization |
vs. Transformer. Attention computes over the full context at every step. Every new token attends to all previous tokens, so inference memory and compute grow linearly with sequence length. RetNet replaces softmax with the decay matrix (no normalization across positions, just exponential weighting) and compresses the entire history into a fixed-size state . The tradeoff: Transformers can attend to any past token equally; RetNet exponentially forgets distant tokens controlled by .
vs. RNNs / RWKV / SSMs. These also maintain a fixed state and have O(1) inference, but they are trained sequentially — each step depends on the previous one, so training cannot be parallelized across the sequence. RetNet gets both: because the recurrence can be unrolled into a closed-form sum (the parallel form), the full sequence can be computed in one matrix multiply during training, just like Transformers.
vs. Linear Attention. Linear attention replaces softmax with a kernel and also has O(1) recurrent inference. RetNet is similar in spirit but adds two things linear attention lacks: (1) the exponential decay which gives a proper forgetting mechanism rather than equal weighting of all past tokens, and (2) the complex rotations from diagonalizing which give relative position encoding for free.
| Symbol | Shape | What it is |
|---|---|---|
| Full input sequence, one row per token | ||
| Learned projection matrices | ||
| Query, Key, Value vectors for token | ||
| Retention state — running memory accumulating all past KV pairs | ||
| Output for token , read from the state | ||
| Recurrence matrix — controls how state evolves | ||
| scalar | Decay factor — eigenvalue magnitude of after diagonalization | |
| scalar | Rotation frequency — eigenvalue angle of , encodes relative position | |
| diagonal | Change-of-basis that diagonalizes , absorbed into / | |
| Causal decay matrix — combines mask and exponential decay |
The intuition: is the memory at step . It accumulates all past pairs, exponentially forgetting older ones via . The output is a query into that memory.
The input sequence is packed into a matrix .
The general recurrence. The retention state starts as:
is a matrix — an outer-product memory. Each token writes into it (store this key-value association), while decays what is already there.
Why diagonalize ? A full matrix is expensive and hard to train stably. The paper constrains it to be diagonalizable:
—
is diagonal with entries — a magnitude and a rotation per dimension. Absorbing into and into (they are just linear projections), disappears from the recurrence:
So is not a hyperparameter — it is the eigenvalue magnitude of . The state update is now O() per step with no matrix multiply.
Position encoding comes for free. The absorbed matrices become xPos-style complex rotations on and :
The inner product then gives — relative position is baked in. No separate RoPE layer needed on and ; the diagonalization of handles it.
Reading out the output and unrolling the recurrence:
Token ’s contribution to the current output is weighted by — exponentially smaller the further back it is. The conjugate on is what makes the phase difference emerge.
For training, process the full sequence at once. Apply rotations across all positions:
where and is its conjugate (so gets the conjugated rotation, matching the recurrent form).
The causal decay matrix encodes both causal masking and decay:
Row , column of is exactly the weight that token gets when computing output at position . Future tokens are zeroed.
Mathematically identical to the recurrent form, just computed all at once.
The chunkwise form is the training-time workhorse for long sequences. The idea: split the sequence into non-overlapping chunks of length , run the parallel form inside each chunk, and pass a recurrent state across chunks.
Notation. The -th chunk extracts:
Each is a matrix. The matrix is now (causal decay within the chunk).
Output for chunk — two terms:
The intra-chunk term is a standard parallel retention over tokens. Fully parallelizable, just like section 2.2.
The cross-chunk term reads from the carried state (a matrix summarizing all history before chunk ). The scaling matrix weights each position within the chunk by how far it is from the start of the chunk:
Position (start of chunk) gets weight ; position gets . This ensures the cross-chunk contribution decays correctly relative to the global timeline.
State update across chunks:
where down-weights tokens earlier in the chunk before folding them into the state (they will be further in the past for future chunks). The full previous state is discounted by — one chunk length of decay.
. The state has fixed size regardless of total sequence length.
Why this is efficient. For a sequence of length with chunk size , there are chunks. Each chunk costs for the intra-chunk parallel term and for the cross-chunk recurrent term. Total: — linear in . Standard attention is . The cross-chunk recurrence is sequential but only involves updates, which are cheap.
| Intra-chunk | Cross-chunk | State size | |
|---|---|---|---|
| Compute | per chunk, parallel | per chunk, sequential | — |
| Memory | attention matrix | fixed |
Raw retention scores can vary wildly in magnitude, especially as sequence length grows or as varies per head. The paper exploits a key property: GroupNorm is scale-invariant — for any scalar . So we can insert normalizing scale factors at intermediate steps without changing the final output.
Three stabilization steps, applied in sequence:
Step 1 — Scale the scores:
Standard scaled dot-product to prevent large inner products (same as Transformer).
Step 2 — Normalize the decay matrix :
Divide each row of by the square root of its row sum. This normalizes the total weight each output position accumulates from the past, preventing early positions (which see fewer past tokens) from having systematically different magnitudes than late positions.
Step 3 — Normalize the retention scores :
Divide each row of by its absolute row sum, clamped to at least 1. The clamp prevents division by near-zero values when the row sum is small. This keeps the magnitude of the aggregated output bounded regardless of sequence length.
Because GroupNorm absorbs any constant rescaling, all three steps leave the mathematical output unchanged while keeping intermediate activations numerically well-behaved for both forward and backward passes.
Just like multi-head attention, retention is computed in parallel across heads. The “multi-scale” part: each head gets its own , giving different memory horizons:
For this gives . Heads with close to 1 retain information over long ranges; heads with smaller focus on local context. The model uses both simultaneously.
Different values produce activations with different variances. A single LayerNorm would conflate them. GroupNorm with groups normalizes each head independently, correcting each head’s variance before concatenation.
Full MSR block:
The Swish gate is a learned input-dependent signal that modulates the retention output before the final projection. Similar to SwiGLU in FFN layers, it adds non-linearity and improves model capacity.
All three forms compute the same function. The choice is purely about what is efficient given your hardware and sequence length.
Recurrent unrolls as a step-by-step state update. Token arrives, updates , and immediately produces output . You never store past tokens — only the state.
Parallel treats the entire sequence as a single matrix operation. No recurrence, no state. The matrix encodes all the decay weights at once. This is what GPUs are built for.
Chunkwise splits the sequence into blocks of size . Inside each block: parallel. Across blocks: recurrent state . It is a controlled interpolation between the two.
| Training compute | Training memory | Inference compute/step | Inference memory | |
|---|---|---|---|---|
| Parallel | — (not used) | — | ||
| Recurrent | sequential | |||
| Chunkwise | — (not used) | — | ||
| Transformer | KV cache |
For typical , , : chunkwise is roughly cheaper than parallel training.
Recurrent is used at inference time. You process one token at a time, keep in memory, output . No KV cache. Memory is and fixed.
Parallel is used for short-sequence training. One shot, maximum GPU utilization, but memory means it does not scale past a few thousand tokens.
Chunkwise is used for long-sequence training. The intra-chunk matrix stays small, and the cross-chunk state does not grow with sequence length. For sequences in the tens of thousands, this is the only practical option.
The three forms are not approximations of each other — they are algebraically identical. The parallel form is the chunkwise form with (one chunk = whole sequence). The recurrent form is the chunkwise form with (one chunk = one token). Chunkwise with intermediate just picks a sweet spot for GPU efficiency.
This is different from Transformers + RNNs tradeoff, where you would have to choose your architecture. In RetNet you train with chunkwise (parallel within chunk, recurrent across chunks) and switch to the recurrent form at inference — same weights, same math, different compute schedule.
WIP
WIP
WIP
WIP