RetNet: Retentive Network

RetNet is a foundation architecture for LLMs that achieves training parallelism, low-cost inference, and good performance simultaneously. The “impossible triangle” of sequence modeling.

The key contribution is the Retention mechanism, which has a dual form of recurrence and parallelism. This means we can train models in parallel (like Transformers) while doing inference recurrently (like RNNs).

Three computation paradigms:

  • Parallel — training parallelism on GPU
  • Recurrent — O(1) inference, reducing decode latency and GPU memory
  • Chunkwise Recurrent — efficient long-sequence modeling

Paper: https://arxiv.org/abs/2307.08621


Comparison with Other Architectures

The three main sequence modeling families each make a different tradeoff:

Training Inference memory Inference cost per step Positional encoding
Transformer Parallel O(n)O(n) KV cache grows with sequence O(n)O(n) per token RoPE / ALiBi / learned
RNN (RWKV, SSM) Sequential O(1)O(1) fixed state O(1)O(1) per token Built into recurrence
RetNet Parallel O(1)O(1) fixed state SnS_n O(1)O(1) per token Built into AA diagonalization

vs. Transformer. Attention computes softmax(QKT/d)V\text{softmax}(QK^T / \sqrt{d}) V over the full context at every step. Every new token attends to all previous tokens, so inference memory and compute grow linearly with sequence length. RetNet replaces softmax with the decay matrix DD (no normalization across positions, just exponential weighting) and compresses the entire history into a fixed-size state SnS_n. The tradeoff: Transformers can attend to any past token equally; RetNet exponentially forgets distant tokens controlled by γ\gamma.

vs. RNNs / RWKV / SSMs. These also maintain a fixed state and have O(1) inference, but they are trained sequentially — each step depends on the previous one, so training cannot be parallelized across the sequence. RetNet gets both: because the recurrence Sn=γSn1+KnTVnS_n = \gamma S_{n-1} + K_n^T V_n can be unrolled into a closed-form sum (the parallel form), the full sequence can be computed in one matrix multiply during training, just like Transformers.

vs. Linear Attention. Linear attention replaces softmax with a kernel ϕ(Q)ϕ(K)T\phi(Q)\phi(K)^T and also has O(1) recurrent inference. RetNet is similar in spirit but adds two things linear attention lacks: (1) the exponential decay γnm\gamma^{n-m} which gives a proper forgetting mechanism rather than equal weighting of all past tokens, and (2) the complex rotations from diagonalizing AA which give relative position encoding for free.


Variables at a Glance

Symbol Shape What it is
XX x×dmodel\|x\| \times d_{\text{model}} Full input sequence, one row per token
WQ,WK,WVW_Q, W_K, W_V d×dd \times d Learned projection matrices
Qn,Kn,VnQ_n, K_n, V_n 1×d1 \times d Query, Key, Value vectors for token nn
SnS_n d×dd \times d Retention state — running memory accumulating all past KV pairs
ono_n 1×d1 \times d Output for token nn, read from the state
AA d×dd \times d Recurrence matrix — controls how state evolves
γ\gamma scalar (0,1)\in (0,1) Decay factor — eigenvalue magnitude of AA after diagonalization
θ\theta scalar Rotation frequency — eigenvalue angle of AA, encodes relative position
Λ\Lambda diagonal d×dd \times d Change-of-basis that diagonalizes AA, absorbed into WQW_Q/WKW_K
DD x×x\|x\| \times \|x\| Causal decay matrix — combines mask and exponential decay

The intuition: SnS_n is the memory at step nn. It accumulates all past (Km,Vm)(K_m, V_m) pairs, exponentially forgetting older ones via γ\gamma. The output ono_n is a query into that memory.


2.1 Retention: Recurrent Form

The input sequence {xi}\{x_i\} is packed into a matrix XRx×dmodelX \in \mathbb{R}^{|x| \times d_{\text{model}}}.

The general recurrence. The retention state starts as:

Sn=ASn1+KnTVnS_n = A \, S_{n-1} + K_n^T V_n

SnS_n is a d×dd \times d matrix — an outer-product memory. Each token writes KnTVnK_n^T V_n into it (store this key-value association), while AA decays what is already there.

Why diagonalize AA? A full d×dd \times d matrix AA is expensive and hard to train stably. The paper constrains it to be diagonalizable:

A=Λ(γeiΘ)Λ1A = \Lambda \, (\gamma \, e^{i\Theta}) \, \Lambda^{-1}

γeiΘ\gamma e^{i\Theta} is diagonal with entries γeiθj\gamma e^{i\theta_j} — a magnitude γ\gamma and a rotation θj\theta_j per dimension. Absorbing Λ\Lambda into WQW_Q and Λ1\Lambda^{-1} into WKW_K (they are just linear projections), AA disappears from the recurrence:

Sn=γSn1+KnTVnS_n = \gamma \, S_{n-1} + K_n^T V_n

So γ\gamma is not a hyperparameter — it is the eigenvalue magnitude of AA. The state update is now O(d2d^2) per step with no matrix multiply.

Position encoding comes for free. The absorbed Λ\Lambda matrices become xPos-style complex rotations on QQ and KK:

QnQneinθ,KmKmeimθQ_n \leftarrow Q_n e^{in\theta}, \qquad K_m \leftarrow K_m e^{im\theta}

The inner product QnKmQ_n K_m^* then gives ei(nm)θe^{i(n-m)\theta}relative position nmn-m is baked in. No separate RoPE layer needed on QQ and KK; the diagonalization of AA handles it.

Reading out the output and unrolling the recurrence:

on=QnSn=m=1nγnm(Qneinθ)(Kmeimθ)Vmo_n = Q_n S_n = \sum_{m=1}^{n} \gamma^{n-m} \left( Q_n e^{in\theta} \right) \left( K_m e^{im\theta} \right)^* V_m

Token mm’s contribution to the current output is weighted by γnm\gamma^{n-m} — exponentially smaller the further back it is. The conjugate ()(\cdot)^* on KmK_m is what makes the phase difference ei(nm)θe^{i(n-m)\theta} emerge.


2.2 Retention: Parallel Form

For training, process the full sequence at once. Apply rotations across all positions:

Q=(XWQ)Θ,K=(XWK)Θˉ,V=XWVQ = (X W_Q) \odot \Theta, \quad K = (X W_K) \odot \bar{\Theta}, \quad V = X W_V

where Θn=einθ\Theta_n = e^{in\theta} and Θˉ\bar{\Theta} is its conjugate (so KK gets the conjugated rotation, matching the recurrent form).

The causal decay matrix DRx×xD \in \mathbb{R}^{|x| \times |x|} encodes both causal masking and γ\gamma decay:

Dnm={γnmnm0n<mD_{nm} = \begin{cases} \gamma^{n-m} & n \geq m \\ 0 & n < m \end{cases}

Row nn, column mm of DD is exactly the weight γnm\gamma^{n-m} that token mm gets when computing output at position nn. Future tokens are zeroed.

Retention(X)=(QKTD)V\text{Retention}(X) = (Q K^T \odot D) V

Mathematically identical to the recurrent form, just computed all at once.


2.3 Retention: Chunkwise Recurrent Form

The chunkwise form is the training-time workhorse for long sequences. The idea: split the sequence into non-overlapping chunks of length BB, run the parallel form inside each chunk, and pass a recurrent state RiR_i across chunks.

Notation. The ii-th chunk extracts:

Q[i]=Q(Bi:B(i+1)),K[i]=K(Bi:B(i+1)),V[i]=V(Bi:B(i+1))Q_{[i]} = Q_{(Bi\,:\,B(i+1))}, \quad K_{[i]} = K_{(Bi\,:\,B(i+1))}, \quad V_{[i]} = V_{(Bi\,:\,B(i+1))}

Each is a B×dB \times d matrix. The DD matrix is now B×BB \times B (causal decay within the chunk).

Output for chunk ii — two terms:

Retention(X)[i]=(Q[i]K[i]TD)V[i]intra-chunk+(Q[i]Ri1)ξcross-chunk\text{Retention}(X)_{[i]} = \underbrace{(Q_{[i]} K_{[i]}^T \odot D)\, V_{[i]}}_{\text{intra-chunk}} + \underbrace{(Q_{[i]}\, R_{i-1}) \odot \xi}_{\text{cross-chunk}}

The intra-chunk term is a standard parallel retention over BB tokens. Fully parallelizable, just like section 2.2.

The cross-chunk term reads from the carried state Ri1R_{i-1} (a d×dd \times d matrix summarizing all history before chunk ii). The scaling matrix ξ\xi weights each position within the chunk by how far it is from the start of the chunk:

ξj=γj+1,j=0,,B1\xi_{j} = \gamma^{j+1}, \quad j = 0, \ldots, B-1

Position j=0j=0 (start of chunk) gets weight γ1\gamma^1; position j=B1j=B-1 gets γB\gamma^B. This ensures the cross-chunk contribution decays correctly relative to the global timeline.

State update across chunks:

Ri=K[i]T(V[i]ζ)+γBRi1R_i = K_{[i]}^T \left(V_{[i]} \odot \zeta\right) + \gamma^B\, R_{i-1}

where ζj=γBj1\zeta_j = \gamma^{B-j-1} down-weights tokens earlier in the chunk before folding them into the state (they will be further in the past for future chunks). The full previous state Ri1R_{i-1} is discounted by γB\gamma^B — one chunk length of decay.

R0=0R_0 = 0. The state has fixed size d×dd \times d regardless of total sequence length.

Why this is efficient. For a sequence of length nn with chunk size BB, there are n/Bn/B chunks. Each chunk costs O(B2d)O(B^2 d) for the intra-chunk parallel term and O(Bd2)O(B d^2) for the cross-chunk recurrent term. Total: O(nd(B+d))O(n d (B + d)) — linear in nn. Standard attention is O(n2d)O(n^2 d). The cross-chunk recurrence is sequential but only involves d×dd \times d updates, which are cheap.

Intra-chunk Cross-chunk State size
Compute O(B2d)O(B^2 d) per chunk, parallel O(Bd2)O(B d^2) per chunk, sequential
Memory O(B2)O(B^2) attention matrix O(d2)O(d^2) fixed d×dd \times d

2.5 Retention Score Normalization

Raw retention scores can vary wildly in magnitude, especially as sequence length grows or as γ\gamma varies per head. The paper exploits a key property: GroupNorm is scale-invariantGroupNorm(αh)=GroupNorm(h)\text{GroupNorm}(\alpha \cdot h) = \text{GroupNorm}(h) for any scalar α\alpha. So we can insert normalizing scale factors at intermediate steps without changing the final output.

Three stabilization steps, applied in sequence:

Step 1 — Scale the scores:

QKTd\frac{QK^T}{\sqrt{d}}

Standard scaled dot-product to prevent large inner products (same as Transformer).

Step 2 — Normalize the decay matrix DD:

D^nm=Dnmi=1nDni\hat{D}_{nm} = \frac{D_{nm}}{\sqrt{\sum_{i=1}^{n} D_{ni}}}

Divide each row of DD by the square root of its row sum. This normalizes the total weight each output position accumulates from the past, preventing early positions (which see fewer past tokens) from having systematically different magnitudes than late positions.

Step 3 — Normalize the retention scores R=QKTD^R = QK^T \odot \hat{D}:

R^nm=Rnmmax ⁣(i=1nRni,  1)\hat{R}_{nm} = \frac{R_{nm}}{\max\!\left(\left|\sum_{i=1}^{n} R_{ni}\right|,\; 1\right)}

Divide each row of RR by its absolute row sum, clamped to at least 1. The clamp prevents division by near-zero values when the row sum is small. This keeps the magnitude of the aggregated output bounded regardless of sequence length.

Because GroupNorm absorbs any constant rescaling, all three steps leave the mathematical output unchanged while keeping intermediate activations numerically well-behaved for both forward and backward passes.


2.6 Multi-Scale Retention (MSR)

Just like multi-head attention, retention is computed in parallel across h=dmodel/dh = d_{\text{model}} / d heads. The “multi-scale” part: each head gets its own γi\gamma_i, giving different memory horizons:

γ=125arange(0,h)\gamma = 1 - 2^{-5 - \text{arange}(0,\, h)}

For h=8h = 8 this gives γ{1132,  1164,  ,  11512}\gamma \in \{1 - \tfrac{1}{32},\; 1 - \tfrac{1}{64},\; \ldots,\; 1 - \tfrac{1}{512}\}. Heads with γ\gamma close to 1 retain information over long ranges; heads with smaller γ\gamma focus on local context. The model uses both simultaneously.

Different γ\gamma values produce activations with different variances. A single LayerNorm would conflate them. GroupNorm with hh groups normalizes each head independently, correcting each head’s variance before concatenation.

Full MSR block:

headi=Retention(X,  γi)\text{head}_i = \text{Retention}(X,\; \gamma_i)

Y=GroupNormh ⁣(Concat(head1,,headh))Y = \text{GroupNorm}_h\!\left(\text{Concat}(\text{head}_1, \ldots, \text{head}_h)\right)

MSR(X)=(swish(XWG)Y)WO\text{MSR}(X) = \big(\text{swish}(X W_G) \odot Y\big)\, W_O

The Swish gate swish(XWG)\text{swish}(X W_G) is a learned input-dependent signal that modulates the retention output before the final projection. Similar to SwiGLU in FFN layers, it adds non-linearity and improves model capacity.


Appendix A: The Three Forms Side by Side

All three forms compute the same function. The choice is purely about what is efficient given your hardware and sequence length.

What each form does

Recurrent unrolls as a step-by-step state update. Token nn arrives, updates SnS_n, and immediately produces output ono_n. You never store past tokens — only the d×dd \times d state.

Parallel treats the entire sequence as a single matrix operation. No recurrence, no state. The DD matrix encodes all the decay weights at once. This is what GPUs are built for.

Chunkwise splits the sequence into blocks of size BB. Inside each block: parallel. Across blocks: recurrent state RiR_i. It is a controlled interpolation between the two.

Complexity comparison

Training compute Training memory Inference compute/step Inference memory
Parallel O(n2d)O(n^2 d) O(n2)O(n^2) — (not used)
Recurrent O(nd2)O(n d^2) sequential O(d2)O(d^2) O(d2)O(d^2) O(d2)O(d^2)
Chunkwise O(nd(B+d))O(nd(B+d)) O(Bd+d2)O(Bd + d^2) — (not used)
Transformer O(n2d)O(n^2 d) O(n2)O(n^2) O(nd)O(nd) O(nd)O(nd) KV cache

For typical d=256d = 256, B=512B = 512, n=8192n = 8192: chunkwise is roughly 8×8\times cheaper than parallel training.

When to use each

Recurrent is used at inference time. You process one token at a time, keep SnS_n in memory, output ono_n. No KV cache. Memory is O(d2)O(d^2) and fixed.

Parallel is used for short-sequence training. One shot, maximum GPU utilization, but O(n2)O(n^2) memory means it does not scale past a few thousand tokens.

Chunkwise is used for long-sequence training. The B×BB \times B intra-chunk matrix stays small, and the d×dd \times d cross-chunk state does not grow with sequence length. For sequences in the tens of thousands, this is the only practical option.

The key insight

The three forms are not approximations of each other — they are algebraically identical. The parallel form is the chunkwise form with B=nB = n (one chunk = whole sequence). The recurrent form is the chunkwise form with B=1B = 1 (one chunk = one token). Chunkwise with intermediate BB just picks a sweet spot for GPU efficiency.

This is different from Transformers + RNNs tradeoff, where you would have to choose your architecture. In RetNet you train with chunkwise (parallel within chunk, recurrent across chunks) and switch to the recurrent form at inference — same weights, same math, different compute schedule.


Section 3: Code (WIP)

3.1 Retention: Recurrent Form

WIP

3.2 Retention: Parallel Form

WIP

3.3 Retention: Chunkwise Recurrent Form

WIP

3.4 Multi-Scale Retention (MSR)

WIP