Transformer Normalisation Layers and the Independence of Semantic Subspaces

Stephen Menary
University of Manchester, UK
[email protected]
Samuel Kaski
1 University of Manchester, UK
2 Aalto University, Finland
[email protected]
André Freitas
1 Department of Computer Science, University of Manchester, UK
2 Idiap Research Institute, Switzerland
3 National Biomarker Centre, CRUK-MI, University of Manchester, UK
[email protected]
Abstract

Recent works have shown that transformers can solve contextual reasoning tasks by internally executing computational graphs called circuits. Circuits often use attention to logically match information from subspaces of the representation, e.g. using position-in-sequence to identify the previous token. In this work, we consider a semantic subspace to be any independent subspace of the latent representation that can fully determine an attention distribution. We show that Pre-Norm, the placement of normalisation layer used by state-of-the-art transformers, violates this ability unless the model learns a strict representation structure of orthogonal spheres. This is because it causes linear subspaces to interfere through their common normalisation factor. Theoretically, we analyse circuit stability by modelling this interference as random noise on the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms of the query/key/value vectors, predicting a phenomenon of circuit collapse when sparse-attention shifts to a different token. Empirically, we investigate the sensitivity of real-world models trained for mathematical addition, observing a 1% rate of circuit collapse when the norms are artificially perturbed by less-than-or-similar-to\lesssim10%. We contrast Pre-Norm with QKV-Norm, which places normalisation after the attention head’s linear operators. Theoretically this relaxes the representational constraints. Empirically we observe comparable in-distribution but worse out-of-distribution performance.

1 Introduction

Transformer-based models [1] are commonplace in machine learning, providing state-of-the-art contextual reasoning in domains ranging from natural language [2, 3] to protein-folding [4, 5, 6] and theoretical physics [7]. Recent interpretability work investigates the internal mechanisms that lead to specific model behaviours [8, 9, 10, 11, 12, 13, 14, 15]. This is important for predicting behaviour in new environments, enables practitioners to match the inductive bias of a model with the structure of its task, and informs the design of architectures that promote desirable behaviour.

Two such works discovered complete circuits [16] in trained transformers [8, 9, 10]. These are computational graphs that dominate the model prediction when activated in a specialised context. They perform a type of algorithmic reasoning by internally executing a sequence of logical operations, using attention to pass information between memory buffers that begin as token embeddings and become increasingly abstract. Furthermore, a number of attention heads have been identified as performing logical operations (see [11] section 5). To understand transformer behaviour, an important goal is to understand how logical attention heads operate, and their generality beyond the simple cases that facilitate interpretability.

One key observation is that the attention distribution is sometimes fully-determined by an independent subspace of the representation - for example, an attention layer can identify the previous token by accessing a subspace that encodes only position-in-sequence. Indeed, low-rank weight matrices can only access linear subspaces by construction. A second observation is that, like most deep architectures, transformers use normalisation layers to improve training stability. A leading choice is to place normalisation at the input to each attention layer, which we call Pre-Norm [17]. Some interpretability works ignore this layer because it has a linear-up-to-scale structure, absorbing the linear part into adjacent weights. In this work we argue that the layer is important, because Pre-Norm causes independent linear subspaces to interfere through a common normalisation factor, preventing their separation by linear attention layers.

The purpose of this work is to ask: if the use of independent subspaces is generally important, what are the expected consequences of Pre-Norm for (i) the latent representation structure, and (ii) circuit stability? To answer this, we take an abstract approach that complements direct interpretability by considering general behaviour beyond the interpretable limit. Our contributions are:

  1. 1.

    Conceptual: we identify interference between independent subspaces as a potential destabiliser of circuits caused by Pre-Norm. We suggest separability of latent subspaces as a target for study, and show it is easily satisfied by the alternative QKV-Norm. This differs from Pre-Norm by placing the normalisation layer after the linear operators. It is similar to QK-Norm, for which sparse evidence currently exists [18, 19, 20].

  2. 2.

    Theoretical: we formalise a semantic subspace as any independent subspace of the latent representation that can fully determine the attention distribution. We show that Pre-Norm can only achieve this when semantic subspaces are spherical and mutually orthogonal. By contrast, QKV-Norm requires only that subspaces be linearly independent, matching the No-Norm case in this sense. We study the stability of attention to subspace interference, predicting a potentially problematic phenomenon of circuit collapse when a sparse-attention distribution changes which embedding it attends to.

  3. 3.

    Experimental: we measure the sensitivity of trained models to simulated interference in a numerical addition task. Constraining our predictions, we find that (i) Pre-Norm models induce a narrower distribution of embedding L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms than QKV-Norm, (ii) we bound the spread of L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms to ±20%plus-or-minuspercent20\pm 20\%± 20 % with 90% coverage, and (iii) the circuit collapse phenomenon occurs at a rate of 1% when norms are perturbed by 𝒪(10%)𝒪percent10\mathcal{O}(10\%)caligraphic_O ( 10 % ).

2 The idea

Independent subspaces are observed in real-world transformer circuits

Before providing a formal definition in section 5, we explain what we mean by a semantic subspace of the latent representation. To emphasise that this is observed in real-world models, we use a known example: the induction circuit [8, 9]. This two-layer circuit emerges in next-token-prediction models and implements a simple contextual reasoning algorithm called prefix-matching.

Consider text to be a sequence of tokens111In this example, we tokenise per-word to help with visualisation., and our task is to predict the next token at every point. The induction circuit solves this by copying a previous example from the context window: e.g. if the input includes the phrase “Harry Potter” and the last observed word was “Harry”, the induction circuit will predict that “Potter” comes next. This solves the task even if the combination “Harry Potter” never occurred in the training data.

To achieve this, we initially create an embedding for each token, encoding it’s position and type. Attention layers then copy information between embeddings in a directed way, using two components that determine (i) which embeddings to extract information from, and (ii) what to extract. Remarkably, the model learns to implement logical gates that we will call “match&pass”, internally composing the algorithm:

[Uncaptioned image]

Each match&pass step operates only on an independent subspace of information, which we will call a semantic subspace. In this example, there are four semantic subspaces corresponding to position, type, prev-type, and pred-suffix. We observe that the latent embeddings can contain various information, and it is instructive to think of them as memory buffers rather than tokens. The principle of composing logical operations that act on latent semantic subspaces is also observed in the more complex example of indirect-object identification in GPT2-Small [10].

The problem with Pre-Norm

We express the latent embeddings as x=αxα𝑥subscript𝛼subscript𝑥𝛼x=\sum_{\alpha}x_{\alpha}italic_x = ∑ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT where xαsubscript𝑥𝛼x_{\alpha}italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT encodes the value of concept α𝛼\alphaitalic_α. This is important, because linear-attention layers extract information from x𝑥xitalic_x using linear operators (section 4), and can only isolate xαsubscript𝑥𝛼x_{\alpha}italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT if each subspace {xα|α}conditional-setsubscript𝑥𝛼𝛼\{x_{\alpha}~{}|~{}\alpha\}{ italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT | italic_α } is linearly independent. In other words, there must always exist a linear projection operator Pαsubscript𝑃𝛼P_{\alpha}italic_P start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT such that Pαx=xαsubscript𝑃𝛼𝑥subscript𝑥𝛼P_{\alpha}x=x_{\alpha}italic_P start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT italic_x = italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT.

Most transformers use either RMSNorm [21] or LayerNorm [22] for their internal normalisation layers. Geometrically, RMSNorm projects a vector zN𝑧superscript𝑁z\in\mathbb{R}^{N}italic_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT onto the unit-sphere SN1superscript𝑆𝑁1S^{N-1}italic_S start_POSTSUPERSCRIPT italic_N - 1 end_POSTSUPERSCRIPT according to

zz|z|where|z|i=1Nzi2+istheL2norm.formulae-sequence𝑧𝑧𝑧where𝑧superscriptsubscript𝑖1𝑁superscriptsubscript𝑧𝑖2isthesubscript𝐿2normz~{}\rightarrow~{}\frac{z}{|z|}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{% }~{}~{}~{}\mathrm{where}~{}|z|~{}\triangleq~{}\sqrt[+]{\sum_{i=1}^{N}z_{i}^{2}% }~{}~{}\mathrm{is~{}the}~{}L_{2}\mathrm{-norm.}italic_z → divide start_ARG italic_z end_ARG start_ARG | italic_z | end_ARG roman_where | italic_z | ≜ nth-root start_ARG + end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG roman_is roman_the italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - roman_norm . (1)

LayerNorm is similar, projecting onto the sphere SN2superscript𝑆𝑁2S^{N-2}italic_S start_POSTSUPERSCRIPT italic_N - 2 end_POSTSUPERSCRIPT defined perpendicular to the direction 1Nsuperscript1𝑁1^{N}1 start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT. This does not affect our analysis, and we focus on RMSNorm for simplicity. Normalisation layers sometimes also include gain and/or bias parameters, applying a stretch-and-translate to the sphere. Pre-Norm [17] normalises the latent embeddings at the input to every attention layer. Consider the example x=xpos+xtype+xprevtype𝑥subscript𝑥possubscript𝑥typesubscript𝑥prevtypex=x_{\mathrm{pos}}+x_{\mathrm{type}}+x_{\mathrm{prev-type}}italic_x = italic_x start_POSTSUBSCRIPT roman_pos end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT roman_type end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT roman_prev - roman_type end_POSTSUBSCRIPT. Applying Pre-Norm, we find Pposx=xpossubscript𝑃pos𝑥subscript𝑥posP_{\mathrm{pos}}x=x_{\mathrm{pos}}italic_P start_POSTSUBSCRIPT roman_pos end_POSTSUBSCRIPT italic_x = italic_x start_POSTSUBSCRIPT roman_pos end_POSTSUBSCRIPT is replaced by

Pposx|x|=Pposx|x|=xpos|xpos+xtype+xprevtype|subscript𝑃pos𝑥𝑥subscript𝑃pos𝑥𝑥subscript𝑥possubscript𝑥possubscript𝑥typesubscript𝑥prevtypeP_{\mathrm{pos}}\frac{x}{|x|}~{}=~{}\frac{P_{\mathrm{pos}}x}{|x|}~{}=~{}\frac{% x_{\mathrm{pos}}}{|x_{\mathrm{pos}}~{}+~{}x_{\mathrm{type}}~{}+~{}x_{\mathrm{% prev-type}}|}italic_P start_POSTSUBSCRIPT roman_pos end_POSTSUBSCRIPT divide start_ARG italic_x end_ARG start_ARG | italic_x | end_ARG = divide start_ARG italic_P start_POSTSUBSCRIPT roman_pos end_POSTSUBSCRIPT italic_x end_ARG start_ARG | italic_x | end_ARG = divide start_ARG italic_x start_POSTSUBSCRIPT roman_pos end_POSTSUBSCRIPT end_ARG start_ARG | italic_x start_POSTSUBSCRIPT roman_pos end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT roman_type end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT roman_prev - roman_type end_POSTSUBSCRIPT | end_ARG (2)

Therefore it is impossible for a linear-attention layer to extract xpossubscript𝑥posx_{\mathrm{pos}}italic_x start_POSTSUBSCRIPT roman_pos end_POSTSUBSCRIPT without interference from xtypesubscript𝑥typex_{\mathrm{type}}italic_x start_POSTSUBSCRIPT roman_type end_POSTSUBSCRIPT and xprevtypesubscript𝑥prevtypex_{\mathrm{prev-type}}italic_x start_POSTSUBSCRIPT roman_prev - roman_type end_POSTSUBSCRIPT, unless |xpos+xtype+xprevtype|subscript𝑥possubscript𝑥typesubscript𝑥prevtype|x_{\mathrm{pos}}+x_{\mathrm{type}}+x_{\mathrm{prev-type}}|| italic_x start_POSTSUBSCRIPT roman_pos end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT roman_type end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT roman_prev - roman_type end_POSTSUBSCRIPT | is a constant. In general, we have Pαx|x|=xα|βxβ|subscript𝑃𝛼𝑥𝑥subscript𝑥𝛼subscript𝛽subscript𝑥𝛽P_{\alpha}\frac{x}{|x|}=\frac{x_{\alpha}}{|\sum_{\beta}x_{\beta}|}italic_P start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT divide start_ARG italic_x end_ARG start_ARG | italic_x | end_ARG = divide start_ARG italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_ARG start_ARG | ∑ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT | end_ARG, and semantic subspaces are entangled unless |αxα|subscript𝛼subscript𝑥𝛼|\sum_{\alpha}x_{\alpha}|| ∑ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT | is constant. This is only possible if |xα|2=constααsuperscriptsubscript𝑥𝛼2𝑐𝑜𝑛𝑠subscript𝑡𝛼for-all𝛼|x_{\alpha}|^{2}=const_{\alpha}~{}\forall~{}\alpha| italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_c italic_o italic_n italic_s italic_t start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ∀ italic_α, i.e. every subspace is a sphere, and xαTxβ=0xα,xβαsuperscriptsubscript𝑥𝛼𝑇subscript𝑥𝛽0for-allsubscript𝑥𝛼subscript𝑥𝛽𝛼x_{\alpha}^{T}x_{\beta}=0~{}\forall~{}x_{\alpha},x_{\beta\neq\alpha}italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT = 0 ∀ italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_β ≠ italic_α end_POSTSUBSCRIPT, i.e. all spheres are orthogonal (to maintain independence). This has several possible implications:

  1. 1.

    It is a restrictive structure that must be learned during training, with unknown difficulty. Finite steps of gradient descent may separate the model from the manifold of acceptable representations, hindering the learning of circuit components that require semantic separation, like match&pass, especially when training with large learning rates.

  2. 2.

    The constraint |xα|2=constαsuperscriptsubscript𝑥𝛼2𝑐𝑜𝑛𝑠subscript𝑡𝛼|x_{\alpha}|^{2}=const_{\alpha}| italic_x start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_c italic_o italic_n italic_s italic_t start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT removes a degree of freedom for every α𝛼\alphaitalic_α, reducing the information capacity of the embedding space. For example, an embedding on 5superscript5\mathbb{R}^{5}blackboard_R start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT could have the two-subspace structure 𝒮2𝒮1superscript𝒮2direct-sumsuperscript𝒮1\mathcal{S}^{2}\bigoplus\mathcal{S}^{1}caligraphic_S start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⨁ caligraphic_S start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT but not 𝒮2𝒮2superscript𝒮2direct-sumsuperscript𝒮2\mathcal{S}^{2}\bigoplus\mathcal{S}^{2}caligraphic_S start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⨁ caligraphic_S start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

  3. 3.

    We hypothesise that the structure may be violated by (i) a tradeoff with other representational effects, (ii) imperfect model training, or (iii) encountering unexpected semantic combinations at inference-time when generalising out-of-distribution. These would cause semantic subspaces to interfere through their common normalisation factor, manifesting as noise on the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms of the {query, key, value} vectors.

  4. 4.

    It is a structure that we can search for empirically.

A possible solution: QKV-Norm

A natural fix could be to apply the normalisation layer after the linear operators. In practice this means that we normalise the {query, key, value} vectors, called QKV-Norm and defined in section 4.

Paper strategy

Our work is based on three key observations: (i) semantic subspaces are observed in known circuits, (ii) they contribute to the model behaviour, and (iii) Pre-Norm requires them to follow a strict latent embedding structure or else interfere through the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms of the {query, key, value} vectors.

However, it is difficult to demonstrate specific examples of subspace interference. Firstly, a fully-converged model should learn to manage interference for in-distribution examples. Instead, we expect it to concern (i) training stability, (ii) model inductive bias, and (iii) out-of-distribution behaviour. Secondly, circuit explainability is difficult, only being achieved in simple cases. In general we expect circuits to become complicated, contain steps that are harder to interpret than match&pass, and exploit non-interpretable latent subspaces. Difficulty is further increased by polysemanticity [23], the ability for heads and features to change behaviour according to context.

In this work, we take an abstract approach instead. We formally define latent semantic separability, then investigate the theoretical consequences for Pre-Norm architectures if this behaviour is important generally. This allows us to make testable predictions about representation structure and model stability without needing to fully reverse-engineer a network or explain subspaces in human terms. We then place some data-driven limits on the effect size. Nonetheless, direct observation remains important, and we hope that future works can confirm or falsify the importance of the proposed representation structure and interference effect.

3 Related Works

Our work is motivated by transformer circuit discovery [8, 9, 10, 24, 25, 26] and formation [27, 28]. See [11] for a recent review of interpretability for language decoder models, with a list of known logical operations implemented by attention heads. This builds upon works in BERTology [13, 29]. We study normalisation, for which several formulations have been proposed [17, 30, 31, 32, 33]. Our QKV-Norm variant is similar to QK-Norm, which is studied by [18, 19, 20] for asymptotic performance and training stability at large learning rates. These are motivated by logit-regularisation, whereas we are motivated by representational inductive bias and stability to latent semantic interference.

We highlight other works that study transformer normalisation through its geometric interpretation as a projection onto a sphere. [34] investigated the role of normalisation in mixing the attention output with the residual stream in Post-Norm models, but does not consider Pre-Norm. [35] studies the computational abilities of Pre-LayerNorm architectures, in particular demonstrating that projection onto a sphere ensures that all keys reside on their own convex hull, preventing them from becoming “unselectable”. [36] interprets the latent embeddings of Pre-Norm models as a trajectory on a sphere. These works do not consider the interference of semantic subspaces. [37] and the contemporary work [38] study the role of LayerNorm in the related phenomenon of embedding rank collapse.

We highlight the contemporary work of [39], who also study multi-step contextual reasoning in transformers using matching operations over independent subspaces, for both Pre-Norm and Post-Norm. This builds upon [40], who study the learning of abstract symbolic reasoning in transformers, and works that manipulate the flow of information to promote algorithmic reasoning, e.g. [41].

We are not aware of previous works that study the impact of Pre-Norm’s spherical geometry on the structure of latent subspaces. However, many works consider linear subspaces, described in the following paragraph. These results are directly applicable to the No-Norm and QKV-Norm methods in this work, although QKV-Norm applies a subsequent spherical projection. [42] design subspace separability into their model by decoupling the normalisation layers for different mechanisms.

Works on vector embeddings [43, 44, 45] and the linear representation hypothesis [46, 47, 48] study the emergence of linear subspaces that encode separable concepts in embedding-unembedding models, using both interpretation and intervention techniques. Many works search for linear subspaces/directions in a transformer representation (e.g. linear probes [49, 50]) or search for faithful causal abstractions (e.g. [51]), with a survey provided in [11] sections 3-4. We also highlight works that study the use of features in linear superposition [52, 23]. This allows a model to store more features than it has dimensions, at the cost of interference in their linear projections.

The terminology of semantic subspaces is used more generally, e.g. [49, 53, 54]. We consider a definition that does not require humans to define the separable concepts, only that abstract latent features remain independent in an attention layer. We also highlight works that study subspaces of static (model input) and contextual (latent or model output) embeddings in transformers, e.g. [55, 54, 56, 57, 58, 59, 60, 61, 62] (review in [13]). These are relevant because they also decompose embeddings into a combination of abstract subspaces, capturing different semantic and syntactic structures in a natural language setting. These may be used as semantic subspaces in our work. We highlight [57] which studies interference between positional and contextual components using a decomposition similar to ours, and also experiments using a next-token addition task.

4 Formulation

Consider the No-Norm case. Let 𝒳𝒳\mathcal{X}caligraphic_X be an unordered set of message receiving tokens, and 𝒴𝒴\mathcal{Y}caligraphic_Y the message senders. Let xNx𝑥superscriptsubscript𝑁𝑥x\in\mathbb{R}^{N_{x}}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT be the Nxsubscript𝑁𝑥N_{x}italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT-dimensional representation of an element in 𝒳𝒳\mathcal{X}caligraphic_X, and ytNysubscript𝑦𝑡superscriptsubscript𝑁𝑦y_{t}\in\mathbb{R}^{N_{y}}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT be the tthsuperscript𝑡tht^{\mathrm{th}}italic_t start_POSTSUPERSCRIPT roman_th end_POSTSUPERSCRIPT element in 𝒴𝒴\mathcal{Y}caligraphic_Y, with 1tT1𝑡𝑇1\leq t\leq T1 ≤ italic_t ≤ italic_T. For self-attention we have 𝒳=𝒴𝒳𝒴\mathcal{X}=\mathcal{Y}caligraphic_X = caligraphic_Y. Let WQNqkv×Nxsubscript𝑊𝑄superscriptsubscript𝑁𝑞𝑘𝑣subscript𝑁𝑥W_{Q}\in\mathbb{R}^{N_{qkv}\times N_{x}}italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and WKNqkv×Nysubscript𝑊𝐾superscriptsubscript𝑁𝑞𝑘𝑣subscript𝑁𝑦W_{K}\in\mathbb{R}^{N_{qkv}\times N_{y}}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT be the query and key weight matrices, with associated vectors q=WQxNqkv𝑞subscript𝑊𝑄𝑥superscriptsubscript𝑁𝑞𝑘𝑣q=W_{Q}x\in\mathbb{R}^{N_{qkv}}italic_q = italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and kt=WKytNqkvsubscript𝑘𝑡subscript𝑊𝐾subscript𝑦𝑡superscriptsubscript𝑁𝑞𝑘𝑣k_{t}=W_{K}y_{t}\in\mathbb{R}^{N_{qkv}}italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT on an Nqkvsubscript𝑁𝑞𝑘𝑣N_{qkv}italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT-dimensional latent space. We do not include biases in {q,kt}𝑞subscript𝑘𝑡\{q,~{}k_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } because they contribute terms that are nullified by the softmax, or are reproduced by constant directions in x𝑥xitalic_x (Theorem 12). We define dot-product attention scores as:

wt=qTkt=xTWQTWKyt=xTWQKytsubscript𝑤𝑡superscript𝑞𝑇subscript𝑘𝑡superscript𝑥𝑇superscriptsubscript𝑊𝑄𝑇subscript𝑊𝐾subscript𝑦𝑡superscript𝑥𝑇subscript𝑊𝑄𝐾subscript𝑦𝑡w_{t}~{}=~{}q^{T}k_{t}~{}=~{}x^{T}W_{Q}^{T}W_{K}y_{t}~{}=~{}x^{T}W_{QK}y_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (3)

where WQKWQTWKNx×Nysubscript𝑊𝑄𝐾superscriptsubscript𝑊𝑄𝑇subscript𝑊𝐾superscriptsubscript𝑁𝑥subscript𝑁𝑦W_{QK}\triangleq W_{Q}^{T}W_{K}\in\mathbb{R}^{N_{x}\times N_{y}}italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT ≜ italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a matrix with Rank(WQK)min(Nx,Ny,Nqkv)𝑅𝑎𝑛𝑘subscript𝑊𝑄𝐾subscript𝑁𝑥subscript𝑁𝑦subscript𝑁𝑞𝑘𝑣Rank(W_{QK})\leq\min(N_{x},N_{y},N_{qkv})italic_R italic_a italic_n italic_k ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT ) ≤ roman_min ( italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT ). This is the maximum span of the attended subspace in {x,yt}𝑥subscript𝑦𝑡\{x,y_{t}\}{ italic_x , italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }. The attention weights are

at=softmax(wt)=ewttewt.subscript𝑎𝑡softmaxsubscript𝑤𝑡superscript𝑒subscript𝑤𝑡subscriptsuperscript𝑡superscript𝑒subscript𝑤superscript𝑡a_{t}~{}=~{}\texttt{softmax}\left(w_{t}\right)~{}=~{}\frac{e^{w_{t}}}{\sum_{t^% {\prime}}e^{w_{t^{\prime}}}}~{}~{}~{}~{}~{}~{}.italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = softmax ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = divide start_ARG italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG . (4)

Let vt=WVytNxsubscript𝑣𝑡subscript𝑊𝑉subscript𝑦𝑡superscriptsubscript𝑁𝑥v_{t}=W_{V}y_{t}\in\mathbb{R}^{N_{x}}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT be the value vectors with WVNqkv×Nysubscript𝑊𝑉superscriptsubscript𝑁𝑞𝑘𝑣subscript𝑁𝑦W_{V}\in\mathbb{R}^{N_{qkv}\times N_{y}}italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. We do not include biases in vtsubscript𝑣𝑡v_{t}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT because they carry no dependence on the attended token. Each token emits the message mt=WOvtWOWVytWOVytsubscript𝑚𝑡subscript𝑊𝑂subscript𝑣𝑡subscript𝑊𝑂subscript𝑊𝑉subscript𝑦𝑡subscript𝑊𝑂𝑉subscript𝑦𝑡m_{t}=W_{O}v_{t}\equiv W_{O}W_{V}y_{t}\triangleq W_{OV}y_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≡ italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≜ italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT where WO=Nx×Nqkvsubscript𝑊𝑂superscriptsubscript𝑁𝑥subscript𝑁𝑞𝑘𝑣W_{O}=\mathbb{R}^{N_{x}\times N_{qkv}}italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT = blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the output-matrix. Each attention-head updates x𝑥xitalic_x by adding the attention-weighted convex combination of messages, xx+Δx𝑥𝑥Δ𝑥x\rightarrow x+\Delta xitalic_x → italic_x + roman_Δ italic_x with Δx=tatmtΔ𝑥subscript𝑡subscript𝑎𝑡subscript𝑚𝑡\Delta x=\sum_{t}a_{t}m_{t}roman_Δ italic_x = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. We usually run H𝐻Hitalic_H attention-heads in parallel, giving the total update:

xx+h=1Ht=1Tat(h)mt(h)Multi-head attention𝑥𝑥superscriptsubscript1𝐻superscriptsubscript𝑡1𝑇superscriptsubscript𝑎𝑡superscriptsubscript𝑚𝑡Multi-head attentionx~{}\rightarrow~{}x~{}+~{}\sum_{h=1}^{H}\sum_{t=1}^{T}a_{t}^{(h)}m_{t}^{(h)}~{% }~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\text{Multi-head~{}attention}italic_x → italic_x + ∑ start_POSTSUBSCRIPT italic_h = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_h ) end_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_h ) end_POSTSUPERSCRIPT Multi-head attention (5)

with unique weights {WQ(h),WK(h),WV(h),WO(h)}superscriptsubscript𝑊𝑄superscriptsubscript𝑊𝐾superscriptsubscript𝑊𝑉superscriptsubscript𝑊𝑂\{W_{Q}^{(h)},W_{K}^{(h)},W_{V}^{(h)},W_{O}^{(h)}\}{ italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_h ) end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_h ) end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_h ) end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_h ) end_POSTSUPERSCRIPT } for each head index hhitalic_h.

We now introduce normalisation layers. Let zNz𝑧superscriptsubscript𝑁𝑧z\in\mathbb{R}^{N_{z}}italic_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT be any Nzsubscript𝑁𝑧N_{z}italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT-dimensional vector, then N(z;αz):NzNz:N𝑧subscript𝛼𝑧superscriptsubscript𝑁𝑧superscriptsubscript𝑁𝑧\texttt{N}(z;\alpha_{z}):\mathbb{R}^{N_{z}}\rightarrow\mathbb{R}^{N_{z}}N ( italic_z ; italic_α start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) : blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a normalisation function with parameters αzsubscript𝛼𝑧\alpha_{z}italic_α start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT. We consider two such functions:

RMSNorm(z;αz)=Nz|z|diag(αz)zLayerNorm(z;αz)=Nz|z|diag(αz)zformulae-sequenceRMSNorm𝑧subscript𝛼𝑧subscript𝑁𝑧𝑧diagsubscript𝛼𝑧𝑧LayerNorm𝑧subscript𝛼𝑧subscript𝑁𝑧subscript𝑧perpendicular-todiagsubscript𝛼𝑧subscript𝑧perpendicular-to\texttt{RMSNorm}\left(z;~{}\alpha_{z}\right)=\frac{\sqrt{N_{z}}}{|z|}\mathrm{% diag}\left(\alpha_{z}\right)z~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\texttt{% LayerNorm}\left(z;~{}\alpha_{z}\right)=\frac{\sqrt{N_{z}}}{|z_{\perp}|}\mathrm% {diag}\left(\alpha_{z}\right)z_{\perp}RMSNorm ( italic_z ; italic_α start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) = divide start_ARG square-root start_ARG italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_ARG end_ARG start_ARG | italic_z | end_ARG roman_diag ( italic_α start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) italic_z LayerNorm ( italic_z ; italic_α start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) = divide start_ARG square-root start_ARG italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_ARG end_ARG start_ARG | italic_z start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT | end_ARG roman_diag ( italic_α start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) italic_z start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT (6)

[21, 22] where Pdiag(1Nz)1Nz1NzTsubscript𝑃perpendicular-todiagsuperscript1subscript𝑁𝑧superscript1subscript𝑁𝑧superscriptsuperscript1subscript𝑁𝑧𝑇P_{\perp}\triangleq\mathrm{diag}\left(1^{N_{z}}\right)-1^{N_{z}}{1^{N_{z}}}^{T}italic_P start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT ≜ roman_diag ( 1 start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) - 1 start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT 1 start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT is a linear operator that subtracts the mean of z𝑧zitalic_z from every component, 1Nzsuperscript1subscript𝑁𝑧1^{N_{z}}1 start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is vector of ones, and zPzsubscript𝑧perpendicular-tosubscript𝑃perpendicular-to𝑧z_{\perp}\triangleq P_{\perp}zitalic_z start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT ≜ italic_P start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT italic_z is the component of z𝑧zitalic_z perpendicular to 1Nzsuperscript1subscript𝑁𝑧1^{N_{z}}1 start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT.

The Pre-Norm strategy means applying normalisation to the inputs {x,yt}𝑥subscript𝑦𝑡\{x,y_{t}\}{ italic_x , italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }. The QKV-Norm strategy means applying normalisation to the vectors {q,kt,vt}𝑞subscript𝑘𝑡subscript𝑣𝑡\{q,k_{t},v_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }. We then have three cases:

wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT vtsubscript𝑣𝑡v_{t}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT Norm params
No-Norm xTWQKytsuperscript𝑥𝑇subscript𝑊𝑄𝐾subscript𝑦𝑡x^{T}~{}W_{QK}~{}y_{t}italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT WVytsubscript𝑊𝑉subscript𝑦𝑡W_{V}~{}y_{t}italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT -
(baseline) Pre-Norm N(x;αx)TWQKN(yt;αyK)Nsuperscript𝑥subscript𝛼𝑥𝑇subscript𝑊𝑄𝐾Nsubscript𝑦𝑡subscriptsuperscript𝛼𝐾𝑦\texttt{N}\left(x;\alpha_{x}\right)^{T}~{}W_{QK}~{}\texttt{N}\left(y_{t};% \alpha^{K}_{y}\right)N ( italic_x ; italic_α start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT N ( italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_α start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) WVN(yt;αyV)subscript𝑊𝑉Nsubscript𝑦𝑡subscriptsuperscript𝛼𝑉𝑦W_{V}~{}\texttt{N}\left(y_{t};\alpha^{V}_{y}\right)italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT N ( italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_α start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) {αx,αyK,αyV}subscript𝛼𝑥subscriptsuperscript𝛼𝐾𝑦subscriptsuperscript𝛼𝑉𝑦\{\alpha_{x},~{}\alpha^{K}_{y},~{}\alpha^{V}_{y}\}{ italic_α start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_α start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_α start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT }
(alternate) QKV-Norm N(WQx;αq)TN(WKyt;αk)Nsuperscriptsubscript𝑊𝑄𝑥subscript𝛼𝑞𝑇Nsubscript𝑊𝐾subscript𝑦𝑡subscript𝛼𝑘\texttt{N}\left(W_{Q}x;\alpha_{q}\right)^{T}\texttt{N}\left(W_{K}y_{t};\alpha_% {k}\right)N ( italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_x ; italic_α start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT N ( italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) N(WVyt;αv)Nsubscript𝑊𝑉subscript𝑦𝑡subscript𝛼𝑣\texttt{N}\left(W_{V}y_{t};\alpha_{v}\right)N ( italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_α start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) {αq,αk,αv}subscript𝛼𝑞subscript𝛼𝑘subscript𝛼𝑣\{\alpha_{q},~{}\alpha_{k},~{}\alpha_{v}\}{ italic_α start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_α start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT }

We note that several of these degrees of freedom are redundant and could be combined, e.g. αqsubscript𝛼𝑞\alpha_{q}italic_α start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT and αksubscript𝛼𝑘\alpha_{k}italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. We do not consider these variations (i) because they are not relevant for the results of this paper, and (ii) to standardise the number of training parameters.

5 Theory: representation structure required for independent subspaces

Let 𝕊NNsuperscript𝕊𝑁superscript𝑁\mathbb{S}^{N}\equiv\mathbb{R}^{N}blackboard_S start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ≡ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT be an N𝑁Nitalic_N-dimensional latent representation of 𝒳𝒳\mathcal{X}caligraphic_X or 𝒴𝒴\mathcal{Y}caligraphic_Y.

{mdframed}

[backgroundcolor=red!5] [Definition]   Semantic subspace: any independent Nαsubscript𝑁𝛼N_{\alpha}italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT-dimensional subspace 𝕊αNα𝕊Nsuperscriptsubscript𝕊𝛼subscript𝑁𝛼superscript𝕊𝑁\mathbb{S}_{\alpha}^{N_{\alpha}}\subset\mathbb{S}^{N}blackboard_S start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ⊂ blackboard_S start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT for which every element may be uniquely identified by some parameters θαsubscript𝜃𝛼\theta_{\alpha}italic_θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT, such that it is possible for the attention scores wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to be fully specified by θαsubscript𝜃𝛼\theta_{\alpha}italic_θ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT. Semantic separability: ability for parallel heads to be fully specified by different semantic subspaces.

Let {α}𝛼\{\alpha\}{ italic_α } be the set of indivisible semantic subspaces. This can be seen as a co-ordinate system for the attendable embedding space. Semantic separability requires that each co-ordinate α𝛼\alphaitalic_α be independently measurable by an attention head. Let 𝕊Nsuperscript𝕊𝑁\mathbb{S}^{N}blackboard_S start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT contain Nssubscript𝑁𝑠N_{s}italic_N start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT indivisible semantic subspaces 1αNs1𝛼subscript𝑁𝑠1\leq\alpha\leq N_{s}1 ≤ italic_α ≤ italic_N start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT. Then 𝕊N=α𝕊αNα𝕊nullsuperscript𝕊𝑁subscriptproduct𝛼subscriptsuperscript𝕊subscript𝑁𝛼𝛼direct-sumsubscript𝕊null\mathbb{S}^{N}=\prod_{\alpha}\mathbb{S}^{N_{\alpha}}_{\alpha}\bigoplus\mathbb{% S}_{\mathrm{null}}blackboard_S start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT = ∏ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT blackboard_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ⨁ blackboard_S start_POSTSUBSCRIPT roman_null end_POSTSUBSCRIPT such that αNαNsubscript𝛼subscript𝑁𝛼𝑁\sum_{\alpha}N_{\alpha}\leq N∑ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ≤ italic_N satisfies semantic separability, where α,subscriptproduct𝛼direct-sum\prod_{\alpha},\bigoplus∏ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT , ⨁ are Cartesian products and 𝕊nullsubscript𝕊null\mathbb{S}_{\mathrm{null}}blackboard_S start_POSTSUBSCRIPT roman_null end_POSTSUBSCRIPT is a separable space of non-attended information.

The following theorems derive the representation structure required for semantic separability:

{mdframed}

[backgroundcolor=blue!5] Semantically separable representation structures                                 [Proofs in appendix F]

Theorem 1.

No-Norm: If two heads with finite non-zero temperature attend to different semantic subspaces, the subspaces must be linearly independent 𝕊αNαNαsubscriptsuperscript𝕊subscriptNααsuperscriptsubscriptNα\mathbb{S}^{N_{\alpha}}_{\alpha}\equiv\mathbb{R}^{N_{\alpha}}blackboard_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ≡ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. Corollary: WQKsubscriptWQKW_{QK}italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT is a low-rank matrix with (left and right) null-spaces that span all non-attended information.

Theorem 2.

Pre-Norm: Semantic subspaces must be represented as orthogonal spheres 𝕊Nα𝒮Nα1superscript𝕊subscriptNαsuperscript𝒮subscriptNα1\mathbb{S}^{N_{\alpha}}\equiv\mathcal{S}^{N_{\alpha}-1}blackboard_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ≡ caligraphic_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT defined using the L2subscriptL2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm. Corollary: if either orthogonality or constant-norm are violated, semantic subspaces interfere through a multiplicative factor on wtsubscriptwtw_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Theorem 3.

QKV-Norm: Semantic subspaces must be linearly independent.

We note that every linear subspace Nαsuperscriptsubscript𝑁𝛼\mathbb{R}^{N_{\alpha}}blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT has Nαsubscript𝑁𝛼N_{\alpha}italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT continuous degrees of freedom, whilst 𝒮Nα1superscript𝒮subscript𝑁𝛼1\mathcal{S}^{N_{\alpha}-1}caligraphic_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT has only Nα1subscript𝑁𝛼1N_{\alpha}-1italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - 1, the other being removed by the fixed-norm constraint. The subspace 𝒮0superscript𝒮0\mathcal{S}^{0}caligraphic_S start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT is allowed and may be seen as a binary variable with values ±constαplus-or-minus𝑐𝑜𝑛𝑠subscript𝑡𝛼\pm const_{\alpha}± italic_c italic_o italic_n italic_s italic_t start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT, and the total representation can store Nssubscript𝑁𝑠N_{s}italic_N start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT such variables. For QKV-Norm, we note that the residual subspace Nαsuperscriptsubscript𝑁𝛼\mathbb{R}^{N_{\alpha}}blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT only contributes Nα1subscript𝑁𝛼1N_{\alpha}-1italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - 1 continuous degrees of freedom to the attention calculation, because we apply the projection Nα𝒮Nα1superscriptsubscript𝑁𝛼superscript𝒮subscript𝑁𝛼1\mathbb{R}^{N_{\alpha}}\rightarrow\mathcal{S}^{N_{\alpha}-1}blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → caligraphic_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT after extracting the subspace. Table 2 provides a summary.

Structure of messages

We note the special case of compositional annotation, in which a layer creates a semantic subspace that is extracted by a later layer. This is used by circuits including the induction circuit [9] described in section 2. By normalising the inputs, Pre-Norm induces a spheroid message structure close to the sphere required for separability in later layers. This may facilitate compositional annotation, aiding in circuit-formation. Message structures are summarised in Table 2.

Strategy 𝕊Nsuperscript𝕊𝑁\mathbb{S}^{N}blackboard_S start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT Representation structure Attendable d.o.f.
No-Norm αNαsubscriptproduct𝛼superscriptsubscript𝑁𝛼\prod_{\alpha}\mathbb{R}^{N_{\alpha}}∏ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT Linearly independent subspaces N𝑁Nitalic_N
Pre-LayerNorm α𝒮Nα1subscriptproduct𝛼superscript𝒮subscript𝑁𝛼1\prod_{\alpha}\mathcal{S}^{N_{\alpha}-1}∏ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT caligraphic_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT Orthogonal spheres 1Nperpendicular-toabsentsuperscript1𝑁\perp 1^{N}⟂ 1 start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT NNs1𝑁subscript𝑁𝑠1N-N_{s}-1italic_N - italic_N start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - 1
Pre-RMSNorm α𝒮Nα1subscriptproduct𝛼superscript𝒮subscript𝑁𝛼1\prod_{\alpha}\mathcal{S}^{N_{\alpha}-1}∏ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT caligraphic_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT Orthogonal spheres NNs𝑁subscript𝑁𝑠N-N_{s}italic_N - italic_N start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT
QKV-Norm αNαsubscriptproduct𝛼superscriptsubscript𝑁𝛼\prod_{\alpha}\mathbb{R}^{N_{\alpha}}∏ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT Linearly independent subspaces NNs𝑁subscript𝑁𝑠N-N_{s}italic_N - italic_N start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT
Table 1: Representation structure required for semantic separability; d.o.f. means degrees of freedom.
Strategy mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT Structure of mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT Compositional annotation if
No-Norm WOVytsubscript𝑊𝑂𝑉subscript𝑦𝑡W_{OV}y_{t}italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT Linear mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT on independent subspace
Pre-Norm WOVN(yt;αv)subscript𝑊𝑂𝑉Nsubscript𝑦𝑡subscript𝛼𝑣W_{OV}\texttt{N}(y_{t};\alpha_{v})italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT N ( italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_α start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) Spheroid mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT on orthogonal sphere
QKV-Norm WON(WVyt;αv)subscript𝑊𝑂Nsubscript𝑊𝑉subscript𝑦𝑡subscript𝛼𝑣W_{O}\texttt{N}(W_{V}y_{t};\alpha_{v})italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT N ( italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_α start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) Spheroid mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT on independent subspace
Table 2: Summary of message structures induced by different placements of normalisation layer.

6 Theory: stability to subspace interference

We now investigate the impact of interfering subspaces. Consider the almost-separable limit, modelling interference as a random infinitesimal perturbation of the vectors {q,kt,mt}𝑞subscript𝑘𝑡subscript𝑚𝑡\{q,k_{t},m_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }. Let ϵitalic-ϵ\epsilonitalic_ϵ-symbols denote perturbations such that ϵΔx(q)Δxqϵqsuperscriptitalic-ϵΔ𝑥𝑞Δ𝑥𝑞superscriptitalic-ϵ𝑞\epsilon^{\Delta x(q)}\rightarrow\frac{\partial\Delta x}{\partial q}\epsilon^{q}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT → divide start_ARG ∂ roman_Δ italic_x end_ARG start_ARG ∂ italic_q end_ARG italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT for ϵq0superscriptitalic-ϵ𝑞0\epsilon^{q}\rightarrow 0italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → 0 is the change of ΔxΔ𝑥\Delta xroman_Δ italic_x induced by ϵqsuperscriptitalic-ϵ𝑞\epsilon^{q}italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT. We consider (i) the sparse limit, in which the attention is concentrated entirely on a single embedding, and (ii) the isotropic limit, in which it is distributed evenly among embeddings. We are particularly interested in the sparse case, since this highly directed flow of information is used by match&pass, although semantic separation can also be used by non-sparse heads.

{mdframed}

[backgroundcolor=red!5] [Definition]   Sparse attention: the low-temperature limit atδttsubscript𝑎𝑡subscript𝛿𝑡superscript𝑡a_{t}\approx\delta_{tt^{*}}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≈ italic_δ start_POSTSUBSCRIPT italic_t italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT and Δx=mtΔ𝑥subscript𝑚superscript𝑡\Delta x=m_{t^{*}}roman_Δ italic_x = italic_m start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, where δ𝛿\deltaitalic_δ is the Kronecker delta. This occurs when there is a large difference between the top two scores: t=argmaxtwtsuperscript𝑡subscriptargmax𝑡subscript𝑤𝑡t^{*}=\mathrm{argmax}_{t}w_{t}italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = roman_argmax start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and wtmaxttwt1much-greater-thansubscript𝑤superscript𝑡subscript𝑡superscript𝑡subscript𝑤𝑡1w_{t^{*}}-\max_{t\neq t^{*}}w_{t}\gg 1italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - roman_max start_POSTSUBSCRIPT italic_t ≠ italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≫ 1. Isotropic attention: the high-temperature limit at=1Tsubscript𝑎𝑡1𝑇a_{t}=\frac{1}{T}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_T end_ARG and Δx=mttΔ𝑥subscriptdelimited-⟨⟩subscript𝑚𝑡𝑡\Delta x=\langle m_{t}\rangle_{t}roman_Δ italic_x = ⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. This occurs when wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is constant, requiring q=0𝑞0q=0italic_q = 0 or constant ktsubscript𝑘𝑡k_{t}italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

{mdframed}

[backgroundcolor=blue!5] Stability of attention updates to perturbations on q, k, v                    [Proofs in appendix F]

Theorem 4.

Consider independent infinitesimal perturbations on queries ϵqNqkvsuperscriptitalic-ϵ𝑞superscriptsubscript𝑁𝑞𝑘𝑣\epsilon^{q}\in\mathbb{R}^{N_{qkv}}italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, keys ϵtkNqkvsubscriptsuperscriptitalic-ϵ𝑘𝑡superscriptsubscript𝑁𝑞𝑘𝑣\epsilon^{k}_{t}\in\mathbb{R}^{N_{qkv}}italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and messages ϵtmNqkvsubscriptsuperscriptitalic-ϵ𝑚𝑡superscriptsubscript𝑁𝑞𝑘𝑣\epsilon^{m}_{t}\in\mathbb{R}^{N_{qkv}}italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. These propagate onto Δx=tatmtΔ𝑥subscript𝑡subscript𝑎𝑡subscript𝑚𝑡\Delta x=\sum_{t}a_{t}m_{t}roman_Δ italic_x = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT as

ϵΔx(q)superscriptitalic-ϵΔ𝑥𝑞\displaystyle\epsilon^{\Delta x(q)}~{}~{}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT ϵq0perturbq𝔼at[mtk~tT]ϵqk~tkt𝔼at[kt]\displaystyle\xrightarrow[\epsilon^{q}\rightarrow 0]{\mathrm{~{}~{}~{}~{}% perturb~{}q~{}~{}~{}~{}}}~{}~{}\mathop{\mathbb{E}}_{a_{t}}\Big{[}m_{t}{\tilde{% k}}_{t}^{T}\Big{]}\epsilon^{q}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}% ~{}{\tilde{k}}_{t}~{}\triangleq~{}k_{t}~{}-\mathop{\mathbb{E}}_{a_{t}}\Big{[}k% _{t}\Big{]}start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_q end_OVERACCENT → end_ARROW end_ARROW blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≜ italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] (7)
ϵΔx(k)superscriptitalic-ϵΔ𝑥𝑘\displaystyle\epsilon^{\Delta x(k)}~{}~{}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT ϵtk0perturbk𝔼at[m~tϵtkT]qm~tmt𝔼at[mt]\displaystyle\xrightarrow[\epsilon^{k}_{t}\rightarrow 0]{\mathrm{~{}~{}~{}~{}% perturb~{}k~{}~{}~{}~{}}}~{}~{}\mathop{\mathbb{E}}_{a_{t}}\Big{[}{\tilde{m}}_{% t}{\epsilon^{k}_{t}}^{T}\Big{]}q~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~% {}~{}{\tilde{m}}_{t}~{}\triangleq~{}m_{t}~{}-\mathop{\mathbb{E}}_{a_{t}}\Big{[% }m_{t}\Big{]}start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_k end_OVERACCENT → end_ARROW end_ARROW blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] italic_q over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≜ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] (8)
ϵΔx(m)superscriptitalic-ϵΔ𝑥𝑚\displaystyle\epsilon^{\Delta x(m)}~{}~{}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_m ) end_POSTSUPERSCRIPT ϵtm0perturbm𝔼at[ϵtm]\displaystyle\xrightarrow[\epsilon^{m}_{t}\rightarrow 0]{\mathrm{~{}~{}~{}~{}% perturb~{}m~{}~{}~{}~{}}}~{}~{}\mathop{\mathbb{E}}_{a_{t}}\Big{[}\epsilon^{m}_% {t}\Big{]}start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_m end_OVERACCENT → end_ARROW end_ARROW blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] (9)

where z~tsubscript~𝑧𝑡{\tilde{z}}_{t}over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the value of ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT measured from the attention-weighted centroid 𝔼at[zt]=tatztsubscript𝔼subscript𝑎𝑡delimited-[]subscript𝑧𝑡subscript𝑡subscript𝑎𝑡subscript𝑧𝑡\mathbb{E}_{a_{t}}[z_{t}]=\sum_{t}a_{t}z_{t}blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Theorem 5.

For sparse attention:

ϵΔx(q)ϵq0perturbq0ϵΔx(k)ϵtk0perturbk0ϵΔx(m)ϵtm0perturbmϵtmformulae-sequencesuperscriptitalic-ϵ𝑞0perturbqsuperscriptitalic-ϵΔ𝑥𝑞0formulae-sequencesubscriptsuperscriptitalic-ϵ𝑘𝑡0perturbksuperscriptitalic-ϵΔ𝑥𝑘0subscriptsuperscriptitalic-ϵ𝑚𝑡0perturbmsuperscriptitalic-ϵΔ𝑥𝑚subscriptsuperscriptitalic-ϵ𝑚superscript𝑡\epsilon^{\Delta x(q)}\xrightarrow[\epsilon^{q}\rightarrow 0]{\mathrm{~{}~{}% perturb~{}q~{}~{}}}0~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\epsilon^{\Delta x(k)}% \xrightarrow[\epsilon^{k}_{t}\rightarrow 0]{\mathrm{~{}~{}perturb~{}k~{}~{}}}0% ~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\epsilon^{\Delta x(m)}\xrightarrow[\epsilon^{m}_% {t}\rightarrow 0]{\mathrm{~{}~{}perturb~{}m~{}~{}}}\epsilon^{m}_{t^{*}}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_q end_OVERACCENT → end_ARROW end_ARROW 0 italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_k end_OVERACCENT → end_ARROW end_ARROW 0 italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_m ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_m end_OVERACCENT → end_ARROW end_ARROW italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (10)

i.e. the message is stable with respect to small interference in the queries and keys. Interference in the selected value is linearly transferred onto the message.

Theorem 6.

For isotropic attention:

ϵΔx(q)ϵq0perturbqmtk~tTtϵqϵΔx(k)ϵtk0perturbkm~tϵtkTtqϵΔx(m)ϵtm0perturbmϵtmtformulae-sequencesuperscriptitalic-ϵ𝑞0perturbqsuperscriptitalic-ϵΔ𝑥𝑞subscriptdelimited-⟨⟩subscript𝑚𝑡superscriptsubscript~𝑘𝑡𝑇𝑡superscriptitalic-ϵ𝑞formulae-sequencesubscriptsuperscriptitalic-ϵ𝑘𝑡0perturbksuperscriptitalic-ϵΔ𝑥𝑘subscriptdelimited-⟨⟩subscript~𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑡𝑞subscriptsuperscriptitalic-ϵ𝑚𝑡0perturbmsuperscriptitalic-ϵΔ𝑥𝑚subscriptdelimited-⟨⟩subscriptsuperscriptitalic-ϵ𝑚𝑡𝑡\epsilon^{\Delta x(q)}\xrightarrow[\epsilon^{q}\rightarrow 0]{\mathrm{perturb~% {}q}}\langle m_{t}{\tilde{k}}_{t}^{T}\rangle_{t}\epsilon^{q}~{}~{}~{}~{}~{}~{}% ~{}~{}\epsilon^{\Delta x(k)}\xrightarrow[\epsilon^{k}_{t}\rightarrow 0]{% \mathrm{perturb~{}k}}\langle{\tilde{m}}_{t}{\epsilon^{k}_{t}}^{T}\rangle_{t}~{% }q~{}~{}~{}~{}~{}~{}~{}~{}\epsilon^{\Delta x(m)}\xrightarrow[\epsilon^{m}_{t}% \rightarrow 0]{\mathrm{perturb~{}m}}\langle\epsilon^{m}_{t}\rangle_{t}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_q end_OVERACCENT → end_ARROW end_ARROW ⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_k end_OVERACCENT → end_ARROW end_ARROW ⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_q italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_m ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_m end_OVERACCENT → end_ARROW end_ARROW ⟨ italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (11)

N.B. isotropy requires kt=constsubscript𝑘𝑡𝑐𝑜𝑛𝑠𝑡k_{t}=constitalic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t or q=0𝑞0q=0italic_q = 0. Lemma 1: the update is stable to noisy q𝑞qitalic_q when kt=constsubscript𝑘𝑡𝑐𝑜𝑛𝑠𝑡k_{t}=constitalic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t, or when mtktperpendicular-tosubscript𝑚𝑡subscript𝑘𝑡m_{t}\perp k_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟂ italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (c.f. keys and messages from independent subspaces). Lemma 2: the update is stable to noisy ktsubscript𝑘𝑡k_{t}italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT when q=0𝑞0q=0italic_q = 0, or when mtϵtkperpendicular-tosubscript𝑚𝑡superscriptsubscriptitalic-ϵ𝑡𝑘m_{t}\perp\epsilon_{t}^{k}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟂ italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. Lemma 3: the update is stable to noisy mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT when ϵtmt=0subscriptdelimited-⟨⟩subscriptsuperscriptitalic-ϵ𝑚𝑡𝑡0\langle\epsilon^{m}_{t}\rangle_{t}=0⟨ italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 0. Other cases propagate linearly.

The stability of sparse attention is because softmax becomes an argmax for low-temperature heads, which is only sensitive to the order of wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. However, this introduces a different vulnerability when perturbations cause the order of wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to change, as the attention distribution undergoes a phase transition to select a different token. We call this circuit collapse. For example, the induction circuit collapses when the operation attend to the previous token attends to any other token because of interference.

{mdframed}

[backgroundcolor=red!5] [Definition]   Circuit collapse: spontaneous phase transition in which a sparse attention distribution selects a different token due to noise on {q,kt}𝑞subscript𝑘𝑡\{q,k_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }. Let ϵtw=ktTϵq+qTϵtk+𝒪(ϵqTϵtk)subscriptsuperscriptitalic-ϵ𝑤𝑡superscriptsubscript𝑘𝑡𝑇superscriptitalic-ϵ𝑞superscript𝑞𝑇subscriptsuperscriptitalic-ϵ𝑘𝑡𝒪superscriptsuperscriptitalic-ϵ𝑞𝑇subscriptsuperscriptitalic-ϵ𝑘𝑡\epsilon^{w}_{t}=k_{t}^{T}\epsilon^{q}+q^{T}\epsilon^{k}_{t}+\mathcal{O}({% \epsilon^{q}}^{T}\epsilon^{k}_{t})italic_ϵ start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) be perturbations on wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT that result from ϵqsuperscriptitalic-ϵ𝑞\epsilon^{q}italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT and ϵtksubscriptsuperscriptitalic-ϵ𝑘𝑡\epsilon^{k}_{t}italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Circuit collapse occurs when there exists a tt𝑡superscript𝑡t\neq t^{*}italic_t ≠ italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT for which wtwt<ϵtwϵtwsubscript𝑤superscript𝑡subscript𝑤𝑡subscriptsuperscriptitalic-ϵ𝑤𝑡subscriptsuperscriptitalic-ϵ𝑤superscript𝑡w_{t^{*}}-w_{t}<\epsilon^{w}_{t}-\epsilon^{w}_{t^{*}}italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT < italic_ϵ start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_ϵ start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT.

We now study the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm interference that we expect to be induced by Pre-Norm when semantic separability is violated. This is characterised by perturbations that are parallel to their corresponding vector. Theorem 7 shows the conditions under which we expect circuit collapse to occur.

{mdframed}

[backgroundcolor=blue!5] Stability of attention updates to scaling of q, k, v                                     [Proofs in appendix F]

Theorem 7.

Sensitivity of sparse attention to multiplicative perturbations ϵq=κqqsuperscriptitalic-ϵ𝑞superscript𝜅𝑞𝑞\epsilon^{q}=\kappa^{q}qitalic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT = italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT italic_q and ϵk=κtkktsuperscriptitalic-ϵ𝑘subscriptsuperscript𝜅𝑘𝑡subscript𝑘𝑡\epsilon^{k}=\kappa^{k}_{t}k_{t}italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with κq,κtk1much-less-thansuperscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡1\kappa^{q},\kappa^{k}_{t}\ll 1italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT , italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1. Circuit collapse occurs when tt𝑡superscript𝑡\exists~{}t\neq t^{*}∃ italic_t ≠ italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT for which:

wtwt{<λwifwt(1+κq+κtk)>0>λwotherwiseλw1+κq+κtk1+κq+κtksubscript𝑤superscript𝑡subscript𝑤𝑡casesabsentsubscript𝜆𝑤ifsubscript𝑤𝑡1superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡0absentsubscript𝜆𝑤otherwisesubscript𝜆𝑤1superscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡1superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡\frac{w_{t^{*}}}{w_{t}}~{}\begin{cases}~{}<~{}\lambda_{w}&\mathrm{if}~{}w_{t}% \left(1+\kappa^{q}+\kappa^{k}_{t^{*}}\right)>0\\ ~{}>~{}\lambda_{w}&\mathrm{otherwise}\\ \end{cases}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\lambda_{w}~{}\triangleq~{}% \frac{1+\kappa^{q}+\kappa^{k}_{t}}{1+\kappa^{q}+\kappa^{k}_{t^{*}}}divide start_ARG italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG { start_ROW start_CELL < italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_CELL start_CELL roman_if italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) > 0 end_CELL end_ROW start_ROW start_CELL > italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_CELL start_CELL roman_otherwise end_CELL end_ROW italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ≜ divide start_ARG 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG (12)

where temperature cancels in the fraction. Attention is fully stable above the critical transition point λwsubscriptλw\lambda_{w}italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT (c.f. wt(1+κq+κtk)>0subscript𝑤𝑡1superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡0w_{t}\left(1+\kappa^{q}+\kappa^{k}_{t^{*}}\right)>0italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) > 0). We see that query perturbations alone are insufficient, as they result in λw=1subscript𝜆𝑤1\lambda_{w}=1italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = 1. Lemma: consider the special case when all keys have similar length ktconstsubscript𝑘𝑡𝑐𝑜𝑛𝑠𝑡k_{t}\approx constitalic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≈ italic_c italic_o italic_n italic_s italic_t, the attended token has θt0subscript𝜃superscript𝑡0\theta_{t^{*}}\approx 0italic_θ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ≈ 0, the keys are far-from-orthogonal s.t. θt1much-less-thansubscript𝜃𝑡1\theta_{t}\ll 1italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1, and κq0superscript𝜅𝑞0\kappa^{q}\approx 0italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ≈ 0. Using wt|q||kt|cosθtsubscript𝑤𝑡𝑞subscript𝑘𝑡subscript𝜃𝑡w_{t}\triangleq|q||k_{t}|\cos\theta_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≜ | italic_q | | italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | roman_cos italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, circuit collapse occurs when tt𝑡superscript𝑡\exists~{}t\neq t^{*}∃ italic_t ≠ italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT for which:

12θt2κtkκtkifwt(1+κtk)>0, otherwise reverseformulae-sequenceless-than-or-similar-to12superscriptsubscript𝜃𝑡2subscriptsuperscript𝜅𝑘𝑡subscriptsuperscript𝜅𝑘superscript𝑡ifsubscript𝑤𝑡1subscriptsuperscript𝜅𝑘superscript𝑡0, otherwise reverse\frac{1}{2}\theta_{t}^{2}~{}\lesssim~{}\kappa^{k}_{t}-\kappa^{k}_{t^{*}}~{}~{}% ~{}~{}~{}~{}~{}~{}~{}~{}~{}\mathrm{if}~{}w_{t}\left(1+\kappa^{k}_{t^{*}}\right% )>0~{}\text{, otherwise reverse}divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≲ italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_if italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( 1 + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) > 0 , otherwise reverse (13)

i.e. stability requires either well-separated keys s.t. θt0much-greater-thansubscript𝜃𝑡0\theta_{t}\gg 0italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≫ 0, or small perturbations κtκt1much-less-thansubscript𝜅𝑡subscriptsuperscript𝜅𝑡1\kappa_{t}-\kappa^{*}_{t}\ll 1italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_κ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1.

Theorem 8.

Sensitivity of isotropic attention to multiplicative perturbations. Say ϵk=κtkktsuperscriptitalic-ϵ𝑘subscriptsuperscript𝜅𝑘𝑡subscript𝑘𝑡\epsilon^{k}=\kappa^{k}_{t}k_{t}italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with κtk1much-less-thansubscriptsuperscript𝜅𝑘𝑡1\kappa^{k}_{t}\ll 1italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1 where {κt}subscript𝜅𝑡\{\kappa_{t}\}{ italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } have comparable amplitudes. Then

ϵΔx(k){0if κt independent of m~t, by symmetry0if κtκ for constant κ0if q=0wm~tκtktotherwisesuperscriptitalic-ϵΔ𝑥𝑘cases0if κt independent of m~t, by symmetry0if κtκ for constant κ0if q=0𝑤subscriptdelimited-⟨⟩subscript~𝑚𝑡subscriptsuperscript𝜅𝑘𝑡𝑡otherwise\epsilon^{\Delta x(k)}~{}\approx~{}\begin{cases}0~{}&~{}\text{if~{}$\kappa_{t}% $~{}independent~{}of~{}${\tilde{m}}_{t}$,~{}by~{}symmetry}\\ 0~{}&~{}\text{if~{}$\kappa_{t}\equiv\kappa$~{}for~{}constant~{}$\kappa$}\\ 0~{}&~{}\text{if~{}$q=0$}\\ w\langle{\tilde{m}}_{t}\kappa^{k}_{t}\rangle_{t}~{}&~{}\text{otherwise}\end{cases}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT ≈ { start_ROW start_CELL 0 end_CELL start_CELL if italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT independent of over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , by symmetry end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL if italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≡ italic_κ for constant italic_κ end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL if italic_q = 0 end_CELL end_ROW start_ROW start_CELL italic_w ⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL otherwise end_CELL end_ROW (14)

7 Experimental results

We now use experiments to empirically probe (i) the real-world embedding structure, and (ii) the sensitivity to artificial noise on the {query, key, value} L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms. Whilst this does not directly observe real-world interference, it constrains the effect importance.

We consider a base-10 integer-addition task with a question-answer structure, and train for next-token prediction. We use a decoder architecture, common for state-of-the-art language models, with 10101010 layers, per-character tokenisation, and begin [ and end ] tokens. In the output, we mask * tokens that precede the answer. For example, the first training sequence has input [453+16+17-N846=1332 and output ***************1332]. We compare two models that use Pre-Norm and QKV-Norm respectively. Appendices A-C provide a full experimental setup and supplementary plots. In this section we make all plots using an in-distribution test set that is expected to have some overlap with the training set, bounded at 20%much-less-thanabsentpercent20\ll 20\%≪ 20 %.

We choose this task because it emphasises contextual reasoning in a small-scale setting, is configurable for complexity, and allows us to define meaningful out-of-distribution test sets. The Pre-Norm (QKV-Norm) model achieves an in-distribution per-token accuracy of 91.4%percent91.491.4\%91.4 % (91.0%percent91.091.0\%91.0 %), drop** to 87.5%percent87.587.5\%87.5 % (82.5%percent82.582.5\%82.5 %) when generalising out-of-distribution to intermediate complexity, and 66.7%percent66.766.7\%66.7 % (46.8%percent46.846.8\%46.8 %) for increased complexity. Statistical uncertainties are below 0.1%percent0.10.1\%0.1 %. The in-distribution performance is comparable, but QKV-Norm generalises worse in this task, implying it has learned less task-appropriate solutions. Appendix D shows additional comparisons suggesting that the Pre-Norm and QKV-Norm models behave differently, supporting the observations of [18, 19, 20].

Embedding structure

Our theory predicts that Pre-Norm attention is stable with respect to information in non-attended subspaces if all input embeddings have similar L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms, whereas QKV-Norm imposes no norm constraint. We seek to experimentally bound the degree to which this structure is learned in practice.

We do this by plotting the spread of norms with respect to their median. A confounding effect is that the norms may differ for (i) embeddings attended to by different heads, (ii) the same head acting in different contexts, and (iii) embeddings that are never attended. We therefore measure the ratio per-head, and weight each embedding by its assigned attention. We remove the begin-sequence token from consideration. Figure 1 shows the resulting spread for all attention layers. On the LHS, we see that 90% of the distribution is contained within an interval of ±20%plus-or-minuspercent20\pm 20\%± 20 % when using Pre-Norm. On the RHS, we see that QKV-Norm allows a much wider spread. This is consistent with our theory, and experimentally bounds the representation effect on Pre-Norm to 20%less-than-or-similar-toabsentpercent20\lesssim 20\%≲ 20 % in this model. Supplementary Figures 16-16 show consistent results for two model variations, although we note that Pre-Norm and QKV-Norm are more comparable for the variation labelled Alternate.

Refer to caption
Figure 1: Spread of embedding L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms experienced by attention heads at increasing model depth, excluding the [ token. For Pre-Norm, 90% of the spread is observed within an interval of ±20%plus-or-minuspercent20\pm 20\%± 20 %. Supplementary Figure 12 shows the distributions used to make this plot. Supplementary Figures 16-16 replicate the analysis for two model variations.

Model stability with simulated interference

Section 6 theoretically modelled the semantic interference induced by Pre-Norm as a random perturbation on the norms of {q,kt,mt}𝑞subscript𝑘𝑡subscript𝑚𝑡\{q,k_{t},m_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }. To estimate the real-world sensitivity to such an effect, we artificially introduce uncorrelated uniform noise onto these norms inside our trained Pre-Norm model. Even though Gaussian noise is expected in the large-Nssubscript𝑁𝑠N_{s}italic_N start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT limit, we use uniform noise to avoid outliers. Figure 3 shows the evolution of in-distribution per-token accuracy with increasing RMS. On the LHS, we see that performance falls by 10%greater-than-or-equivalent-toabsentpercent10\gtrsim 10\%≳ 10 % at only a 1%percent11\%1 % noise level. We also show the trend excluding the end-sequence token, which contributes a significant fraction of the metric. On the RHS, we introduce {q,kt}𝑞subscript𝑘𝑡\{q,k_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } noise only to sparse heads (when maxtat95%subscript𝑡subscript𝑎𝑡percent95\max_{t}a_{t}\geq 95\%roman_max start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≥ 95 %) and non-sparse heads (when maxtat<70%subscript𝑡subscript𝑎𝑡percent70\max_{t}a_{t}<70\%roman_max start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT < 70 %). We see that the model is stable with respect to %percent\%%-scale noise on sparse-attention, and this regime is dominated by the non-sparse case.

Figure 3 (right) is consistent with the stability predictions of Theorems 5-6. However, it may also be explained if non-sparse distributions are simply more important to the model. This could be caused by non-sparse distributions being more common, as well as depth-dependence. This is because artificial noise is applied to all layers during the forward pass, therefore later layers are perturbed by both the noise component and the shifting of their inputs due to previous layers, which is expected to compound with depth. We are interested in capturing this effect, however it may increase the importance of early layers. See Figure 27 for a visualisation of the observed attention maps.

Refer to caption
Figure 2: Left: evolution of per-token accuracy as we increase noise on the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms of {q,kt,mt}𝑞subscript𝑘𝑡subscript𝑚𝑡\{q,k_{t},m_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }. A 10%greater-than-or-equivalent-toabsentpercent10\gtrsim 10\%≳ 10 % drop in performance is observed when 1%percent11\%1 % noise is applied to all layers. Right: applying noise only to {q,kt}𝑞subscript𝑘𝑡\{q,k_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }, we see that non-sparse attention drives the drop at small noise, whereas the sparse case is stable. This is consistent with Theorems 5-6, but this interpretation is confounded by the relative importance of non-sparse distributions caused by frequency and depth-dependence.
Refer to caption
Figure 3: Probability of circuit collapse vs increasing noise. This observes the effect predicted in Section 6, and measures that 1%percent11\%1 % of sparse distributions collapse at a noise level of 11%percent1111\%11 %.

Circuit collapse

Figure 3 shows the probability that our artificial noise causes the circuit collapse phenomenon as defined in section 6. In this experiment, we add noise to every layer independently. This prevents the confounding effect of shifting inputs due to noise in previous layers. We observe that 1%percent11\%1 % of sparse attention distributions collapse when they experience noise at a level of 11%percent1111\%11 %. This reduces to 7.5%percent7.57.5\%7.5 % and 5.5%percent5.55.5\%5.5 % for the two model variations shown in Appendix C.

8 Summary & Outlook

We have presented the idea that transformer Pre-Norm can cause interference between independent subspaces of the latent embeddings, a feature used by some real-world transformer circuits. Theoretically, we found this can only be avoided when using an embedding structure of orthogonal spheres. By contrast, the QKV-Norm architecture requires only linearly independent subspaces. We predict that sparse attention is stable with respect to interference, until a certain threshold of noise is reached, at which point it undergoes a phase transition called circuit collapse.

Empirically, we observe that the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms of attended embeddings are contained within a spread of ±20%plus-or-minuspercent20\pm 20\%± 20 % for Pre-Norm (with 90% coverage), whilst QKV-Norm creates a wider spread. We simulate interference by introducing artificial noise onto the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms of {q,kt,vt}𝑞subscript𝑘𝑡subscript𝑣𝑡\{q,k_{t},v_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } in our trained Pre-Norm model, observing that 1%percent11\%1 % of sparse distributions collapse at a noise level of 11%percent1111\%11 %. We observe that per-token accuracy degrades by 𝒪(10%)𝒪percent10\mathcal{O}(10\%)caligraphic_O ( 10 % ) when norms are simultaneously perturbed by noise of 1% in all layers, but is stable to %-scale noise in only sparse distributions. This may be attributed to either the predicted stability of sparse attention, or to a difference in the importance of sparse vs non-sparse heads induced by frequency and depth-dependence. More work is needed to disentangle these.

This work contributes a theoretical hypothesis of model behaviour, and empirically constrains the effect size without full model reverse-engineering. We have made predictions on representation structure, interference, and circuit collapse that practitioners may search for in their own models.

9 Limitations

We have not directly observed subspace independence or interference, and further work is required to establish their importance in real-world models. Experimentally, we simulate interference as being independent and similar in amplitude across heads and layers, however it is possible that it is correlated and depth-dependent. Whilst our stability experiments demonstrate that the model is more stable with respect to noise in sparse than non-sparse distributions, we have not shown whether this is due to the inherent stability of the attention distribution predicted by our theory, or the relative importance of sparse vs non-sparse distributions to the model. We show experimental results for a small model on a targeted task (with model variations in Appendix C); further work is needed to study the behaviour of larger models and different corpora.

Acknowledgments and Disclosure of Funding

This work is supported by UKRI Turing AI World-Leading Researcher Fellowship (EP/W002973/1). This work was partially funded by the Swiss National Science Foundation (SNSF) project NeuMath (200021_204617), by the EPSRC grant EP/T026995/1, “EnnCore: End-to-End Conceptual Guarding of Neural Architectures” under Security for all in an AI enabled society, by the CRUK National Biomarker Centre, and supported by the Manchester Experimental Cancer Medicine Centre and the NIHR Manchester Biomedical Research Centre.

References

  • [1] Vaswani et al. Attention is all you need. In Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
  • [2] Bubeck et al. Sparks of artificial general intelligence: Early experiments with gpt-4, 2023.
  • [3] Touvron et al. Llama: Open and efficient foundation language models, 2023.
  • [4] Abramson et al. Accurate structure prediction of biomolecular interactions with alphafold 3. Nature, 2024.
  • [5] Baek et al. Accurate prediction of protein structures and interactions using a three-track neural network. Science, 373(6557):871–876, 2021.
  • [6] Chandra et al. Transformer-based deep learning for predicting protein properties in the life sciences. eLife, 12:e82819, jan 2023.
  • [7] Cai et al. Transforming the bootstrap: Using transformers to compute scattering amplitudes in planar n = 4 super yang-mills theory, 2024.
  • [8] Olsson et al. In-context learning and induction heads. Transformer Circuits Thread, 2022. https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html.
  • [9] Elhage et al. A mathematical framework for transformer circuits. Transformer Circuits Thread, 2021. https://transformer-circuits.pub/2021/framework/index.html.
  • [10] Wang et al. Interpretability in the wild: a circuit for indirect object identification in GPT-2 small. In The Eleventh International Conference on Learning Representations, 2023.
  • [11] Ferrando et al. A primer on the inner workings of transformer-based language models, 2024.
  • [12] Meng et al. Locating and editing factual associations in GPT. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural Information Processing Systems, 2022.
  • [13] Rogers et al. A primer in bertology: What we know about how bert works. Transactions of the Association for Computational Linguistics, 8:842–866, 12 2020.
  • [14] Liu et al. Transformers learn shortcuts to automata. In The Eleventh International Conference on Learning Representations, 2023.
  • [15] Goldowsky-Dill et al. Localizing model behavior with path patching, 2023.
  • [16] Cammarata et al. Thread: Circuits. Distill, 2020. https://distill.pub/2020/circuits.
  • [17] Xiong et al. On layer normalization in the transformer architecture. In Proceedings of the 37th International Conference on Machine Learning, ICML’20. JMLR.org, 2020.
  • [18] Henry et al. Query-key normalization for transformers. CoRR, abs/2010.04245, 2020.
  • [19] Wortsman et al. Small-scale proxies for large-scale transformer training instabilities. In The Twelfth International Conference on Learning Representations, 2024.
  • [20] Dehghani et al. Scaling vision transformers to 22 billion parameters. In Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 7480–7512. PMLR, 23–29 Jul 2023.
  • [21] Biao Zhang and Rico Sennrich. Root mean square layer normalization. In Wallach et al, editor, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • [22] Ba et al. Layer normalization. CoRR, abs/1607.06450, 2016.
  • [23] Elhage et al. Toy models of superposition. Transformer Circuits Thread, 2022.
  • [24] Stolfo et al. A mechanistic interpretation of arithmetic reasoning in language models using causal mediation analysis, 2023.
  • [25] Goldowsky-Dill et al. Localizing model behavior with path patching. ArXiv, abs/2304.05969, 2023.
  • [26] Javier Ferrando and Elena Voita. Information flow routes: Automatically interpreting language models at scale. ArXiv, abs/2403.00824, 2024.
  • [27] Singh et al. The transient nature of emergent in-context learning in transformers. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  • [28] Singh et al. What needs to go right for an induction head? a mechanistic study of in-context learning circuits and their formation, 2024.
  • [29] Devlin et al. Bert: Pre-training of deep bidirectional transformers for language understanding. In North American Chapter of the Association for Computational Linguistics, 2019.
  • [30] Nguyen et al. Transformers without tears: Improving the normalization of self-attention. In Proceedings of the 16th International Conference on Spoken Language Translation, Hong Kong, November 2-3 2019. Association for Computational Linguistics.
  • [31] Toan Nguyen and David Chiang. Improving lexical choice in neural machine translation. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers), pages 334–343, New Orleans, Louisiana, June 2018. Association for Computational Linguistics.
  • [32] Shleifer et al. Normformer: Improved transformer pretraining with extra normalization, 2021.
  • [33] Xu et al. Understanding and improving layer normalization. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • [34] Kobayashi et al. Incorporating Residual and Normalization Layers into Analysis of Masked Language Models. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pages 4547–4568, Online and Punta Cana, Dominican Republic, November 2021. Association for Computational Linguistics.
  • [35] Brody et al. On the expressivity role of layernorm in transformers’ attention. pages 14211–14221, 01 2023.
  • [36] Raul Molina. Traveling words: A geometric interpretation of transformers. ArXiv, abs/2309.07315, 2023.
  • [37] Dong et al. Attention is not all you need: Pure attention loses rank doubly exponentially with depth. PMLR, 139, 2021.
  • [38] Wu et al. On the role of attention masks and layernorm in transformers, 2024.
  • [39] Wang et al. Towards understanding how transformer perform multi-step reasoning with matching operation, 2024.
  • [40] Boix-Adserà et al. When can transformers reason with abstract symbols? In The Twelfth International Conference on Learning Representations, 2024.
  • [41] Csordá et al. The neural data router: Adaptive control flow in transformers improves systematic generalization. In International Conference on Learning Representations, 2022.
  • [42] Lambd et al. Transformers with competitive ensembles of independent mechanisms. ArXiv, abs/2103.00336, 2021.
  • [43] Mikolov et al. Efficient estimation of word representations in vector space. pages 1–12, 01 2013.
  • [44] Mikolov et al. Distributed representations of words and phrases and their compositionality. In C.J. Burges, L. Bottou, M. Welling, Z. Ghahramani, and K.Q. Weinberger, editors, Advances in Neural Information Processing Systems, volume 26. Curran Associates, Inc., 2013.
  • [45] Jeffrey Pennington, Richard Socher, and Christopher D. Manning. Glove: Global vectors for word representation. In Conference on Empirical Methods in Natural Language Processing, 2014.
  • [46] Mikolov et al. Linguistic regularities in continuous space word representations. Proceedings of NAACL-HLT, pages 746–751, 01 2013.
  • [47] Park et al. The linear representation hypothesis and the geometry of large language models. ArXiv, abs/2311.03658, 2023.
  • [48] Jiang et al. On the origins of linear representations in large language models, 2024.
  • [49] Dmitry Nikolaev and Sebastian Padó. Investigating semantic subspaces of transformer sentence embeddings through linear structural probing. 10 2023.
  • [50] Yonatan Belinkov. Probing classifiers: Promises, shortcomings, and advances. Computational Linguistics, 48(1):207–219, March 2022.
  • [51] Geiger et al. Finding alignments between interpretable causal variables and distributed neural representations. ArXiv, abs/2303.02536, 2023.
  • [52] Sanjeev et al. Linear algebraic structure of word senses, with applications to polysemy. Transactions of the Association for Computational Linguistics, 6, 01 2016.
  • [53] Tripathi et al. Semantic subspace learning with conditional significance vectors. In The 2010 International Joint Conference on Neural Networks (IJCNN), pages 1–8, 2010.
  • [54] Coenen et al. Visualizing and measuring the geometry of bert. In Neural Information Processing Systems, 2019.
  • [55] Hewitt et al. A structural probe for finding syntax in word representations. In North American Chapter of the Association for Computational Linguistics, 2019.
  • [56] Kawin Ethayarajh. How contextual are contextualized word representations? Comparing the geometry of BERT, ELMo, and GPT-2 embeddings. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 55–65, Hong Kong, China, November 2019. Association for Computational Linguistics.
  • [57] Song et al. Uncovering hidden geometry in transformers via disentangling position and context, 2024.
  • [58] Mickus et al. How to dissect a muppet: The structure of transformer embedding spaces. Transactions of the Association for Computational Linguistics, 10:981–996, 09 2022.
  • [59] Hernandez et al. Linearity of relation decoding in transformer language models. In The Twelfth International Conference on Learning Representations, 2024.
  • [60] Chi et al. Finding universal grammatical relations in multilingual BERT. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 5564–5577, Online, July 2020. Association for Computational Linguistics.
  • [61] Cai et al. Isotropy in the contextual embedding space: Clusters and manifolds. In International Conference on Learning Representations, 2021.
  • [62] Evan Hernandez and Jacob Andreas. The low-dimensional linear geometry of contextualized word representations. In Conference on Computational Natural Language Learning, 2021.
  • [63] Martín Abadi et al. TensorFlow: Large-scale machine learning on heterogeneous systems, 2015. Software available from tensorflow.org.
  • [64] Harris et al. Array programming with NumPy. Nature, 585(7825):357–362, September 2020.
  • [65] Abien Fred Agarap. Deep learning using rectified linear units (relu), 2019.
  • [66] Xavier Glorot and Yoshua Bengio. Understanding the difficulty of training deep feedforward neural networks. In Yee Whye Teh and Mike Titterington, editors, Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics, volume 9 of Proceedings of Machine Learning Research, pages 249–256, Chia Laguna Resort, Sardinia, Italy, 13–15 May 2010. PMLR.
  • [67] François Chollet et al. Keras. https://keras.io, 2015.
  • [68] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization, 2019.

Appendix A Experimental setup

A.1 Data

Due to the task nature, we do not require static datasets and so generate both train and test data on-the-fly. This alleviates storage and memory concerns for long training runs in which a static dataset would have to be large. Datasets are reproducible through configuration of the environment and global random seed, which is used to manually control the random seeds of Python, TensorFlow [63], and NumPy [64]. This also reproduces the model initialisation.

A.2 Task specification

We consider an integer addition task, where each character is a base-10 numeral 0-9, mathematical operator {+, -, =, N}, or special character {[, ], *}. The N operator signifies that the following integer is negative, and is used to avoid overloading notation with the - operator, which means minus. The special characters are the begin-sequence token [, end-sequence token ], and mask character *. Input sequences in the same batch are right-padded with mask tokens to the same length, which do not contribute to the model. Characters that are masked in the output do not contribute to the evaluation metrics. We tokenise per-character so the model does not need to disambiguate different representations for identical patterns (e.g. if the number 112 is tokenised as [11,2], and 212 is tokenised as [2,12], then the pattern 12 has a context-dependent representation). The token dictionary has a length of 17.

For a decoder architecture, the model is a sequence-sequence transformer and each datapoint has a question-answer structure separated by the = character, e.g. the first datapoint is:

[453+16+17-N846=1332***************1332][453+16+17-N846=1332***************1332]{\color[rgb]{0,0,0}\texttt{{[453+16+17-N846=1332}}}~{}~{}~{}~{}\rightarrow~{}~% {}~{}~{}{\color[rgb]{0,0,0}\texttt{{***************1332]}}}[453+16+17-N846=1332 → ***************1332] (15)

The model must therefore predict the numerical outputs and the ] token. For an encoder-decoder architecture, the encoder input is the question and the decoder performs next-token prediction over the answer, e.g.

[453+16+17-N846] (encoder),[1332 (decoder)1332] (decoder)[453+16+17-N846] (encoder)[1332 (decoder)1332] (decoder){\color[rgb]{0,0,0}\texttt{{[453+16+17-N846]}}}~{}\text{~{}(encoder)},~{}~{}{% \color[rgb]{0,0,0}\texttt{{[1332}}}~{}\text{~{}(decoder)}~{}~{}~{}~{}% \rightarrow~{}~{}~{}~{}{\color[rgb]{0,0,0}\texttt{{1332]}}}~{}\text{~{}(% decoder)}[453+16+17-N846] (encoder) , [1332 (decoder) → 1332] (decoder) (16)

To help visualise the task, Figure 7 shows the predictions of the baseline Pre-Norm model after 1 epoch. Figure 7 shows the fully-trained model, to help visualise the attainable in-distribution performance. The final epoch per-token accuracy is logged as 92%; the model sometimes correctly predicts all digits of the answer, otherwise it appears to be correct in the leading digits. Figures 7-7 repeat this for the Large model variation, which acts on a more complex task setting and achieves a lower per-token accuracy of 57%. Once again, the correctly predicted tokens appear to be driven by the leading digits.

Refer to caption
Figure 4: Baseline Pre-Norm model predictions after 1 training epoch.
Refer to caption
Figure 5: Baseline Pre-Norm model predictions after training.
Refer to caption
Figure 6: Large Pre-Norm model predictions after 1 training epoch.
Refer to caption
Figure 7: Large Pre-Norm model predictions after training.

A.3 Data-generation process

One advantage of this task is the ability to modulate its complexity. Each dataset is defined by two hyperparameters:

Dataset parameter Example Description
N𝑁Nitalic_N [3, 4, 6] The allowed number of integers per-sequence
L𝐿Litalic_L [2, 3] The allowed number of digits per-integer

Each datapoint is generated by uniformly sampling a value of N𝑁Nitalic_N, then uniformly sampling a value of L𝐿Litalic_L for each integer. This ensures that examples are not simply dominated by integers with the maximum number of digits. Each integer is uniformly sampled from all positive and negative integers with that length. Between each integer, an operator is uniformly sampled from the list [+,][+,-][ + , - ]. For example, the datapoint

[453+16+17-N846=1332***************1332][453+16+17-N846=1332***************1332]{\color[rgb]{0,0,0}\texttt{{[453+16+17-N846=1332}}}~{}~{}~{}~{}\rightarrow~{}~% {}~{}~{}{\color[rgb]{0,0,0}\texttt{{***************1332]}}}[453+16+17-N846=1332 → ***************1332] (17)

was generated by sampling a value of N=4𝑁4N=4italic_N = 4 to determine that the sum contains four integers, then sampling four values of L=[3,2,2,3]𝐿3223L=[3,2,2,3]italic_L = [ 3 , 2 , 2 , 3 ] to determine their lengths, then sampling the numbers [453,16,17,N846]4531617𝑁846[453,16,17,N846][ 453 , 16 , 17 , italic_N 846 ] and operators [+,+,][+,+,-][ + , + , - ]. The inclusion of subtraction, addition of negative numbers, and double-negatives is intended to emphasise solutions that parse the context of each digit within the sum.

A.4 Train/test specifications

Table 5 shows the N𝑁Nitalic_N and L𝐿Litalic_L parameters used for the Baseline and Alternate experiments. We also show the number of datapoints, and the per-datapoint sampling probability. This is a range, with higher probabilites for the simpler sums. Table 5 shows the task specification for the Large model variation, which is trained on a more complex setting. We also perform a scan over model size and learning rate to compare the training stability of Pre-Norm and QKV-Norm. These experiments were performed using an earlier problem configuration shown in Table 5.

Dataset N𝑁Nitalic_N L𝐿Litalic_L Num datapoints Datapoint probability
Train [3,4,6]346[3,4,6][ 3 , 4 , 6 ] [2,3]23[2,3][ 2 , 3 ] 110M, acc=90% @ 40M 2×1092superscript1092\times 10^{-9}2 × 10 start_POSTSUPERSCRIPT - 9 end_POSTSUPERSCRIPT to 5×10245superscript10245\times 10^{-24}5 × 10 start_POSTSUPERSCRIPT - 24 end_POSTSUPERSCRIPT
Validation [5]delimited-[]5[5][ 5 ] [2,3]23[2,3][ 2 , 3 ] 6.4k 1×10141superscript10141\times 10^{-14}1 × 10 start_POSTSUPERSCRIPT - 14 end_POSTSUPERSCRIPT to 1×10191superscript10191\times 10^{-19}1 × 10 start_POSTSUPERSCRIPT - 19 end_POSTSUPERSCRIPT
In-distribution [3,4,6]346[3,4,6][ 3 , 4 , 6 ] [2,3]23[2,3][ 2 , 3 ] 128k 2×1092superscript1092\times 10^{-9}2 × 10 start_POSTSUPERSCRIPT - 9 end_POSTSUPERSCRIPT to 5×10245superscript10245\times 10^{-24}5 × 10 start_POSTSUPERSCRIPT - 24 end_POSTSUPERSCRIPT
OOD (interpolation) [5]delimited-[]5[5][ 5 ] [2,3]23[2,3][ 2 , 3 ] 128k 1×10141superscript10141\times 10^{-14}1 × 10 start_POSTSUPERSCRIPT - 14 end_POSTSUPERSCRIPT to 1×10191superscript10191\times 10^{-19}1 × 10 start_POSTSUPERSCRIPT - 19 end_POSTSUPERSCRIPT
OOD (extrapolation) [7,8,9]789[7,8,9][ 7 , 8 , 9 ] [2,3]23[2,3][ 2 , 3 ] 128k 7×10217superscript10217\times 10^{-21}7 × 10 start_POSTSUPERSCRIPT - 21 end_POSTSUPERSCRIPT to 1×10351superscript10351\times 10^{-35}1 × 10 start_POSTSUPERSCRIPT - 35 end_POSTSUPERSCRIPT
Table 3: Dataset configurations used for Baseline and Alternate results.
Dataset N𝑁Nitalic_N L𝐿Litalic_L Num datapoints Datapoint probability
Train [4,5,7,8]4578[4,5,7,8][ 4 , 5 , 7 , 8 ] [3,4,5]345[3,4,5][ 3 , 4 , 5 ] 25M 4×10174superscript10174\times 10^{-17}4 × 10 start_POSTSUPERSCRIPT - 17 end_POSTSUPERSCRIPT to 3×10493superscript10493\times 10^{-49}3 × 10 start_POSTSUPERSCRIPT - 49 end_POSTSUPERSCRIPT
Validation [6]delimited-[]6[6][ 6 ] [3,4,5]345[3,4,5][ 3 , 4 , 5 ] 6.4k 2×10242superscript10242\times 10^{-24}2 × 10 start_POSTSUPERSCRIPT - 24 end_POSTSUPERSCRIPT to 2×10362superscript10362\times 10^{-36}2 × 10 start_POSTSUPERSCRIPT - 36 end_POSTSUPERSCRIPT
In-distribution [4,5,7,8]4578[4,5,7,8][ 4 , 5 , 7 , 8 ] [3,4,5]345[3,4,5][ 3 , 4 , 5 ] 128k 4×10174superscript10174\times 10^{-17}4 × 10 start_POSTSUPERSCRIPT - 17 end_POSTSUPERSCRIPT to 3×10493superscript10493\times 10^{-49}3 × 10 start_POSTSUPERSCRIPT - 49 end_POSTSUPERSCRIPT
OOD (interpolation) [6]delimited-[]6[6][ 6 ] [3,4,5]345[3,4,5][ 3 , 4 , 5 ] 128k 2×10242superscript10242\times 10^{-24}2 × 10 start_POSTSUPERSCRIPT - 24 end_POSTSUPERSCRIPT to 2×10362superscript10362\times 10^{-36}2 × 10 start_POSTSUPERSCRIPT - 36 end_POSTSUPERSCRIPT
OOD (extrapolation) [9,10,11]91011[9,10,11][ 9 , 10 , 11 ] [3,4,5]345[3,4,5][ 3 , 4 , 5 ] 128k 3×10373superscript10373\times 10^{-37}3 × 10 start_POSTSUPERSCRIPT - 37 end_POSTSUPERSCRIPT to 3×10673superscript10673\times 10^{-67}3 × 10 start_POSTSUPERSCRIPT - 67 end_POSTSUPERSCRIPT
Table 4: Dataset configurations used for Large results.
Dataset N𝑁Nitalic_N L𝐿Litalic_L Datapoint probability
Train set [3,4,5,7]3457[3,4,5,7][ 3 , 4 , 5 , 7 ] [3,4,5]345[3,4,5][ 3 , 4 , 5 ] 4×10134superscript10134\times 10^{-13}4 × 10 start_POSTSUPERSCRIPT - 13 end_POSTSUPERSCRIPT to 3×10433superscript10433\times 10^{-43}3 × 10 start_POSTSUPERSCRIPT - 43 end_POSTSUPERSCRIPT
In-distribution [3,4,5,7]3457[3,4,5,7][ 3 , 4 , 5 , 7 ] [3,4,5]345[3,4,5][ 3 , 4 , 5 ] 4×10134superscript10134\times 10^{-13}4 × 10 start_POSTSUPERSCRIPT - 13 end_POSTSUPERSCRIPT to 3×10433superscript10433\times 10^{-43}3 × 10 start_POSTSUPERSCRIPT - 43 end_POSTSUPERSCRIPT
Table 5: Dataset configurations used for training stability results (Figure 24).

We halt training according to wall time, which leads to a range of observed dataset sizes. Model convergence may also occur much earlier. We therefore show an order-of-magnitude estimate for the number of observed datapoints, as well as the point at which the baseline model reaches 90% per-token accuracy (this represents almost-convergence, which is logged at 92.1%).

Note that our data-generation strategy does not ensure that training examples are exclusive (there may be repetitions), nor that the in-distribution test set does not contain overlap with training examples. The final column is therefore important, because it demonstrates that the highest per-datapoint sampling probability is 2×1092superscript1092\times 10^{-9}2 × 10 start_POSTSUPERSCRIPT - 9 end_POSTSUPERSCRIPT, whilst the model converges with 𝒪(107)𝒪superscript107\mathcal{O}(10^{7})caligraphic_O ( 10 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT ) datapoints and observes 𝒪(108)𝒪superscript108\mathcal{O}(10^{8})caligraphic_O ( 10 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT ) in total. Since the datapoint probability is 2×1092superscript1092\times 10^{-9}2 × 10 start_POSTSUPERSCRIPT - 9 end_POSTSUPERSCRIPT for the simplest configurations and 5×10255superscript10255\times 10^{-25}5 × 10 start_POSTSUPERSCRIPT - 25 end_POSTSUPERSCRIPT for the most complicated, this ensures that the in-distribution evaluation metric is dominated by novel examples. The validation set is only used for visual inspection of model behaviour during training, as in Figures 7-7.

A.5 Model specification (main experiments)

We use a decoder architecture, meaning that the dot-product self-attention layers are causally masked such that token t𝑡titalic_t can only attend to tokens tabsent𝑡\leq t≤ italic_t. The model has the following structure:

Embedding + positional encodingNlayer×[Attention blockFeed-forward block (ReLU)]Multi-layer perceptron (ReLU)Predicted logitsmatrixEmbedding + positional encodingsubscript𝑁𝑙𝑎𝑦𝑒𝑟matrixAttention blockFeed-forward block (ReLU)Multi-layer perceptron (ReLU)Predicted logits\begin{matrix}\texttt{Embedding + positional encoding}\\ \downarrow\\ N_{layer}\times\begin{bmatrix}\texttt{Attention block}\\ \downarrow\\ \texttt{Feed-forward block~{}(ReLU)}\\ \end{bmatrix}\\ \downarrow\\ \texttt{Multi-layer perceptron~{}(ReLU)}\\ \downarrow\\ \texttt{Predicted~{}logits}\\ \end{matrix}start_ARG start_ROW start_CELL Embedding + positional encoding end_CELL end_ROW start_ROW start_CELL ↓ end_CELL end_ROW start_ROW start_CELL italic_N start_POSTSUBSCRIPT italic_l italic_a italic_y italic_e italic_r end_POSTSUBSCRIPT × [ start_ARG start_ROW start_CELL Attention block end_CELL end_ROW start_ROW start_CELL ↓ end_CELL end_ROW start_ROW start_CELL Feed-forward block (ReLU) end_CELL end_ROW end_ARG ] end_CELL end_ROW start_ROW start_CELL ↓ end_CELL end_ROW start_ROW start_CELL Multi-layer perceptron (ReLU) end_CELL end_ROW start_ROW start_CELL ↓ end_CELL end_ROW start_ROW start_CELL Predicted logits end_CELL end_ROW end_ARG

Embedding + positional encoding   We initialise each token embedding as x=xtype+xpos𝑥subscript𝑥𝑡𝑦𝑝𝑒subscript𝑥𝑝𝑜𝑠x=x_{type}+x_{pos}italic_x = italic_x start_POSTSUBSCRIPT italic_t italic_y italic_p italic_e end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_p italic_o italic_s end_POSTSUBSCRIPT, where xtypesubscript𝑥𝑡𝑦𝑝𝑒x_{type}italic_x start_POSTSUBSCRIPT italic_t italic_y italic_p italic_e end_POSTSUBSCRIPT is a token embedding with Nembsubscript𝑁𝑒𝑚𝑏N_{emb}italic_N start_POSTSUBSCRIPT italic_e italic_m italic_b end_POSTSUBSCRIPT elements, and xpossubscript𝑥𝑝𝑜𝑠x_{pos}italic_x start_POSTSUBSCRIPT italic_p italic_o italic_s end_POSTSUBSCRIPT use cyclic positional encodings of the same form as the original transformer architecture [1], with Nfreqsubscript𝑁𝑓𝑟𝑒𝑞N_{freq}italic_N start_POSTSUBSCRIPT italic_f italic_r italic_e italic_q end_POSTSUBSCRIPT frequencies initialised as a base e𝑒eitalic_e log-series between periods of 3333 and 1k1𝑘1k1 italic_k tokens. For each sequence, all position indices are simultaneously offset by a random integer between 00 and 50505050. This augmentation is designed to encourage the use of relative positions rather than absolute. The frequencies are then left as trainable parameters. The positional encodings contribute the first 2Nfreq2subscript𝑁𝑓𝑟𝑒𝑞2N_{freq}2 italic_N start_POSTSUBSCRIPT italic_f italic_r italic_e italic_q end_POSTSUBSCRIPT components of xpossubscript𝑥𝑝𝑜𝑠x_{pos}italic_x start_POSTSUBSCRIPT italic_p italic_o italic_s end_POSTSUBSCRIPT, and the remaining are set to 00. This configuration guarantees that the token embeddings and positional encodings can be made orthogonal in the first layer, and xpossubscript𝑥𝑝𝑜𝑠x_{pos}italic_x start_POSTSUBSCRIPT italic_p italic_o italic_s end_POSTSUBSCRIPT have constant L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm, consistent with our theoretical structure.

Attention block   Nlayersubscript𝑁𝑙𝑎𝑦𝑒𝑟N_{layer}italic_N start_POSTSUBSCRIPT italic_l italic_a italic_y italic_e italic_r end_POSTSUBSCRIPT is the number of residual blocks of our model, where our baseline is Nlayer=10subscript𝑁𝑙𝑎𝑦𝑒𝑟10N_{layer}=10italic_N start_POSTSUBSCRIPT italic_l italic_a italic_y italic_e italic_r end_POSTSUBSCRIPT = 10. The update is as formulated in section 4, where H𝐻Hitalic_H is the number of parallel attention heads per layer. Since the embeddings have length Nembsubscript𝑁𝑒𝑚𝑏N_{emb}italic_N start_POSTSUBSCRIPT italic_e italic_m italic_b end_POSTSUBSCRIPT, we must have Nx=Nembsubscript𝑁𝑥subscript𝑁𝑒𝑚𝑏N_{x}=N_{emb}italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT = italic_N start_POSTSUBSCRIPT italic_e italic_m italic_b end_POSTSUBSCRIPT, whilst the latent dimension Nqkvsubscript𝑁𝑞𝑘𝑣N_{qkv}italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT is configurable. Either the Pre-Norm or QKV-Norm strategy is used, as configured.

Feed-forward block   The feed-forward blocks update embeddings using the function xx+FF(LayerNorm(x))𝑥𝑥𝐹𝐹LayerNorm𝑥x\rightarrow x+FF(\texttt{LayerNorm}(x))italic_x → italic_x + italic_F italic_F ( LayerNorm ( italic_x ) ), where FF𝐹𝐹FFitalic_F italic_F is a dense network with one hidden layer of size Nffsubscript𝑁𝑓𝑓N_{ff}italic_N start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT. The network uses a ReLU [65] activation function on the intermediate layer, followed by a linear projection back onto embedding space. To maintain consistency with other models, we apply LayerNorm at the input to FF𝐹𝐹FFitalic_F italic_F. Both LayerNorm and FF𝐹𝐹FFitalic_F italic_F use bias parameters.

Multi-layer perceptron   The final embeddings x𝑥xitalic_x are mapped onto token logits y𝑦yitalic_y using the function y=MLP(LayerNorm(x))𝑦𝑀𝐿𝑃LayerNorm𝑥y=MLP(\texttt{LayerNorm}(x))italic_y = italic_M italic_L italic_P ( LayerNorm ( italic_x ) ), where MLP𝑀𝐿𝑃MLPitalic_M italic_L italic_P is a multi-layer perceptron with two hidden layers of size NMLPsubscript𝑁𝑀𝐿𝑃N_{MLP}italic_N start_POSTSUBSCRIPT italic_M italic_L italic_P end_POSTSUBSCRIPT and ReLU activation. The final layer is a linear projection onto the space of logits, which has length 17. For the training stability scan in Figure 24, the MLP has three hidden layers instead.

Hyperparameters   Table 7 shows the hyperparameters used to configure the networks of the main experiments. Table 7 show the hyperparameters used for the training stability analysis. This experiment also uses encoder-decoder models, following the same setup as the original transformer architecture [1] and with the layer configurations listed here.

Model Nfreqsubscript𝑁𝑓𝑟𝑒𝑞N_{freq}italic_N start_POSTSUBSCRIPT italic_f italic_r italic_e italic_q end_POSTSUBSCRIPT Nembsubscript𝑁𝑒𝑚𝑏N_{emb}italic_N start_POSTSUBSCRIPT italic_e italic_m italic_b end_POSTSUBSCRIPT Nlayersubscript𝑁𝑙𝑎𝑦𝑒𝑟N_{layer}italic_N start_POSTSUBSCRIPT italic_l italic_a italic_y italic_e italic_r end_POSTSUBSCRIPT H𝐻Hitalic_H Nqkvsubscript𝑁𝑞𝑘𝑣N_{qkv}italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT Nffsubscript𝑁𝑓𝑓N_{ff}italic_N start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT NMLPsubscript𝑁𝑀𝐿𝑃N_{MLP}italic_N start_POSTSUBSCRIPT italic_M italic_L italic_P end_POSTSUBSCRIPT seed
Baseline 32 512 10 12 64 512 2×\times×512 100
Alternative 32 512 8 12 64 512 2×\times×512 100
Large 32 1024 12 16 64 512 2×\times×512 100
Table 6: Model hyperparameters for main experiments (i.e. other than training stability). Baseline is used for the main results presented in section 7 (short). Alternate and Large are presented in appendix C to show reproducibility of observations.
Model Nfreqsubscript𝑁𝑓𝑟𝑒𝑞N_{freq}italic_N start_POSTSUBSCRIPT italic_f italic_r italic_e italic_q end_POSTSUBSCRIPT Nembsubscript𝑁𝑒𝑚𝑏N_{emb}italic_N start_POSTSUBSCRIPT italic_e italic_m italic_b end_POSTSUBSCRIPT Nlayersubscript𝑁𝑙𝑎𝑦𝑒𝑟N_{layer}italic_N start_POSTSUBSCRIPT italic_l italic_a italic_y italic_e italic_r end_POSTSUBSCRIPT H𝐻Hitalic_H Nqkvsubscript𝑁𝑞𝑘𝑣N_{qkv}italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT Nffsubscript𝑁𝑓𝑓N_{ff}italic_N start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT NMLPsubscript𝑁𝑀𝐿𝑃N_{MLP}italic_N start_POSTSUBSCRIPT italic_M italic_L italic_P end_POSTSUBSCRIPT seed
All 16 - - 12 - 512 3×\times×512 1,2
Table 7: Model hyperparameters for training stability experiments. Empty parameters are varied per-model and displayed in Figure 24.

Loss   The loss function is categorical cross entropy, calculated from the output logits.

A.6 Model initialisation

We use a custom initialisation strategy to give control over the initial state of the model. In particular, we use Checkpoint layers to ensure that the initial states are comparable between Pre-Norm and QKV-Norm. This ensures that any observed differences are driven by the normalisation function, rather than being confounded by the layer placement creating more/less favourable initial conditions.

Checkpoint layers are calibrated on the first training batch immediately prior to training. They use this data to measure the standard deviation at that point, and calculate a scale factor that fixes the standard deviation to a pre-defined hyperparameter σ𝜎\sigmaitalic_σ. All subsequent passes through the layer simply apply this scale factor. This ensures that the model is initialised with a standard deviation of σ𝜎\sigmaitalic_σ at that point.

We apply Checkpoint layers to the token embeddings xtypesubscript𝑥𝑡𝑦𝑝𝑒x_{type}italic_x start_POSTSUBSCRIPT italic_t italic_y italic_p italic_e end_POSTSUBSCRIPT (σtype=0.5subscript𝜎𝑡𝑦𝑝𝑒0.5\sigma_{type}=0.5italic_σ start_POSTSUBSCRIPT italic_t italic_y italic_p italic_e end_POSTSUBSCRIPT = 0.5), and the initial embeddings x𝑥xitalic_x (σx=1.0subscript𝜎𝑥1.0\sigma_{x}=1.0italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT = 1.0), ensuring they are relatively balanced and unit scale. In every attention layer, we apply Checkpoint layers to re-calibrate the possibly-Pre-normalised embeddings to σxsubscript𝜎𝑥\sigma_{x}italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT immediately before applying the WQsubscript𝑊𝑄W_{Q}italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT, WKsubscript𝑊𝐾W_{K}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, and WVsubscript𝑊𝑉W_{V}italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT operators. This counteracts the effect that transformer necessarily increases the embedding variance throughout the model at initialisation. We apply Checkpoint layers to wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT in every attention layer, with constant σw=0.1subscript𝜎𝑤0.1\sigma_{w}=0.1italic_σ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = 0.1. This controls the variance on the initial-state attention distribution. We apply Checkpoint layers to ΔxΔ𝑥\Delta xroman_Δ italic_x in every attention layer, with constant σΔx=0.05subscript𝜎Δ𝑥0.05\sigma_{\Delta x}=0.05italic_σ start_POSTSUBSCRIPT roman_Δ italic_x end_POSTSUBSCRIPT = 0.05, calibrating it with respect to x𝑥xitalic_x.

In the attention layer, we use uniform initialisation of the weight matrices WQsubscript𝑊𝑄W_{Q}italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT, WKsubscript𝑊𝐾W_{K}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, WVsubscript𝑊𝑉W_{V}italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT, and WOsubscript𝑊𝑂W_{O}italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT. The limits are configured to ensure that the initial state standard deviations on wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and ΔxΔ𝑥\Delta xroman_Δ italic_x are close to their target values. Defining σqkσwNqkv34subscript𝜎𝑞𝑘4subscript𝜎𝑤superscriptsubscript𝑁𝑞𝑘𝑣3\sigma_{qk}\triangleq\sqrt[4]{\frac{\sigma_{w}}{N_{qkv}^{3}}}italic_σ start_POSTSUBSCRIPT italic_q italic_k end_POSTSUBSCRIPT ≜ nth-root start_ARG 4 end_ARG start_ARG divide start_ARG italic_σ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG end_ARG, the limits are calculated as follows:

Weight Limits
WQsubscript𝑊𝑄W_{Q}italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ±3σqkplus-or-minus3subscript𝜎𝑞𝑘\pm\sqrt{3}\sigma_{qk}± square-root start_ARG 3 end_ARG italic_σ start_POSTSUBSCRIPT italic_q italic_k end_POSTSUBSCRIPT
WKsubscript𝑊𝐾W_{K}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ±3σqkplus-or-minus3subscript𝜎𝑞𝑘\pm\sqrt{3}\sigma_{qk}± square-root start_ARG 3 end_ARG italic_σ start_POSTSUBSCRIPT italic_q italic_k end_POSTSUBSCRIPT
WVsubscript𝑊𝑉W_{V}italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ±3Nqkvplus-or-minus3subscript𝑁𝑞𝑘𝑣\pm\sqrt{\frac{3}{N_{qkv}}}± square-root start_ARG divide start_ARG 3 end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_ARG end_ARG
WOsubscript𝑊𝑂W_{O}italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT ±3HNqkvplus-or-minus3𝐻subscript𝑁𝑞𝑘𝑣\pm\sqrt{\frac{3}{HN_{qkv}}}± square-root start_ARG divide start_ARG 3 end_ARG start_ARG italic_H italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_ARG end_ARG

However, we note that this initialisation is superseded by the calibration of the Checkpoint layers for determining the initial state, and we include it only to promote numerical stability. All other feed-forward layers use Glorot uniform [66] initialisation, as implemented in Keras [67]. Normalisation gain parameters are initialised to 1111 and biases, where used, to 00.

A.7 Training algorithm

We train using the AdamW optimiser [68] with learning rate 3×1043superscript1043\times 10^{-4}3 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT and weight decay of 0.010.010.010.01, with all other parameters following their default values in TensorFlow+Keras v2.15.0. Each epoch consists of 2000200020002000 batches of 128128128128 datapoints. For the main experiments and model variations, we use an adaptive learning rate decay strategy. This means that the learning rate is multiplied by a factor of 0.50.50.50.5 if the training loss does not improve for 3333 consecutive epochs. We find that this balances training speed with improved performance by using small learning rates later in training. Training is halted after two days of wall time, which we observe to allow model convergence, as shown in Figure 8. For the model stability scan, training is run for 60606060 hours, and learning rate is not allowed to decay (stability with respect to learning rate being one of the targets of study).

Refer to caption
Figure 8: Model training curves for the Baseline Pre-Norm configuration.

A.8 Computational resources

The main experiments are all performed on a single Nvidia v100-SXM2-16GB (Volta) GPU. The scan of models used for the stability analysis were trained on a batch cluster with a variety of compute nodes, using 8888 cores per training run. A representative compute node is 2×12-core Intel Xeon E5-2690 v3 @ 2.60GHz + 128GB RAM.

A.9 Environment details

The main contributing package versions are as follows:

Package Version
Python 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
TensorFlow 2.15.0
Keras 2.15.0
NumPy 1.26.2

Appendix B Extended main experiments

This appendix provides an extended explanation of the experimental results in section 7.

B.1 Embedding structure

Figure 1 presented the spread of embedding L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms as a function of model depth. Let us now describe in detail how this plot was made. Figure 10 shows the distribution of embeddings at the input to every attention layer, for Baseline models trained using Pre-Norm (top) and QKV-Norm (bottom). Colours represent the initial token type corresponding to that embedding. Asterisks denote tokens in the answer, with all labels denoting the question.

Refer to caption
Figure 9: Distribution of embedding L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms at different model depths using the Baseline Pre-Norm model.
Refer to caption
Figure 10: Distribution of embedding L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms at different model depths using the Baseline QKV-Norm model.

We see that the begin-sequence token (BEG) is often separate from the distribution, which may be because it remains non-annotated and fulfils a qualitatively different role. We remove this from our estimates to avoid erroneously inflating the spread. Interestingly, BEG still tends to be close to the main bulk for Pre-Norm, but can be very far for QKV-Norm, and we have to use overflow panels to capture it. This is consistent with our hypothesis that Pre-Norm stability requires embeddings to have similar norms, whilst QKV-Norm does not require this.

It is not sufficient to simply measure the spread of Figures 10 and 10, because an attention head may not be sensitive to all embeddings in the layer. The easiest way to account for this is to weight every embedding according to its assigned attention. Secondly, the distribution is expected to be narrow only on a per-head basis, and there is no reason why distinct heads cannot be centred around different medians. We therefore calculate the weighted distribution of embeddings on a per-head basis, as shown in Figures 12-12.

Refer to caption
Figure 11: Embedding distributions at different model depths using the Baseline Pre-Norm model. The categories 0-9 and N are separated into whether they occur in the question (light colour) or answer (dark colour, distinguished by * label). These distributions are used to compute the LHS of Figure 1 after removing the BEG tokens.
Refer to caption
Figure 12: Embedding distributions at different model depths using the Baseline QKV-Norm model. The categories 0-9 and N are separated into whether they occur in the question (light colour) or answer (dark colour, distinguished by * label). Note that overflow panels are excluded from this plot for legibility. These distributions are used to compute the RHS of Figure 1 after removing the BEG tokens.

B.2 Circuit collapse

Figure 3 shows the probability of circuit collapse. This is the probability that an attention distribution with no noise selects embedding i𝑖iitalic_i with high probability ai95%subscript𝑎𝑖percent95a_{i}\geq 95\%italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 95 %, and when noise is added, it transitions such that some ki𝑘𝑖k\neq iitalic_k ≠ italic_i becomes the maximum attended embedding. This definition is chosen because it matches our theoretical results in section 6. However, it does not require that the distribution remains sparse after the noise addition. Figure 13 compares this baseline result (top) with an alternative definition (bottom), in which the second distribution must also be sparse, meaning ak95%subscript𝑎𝑘percent95a_{k}\geq 95\%italic_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ≥ 95 % for some ki𝑘𝑖k\neq iitalic_k ≠ italic_i. We see that 1%percent11\%1 % of sparse attention distributions collapse at a noise level of 11%percent1111\%11 % when using the original definition, delayed until 17%percent1717\%17 % when using the sparse definition. Therefore we observe that the sparse-to-sparse case does occur, but requires a higher noise level.

Refer to caption
Refer to caption
Figure 13: Probability of circuit collapse vs increasing noise. Top: using the baseline definition. This is a reproduction of Figure 3. Bottom: requiring the attention distribution to remain sparse after switching to a different token.

Appendix C Main experiments: results with different models

In this appendix we reproduce the main experimental results using our model variations.

C.1 Embedding lengths

Figure 1 shows the empirical results demonstrating the attention-weighted spread of embeddings. Figures 16-16 show the results we obtain when we perform the same analysis using the Alternate and Large model variations. In all cases, we observe 90% of embeddings within a spread of roughly ±30%plus-or-minuspercent30\pm 30\%± 30 % when using Pre-Norm. In all cases, the spread of embeddings for QKV-Norm is larger, although we note that the effect is smaller when using the Alternate variation.

Refer to caption
Figure 14: Attention-weighted spread of embeddings at increasing model depth using the Baseline model and task configuration. This is a replication of Figure 1.
Refer to caption
Figure 15: Attention-weighted spread of embeddings at increasing model depth using the Alternate model and task configuration.
Refer to caption
Figure 16: Attention-weighted spread of embeddings at increasing model depth using the Large model and task configuration.

C.2 Model stability with simulated inference

Figure 3(left) shows the stability of the model predictions under simulated interference. Figure 18 shows the results we obtain when we perform the same analysis using the Alternate model variation. The Large model was not run due to its high computational load. We find that the Alternate model has a larger effect size that Baseline, with a 20%greater-than-or-equivalent-toabsentpercent20\gtrsim 20\%≳ 20 % loss of per-token accuracy with only a 1%percent11\%1 % noise effect. For completeness, we show QKV-Norm on the RHS. This is stable by construction, and only jitter due to finite sampling is observed. In these plots, we estimate the statistical uncertainty by evaluating over three datasets and calculating the standard error on the mean. This is plotted as a shaded band, but tends to be narrower than the line width.

Figure 3(right) compares the stability when we only apply noise to sparse heads (defined as maxiai95%subscript𝑖subscript𝑎𝑖percent95\max_{i}a_{i}\geq 95\%roman_max start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 95 %, thin dashed line) and non-sparse heads (defined as maxiai<70%subscript𝑖subscript𝑎𝑖percent70\max_{i}a_{i}<70\%roman_max start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT < 70 %, thick dashed line). Figure 20 compares these results with the Alternate model variation. In both experiments, sparse-attention is stable under %-level noise, and non-sparse distributions dominate this regime.

Note that later layers experience both the artificial noise injection as well as perturbation of their inputs due to the compounding of errors caused by noise in the previous layers.

Refer to caption
Figure 17: Evolution of per-token accuracy as we increase noise on the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms of {q,kt,mt}𝑞subscript𝑘𝑡subscript𝑚𝑡\{q,k_{t},m_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } for the Baseline model and task configuration.
Refer to caption
Figure 18: Evolution of per-token accuracy as we increase noise on the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms of {q,kt,mt}𝑞subscript𝑘𝑡subscript𝑚𝑡\{q,k_{t},m_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } for the Alternate model and task configuration.
Refer to caption
Figure 19: Evolution of per-token accuracy as we increase noise on the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms of {q,kt,mt}𝑞subscript𝑘𝑡subscript𝑚𝑡\{q,k_{t},m_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } for the Baseline model and task configuration.
Refer to caption
Figure 20: Evolution of per-token accuracy as we increase noise on the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norms of {q,kt,mt}𝑞subscript𝑘𝑡subscript𝑚𝑡\{q,k_{t},m_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } for the Alternate model and task configuration.

C.3 Circuit collapse

Figure 3 shows the probability of circuit collapse. This is the probability that an attention distribution with no noise selects embedding i𝑖iitalic_i with high probability ai95%subscript𝑎𝑖percent95a_{i}\geq 95\%italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 95 %, and when noise is added, it transitions such that some ki𝑘𝑖k\neq iitalic_k ≠ italic_i becomes the maximum attended embedding. Figure 23 shows the results we obtain when we perform the same analysis using the Alternate and Large model variations. In both cases, we observe the onset of circuit collapse at smaller noise levels. Whilst the Baseline model observed that 1% of sparse attention heads collapsed with 11% noise, this value is 7.5% for Alternate and 5.5% for Large.

Refer to caption
Figure 21: Probability of circuit collapse vs increasing noise using the Baseline model and task configuration. This is a replication of Figure 3.
Refer to caption
Figure 22: Probability of circuit collapse vs increasing noise using the Alternate model and task configuration.
Refer to caption
Figure 23: Probability of circuit collapse vs increasing noise using the Large model and task configuration.

Appendix D Additional comparisons between Pre-Norm and QKV-Norm

D.1 Model performance

Table 8 shows the per-token accuracy performance for the trained Baseline models. Pre-Norm and QKV-Norm have comparable in-distribution per-token accuracies of 91.4%percent91.491.4\%91.4 % and 91.0%percent91.091.0\%91.0 % respectively. However, performance drops to 87.5%percent87.587.5\%87.5 % (82.5%percent82.582.5\%82.5 %) for generalisation to intermediate task difficulty, and 66.7%percent66.766.7\%66.7 % (46.8%percent46.846.8\%46.8 %) for increased difficulty. The performance drop of QKV-Norm implies that it has learned a less generalisable solution. This re-enforces our motivation that architectural changes should be important for the inductive bias of a model.

Dataset Pre-Norm QKV-Norm
In-distribution 91.38±0.04%plus-or-minus91.38percent0.0491.38\pm 0.04\%91.38 ± 0.04 % 90.99±0.03%plus-or-minus90.99percent0.0390.99\pm 0.03\%90.99 ± 0.03 %
OOD (interpolation) 87.46±0.04%plus-or-minus87.46percent0.0487.46\pm 0.04\%87.46 ± 0.04 % 82.54±0.04%plus-or-minus82.54percent0.0482.54\pm 0.04\%82.54 ± 0.04 %
OOD (extrapolation) 66.65±0.05%plus-or-minus66.65percent0.0566.65\pm 0.05\%66.65 ± 0.05 % 46.76±0.05%plus-or-minus46.76percent0.0546.76\pm 0.05\%46.76 ± 0.05 %
Table 8: Per-token accuracy for the Baseline models. Dataset configurations are shown in Table 5.

D.2 Training stability

Changing the normalisation layer is expected to affect the training rate and stability. To investigate this, Figure 24 shows the training curves for different model sizes and learning rates. The task is configured as presented in Table 7. The Depth parameter is the number of layers, where brackets indicate the values for an encoder-decoder model. For example, (2,2)22(2,2)( 2 , 2 ) means that we use 2222 encoder blocks and 2222 decoder blocks. Each decoder block has a self-attention and a cross-attention layer, and so the total model has 6666 attention layers. A single Depth value indicates a decoder architecture, with the number of layers shown. Width is the number of neurons per layer, and Latent width is the number of neurons on the space of {q,kt,vt}𝑞subscript𝑘𝑡subscript𝑣𝑡\{q,~{}k_{t},v_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } (called Nqkvsubscript𝑁𝑞𝑘𝑣N_{qkv}italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT in section 4). Training curves on the top row use a learning rate of 0.0010.0010.0010.001, whilst the bottom row use a value of 0.00010.00010.00010.0001. In each panel, two training runs are shown, with different random seeds. Pre-Norm is shown in blue, and QKV-Norm in red.

We find that Pre-Norm training is unstable for large learning rates and model sizes, as shown by the flat blue curves in the top right hand panels. Similar stabilisation improvements at large learning rate is reported for QK-Norm in [19, 20], which applies layer normalisation to {q,kt}𝑞subscript𝑘𝑡\{q,~{}k_{t}\}{ italic_q , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } but not vtsubscript𝑣𝑡v_{t}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, as for QKV-Norm. However, we note that training large models with a smaller learning rate leads to improved model performance, as shown by the panels on the bottom right. Finally, we note that both methods typically train the model at similar rates, however small model training follows a very different trajectory, with QKV-Norm learning more slowly at the beginning of training (bottom left panels). There is also some visible evidence that small model training is actually less effective when using QKV-Norm with a large learning rate (top left panels).

Refer to caption
Figure 24: Training curves when learning the task configuration shown in Table 7.

D.3 Attention sparsity

We find that our Pre-Norm models often exploit sparse-attention, whereas models trained with QKV-Norm do not. Similar behaviour is reported for QK-Norm in [18]. For a systematic comparison, Figure 25 shows a histogram of the maximum attention observed per-distribution (i.e. a histogram of maxiaisubscript𝑖subscript𝑎𝑖\max_{i}a_{i}roman_max start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT). When making this plot, we do not consider the first row of the attention matrix, in which the [ token attends fully to itself.

We see that the Pre-Norm distribution has a sharp peak at 1111, indicating a significant use of sparse-attention. By contrast, the QKV-Norm distribution is weighted towards 00 and has no peak at 1111. To verify this behaviour, Figure 27 shows an attention heatmap for a randomly chosen datapoint when using the Baseline Pre-Norm model, and Figure 27 shows the same datapoint for QKV-Norm. We observe a significantly less sparse attention matrix for QKV-Norm. Note that [18] also shows a similar visualisation.

Refer to caption
Figure 25: Distribution of the maximum attention observed per-distribution, i.e. maxiaisubscript𝑖subscript𝑎𝑖\max_{i}a_{i}roman_max start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, in the Baseline case. We observe that the Pre-Norm model often utilises sparse-attention, as seen by the peak at 1111. By contrast, QKV-Norm shows no such peak. Similar behaviour is reported for QK-Norm in [18].
Refer to caption
Figure 26: Attention maps for a random in-distribution example using the Baseline Pre-Norm model. Several attention heads create sparse attention distributions.
Refer to caption
Figure 27: Attention maps for a random in-distribution example using the Baseline QKV-Norm model. We observe much less sparsity than in the Pre-Norm model, shown in Figure 27. Similar behaviour is reported for QK-Norm in [18].

Appendix E Supplementary theorems

This appendix contains theorems that support the main results, providing additional context or being pre-requisite for the proofs in appendix F. We use the formulation of section 4, where 1{t,t}T1𝑡superscript𝑡𝑇1\leq\{t,t^{\prime}\}\leq T1 ≤ { italic_t , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT } ≤ italic_T are indices over tokens, xNx𝑥superscriptsubscript𝑁𝑥x\in\mathbb{R}^{N_{x}}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the message receiving embedding, {ytNy}subscript𝑦𝑡superscriptsubscript𝑁𝑦\{y_{t}\in\mathbb{R}^{N_{y}}\}{ italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT } are the message senders, wt=xTWQKytsubscript𝑤𝑡superscript𝑥𝑇subscript𝑊𝑄𝐾subscript𝑦𝑡w_{t}=x^{T}W_{QK}y_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and at=softmaxtwtsubscript𝑎𝑡subscriptsoftmax𝑡subscript𝑤𝑡a_{t}=\texttt{softmax}_{t}w_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = softmax start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the attention distribution

{mdframed}

[backgroundcolor=green!5,skipabove=-2pt,skipbelow=0]

Theorem 9.

Shifting attention scores wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by a constant offset does not affect the attention distribution. Therefore attention is fully determined by differences in scores.

Proof.  Applying the shift wtoffsetwwt+δwtoffsetwabsentsubscript𝑤𝑡subscript𝑤𝑡𝛿𝑤for-all𝑡w_{t}\xrightarrow[\mathrm{offset~{}w}]{}w_{t}+\delta w~{}\forall~{}titalic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_ARROW start_UNDERACCENT roman_offset roman_w end_UNDERACCENT start_ARROW start_OVERACCENT end_OVERACCENT → end_ARROW end_ARROW italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_δ italic_w ∀ italic_t with fixed δw𝛿𝑤\delta witalic_δ italic_w, we have

at=ewttewtoffsetweδwewtteδwewt=eδweδwewttewt=1at=atsubscript𝑎𝑡superscript𝑒subscript𝑤𝑡subscriptsuperscript𝑡superscript𝑒subscript𝑤superscript𝑡offsetwabsentsuperscript𝑒𝛿𝑤superscript𝑒subscript𝑤𝑡subscriptsuperscript𝑡superscript𝑒𝛿𝑤superscript𝑒subscript𝑤superscript𝑡superscript𝑒𝛿𝑤superscript𝑒𝛿𝑤superscript𝑒subscript𝑤𝑡subscriptsuperscript𝑡superscript𝑒subscript𝑤superscript𝑡1subscript𝑎𝑡subscript𝑎𝑡\begin{split}a_{t}~{}&=~{}\frac{e^{w_{t}}}{\sum_{t^{\prime}}e^{w_{t^{\prime}}}% }\\ &\xrightarrow[\mathrm{offset~{}w}]{}~{}\frac{e^{\delta w}e^{w_{t}}}{\sum_{t^{% \prime}}e^{\delta w}e^{w_{t^{\prime}}}}~{}=~{}\frac{e^{\delta w}}{e^{\delta w}% }\frac{e^{w_{t}}}{\sum_{t^{\prime}}e^{w_{t^{\prime}}}}~{}=~{}1\cdot a_{t}~{}=~% {}a_{t}\end{split}start_ROW start_CELL italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = divide start_ARG italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL start_ARROW start_UNDERACCENT roman_offset roman_w end_UNDERACCENT start_ARROW start_OVERACCENT end_OVERACCENT → end_ARROW end_ARROW divide start_ARG italic_e start_POSTSUPERSCRIPT italic_δ italic_w end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_δ italic_w end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG = divide start_ARG italic_e start_POSTSUPERSCRIPT italic_δ italic_w end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_δ italic_w end_POSTSUPERSCRIPT end_ARG divide start_ARG italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG = 1 ⋅ italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW (18)

Alternatively we may write

at=ewttewt=ewtewttewtwt=1tewtwtsubscript𝑎𝑡superscript𝑒subscript𝑤𝑡subscriptsuperscript𝑡superscript𝑒subscript𝑤superscript𝑡superscript𝑒subscript𝑤𝑡superscript𝑒subscript𝑤superscript𝑡subscriptsuperscript𝑡superscript𝑒subscript𝑤superscript𝑡subscript𝑤𝑡1subscriptsuperscript𝑡superscript𝑒subscript𝑤superscript𝑡subscript𝑤𝑡a_{t}~{}=~{}\frac{e^{w_{t}}}{\sum_{t^{\prime}}e^{w_{t^{\prime}}}}~{}=~{}\frac{% e^{w_{t}}}{e^{w_{t^{\prime}}}\sum_{t^{\prime}}e^{w_{t^{\prime}}-w_{t}}}~{}=~{}% \frac{1}{\sum_{t^{\prime}}e^{w_{t^{\prime}}-w_{t}}}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG = divide start_ARG italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG = divide start_ARG 1 end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG (19)

where (wt+δt)(wt+δt)=wtwtsubscript𝑤superscript𝑡𝛿𝑡subscript𝑤𝑡𝛿𝑡subscript𝑤superscript𝑡subscript𝑤𝑡\left(w_{t^{\prime}}+\delta t\right)-\left(w_{t}+\delta t\right)=w_{t^{\prime}% }-w_{t}( italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT + italic_δ italic_t ) - ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_δ italic_t ) = italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

{mdframed}

[backgroundcolor=green!5]

Theorem 10.

Multiplying attention scores by a positive factor changes the inverse-temperature of the attention distribution, modulating its sparsity (low temperature = less entropy = more sparse). Corollary: In the sparse limit, attention is fully determined by the order of wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Proof.  Applying the scaling wtscalewκwttscalewabsentsubscript𝑤𝑡𝜅subscript𝑤𝑡for-all𝑡w_{t}\xrightarrow[\mathrm{scale~{}w}]{}\kappa w_{t}~{}\forall~{}titalic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_ARROW start_UNDERACCENT roman_scale roman_w end_UNDERACCENT start_ARROW start_OVERACCENT end_OVERACCENT → end_ARROW end_ARROW italic_κ italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∀ italic_t with fixed κ>0𝜅0\kappa>0italic_κ > 0, we have

eκwtteκwt=1teκ(wtwt)κ01te0=1Tt[fully isotropic distribution]κ{1fort=argmaxtwt0targmaxtwt[fully sparse distribution]\begin{split}\frac{e^{\kappa w_{t}}}{\sum_{t^{\prime}}e^{\kappa w_{t^{\prime}}% }}~{}~{}&=~{}~{}\frac{1}{\sum_{t^{\prime}}e^{\kappa(w_{t^{\prime}}-w_{t})}}\\ ~{}~{}&\xrightarrow[\kappa\rightarrow 0]{}~{}~{}~{}~{}\frac{1}{\sum_{t^{\prime% }}e^{0}}~{}=~{}\frac{1}{T}~{}\forall~{}t~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{% }~{}~{}~{}~{}~{}~{}~{}~{}\text{[fully isotropic distribution]}\\ ~{}~{}&\xrightarrow[\kappa\rightarrow\infty]{}~{}~{}\begin{cases}1~{}~{}~{}% \mathrm{for}~{}t=\mathrm{argmax}_{t^{\prime}}w_{t^{\prime}}\\ 0~{}~{}~{}~{}~{}\forall~{}t\neq\mathrm{argmax}_{t^{\prime}}w_{t^{\prime}}\\ \end{cases}~{}~{}~{}\text{[fully sparse distribution]}\end{split}start_ROW start_CELL divide start_ARG italic_e start_POSTSUPERSCRIPT italic_κ italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_κ italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_κ ( italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL start_ARROW start_UNDERACCENT italic_κ → 0 end_UNDERACCENT start_ARROW start_OVERACCENT end_OVERACCENT → end_ARROW end_ARROW divide start_ARG 1 end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT end_ARG = divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ∀ italic_t [fully isotropic distribution] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL start_ARROW start_UNDERACCENT italic_κ → ∞ end_UNDERACCENT start_ARROW start_OVERACCENT end_OVERACCENT → end_ARROW end_ARROW { start_ROW start_CELL 1 roman_for italic_t = roman_argmax start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL 0 ∀ italic_t ≠ roman_argmax start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL start_CELL end_CELL end_ROW [fully sparse distribution] end_CELL end_ROW (20)

where the argmax operator is fully determined by the order of wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

{mdframed}

[backgroundcolor=green!5]

Theorem 11.

In the No-Norm case, the attention distribution atsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is defined by the projection of ytsubscript𝑦𝑡y_{t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT onto a fixed vector yxsubscript𝑦𝑥y_{x}italic_y start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT for a given x𝑥xitalic_x. The length of yxsubscript𝑦𝑥y_{x}italic_y start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT is an inverse-temperature parameter.

Proof.  Write wt=xTWQKyt=(WQKTx)TytyxTytsubscript𝑤𝑡superscript𝑥𝑇subscript𝑊𝑄𝐾subscript𝑦𝑡superscriptsuperscriptsubscript𝑊𝑄𝐾𝑇𝑥𝑇subscript𝑦𝑡superscriptsubscript𝑦𝑥𝑇subscript𝑦𝑡w_{t}=x^{T}W_{QK}y_{t}=(W_{QK}^{T}x)^{T}y_{t}\equiv y_{x}^{T}y_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≡ italic_y start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT where yxWQKTxNysubscript𝑦𝑥superscriptsubscript𝑊𝑄𝐾𝑇𝑥superscriptsubscript𝑁𝑦y_{x}\triangleq W_{QK}^{T}x\in\mathbb{R}^{N_{y}}italic_y start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ≜ italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, which is the dot-product between ytsubscript𝑦𝑡y_{t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and a fixed vector yxsubscript𝑦𝑥y_{x}italic_y start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT on the row space of WQKsubscript𝑊𝑄𝐾W_{QK}italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT. Then, re-writing in terms of the vector lengths and the enclosing angle θyt=yxytsubscript𝜃subscript𝑦𝑡subscript𝑦𝑥subscript𝑦𝑡\theta_{y_{t}}=y_{x}\wedge y_{t}italic_θ start_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT = italic_y start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ∧ italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, we have w=|yx||yt|cosθyt𝑤subscript𝑦𝑥subscript𝑦𝑡subscript𝜃subscript𝑦𝑡w=|y_{x}||y_{t}|\cos\theta_{y_{t}}italic_w = | italic_y start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT | | italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | roman_cos italic_θ start_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT. The factor |yx|subscript𝑦𝑥|y_{x}|| italic_y start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT | is identical for all t𝑡titalic_t, making it an inverse-temperature.

{mdframed}

[backgroundcolor=green!5]

Theorem 12.

In the No-Norm case, bias parameters in the construction of query and key vectors are nullified by the softmax, or only contribute terms that may be recovered if x𝑥xitalic_x contains a constant direction.

Proof.  Consider a modification to the construction of query and key vectors that uses the affine transformations q=WQx+bQ𝑞subscript𝑊𝑄𝑥subscript𝑏𝑄q=W_{Q}x+b_{Q}italic_q = italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_x + italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT and kt=WKyt+bKsubscript𝑘𝑡subscript𝑊𝐾subscript𝑦𝑡subscript𝑏𝐾k_{t}=W_{K}y_{t}+b_{K}italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, with WQNqkv×Nxsubscript𝑊𝑄superscriptsubscript𝑁𝑞𝑘𝑣subscript𝑁𝑥W_{Q}\in\mathbb{R}^{N_{qkv}\times N_{x}}italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, WKNqkv×Nysubscript𝑊𝐾superscriptsubscript𝑁𝑞𝑘𝑣subscript𝑁𝑦W_{K}\in\mathbb{R}^{N_{qkv}\times N_{y}}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, WQKWQTWKsubscript𝑊𝑄𝐾superscriptsubscript𝑊𝑄𝑇subscript𝑊𝐾W_{QK}\triangleq W_{Q}^{T}W_{K}italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT ≜ italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, and bQ,bKNqkvsubscript𝑏𝑄subscript𝑏𝐾superscriptsubscript𝑁𝑞𝑘𝑣b_{Q},b_{K}\in\mathbb{R}^{N_{qkv}}italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. The dot-product attention scores are then:

wt=qTkt=(WQx+bQ)T(WKyt+bK)=xTWQKyt+(WQTbK)Tx+(WKTbQ)Tyt+bQTbKwt+const=xTWQKyt+(WQTbK)Tx+(WKTbQ)TytxTWQKyt+ρxTx+ρyTytρxTx=constgivenx=xTWQKyt+ρyTytWQKΩTΛΣvia SVD=xTΩTΛΣyt+ρyTytxΩx,ytΣyt=xTΛyt+ρyTyt\begin{split}w_{t}~{}&=~{}q^{T}k_{t}\\ &=~{}\left(W_{Q}x+b_{Q}\right)^{T}\left(W_{K}y_{t}+b_{K}\right)\\ &=~{}x^{T}W_{QK}y_{t}~{}+~{}({W_{Q}}^{T}b_{K})^{T}x~{}+~{}({W_{K}}^{T}b_{Q})^{% T}y_{t}~{}+~{}b_{Q}^{T}b_{K}\\ w_{t}~{}+~{}const~{}&=~{}x^{T}W_{QK}y_{t}~{}+~{}({W_{Q}}^{T}b_{K})^{T}x~{}+~{}% ({W_{K}}^{T}b_{Q})^{T}y_{t}\\ &\triangleq~{}x^{T}W_{QK}y_{t}~{}+~{}\rho_{x}^{T}x~{}+~{}\rho_{y}^{T}y_{t}~{}~% {}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\rightarrow~{}\rho_{x}^{T}x=const~{}\mathrm% {given}~{}x\rightarrow\\ &=~{}x^{T}W_{QK}y_{t}~{}+~{}\rho_{y}^{T}y_{t}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}% ~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\rightarrow~{}W_{QK}\triangleq% \Omega^{T}\Lambda\Sigma~{}\text{via SVD}~{}\rightarrow\\ &=~{}x^{T}\Omega^{T}\Lambda\Sigma y_{t}~{}+~{}\rho_{y}^{T}y_{t}~{}~{}~{}~{}~{}% ~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\rightarrow~{}x% ^{\prime}\triangleq\Omega x,~{}~{}y^{\prime}_{t}\triangleq\Sigma y_{t}~{}% \rightarrow\\ &=~{}{x^{\prime}}^{T}\Lambda y^{\prime}_{t}~{}+~{}\rho_{y}^{T}y_{t}\end{split}start_ROW start_CELL italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ( italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_x + italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + ( italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x + ( italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_c italic_o italic_n italic_s italic_t end_CELL start_CELL = italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + ( italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x + ( italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≜ italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_ρ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x + italic_ρ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → italic_ρ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x = italic_c italic_o italic_n italic_s italic_t roman_given italic_x → end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_ρ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT ≜ roman_Ω start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_Λ roman_Σ via SVD → end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_x start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_Ω start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_Λ roman_Σ italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_ρ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≜ roman_Ω italic_x , italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≜ roman_Σ italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_Λ italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_ρ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW (21)

After expanding the terms, we find an additive constant bQTbKsuperscriptsubscript𝑏𝑄𝑇subscript𝑏𝐾b_{Q}^{T}b_{K}italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, and move this onto the LHS. Theorem 9 states that this has no impact on the output of the softmax operator. We identify ρxWQTbksubscript𝜌𝑥superscriptsubscript𝑊𝑄𝑇subscript𝑏𝑘\rho_{x}\triangleq W_{Q}^{T}b_{k}italic_ρ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ≜ italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and ρyWKTbqsubscript𝜌𝑦superscriptsubscript𝑊𝐾𝑇subscript𝑏𝑞\rho_{y}\triangleq W_{K}^{T}b_{q}italic_ρ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ≜ italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT as vectors on the row-spaces of WQsubscript𝑊𝑄W_{Q}italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT and WKsubscript𝑊𝐾W_{K}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT respectively, defined as linear maps of the special directions bKsubscript𝑏𝐾b_{K}italic_b start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT and bQsubscript𝑏𝑄b_{Q}italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT. Since x𝑥xitalic_x is constant for each softmax, ρxTxsuperscriptsubscript𝜌𝑥𝑇𝑥\rho_{x}^{T}xitalic_ρ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x is constant, and we absorb it into the LHS. We perform the singular value decomposition WQKΩTΛΣsubscript𝑊𝑄𝐾superscriptΩ𝑇ΛΣW_{QK}\triangleq\Omega^{T}\Lambda\Sigmaitalic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT ≜ roman_Ω start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_Λ roman_Σ where {ΩNx×Nx,ΣNy×Ny}formulae-sequenceΩsuperscriptsubscript𝑁𝑥subscript𝑁𝑥Σsuperscriptsubscript𝑁𝑦subscript𝑁𝑦\{\Omega\in\mathbb{R}^{N_{x}\times N_{x}},~{}\Sigma\in\mathbb{R}^{N_{y}\times N% _{y}}\}{ roman_Ω ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , roman_Σ ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT } are orthonormal matrices and ΛNx×NyΛsuperscriptsubscript𝑁𝑥subscript𝑁𝑦\Lambda\in\mathbb{R}^{N_{x}\times N_{y}}roman_Λ ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT × italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a diagonal matrix of positive-semidefinite singular values with maximum rank min(Nx,Ny,Nqkv)subscript𝑁𝑥subscript𝑁𝑦subscript𝑁𝑞𝑘𝑣\min(N_{x},N_{y},N_{qkv})roman_min ( italic_N start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT ). Orthonormal matrices apply a basis change to the embedding space using rotations and reflections. We write the transformed embeddings as xΩxsuperscript𝑥Ω𝑥x^{\prime}\triangleq\Omega xitalic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≜ roman_Ω italic_x and ytΣytsuperscriptsubscript𝑦𝑡Σsubscript𝑦𝑡y_{t}^{\prime}\triangleq\Sigma y_{t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≜ roman_Σ italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. The dot-product then has two terms:

  1. 1.

    xTΛyt=iΛiixiytisuperscriptsuperscript𝑥𝑇Λsubscriptsuperscript𝑦𝑡subscript𝑖subscriptΛ𝑖𝑖subscriptsuperscript𝑥𝑖subscriptsuperscript𝑦𝑡𝑖{x^{\prime}}^{T}\Lambda y^{\prime}_{t}=\sum_{i}\Lambda_{ii}x^{\prime}_{i}y^{% \prime}_{ti}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_Λ italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_Λ start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t italic_i end_POSTSUBSCRIPT sculpts the attention distribution according to pairwise relationships between embeddings. We can say that {Ω,Σ}ΩΣ\{\Omega,\Sigma\}{ roman_Ω , roman_Σ } align the bases of x𝑥xitalic_x and ytsubscript𝑦𝑡y_{t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, map** them onto a common orthonormal coordinate system. ΛiisubscriptΛ𝑖𝑖\Lambda_{ii}roman_Λ start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT then assigns an importance weight to each coordinate i𝑖iitalic_i, determining the contribution of xiytisubscriptsuperscript𝑥𝑖subscriptsuperscript𝑦𝑡𝑖x^{\prime}_{i}y^{\prime}_{ti}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t italic_i end_POSTSUBSCRIPT.

  2. 2.

    ρyTysuperscriptsubscript𝜌𝑦𝑇𝑦\rho_{y}^{T}yitalic_ρ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y means “token t𝑡titalic_t sends to all receivers when ytρyconditionalsubscript𝑦𝑡subscript𝜌𝑦y_{t}\parallel\rho_{y}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ italic_ρ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT”, where ρysubscript𝜌𝑦\rho_{y}italic_ρ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT must be a vector on the row-space of WKsubscript𝑊𝐾W_{K}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT. This may be recovered in the expansion of xTΛytsuperscriptsuperscript𝑥𝑇Λsubscriptsuperscript𝑦𝑡{x^{\prime}}^{T}\Lambda y^{\prime}_{t}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_Λ italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT if there exists a direction i𝑖iitalic_i for which xi=constsubscriptsuperscript𝑥𝑖𝑐𝑜𝑛𝑠𝑡x^{\prime}_{i}=constitalic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t.

Appendix F Proofs of theorems in the main text

This appendix provides proofs for the theorems presented in section 5-6.

{mdframed}

[backgroundcolor=green!5] Theorem 1. No-Norm: If two heads with finite non-zero temperature attend to different semantic subspaces, the subspaces must be linearly independent 𝕊αNαNαsubscriptsuperscript𝕊subscriptNααsuperscriptsubscriptNα\mathbb{S}^{N_{\alpha}}_{\alpha}\equiv\mathbb{R}^{N_{\alpha}}blackboard_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ≡ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. Corollary: WQKsubscriptWQKW_{QK}italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT is a low-rank matrix with (left and right) null-spaces that span all non-attended information.

Proof.   Let θAsubscript𝜃𝐴\theta_{A}italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and θBsubscript𝜃𝐵\theta_{B}italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT be co-ordinates for the subspaces of x𝑥xitalic_x attended to by heads A and B respectively, and ϕitalic-ϕ\phiitalic_ϕ be all other information. Let θAθBϕperpendicular-tosubscript𝜃𝐴subscript𝜃𝐵perpendicular-toitalic-ϕ\theta_{A}\perp\theta_{B}\perp\phiitalic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ⟂ italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ⟂ italic_ϕ and xytperpendicular-to𝑥subscript𝑦𝑡x\perp y_{t}italic_x ⟂ italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, where perpendicular-to\perp denotes independence. Without loss of generality, write

x(θA,θB,ϕ)=xA(θA)+xB(θB)+xother(θA,θB,ϕ)𝑥subscript𝜃𝐴subscript𝜃𝐵italic-ϕsubscript𝑥𝐴subscript𝜃𝐴subscript𝑥𝐵subscript𝜃𝐵subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴subscript𝜃𝐵italic-ϕx(\theta_{A},\theta_{B},\phi)~{}=~{}x_{A}(\theta_{A})~{}+~{}x_{B}(\theta_{B})~% {}+~{}x_{other}(\theta_{A},\theta_{B},\phi)italic_x ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) = italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) + italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) + italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) (22)

Then write

wt(A)(θA)=(WQK(A)yt)Tx(θA,θB,ϕ)=(WQK(A)yt)TxA(θA)+(WQK(A)yt)TxB(θB)+(WQK(A)yt)Txother(θA,θB,ϕ)superscriptsubscript𝑤𝑡𝐴subscript𝜃𝐴superscriptsuperscriptsubscript𝑊𝑄𝐾𝐴subscript𝑦𝑡𝑇𝑥subscript𝜃𝐴subscript𝜃𝐵italic-ϕsuperscriptsuperscriptsubscript𝑊𝑄𝐾𝐴subscript𝑦𝑡𝑇subscript𝑥𝐴subscript𝜃𝐴superscriptsuperscriptsubscript𝑊𝑄𝐾𝐴subscript𝑦𝑡𝑇subscript𝑥𝐵subscript𝜃𝐵superscriptsuperscriptsubscript𝑊𝑄𝐾𝐴subscript𝑦𝑡𝑇subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴subscript𝜃𝐵italic-ϕ\begin{split}w_{t}^{(A)}(\theta_{A})~{}&=~{}\left(W_{QK}^{(A)}y_{t}\right)^{T}% x(\theta_{A},\theta_{B},\phi)\\ &=~{}\left(W_{QK}^{(A)}y_{t}\right)^{T}x_{A}(\theta_{A})~{}+~{}\left(W_{QK}^{(% A)}y_{t}\right)^{T}x_{B}(\theta_{B})~{}+~{}\left(W_{QK}^{(A)}y_{t}\right)^{T}x% _{other}(\theta_{A},\theta_{B},\phi)\\ \end{split}start_ROW start_CELL italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) end_CELL start_CELL = ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) + ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) + ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) end_CELL end_ROW (23)

which requires (WQK(A)yt)TxB(θB)=0superscriptsuperscriptsubscript𝑊𝑄𝐾𝐴subscript𝑦𝑡𝑇subscript𝑥𝐵subscript𝜃𝐵0\left(W_{QK}^{(A)}y_{t}\right)^{T}x_{B}(\theta_{B})=0( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) = 0 and (WQK(A)yt)Txother(θA,θB,ϕ)=0superscriptsuperscriptsubscript𝑊𝑄𝐾𝐴subscript𝑦𝑡𝑇subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴subscript𝜃𝐵italic-ϕ0\left(W_{QK}^{(A)}y_{t}\right)^{T}x_{other}(\theta_{A},\theta_{B},\phi)=0( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) = 0, since any cancellation between the two terms must be independent of θA,ϕsubscript𝜃𝐴italic-ϕ\theta_{A},\phiitalic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_ϕ and so can be absorbed entirely into the function xB(θB)subscript𝑥𝐵subscript𝜃𝐵x_{B}(\theta_{B})italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ). This means that xB(θB)subscript𝑥𝐵subscript𝜃𝐵x_{B}(\theta_{B})italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) and xother(θA,θB,ϕ)subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴subscript𝜃𝐵italic-ϕx_{other}(\theta_{A},\theta_{B},\phi)italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) must both be orthogonal to WQK(A)ytsuperscriptsubscript𝑊𝑄𝐾𝐴subscript𝑦𝑡W_{QK}^{(A)}y_{t}italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, meaning that they reside on the left null space of WQK(A)superscriptsubscript𝑊𝑄𝐾𝐴W_{QK}^{(A)}italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT, or are projected by WQK(A)Tsuperscriptsuperscriptsubscript𝑊𝑄𝐾𝐴𝑇{W_{QK}^{(A)}}^{T}italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPTonto a null space of ytsubscript𝑦𝑡y_{t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Head A can only attend to θAsubscript𝜃𝐴\theta_{A}italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT if xA(θA)subscript𝑥𝐴subscript𝜃𝐴x_{A}(\theta_{A})italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) it is not on either of these null spaces, meaning that xA(θA)subscript𝑥𝐴subscript𝜃𝐴x_{A}(\theta_{A})italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) is linearly independent of xB(θB)subscript𝑥𝐵subscript𝜃𝐵x_{B}(\theta_{B})italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) and xother(θA,θB,ϕ)subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴subscript𝜃𝐵italic-ϕx_{other}(\theta_{A},\theta_{B},\phi)italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ). Likewise for head B

wt(B)(θB)=(WQK(B)yt)Tx(θA,θB,ϕ)=(WQK(B)yt)TxA(θA)+(WQK(B)yt)TxB(θB)+(WQK(B)yt)Txother(θA,θB,ϕ)superscriptsubscript𝑤𝑡𝐵subscript𝜃𝐵superscriptsuperscriptsubscript𝑊𝑄𝐾𝐵subscript𝑦𝑡𝑇𝑥subscript𝜃𝐴subscript𝜃𝐵italic-ϕsuperscriptsuperscriptsubscript𝑊𝑄𝐾𝐵subscript𝑦𝑡𝑇subscript𝑥𝐴subscript𝜃𝐴superscriptsuperscriptsubscript𝑊𝑄𝐾𝐵subscript𝑦𝑡𝑇subscript𝑥𝐵subscript𝜃𝐵superscriptsuperscriptsubscript𝑊𝑄𝐾𝐵subscript𝑦𝑡𝑇subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴subscript𝜃𝐵italic-ϕ\begin{split}w_{t}^{(B)}(\theta_{B})~{}&=~{}\left(W_{QK}^{(B)}y_{t}\right)^{T}% x(\theta_{A},\theta_{B},\phi)\\ &=~{}\left(W_{QK}^{(B)}y_{t}\right)^{T}x_{A}(\theta_{A})~{}+~{}\left(W_{QK}^{(% B)}y_{t}\right)^{T}x_{B}(\theta_{B})~{}+~{}\left(W_{QK}^{(B)}y_{t}\right)^{T}x% _{other}(\theta_{A},\theta_{B},\phi)\\ \end{split}start_ROW start_CELL italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_B ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) end_CELL start_CELL = ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_B ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_B ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) + ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_B ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) + ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_B ) end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) end_CELL end_ROW (24)

requires that xB(θB)subscript𝑥𝐵subscript𝜃𝐵x_{B}(\theta_{B})italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) is linearly independent of both xA(θA)subscript𝑥𝐴subscript𝜃𝐴x_{A}(\theta_{A})italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) and xother(θA,θB,ϕ)subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴subscript𝜃𝐵italic-ϕx_{other}(\theta_{A},\theta_{B},\phi)italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ). Since xothersubscript𝑥𝑜𝑡𝑒𝑟x_{other}italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT resides on both null spaces, it is linearly independent of both xA(θA)subscript𝑥𝐴subscript𝜃𝐴x_{A}(\theta_{A})italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) and xB(θB)subscript𝑥𝐵subscript𝜃𝐵x_{B}(\theta_{B})italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ), and may be seen as a third subspace that passes information through to subsequent layers.

We can also write wt=(WQKTx)Tytsubscript𝑤𝑡superscriptsuperscriptsubscript𝑊𝑄𝐾𝑇𝑥𝑇subscript𝑦𝑡w_{t}=\left(W_{QK}^{T}x\right)^{T}y_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and so the same argument also holds for subspaces on ytsubscript𝑦𝑡y_{t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. In this case, non-attended subspaces are spanned by the right null space of WQKsubscript𝑊𝑄𝐾W_{QK}italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT.

{mdframed}

[backgroundcolor=green!5] Theorem 2. Pre-Norm: Semantic subspaces must be represented as orthogonal spheres 𝕊Nα𝒮Nα1superscript𝕊subscriptNαsuperscript𝒮subscriptNα1\mathbb{S}^{N_{\alpha}}\equiv\mathcal{S}^{N_{\alpha}-1}blackboard_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ≡ caligraphic_S start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT defined using the L2subscriptL2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm. Corollary: if either orthogonality or constant-norm are violated, semantic subspaces interfere through a multiplicative factor on wtsubscriptwtw_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Proof.  Write

x(θA,θB,ϕ)=xA(θA)+xB(θB)+xAB(θA,θB)+xother(θA,θB,ϕ)𝑥subscript𝜃𝐴subscript𝜃𝐵italic-ϕsubscript𝑥𝐴subscript𝜃𝐴subscript𝑥𝐵subscript𝜃𝐵subscript𝑥𝐴𝐵subscript𝜃𝐴subscript𝜃𝐵subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴subscript𝜃𝐵italic-ϕx(\theta_{A},\theta_{B},\phi)~{}=~{}x_{A}(\theta_{A})~{}+~{}x_{B}(\theta_{B})~% {}+~{}x_{AB}(\theta_{A},\theta_{B})~{}+~{}x_{other}(\theta_{A},\theta_{B},\phi)italic_x ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) = italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) + italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) + italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) + italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) (25)

Then for head A we have

wt(A)(θA)=1|yt||x(θA,θB,ϕ)|wt(A)(θA)superscriptsubscript𝑤𝑡𝐴subscript𝜃𝐴1subscript𝑦𝑡𝑥subscript𝜃𝐴subscript𝜃𝐵italic-ϕsuperscriptsubscriptsuperscript𝑤𝑡𝐴subscript𝜃𝐴w_{t}^{(A)}(\theta_{A})~{}=~{}\frac{1}{\left|y_{t}\right|\left|x(\theta_{A},% \theta_{B},\phi)\right|}{w^{*}_{t}}^{(A)}(\theta_{A})italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG | italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | | italic_x ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) | end_ARG italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) (26)

where wtsubscriptsuperscript𝑤𝑡w^{*}_{t}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are the attention scores from the No-Norm case, which requires xA(θA)subscript𝑥𝐴subscript𝜃𝐴x_{A}(\theta_{A})italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) and xB(θB)subscript𝑥𝐵subscript𝜃𝐵x_{B}(\theta_{B})italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) to be linearly independent. Now we additionally require |x(θA,θB,ϕ)|θB,ϕperpendicular-to𝑥subscript𝜃𝐴subscript𝜃𝐵italic-ϕsubscript𝜃𝐵italic-ϕ\left|x(\theta_{A},\theta_{B},\phi)\right|\perp\theta_{B},\phi| italic_x ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) | ⟂ italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ, with

|x|=|xA|2+|xB+xAB+xother|2+2xAT(xB+xAB+xother)𝑥superscriptsubscript𝑥𝐴2superscriptsubscript𝑥𝐵subscript𝑥𝐴𝐵subscript𝑥𝑜𝑡𝑒𝑟22superscriptsubscript𝑥𝐴𝑇subscript𝑥𝐵subscript𝑥𝐴𝐵subscript𝑥𝑜𝑡𝑒𝑟|x|~{}=~{}\sqrt{|x_{A}|^{2}~{}+~{}|x_{B}~{}+~{}x_{AB}~{}+~{}x_{other}|^{2}~{}+% ~{}2x_{A}^{T}\left(x_{B}~{}+~{}x_{AB}~{}+~{}x_{other}\right)}| italic_x | = square-root start_ARG | italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + | italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ) end_ARG (27)

where we suppress parameter dependence for readability. Since \sqrt{\cdot}square-root start_ARG ⋅ end_ARG is a monotonic function, this can only be satisfied if

|xA|2+|xB+xAB+xother|2+2xAT(xB+xAB+xother)θB,ϕperpendicular-tosuperscriptsubscript𝑥𝐴2superscriptsubscript𝑥𝐵subscript𝑥𝐴𝐵subscript𝑥𝑜𝑡𝑒𝑟22superscriptsubscript𝑥𝐴𝑇subscript𝑥𝐵subscript𝑥𝐴𝐵subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐵italic-ϕ|x_{A}|^{2}~{}+~{}|x_{B}~{}+~{}x_{AB}~{}+~{}x_{other}|^{2}~{}+~{}2x_{A}^{T}% \left(x_{B}~{}+~{}x_{AB}~{}+~{}x_{other}\right)~{}\perp~{}\theta_{B},\phi| italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + | italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ) ⟂ italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ (28)

Repeating this process for head B gives

|xB|2+|xA+xAB+xother|2+2xBT(xA+xAB+xother)θA,ϕperpendicular-tosuperscriptsubscript𝑥𝐵2superscriptsubscript𝑥𝐴subscript𝑥𝐴𝐵subscript𝑥𝑜𝑡𝑒𝑟22superscriptsubscript𝑥𝐵𝑇subscript𝑥𝐴subscript𝑥𝐴𝐵subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴italic-ϕ|x_{B}|^{2}~{}+~{}|x_{A}~{}+~{}x_{AB}~{}+~{}x_{other}|^{2}~{}+~{}2x_{B}^{T}% \left(x_{A}~{}+~{}x_{AB}~{}+~{}x_{other}\right)~{}\perp~{}\theta_{A},\phi| italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + | italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT + italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ) ⟂ italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_ϕ (29)

Combining and collecting dependencies, we then have

|xA|2=constsuperscriptsubscript𝑥𝐴2𝑐𝑜𝑛𝑠𝑡\displaystyle|x_{A}|^{2}~{}=~{}const~{}~{}~{}| italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_c italic_o italic_n italic_s italic_t θAfor-allsubscript𝜃𝐴\displaystyle\forall~{}~{}~{}\theta_{A}∀ italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT (30)
|xB|2=constsuperscriptsubscript𝑥𝐵2𝑐𝑜𝑛𝑠𝑡\displaystyle|x_{B}|^{2}~{}=~{}const~{}~{}~{}| italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_c italic_o italic_n italic_s italic_t θBfor-allsubscript𝜃𝐵\displaystyle\forall~{}~{}~{}\theta_{B}∀ italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT (31)
(xAB+2xA+2xB)TxAB+2xATxB=constsuperscriptsubscript𝑥𝐴𝐵2subscript𝑥𝐴2subscript𝑥𝐵𝑇subscript𝑥𝐴𝐵2superscriptsubscript𝑥𝐴𝑇subscript𝑥𝐵𝑐𝑜𝑛𝑠𝑡\displaystyle\left(x_{AB}~{}+~{}2x_{A}~{}+~{}2x_{B}\right)^{T}x_{AB}~{}+~{}2x_% {A}^{T}x_{B}~{}=~{}const~{}~{}~{}( italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT + 2 italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT + 2 italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT + 2 italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t θA,θBfor-allsubscript𝜃𝐴subscript𝜃𝐵\displaystyle\forall~{}~{}~{}\theta_{A},\theta_{B}∀ italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT (32)
(xother+2xA+2xB+2xAB)Txother=constsuperscriptsubscript𝑥𝑜𝑡𝑒𝑟2subscript𝑥𝐴2subscript𝑥𝐵2subscript𝑥𝐴𝐵𝑇subscript𝑥𝑜𝑡𝑒𝑟𝑐𝑜𝑛𝑠𝑡\displaystyle\left(x_{other}+2x_{A}+2x_{B}+2x_{AB}\right)^{T}x_{other}~{}=~{}% const~{}~{}~{}( italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT + 2 italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT + 2 italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT + 2 italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t θA,θB,ϕfor-allsubscript𝜃𝐴subscript𝜃𝐵italic-ϕ\displaystyle\forall~{}~{}~{}\theta_{A},\theta_{B},\phi∀ italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ (33)

We can go one step further, noticing that each individual term carries a different functional dependence, and so must independently be constant222N.B. If |xAB|2xATxBproportional-tosuperscriptsubscript𝑥𝐴𝐵2superscriptsubscript𝑥𝐴𝑇subscript𝑥𝐵|x_{AB}|^{2}\propto x_{A}^{T}x_{B}| italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∝ italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT then |xAB|2=constsuperscriptsubscript𝑥𝐴𝐵2𝑐𝑜𝑛𝑠𝑡|x_{AB}|^{2}=const| italic_x start_POSTSUBSCRIPT italic_A italic_B end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_c italic_o italic_n italic_s italic_t reduces to xATxB=constsuperscriptsubscript𝑥𝐴𝑇subscript𝑥𝐵𝑐𝑜𝑛𝑠𝑡x_{A}^{T}x_{B}=constitalic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t, which is already required.. We then have μ,ν{A,B,AB,other}for-all𝜇𝜈𝐴𝐵𝐴𝐵𝑜𝑡𝑒𝑟\forall~{}~{}\mu,\nu\in\{A,B,AB,other\}∀ italic_μ , italic_ν ∈ { italic_A , italic_B , italic_A italic_B , italic_o italic_t italic_h italic_e italic_r }

|xμ|=constandxμTxν=constformulae-sequencesubscript𝑥𝜇𝑐𝑜𝑛𝑠𝑡andsuperscriptsubscript𝑥𝜇𝑇subscript𝑥𝜈𝑐𝑜𝑛𝑠𝑡|x_{\mu}|=const~{}~{}~{}~{}~{}~{}\mathrm{and}~{}~{}~{}~{}~{}~{}x_{\mu}^{T}x_{% \nu}=const| italic_x start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT | = italic_c italic_o italic_n italic_s italic_t roman_and italic_x start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t (34)

The requirements |xA(θA)|=constθAsubscript𝑥𝐴subscript𝜃𝐴𝑐𝑜𝑛𝑠𝑡for-allsubscript𝜃𝐴|x_{A}(\theta_{A})|=const~{}\forall~{}\theta_{A}| italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) | = italic_c italic_o italic_n italic_s italic_t ∀ italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and |xB(θB)|=constθBsubscript𝑥𝐵subscript𝜃𝐵𝑐𝑜𝑛𝑠𝑡for-allsubscript𝜃𝐵|x_{B}(\theta_{B})|=const~{}\forall~{}\theta_{B}| italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) | = italic_c italic_o italic_n italic_s italic_t ∀ italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT mean that the semantic subspaces have a spherical structure defined by the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm |||\cdot|| ⋅ |.

Now consider the requirement xA(θA)TxB(θB)=constsubscript𝑥𝐴superscriptsubscript𝜃𝐴𝑇subscript𝑥𝐵subscript𝜃𝐵𝑐𝑜𝑛𝑠𝑡x_{A}(\theta_{A})^{T}x_{B}(\theta_{B})=constitalic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) = italic_c italic_o italic_n italic_s italic_t. Say that θAsubscript𝜃𝐴\theta_{A}italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and θBsubscript𝜃𝐵\theta_{B}italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT have NAsubscript𝑁𝐴N_{A}italic_N start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and NBsubscript𝑁𝐵N_{B}italic_N start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT degrees of freedom, meaning that xAsubscript𝑥𝐴x_{A}italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and xBsubscript𝑥𝐵x_{B}italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT have NA1subscript𝑁𝐴1N_{A}-1italic_N start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT - 1 and NB1subscript𝑁𝐵1N_{B}-1italic_N start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT - 1 respectively, since they each lose one by confinement to the sphere. Say that the constant is nonzero such that xATxB0superscriptsubscript𝑥𝐴𝑇subscript𝑥𝐵0x_{A}^{T}x_{B}\neq 0italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ≠ 0. This means that there must be some direction i𝑖iitalic_i for which xAixBi0subscript𝑥𝐴𝑖subscript𝑥𝐵𝑖0x_{Ai}x_{Bi}\neq 0italic_x start_POSTSUBSCRIPT italic_A italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_B italic_i end_POSTSUBSCRIPT ≠ 0. If we know all NA1subscript𝑁𝐴1N_{A}-1italic_N start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT - 1 coordinates of xAsubscript𝑥𝐴x_{A}italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT, and all NB2subscript𝑁𝐵2N_{B}-2italic_N start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT - 2 coordinates of xBsubscript𝑥𝐵x_{B}italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT except for direction i𝑖iitalic_i, then we also know the value of xBisubscript𝑥𝐵𝑖x_{Bi}italic_x start_POSTSUBSCRIPT italic_B italic_i end_POSTSUBSCRIPT, because it is fixed by the constant. However, this would mean that xAsubscript𝑥𝐴x_{A}italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and xBsubscript𝑥𝐵x_{B}italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT are not independent, violating the condition θAθBperpendicular-tosubscript𝜃𝐴subscript𝜃𝐵\theta_{A}\perp\theta_{B}italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ⟂ italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT. The only way to satisfy independence is if xAixBi=0isubscript𝑥𝐴𝑖subscript𝑥𝐵𝑖0for-all𝑖x_{Ai}x_{Bi}=0~{}\forall~{}iitalic_x start_POSTSUBSCRIPT italic_A italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_B italic_i end_POSTSUBSCRIPT = 0 ∀ italic_i, ensuring that degrees of freedom on xAsubscript𝑥𝐴x_{A}italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and xBsubscript𝑥𝐵x_{B}italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT never become entangled. Therefore, to satisfy semantic independence, we must have xA(θA)TxB(θB)=0θA,θBsubscript𝑥𝐴superscriptsubscript𝜃𝐴𝑇subscript𝑥𝐵subscript𝜃𝐵0for-allsubscript𝜃𝐴subscript𝜃𝐵x_{A}(\theta_{A})^{T}x_{B}(\theta_{B})=0~{}\forall~{}\theta_{A},\theta_{B}italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) = 0 ∀ italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT. This means that the subspaces are not just linearly independent, but orthogonal.

We have shown the proof for semantic subspaces of x𝑥xitalic_x. As for Theorem 1, the same structure must be true for ytsubscript𝑦𝑡y_{t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by symmetry.

{mdframed}

[backgroundcolor=green!5] Theorem 3. QKV-Norm: Semantic subspaces must be linearly separable, reproducing the No-Norm case.

Proof.  We have

wt(A)(θA)=1|kt(A)||q(A)|wt(A)(θA)superscriptsubscript𝑤𝑡𝐴subscript𝜃𝐴1superscriptsubscript𝑘𝑡𝐴superscript𝑞𝐴superscriptsubscriptsuperscript𝑤𝑡𝐴subscript𝜃𝐴w_{t}^{(A)}(\theta_{A})~{}=~{}\frac{1}{\left|k_{t}^{(A)}\right|\left|q^{(A)}% \right|}{w^{*}_{t}}^{(A)}(\theta_{A})italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG | italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT | | italic_q start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT | end_ARG italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) (35)

where wtsubscriptsuperscript𝑤𝑡w^{*}_{t}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are the attention scores from the No-Norm case, which requires xA(θA)subscript𝑥𝐴subscript𝜃𝐴x_{A}(\theta_{A})italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) and xB(θB)subscript𝑥𝐵subscript𝜃𝐵x_{B}(\theta_{B})italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) to be linearly independent. Use

x(θA,θB,ϕ)=xA(θA)+xB(θB)+xother(θA,θB,ϕ)𝑥subscript𝜃𝐴subscript𝜃𝐵italic-ϕsubscript𝑥𝐴subscript𝜃𝐴subscript𝑥𝐵subscript𝜃𝐵subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴subscript𝜃𝐵italic-ϕx(\theta_{A},\theta_{B},\phi)~{}=~{}x_{A}(\theta_{A})~{}+~{}x_{B}(\theta_{B})~% {}+~{}x_{other}(\theta_{A},\theta_{B},\phi)italic_x ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) = italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) + italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) + italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) (36)

and

q(A)(θA)=WQ(A)x(θA,θB,ϕ)=WQ(A)xA(θA)+WQ(A)xB(θB)+WQ(A)xother(θA,θB,ϕ)superscript𝑞𝐴subscript𝜃𝐴superscriptsubscript𝑊𝑄𝐴𝑥subscript𝜃𝐴subscript𝜃𝐵italic-ϕsuperscriptsubscript𝑊𝑄𝐴subscript𝑥𝐴subscript𝜃𝐴superscriptsubscript𝑊𝑄𝐴subscript𝑥𝐵subscript𝜃𝐵superscriptsubscript𝑊𝑄𝐴subscript𝑥𝑜𝑡𝑒𝑟subscript𝜃𝐴subscript𝜃𝐵italic-ϕ\begin{split}q^{(A)}(\theta_{A})~{}&=~{}W_{Q}^{(A)}x(\theta_{A},\theta_{B},% \phi)\\ &=~{}W_{Q}^{(A)}x_{A}(\theta_{A})~{}+~{}W_{Q}^{(A)}x_{B}(\theta_{B})~{}+~{}W_{% Q}^{(A)}x_{other}(\theta_{A},\theta_{B},\phi)\\ \end{split}start_ROW start_CELL italic_q start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) end_CELL start_CELL = italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_x ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) + italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) + italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_o italic_t italic_h italic_e italic_r end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_ϕ ) end_CELL end_ROW (37)

Since we already have the condition of linearly independent xA,xBsubscript𝑥𝐴subscript𝑥𝐵x_{A},x_{B}italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT, there must exist a linear projection operator PAsubscript𝑃𝐴P_{A}italic_P start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT such that PAxA=xAsubscript𝑃𝐴subscript𝑥𝐴subscript𝑥𝐴P_{A}x_{A}=x_{A}italic_P start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT. Defining WQ(A)=PAsuperscriptsubscript𝑊𝑄𝐴subscript𝑃𝐴W_{Q}^{(A)}=P_{A}italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT = italic_P start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT, we then have

q(A)(θA)=WQ(A)xA(θA)superscript𝑞𝐴subscript𝜃𝐴superscriptsubscript𝑊𝑄𝐴subscript𝑥𝐴subscript𝜃𝐴q^{(A)}(\theta_{A})~{}=~{}W_{Q}^{(A)}x_{A}(\theta_{A})italic_q start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) = italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ) (38)

This demonstrates that it is possible to separate linearly independent semantic subspaces on x𝑥xitalic_x. By symmetry of wt(A)(θA)superscriptsubscript𝑤𝑡𝐴subscript𝜃𝐴w_{t}^{(A)}(\theta_{A})italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_A ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ), the same must be true for ytsubscript𝑦𝑡y_{t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

{mdframed}

[backgroundcolor=green!5] Theorem 4. Consider independent infinitesimal perturbations on queries ϵqNqkvsuperscriptitalic-ϵ𝑞superscriptsubscript𝑁𝑞𝑘𝑣\epsilon^{q}\in\mathbb{R}^{N_{qkv}}italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, keys ϵtkNqkvsubscriptsuperscriptitalic-ϵ𝑘𝑡superscriptsubscript𝑁𝑞𝑘𝑣\epsilon^{k}_{t}\in\mathbb{R}^{N_{qkv}}italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and messages ϵtmNqkvsubscriptsuperscriptitalic-ϵ𝑚𝑡superscriptsubscript𝑁𝑞𝑘𝑣\epsilon^{m}_{t}\in\mathbb{R}^{N_{qkv}}italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. These propagate onto Δx=tatmtΔ𝑥subscript𝑡subscript𝑎𝑡subscript𝑚𝑡\Delta x=\sum_{t}a_{t}m_{t}roman_Δ italic_x = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT as

ϵΔx(q)superscriptitalic-ϵΔ𝑥𝑞\displaystyle\epsilon^{\Delta x(q)}~{}~{}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT ϵq0perturbq𝔼at[mtk~tT]ϵqk~tkt𝔼at[kt]\displaystyle\xrightarrow[\epsilon^{q}\rightarrow 0]{\mathrm{~{}~{}~{}~{}% perturb~{}q~{}~{}~{}~{}}}~{}~{}\mathop{\mathbb{E}}_{a_{t}}\Big{[}m_{t}{\tilde{% k}}_{t}^{T}\Big{]}\epsilon^{q}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}% ~{}{\tilde{k}}_{t}~{}\triangleq~{}k_{t}~{}-\mathop{\mathbb{E}}_{a_{t}}\Big{[}k% _{t}\Big{]}start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_q end_OVERACCENT → end_ARROW end_ARROW blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≜ italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] (39)
ϵΔx(k)superscriptitalic-ϵΔ𝑥𝑘\displaystyle\epsilon^{\Delta x(k)}~{}~{}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT ϵtk0perturbk𝔼at[m~tϵtkT]qm~tmt𝔼at[mt]\displaystyle\xrightarrow[\epsilon^{k}_{t}\rightarrow 0]{\mathrm{~{}~{}~{}~{}% perturb~{}k~{}~{}~{}~{}}}~{}~{}\mathop{\mathbb{E}}_{a_{t}}\Big{[}{\tilde{m}}_{% t}{\epsilon^{k}_{t}}^{T}\Big{]}q~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~% {}~{}{\tilde{m}}_{t}~{}\triangleq~{}m_{t}~{}-\mathop{\mathbb{E}}_{a_{t}}\Big{[% }m_{t}\Big{]}start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_k end_OVERACCENT → end_ARROW end_ARROW blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] italic_q over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≜ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] (40)
ϵΔx(m)superscriptitalic-ϵΔ𝑥𝑚\displaystyle\epsilon^{\Delta x(m)}~{}~{}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_m ) end_POSTSUPERSCRIPT ϵtm0perturbm𝔼at[ϵtm]\displaystyle\xrightarrow[\epsilon^{m}_{t}\rightarrow 0]{\mathrm{~{}~{}~{}~{}% perturb~{}m~{}~{}~{}~{}}}~{}~{}\mathop{\mathbb{E}}_{a_{t}}\Big{[}\epsilon^{m}_% {t}\Big{]}start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_m end_OVERACCENT → end_ARROW end_ARROW blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] (41)

where z~tsubscript~𝑧𝑡{\tilde{z}}_{t}over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the value of ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT measured from the attention-weighted centroid 𝔼at[zt]=tatztsubscript𝔼subscript𝑎𝑡delimited-[]subscript𝑧𝑡subscript𝑡subscript𝑎𝑡subscript𝑧𝑡\mathbb{E}_{a_{t}}[z_{t}]=\sum_{t}a_{t}z_{t}blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Proof.  Consider qq+ϵq𝑞𝑞superscriptitalic-ϵ𝑞q\rightarrow q+\epsilon^{q}italic_q → italic_q + italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT where ϵqsuperscriptitalic-ϵ𝑞\epsilon^{q}italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT are infinitesimal perturbations on q𝑞qitalic_q. Then ΔxΔx+ϵΔx(q)Δ𝑥Δ𝑥superscriptitalic-ϵΔ𝑥𝑞\Delta x\rightarrow\Delta x+\epsilon^{\Delta x(q)}roman_Δ italic_x → roman_Δ italic_x + italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT where by Taylor expansion we find

ϵΔx(q)=Δxqϵq+𝒪(ϵq2)superscriptitalic-ϵΔ𝑥𝑞Δ𝑥𝑞superscriptitalic-ϵ𝑞𝒪superscriptsuperscriptitalic-ϵ𝑞2\epsilon^{\Delta x(q)}~{}=~{}\frac{\partial\Delta x}{\partial q}\epsilon^{q}~{% }+~{}\mathcal{O}\left({\epsilon^{q}}^{2}\right)italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT = divide start_ARG ∂ roman_Δ italic_x end_ARG start_ARG ∂ italic_q end_ARG italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (42)

where the leading term is a matrix ΔxqΔ𝑥𝑞\frac{\partial\Delta x}{\partial q}divide start_ARG ∂ roman_Δ italic_x end_ARG start_ARG ∂ italic_q end_ARG acting on a vector ϵqsuperscriptitalic-ϵ𝑞\epsilon^{q}italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT. Differentiating gives

Δxq=ijmiaiwjwjqΔ𝑥𝑞subscript𝑖𝑗subscript𝑚𝑖subscript𝑎𝑖subscript𝑤𝑗subscript𝑤𝑗𝑞\frac{\partial\Delta x}{\partial q}~{}=~{}\sum_{ij}m_{i}\frac{\partial a_{i}}{% \partial w_{j}}\frac{\partial w_{j}}{\partial q}divide start_ARG ∂ roman_Δ italic_x end_ARG start_ARG ∂ italic_q end_ARG = ∑ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG divide start_ARG ∂ italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_q end_ARG (43)

with ai=softmaxi(wi)subscript𝑎𝑖subscriptsoftmax𝑖subscript𝑤𝑖a_{i}=\texttt{softmax}_{i}(w_{i})italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = softmax start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) and wi=kiTqsubscript𝑤𝑖superscriptsubscript𝑘𝑖𝑇𝑞w_{i}=k_{i}^{T}qitalic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_q, and we are using i,j,k𝑖𝑗𝑘i,j,kitalic_i , italic_j , italic_k etc to index over tokens instead of t,t,t′′𝑡superscript𝑡superscript𝑡′′t,t^{\prime},t^{\prime\prime}italic_t , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT etc, because this is more readable when we have many summations. Then

[continued in next box…]

{mdframed}

[backgroundcolor=green!5] […continuing from previous box]

aiwj=wjewikewk=δijewikewk+ewi(ewj(kewk)2)=ewikewk(1ewjlewl)=ai(δijaj)subscript𝑎𝑖subscript𝑤𝑗subscript𝑤𝑗superscript𝑒subscript𝑤𝑖subscript𝑘superscript𝑒subscript𝑤𝑘subscript𝛿𝑖𝑗superscript𝑒subscript𝑤𝑖subscript𝑘superscript𝑒subscript𝑤𝑘superscript𝑒subscript𝑤𝑖superscript𝑒subscript𝑤𝑗superscriptsubscript𝑘superscript𝑒subscript𝑤𝑘2superscript𝑒subscript𝑤𝑖subscript𝑘superscript𝑒subscript𝑤𝑘1superscript𝑒subscript𝑤𝑗subscript𝑙superscript𝑒subscript𝑤𝑙subscript𝑎𝑖subscript𝛿𝑖𝑗subscript𝑎𝑗\begin{split}\frac{\partial a_{i}}{\partial w_{j}}~{}&=~{}\frac{\partial}{% \partial w_{j}}~{}\frac{e^{w_{i}}}{\sum_{k}e^{w_{k}}}\\ &=~{}\frac{\delta_{ij}e^{w_{i}}}{\sum_{k}e^{w_{k}}}~{}+~{}e^{w_{i}}\left(-% \frac{e^{w_{j}}}{\left(\sum_{k}e^{w_{k}}\right)^{2}}\right)\\ &=~{}\frac{e^{w_{i}}}{\sum_{k}e^{w_{k}}}\left(1~{}-~{}\frac{e^{w_{j}}}{\sum_{l% }e^{w_{l}}}\right)\\ &=~{}a_{i}\left(\delta_{ij}~{}-~{}a_{j}\right)\\ \end{split}start_ROW start_CELL divide start_ARG ∂ italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG end_CELL start_CELL = divide start_ARG ∂ end_ARG start_ARG ∂ italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG divide start_ARG italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG italic_δ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG + italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( - divide start_ARG italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ( ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG ( 1 - divide start_ARG italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_δ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT - italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_CELL end_ROW (44)

and wiq=kiTsubscript𝑤𝑖𝑞superscriptsubscript𝑘𝑖𝑇\frac{\partial w_{i}}{\partial q}=k_{i}^{T}divide start_ARG ∂ italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_q end_ARG = italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, where we retain the transpose to indicate that this is an element of the dual vector space (i.e. covector). Inserting these results into our expression for ϵΔx(q)superscriptitalic-ϵΔ𝑥𝑞\epsilon^{\Delta x(q)}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT gives

ϵΔx(q)=ijmiai(δijaj)kjTϵq=imiai(kijajkj)Tϵq=imiaik~iTϵq=𝔼ai[mik~iT]ϵqsuperscriptitalic-ϵΔ𝑥𝑞subscript𝑖𝑗subscript𝑚𝑖subscript𝑎𝑖subscript𝛿𝑖𝑗subscript𝑎𝑗superscriptsubscript𝑘𝑗𝑇superscriptitalic-ϵ𝑞subscript𝑖subscript𝑚𝑖subscript𝑎𝑖superscriptsubscript𝑘𝑖subscript𝑗subscript𝑎𝑗subscript𝑘𝑗𝑇superscriptitalic-ϵ𝑞subscript𝑖subscript𝑚𝑖subscript𝑎𝑖superscriptsubscript~𝑘𝑖𝑇superscriptitalic-ϵ𝑞subscript𝔼subscript𝑎𝑖delimited-[]subscript𝑚𝑖superscriptsubscript~𝑘𝑖𝑇superscriptitalic-ϵ𝑞\begin{split}\epsilon^{\Delta x(q)}~{}&=~{}\sum_{ij}m_{i}a_{i}\left(\delta_{ij% }~{}-~{}a_{j}\right)k_{j}^{T}\epsilon^{q}\\ &=~{}\sum_{i}m_{i}a_{i}\left(k_{i}~{}-~{}\sum_{j}a_{j}k_{j}\right)^{T}\epsilon% ^{q}\\ &=~{}\sum_{i}m_{i}a_{i}{\tilde{k}}_{i}^{T}\epsilon^{q}\\ &=~{}\mathop{\mathbb{E}}_{a_{i}}\Big{[}m_{i}{\tilde{k}}_{i}^{T}\Big{]}\epsilon% ^{q}\\ \end{split}start_ROW start_CELL italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_δ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT - italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) italic_k start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW (45)

This is the result for Eq. 39. Repeating the process for perturbations on kisubscript𝑘𝑖k_{i}italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we have

ϵΔx(k)=iΔxkiϵik+𝒪(ϵk2)superscriptitalic-ϵΔ𝑥𝑘subscript𝑖Δ𝑥subscript𝑘𝑖subscriptsuperscriptitalic-ϵ𝑘𝑖𝒪superscriptsuperscriptitalic-ϵ𝑘2\epsilon^{\Delta x(k)}~{}=~{}\sum_{i}\frac{\partial\Delta x}{\partial k_{i}}% \epsilon^{k}_{i}~{}+~{}\mathcal{O}\left({\epsilon^{k}}^{2}\right)italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ roman_Δ italic_x end_ARG start_ARG ∂ italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (46)

and

Δxki=jkmjajwkwkki=jkmjaj(δjkak)δkiqT=jmjaj(δjiai)qT=aim~iqTΔ𝑥subscript𝑘𝑖subscript𝑗𝑘subscript𝑚𝑗subscript𝑎𝑗subscript𝑤𝑘subscript𝑤𝑘subscript𝑘𝑖subscript𝑗𝑘subscript𝑚𝑗subscript𝑎𝑗subscript𝛿𝑗𝑘subscript𝑎𝑘subscript𝛿𝑘𝑖superscript𝑞𝑇subscript𝑗subscript𝑚𝑗subscript𝑎𝑗subscript𝛿𝑗𝑖subscript𝑎𝑖superscript𝑞𝑇subscript𝑎𝑖subscript~𝑚𝑖superscript𝑞𝑇\begin{split}\frac{\partial\Delta x}{\partial k_{i}}~{}&=~{}\sum_{jk}m_{j}% \frac{\partial a_{j}}{\partial w_{k}}\frac{\partial w_{k}}{\partial k_{i}}\\ &=~{}\sum_{jk}m_{j}a_{j}\left(\delta_{jk}~{}-~{}a_{k}\right)\delta_{ki}q^{T}\\ &=~{}\sum_{j}m_{j}a_{j}\left(\delta_{ji}~{}-~{}a_{i}\right)q^{T}\\ &=~{}a_{i}{\tilde{m}}_{i}q^{T}\end{split}start_ROW start_CELL divide start_ARG ∂ roman_Δ italic_x end_ARG start_ARG ∂ italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_j italic_k end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT divide start_ARG ∂ italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG divide start_ARG ∂ italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_j italic_k end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_δ start_POSTSUBSCRIPT italic_j italic_k end_POSTSUBSCRIPT - italic_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_δ start_POSTSUBSCRIPT italic_k italic_i end_POSTSUBSCRIPT italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_δ start_POSTSUBSCRIPT italic_j italic_i end_POSTSUBSCRIPT - italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_CELL end_ROW (47)

Therefore

ϵΔx(k)=iaim~iqTϵik=𝔼ai[m~iϵikT]qsuperscriptitalic-ϵΔ𝑥𝑘subscript𝑖subscript𝑎𝑖subscript~𝑚𝑖superscript𝑞𝑇subscriptsuperscriptitalic-ϵ𝑘𝑖subscript𝔼subscript𝑎𝑖delimited-[]subscript~𝑚𝑖superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑖𝑇𝑞\epsilon^{\Delta x(k)}~{}=~{}\sum_{i}a_{i}{\tilde{m}}_{i}q^{T}\epsilon^{k}_{i}% ~{}=~{}\mathop{\mathbb{E}}_{a_{i}}\Big{[}{\tilde{m}}_{i}{\epsilon^{k}_{i}}^{T}% \Big{]}qitalic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] italic_q (48)

which is the result for Eq. 40. Finally,

ϵΔx(m)=iΔxmiϵim=iaiϵim=𝔼ai[ϵim]superscriptitalic-ϵΔ𝑥𝑚subscript𝑖Δ𝑥subscript𝑚𝑖subscriptsuperscriptitalic-ϵ𝑚𝑖subscript𝑖subscript𝑎𝑖subscriptsuperscriptitalic-ϵ𝑚𝑖subscript𝔼subscript𝑎𝑖delimited-[]subscriptsuperscriptitalic-ϵ𝑚𝑖\begin{split}\epsilon^{\Delta x(m)}~{}&=~{}\sum_{i}\frac{\partial\Delta x}{% \partial m_{i}}\epsilon^{m}_{i}\\ &=~{}\sum_{i}a_{i}\epsilon^{m}_{i}\\ &=~{}\mathop{\mathbb{E}}_{a_{i}}\Big{[}\epsilon^{m}_{i}\Big{]}\end{split}start_ROW start_CELL italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_m ) end_POSTSUPERSCRIPT end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ roman_Δ italic_x end_ARG start_ARG ∂ italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] end_CELL end_ROW (49)

using Δxmi=mijajmj=jajδij=aiΔ𝑥subscript𝑚𝑖subscript𝑚𝑖subscript𝑗subscript𝑎𝑗subscript𝑚𝑗subscript𝑗subscript𝑎𝑗subscript𝛿𝑖𝑗subscript𝑎𝑖\frac{\partial\Delta x}{\partial m_{i}}=\frac{\partial}{\partial m_{i}}\sum_{j% }a_{j}m_{j}=\sum_{j}a_{j}\delta_{ij}=a_{i}divide start_ARG ∂ roman_Δ italic_x end_ARG start_ARG ∂ italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ end_ARG start_ARG ∂ italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. This is the result for Eq. 41.

{mdframed}

[backgroundcolor=green!5] Theorem 5. For sparse attention:

ϵΔx(q)ϵq0perturbq0ϵΔx(k)ϵtk0perturbk0ϵΔx(m)ϵtm0perturbmϵtmformulae-sequencesuperscriptitalic-ϵ𝑞0perturbqsuperscriptitalic-ϵΔ𝑥𝑞0formulae-sequencesubscriptsuperscriptitalic-ϵ𝑘𝑡0perturbksuperscriptitalic-ϵΔ𝑥𝑘0subscriptsuperscriptitalic-ϵ𝑚𝑡0perturbmsuperscriptitalic-ϵΔ𝑥𝑚subscriptsuperscriptitalic-ϵ𝑚superscript𝑡\epsilon^{\Delta x(q)}\xrightarrow[\epsilon^{q}\rightarrow 0]{\mathrm{~{}~{}% perturb~{}q~{}~{}}}0~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\epsilon^{\Delta x(k)}% \xrightarrow[\epsilon^{k}_{t}\rightarrow 0]{\mathrm{~{}~{}perturb~{}k~{}~{}}}0% ~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\epsilon^{\Delta x(m)}\xrightarrow[\epsilon^{m}_% {t}\rightarrow 0]{\mathrm{~{}~{}perturb~{}m~{}~{}}}\epsilon^{m}_{t^{*}}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_q end_OVERACCENT → end_ARROW end_ARROW 0 italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_k end_OVERACCENT → end_ARROW end_ARROW 0 italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_m ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_m end_OVERACCENT → end_ARROW end_ARROW italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (50)

i.e. the message is stable with respect to small interference in the queries and keys. Interference in the selected value is linearly transferred onto the message.

Proof.  For sparse attention we have at=δttsubscript𝑎𝑡subscript𝛿𝑡superscript𝑡a_{t}=\delta_{tt^{*}}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_δ start_POSTSUBSCRIPT italic_t italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT for some tsuperscript𝑡t^{*}italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. For perturbations of q𝑞qitalic_q, the RHS of Eq. 39 becomes

𝔼at[mtk~tT]ϵq=tatmtk~tTϵq=tδttmtk~tTϵq=mtk~tTϵq=0subscript𝔼subscript𝑎𝑡delimited-[]subscript𝑚𝑡superscriptsubscript~𝑘𝑡𝑇superscriptitalic-ϵ𝑞subscript𝑡subscript𝑎𝑡subscript𝑚𝑡superscriptsubscript~𝑘𝑡𝑇superscriptitalic-ϵ𝑞subscript𝑡subscript𝛿𝑡superscript𝑡subscript𝑚𝑡superscriptsubscript~𝑘𝑡𝑇superscriptitalic-ϵ𝑞subscript𝑚superscript𝑡superscriptsubscript~𝑘superscript𝑡𝑇superscriptitalic-ϵ𝑞0\begin{split}\mathop{\mathbb{E}}_{a_{t}}\Big{[}m_{t}{\tilde{k}}_{t}^{T}\Big{]}% \epsilon^{q}~{}&=~{}\sum_{t}a_{t}m_{t}{\tilde{k}}_{t}^{T}\epsilon^{q}\\ &=~{}\sum_{t}\delta_{tt^{*}}m_{t}{\tilde{k}}_{t}^{T}\epsilon^{q}\\ &=~{}m_{t^{*}}{\tilde{k}}_{t^{*}}^{T}\epsilon^{q}\\ &=~{}0\\ \end{split}start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_t italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_m start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = 0 end_CELL end_ROW (51)

where the final step is because k~t=kt𝔼at[kt]=kttδttkt=ktkt=0subscript~𝑘superscript𝑡subscript𝑘superscript𝑡subscript𝔼subscript𝑎𝑡delimited-[]subscript𝑘𝑡subscript𝑘superscript𝑡subscript𝑡subscript𝛿𝑡superscript𝑡subscript𝑘𝑡subscript𝑘superscript𝑡subscript𝑘superscript𝑡0{\tilde{k}}_{t^{*}}=k_{t^{*}}-\mathbb{E}_{a_{t}}[k_{t}]=k_{t^{*}}-\sum_{t}% \delta_{tt^{*}}k_{t}=k_{t^{*}}-k_{t^{*}}=0over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = italic_k start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] = italic_k start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_t italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_k start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_k start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = 0. For perturbations of ktsubscript𝑘𝑡k_{t}italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, the RHS of Eq. 40 evaluates to 00 because

𝔼at[m~tϵtkT]q=tatm~tqTϵtk=tδttm~tqTϵtk=m~tqTϵtk=0subscript𝔼subscript𝑎𝑡delimited-[]subscript~𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑞subscript𝑡subscript𝑎𝑡subscript~𝑚𝑡superscript𝑞𝑇subscriptsuperscriptitalic-ϵ𝑘𝑡subscript𝑡subscript𝛿𝑡superscript𝑡subscript~𝑚𝑡superscript𝑞𝑇subscriptsuperscriptitalic-ϵ𝑘𝑡subscript~𝑚superscript𝑡superscript𝑞𝑇subscriptsuperscriptitalic-ϵ𝑘superscript𝑡0\begin{split}\mathop{\mathbb{E}}_{a_{t}}\Big{[}{\tilde{m}}_{t}{\epsilon^{k}_{t% }}^{T}\Big{]}q~{}&=~{}\sum_{t}a_{t}{\tilde{m}}_{t}q^{T}\epsilon^{k}_{t}\\ &=~{}\sum_{t}\delta_{tt^{*}}{\tilde{m}}_{t}q^{T}\epsilon^{k}_{t}\\ &=~{}{\tilde{m}}_{t^{*}}q^{T}\epsilon^{k}_{t^{*}}\\ &=~{}0\\ \end{split}start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] italic_q end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_t italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = 0 end_CELL end_ROW (52)

where the final step is because m~t=mttδttmt=mtmt=0subscript~𝑚superscript𝑡subscript𝑚superscript𝑡subscript𝑡subscript𝛿𝑡superscript𝑡subscript𝑚𝑡subscript𝑚superscript𝑡subscript𝑚superscript𝑡0{\tilde{m}}_{t^{*}}=m_{t^{*}}-\sum_{t}\delta_{tt^{*}}m_{t}=m_{t^{*}}-m_{t^{*}}=0over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = italic_m start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_t italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_m start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_m start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = 0. For perturbations of mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, the RHS of Eq. 41 evaluates to

𝔼at[ϵtm]=tatϵtm=tδttϵtm=ϵtmsubscript𝔼subscript𝑎𝑡delimited-[]subscriptsuperscriptitalic-ϵ𝑚𝑡subscript𝑡subscript𝑎𝑡subscriptsuperscriptitalic-ϵ𝑚𝑡subscript𝑡subscript𝛿𝑡superscript𝑡subscriptsuperscriptitalic-ϵ𝑚𝑡subscriptsuperscriptitalic-ϵ𝑚superscript𝑡\mathop{\mathbb{E}}_{a_{t}}\Big{[}\epsilon^{m}_{t}\Big{]}~{}=~{}\sum_{t}a_{t}% \epsilon^{m}_{t}~{}=~{}\sum_{t}\delta_{tt^{*}}\epsilon^{m}_{t}~{}=~{}\epsilon^% {m}_{t^{*}}blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_t italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (53)
{mdframed}

[backgroundcolor=green!5] Theorem 6. For isotropic attention:

ϵΔx(q)ϵq0perturbqmtk~tTtϵqϵΔx(k)ϵtk0perturbkm~tϵtkTtqϵΔx(m)ϵtm0perturbmϵtmtformulae-sequencesuperscriptitalic-ϵ𝑞0perturbqsuperscriptitalic-ϵΔ𝑥𝑞subscriptdelimited-⟨⟩subscript𝑚𝑡superscriptsubscript~𝑘𝑡𝑇𝑡superscriptitalic-ϵ𝑞formulae-sequencesubscriptsuperscriptitalic-ϵ𝑘𝑡0perturbksuperscriptitalic-ϵΔ𝑥𝑘subscriptdelimited-⟨⟩subscript~𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑡𝑞subscriptsuperscriptitalic-ϵ𝑚𝑡0perturbmsuperscriptitalic-ϵΔ𝑥𝑚subscriptdelimited-⟨⟩subscriptsuperscriptitalic-ϵ𝑚𝑡𝑡\epsilon^{\Delta x(q)}\xrightarrow[\epsilon^{q}\rightarrow 0]{\mathrm{perturb~% {}q}}\langle m_{t}{\tilde{k}}_{t}^{T}\rangle_{t}\epsilon^{q}~{}~{}~{}~{}~{}~{}% ~{}~{}\epsilon^{\Delta x(k)}\xrightarrow[\epsilon^{k}_{t}\rightarrow 0]{% \mathrm{perturb~{}k}}\langle{\tilde{m}}_{t}{\epsilon^{k}_{t}}^{T}\rangle_{t}~{% }q~{}~{}~{}~{}~{}~{}~{}~{}\epsilon^{\Delta x(m)}\xrightarrow[\epsilon^{m}_{t}% \rightarrow 0]{\mathrm{perturb~{}m}}\langle\epsilon^{m}_{t}\rangle_{t}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_q ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_q end_OVERACCENT → end_ARROW end_ARROW ⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_k end_OVERACCENT → end_ARROW end_ARROW ⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_q italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_m ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_m end_OVERACCENT → end_ARROW end_ARROW ⟨ italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (54)

N.B. isotropy requires kt=constsubscript𝑘𝑡𝑐𝑜𝑛𝑠𝑡k_{t}=constitalic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t or q=0𝑞0q=0italic_q = 0. Lemma 1: the update is stable to noisy q𝑞qitalic_q when kt=constsubscript𝑘𝑡𝑐𝑜𝑛𝑠𝑡k_{t}=constitalic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t, or when mtktperpendicular-tosubscript𝑚𝑡subscript𝑘𝑡m_{t}\perp k_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟂ italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (c.f. keys and messages from independent subspaces). Lemma 2: the update is stable to noisy ktsubscript𝑘𝑡k_{t}italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT when q=0𝑞0q=0italic_q = 0, or when mtϵtkperpendicular-tosubscript𝑚𝑡superscriptsubscriptitalic-ϵ𝑡𝑘m_{t}\perp\epsilon_{t}^{k}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟂ italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. Lemma 3: the update is stable to noisy mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT when ϵtmt=0subscriptdelimited-⟨⟩subscriptsuperscriptitalic-ϵ𝑚𝑡𝑡0\langle\epsilon^{m}_{t}\rangle_{t}=0⟨ italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 0. Other cases propagate linearly.

Proof.  For isotropic attention we have at=1Tsubscript𝑎𝑡1𝑇a_{t}=\frac{1}{T}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_T end_ARG. For perturbations of q𝑞qitalic_q, the RHS of Eq. 39 is

𝔼at[mtk~tT]ϵq=tatmtk~tTϵq=1Tt=1Tmtk~tTϵq=mtk~tTtϵqsubscript𝔼subscript𝑎𝑡delimited-[]subscript𝑚𝑡superscriptsubscript~𝑘𝑡𝑇superscriptitalic-ϵ𝑞subscript𝑡subscript𝑎𝑡subscript𝑚𝑡superscriptsubscript~𝑘𝑡𝑇superscriptitalic-ϵ𝑞1𝑇superscriptsubscript𝑡1𝑇subscript𝑚𝑡superscriptsubscript~𝑘𝑡𝑇superscriptitalic-ϵ𝑞subscriptdelimited-⟨⟩subscript𝑚𝑡superscriptsubscript~𝑘𝑡𝑇𝑡superscriptitalic-ϵ𝑞\begin{split}\mathop{\mathbb{E}}_{a_{t}}\Big{[}m_{t}{\tilde{k}}_{t}^{T}\Big{]}% \epsilon^{q}~{}&=~{}\sum_{t}a_{t}m_{t}{\tilde{k}}_{t}^{T}\epsilon^{q}\\ &=~{}\frac{1}{T}\sum_{t=1}^{T}m_{t}{\tilde{k}}_{t}^{T}\epsilon^{q}\\ &=~{}\langle m_{t}{\tilde{k}}_{t}^{T}\rangle_{t}\epsilon^{q}\\ \end{split}start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW (55)

For lemma 1, we note that kt=constsubscript𝑘𝑡𝑐𝑜𝑛𝑠𝑡k_{t}=constitalic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t implies k~t=0subscript~𝑘𝑡0{\tilde{k}}_{t}=0over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 0, and if mtktperpendicular-tosubscript𝑚𝑡subscript𝑘𝑡m_{t}\perp k_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟂ italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT then mtk~tTt=mtkttmttktt=Cov(mt,kt)=0subscriptdelimited-⟨⟩subscript𝑚𝑡superscriptsubscript~𝑘𝑡𝑇𝑡subscriptdelimited-⟨⟩subscript𝑚𝑡subscript𝑘𝑡𝑡subscriptdelimited-⟨⟩subscript𝑚𝑡𝑡subscriptdelimited-⟨⟩subscript𝑘𝑡𝑡𝐶𝑜𝑣subscript𝑚𝑡subscript𝑘𝑡0\langle m_{t}{\tilde{k}}_{t}^{T}\rangle_{t}=\langle m_{t}k_{t}\rangle_{t}-% \langle m_{t}\rangle_{t}\langle k_{t}\rangle_{t}=Cov(m_{t},k_{t})=0⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over~ start_ARG italic_k end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - ⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟨ italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_C italic_o italic_v ( italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = 0.

[continued in next box…]

{mdframed}

[backgroundcolor=green!5] […continuing from previous box]

For perturbations of ktsubscript𝑘𝑡k_{t}italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, the RHS of Eq. 40 is

𝔼at[m~tϵtkT]q=1Tt=1Tm~tϵtkTq=m~tϵtkTtqsubscript𝔼subscript𝑎𝑡delimited-[]subscript~𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑞1𝑇superscriptsubscript𝑡1𝑇subscript~𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑞subscriptdelimited-⟨⟩subscript~𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑡𝑞\begin{split}\mathop{\mathbb{E}}_{a_{t}}\Big{[}{\tilde{m}}_{t}{\epsilon^{k}_{t% }}^{T}\Big{]}q~{}&=~{}\frac{1}{T}\sum_{t=1}^{T}{\tilde{m}}_{t}{\epsilon^{k}_{t% }}^{T}q\\ &=~{}\langle{\tilde{m}}_{t}{\epsilon^{k}_{t}}^{T}\rangle_{t}q\\ \end{split}start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] italic_q end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_q end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_q end_CELL end_ROW (56)

For lemma 2, this expression evaluates to 00 if q=0𝑞0q=0italic_q = 0, and if mtϵtkperpendicular-tosubscript𝑚𝑡superscriptsubscriptitalic-ϵ𝑡𝑘m_{t}\perp\epsilon_{t}^{k}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟂ italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT then m~tϵtkTt=mtϵtkTtmttϵtkTt=Cov(mt,ϵtkT)=0subscriptdelimited-⟨⟩subscript~𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑡subscriptdelimited-⟨⟩subscript𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑡subscriptdelimited-⟨⟩subscript𝑚𝑡𝑡subscriptdelimited-⟨⟩superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑡𝐶𝑜𝑣subscript𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇0\langle{\tilde{m}}_{t}{\epsilon^{k}_{t}}^{T}\rangle_{t}=\langle m_{t}{\epsilon% ^{k}_{t}}^{T}\rangle_{t}-\langle m_{t}\rangle_{t}\langle{\epsilon^{k}_{t}}^{T}% \rangle_{t}=Cov(m_{t},{\epsilon^{k}_{t}}^{T})=0⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - ⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟨ italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_C italic_o italic_v ( italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) = 0.

For perturbations of mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, the RHS of Eq. 41 evaluates to

𝔼at[ϵtm]=1Tt=1Tϵtm=ϵtmtsubscript𝔼subscript𝑎𝑡delimited-[]subscriptsuperscriptitalic-ϵ𝑚𝑡1𝑇superscriptsubscript𝑡1𝑇subscriptsuperscriptitalic-ϵ𝑚𝑡subscriptdelimited-⟨⟩subscriptsuperscriptitalic-ϵ𝑚𝑡𝑡\mathop{\mathbb{E}}_{a_{t}}\Big{[}\epsilon^{m}_{t}\Big{]}~{}=~{}\frac{1}{T}% \sum_{t=1}^{T}\epsilon^{m}_{t}~{}=~{}\langle\epsilon^{m}_{t}\rangle_{t}blackboard_E start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] = divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ⟨ italic_ϵ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (57)
{mdframed}

[backgroundcolor=green!5] Theorem 7.  Sensitivity of sparse attention to multiplicative perturbations ϵq=κqqsuperscriptitalic-ϵ𝑞superscript𝜅𝑞𝑞\epsilon^{q}=\kappa^{q}qitalic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT = italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT italic_q and ϵk=κtkktsuperscriptitalic-ϵ𝑘subscriptsuperscript𝜅𝑘𝑡subscript𝑘𝑡\epsilon^{k}=\kappa^{k}_{t}k_{t}italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with κq,κtk1much-less-thansuperscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡1\kappa^{q},\kappa^{k}_{t}\ll 1italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT , italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1. Circuit collapse occurs when tt𝑡superscript𝑡\exists~{}t\neq t^{*}∃ italic_t ≠ italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT for which:

wtwt{<λwifwt(1+κq+κtk)>0>λwotherwiseλw1+κq+κtk1+κq+κtksubscript𝑤superscript𝑡subscript𝑤𝑡casesabsentsubscript𝜆𝑤ifsubscript𝑤𝑡1superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡0absentsubscript𝜆𝑤otherwisesubscript𝜆𝑤1superscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡1superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡\frac{w_{t^{*}}}{w_{t}}~{}\begin{cases}~{}<~{}\lambda_{w}&\mathrm{if}~{}w_{t}% \left(1+\kappa^{q}+\kappa^{k}_{t^{*}}\right)>0\\ ~{}>~{}\lambda_{w}&\mathrm{otherwise}\\ \end{cases}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\lambda_{w}~{}\triangleq~{}% \frac{1+\kappa^{q}+\kappa^{k}_{t}}{1+\kappa^{q}+\kappa^{k}_{t^{*}}}divide start_ARG italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG { start_ROW start_CELL < italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_CELL start_CELL roman_if italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) > 0 end_CELL end_ROW start_ROW start_CELL > italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_CELL start_CELL roman_otherwise end_CELL end_ROW italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ≜ divide start_ARG 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG (58)

where temperature cancels in the fraction. Attention is fully stable above the critical transition point λwsubscript𝜆𝑤\lambda_{w}italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT (c.f. wt(1+κq+κtk)>0subscript𝑤𝑡1superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡0w_{t}\left(1+\kappa^{q}+\kappa^{k}_{t^{*}}\right)>0italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) > 0). We see that query perturbations alone are insufficient, as they result in λw=1subscript𝜆𝑤1\lambda_{w}=1italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = 1. Lemma: consider the special case when all keys have similar length ktconstsubscript𝑘𝑡𝑐𝑜𝑛𝑠𝑡k_{t}\approx constitalic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≈ italic_c italic_o italic_n italic_s italic_t, the attended token has θt0subscript𝜃superscript𝑡0\theta_{t^{*}}\approx 0italic_θ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ≈ 0, the keys are far-from-orthogonal s.t. θt1much-less-thansubscript𝜃𝑡1\theta_{t}\ll 1italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1, and κq0superscript𝜅𝑞0\kappa^{q}\approx 0italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ≈ 0. Using wt|q||kt|cosθtsubscript𝑤𝑡𝑞subscript𝑘𝑡subscript𝜃𝑡w_{t}\triangleq|q||k_{t}|\cos\theta_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≜ | italic_q | | italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | roman_cos italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, circuit collapse occurs when tt𝑡superscript𝑡\exists~{}t\neq t^{*}∃ italic_t ≠ italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT for which:

12θt2κtkκtkifwt(1+κtk)>0, otherwise reverseformulae-sequenceless-than-or-similar-to12superscriptsubscript𝜃𝑡2subscriptsuperscript𝜅𝑘𝑡subscriptsuperscript𝜅𝑘superscript𝑡ifsubscript𝑤𝑡1subscriptsuperscript𝜅𝑘superscript𝑡0, otherwise reverse\frac{1}{2}\theta_{t}^{2}~{}\lesssim~{}\kappa^{k}_{t}-\kappa^{k}_{t^{*}}~{}~{}% ~{}~{}~{}~{}~{}~{}~{}~{}~{}\mathrm{if}~{}w_{t}\left(1+\kappa^{k}_{t^{*}}\right% )>0~{}\text{, otherwise reverse}divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≲ italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_if italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( 1 + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) > 0 , otherwise reverse (59)

i.e. stability requires either well-separated keys s.t. θt0much-greater-thansubscript𝜃𝑡0\theta_{t}\gg 0italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≫ 0, or small perturbations κtκt1much-less-thansubscript𝜅𝑡subscriptsuperscript𝜅𝑡1\kappa_{t}-\kappa^{*}_{t}\ll 1italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_κ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1.

Proof.   Apply qq+ϵq𝑞𝑞superscriptitalic-ϵ𝑞q\rightarrow q+\epsilon^{q}italic_q → italic_q + italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT and ktkt+ϵtksubscript𝑘𝑡subscript𝑘𝑡superscriptsubscriptitalic-ϵ𝑡𝑘k_{t}\rightarrow k_{t}+\epsilon_{t}^{k}italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT to wt=qTktsubscript𝑤𝑡superscript𝑞𝑇subscript𝑘𝑡w_{t}=q^{T}k_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, then we have wtwt+ϵwsubscript𝑤𝑡subscript𝑤𝑡subscriptitalic-ϵ𝑤w_{t}\rightarrow w_{t}+\epsilon_{w}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_ϵ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT such that ϵtw=qTϵtk+ϵqTkt+ϵqTϵtksubscriptsuperscriptitalic-ϵ𝑤𝑡superscript𝑞𝑇superscriptsubscriptitalic-ϵ𝑡𝑘superscriptsuperscriptitalic-ϵ𝑞𝑇subscript𝑘𝑡superscriptsuperscriptitalic-ϵ𝑞𝑇superscriptsubscriptitalic-ϵ𝑡𝑘\epsilon^{w}_{t}=q^{T}\epsilon_{t}^{k}+{\epsilon^{q}}^{T}k_{t}+{\epsilon^{q}}^% {T}\epsilon_{t}^{k}italic_ϵ start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. For multiplicative perturbations we have ϵq=κqqsuperscriptitalic-ϵ𝑞superscript𝜅𝑞𝑞\epsilon^{q}=\kappa^{q}qitalic_ϵ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT = italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT italic_q and ϵk=κtkktsuperscriptitalic-ϵ𝑘subscriptsuperscript𝜅𝑘𝑡subscript𝑘𝑡\epsilon^{k}=\kappa^{k}_{t}k_{t}italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and so ϵtw=κtkqTkt+κqqTkt+κtkκqqTktsubscriptsuperscriptitalic-ϵ𝑤𝑡subscriptsuperscript𝜅𝑘𝑡superscript𝑞𝑇subscript𝑘𝑡superscript𝜅𝑞superscript𝑞𝑇subscript𝑘𝑡subscriptsuperscript𝜅𝑘𝑡superscript𝜅𝑞superscript𝑞𝑇subscript𝑘𝑡\epsilon^{w}_{t}=\kappa^{k}_{t}q^{T}k_{t}+\kappa^{q}q^{T}k_{t}+\kappa^{k}_{t}% \kappa^{q}q^{T}k_{t}italic_ϵ start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Each term recovers a factor of wt=qTktsubscript𝑤𝑡superscript𝑞𝑇subscript𝑘𝑡w_{t}=q^{T}k_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_q start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, which we factor out to give ϵtw=(κq+κtk+κtkκq)wtsubscriptsuperscriptitalic-ϵ𝑤𝑡superscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡subscriptsuperscript𝜅𝑘𝑡superscript𝜅𝑞subscript𝑤𝑡\epsilon^{w}_{t}=\left(\kappa^{q}+\kappa^{k}_{t}+\kappa^{k}_{t}\kappa^{q}% \right)w_{t}italic_ϵ start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. The final term is subleading in the limit of small perturbations, and so

ϵtwκq,κtk0(κq+κtk)wt+𝒪(κqκtk)superscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡0absentsubscriptsuperscriptitalic-ϵ𝑤𝑡superscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡subscript𝑤𝑡𝒪superscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡\epsilon^{w}_{t}~{}\xrightarrow[~{}\kappa^{q},\kappa^{k}_{t}\rightarrow 0~{}]{% }~{}\left(\kappa^{q}~{}+~{}\kappa^{k}_{t}\right)w_{t}~{}+~{}\mathcal{O}\left(% \kappa^{q}\kappa^{k}_{t}\right)italic_ϵ start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_ARROW start_UNDERACCENT italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT , italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT end_OVERACCENT → end_ARROW end_ARROW ( italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + caligraphic_O ( italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) (60)

Circuit collapse occurs when wtwt<ϵtwϵtwsubscript𝑤superscript𝑡subscript𝑤𝑡subscriptsuperscriptitalic-ϵ𝑤𝑡subscriptsuperscriptitalic-ϵ𝑤superscript𝑡w_{t^{*}}-w_{t}<\epsilon^{w}_{t}-\epsilon^{w}_{t^{*}}italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT < italic_ϵ start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_ϵ start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT for some t𝑡titalic_t. Substituting our limit for ϵtwsubscriptsuperscriptitalic-ϵ𝑤𝑡\epsilon^{w}_{t}italic_ϵ start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT gives

wtwt<(κq+κtk)wt(κq+κtk)wtsubscript𝑤superscript𝑡subscript𝑤𝑡superscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡subscript𝑤𝑡superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡subscript𝑤superscript𝑡w_{t^{*}}-w_{t}~{}<~{}\left(\kappa^{q}~{}+~{}\kappa^{k}_{t}\right)w_{t}-\left(% \kappa^{q}~{}+~{}\kappa^{k}_{t^{*}}\right)w_{t^{*}}italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT < ( italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - ( italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (61)

and collecting terms gives

(1+κq+κtk)wt<(1+κq+κtk)wt1superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡subscript𝑤superscript𝑡1superscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡subscript𝑤𝑡\left(1~{}+~{}\kappa^{q}~{}+~{}\kappa^{k}_{t^{*}}\right)w_{t^{*}}~{}<~{}\left(% 1~{}+~{}\kappa^{q}~{}+~{}\kappa^{k}_{t}\right)w_{t}( 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT < ( 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (62)

We then divide each side by wt(1+κq+κtk)subscript𝑤𝑡1superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡w_{t}(1+\kappa^{q}+\kappa^{k}_{t^{*}})italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ), taking care to reverse the sign of the inequality when this factor is negative, to give

[continued in next box…]

{mdframed}

[backgroundcolor=green!5] […continuing from previous box]

wtwt{<λwifwt(1+κq+κtk)>0>λwotherwiseλw1+κq+κtk1+κq+κtksubscript𝑤superscript𝑡subscript𝑤𝑡casesabsentsubscript𝜆𝑤ifsubscript𝑤𝑡1superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡0absentsubscript𝜆𝑤otherwisesubscript𝜆𝑤1superscript𝜅𝑞subscriptsuperscript𝜅𝑘𝑡1superscript𝜅𝑞subscriptsuperscript𝜅𝑘superscript𝑡\frac{w_{t^{*}}}{w_{t}}~{}\begin{cases}~{}<~{}\lambda_{w}&\mathrm{if}~{}w_{t}% \left(1+\kappa^{q}+\kappa^{k}_{t^{*}}\right)>0\\ ~{}>~{}\lambda_{w}&\mathrm{otherwise}\\ \end{cases}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\lambda_{w}~{}\triangleq~{}% \frac{1+\kappa^{q}+\kappa^{k}_{t}}{1+\kappa^{q}+\kappa^{k}_{t^{*}}}divide start_ARG italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG { start_ROW start_CELL < italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_CELL start_CELL roman_if italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) > 0 end_CELL end_ROW start_ROW start_CELL > italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_CELL start_CELL roman_otherwise end_CELL end_ROW italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ≜ divide start_ARG 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 + italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG (63)

which is the first expression in the theorem. We note that any temperature parameter cancels in the fraction, which means that the attention head cannot become more stable by reducing its temperature to become more sparse. λwsubscript𝜆𝑤\lambda_{w}italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT has the limits

λwκq0keysonly1+κtk1+κtkλwκtk,κtk0queryonly1+κq1+κq=1formulae-sequencesuperscript𝜅𝑞0keysonlysubscript𝜆𝑤1subscriptsuperscript𝜅𝑘𝑡1subscriptsuperscript𝜅𝑘superscript𝑡subscriptsuperscript𝜅𝑘𝑡subscriptsuperscript𝜅𝑘superscript𝑡0queryonlysubscript𝜆𝑤1subscript𝜅𝑞1subscript𝜅𝑞1\lambda_{w}~{}\xrightarrow[\kappa^{q}\rightarrow 0]{~{}~{}\mathrm{keys~{}only}% ~{}~{}}~{}\frac{1+\kappa^{k}_{t}}{1+\kappa^{k}_{t^{*}}}~{}~{}~{}~{}~{}~{}~{}~{% }~{}~{}~{}~{}~{}~{}~{}~{}~{}\lambda_{w}~{}\xrightarrow[\kappa^{k}_{t},\kappa^{% k}_{t^{*}}\rightarrow 0]{~{}~{}\mathrm{query~{}only}~{}~{}}~{}\frac{1+\kappa_{% q}}{1+\kappa_{q}}=1italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT start_ARROW start_UNDERACCENT italic_κ start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_keys roman_only end_OVERACCENT → end_ARROW end_ARROW divide start_ARG 1 + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG italic_λ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT start_ARROW start_UNDERACCENT italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_query roman_only end_OVERACCENT → end_ARROW end_ARROW divide start_ARG 1 + italic_κ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_ARG start_ARG 1 + italic_κ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_ARG = 1 (64)

meaning that query perturbations alone are insufficient, contributing only when they co-occur with perturbations on the keys. Write wt=|q||kt|cosθtsubscript𝑤𝑡𝑞subscript𝑘𝑡subscript𝜃𝑡w_{t}=|q||k_{t}|\cos\theta_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = | italic_q | | italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | roman_cos italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with θt=qktsubscript𝜃𝑡𝑞subscript𝑘𝑡\theta_{t}=q\wedge k_{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_q ∧ italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and the approximation of identical key norms kt=ktksubscript𝑘superscript𝑡subscript𝑘𝑡𝑘k_{t^{*}}=k_{t}\equiv kitalic_k start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≡ italic_k turns this into wt=|q||k|cosθtsubscript𝑤𝑡𝑞𝑘subscript𝜃𝑡w_{t}=|q||k|\cos\theta_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = | italic_q | | italic_k | roman_cos italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Then

wtwt=|q||k|cosθt|q||k|cosθt=cosθtcosθtsubscript𝑤superscript𝑡subscript𝑤𝑡𝑞𝑘subscript𝜃superscript𝑡𝑞𝑘subscript𝜃𝑡subscript𝜃superscript𝑡subscript𝜃𝑡\frac{w_{t^{*}}}{w_{t}}~{}=~{}\frac{|q||k|\cos\theta_{t^{*}}}{|q||k|\cos\theta% _{t}}~{}=~{}\frac{\cos\theta_{t^{*}}}{\cos\theta_{t}}divide start_ARG italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG = divide start_ARG | italic_q | | italic_k | roman_cos italic_θ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG | italic_q | | italic_k | roman_cos italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG = divide start_ARG roman_cos italic_θ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG roman_cos italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG (65)

Then θt=0subscript𝜃superscript𝑡0\theta_{t^{*}}=0italic_θ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = 0 means that cosθt=cos0=1subscript𝜃superscript𝑡01\cos\theta_{t^{*}}=\cos 0=1roman_cos italic_θ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = roman_cos 0 = 1, and so cosθtcosθt=1cosθt=secθtsubscript𝜃superscript𝑡subscript𝜃𝑡1subscript𝜃𝑡subscript𝜃𝑡\frac{\cos\theta_{t^{*}}}{\cos\theta_{t}}=\frac{1}{\cos\theta_{t}}=\sec\theta_% {t}divide start_ARG roman_cos italic_θ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG roman_cos italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG = divide start_ARG 1 end_ARG start_ARG roman_cos italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG = roman_sec italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. We perform a Taylor expansion in θtsubscript𝜃𝑡\theta_{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to obtain

wtwtsecθt1+12θt2+𝒪(θt4)subscript𝑤superscript𝑡subscript𝑤𝑡subscript𝜃𝑡112superscriptsubscript𝜃𝑡2𝒪superscriptsubscript𝜃𝑡4\frac{w_{t^{*}}}{w_{t}}~{}\approx~{}\sec\theta_{t}~{}\approx~{}1~{}+~{}\frac{1% }{2}\theta_{t}^{2}~{}+~{}\mathcal{O}\left(\theta_{t}^{4}\right)divide start_ARG italic_w start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ≈ roman_sec italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≈ 1 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + caligraphic_O ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ) (66)

which is valid when θt1much-less-thansubscript𝜃𝑡1\theta_{t}\ll 1italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1. This is true for any tt𝑡superscript𝑡t\neq t^{*}italic_t ≠ italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT for which ktsubscript𝑘𝑡k_{t}italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is far from orthogonal with ktsubscript𝑘superscript𝑡k_{t^{*}}italic_k start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. Substituting this into our circuit collapse condition, we have

1+12θt2<1+κtk1+κtkifwt(1+κtk)>0formulae-sequence112superscriptsubscript𝜃𝑡21subscriptsuperscript𝜅𝑘𝑡1subscriptsuperscript𝜅𝑘superscript𝑡ifsubscript𝑤𝑡1subscriptsuperscript𝜅𝑘superscript𝑡01~{}+~{}\frac{1}{2}\theta_{t}^{2}~{}<~{}\frac{1+\kappa^{k}_{t}}{1+\kappa^{k}_{% t^{*}}}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\mathrm{if}~{}w_{t}\left(1+% \kappa^{k}_{t^{*}}\right)>01 + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT < divide start_ARG 1 + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG roman_if italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( 1 + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) > 0 (67)

where we consider the case of κq0subscript𝜅𝑞0\kappa_{q}\approx 0italic_κ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ≈ 0 for readability. Re-arranging gives

12θt2κtkκtkCircuit collapse when ktsimilarless-than-or-similar-to12superscriptsubscript𝜃𝑡2subscriptsuperscript𝜅𝑘𝑡subscriptsuperscript𝜅𝑘superscript𝑡Circuit collapse when subscript𝑘𝑡similar\frac{1}{2}\theta_{t}^{2}~{}\lesssim~{}\kappa^{k}_{t}-\kappa^{k}_{t^{*}}~{}~{}% ~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\text{Circuit~{}collapse~{}when~{}}k% _{t}~{}\text{similar}divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≲ italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT Circuit collapse when italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT similar (68)

if wt(1+κtk)>0subscript𝑤𝑡1subscriptsuperscript𝜅𝑘superscript𝑡0w_{t}(1+\kappa^{k}_{t^{*}})>0italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( 1 + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) > 0, and we reverse the inequality otherwise. We have approximated the denominator on the RHS as 1+κtk11subscriptsuperscript𝜅𝑘superscript𝑡11+\kappa^{k}_{t^{*}}\approx 11 + italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ≈ 1 for κtk0subscriptsuperscript𝜅𝑘superscript𝑡0\kappa^{k}_{t^{*}}\rightarrow 0italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT → 0.

When θt1much-less-thansubscript𝜃𝑡1\theta_{t}\ll 1italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1, the LHS of Eq. 68 is small. This means that the attention head can tolerate only very small perturbations {κtk,κtk}subscriptsuperscript𝜅𝑘𝑡subscriptsuperscript𝜅𝑘superscript𝑡\{\kappa^{k}_{t},\kappa^{k}_{t^{*}}\}{ italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT }. Therefore semantic subspaces must either have a highly orthogonal substructure s.t. θt1ttgreater-than-or-equivalent-tosubscript𝜃𝑡1for-all𝑡superscript𝑡\theta_{t}\gtrsim 1~{}\forall~{}t\neq t^{*}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≳ 1 ∀ italic_t ≠ italic_t start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, or be orthogonal s.t. κt1tmuch-less-thansubscript𝜅𝑡1for-all𝑡\kappa_{t}\ll 1~{}\forall~{}titalic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1 ∀ italic_t.

{mdframed}

[backgroundcolor=green!5] Theorem.  14.  Sensitivity of isotropic attention to multiplicative perturbations. Say ϵk=κtkktsuperscriptitalic-ϵ𝑘subscriptsuperscript𝜅𝑘𝑡subscript𝑘𝑡\epsilon^{k}=\kappa^{k}_{t}k_{t}italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with κtk1much-less-thansubscriptsuperscript𝜅𝑘𝑡1\kappa^{k}_{t}\ll 1italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ 1 where {κt}subscript𝜅𝑡\{\kappa_{t}\}{ italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } have comparable amplitudes. Then

ϵΔx(k){0if κt independent of m~t, by symmetry0if κtκ for constant κ0if q=0wm~tκtktotherwisesuperscriptitalic-ϵΔ𝑥𝑘cases0if κt independent of m~t, by symmetry0if κtκ for constant κ0if q=0𝑤subscriptdelimited-⟨⟩subscript~𝑚𝑡subscriptsuperscript𝜅𝑘𝑡𝑡otherwise\epsilon^{\Delta x(k)}~{}\approx~{}\begin{cases}0~{}&~{}\text{if~{}$\kappa_{t}% $~{}independent~{}of~{}${\tilde{m}}_{t}$,~{}by~{}symmetry}\\ 0~{}&~{}\text{if~{}$\kappa_{t}\equiv\kappa$~{}for~{}constant~{}$\kappa$}\\ 0~{}&~{}\text{if~{}$q=0$}\\ w\langle{\tilde{m}}_{t}\kappa^{k}_{t}\rangle_{t}~{}&~{}\text{otherwise}\end{cases}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT ≈ { start_ROW start_CELL 0 end_CELL start_CELL if italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT independent of over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , by symmetry end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL if italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≡ italic_κ for constant italic_κ end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL if italic_q = 0 end_CELL end_ROW start_ROW start_CELL italic_w ⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL otherwise end_CELL end_ROW (69)

Proof.  We begin with the following result from Theorem 6:

ϵΔx(k)ϵtk0perturbkm~tϵtkTtqsubscriptsuperscriptitalic-ϵ𝑘𝑡0perturbksuperscriptitalic-ϵΔ𝑥𝑘subscriptdelimited-⟨⟩subscript~𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑡𝑞\epsilon^{\Delta x(k)}~{}\xrightarrow[\epsilon^{k}_{t}\rightarrow 0]{\mathrm{~% {}~{}perturb~{}k~{}~{}}}~{}\langle{\tilde{m}}_{t}{\epsilon^{k}_{t}}^{T}\rangle% _{t}~{}qitalic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → 0 end_UNDERACCENT start_ARROW start_OVERACCENT roman_perturb roman_k end_OVERACCENT → end_ARROW end_ARROW ⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_q (70)

Substituting ϵk=κtkktsuperscriptitalic-ϵ𝑘subscriptsuperscript𝜅𝑘𝑡subscript𝑘𝑡\epsilon^{k}=\kappa^{k}_{t}k_{t}italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = italic_κ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and taking q𝑞qitalic_q inside the brackets gives

m~tϵtkTtq=m~tκtktTtq=m~tκtwttsubscriptdelimited-⟨⟩subscript~𝑚𝑡superscriptsubscriptsuperscriptitalic-ϵ𝑘𝑡𝑇𝑡𝑞subscriptdelimited-⟨⟩subscript~𝑚𝑡subscript𝜅𝑡superscriptsubscript𝑘𝑡𝑇𝑡𝑞subscriptdelimited-⟨⟩subscript~𝑚𝑡subscript𝜅𝑡subscript𝑤𝑡𝑡\langle{{\tilde{m}}_{t}\epsilon^{k}_{t}}^{T}\rangle_{t}~{}q~{}=~{}\langle{% \tilde{m}}_{t}\kappa_{t}{k_{t}}^{T}\rangle_{t}q~{}=~{}~{}\langle{\tilde{m}}_{t% }\kappa_{t}w_{t}\rangle_{t}⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_q = ⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_q = ⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (71)

We then notice that isotropic attention requires that wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is a constant, which we call w𝑤witalic_w. Then

ϵΔx(k)wm~tκttsuperscriptitalic-ϵΔ𝑥𝑘𝑤subscriptdelimited-⟨⟩subscript~𝑚𝑡subscript𝜅𝑡𝑡\epsilon^{\Delta x(k)}~{}\approx~{}w\langle{\tilde{m}}_{t}\kappa_{t}\rangle_{t}italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT ≈ italic_w ⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (72)

is our general result. We then note three special cases, each resulting in ϵΔx(k)=0superscriptitalic-ϵΔ𝑥𝑘0\epsilon^{\Delta x(k)}=0italic_ϵ start_POSTSUPERSCRIPT roman_Δ italic_x ( italic_k ) end_POSTSUPERSCRIPT = 0:

  1. 1.

    If κtm~tperpendicular-tosubscript𝜅𝑡subscript~𝑚𝑡\kappa_{t}\perp{\tilde{m}}_{t}italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟂ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT then m~tκtt=mtκttmttκtt=Cov(mt,κt)=0subscriptdelimited-⟨⟩subscript~𝑚𝑡subscript𝜅𝑡𝑡subscriptdelimited-⟨⟩subscript𝑚𝑡subscript𝜅𝑡𝑡subscriptdelimited-⟨⟩subscript𝑚𝑡𝑡subscriptdelimited-⟨⟩subscript𝜅𝑡𝑡𝐶𝑜𝑣subscript𝑚𝑡subscript𝜅𝑡0\langle{\tilde{m}}_{t}\kappa_{t}\rangle_{t}=\langle m_{t}\kappa_{t}\rangle_{t}% -\langle m_{t}\rangle_{t}\langle\kappa_{t}\rangle_{t}=Cov(m_{t},\kappa_{t})=0⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - ⟨ italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟨ italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_C italic_o italic_v ( italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = 0. This is case when interference κtksuperscriptsubscript𝜅𝑡𝑘\kappa_{t}^{k}italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT on the keys is not dominated by the same semantic subspace as the message mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

  2. 2.

    If all keys are perturbed by the same factor κtκsubscript𝜅𝑡𝜅\kappa_{t}\equiv\kappaitalic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≡ italic_κ, then m~tκtt=κm~tt=0subscriptdelimited-⟨⟩subscript~𝑚𝑡subscript𝜅𝑡𝑡𝜅subscriptdelimited-⟨⟩subscript~𝑚𝑡𝑡0\langle{\tilde{m}}_{t}\kappa_{t}\rangle_{t}=\kappa\langle{\tilde{m}}_{t}% \rangle_{t}=0⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_κ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_κ ⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 0 because m~tt=0subscriptdelimited-⟨⟩subscript~𝑚𝑡𝑡0\langle{\tilde{m}}_{t}\rangle_{t}=0⟨ over~ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 0.

  3. 3.

    Isotropic attention can be achieved by either q=0𝑞0q=0italic_q = 0 or kt=constsubscript𝑘𝑡𝑐𝑜𝑛𝑠𝑡k_{t}=constitalic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_c italic_o italic_n italic_s italic_t. If the case is q=0𝑞0q=0italic_q = 0 then this implies w=0𝑤0w=0italic_w = 0 also.