Transformer Normalisation Layers and the Independence of Semantic Subspaces
Stephen Menary
University of Manchester, UK
[email protected] Samuel Kaski
1 University of Manchester, UK
2 Aalto University, Finland
[email protected] André Freitas
1 Department of Computer Science, University of Manchester, UK
2 Idiap Research Institute, Switzerland
3 National Biomarker Centre, CRUK-MI, University of Manchester, UK
[email protected]
Abstract
Recent works have shown that transformers can solve contextual reasoning tasks by internally executing computational graphs called circuits. Circuits often use attention to logically match information from subspaces of the representation, e.g. using position-in-sequence to identify the previous token. In this work, we consider a semantic subspace to be any independent subspace of the latent representation that can fully determine an attention distribution. We show that Pre-Norm, the placement of normalisation layer used by state-of-the-art transformers, violates this ability unless the model learns a strict representation structure of orthogonal spheres. This is because it causes linear subspaces to interfere through their common normalisation factor. Theoretically, we analyse circuit stability by modelling this interference as random noise on the -norms of the query/key/value vectors, predicting a phenomenon of circuit collapse when sparse-attention shifts to a different token. Empirically, we investigate the sensitivity of real-world models trained for mathematical addition, observing a 1% rate of circuit collapse when the norms are artificially perturbed by 10%. We contrast Pre-Norm with QKV-Norm, which places normalisation after the attention head’s linear operators. Theoretically this relaxes the representational constraints. Empirically we observe comparable in-distribution but worse out-of-distribution performance.
1 Introduction
Transformer-based models [1] are commonplace in machine learning, providing state-of-the-art contextual reasoning in domains ranging from natural language [2, 3] to protein-folding [4, 5, 6] and theoretical physics [7]. Recent interpretability work investigates the internal mechanisms that lead to specific model behaviours [8, 9, 10, 11, 12, 13, 14, 15]. This is important for predicting behaviour in new environments, enables practitioners to match the inductive bias of a model with the structure of its task, and informs the design of architectures that promote desirable behaviour.
Two such works discovered complete circuits [16] in trained transformers [8, 9, 10]. These are computational graphs that dominate the model prediction when activated in a specialised context. They perform a type of algorithmic reasoning by internally executing a sequence of logical operations, using attention to pass information between memory buffers that begin as token embeddings and become increasingly abstract. Furthermore, a number of attention heads have been identified as performing logical operations (see [11] section 5). To understand transformer behaviour, an important goal is to understand how logical attention heads operate, and their generality beyond the simple cases that facilitate interpretability.
One key observation is that the attention distribution is sometimes fully-determined by an independent subspace of the representation - for example, an attention layer can identify the previous token by accessing a subspace that encodes only position-in-sequence. Indeed, low-rank weight matrices can only access linear subspaces by construction. A second observation is that, like most deep architectures, transformers use normalisation layers to improve training stability. A leading choice is to place normalisation at the input to each attention layer, which we call Pre-Norm [17]. Some interpretability works ignore this layer because it has a linear-up-to-scale structure, absorbing the linear part into adjacent weights. In this work we argue that the layer is important, because Pre-Norm causes independent linear subspaces to interfere through a common normalisation factor, preventing their separation by linear attention layers.
The purpose of this work is to ask: if the use of independent subspaces is generally important, what are the expected consequences of Pre-Norm for (i) the latent representation structure, and (ii) circuit stability? To answer this, we take an abstract approach that complements direct interpretability by considering general behaviour beyond the interpretable limit. Our contributions are:
1.
Conceptual: we identify interference between independent subspaces as a potential destabiliser of circuits caused by Pre-Norm. We suggest separability of latent subspaces as a target for study, and show it is easily satisfied by the alternative QKV-Norm. This differs from Pre-Norm by placing the normalisation layer after the linear operators. It is similar to QK-Norm, for which sparse evidence currently exists [18, 19, 20].
2.
Theoretical: we formalise a semantic subspace as any independent subspace of the latent representation that can fully determine the attention distribution. We show that Pre-Norm can only achieve this when semantic subspaces are spherical and mutually orthogonal. By contrast, QKV-Norm requires only that subspaces be linearly independent, matching the No-Norm case in this sense. We study the stability of attention to subspace interference, predicting a potentially problematic phenomenon of circuit collapse when a sparse-attention distribution changes which embedding it attends to.
3.
Experimental: we measure the sensitivity of trained models to simulated interference in a numerical addition task. Constraining our predictions, we find that (i) Pre-Norm models induce a narrower distribution of embedding -norms than QKV-Norm, (ii) we bound the spread of -norms to with 90% coverage, and (iii) the circuit collapse phenomenon occurs at a rate of 1% when norms are perturbed by .
2 The idea
Independent subspaces are observed in real-world transformer circuits
Before providing a formal definition in section 5, we explain what we mean by a semantic subspace of the latent representation. To emphasise that this is observed in real-world models, we use a known example: the induction circuit [8, 9]. This two-layer circuit emerges in next-token-prediction models and implements a simple contextual reasoning algorithm called prefix-matching.
Consider text to be a sequence of tokens111In this example, we tokenise per-word to help with visualisation., and our task is to predict the next token at every point. The induction circuit solves this by copying a previous example from the context window: e.g. if the input includes the phrase “Harry Potter” and the last observed word was “Harry”, the induction circuit will predict that “Potter” comes next. This solves the task even if the combination “Harry Potter” never occurred in the training data.
To achieve this, we initially create an embedding for each token, encoding it’s position and type. Attention layers then copy information between embeddings in a directed way, using two components that determine (i) which embeddings to extract information from, and (ii) what to extract. Remarkably, the model learns to implement logical gates that we will call “match&pass”, internally composing the algorithm:
Each match&pass step operates only on an independent subspace of information, which we will call a semantic subspace. In this example, there are four semantic subspaces corresponding to position, type, prev-type, and pred-suffix. We observe that the latent embeddings can contain various information, and it is instructive to think of them as memory buffers rather than tokens. The principle of composing logical operations that act on latent semantic subspaces is also observed in the more complex example of indirect-object identification in GPT2-Small [10].
The problem with Pre-Norm
We express the latent embeddings as where encodes the value of concept . This is important, because linear-attention layers extract information from using linear operators (section 4), and can only isolate if each subspace is linearly independent. In other words, there must always exist a linear projection operator such that .
Most transformers use either RMSNorm [21] or LayerNorm [22] for their internal normalisation layers. Geometrically, RMSNorm projects a vector onto the unit-sphere according to
(1)
LayerNorm is similar, projecting onto the sphere defined perpendicular to the direction . This does not affect our analysis, and we focus on RMSNorm for simplicity. Normalisation layers sometimes also include gain and/or bias parameters, applying a stretch-and-translate to the sphere. Pre-Norm [17] normalises the latent embeddings at the input to every attention layer. Consider the example . Applying Pre-Norm, we find is replaced by
(2)
Therefore it is impossible for a linear-attention layer to extract without interference from and , unless is a constant. In general, we have , and semantic subspaces are entangled unless is constant. This is only possible if , i.e. every subspace is a sphere, and , i.e. all spheres are orthogonal (to maintain independence). This has several possible implications:
1.
It is a restrictive structure that must be learned during training, with unknown difficulty. Finite steps of gradient descent may separate the model from the manifold of acceptable representations, hindering the learning of circuit components that require semantic separation, like match&pass, especially when training with large learning rates.
2.
The constraint removes a degree of freedom for every , reducing the information capacity of the embedding space. For example, an embedding on could have the two-subspace structure but not .
3.
We hypothesise that the structure may be violated by (i) a tradeoff with other representational effects, (ii) imperfect model training, or (iii) encountering unexpected semantic combinations at inference-time when generalising out-of-distribution. These would cause semantic subspaces to interfere through their common normalisation factor, manifesting as noise on the -norms of the {query, key, value} vectors.
4.
It is a structure that we can search for empirically.
A possible solution: QKV-Norm
A natural fix could be to apply the normalisation layer after the linear operators. In practice this means that we normalise the {query, key, value} vectors, called QKV-Norm and defined in section 4.
Paper strategy
Our work is based on three key observations: (i) semantic subspaces are observed in known circuits, (ii) they contribute to the model behaviour, and (iii) Pre-Norm requires them to follow a strict latent embedding structure or else interfere through the -norms of the {query, key, value} vectors.
However, it is difficult to demonstrate specific examples of subspace interference. Firstly, a fully-converged model should learn to manage interference for in-distribution examples. Instead, we expect it to concern (i) training stability, (ii) model inductive bias, and (iii) out-of-distribution behaviour. Secondly, circuit explainability is difficult, only being achieved in simple cases. In general we expect circuits to become complicated, contain steps that are harder to interpret than match&pass, and exploit non-interpretable latent subspaces. Difficulty is further increased by polysemanticity [23], the ability for heads and features to change behaviour according to context.
In this work, we take an abstract approach instead. We formally define latent semantic separability, then investigate the theoretical consequences for Pre-Norm architectures if this behaviour is important generally. This allows us to make testable predictions about representation structure and model stability without needing to fully reverse-engineer a network or explain subspaces in human terms. We then place some data-driven limits on the effect size. Nonetheless, direct observation remains important, and we hope that future works can confirm or falsify the importance of the proposed representation structure and interference effect.
3 Related Works
Our work is motivated by transformer circuit discovery [8, 9, 10, 24, 25, 26] and formation [27, 28]. See [11] for a recent review of interpretability for language decoder models, with a list of known logical operations implemented by attention heads. This builds upon works in BERTology [13, 29]. We study normalisation, for which several formulations have been proposed [17, 30, 31, 32, 33]. Our QKV-Norm variant is similar to QK-Norm, which is studied by [18, 19, 20] for asymptotic performance and training stability at large learning rates. These are motivated by logit-regularisation, whereas we are motivated by representational inductive bias and stability to latent semantic interference.
We highlight other works that study transformer normalisation through its geometric interpretation as a projection onto a sphere. [34] investigated the role of normalisation in mixing the attention output with the residual stream in Post-Norm models, but does not consider Pre-Norm. [35] studies the computational abilities of Pre-LayerNorm architectures, in particular demonstrating that projection onto a sphere ensures that all keys reside on their own convex hull, preventing them from becoming “unselectable”. [36] interprets the latent embeddings of Pre-Norm models as a trajectory on a sphere. These works do not consider the interference of semantic subspaces. [37] and the contemporary work [38] study the role of LayerNorm in the related phenomenon of embedding rank collapse.
We highlight the contemporary work of [39], who also study multi-step contextual reasoning in transformers using matching operations over independent subspaces, for both Pre-Norm and Post-Norm. This builds upon [40], who study the learning of abstract symbolic reasoning in transformers, and works that manipulate the flow of information to promote algorithmic reasoning, e.g. [41].
We are not aware of previous works that study the impact of Pre-Norm’s spherical geometry on the structure of latent subspaces. However, many works consider linear subspaces, described in the following paragraph. These results are directly applicable to the No-Norm and QKV-Norm methods in this work, although QKV-Norm applies a subsequent spherical projection. [42] design subspace separability into their model by decoupling the normalisation layers for different mechanisms.
Works on vector embeddings [43, 44, 45] and the linear representation hypothesis [46, 47, 48] study the emergence of linear subspaces that encode separable concepts in embedding-unembedding models, using both interpretation and intervention techniques. Many works search for linear subspaces/directions in a transformer representation (e.g. linear probes [49, 50]) or search for faithful causal abstractions (e.g. [51]), with a survey provided in [11] sections 3-4. We also highlight works that study the use of features in linear superposition [52, 23]. This allows a model to store more features than it has dimensions, at the cost of interference in their linear projections.
The terminology of semantic subspaces is used more generally, e.g. [49, 53, 54]. We consider a definition that does not require humans to define the separable concepts, only that abstract latent features remain independent in an attention layer. We also highlight works that study subspaces of static (model input) and contextual (latent or model output) embeddings in transformers, e.g. [55, 54, 56, 57, 58, 59, 60, 61, 62] (review in [13]). These are relevant because they also decompose embeddings into a combination of abstract subspaces, capturing different semantic and syntactic structures in a natural language setting. These may be used as semantic subspaces in our work. We highlight [57] which studies interference between positional and contextual components using a decomposition similar to ours, and also experiments using a next-token addition task.
4 Formulation
Consider the No-Norm case. Let be an unordered set of message receiving tokens, and the message senders. Let be the -dimensional representation of an element in , and be the element in , with . For self-attention we have . Let and be the query and key weight matrices, with associated vectors and on an -dimensional latent space. We do not include biases in because they contribute terms that are nullified by the softmax, or are reproduced by constant directions in (Theorem 12). We define dot-product attention scores as:
(3)
where is a matrix with . This is the maximum span of the attended subspace in . The attention weights are
(4)
Let be the value vectors with . We do not include biases in because they carry no dependence on the attended token. Each token emits the message where is the output-matrix. Each attention-head updates by adding the attention-weighted convex combination of messages, with . We usually run attention-heads in parallel, giving the total update:
(5)
with unique weights for each head index .
We now introduce normalisation layers. Let be any -dimensional vector, then is a normalisation function with parameters . We consider two such functions:
(6)
[21, 22] where is a linear operator that subtracts the mean of from every component, is vector of ones, and is the component of perpendicular to .
The Pre-Norm strategy means applying normalisation to the inputs . The QKV-Norm strategy means applying normalisation to the vectors . We then have three cases:
Norm params
No-Norm
-
(baseline)
Pre-Norm
(alternate)
QKV-Norm
We note that several of these degrees of freedom are redundant and could be combined, e.g. and . We do not consider these variations (i) because they are not relevant for the results of this paper, and (ii) to standardise the number of training parameters.
5 Theory: representation structure required for independent subspaces
Let be an -dimensional latent representation of or .
{mdframed}
[backgroundcolor=red!5]
[Definition] Semantic subspace: any independent -dimensional subspace for which every element may be uniquely identified by some parameters , such that it is possible for the attention scores to be fully specified by . Semantic separability: ability for parallel heads to be fully specified by different semantic subspaces.
Let be the set of indivisible semantic subspaces. This can be seen as a co-ordinate system for the attendable embedding space. Semantic separability requires that each co-ordinate be independently measurable by an attention head. Let contain indivisible semantic subspaces . Then such that satisfies semantic separability, where are Cartesian products and is a separable space of non-attended information.
The following theorems derive the representation structure required for semantic separability:
{mdframed}
[backgroundcolor=blue!5]
Semantically separable representation structures[Proofs in appendix F]
Theorem 1.
No-Norm: If two heads with finite non-zero temperature attend to different semantic subspaces, the subspaces must be linearly independent . Corollary: is a low-rank matrix with (left and right) null-spaces that span all non-attended information.
Theorem 2.
Pre-Norm: Semantic subspaces must be represented as orthogonal spheres defined using the -norm. Corollary: if either orthogonality or constant-norm are violated, semantic subspaces interfere through a multiplicative factor on .
Theorem 3.
QKV-Norm: Semantic subspaces must be linearly independent.
We note that every linear subspace has continuous degrees of freedom, whilst has only , the other being removed by the fixed-norm constraint. The subspace is allowed and may be seen as a binary variable with values , and the total representation can store such variables. For QKV-Norm, we note that the residual subspace only contributes continuous degrees of freedom to the attention calculation, because we apply the projection after extracting the subspace. Table 2 provides a summary.
Structure of messages
We note the special case of compositional annotation, in which a layer creates a semantic subspace that is extracted by a later layer. This is used by circuits including the induction circuit [9] described in section 2. By normalising the inputs, Pre-Norm induces a spheroid message structure close to the sphere required for separability in later layers. This may facilitate compositional annotation, aiding in circuit-formation. Message structures are summarised in Table 2.
Strategy
Representation structure
Attendable d.o.f.
No-Norm
Linearly independent subspaces
Pre-LayerNorm
Orthogonal spheres
Pre-RMSNorm
Orthogonal spheres
QKV-Norm
Linearly independent subspaces
Table 1: Representation structure required for semantic separability; d.o.f. means degrees of freedom.
Strategy
Structure of
Compositional annotation if
No-Norm
Linear
on independent subspace
Pre-Norm
Spheroid
on orthogonal sphere
QKV-Norm
Spheroid
on independent subspace
Table 2: Summary of message structures induced by different placements of normalisation layer.
6 Theory: stability to subspace interference
We now investigate the impact of interfering subspaces. Consider the almost-separable limit, modelling interference as a random infinitesimal perturbation of the vectors . Let -symbols denote perturbations such that for is the change of induced by . We consider (i) the sparse limit, in which the attention is concentrated entirely on a single embedding, and (ii) the isotropic limit, in which it is distributed evenly among embeddings. We are particularly interested in the sparse case, since this highly directed flow of information is used by match&pass, although semantic separation can also be used by non-sparse heads.
{mdframed}
[backgroundcolor=red!5]
[Definition] Sparse attention: the low-temperature limit and , where is the Kronecker delta. This occurs when there is a large difference between the top two scores: and . Isotropic attention: the high-temperature limit and . This occurs when is constant, requiring or constant .
{mdframed}
[backgroundcolor=blue!5]
Stability of attention updates to perturbations on q, k, v[Proofs in appendix F]
Theorem 4.
Consider independent infinitesimal perturbations on queries , keys , and messages . These propagate onto as
(7)
(8)
(9)
where is the value of measured from the attention-weighted centroid .
Theorem 5.
For sparse attention:
(10)
i.e. the message is stable with respect to small interference in the queries and keys. Interference in the selected value is linearly transferred onto the message.
Theorem 6.
For isotropic attention:
(11)
N.B. isotropy requires or . Lemma 1: the update is stable to noisy when , or when (c.f. keys and messages from independent subspaces). Lemma 2: the update is stable to noisy when , or when . Lemma 3: the update is stable to noisy when . Other cases propagate linearly.
The stability of sparse attention is because softmax becomes an argmax for low-temperature heads, which is only sensitive to the order of . However, this introduces a different vulnerability when perturbations cause the order of to change, as the attention distribution undergoes a phase transition to select a different token. We call this circuit collapse. For example, the induction circuit collapses when the operation attend to the previous token attends to any other token because of interference.
{mdframed}
[backgroundcolor=red!5]
[Definition] Circuit collapse: spontaneous phase transition in which a sparse attention distribution selects a different token due to noise on . Let be perturbations on that result from and . Circuit collapse occurs when there exists a for which .
We now study the -norm interference that we expect to be induced by Pre-Norm when semantic separability is violated. This is characterised by perturbations that are parallel to their corresponding vector. Theorem 7 shows the conditions under which we expect circuit collapse to occur.
{mdframed}
[backgroundcolor=blue!5]
Stability of attention updates to scaling of q, k, v[Proofs in appendix F]
Theorem 7.
Sensitivity of sparse attention to multiplicative perturbations and with . Circuit collapse occurs when for which:
(12)
where temperature cancels in the fraction. Attention is fully stable above the critical transition point (c.f. ). We see that query perturbations alone are insufficient, as they result in . Lemma: consider the special case when all keys have similar length , the attended token has , the keys are far-from-orthogonal s.t. , and . Using , circuit collapse occurs when for which:
(13)
i.e. stability requires either well-separated keys s.t. , or small perturbations .
Theorem 8.
Sensitivity of isotropic attention to multiplicative perturbations. Say with where have comparable amplitudes. Then
(14)
7 Experimental results
We now use experiments to empirically probe (i) the real-world embedding structure, and (ii) the sensitivity to artificial noise on the {query, key, value} -norms. Whilst this does not directly observe real-world interference, it constrains the effect importance.
We consider a base-10 integer-addition task with a question-answer structure, and train for next-token prediction. We use a decoder architecture, common for state-of-the-art language models, with layers, per-character tokenisation, and begin [ and end ] tokens. In the output, we mask * tokens that precede the answer. For example, the first training sequence has input [453+16+17-N846=1332 and output ***************1332]. We compare two models that use Pre-Norm and QKV-Norm respectively. Appendices A-C provide a full experimental setup and supplementary plots. In this section we make all plots using an in-distribution test set that is expected to have some overlap with the training set, bounded at .
We choose this task because it emphasises contextual reasoning in a small-scale setting, is configurable for complexity, and allows us to define meaningful out-of-distribution test sets. The Pre-Norm (QKV-Norm) model achieves an in-distribution per-token accuracy of (), drop** to () when generalising out-of-distribution to intermediate complexity, and () for increased complexity. Statistical uncertainties are below . The in-distribution performance is comparable, but QKV-Norm generalises worse in this task, implying it has learned less task-appropriate solutions. Appendix D shows additional comparisons suggesting that the Pre-Norm and QKV-Norm models behave differently, supporting the observations of [18, 19, 20].
Embedding structure
Our theory predicts that Pre-Norm attention is stable with respect to information in non-attended subspaces if all input embeddings have similar -norms, whereas QKV-Norm imposes no norm constraint. We seek to experimentally bound the degree to which this structure is learned in practice.
We do this by plotting the spread of norms with respect to their median. A confounding effect is that the norms may differ for (i) embeddings attended to by different heads, (ii) the same head acting in different contexts, and (iii) embeddings that are never attended. We therefore measure the ratio per-head, and weight each embedding by its assigned attention. We remove the begin-sequence token from consideration. Figure 1 shows the resulting spread for all attention layers. On the LHS, we see that 90% of the distribution is contained within an interval of when using Pre-Norm. On the RHS, we see that QKV-Norm allows a much wider spread. This is consistent with our theory, and experimentally bounds the representation effect on Pre-Norm to in this model. Supplementary Figures 16-16 show consistent results for two model variations, although we note that Pre-Norm and QKV-Norm are more comparable for the variation labelled Alternate.
Figure 1: Spread of embedding -norms experienced by attention heads at increasing model depth, excluding the [ token. For Pre-Norm, 90% of the spread is observed within an interval of . Supplementary Figure 12 shows the distributions used to make this plot. Supplementary Figures 16-16 replicate the analysis for two model variations.
Model stability with simulated interference
Section 6 theoretically modelled the semantic interference induced by Pre-Norm as a random perturbation on the norms of . To estimate the real-world sensitivity to such an effect, we artificially introduce uncorrelated uniform noise onto these norms inside our trained Pre-Norm model. Even though Gaussian noise is expected in the large- limit, we use uniform noise to avoid outliers. Figure 3 shows the evolution of in-distribution per-token accuracy with increasing RMS. On the LHS, we see that performance falls by at only a noise level. We also show the trend excluding the end-sequence token, which contributes a significant fraction of the metric. On the RHS, we introduce noise only to sparse heads (when ) and non-sparse heads (when ). We see that the model is stable with respect to -scale noise on sparse-attention, and this regime is dominated by the non-sparse case.
Figure 3 (right) is consistent with the stability predictions of Theorems 5-6. However, it may also be explained if non-sparse distributions are simply more important to the model. This could be caused by non-sparse distributions being more common, as well as depth-dependence. This is because artificial noise is applied to all layers during the forward pass, therefore later layers are perturbed by both the noise component and the shifting of their inputs due to previous layers, which is expected to compound with depth. We are interested in capturing this effect, however it may increase the importance of early layers. See Figure 27 for a visualisation of the observed attention maps.
Figure 2: Left: evolution of per-token accuracy as we increase noise on the -norms of . A drop in performance is observed when noise is applied to all layers. Right: applying noise only to , we see that non-sparse attention drives the drop at small noise, whereas the sparse case is stable. This is consistent with Theorems 5-6, but this interpretation is confounded by the relative importance of non-sparse distributions caused by frequency and depth-dependence.
Figure 3: Probability of circuit collapse vs increasing noise. This observes the effect predicted in Section 6, and measures that of sparse distributions collapse at a noise level of .
Circuit collapse
Figure 3 shows the probability that our artificial noise causes the circuit collapse phenomenon as defined in section 6. In this experiment, we add noise to every layer independently. This prevents the confounding effect of shifting inputs due to noise in previous layers. We observe that of sparse attention distributions collapse when they experience noise at a level of . This reduces to and for the two model variations shown in Appendix C.
8 Summary & Outlook
We have presented the idea that transformer Pre-Norm can cause interference between independent subspaces of the latent embeddings, a feature used by some real-world transformer circuits. Theoretically, we found this can only be avoided when using an embedding structure of orthogonal spheres. By contrast, the QKV-Norm architecture requires only linearly independent subspaces. We predict that sparse attention is stable with respect to interference, until a certain threshold of noise is reached, at which point it undergoes a phase transition called circuit collapse.
Empirically, we observe that the -norms of attended embeddings are contained within a spread of for Pre-Norm (with 90% coverage), whilst QKV-Norm creates a wider spread. We simulate interference by introducing artificial noise onto the -norms of in our trained Pre-Norm model, observing that of sparse distributions collapse at a noise level of . We observe that per-token accuracy degrades by when norms are simultaneously perturbed by noise of 1% in all layers, but is stable to %-scale noise in only sparse distributions. This may be attributed to either the predicted stability of sparse attention, or to a difference in the importance of sparse vs non-sparse heads induced by frequency and depth-dependence. More work is needed to disentangle these.
This work contributes a theoretical hypothesis of model behaviour, and empirically constrains the effect size without full model reverse-engineering. We have made predictions on representation structure, interference, and circuit collapse that practitioners may search for in their own models.
9 Limitations
We have not directly observed subspace independence or interference, and further work is required to establish their importance in real-world models. Experimentally, we simulate interference as being independent and similar in amplitude across heads and layers, however it is possible that it is correlated and depth-dependent. Whilst our stability experiments demonstrate that the model is more stable with respect to noise in sparse than non-sparse distributions, we have not shown whether this is due to the inherent stability of the attention distribution predicted by our theory, or the relative importance of sparse vs non-sparse distributions to the model. We show experimental results for a small model on a targeted task (with model variations in Appendix C); further work is needed to study the behaviour of larger models and different corpora.
Acknowledgments and Disclosure of Funding
This work is supported by UKRI Turing AI World-Leading Researcher Fellowship (EP/W002973/1). This work was partially funded by the Swiss National Science Foundation (SNSF) project NeuMath (200021_204617), by the EPSRC grant EP/T026995/1, “EnnCore: End-to-End Conceptual Guarding of Neural Architectures” under Security for all in an AI enabled society, by the CRUK National Biomarker Centre, and supported by the Manchester Experimental Cancer Medicine Centre and the NIHR Manchester Biomedical Research Centre.
References
[1]
Vaswani et al.
Attention is all you need.
In Advances in Neural Information Processing Systems,
volume 30. Curran Associates, Inc., 2017.
[2]
Bubeck et al.
Sparks of artificial general intelligence: Early experiments with
gpt-4, 2023.
[3]
Touvron et al.
Llama: Open and efficient foundation language models, 2023.
[4]
Abramson et al.
Accurate structure prediction of biomolecular interactions with
alphafold 3.
Nature, 2024.
[5]
Baek et al.
Accurate prediction of protein structures and interactions using a
three-track neural network.
Science, 373(6557):871–876, 2021.
[6]
Chandra et al.
Transformer-based deep learning for predicting protein properties in
the life sciences.
eLife, 12:e82819, jan 2023.
[7]
Cai et al.
Transforming the bootstrap: Using transformers to compute scattering
amplitudes in planar n = 4 super yang-mills theory, 2024.
[8]
Olsson et al.
In-context learning and induction heads.
Transformer Circuits Thread, 2022.
https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html.
[9]
Elhage et al.
A mathematical framework for transformer circuits.
Transformer Circuits Thread, 2021.
https://transformer-circuits.pub/2021/framework/index.html.
[10]
Wang et al.
Interpretability in the wild: a circuit for indirect object
identification in GPT-2 small.
In The Eleventh International Conference on Learning
Representations, 2023.
[11]
Ferrando et al.
A primer on the inner workings of transformer-based language models,
2024.
[12]
Meng et al.
Locating and editing factual associations in GPT.
In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho,
editors, Advances in Neural Information Processing Systems, 2022.
[13]
Rogers et al.
A primer in bertology: What we know about how bert works.
Transactions of the Association for Computational Linguistics,
8:842–866, 12 2020.
[14]
Liu et al.
Transformers learn shortcuts to automata.
In The Eleventh International Conference on Learning
Representations, 2023.
[15]
Goldowsky-Dill et al.
Localizing model behavior with path patching, 2023.
[16]
Cammarata et al.
Thread: Circuits.
Distill, 2020.
https://distill.pub/2020/circuits.
[17]
Xiong et al.
On layer normalization in the transformer architecture.
In Proceedings of the 37th International Conference on Machine
Learning, ICML’20. JMLR.org, 2020.
[18]
Henry et al.
Query-key normalization for transformers.
CoRR, abs/2010.04245, 2020.
[19]
Wortsman et al.
Small-scale proxies for large-scale transformer training
instabilities.
In The Twelfth International Conference on Learning
Representations, 2024.
[20]
Dehghani et al.
Scaling vision transformers to 22 billion parameters.
In Proceedings of the 40th International Conference on Machine
Learning, volume 202 of Proceedings of Machine Learning Research,
pages 7480–7512. PMLR, 23–29 Jul 2023.
[21]
Biao Zhang and Rico Sennrich.
Root mean square layer normalization.
In Wallach et al, editor, Advances in Neural Information
Processing Systems, volume 32. Curran Associates, Inc., 2019.
[22]
Ba et al.
Layer normalization.
CoRR, abs/1607.06450, 2016.
[23]
Elhage et al.
Toy models of superposition.
Transformer Circuits Thread, 2022.
[24]
Stolfo et al.
A mechanistic interpretation of arithmetic reasoning in language
models using causal mediation analysis, 2023.
[25]
Goldowsky-Dill et al.
Localizing model behavior with path patching.
ArXiv, abs/2304.05969, 2023.
[26]
Javier Ferrando and Elena Voita.
Information flow routes: Automatically interpreting language models
at scale.
ArXiv, abs/2403.00824, 2024.
[27]
Singh et al.
The transient nature of emergent in-context learning in transformers.
In Thirty-seventh Conference on Neural Information Processing
Systems, 2023.
[28]
Singh et al.
What needs to go right for an induction head? a mechanistic study of
in-context learning circuits and their formation, 2024.
[29]
Devlin et al.
Bert: Pre-training of deep bidirectional transformers for language
understanding.
In North American Chapter of the Association for Computational
Linguistics, 2019.
[30]
Nguyen et al.
Transformers without tears: Improving the normalization of
self-attention.
In Proceedings of the 16th International Conference on Spoken
Language Translation, Hong Kong, November 2-3 2019. Association for
Computational Linguistics.
[31]
Toan Nguyen and David Chiang.
Improving lexical choice in neural machine translation.
In Proceedings of the 2018 Conference of the North American
Chapter of the Association for Computational Linguistics: Human Language
Technologies, Volume 1 (Long Papers), pages 334–343, New Orleans,
Louisiana, June 2018. Association for Computational Linguistics.
[32]
Shleifer et al.
Normformer: Improved transformer pretraining with extra
normalization, 2021.
[33]
Xu et al.
Understanding and improving layer normalization.
In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural
Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
[34]
Kobayashi et al.
Incorporating Residual and Normalization Layers into
Analysis of Masked Language Models.
In Proceedings of the 2021 Conference on Empirical Methods in
Natural Language Processing, pages 4547–4568, Online and Punta Cana,
Dominican Republic, November 2021. Association for Computational Linguistics.
[35]
Brody et al.
On the expressivity role of layernorm in transformers’ attention.
pages 14211–14221, 01 2023.
[36]
Raul Molina.
Traveling words: A geometric interpretation of transformers.
ArXiv, abs/2309.07315, 2023.
[37]
Dong et al.
Attention is not all you need: Pure attention loses rank doubly
exponentially with depth.
PMLR, 139, 2021.
[38]
Wu et al.
On the role of attention masks and layernorm in transformers, 2024.
[39]
Wang et al.
Towards understanding how transformer perform multi-step reasoning
with matching operation, 2024.
[40]
Boix-Adserà et al.
When can transformers reason with abstract symbols?
In The Twelfth International Conference on Learning
Representations, 2024.
[41]
Csordá et al.
The neural data router: Adaptive control flow in transformers
improves systematic generalization.
In International Conference on Learning Representations, 2022.
[42]
Lambd et al.
Transformers with competitive ensembles of independent mechanisms.
ArXiv, abs/2103.00336, 2021.
[43]
Mikolov et al.
Efficient estimation of word representations in vector space.
pages 1–12, 01 2013.
[44]
Mikolov et al.
Distributed representations of words and phrases and their
compositionality.
In C.J. Burges, L. Bottou, M. Welling, Z. Ghahramani, and K.Q.
Weinberger, editors, Advances in Neural Information Processing Systems,
volume 26. Curran Associates, Inc., 2013.
[45]
Jeffrey Pennington, Richard Socher, and Christopher D. Manning.
Glove: Global vectors for word representation.
In Conference on Empirical Methods in Natural Language
Processing, 2014.
[46]
Mikolov et al.
Linguistic regularities in continuous space word representations.
Proceedings of NAACL-HLT, pages 746–751, 01 2013.
[47]
Park et al.
The linear representation hypothesis and the geometry of large
language models.
ArXiv, abs/2311.03658, 2023.
[48]
Jiang et al.
On the origins of linear representations in large language models,
2024.
[49]
Dmitry Nikolaev and Sebastian Padó.
Investigating semantic subspaces of transformer sentence embeddings
through linear structural probing.
10 2023.
[50]
Yonatan Belinkov.
Probing classifiers: Promises, shortcomings, and advances.
Computational Linguistics, 48(1):207–219, March 2022.
[51]
Geiger et al.
Finding alignments between interpretable causal variables and
distributed neural representations.
ArXiv, abs/2303.02536, 2023.
[52]
Sanjeev et al.
Linear algebraic structure of word senses, with applications to
polysemy.
Transactions of the Association for Computational Linguistics,
6, 01 2016.
[53]
Tripathi et al.
Semantic subspace learning with conditional significance vectors.
In The 2010 International Joint Conference on Neural Networks
(IJCNN), pages 1–8, 2010.
[54]
Coenen et al.
Visualizing and measuring the geometry of bert.
In Neural Information Processing Systems, 2019.
[55]
Hewitt et al.
A structural probe for finding syntax in word representations.
In North American Chapter of the Association for Computational
Linguistics, 2019.
[56]
Kawin Ethayarajh.
How contextual are contextualized word representations? Comparing
the geometry of BERT, ELMo, and GPT-2 embeddings.
In Proceedings of the 2019 Conference on Empirical Methods in
Natural Language Processing and the 9th International Joint Conference on
Natural Language Processing (EMNLP-IJCNLP), pages 55–65, Hong Kong, China,
November 2019. Association for Computational Linguistics.
[57]
Song et al.
Uncovering hidden geometry in transformers via disentangling position
and context, 2024.
[58]
Mickus et al.
How to dissect a muppet: The structure of transformer embedding
spaces.
Transactions of the Association for Computational Linguistics,
10:981–996, 09 2022.
[59]
Hernandez et al.
Linearity of relation decoding in transformer language models.
In The Twelfth International Conference on Learning
Representations, 2024.
[60]
Chi et al.
Finding universal grammatical relations in multilingual BERT.
In Proceedings of the 58th Annual Meeting of the Association for
Computational Linguistics, pages 5564–5577, Online, July 2020. Association
for Computational Linguistics.
[61]
Cai et al.
Isotropy in the contextual embedding space: Clusters and manifolds.
In International Conference on Learning Representations, 2021.
[62]
Evan Hernandez and Jacob Andreas.
The low-dimensional linear geometry of contextualized word
representations.
In Conference on Computational Natural Language Learning, 2021.
[63]
Martín Abadi et al.
TensorFlow: Large-scale machine learning on heterogeneous systems,
2015.
Software available from tensorflow.org.
[64]
Harris et al.
Array programming with NumPy.
Nature, 585(7825):357–362, September 2020.
[65]
Abien Fred Agarap.
Deep learning using rectified linear units (relu), 2019.
[66]
Xavier Glorot and Yoshua Bengio.
Understanding the difficulty of training deep feedforward neural
networks.
In Yee Whye Teh and Mike Titterington, editors, Proceedings of
the Thirteenth International Conference on Artificial Intelligence and
Statistics, volume 9 of Proceedings of Machine Learning Research,
pages 249–256, Chia Laguna Resort, Sardinia, Italy, 13–15 May 2010. PMLR.
[68]
Ilya Loshchilov and Frank Hutter.
Decoupled weight decay regularization, 2019.
Appendix A Experimental setup
A.1 Data
Due to the task nature, we do not require static datasets and so generate both train and test data on-the-fly. This alleviates storage and memory concerns for long training runs in which a static dataset would have to be large. Datasets are reproducible through configuration of the environment and global random seed, which is used to manually control the random seeds of Python, TensorFlow [63], and NumPy [64]. This also reproduces the model initialisation.
A.2 Task specification
We consider an integer addition task, where each character is a base-10 numeral 0-9, mathematical operator {+, -, =, N}, or special character {[, ], *}. The N operator signifies that the following integer is negative, and is used to avoid overloading notation with the - operator, which means minus. The special characters are the begin-sequence token [, end-sequence token ], and mask character *. Input sequences in the same batch are right-padded with mask tokens to the same length, which do not contribute to the model. Characters that are masked in the output do not contribute to the evaluation metrics. We tokenise per-character so the model does not need to disambiguate different representations for identical patterns (e.g. if the number 112 is tokenised as [11,2], and 212 is tokenised as [2,12], then the pattern 12 has a context-dependent representation). The token dictionary has a length of 17.
For a decoder architecture, the model is a sequence-sequence transformer and each datapoint has a question-answer structure separated by the = character, e.g. the first datapoint is:
(15)
The model must therefore predict the numerical outputs and the ] token. For an encoder-decoder architecture, the encoder input is the question and the decoder performs next-token prediction over the answer, e.g.
(16)
To help visualise the task, Figure 7 shows the predictions of the baseline Pre-Norm model after 1 epoch. Figure 7 shows the fully-trained model, to help visualise the attainable in-distribution performance. The final epoch per-token accuracy is logged as 92%; the model sometimes correctly predicts all digits of the answer, otherwise it appears to be correct in the leading digits. Figures 7-7 repeat this for the Large model variation, which acts on a more complex task setting and achieves a lower per-token accuracy of 57%. Once again, the correctly predicted tokens appear to be driven by the leading digits.
Figure 4: BaselinePre-Norm model predictions after 1 training epoch.
Figure 5: BaselinePre-Norm model predictions after training.
Figure 6: LargePre-Norm model predictions after 1 training epoch.
Figure 7: LargePre-Norm model predictions after training.
A.3 Data-generation process
One advantage of this task is the ability to modulate its complexity. Each dataset is defined by two hyperparameters:
Dataset parameter
Example
Description
[3, 4, 6]
The allowed number of integers per-sequence
[2, 3]
The allowed number of digits per-integer
Each datapoint is generated by uniformly sampling a value of , then uniformly sampling a value of for each integer. This ensures that examples are not simply dominated by integers with the maximum number of digits. Each integer is uniformly sampled from all positive and negative integers with that length. Between each integer, an operator is uniformly sampled from the list . For example, the datapoint
(17)
was generated by sampling a value of to determine that the sum contains four integers, then sampling four values of to determine their lengths, then sampling the numbers and operators . The inclusion of subtraction, addition of negative numbers, and double-negatives is intended to emphasise solutions that parse the context of each digit within the sum.
A.4 Train/test specifications
Table 5 shows the and parameters used for the Baseline and Alternate experiments. We also show the number of datapoints, and the per-datapoint sampling probability. This is a range, with higher probabilites for the simpler sums. Table 5 shows the task specification for the Large model variation, which is trained on a more complex setting. We also perform a scan over model size and learning rate to compare the training stability of Pre-Norm and QKV-Norm. These experiments were performed using an earlier problem configuration shown in Table 5.
Dataset
Num datapoints
Datapoint probability
Train
110M, acc=90% @ 40M
to
Validation
6.4k
to
In-distribution
128k
to
OOD (interpolation)
128k
to
OOD (extrapolation)
128k
to
Table 3: Dataset configurations used for Baseline and Alternate results.
Dataset
Num datapoints
Datapoint probability
Train
25M
to
Validation
6.4k
to
In-distribution
128k
to
OOD (interpolation)
128k
to
OOD (extrapolation)
128k
to
Table 4: Dataset configurations used for Large results.
Dataset
Datapoint probability
Train set
to
In-distribution
to
Table 5: Dataset configurations used for training stability results (Figure 24).
We halt training according to wall time, which leads to a range of observed dataset sizes. Model convergence may also occur much earlier. We therefore show an order-of-magnitude estimate for the number of observed datapoints, as well as the point at which the baseline model reaches 90% per-token accuracy (this represents almost-convergence, which is logged at 92.1%).
Note that our data-generation strategy does not ensure that training examples are exclusive (there may be repetitions), nor that the in-distribution test set does not contain overlap with training examples. The final column is therefore important, because it demonstrates that the highest per-datapoint sampling probability is , whilst the model converges with datapoints and observes in total. Since the datapoint probability is for the simplest configurations and for the most complicated, this ensures that the in-distribution evaluation metric is dominated by novel examples. The validation set is only used for visual inspection of model behaviour during training, as in Figures 7-7.
A.5 Model specification (main experiments)
We use a decoder architecture, meaning that the dot-product self-attention layers are causally masked such that token can only attend to tokens . The model has the following structure:
Embedding + positional encoding We initialise each token embedding as , where is a token embedding with elements, and use cyclic positional encodings of the same form as the original transformer architecture [1], with frequencies initialised as a base log-series between periods of and tokens. For each sequence, all position indices are simultaneously offset by a random integer between and . This augmentation is designed to encourage the use of relative positions rather than absolute. The frequencies are then left as trainable parameters. The positional encodings contribute the first components of , and the remaining are set to . This configuration guarantees that the token embeddings and positional encodings can be made orthogonal in the first layer, and have constant -norm, consistent with our theoretical structure.
Attention block is the number of residual blocks of our model, where our baseline is . The update is as formulated in section 4, where is the number of parallel attention heads per layer. Since the embeddings have length , we must have , whilst the latent dimension is configurable. Either the Pre-Norm or QKV-Norm strategy is used, as configured.
Feed-forward block The feed-forward blocks update embeddings using the function , where is a dense network with one hidden layer of size . The network uses a ReLU [65] activation function on the intermediate layer, followed by a linear projection back onto embedding space. To maintain consistency with other models, we apply LayerNorm at the input to . Both LayerNorm and use bias parameters.
Multi-layer perceptron The final embeddings are mapped onto token logits using the function , where is a multi-layer perceptron with two hidden layers of size and ReLU activation. The final layer is a linear projection onto the space of logits, which has length 17. For the training stability scan in Figure 24, the MLP has three hidden layers instead.
Hyperparameters Table 7 shows the hyperparameters used to configure the networks of the main experiments. Table 7 show the hyperparameters used for the training stability analysis. This experiment also uses encoder-decoder models, following the same setup as the original transformer architecture [1] and with the layer configurations listed here.
Model
seed
Baseline
32
512
10
12
64
512
2512
100
Alternative
32
512
8
12
64
512
2512
100
Large
32
1024
12
16
64
512
2512
100
Table 6: Model hyperparameters for main experiments (i.e. other than training stability). Baseline is used for the main results presented in section 7 (short). Alternate and Large are presented in appendix C to show reproducibility of observations.
Model
seed
All
16
-
-
12
-
512
3512
1,2
Table 7: Model hyperparameters for training stability experiments. Empty parameters are varied per-model and displayed in Figure 24.
Loss The loss function is categorical cross entropy, calculated from the output logits.
A.6 Model initialisation
We use a custom initialisation strategy to give control over the initial state of the model. In particular, we use Checkpoint layers to ensure that the initial states are comparable between Pre-Norm and QKV-Norm. This ensures that any observed differences are driven by the normalisation function, rather than being confounded by the layer placement creating more/less favourable initial conditions.
Checkpoint layers are calibrated on the first training batch immediately prior to training. They use this data to measure the standard deviation at that point, and calculate a scale factor that fixes the standard deviation to a pre-defined hyperparameter . All subsequent passes through the layer simply apply this scale factor. This ensures that the model is initialised with a standard deviation of at that point.
We apply Checkpoint layers to the token embeddings (), and the initial embeddings (), ensuring they are relatively balanced and unit scale. In every attention layer, we apply Checkpoint layers to re-calibrate the possibly-Pre-normalised embeddings to immediately before applying the , , and operators. This counteracts the effect that transformer necessarily increases the embedding variance throughout the model at initialisation. We apply Checkpoint layers to in every attention layer, with constant . This controls the variance on the initial-state attention distribution. We apply Checkpoint layers to in every attention layer, with constant , calibrating it with respect to .
In the attention layer, we use uniform initialisation of the weight matrices , , , and . The limits are configured to ensure that the initial state standard deviations on and are close to their target values. Defining , the limits are calculated as follows:
Weight
Limits
However, we note that this initialisation is superseded by the calibration of the Checkpoint layers for determining the initial state, and we include it only to promote numerical stability. All other feed-forward layers use Glorot uniform [66] initialisation, as implemented in Keras [67]. Normalisation gain parameters are initialised to and biases, where used, to .
A.7 Training algorithm
We train using the AdamW optimiser [68] with learning rate and weight decay of , with all other parameters following their default values in TensorFlow+Keras v2.15.0. Each epoch consists of batches of datapoints. For the main experiments and model variations, we use an adaptive learning rate decay strategy. This means that the learning rate is multiplied by a factor of if the training loss does not improve for consecutive epochs. We find that this balances training speed with improved performance by using small learning rates later in training. Training is halted after two days of wall time, which we observe to allow model convergence, as shown in Figure 8. For the model stability scan, training is run for hours, and learning rate is not allowed to decay (stability with respect to learning rate being one of the targets of study).
Figure 8: Model training curves for the BaselinePre-Norm configuration.
A.8 Computational resources
The main experiments are all performed on a single Nvidia v100-SXM2-16GB (Volta) GPU. The scan of models used for the stability analysis were trained on a batch cluster with a variety of compute nodes, using cores per training run. A representative compute node is 2×12-core Intel Xeon E5-2690 v3 @ 2.60GHz + 128GB RAM.
A.9 Environment details
The main contributing package versions are as follows:
Package
Version
Python
3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
TensorFlow
2.15.0
Keras
2.15.0
NumPy
1.26.2
Appendix B Extended main experiments
This appendix provides an extended explanation of the experimental results in section 7.
B.1 Embedding structure
Figure 1 presented the spread of embedding -norms as a function of model depth. Let us now describe in detail how this plot was made. Figure 10 shows the distribution of embeddings at the input to every attention layer, for Baseline models trained using Pre-Norm (top) and QKV-Norm (bottom). Colours represent the initial token type corresponding to that embedding. Asterisks denote tokens in the answer, with all labels denoting the question.
Figure 9: Distribution of embedding -norms at different model depths using the Baseline Pre-Norm model.
Figure 10: Distribution of embedding -norms at different model depths using the Baseline QKV-Norm model.
We see that the begin-sequence token (BEG) is often separate from the distribution, which may be because it remains non-annotated and fulfils a qualitatively different role. We remove this from our estimates to avoid erroneously inflating the spread. Interestingly, BEG still tends to be close to the main bulk for Pre-Norm, but can be very far for QKV-Norm, and we have to use overflow panels to capture it. This is consistent with our hypothesis that Pre-Norm stability requires embeddings to have similar norms, whilst QKV-Norm does not require this.
It is not sufficient to simply measure the spread of Figures 10 and 10, because an attention head may not be sensitive to all embeddings in the layer. The easiest way to account for this is to weight every embedding according to its assigned attention. Secondly, the distribution is expected to be narrow only on a per-head basis, and there is no reason why distinct heads cannot be centred around different medians. We therefore calculate the weighted distribution of embeddings on a per-head basis, as shown in Figures 12-12.
Figure 11: Embedding distributions at different model depths using the Baseline Pre-Norm model. The categories 0-9 and N are separated into whether they occur in the question (light colour) or answer (dark colour, distinguished by label). These distributions are used to compute the LHS of Figure 1 after removing the BEG tokens.
Figure 12: Embedding distributions at different model depths using the Baseline QKV-Norm model. The categories 0-9 and N are separated into whether they occur in the question (light colour) or answer (dark colour, distinguished by label). Note that overflow panels are excluded from this plot for legibility. These distributions are used to compute the RHS of Figure 1 after removing the BEG tokens.
B.2 Circuit collapse
Figure 3 shows the probability of circuit collapse. This is the probability that an attention distribution with no noise selects embedding with high probability , and when noise is added, it transitions such that some becomes the maximum attended embedding. This definition is chosen because it matches our theoretical results in section 6. However, it does not require that the distribution remains sparse after the noise addition. Figure 13 compares this baseline result (top) with an alternative definition (bottom), in which the second distribution must also be sparse, meaning for some . We see that of sparse attention distributions collapse at a noise level of when using the original definition, delayed until when using the sparse definition. Therefore we observe that the sparse-to-sparse case does occur, but requires a higher noise level.
Figure 13: Probability of circuit collapse vs increasing noise. Top: using the baseline definition. This is a reproduction of Figure 3. Bottom: requiring the attention distribution to remain sparse after switching to a different token.
Appendix C Main experiments: results with different models
In this appendix we reproduce the main experimental results using our model variations.
C.1 Embedding lengths
Figure 1 shows the empirical results demonstrating the attention-weighted spread of embeddings. Figures 16-16 show the results we obtain when we perform the same analysis using the Alternate and Large model variations. In all cases, we observe 90% of embeddings within a spread of roughly when using Pre-Norm. In all cases, the spread of embeddings for QKV-Norm is larger, although we note that the effect is smaller when using the Alternate variation.
Figure 14: Attention-weighted spread of embeddings at increasing model depth using the Baseline model and task configuration. This is a replication of Figure 1.
Figure 15: Attention-weighted spread of embeddings at increasing model depth using the Alternate model and task configuration.
Figure 16: Attention-weighted spread of embeddings at increasing model depth using the Large model and task configuration.
C.2 Model stability with simulated inference
Figure 3(left) shows the stability of the model predictions under simulated interference. Figure 18 shows the results we obtain when we perform the same analysis using the Alternate model variation. The Large model was not run due to its high computational load. We find that the Alternate model has a larger effect size that Baseline, with a loss of per-token accuracy with only a noise effect. For completeness, we show QKV-Norm on the RHS. This is stable by construction, and only jitter due to finite sampling is observed. In these plots, we estimate the statistical uncertainty by evaluating over three datasets and calculating the standard error on the mean. This is plotted as a shaded band, but tends to be narrower than the line width.
Figure 3(right) compares the stability when we only apply noise to sparse heads (defined as , thin dashed line) and non-sparse heads (defined as , thick dashed line). Figure 20 compares these results with the Alternate model variation. In both experiments, sparse-attention is stable under %-level noise, and non-sparse distributions dominate this regime.
Note that later layers experience both the artificial noise injection as well as perturbation of their inputs due to the compounding of errors caused by noise in the previous layers.
Figure 17: Evolution of per-token accuracy as we increase noise on the -norms of for the Baseline model and task configuration.
Figure 18: Evolution of per-token accuracy as we increase noise on the -norms of for the Alternate model and task configuration.
Figure 19: Evolution of per-token accuracy as we increase noise on the -norms of for the Baseline model and task configuration.
Figure 20: Evolution of per-token accuracy as we increase noise on the -norms of for the Alternate model and task configuration.
C.3 Circuit collapse
Figure 3 shows the probability of circuit collapse. This is the probability that an attention distribution with no noise selects embedding with high probability , and when noise is added, it transitions such that some becomes the maximum attended embedding. Figure 23 shows the results we obtain when we perform the same analysis using the Alternate and Large model variations. In both cases, we observe the onset of circuit collapse at smaller noise levels. Whilst the Baseline model observed that 1% of sparse attention heads collapsed with 11% noise, this value is 7.5% for Alternate and 5.5% for Large.
Figure 21: Probability of circuit collapse vs increasing noise using the Baseline model and task configuration. This is a replication of Figure 3.
Figure 22: Probability of circuit collapse vs increasing noise using the Alternate model and task configuration.
Figure 23: Probability of circuit collapse vs increasing noise using the Large model and task configuration.
Appendix D Additional comparisons between Pre-Norm and QKV-Norm
D.1 Model performance
Table 8 shows the per-token accuracy performance for the trained Baseline models. Pre-Norm and QKV-Norm have comparable in-distribution per-token accuracies of and respectively. However, performance drops to () for generalisation to intermediate task difficulty, and () for increased difficulty. The performance drop of QKV-Norm implies that it has learned a less generalisable solution. This re-enforces our motivation that architectural changes should be important for the inductive bias of a model.
Dataset
Pre-Norm
QKV-Norm
In-distribution
OOD (interpolation)
OOD (extrapolation)
Table 8: Per-token accuracy for the Baseline models. Dataset configurations are shown in Table 5.
D.2 Training stability
Changing the normalisation layer is expected to affect the training rate and stability. To investigate this, Figure 24 shows the training curves for different model sizes and learning rates. The task is configured as presented in Table 7. The Depth parameter is the number of layers, where brackets indicate the values for an encoder-decoder model. For example, means that we use encoder blocks and decoder blocks. Each decoder block has a self-attention and a cross-attention layer, and so the total model has attention layers. A single Depth value indicates a decoder architecture, with the number of layers shown. Width is the number of neurons per layer, and Latent width is the number of neurons on the space of (called in section 4). Training curves on the top row use a learning rate of , whilst the bottom row use a value of . In each panel, two training runs are shown, with different random seeds. Pre-Norm is shown in blue, and QKV-Norm in red.
We find that Pre-Norm training is unstable for large learning rates and model sizes, as shown by the flat blue curves in the top right hand panels. Similar stabilisation improvements at large learning rate is reported for QK-Norm in [19, 20], which applies layer normalisation to but not , as for QKV-Norm. However, we note that training large models with a smaller learning rate leads to improved model performance, as shown by the panels on the bottom right. Finally, we note that both methods typically train the model at similar rates, however small model training follows a very different trajectory, with QKV-Norm learning more slowly at the beginning of training (bottom left panels). There is also some visible evidence that small model training is actually less effective when using QKV-Norm with a large learning rate (top left panels).
Figure 24: Training curves when learning the task configuration shown in Table 7.
D.3 Attention sparsity
We find that our Pre-Norm models often exploit sparse-attention, whereas models trained with QKV-Norm do not. Similar behaviour is reported for QK-Norm in [18]. For a systematic comparison, Figure 25 shows a histogram of the maximum attention observed per-distribution (i.e. a histogram of ). When making this plot, we do not consider the first row of the attention matrix, in which the [ token attends fully to itself.
We see that the Pre-Norm distribution has a sharp peak at , indicating a significant use of sparse-attention. By contrast, the QKV-Norm distribution is weighted towards and has no peak at . To verify this behaviour, Figure 27 shows an attention heatmap for a randomly chosen datapoint when using the BaselinePre-Norm model, and Figure 27 shows the same datapoint for QKV-Norm. We observe a significantly less sparse attention matrix for QKV-Norm. Note that [18] also shows a similar visualisation.
Figure 25: Distribution of the maximum attention observed per-distribution, i.e. , in the Baseline case. We observe that the Pre-Norm model often utilises sparse-attention, as seen by the peak at . By contrast, QKV-Norm shows no such peak. Similar behaviour is reported for QK-Norm in [18].
Figure 26: Attention maps for a random in-distribution example using the BaselinePre-Norm model. Several attention heads create sparse attention distributions.
Figure 27: Attention maps for a random in-distribution example using the BaselineQKV-Norm model. We observe much less sparsity than in the Pre-Norm model, shown in Figure 27. Similar behaviour is reported for QK-Norm in [18].
Appendix E Supplementary theorems
This appendix contains theorems that support the main results, providing additional context or being pre-requisite for the proofs in appendix F. We use the formulation of section 4, where are indices over tokens, is the message receiving embedding, are the message senders, , and is the attention distribution
Shifting attention scores by a constant offset does not affect the attention distribution. Therefore attention is fully determined by differences in scores.
Proof. Applying the shift with fixed , we have
(18)
Alternatively we may write
(19)
where .
{mdframed}
[backgroundcolor=green!5]
Theorem 10.
Multiplying attention scores by a positive factor changes the inverse-temperature of the attention distribution, modulating its sparsity (low temperature = less entropy = more sparse). Corollary: In the sparse limit, attention is fully determined by the order of .
Proof. Applying the scaling with fixed , we have
(20)
where the argmax operator is fully determined by the order of .
{mdframed}
[backgroundcolor=green!5]
Theorem 11.
In the No-Norm case, the attention distribution is defined by the projection of onto a fixed vector for a given . The length of is an inverse-temperature parameter.
Proof. Write where , which is the dot-product between and a fixed vector on the row space of . Then, re-writing in terms of the vector lengths and the enclosing angle , we have . The factor is identical for all , making it an inverse-temperature.
{mdframed}
[backgroundcolor=green!5]
Theorem 12.
In the No-Norm case, bias parameters in the construction of query and key vectors are nullified by the softmax, or only contribute terms that may be recovered if contains a constant direction.
Proof. Consider a modification to the construction of query and key vectors that uses the affine transformations and , with , , , and . The dot-product attention scores are then:
(21)
After expanding the terms, we find an additive constant , and move this onto the LHS. Theorem 9 states that this has no impact on the output of the softmax operator. We identify and as vectors on the row-spaces of and respectively, defined as linear maps of the special directions and . Since is constant for each softmax, is constant, and we absorb it into the LHS. We perform the singular value decomposition where are orthonormal matrices and is a diagonal matrix of positive-semidefinite singular values with maximum rank . Orthonormal matrices apply a basis change to the embedding space using rotations and reflections. We write the transformed embeddings as and . The dot-product then has two terms:
1.
sculpts the attention distribution according to pairwise relationships between embeddings. We can say that align the bases of and , map** them onto a common orthonormal coordinate system. then assigns an importance weight to each coordinate , determining the contribution of .
2.
means “token sends to all receivers when ”, where must be a vector on the row-space of . This may be recovered in the expansion of if there exists a direction for which .
Appendix F Proofs of theorems in the main text
This appendix provides proofs for the theorems presented in section 5-6.
{mdframed}
[backgroundcolor=green!5]
Theorem 1.No-Norm: If two heads with finite non-zero temperature attend to different semantic subspaces, the subspaces must be linearly independent . Corollary: is a low-rank matrix with (left and right) null-spaces that span all non-attended information.
Proof.
Let and be co-ordinates for the subspaces of attended to by heads A and B respectively, and be all other information. Let and , where denotes independence. Without loss of generality, write
(22)
Then write
(23)
which requires and , since any cancellation between the two terms must be independent of and so can be absorbed entirely into the function . This means that and must both be orthogonal to , meaning that they reside on the left null space of , or are projected by onto a null space of .
Head A can only attend to if it is not on either of these null spaces, meaning that is linearly independent of and . Likewise for head B
(24)
requires that is linearly independent of both and . Since resides on both null spaces, it is linearly independent of both and , and may be seen as a third subspace that passes information through to subsequent layers.
We can also write , and so the same argument also holds for subspaces on . In this case, non-attended subspaces are spanned by the right null space of .
{mdframed}
[backgroundcolor=green!5]
Theorem 2.Pre-Norm: Semantic subspaces must be represented as orthogonal spheres defined using the -norm. Corollary: if either orthogonality or constant-norm are violated, semantic subspaces interfere through a multiplicative factor on .
Proof. Write
(25)
Then for head A we have
(26)
where are the attention scores from the No-Norm case, which requires and to be linearly independent. Now we additionally require , with
(27)
where we suppress parameter dependence for readability. Since is a monotonic function, this can only be satisfied if
(28)
Repeating this process for head B gives
(29)
Combining and collecting dependencies, we then have
(30)
(31)
(32)
(33)
We can go one step further, noticing that each individual term carries a different functional dependence, and so must independently be constant222N.B. If then reduces to , which is already required.. We then have
(34)
The requirements and mean that the semantic subspaces have a spherical structure defined by the -norm .
Now consider the requirement . Say that and have and degrees of freedom, meaning that and have and respectively, since they each lose one by confinement to the sphere. Say that the constant is nonzero such that . This means that there must be some direction for which . If we know all coordinates of , and all coordinates of except for direction , then we also know the value of , because it is fixed by the constant. However, this would mean that and are not independent, violating the condition . The only way to satisfy independence is if , ensuring that degrees of freedom on and never become entangled. Therefore, to satisfy semantic independence, we must have . This means that the subspaces are not just linearly independent, but orthogonal.
We have shown the proof for semantic subspaces of . As for Theorem 1, the same structure must be true for by symmetry.
{mdframed}
[backgroundcolor=green!5]
Theorem 3.QKV-Norm: Semantic subspaces must be linearly separable, reproducing the No-Norm case.
Proof. We have
(35)
where are the attention scores from the No-Norm case, which requires and to be linearly independent. Use
(36)
and
(37)
Since we already have the condition of linearly independent , there must exist a linear projection operator such that . Defining , we then have
(38)
This demonstrates that it is possible to separate linearly independent semantic subspaces on . By symmetry of , the same must be true for .
{mdframed}
[backgroundcolor=green!5]
Theorem 4. Consider independent infinitesimal perturbations on queries , keys , and messages . These propagate onto as
(39)
(40)
(41)
where is the value of measured from the attention-weighted centroid .
Proof. Consider where are infinitesimal perturbations on . Then where by Taylor expansion we find
(42)
where the leading term is a matrix acting on a vector . Differentiating gives
(43)
with and , and we are using etc to index over tokens instead of etc, because this is more readable when we have many summations. Then
[continued in next box…]
{mdframed}
[backgroundcolor=green!5]
[…continuing from previous box]
(44)
and , where we retain the transpose to indicate that this is an element of the dual vector space (i.e. covector). Inserting these results into our expression for gives
(45)
This is the result for Eq. 39. Repeating the process for perturbations on , we have
[backgroundcolor=green!5]
Theorem 5. For sparse attention:
(50)
i.e. the message is stable with respect to small interference in the queries and keys. Interference in the selected value is linearly transferred onto the message.
Proof. For sparse attention we have for some . For perturbations of , the RHS of Eq. 39 becomes
(51)
where the final step is because . For perturbations of , the RHS of Eq. 40 evaluates to because
(52)
where the final step is because . For perturbations of , the RHS of Eq. 41 evaluates to
(53)
{mdframed}
[backgroundcolor=green!5]
Theorem 6.
For isotropic attention:
(54)
N.B. isotropy requires or . Lemma 1: the update is stable to noisy when , or when (c.f. keys and messages from independent subspaces). Lemma 2: the update is stable to noisy when , or when . Lemma 3: the update is stable to noisy when . Other cases propagate linearly.
Proof. For isotropic attention we have . For perturbations of , the RHS of Eq. 39 is
(55)
For lemma 1, we note that implies , and if then .
[continued in next box…]
{mdframed}
[backgroundcolor=green!5]
[…continuing from previous box]
For lemma 2, this expression evaluates to if , and if then .
For perturbations of , the RHS of Eq. 41 evaluates to
(57)
{mdframed}
[backgroundcolor=green!5]
Theorem 7. Sensitivity of sparse attention to multiplicative perturbations and with . Circuit collapse occurs when for which:
(58)
where temperature cancels in the fraction. Attention is fully stable above the critical transition point (c.f. ). We see that query perturbations alone are insufficient, as they result in . Lemma: consider the special case when all keys have similar length , the attended token has , the keys are far-from-orthogonal s.t. , and . Using , circuit collapse occurs when for which:
(59)
i.e. stability requires either well-separated keys s.t. , or small perturbations .
Proof. Apply and to , then we have such that . For multiplicative perturbations we have and , and so . Each term recovers a factor of , which we factor out to give . The final term is subleading in the limit of small perturbations, and so
(60)
Circuit collapse occurs when for some . Substituting our limit for gives
(61)
and collecting terms gives
(62)
We then divide each side by , taking care to reverse the sign of the inequality when this factor is negative, to give
[continued in next box…]
{mdframed}
[backgroundcolor=green!5]
[…continuing from previous box]
(63)
which is the first expression in the theorem. We note that any temperature parameter cancels in the fraction, which means that the attention head cannot become more stable by reducing its temperature to become more sparse. has the limits
(64)
meaning that query perturbations alone are insufficient, contributing only when they co-occur with perturbations on the keys. Write with , and the approximation of identical key norms turns this into . Then
(65)
Then means that , and so . We perform a Taylor expansion in to obtain
(66)
which is valid when . This is true for any for which is far from orthogonal with . Substituting this into our circuit collapse condition, we have
(67)
where we consider the case of for readability. Re-arranging gives
(68)
if , and we reverse the inequality otherwise. We have approximated the denominator on the RHS as for .
When , the LHS of Eq. 68 is small. This means that the attention head can tolerate only very small perturbations . Therefore semantic subspaces must either have a highly orthogonal substructure s.t. , or be orthogonal s.t. .
{mdframed}
[backgroundcolor=green!5]
Theorem. 14. Sensitivity of isotropic attention to multiplicative perturbations. Say with where have comparable amplitudes. Then
(69)
Proof. We begin with the following result from Theorem 6:
(70)
Substituting and taking inside the brackets gives
(71)
We then notice that isotropic attention requires that is a constant, which we call . Then
(72)
is our general result. We then note three special cases, each resulting in :
1.
If then . This is case when interference on the keys is not dominated by the same semantic subspace as the message .
2.
If all keys are perturbed by the same factor , then because .
3.
Isotropic attention can be achieved by either or . If the case is then this implies also.