Learning to grok: Emergence of in-context learning and skill composition in modular arithmetic tasks
Abstract
Large language models can solve tasks that were not present in the training set. This capability is believed to be due to in-context learning and skill composition. In this work, we study the emergence of in-context learning and skill composition in a collection of modular arithmetic tasks. Specifically, we consider a finite collection of linear modular functions labeled by the vector . We use some of these tasks for pre-training and the rest for out-of-distribution testing. We empirically show that a GPT-style transformer exhibits a transition from in-distribution to out-of-distribution generalization as the number of pre-training tasks increases. We find that the smallest model capable of out-of-distribution generalization requires two transformer blocks, while for deeper models, the out-of-distribution generalization phase is transient, necessitating early stop**. Finally, we perform an interpretability study of the pre-trained models, revealing the highly structured representations in both phases; and discuss the learnt algorithm.
{tianyuh, ddoshi, aritrad}@umd.edu [email protected]
1 Introduction
Large language models (LLMs) can perform simple tasks that have not been present in the training set. This ability is usually achieved via in-context learning [5]. More importantly, LLMs can perform an even larger variety of very complex tasks upon appropriate prompting or fine-tuning. The latter ability to perform complex tasks is usually attributed to the following mechanism. First, LLMs learn a large variety of simple tasks and, then, how to compose those skills to form very complex skills [3]. Furthermore, LLMs also exhibit “emergent capabilities” – a sudden emergence of a new complex skill as a function of scale (either model parameters, compute or data) [28, 9]. It is plausible that these sudden performance improvements are due to one or both of these mechanisms. For example, LLMs show grokking on algorithmic tasks [23], which results from the model learning very structured representations [18, 11, 19]. Once these representations emerge, the model abruptly learns how to perform the task.
In this work, we set out to examine skill composition both empirically and mechanistically. Inspired by the prior work that investigated emergence of in-context learning on linear regression tasks [1, 24], we introduce a finite collection of discrete modular arithmetic tasks [23] generalized to the in-context learning setting. Each task corresponds to learning a linear function over from the examples provided in context of the autoregressive model (AM). In the bi-variate case there are such functions labeled by the vector . The main objective of this algorithmic dataset is to probe how AM utilizes the tasks it has learnt during training to solve the new tasks.
Our analysis shows that the solution found by the AM after optimization is qualitatively different from the linear regression cases studied before [1]. In those cases, due to the continuous nature of the task, AM develops an emergent first-order optimization method that minimizes an emergent quadratic loss function. Furthermore, as it was shown in [1], a single linear attention layer can solve the regression problem, while adding extra layers and non-linearities slightly modifies the gradient descent. In the modular arithmetic case, AM first learns how to solve the pre-training tasks and later (assuming enough different tasks) develops a generalizing solution by combining the solved tasks.
Our main findings as well as the structure of the algorithmic dataset are illustrated on Fig. 1. Our main findings are: (i) there are four different phases in the end of pre-training depending on the number of tasks, , and number of examples per task, . (ii) At inference time, there is a generalization transition in the number of few-shot examples, as the number of examples grows, the models starts to generalize. This effect is somewhat similar to the transition in sample complexity for the modular arithmetic found in [23]. (iii) model develops a striking circular representation for all of the tasks that naturally generalizes the circular representations found in the original work [23]. We further find that the deeper models are easier to optimize, but much harder to interpret. The optimization is discussed in more detail in the main text. Here we will highlight that optimization for these tasks is challenging and the AM tends to prefer the minimum that just solve a handful of tasks and memorize the training set. To avoid such minima we make sure that every batch contains equal number of tasks (meaning that no tasks is over- or under-represented in each batch). We further find that for larger models early stop** is necessary because the generalizing solution is transient.
![Refer to caption](x1.png)
We organize our paper as follows. Section 2 contains the literature review. In Section 3 we explain our notations and discuss the experimental details. In Section 4 we demonstrate empirically that the out-of-distribution ICL ability emerges as the number of training tasks increases. We also study the effects of model depth and task difficulty. In Section 5 we carefully examine a minimal setting, i.e. two-block transformer: we compare the representations learnt in four different phases and show that in the generalizing phase the representations are highly structured and generalize the original modular addition case of [23].
2 Related Works
In-Context Learning (ICL)
Brown et al. [5] first demonstrated that large models performance improves substantially when a few examples of the task at hand are provided at inference time, in the prompt. Akyürek et al. [2], Ahn et al. [1], von Oswald et al. [27] showed that the AM implements emergent first-order optimization on an emergent objective function to solve linear regression tasks. Furthermore, [2] showed that larger models learn to perform Bayesian estimation. Garg et al. [10] demonstrated that transformers can learn several simple classes of functions in context. Kirsch et al. [15] presented how task diversity and model size would affect the ICL performance for unseen tasks using a mixture of modified MNIST datasets. Raventos et al. [24] investigated the relation between task diversity and out-of-distribution ICL ability on linear regression tasks. Lin and Lee [16] identified two operating modes in ICL using a mixture of linear regression tasks, where for the first several shots, the model tries to figure out the correct task vector and later uses it to predict the correct results. Boix-Adserà et al. [4] showed theoretically and experimentally that with enough pre-training data, a transformer model can perform abstract reasoning that a MLP cannot do. Guo et al. [13] showed that transformers can use lower layers to memorize and upper layers to perform ICL in a feature regression setting. It was found in [25] that ICL is a transient phase from the optimization point of view: it goes away once the model is over-trained. Hendel et al. [14], Liu et al. [17] showed that language models form in-context vectors, which can be extracted and used to control model predictions.
Modular Arithmetic
Power et al. [23] discovered Grokking, where models trained on modular arithmetic datasets have an abrupt change from random guessing to generalization on the test set way after the model memorized the training set. Gromov [11], Nanda et al. [19], Gu et al. [12] showed that for modular addition tasks, models learned to map integers to Fourier features to solve modular arithmetic tasks. Liu et al. [18] showed that grokking is related to learning highly structural features, and the grokking transition can be explained by a toy model. Zhong et al. [29] showed that there is more than one algorithm that a model can implement to solve modular addition. Doshi et al. [6] showed that corruption of the label does not prevent the models from finding a generalizing solution. Doshi et al. [7] showed that MLP and transformer models can solve a specific family of modular polynomial tasks by bijectively map** them to modular addition tasks.
Interpretability
Elhage et al. [8], Olsson et al. [21] showed that transformers can form induction heads that predict the next token in a sequence by identifying and copying patterns from earlier in the sequence. With several indirect empirical evidences, they showed that those heads might constitute the core mechanism of ICL. Nichani et al. [20] showed theoretically and empirically how disentangled transformers learn causal structures from in-context Markov chains by forming induction heads.
3 Preliminaries
Linear Modular Functions
We consider modular arithmetic tasks of the form: . We will refer to the coefficients as the task vector. The superscript labels the possible tasks. We will refer to as the input vector, which is labeled by the subscript .
In-Context Learning with Transformers
We use GPT-like transformers [5] with ReLU activation function and Rotary Positional Embedding (RoPE) [26]. The model has consecutive blocks, attention-heads, and embedding dimension . Each number is tokenized as an independent token. The pre-training is done following a slightly modified next-token prediction setup, with sequences of the form:
(1) |
where is the maximum number of in-context examples. The model is asked to predict only the labels . We emphasize that we do not explicitly provide the task vectors to the model (see Fig. 1) – this information is implicit in the in-context labels . In order for the model to generalize, it must determine the underlying task vector from the few-shot examples.
Generalization
There are two notions of generalization in this setup. (i) In-distribution: Generalization to unseen input vectors , but on task vector the model has seen during pre-training. (ii) Out-of-distribution: Generalization to task vectors the model has not seen during pre-training. To clearly separate these regimes, we split the task vectors into in-distribution (i.d.) set and out-of-distribution (o.o.d.) set . Similarly, we split the input vectors into train and test sets: . This results in four distinct sets of sequences constructed from those sets; we name them and . The set is used for pre-training, while the other three sets are used for evaluations.
Pre-Training Task Selection and Sequence Design
We always sample the pre-training task vectors in sets of 4, following the rectangular rule, shown in Figure 2(a). Additionally, each batch contains an equal representation from all the task vectors in the set . Moreover, all the tasks share the same sequence of inputs. For example, a batch with two different task vectors and two distinct input sequences per task (resulting in four total sequences) is shown in Figure 2(b).
![Refer to caption](x2.png)
This structured approach creates a coherent signal from the sequences within each batch; ensuring that the model learns multiple task vectors with reasonable batch sizes. Alternatively, if the batches are sampled i.i.d., then the model is confused by the batch noise and cannot learn any tasks.
Detailed discussions on task selection and sequence design are presented in Appendix D.
Default Setting
Unless stated explicitly, we will use , the number of heads , and embedding dimension , with in-context examples. All models are trained with AdamW optimizer with batch size for k steps. We have also tied the embedding layer of the model with the readout layer.
4 Emergence of In-Context Learning and Task Composition
In this section, we demonstrate that a transformer model with depth can develop ICL and out-of-distribution generalization on modular arithmetic tasks. We delve deeper into the two notions of generalization (i.d. and o.o.d.), and discuss the relevant factors. We find that the model’s ability to generalize out-of-distribution is predominantly determined by the number of pretraining tasks .
4.1 Transition driven by the number of tasks
![Refer to caption](x3.png)
![Refer to caption](x4.png)
![Refer to caption](x5.png)
![Refer to caption](x6.png)
![Refer to caption](x7.png)
In Figure 3(a), we show the accuracy of models vs the number of training tasks and the number of few-shot examples quantified by the fraction of the total number of data points, ; on sets , , and . The phase diagram in Figure 1 is constructed by merging the last shot accuracy version of these four diagrams shown in Figure 18(a) of Appendix G.
The ability of the model to generalize in-distribution increases with , as can be seen by comparing the first two panels of Figure 3(a). This behavior is in correspondence with the original work on grokking, where the transition to generalizing solution is driven by the amount of data. Further, we observe that an increase in enhances the in-context sample efficiency, i.e. the model generalizes at inference time with fewer few-shot examples. This indicates the onset of the transition from a task-memorizing solution to the one that generalizes out-of-distribution. The model switches to a new algorithmic way of solving the task and the solution is more few-shot-sample-efficient.
Shifting our focus to the last two panels of Figure 3(a), we see that when , the model can solve new tasks that were absent in the training set. Notably, there appears to be a trade-off between memorization and generalization when the model attains this o.o.d. generalization ability. As the o.o.d. performance increases, the pre-training performance simultaneously degrades. This phenomenon indicates a shift in the algorithm implemented by the model. Prior to this transition, the model primarily needed to select possible vectors from the list of memorized tasks and apply them. However, post-transition, the model adopts a more universal approach to solve the task in-context. We emphasize, that the model learns to perform ICL in both scenarios. The difference lies in the approach to generalization. When the model can only generalize in-distribution it’s task is to classify the sequence as one of the seen tasks or as unknown. Once it matches the sequence to one of the memorized task vectors, it does well for pairs that only appear in the test set. However, as the number of tasks vectors grows the model fails to store them all and is forced to find a method of determining the task vector algorithmically at inference time. In that case model performs equally well on seen and un-seen tasks alike. In fact, the small two-layer model we study has such a low capacity that it entirely skips the in-distribution generalization phase and immediately jumps from pure memorization to out-of-distribution generalization.
Next, to further illustrate the effect of task diversity, we plot the pre-training accuracy (set ) and the o.o.d. test accuracy (set ) as a function of training steps (Figure 3(b, c)); for various values of . We observe a clear memorization-to-generalization transition as task diversity increases. Interestingly, for , the ICL ability on set exhibits non-monotonic behavior, where the o.o.d. performance rises and falls along the training process. This phenomenon is likely due to a competition between the memorizing and generalizing circuits inside the model. Note that this phenomenon is akin to the one analyzed in Singh et al. [25], albeit with a different setup.
Further evidence supporting the two-circuit competition can be observed in panel (d) of Figure 3. The loss curves show a “monotonic non-monotonic monotonic" transition as the task diversity increases. With a minimal number of pre-training tasks, the model primarily engages in memorization, resulting in a monotonically increasing o.o.d. loss curve. As the number of pre-training tasks increases, the loss curve exhibits non-monotonic behavior, indicating competition between two distinct neural circuits. This transient nature of o.o.d. generalization for is a peculiar case where memorization circuits are initially suppressed but eventually prevail. With substantial task diversity, the circuits responsible for generalization take over, culminating in a monotonic loss curve. Similar insights can be derived from examining the monotonicity of the accuracy curves in panel (e).
4.2 Effect of Model Size and Task Difficulty
![Refer to caption](x8.png)
A natural question to ask is if similar phenomena can be observed with different model sizes or task difficulties. Here we present our results with and in Figure 4 and leave the results for other prime values in Appendix H.
When comparing phase diagrams in Figure 3 with Figure 4, we observe that those phase diagrams across different depths are qualitatively similar, where the o.o.d. generalization only emerges with a large enough number of pre-training tasks. As model capacity decreases, performance on both the pre-training set and the o.o.d. test set degrades. This is particularly evident in the case, where the pre-training accuracy falls drastically as the model gains o.o.d. generalization ability.
Interesting observations can be made by comparing loss and accuracy on the o.o.d. test set as a function of context length at the end of training. First, it is evident that as the model depth decreases, the -shot loss surge attributable to memorization becomes milder. Notably, for models with , there is no loss surge in the -shot case across all three depths. Furthermore, the model with behaves significantly differently from the corresponding one with case, where the model fails to perform ICL for the o.o.d. test set. This is also distinct from the case, where the model tends heavily toward memorization due to its excessive capacity. Instead, the model manages to maintain a better balance between memorization and generalization at the end of pre-training. Consequently, the model has a -shot loss surge followed by a notable drop in ICL loss. This suggests that optimally leverages the available model capacity to facilitate effective learning dynamics for o.o.d. generalization.
5 Interpretability
![Refer to caption](x9.png)
An important question to consider is what algorithm the model implements to achieve o.o.d. generalization. Here we propose such an algorithm and provide empirical evidence supporting our claims. For simplicity, consider the case with two-shot inference:
The model learns to (i) re-scale the in-context inputs and (ii) perform linear transformations (over ). For example, in the above case, the model needs to find the two constants and such that
(2) |
Once the model has figured out the constants, the result can be simply computed with modular addition: [11, 19, 18, 29]. Notably, with sequences that provide more demonstrations, the model has the flexibility to combine in-context examples to derive the answer111In an autoregressive setting, the actual position that the model needs to output the prediction is shifted by one token, but this does not change our argument.. However, no matter which specific implementation the model chooses, solving linear systems in to figure out those constants is a must. This requires the model to be able to combine in-context examples and perform modular arithmetic over them, which breaks down to the following three essential skills other than copying information across in-context examples:
In Section 5.1, we demonstrate that while models of both depths, () learn skills I, II perfectly; the model outperforms its in skill III – namely combining the in-context examples. We attribute this disparity to the limited capacity of models. Then, in Section 5.2, we show that there are special attention heads that implement the skills I and II. We explicitly show that the structure of these heads deteriorate as pre-training task diversity decreases. We leave the discussion of embedding layers (shown in Figure 1(e)) to Appendix E.
5.1 Model Learns to Combine
We find telling signatures of Equation 2 upon analyzing model predictions with varying numbers of examples in context. Consider the 1-shot scenario, wherein only one demonstration is given to the model. This is equivalent to setting in Equation 2. In this case, one would expect that the model will only solve the task correctly for examples that obey for some . Indeed, this is exactly what we observe in fig. 5 – the model correctly predicts all the targets when the inputs are re-scalings of the in-context input .
For few-shot scenarios, the model does two things. It can predict the targets correctly whenever the query input is proportional to any of the inputs of the previous demonstrations (Figure 5 row 1). Additionally and more importantly, it can correctly solve substantially more than just these cases (Figure 5 row 1,3). This results from the model’s ability to combine in-context examples and use them for prediction. Evidence for this can be seen in the proliferation of red points in row 3 of Figure 5 as more examples are given in-context. The model performs inferior to its counterpart in this latter skill, hinting at an imperfect implementation of the algorithm and the partially correct choice of the set of constants .
5.2 Attention Heads Implement Essential Skills
![Refer to caption](x10.png)
In the previous section, we found “black-box" evidence suggesting that the model is implementing an algorithm akin to solving equation eq. 2. Now, we turn to “white-box" interpretability, to identify the components within the transformer model that are responsible for the essential skills outlined at the beginning of Section 5. In figure Figure 6(a, b) we analyze the important attention heads in a model that generalizes o.o.d. We compare them with the heads from a model that does not generalize o.o.d in Figure 6(c, d).
In Figure 6(a), we show the attention head from layer 1 that implements skill I. In the top panel, we see that each query only pays attention to itself and the two preceding keys. This pattern likely stems from the fact that each example in the sequence contains three tokens ; and suggests that this head is mostly focused on the information within the example.
In the bottom panel, we perform principal component analysis (PCA) on the outputs of this head. Specifically, we feed a batched k-shot sequence of the form , where the first inputs are fixed and the last input is scanned over all possible pairs. We concatenate the resulting features from and , resulting in batch of features – and perform PCA on this matrix. We project all these dimensional features onto the first two principal components. Annotating the pairs with 222We use 27 as the base of logarithm, which is a primitive root of , we find a “Clock-of-clocks" – where clocks of period 28 are themselves arranged in a bigger clock of period 28. Number 0 is located at the center of the clocks333This is expected since is mathematically ill-defined and needs special treatment.. We observe similar clock-of-clocks for concatenated features from and as well. We refer the reader to Appendix E for further details.
In Figure 6(b), we analyze the head from layer 2 that effectively implements Skill II – it does so by building upon skill 1 and leveraging the features discussed above. The upper panel shows that the highly structured attention map that focuses on the current example as well as the examples preceding the current one. This pattern aligns with the step in our proposed algorithm where the model needs to compare pairs across different examples in-context; and selectively re-scale some of them for later combination. Additionally, as the number of in-context demonstrations increases, the attention map indicates that the model begins to focus on multiple examples simultaneously. This suggests that the model might be employing a potential weighted average approach over these examples, as discussed earlier.
By conducting a PCA analysis similar to that in Figure 6(a), we also identify clocks when annotating examples in the format. Unlike the previously discussed pattern, the specifics of this "clock-of-clocks" arrangement vary depending on the position and the choice of task vector . This variability suggests that the head in question is utilizing information about the specific task from the context to appropriately rescale the equations.
We further highlight the importance of the signatures we find in the above paragraphs via comparison with a model that does not generalize o.o.d, in Figure 6(c, d). Note that as the number of pre-training tasks decreases, the attention map similar to the one in panel (a) starts to be mosaicked; and at the same time the PCA projections lose their shape. As a result, those models also lose the ability to perform ICL on modular arithmetic out-of-distribution.
From Figure 5, we infer that the model has only partially acquired skill III, due to its limited capacity. On the other hand, model is very good at combining equations via skill III, explaining its superior performance.
6 Discussion
We have investigated the emergence of ICL learning and skill composition in AMs on a novel algorithmic dataset. The dataset includes a large discrete set of modular arithmetic tasks and was specifically designed to force models to learn how to solve a variety of tasks. It consists of learning linear modular functions, where the model is expected to identify and perform a modular operation in-context. When the number of tasks is relatively small, the models simply memorize the task vectors and classify the input vectors by the task vectors they have memorized. In this case, the models still have to learn how to perform the memorized tasks because they do generalize in-distribution and develop ICL capabilities. Once the number of training tasks becomes too large the models transition to a qualitatively different, algorithmic approach, where the task vector is determined at inference time.
Finally, we have investigated the learnt representations and showed that qualitatively different circuits are formed in different phases.
Limitations
We have limited ourselves to algorithmic datasets. Without further investigation, it is difficult to say what lessons can be transported to the realistic language models, and what lessons are specific to the current setting. Mechanistic interpretability analysis of deeper models proved to be much more difficult than that of a smaller models. Consequently, we still do not understand the role of every part of the network in the deeper cases.
Acknowledgments
T.H. thanks Yue Xu and Dayal Singh Kalra for helpful discussions. A.G.’s work at the University of Maryland was supported in part by NSF CAREER Award DMR-2045181, Sloan Foundation and the Laboratory for Physical Sciences through the Condensed Matter Theory Center. The authors acknowledge the University of Maryland supercomputing resources (http://hpcc.umd.edu) made available for conducting the research reported in this paper.
References
- Ahn et al. [2023] Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, and Suvrit Sra. Transformers learn to implement preconditioned gradient descent for in-context learning. In A. Oh, T. Neumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine, editors, Advances in Neural Information Processing Systems, volume 36, pages 45614–45650. Curran Associates, Inc., 2023. URL https://proceedings.neurips.cc/paper_files/paper/2023/file/8ed3d610ea4b68e7afb30ea7d01422c6-Paper-Conference.pdf.
- Akyürek et al. [2023] Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm is in-context learning? investigations with linearmodels, 2023. URL https://openreview.net/forum?id=0g0X4H8yN4I.
- Arora and Goyal [2023] Sanjeev Arora and Anirudh Goyal. A theory for emergence of complex skills in language models. arXiv preprint arXiv:2307.15936, 2023.
- Boix-Adserà et al. [2024] Enric Boix-Adserà, Omid Saremi, Emmanuel Abbe, Samy Bengio, Etai Littwin, and Joshua M. Susskind. When can transformers reason with abstract symbols? In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=STUGfUz8ob.
- Brown et al. [2020] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel Ziegler, Jeffrey Wu, Clemens Winter, Chris Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners. In H. Larochelle, M. Ranzato, R. Hadsell, M.F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 1877–1901. Curran Associates, Inc., 2020. URL https://proceedings.neurips.cc/paper_files/paper/2020/file/1457c0d6bfcb4967418bfb8ac142f64a-Paper.pdf.
- Doshi et al. [2023] Darshil Doshi, Aritra Das, Tianyu He, and Andrey Gromov. To grok or not to grok: Disentangling generalization and memorization on corrupted algorithmic datasets. arXiv preprint arXiv:2310.13061, 2023.
- Doshi et al. [2024] Darshil Doshi, Tianyu He, Aritra Das, and Andrey Gromov. On learning modular polynomials. In ICLR 2024 Workshop on Bridging the Gap Between Practice and Theory in Deep Learning, 2024. URL https://openreview.net/forum?id=QO0y9ysrgu.
- Elhage et al. [2021] Nelson Elhage, Neel Nanda, Catherine Olsson, Tom Henighan, Nicholas Joseph, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Nova DasSarma, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. A mathematical framework for transformer circuits. Transformer Circuits Thread, 2021. https://transformer-circuits.pub/2021/framework/index.html.
- Ganguli et al. [2022] Deep Ganguli, Danny Hernandez, Liane Lovitt, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Nova Dassarma, Dawn Drain, Nelson Elhage, et al. Predictability and surprise in large generative models. In Proceedings of the 2022 ACM Conference on Fairness, Accountability, and Transparency, pages 1747–1764, 2022.
- Garg et al. [2023] Shivam Garg, Dimitris Tsipras, Percy Liang, and Gregory Valiant. What can transformers learn in-context? a case study of simple function classes. 2023.
- Gromov [2023] Andrey Gromov. Grokking modular arithmetic, 2023. URL https://arxiv.longhoe.net/abs/2301.02679.
- Gu et al. [2024] Jiuxiang Gu, Chenyang Li, Yingyu Liang, Zhenmei Shi, Zhao Song, and Tianyi Zhou. Fourier circuits in neural networks: Unlocking the potential of large language models in mathematical reasoning and modular arithmetic, 2024.
- Guo et al. [2024] Tianyu Guo, Wei Hu, Song Mei, Huan Wang, Caiming Xiong, Silvio Savarese, and Yu Bai. How do transformers learn in-context beyond simple functions? a case study on learning with representations. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=ikwEDva1JZ.
- Hendel et al. [2023] Roee Hendel, Mor Geva, and Amir Globerson. In-context learning creates task vectors, 2023.
- Kirsch et al. [2024] Louis Kirsch, James Harrison, Jascha Sohl-Dickstein, and Luke Metz. General-purpose in-context learning by meta-learning transformers, 2024.
- Lin and Lee [2024] Ziqian Lin and Kangwook Lee. Dual operating modes of in-context learning, 2024.
- Liu et al. [2024] Sheng Liu, Haotian Ye, Lei Xing, and James Zou. In-context vectors: Making in context learning more effective and controllable through latent space steering, 2024.
- Liu et al. [2022] Ziming Liu, Ouail Kitouni, Niklas S Nolte, Eric Michaud, Max Tegmark, and Mike Williams. Towards understanding grokking: An effective theory of representation learning. Advances in Neural Information Processing Systems, 35:34651–34663, 2022.
- Nanda et al. [2023] Neel Nanda, Lawrence Chan, Tom Lieberum, Jess Smith, and Jacob Steinhardt. Progress measures for grokking via mechanistic interpretability. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=9XFSbDPmdW.
- Nichani et al. [2024] Eshaan Nichani, Alex Damian, and Jason D. Lee. How transformers learn causal structure with gradient descent, 2024.
- Olsson et al. [2022] Catherine Olsson, Nelson Elhage, Neel Nanda, Nicholas Joseph, Nova DasSarma, Tom Henighan, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Scott Johnston, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. In-context learning and induction heads, 2022.
- Paszke et al. [2019] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Köpf, Edward Yang, Zach DeVito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. Pytorch: An imperative style, high-performance deep learning library, 2019.
- Power et al. [2022] Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, and Vedant Misra. Grokking: Generalization beyond overfitting on small algorithmic datasets. arXiv preprint arXiv:2201.02177, 2022.
- Raventos et al. [2023] Allan Raventos, Mansheej Paul, Feng Chen, and Surya Ganguli. Pretraining task diversity and the emergence of non-bayesian in-context learning for regression. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=BtAz4a5xDg.
- Singh et al. [2023] Aaditya K Singh, Stephanie C.Y. Chan, Ted Moskovitz, Erin Grant, Andrew M Saxe, and Felix Hill. The transient nature of emergent in-context learning in transformers. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=Of0GBzow8P.
- Su et al. [2023] Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu. Roformer: Enhanced transformer with rotary position embedding, 2023.
- von Oswald et al. [2023] Johannes von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, and Max Vladymyrov. Transformers learn in-context by gradient descent, 2023.
- Wei et al. [2022] Jason Wei, Yi Tay, Rishi Bommasani, Colin Raffel, Barret Zoph, Sebastian Borgeaud, Dani Yogatama, Maarten Bosma, Denny Zhou, Donald Metzler, et al. Emergent abilities of large language models. arXiv preprint arXiv:2206.07682, 2022.
- Zhong et al. [2023] Ziqian Zhong, Ziming Liu, Max Tegmark, and Jacob Andreas. The clock and the pizza: Two stories in mechanistic explanation of neural networks. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=S5wmbQc1We.
Appendix A Experimental Details
A.1 Model and Training Hyperparameters
Architecture
We used GPT-like architectures with Rotary Positional Embedding () and ReLU activations. We fix the number of heads , embedding dimension and MLP widening factor throughout every model. We use throughout the paper. Embedding layers and the output layer are tied via weight tying.
Initialization
All linear layers and embedding layer weights are sampled from Gaussian distribution at initialization, with the exception that the last linear layer in each MLP is sampled from . No bias is used in any layer.
Optimization and Schedule
We trained most models using AdamW optimizer with learning rate , weight decay , , , , batch size , in-context examples for k steps, together with a linear warmup starting from and a cosine annealing towards the end to . Weight decay is not applied to LayerNorm layers.
Hyperparameter Choice
For models we scanned learning rates and weight decay values . Then we transfer our hyperparameters to other depths. Benefiting from the extra scaling in the initialization of the last linear in MLP, we find that the hyperparameters perform well for other depths. For larger values, we lowering down the learning rate to .
A.2 Further Details of Each Plot in the Main Text
Figure 1 (b) Phase diagram constructed using data shown in Figure 18, threshold for defining each phase is set to for the corresponding set; (c) accuracy - number of shots curve for , , and . (d, e) model with , and .
Figure 3 (a) Selected the best out of three random seeds with early stop**. (b, c) averaged over three random seeds with standard error labeled. All o.o.d. data are measured every step for randomly sampled sequences along the pre-training. (d, e) Used checkpoint at the end of pre-training, averaged over 128 random sequences sampled from .
Figure 4 (a, b) Both phase diagrams are the best selected from three random seeds with early stop**. The loss - number of shots curves are plotted with .
Figure 5 We use a model trained with and , and a model trained with and with batch size .
Figure 6 (a, b) Attention heads extracted form model trained with and . (c, d) Both models are trained with .
Appendix B Details of resources used
To generate each phase diagram and training curve, we used GPU days on NVIDIA A100 40GB, with automatic mixed precision (BFloat16) and Flash Attention implemented in the PyTorch library[22]. During the exploring stage, we also used around another GPU days. Most inferences were running on 1/7-NVIDIA A100 40GB GPU, of which the cost is negligible compared to the pre-training cost.
Appendix C Table of log map for
Appendix D Task Selection and Sequence Design
Task Selection
During the initial exploration phase, we observed that it was challenging for the model to learn multiple modular arithmetic tasks simultaneously. Typically, the loss would stagnate at a plateau indefinitely.
We hypothesize that when the model is trained on multiple modular arithmetic tasks, the strong bias inherent in each individual task may interfere significantly. When tasks are selected randomly, the resultant noise appears to inhibit the learning capabilities of the model, preventing it from acquiring any meaningful patterns or rules from the data.
Ultimately, we adopted a more structured approach to task selection as illustrated in Figure Figure 2, based on the following rationale: If the model is to learn a specific task vector , it would presumably be more straightforward if one of the two components – either or – has already learned by the model. This would leave the model with only the other component to decipher, which we assume is a comparatively simpler task. Thus, we decided to employ the rectangular rule for sampling task vectors, which we believe strategically reduces the complexity of the learning process by partially leveraging previously acquired knowledge.
In Figure 7, we show an ablation plot of phase diagrams with a randomly selected task vector collection for pre-training. We still use the special sequence design.
![Refer to caption](x11.png)
Sequence Design
Following a similar spirit, we use a balanced batch where sequences generated from all task vectors appear exactly the same number of times during the training. We further align the examples across sequences generated by different task vectors, which we believe reduces the chance of the model getting confused by the same input appearing at different positions within the batch. Without this design, we could not make the model train.
Here, we also show the training curve per task in Figure 8, trained with all of our tricks. We see that the model first learned very few tasks and then eventually found its way out. For training without the task selection and sequence design, the loss typically plateau around .
![Refer to caption](x12.png)
![Refer to caption](x13.png)
![Refer to caption](x14.png)
![Refer to caption](x15.png)
Appendix E Additional Interpretability Results
In this section, we show additional results on interpretability. Where we first show evidence in LABEL:subsecapp:pattern that a model also implements a version of the scale-and-combine algorithm we proposed in Section 5, and it can do better than the model. Then, we continue our discussion on how the algorithm might be implemented inside the model. Importantly, This includes the role of embedding (Section E.1), the role of other attention heads (Section E.2), and finally, the role of MLP and LayerNorm (Section E.3).
E.1 PCA over Embeddings
We begin the discussion with the role of embedding layers. By further examining different models, we find highly structured embedding from models. models, on the other hand, does not have such structured embedding layers.
First we focus the embedding of models. As shown in Figure 9, clearly, the logarithm of each number is split into even and odd groups, and each group forms one clock, which is a suitable embedding for doing modular multiplications. However, one should note that this will not obviate the importance of the head shown in Figure 6(a), as the model still needs a way to distinguish the same number that appears in the different positions in the context.
Curiously, we could not find such a structured embedding for the model, as shown in Figure 10. However, as we will show in the next subsection, this non-structural embedding, together with the first layer, prepares a foundation for the essential heads in the latter layer to create a similar “clock-of-clock" feature as we shown in Figure 6 (a, b).
![Refer to caption](x16.png)
![Refer to caption](x17.png)
![Refer to caption](x18.png)
![Refer to caption](x19.png)
![Refer to caption](x20.png)
![Refer to caption](x21.png)
E.2 Attention Heads
![Refer to caption](x22.png)
![Refer to caption](x23.png)
![Refer to caption](x24.png)
![Refer to caption](x25.png)
![Refer to caption](x26.png)
![Refer to caption](x27.png)
E.2.1 Model
To continue the story, we analyse the attention heads in different models. We first study the model, where we also find a similar head with “clock-of-clocks" (Figure 6(a)). Moreover, two other heads put together are seemingly equivalent to Figure 6(b). We surmise that any model needs that solves modular arithmetic in context requires such heads.
From Figure 11(a), we see that the head still pays attention locally within three token positions. Importantly, it also creates a clock-of-clock while performing PCA over concatenated features. The difference here is that the clock winds twice to go back to its origin. We believe that this factor of two differences in the period means that this head is effectively combing the row of the embedding layer of the model and the role of the head in Figure 6(a). Overall, we do not yet know why the model needs to split the logarithm of numbers into even and odd groups.
From Figure 11(b, c), we find two heads that are structured but different from Figure 6(b). The PCA pattern forms clocks while annotated within logarithm space, but depends on the choice of the sequences. This hints that those two heads are trying to re-scale the examples by comparing them with the previous inputs. We think that the two can be combined to form a similar pattern as Figure 6(b).
E.2.2 Model
Next we focus on the model. In Figure 12, we plot similar PCA to the one in Figure 6(a), but with different concatenated features from the same heads: and . We again see clocks-of-clock, solidifying our argument that this head provides proper spaces for later heads to perform modular operations.
![Refer to caption](x28.png)
![Refer to caption](x29.png)
Finally, in Figure 13, we dump all the attention patterns we had for models. The corresopnding PCA over concatenated features are shown in Figure 14 (except for layer 1 head 4, for which we plot for ). The special head we choose here has a similar behavior to the one in Figure 6(a), with its main focus shifted by one position but still within two preceding tokens. We believe this head implements another variant of Equation 2 – it re-scales equations for the next layer, then helps them figure out how to combine them to predict new examples.
We have reasons to believe that the non-clock heads are not essential to the models’ performance. However, we leave a careful exploration of this front for future work.
![Refer to caption](x30.png)
![Refer to caption](x31.png)
![Refer to caption](x32.png)
![Refer to caption](x33.png)
![Refer to caption](x34.png)
![Refer to caption](x35.png)
![Refer to caption](x36.png)
![Refer to caption](x37.png)
![Refer to caption](x38.png)
![Refer to caption](x39.png)
![Refer to caption](x40.png)
![Refer to caption](x41.png)
![Refer to caption](x42.png)
![Refer to caption](x43.png)
![Refer to caption](x44.png)
![Refer to caption](x45.png)
E.3 MLP and LayerNorm
The arguments in Section 5 and the experiments in the previous subsection, strongly suggest that skill III is implemented within the MLP layer. We tried to hunt for signals similar to those in Nanda et al. [19], Gromov [11] from MLP layers, but did not discover conclusive evidence. Similarly, no obvious signals were found in LayerNorm. We think it is crucial to study these layers more carefully. We leave this task for future work.
E.4 Label Noise
To gain insight into how the model combines the in-context examples, we introduce label-corruption in the in-context examples. In particular, we note the effect of (i) amount and (ii) position of label corruption on the model’s performance. When we corrupt a single in-context example for model, the model performance remains unaffected for longer sequences. This hints at that weighted average of the in-context inputs being used in model prediction. The model, however, did not show such resilience.
Next, we corrupt multiple in-context examples in random locations. We study the effect on model performance as the amount of corrupted labels increases. While the model is easily overwhelmed, the model is able to offer strong resistance even at label corruption, for long sequences. This behavior remains invariant with the change in task vector for the particular sequence, indicating the universality of the underlying algorithm necessary for o.o.d. generalization.
![Refer to caption](x46.png)
![Refer to caption](x47.png)
![Refer to caption](x48.png)
Appendix F Additional Training Curves
We plot some selected training curves for (Figure 16) and (Figure 17) from Figure 4 phase diagrams. We see that even for , ICL can be a transient. With increased or , the transient nature goes away.
![Refer to caption](x49.png)
![Refer to caption](x50.png)
![Refer to caption](x51.png)
![Refer to caption](x52.png)
![Refer to caption](x53.png)
![Refer to caption](x54.png)
![Refer to caption](x55.png)
![Refer to caption](x56.png)
Appendix G Additional Phase Diagrams
In Figure 18, we plotted detailed/extended versions of the phase diagrams shown in Figure 1/Figure 4. The four phases story we have shown in Figure 1 still hold for other depths.
![Refer to caption](x57.png)
![Refer to caption](x58.png)
![Refer to caption](x59.png)
Appendix H Different Choice of
In this section, we check the effect of varying task difficulties, i.e. the value of . In Figure 19, we plotted o.o.d. generalization accuracy. Clearly as the task gets harder, the model needs to see more tasks to generalize out-of-distribution.
![Refer to caption](x60.png)
![Refer to caption](x61.png)