Learning to grok: Emergence of in-context learning and skill composition in modular arithmetic tasks

Tianyu He a, b, †
&Darshil Doshi a, b, *
&Aritra Das b, *
&Andrey Gromov c, a, b    
Abstract

Large language models can solve tasks that were not present in the training set. This capability is believed to be due to in-context learning and skill composition. In this work, we study the emergence of in-context learning and skill composition in a collection of modular arithmetic tasks. Specifically, we consider a finite collection of linear modular functions z=ax+bymodp𝑧𝑎𝑥𝑏𝑦mod𝑝z=a\,x+b\,y\;\mathrm{mod}\;pitalic_z = italic_a italic_x + italic_b italic_y roman_mod italic_p labeled by the vector (a,b)p2𝑎𝑏superscriptsubscript𝑝2(a,b)\in\mathbb{Z}_{p}^{2}( italic_a , italic_b ) ∈ blackboard_Z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. We use some of these tasks for pre-training and the rest for out-of-distribution testing. We empirically show that a GPT-style transformer exhibits a transition from in-distribution to out-of-distribution generalization as the number of pre-training tasks increases. We find that the smallest model capable of out-of-distribution generalization requires two transformer blocks, while for deeper models, the out-of-distribution generalization phase is transient, necessitating early stop**. Finally, we perform an interpretability study of the pre-trained models, revealing the highly structured representations in both phases; and discuss the learnt algorithm.

11footnotetext: Condensed Matter Theory Center, University of Maryland, College Park22footnotetext: Department of Physics, University of Maryland, College Park33footnotetext: Meta AI44footnotetext: Corresponding author55footnotetext: These authors contributed equally66footnotetext: Source Code: https://github.com/ablghtianyi/ICL_Modular_Arithmetic

{tianyuh, ddoshi, aritrad}@umd.edu                     [email protected]

1 Introduction

Large language models (LLMs) can perform simple tasks that have not been present in the training set. This ability is usually achieved via in-context learning [5]. More importantly, LLMs can perform an even larger variety of very complex tasks upon appropriate prompting or fine-tuning. The latter ability to perform complex tasks is usually attributed to the following mechanism. First, LLMs learn a large variety of simple tasks and, then, how to compose those skills to form very complex skills [3]. Furthermore, LLMs also exhibit “emergent capabilities” – a sudden emergence of a new complex skill as a function of scale (either model parameters, compute or data) [28, 9]. It is plausible that these sudden performance improvements are due to one or both of these mechanisms. For example, LLMs show grokking on algorithmic tasks [23], which results from the model learning very structured representations [18, 11, 19]. Once these representations emerge, the model abruptly learns how to perform the task.

In this work, we set out to examine skill composition both empirically and mechanistically. Inspired by the prior work that investigated emergence of in-context learning on linear regression tasks [1, 24], we introduce a finite collection of discrete modular arithmetic tasks [23] generalized to the in-context learning setting. Each task corresponds to learning a linear function z=ax+bymodp𝑧𝑎𝑥𝑏𝑦mod𝑝z=a\,x+b\,y\,\,\textrm{mod}\,\,pitalic_z = italic_a italic_x + italic_b italic_y mod italic_p over psubscript𝑝\mathbb{Z}_{p}blackboard_Z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT from the examples provided in context of the autoregressive model (AM). In the bi-variate case there are p2superscript𝑝2p^{2}italic_p start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT such functions labeled by the vector (a,b)𝑎𝑏(a,b)( italic_a , italic_b ). The main objective of this algorithmic dataset is to probe how AM utilizes the tasks it has learnt during training to solve the new tasks.

Our analysis shows that the solution found by the AM after optimization is qualitatively different from the linear regression cases studied before [1]. In those cases, due to the continuous nature of the task, AM develops an emergent first-order optimization method that minimizes an emergent quadratic loss function. Furthermore, as it was shown in [1], a single linear attention layer can solve the regression problem, while adding extra layers and non-linearities slightly modifies the gradient descent. In the modular arithmetic case, AM first learns how to solve the pre-training tasks and later (assuming enough different tasks) develops a generalizing solution by combining the solved tasks.

Our main findings as well as the structure of the algorithmic dataset are illustrated on Fig. 1. Our main findings are: (i) there are four different phases in the end of pre-training depending on the number of tasks, ni.d.subscript𝑛formulae-sequenceidn_{\mathrm{i.d.}}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT, and number of examples per task, α𝛼\alphaitalic_α. (ii) At inference time, there is a generalization transition in the number of few-shot examples, as the number of examples grows, the models starts to generalize. This effect is somewhat similar to the transition in sample complexity for the modular arithmetic found in [23]. (iii) model develops a striking circular representation for all of the tasks that naturally generalizes the circular representations found in the original work [23]. We further find that the deeper models are easier to optimize, but much harder to interpret. The optimization is discussed in more detail in the main text. Here we will highlight that optimization for these tasks is challenging and the AM tends to prefer the minimum that just solve a handful of tasks and memorize the training set. To avoid such minima we make sure that every batch contains equal number of tasks (meaning that no tasks is over- or under-represented in each batch). We further find that for larger models early stop** is necessary because the generalizing solution is transient.

Refer to caption
Figure 1: (a) The dataset. The tasks are labeled by vectors (a,b)p2𝑎𝑏superscriptsubscript𝑝2(a,b)\in\mathbb{Z}_{p}^{2}( italic_a , italic_b ) ∈ blackboard_Z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. Each table contains examples of ax+bymodp𝑎𝑥𝑏𝑦mod𝑝ax+by\,\,\textrm{mod}\,\,pitalic_a italic_x + italic_b italic_y mod italic_p. A fraction 1α1𝛼1-\alpha1 - italic_α of the examples is blacked out. The remaining examples are flattened into a single “document” in the batch. Each document is organized as a collection of triples (x,y,ax+by)𝑥𝑦𝑎𝑥𝑏𝑦(x,y,ax+by)( italic_x , italic_y , italic_a italic_x + italic_b italic_y ) for x,y𝑥𝑦x,yitalic_x , italic_y from the training set (i.e. not blacked out in the table). To leverage training compute more effectively we predict not just the last element, but also other elements colored in red. This is similar to the traditional next-token prediction, with the main difference that we predict each third token (x𝑥xitalic_x and y𝑦yitalic_y do not have any correlation between each other). Every task appears exactly the same number of times in each batch. (b) Phase diagram for a six-layer model. We find four different phases. (1) in-distribution memorization: model does not generalize on x,y𝑥𝑦x,yitalic_x , italic_y it has not seen for the training and test tasks. (2) in-distribution generalization: model generalizes on x,y𝑥𝑦x,yitalic_x , italic_y it has not seen for the training tasks, but not for test tasks. (3) out-of-distribution memorization: model does not generalize on x,y𝑥𝑦x,yitalic_x , italic_y it has not seen for test tasks, but completely generalizes on train tasks. (4) out-of-distribution generalization: model generalizes on all x,y𝑥𝑦x,yitalic_x , italic_y for both train and test tasks. We focus on investigating phase (4) in more detail. (c) In-context sample complexity. Accuracy of the model in phase (4) as a function of the number of few-shot examples. (d) Representations developed by one of the attention heads in the first layer. These are projections of the embedding of a pair of numbers onto the two largest principal components (PCs) of the internal representation formed after passing through the attention layer and projection matrix. (e) The first 3 principal components of the embeddings partition the log27𝑙𝑜subscript𝑔27log_{27}italic_l italic_o italic_g start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT-annotated numbers into separate planes of even and odd groups. The number 0 is special, and is sandwiched between the two planes.

We organize our paper as follows. Section 2 contains the literature review. In Section 3 we explain our notations and discuss the experimental details. In Section 4 we demonstrate empirically that the out-of-distribution ICL ability emerges as the number of training tasks increases. We also study the effects of model depth and task difficulty. In Section 5 we carefully examine a minimal setting, i.e. two-block transformer: we compare the representations learnt in four different phases and show that in the generalizing phase the representations are highly structured and generalize the original modular addition case of [23].

2 Related Works

In-Context Learning (ICL)

Brown et al. [5] first demonstrated that large models performance improves substantially when a few examples of the task at hand are provided at inference time, in the prompt. Akyürek et al. [2], Ahn et al. [1], von Oswald et al. [27] showed that the AM implements emergent first-order optimization on an emergent objective function to solve linear regression tasks. Furthermore, [2] showed that larger models learn to perform Bayesian estimation. Garg et al. [10] demonstrated that transformers can learn several simple classes of functions in context. Kirsch et al. [15] presented how task diversity and model size would affect the ICL performance for unseen tasks using a mixture of modified MNIST datasets. Raventos et al. [24] investigated the relation between task diversity and out-of-distribution ICL ability on linear regression tasks. Lin and Lee [16] identified two operating modes in ICL using a mixture of linear regression tasks, where for the first several shots, the model tries to figure out the correct task vector and later uses it to predict the correct results. Boix-Adserà et al. [4] showed theoretically and experimentally that with enough pre-training data, a transformer model can perform abstract reasoning that a MLP cannot do. Guo et al. [13] showed that transformers can use lower layers to memorize and upper layers to perform ICL in a feature regression setting. It was found in [25] that ICL is a transient phase from the optimization point of view: it goes away once the model is over-trained. Hendel et al. [14], Liu et al. [17] showed that language models form in-context vectors, which can be extracted and used to control model predictions.

Modular Arithmetic

Power et al. [23] discovered Grokking, where models trained on modular arithmetic datasets have an abrupt change from random guessing to generalization on the test set way after the model memorized the training set. Gromov [11], Nanda et al. [19], Gu et al. [12] showed that for modular addition tasks, models learned to map integers to Fourier features to solve modular arithmetic tasks. Liu et al. [18] showed that grokking is related to learning highly structural features, and the grokking transition can be explained by a toy model. Zhong et al. [29] showed that there is more than one algorithm that a model can implement to solve modular addition. Doshi et al. [6] showed that corruption of the label does not prevent the models from finding a generalizing solution. Doshi et al. [7] showed that MLP and transformer models can solve a specific family of modular polynomial tasks by bijectively map** them to modular addition tasks.

Interpretability

Elhage et al. [8], Olsson et al. [21] showed that transformers can form induction heads that predict the next token in a sequence by identifying and copying patterns from earlier in the sequence. With several indirect empirical evidences, they showed that those heads might constitute the core mechanism of ICL. Nichani et al. [20] showed theoretically and empirically how disentangled transformers learn causal structures from in-context Markov chains by forming induction heads.

3 Preliminaries

Linear Modular Functions

We consider modular arithmetic tasks of the form: zit=atxi+btyimodpsubscriptsuperscript𝑧𝑡𝑖superscript𝑎𝑡subscript𝑥𝑖superscript𝑏𝑡subscript𝑦𝑖mod𝑝z^{t}_{i}=a^{t}x_{i}+b^{t}y_{i}\;\mathrm{mod}\;pitalic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_mod italic_p. We will refer to the coefficients (at,bt)p2superscript𝑎𝑡superscript𝑏𝑡superscriptsubscript𝑝2(a^{t},b^{t})\in\mathbb{Z}_{p}^{2}( italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) ∈ blackboard_Z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT as the task vector. The superscript t[1,p2]𝑡1superscript𝑝2t\in[1,p^{2}]italic_t ∈ [ 1 , italic_p start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] labels the p2superscript𝑝2p^{2}italic_p start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT possible tasks. We will refer to (xi,yi)p2subscript𝑥𝑖subscript𝑦𝑖superscriptsubscript𝑝2(x_{i},y_{i})\in\mathbb{Z}_{p}^{2}( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ blackboard_Z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT as the input vector, which is labeled by the subscript i[1,p2]𝑖1superscript𝑝2i\in[1,p^{2}]italic_i ∈ [ 1 , italic_p start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ].

In-Context Learning with Transformers

We use GPT-like transformers [5] with ReLU activation function and Rotary Positional Embedding (RoPE) [26]. The model has d𝑑ditalic_d consecutive blocks, H𝐻Hitalic_H attention-heads, and embedding dimension dembedsubscript𝑑embedd_{\mathrm{embed}}italic_d start_POSTSUBSCRIPT roman_embed end_POSTSUBSCRIPT. Each number is tokenized as an independent token. The pre-training is done following a slightly modified next-token prediction setup, with sequences of the form:

𝒔t=(x1y1z1tx2y2z2txnctxynctxznctxt)p3×nctx,superscript𝒔𝑡matrixsubscript𝑥1subscript𝑦1superscriptsubscript𝑧1𝑡subscript𝑥2subscript𝑦2subscriptsuperscript𝑧𝑡2subscript𝑥subscript𝑛ctxsubscript𝑦subscript𝑛ctxsubscriptsuperscript𝑧𝑡subscript𝑛ctxsuperscriptsubscript𝑝3subscript𝑛ctx\displaystyle\bm{s}^{t}=(\begin{matrix}x_{1}&y_{1}&\color[rgb]{1,0,0}% \definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}{z_{1}^{t}}&x_{2}&y_{2}&\color[% rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}{z^{t}_{2}}&\cdots&x% _{n_{\mathrm{ctx}}}&y_{n_{\mathrm{ctx}}}&\color[rgb]{1,0,0}\definecolor[named]% {pgfstrokecolor}{rgb}{1,0,0}{z^{t}_{n_{\mathrm{ctx}}}}\end{matrix})\in\mathbb{% Z}_{p}^{3\times n_{\mathrm{ctx}}}\,,bold_italic_s start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = ( start_ARG start_ROW start_CELL italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_CELL start_CELL italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_x start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT roman_ctx end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_CELL start_CELL italic_y start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT roman_ctx end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_CELL start_CELL italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT roman_ctx end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) ∈ blackboard_Z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 3 × italic_n start_POSTSUBSCRIPT roman_ctx end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , (1)

where nctxsubscript𝑛ctxn_{\mathrm{ctx}}italic_n start_POSTSUBSCRIPT roman_ctx end_POSTSUBSCRIPT is the maximum number of in-context examples. The model is asked to predict only the labels {z1t,,znctxt}subscriptsuperscript𝑧𝑡1subscriptsuperscript𝑧𝑡subscript𝑛ctx\{z^{t}_{1},\cdots,z^{t}_{n_{\mathrm{ctx}}}\}{ italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT roman_ctx end_POSTSUBSCRIPT end_POSTSUBSCRIPT }. We emphasize that we do not explicitly provide the task vectors (at,bt)superscript𝑎𝑡superscript𝑏𝑡(a^{t},b^{t})( italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) to the model (see Fig. 1) – this information is implicit in the in-context labels {zit}subscriptsuperscript𝑧𝑡𝑖\{z^{t}_{i}\}{ italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }. In order for the model to generalize, it must determine the underlying task vector (at,bt)superscript𝑎𝑡superscript𝑏𝑡(a^{t},b^{t})( italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) from the few-shot examples.

Generalization

There are two notions of generalization in this setup. (i) In-distribution: Generalization to unseen input vectors (xi,yi)subscript𝑥𝑖subscript𝑦𝑖(x_{i},y_{i})( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), but on task vector (at,bt)superscript𝑎𝑡superscript𝑏𝑡(a^{t},b^{t})( italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) the model has seen during pre-training. (ii) Out-of-distribution: Generalization to task vectors the model has not seen during pre-training. To clearly separate these regimes, we split the task vectors into in-distribution (i.d.) set 𝒯i.d.{(at,bt)}i.d.subscript𝒯formulae-sequenceidsubscriptsuperscript𝑎𝑡superscript𝑏𝑡formulae-sequenceid\mathcal{T}_{\mathrm{i.d.}}\coloneqq\{(a^{t},b^{t})\}_{\mathrm{i.d.}}caligraphic_T start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT ≔ { ( italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT and out-of-distribution (o.o.d.) set 𝒯o.o.d.{(at,bt)}o.o.d.subscript𝒯formulae-sequenceoodsubscriptsuperscript𝑎𝑡superscript𝑏𝑡formulae-sequenceood\mathcal{T}_{\mathrm{o.o.d.}}\coloneqq\{(a^{t},b^{t})\}_{\mathrm{o.o.d.}}caligraphic_T start_POSTSUBSCRIPT roman_o . roman_o . roman_d . end_POSTSUBSCRIPT ≔ { ( italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT roman_o . roman_o . roman_d . end_POSTSUBSCRIPT. Similarly, we split the input vectors into train and test sets: 𝒳train{(xi,yi)}train,𝒳test{(xi,yi)}testformulae-sequencesubscript𝒳trainsubscriptsubscript𝑥𝑖subscript𝑦𝑖trainsubscript𝒳testsubscriptsubscript𝑥𝑖subscript𝑦𝑖test\mathcal{X}_{\mathrm{train}}\coloneqq\{(x_{i},y_{i})\}_{\mathrm{train}},% \mathcal{X}_{\mathrm{test}}\coloneqq\{(x_{i},y_{i})\}_{\mathrm{test}}caligraphic_X start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ≔ { ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT , caligraphic_X start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT ≔ { ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT. This results in four distinct sets of sequences constructed from those sets; we name them Straini.d.,Stesti.d.,Straino.o.d.subscriptsuperscript𝑆formulae-sequenceidtrainsubscriptsuperscript𝑆formulae-sequenceidtestsubscriptsuperscript𝑆formulae-sequenceoodtrainS^{\mathrm{i.d.}}_{\mathrm{train}},S^{\mathrm{i.d.}}_{\mathrm{test}},S^{% \mathrm{o.o.d.}}_{\mathrm{train}}italic_S start_POSTSUPERSCRIPT roman_i . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT , italic_S start_POSTSUPERSCRIPT roman_i . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT , italic_S start_POSTSUPERSCRIPT roman_o . roman_o . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT and Stesto.o.d.subscriptsuperscript𝑆formulae-sequenceoodtestS^{\mathrm{o.o.d.}}_{\mathrm{test}}italic_S start_POSTSUPERSCRIPT roman_o . roman_o . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT. The set Straini.d.subscriptsuperscript𝑆formulae-sequenceidtrainS^{\mathrm{i.d.}}_{\mathrm{train}}italic_S start_POSTSUPERSCRIPT roman_i . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT is used for pre-training, while the other three sets are used for evaluations.

Pre-Training Task Selection and Sequence Design

We always sample the pre-training task vectors 𝒯i.d.subscript𝒯formulae-sequenceid\mathcal{T}_{\mathrm{i.d.}}caligraphic_T start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT in sets of 4, following the rectangular rule, shown in Figure 2(a). Additionally, each batch contains an equal representation from all the task vectors in the set 𝒯i.d.subscript𝒯formulae-sequenceid\mathcal{T}_{\mathrm{i.d.}}caligraphic_T start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT. Moreover, all the tasks share the same sequence of inputs. For example, a batch with two different task vectors (at1,bt1);(at2,bt2)superscript𝑎subscript𝑡1superscript𝑏subscript𝑡1superscript𝑎subscript𝑡2superscript𝑏subscript𝑡2(a^{t_{1}},b^{t_{1}});(a^{t_{2}},b^{t_{2}})( italic_a start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) ; ( italic_a start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) and two distinct input sequences per task (resulting in four total sequences) is shown in Figure 2(b).

Refer to caption
Figure 2: Structured selection of pre-training tasks and sequences.

This structured approach creates a coherent signal from the sequences within each batch; ensuring that the model learns multiple task vectors with reasonable batch sizes. Alternatively, if the batches are sampled i.i.d., then the model is confused by the batch noise and cannot learn any tasks.

Detailed discussions on task selection and sequence design are presented in Appendix D.

Default Setting

Unless stated explicitly, we will use p=29𝑝29p=29italic_p = 29, the number of heads H=4𝐻4H=4italic_H = 4, and embedding dimension dembed=512subscript𝑑embed512d_{\mathrm{embed}}=512italic_d start_POSTSUBSCRIPT roman_embed end_POSTSUBSCRIPT = 512, with nctx=32subscript𝑛ctx32n_{\mathrm{ctx}}=32italic_n start_POSTSUBSCRIPT roman_ctx end_POSTSUBSCRIPT = 32 in-context examples. All models are trained with AdamW optimizer with batch size B=1024𝐵1024B=1024italic_B = 1024 for 200200200200k steps. We have also tied the embedding layer of the model with the readout layer.

4 Emergence of In-Context Learning and Task Composition

In this section, we demonstrate that a transformer model with depth d2𝑑2d\geq 2italic_d ≥ 2 can develop ICL and out-of-distribution generalization on modular arithmetic tasks. We delve deeper into the two notions of generalization (i.d. and o.o.d.), and discuss the relevant factors. We find that the model’s ability to generalize out-of-distribution is predominantly determined by the number of pretraining tasks ni.d.subscript𝑛formulae-sequenceidn_{\mathrm{i.d.}}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT.

4.1 Transition driven by the number of tasks

Refer to caption
(a) Accuracy phase diagrams of all four sets
Refer to caption
(b) Loss - step
Refer to caption
(c) Accuracy - step
Refer to caption
(d) Loss - # of shots
Refer to caption
(e) Accuracy - # of shots
Figure 3: Phase diagram for the depth d=6𝑑6d=6italic_d = 6 models. (a) Accuracy on all four sets used to plot the 1 phase diagram, with an early stop** applied. Notably, in the regions when models generalize to o.o.d. sets, the pre-training performance degrades; (b, c) α=0.6𝛼0.6\alpha=0.6italic_α = 0.6 training accuracy and o.o.d. test accuracy (dotted line). For ni.d.=28subscript𝑛formulae-sequenceidsuperscript28n_{\mathrm{i.d.}}=2^{8}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT, we notice that the o.o.d. generalization ability of the model first improves then degrades as we train longer; (d, e) α=0.6𝛼0.6\alpha=0.6italic_α = 0.6, loss and accuracy vs context length, measured on Stesto.o.d.subscriptsuperscript𝑆formulae-sequenceoodtestS^{\mathrm{o.o.d.}}_{\mathrm{test}}italic_S start_POSTSUPERSCRIPT roman_o . roman_o . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT at the end of training, where for ni.d.=28subscript𝑛formulae-sequenceidsuperscript28n_{\mathrm{i.d.}}=2^{8}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT case the ICL ability fades away.

In Figure 3(a), we show the accuracy of d=6𝑑6d=6italic_d = 6 models vs the number of training tasks ni.d.subscript𝑛formulae-sequenceidn_{\mathrm{i.d.}}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT and the number of few-shot examples quantified by the fraction of the total number of data points, α𝛼\alphaitalic_α; on sets Straini.d.subscriptsuperscript𝑆formulae-sequenceidtrainS^{\mathrm{i.d.}}_{\mathrm{train}}italic_S start_POSTSUPERSCRIPT roman_i . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT, Stesti.d.subscriptsuperscript𝑆formulae-sequenceidtestS^{\mathrm{i.d.}}_{\mathrm{test}}italic_S start_POSTSUPERSCRIPT roman_i . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT, Straino.o.d.subscriptsuperscript𝑆formulae-sequenceoodtrainS^{\mathrm{o.o.d.}}_{\mathrm{train}}italic_S start_POSTSUPERSCRIPT roman_o . roman_o . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT and Stesto.o.d.subscriptsuperscript𝑆formulae-sequenceoodtestS^{\mathrm{o.o.d.}}_{\mathrm{test}}italic_S start_POSTSUPERSCRIPT roman_o . roman_o . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT. The phase diagram in Figure 1 is constructed by merging the last shot accuracy version of these four diagrams shown in Figure 18(a) of Appendix G.

The ability of the model to generalize in-distribution increases with α𝛼\alphaitalic_α, as can be seen by comparing the first two panels of Figure 3(a). This behavior is in correspondence with the original work on grokking, where the transition to generalizing solution is driven by the amount of data. Further, we observe that an increase in ni.d.subscript𝑛formulae-sequenceidn_{\mathrm{i.d.}}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT enhances the in-context sample efficiency, i.e. the model generalizes at inference time with fewer few-shot examples. This indicates the onset of the transition from a task-memorizing solution to the one that generalizes out-of-distribution. The model switches to a new algorithmic way of solving the task and the solution is more few-shot-sample-efficient.

Shifting our focus to the last two panels of Figure 3(a), we see that when ni.d.256subscript𝑛formulae-sequenceid256n_{\mathrm{i.d.}}\geq 256italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT ≥ 256, the model can solve new tasks that were absent in the training set. Notably, there appears to be a trade-off between memorization and generalization when the model attains this o.o.d. generalization ability. As the o.o.d. performance increases, the pre-training performance simultaneously degrades. This phenomenon indicates a shift in the algorithm implemented by the model. Prior to this transition, the model primarily needed to select possible vectors from the list of memorized tasks and apply them. However, post-transition, the model adopts a more universal approach to solve the task in-context. We emphasize, that the model learns to perform ICL in both scenarios. The difference lies in the approach to generalization. When the model can only generalize in-distribution it’s task is to classify the sequence as one of the seen tasks or as unknown. Once it matches the sequence to one of the memorized task vectors, it does well for x,y𝑥𝑦x,yitalic_x , italic_y pairs that only appear in the test set. However, as the number of tasks vectors grows the model fails to store them all and is forced to find a method of determining the task vector algorithmically at inference time. In that case model performs equally well on seen and un-seen tasks alike. In fact, the small two-layer model we study has such a low capacity that it entirely skips the in-distribution generalization phase and immediately jumps from pure memorization to out-of-distribution generalization.

Next, to further illustrate the effect of task diversity, we plot the pre-training accuracy (set Straini.d.subscriptsuperscript𝑆formulae-sequenceidtrainS^{\mathrm{i.d.}}_{\mathrm{train}}italic_S start_POSTSUPERSCRIPT roman_i . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT) and the o.o.d. test accuracy (set Stesto.o.d.subscriptsuperscript𝑆formulae-sequenceoodtestS^{\mathrm{o.o.d.}}_{\mathrm{test}}italic_S start_POSTSUPERSCRIPT roman_o . roman_o . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT) as a function of training steps (Figure 3(b, c)); for various values of ni.d.subscript𝑛formulae-sequenceidn_{\mathrm{i.d.}}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT. We observe a clear memorization-to-generalization transition as task diversity increases. Interestingly, for ni.d.=28subscript𝑛formulae-sequenceidsuperscript28n_{\mathrm{i.d.}}=2^{8}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT, the ICL ability on Stesto.o.d.subscriptsuperscript𝑆formulae-sequenceoodtestS^{\mathrm{o.o.d.}}_{\mathrm{test}}italic_S start_POSTSUPERSCRIPT roman_o . roman_o . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT set exhibits non-monotonic behavior, where the o.o.d. performance rises and falls along the training process. This phenomenon is likely due to a competition between the memorizing and generalizing circuits inside the model. Note that this phenomenon is akin to the one analyzed in Singh et al. [25], albeit with a different setup.

Further evidence supporting the two-circuit competition can be observed in panel (d) of Figure 3. The loss curves show a “monotonic \rightarrow non-monotonic \rightarrow monotonic" transition as the task diversity increases. With a minimal number of pre-training tasks, the model primarily engages in memorization, resulting in a monotonically increasing o.o.d. loss curve. As the number of pre-training tasks increases, the loss curve exhibits non-monotonic behavior, indicating competition between two distinct neural circuits. This transient nature of o.o.d. generalization for ni.d.=28subscript𝑛formulae-sequenceidsuperscript28n_{\mathrm{i.d.}}=2^{8}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT is a peculiar case where memorization circuits are initially suppressed but eventually prevail. With substantial task diversity, the circuits responsible for generalization take over, culminating in a monotonic loss curve. Similar insights can be derived from examining the monotonicity of the accuracy curves in panel (e).

4.2 Effect of Model Size and Task Difficulty

Refer to caption
Figure 4: Phases of depth d=4𝑑4d=4italic_d = 4 and d=2𝑑2d=2italic_d = 2 models. With decreasing model capacity, the performance on both sets degrades. At the same time, the transient nature of ICL does not appear. (a, b) from left to right: accuracy phase diagrams on pre-training set Straini.d.subscriptsuperscript𝑆formulae-sequenceidtrainS^{\mathrm{i.d.}}_{\mathrm{train}}italic_S start_POSTSUPERSCRIPT roman_i . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT and on o.o.d. test set Stesto.o.d.subscriptsuperscript𝑆formulae-sequenceoodtestS^{\mathrm{o.o.d.}}_{\mathrm{test}}italic_S start_POSTSUPERSCRIPT roman_o . roman_o . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT, with early stop**; loss and accuracy vs context length on o.o.d. test set Stesto.o.d.subscriptsuperscript𝑆formulae-sequenceoodtestS^{\mathrm{o.o.d.}}_{\mathrm{test}}italic_S start_POSTSUPERSCRIPT roman_o . roman_o . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT for α=0.6𝛼0.6\alpha=0.6italic_α = 0.6.

A natural question to ask is if similar phenomena can be observed with different model sizes or task difficulties. Here we present our results with d=4𝑑4d=4italic_d = 4 and d=2𝑑2d=2italic_d = 2 in Figure 4 and leave the results for other prime p𝑝pitalic_p values in Appendix H.

When comparing phase diagrams in Figure 3 with Figure 4, we observe that those phase diagrams across different depths are qualitatively similar, where the o.o.d. generalization only emerges with a large enough number of pre-training tasks. As model capacity decreases, performance on both the pre-training set and the o.o.d. test set degrades. This is particularly evident in the d=2𝑑2d=2italic_d = 2 case, where the pre-training accuracy falls drastically as the model gains o.o.d. generalization ability.

Interesting observations can be made by comparing loss and accuracy on the o.o.d. test set as a function of context length at the end of training. First, it is evident that as the model depth decreases, the 1111-shot loss surge attributable to memorization becomes milder. Notably, for models with ni.d.=29subscript𝑛formulae-sequenceidsuperscript29n_{\mathrm{i.d.}}=2^{9}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT, there is no loss surge in the 1111-shot case across all three depths. Furthermore, the d=4𝑑4d=4italic_d = 4 model with ni.d.=28subscript𝑛formulae-sequenceidsuperscript28n_{\mathrm{i.d.}}=2^{8}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT behaves significantly differently from the corresponding one with d=2𝑑2d=2italic_d = 2 case, where the model fails to perform ICL for the o.o.d. test set. This is also distinct from the d=6𝑑6d=6italic_d = 6 case, where the model tends heavily toward memorization due to its excessive capacity. Instead, the d=4𝑑4d=4italic_d = 4 model manages to maintain a better balance between memorization and generalization at the end of pre-training. Consequently, the model has a 1111-shot loss surge followed by a notable drop in ICL loss. This suggests that d=4𝑑4d=4italic_d = 4 optimally leverages the available model capacity to facilitate effective learning dynamics for o.o.d. generalization.

5 Interpretability

Refer to caption
(a)
Figure 5: d=4𝑑4d=4italic_d = 4 and d=2𝑑2d=2italic_d = 2 models’ performance on k-shot inference, on the grid of inputs (x,y)p2𝑥𝑦superscriptsubscript𝑝2(x,y)\in\mathbb{Z}_{p}^{2}( italic_x , italic_y ) ∈ blackboard_Z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (task vector = (6,6)66(6,6)( 6 , 6 )). row 1: Models’ predictions on o.o.d. task of the type (x1y1z1xkykzkxy?)subscript𝑥1subscript𝑦1subscript𝑧1subscript𝑥𝑘subscript𝑦𝑘subscript𝑧𝑘𝑥𝑦?(x_{1}\;\;y_{1}\;\;{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}% {1,0,0}z_{1}}\;\;\cdots\;\;x_{k}\;\;y_{k}\;\;{\color[rgb]{1,0,0}\definecolor[% named]{pgfstrokecolor}{rgb}{1,0,0}z_{k}}\;\;x\;\;y\;\;{\color[rgb]{1,0,0}% \definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}?})( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋯ italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_x italic_y ? ). row 2: Analytical plots showing all the possible examples that are re-scalings of previous examples. In other words, we mark the points white whenever (x,y)=ci(xi,yi)𝑥𝑦subscript𝑐𝑖subscript𝑥𝑖subscript𝑦𝑖(x,y)=c_{i}(x_{i},y_{i})( italic_x , italic_y ) = italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) for some ik𝑖𝑘i\leq kitalic_i ≤ italic_k. row 3: Subtract row 2 from row 1, by using correct=1 and incorrect=0. The red points correspond to the examples that the model manages to get right, but are not re-scalings of any of the previous examples. Whereas, blue points are the examples the model misses, despite being a re-scaling of a previous example. This row tells us about the model’s ability to combine the re-scaled in-context examples. Note that d=4𝑑4d=4italic_d = 4 model readily learns to combine previous examples, while its d=2𝑑2d=2italic_d = 2 counterpart struggles due to its limited capacity.

An important question to consider is what algorithm the model implements to achieve o.o.d. generalization. Here we propose such an algorithm and provide empirical evidence supporting our claims. For simplicity, consider the case with two-shot inference:

(x1y1z1tx2y2z2txy?)matrixsubscript𝑥1subscript𝑦1subscriptsuperscript𝑧𝑡1subscript𝑥2subscript𝑦2subscriptsuperscript𝑧𝑡2𝑥𝑦?\displaystyle\left(\begin{matrix}x_{1}&y_{1}&\color[rgb]{1,0,0}\definecolor[% named]{pgfstrokecolor}{rgb}{1,0,0}{z^{t}_{1}}&x_{2}&y_{2}&\color[rgb]{1,0,0}% \definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}{z^{t}_{2}}&x&y&\color[rgb]{% 1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}{?}\end{matrix}\right)( start_ARG start_ROW start_CELL italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL italic_x end_CELL start_CELL italic_y end_CELL start_CELL ? end_CELL end_ROW end_ARG )

The model learns to (i) re-scale the in-context inputs and (ii) perform linear transformations (over GF(p)GF𝑝\mathrm{GF}(p)roman_GF ( italic_p )). For example, in the above case, the model needs to find the two constants c1subscript𝑐1c_{1}italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and c2subscript𝑐2c_{2}italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT such that

c1(x1,y1)+c2(x2,y2)=(x,y)modp.subscript𝑐1subscript𝑥1subscript𝑦1subscript𝑐2subscript𝑥2subscript𝑦2𝑥𝑦mod𝑝\displaystyle c_{1}(x_{1},\;y_{1})+c_{2}(x_{2},\;y_{2})=(x,\;y)\quad\mathrm{% mod}\;p.italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = ( italic_x , italic_y ) roman_mod italic_p . (2)

Once the model has figured out the constants, the result can be simply computed with modular addition: c1z1t+c2z2tmodpsubscript𝑐1subscriptsuperscript𝑧𝑡1subscript𝑐2subscriptsuperscript𝑧𝑡2mod𝑝c_{1}z^{t}_{1}+c_{2}z^{t}_{2}\;\mathrm{mod}\;pitalic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_mod italic_p [11, 19, 18, 29]. Notably, with sequences that provide more demonstrations, the model has the flexibility to combine in-context examples to derive the answer111In an autoregressive setting, the actual position that the model needs to output the prediction is shifted by one token, but this does not change our argument.. However, no matter which specific implementation the model chooses, solving linear systems in GF(p)GF𝑝\mathrm{GF}(p)roman_GF ( italic_p ) to figure out those constants is a must. This requires the model to be able to combine in-context examples and perform modular arithmetic over them, which breaks down to the following three essential skills other than copying information across in-context examples:

  1. I.

    Modular Map: Encode the tokens such that operations over GF(p)GF𝑝\mathrm{GF}(p)roman_GF ( italic_p ) can be easily implemented

  2. II.

    Multiplication: A necessary skill to rescale in-context examples like in eq. 2

  3. III.

    Addition: A necessary skill for combining two examples in eq. 2

In Section 5.1, we demonstrate that while models of both depths, (d=2,4𝑑24d=2,4italic_d = 2 , 4) learn skills I, II perfectly; the d=4𝑑4d=4italic_d = 4 model outperforms its d=2𝑑2d=2italic_d = 2 in skill III – namely combining the in-context examples. We attribute this disparity to the limited capacity of d=2𝑑2d=2italic_d = 2 models. Then, in Section 5.2, we show that there are special attention heads that implement the skills I and II. We explicitly show that the structure of these heads deteriorate as pre-training task diversity ni.d.subscript𝑛formulae-sequenceidn_{\mathrm{i.d.}}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT decreases. We leave the discussion of embedding layers (shown in Figure 1(e)) to Appendix E.

5.1 Model Learns to Combine

We find telling signatures of Equation 2 upon analyzing model predictions with varying numbers of examples in context. Consider the 1-shot scenario, wherein only one demonstration is given to the model. This is equivalent to setting c2=0subscript𝑐20c_{2}=0italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 0 in Equation 2. In this case, one would expect that the model will only solve the task correctly for examples that obey (x,y)=c1(x1,y1)𝑥𝑦subscript𝑐1subscript𝑥1subscript𝑦1(x,y)=c_{1}(x_{1},y_{1})( italic_x , italic_y ) = italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) for some c1subscript𝑐1c_{1}italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. Indeed, this is exactly what we observe in fig. 5 – the model correctly predicts all the targets when the inputs (x,y)𝑥𝑦(x,y)( italic_x , italic_y ) are re-scalings of the in-context input (x1,y1)subscript𝑥1subscript𝑦1(x_{1},y_{1})( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ).

For few-shot scenarios, the model does two things. It can predict the targets correctly whenever the query input is proportional to any of the inputs of the previous demonstrations (Figure 5 row 1). Additionally and more importantly, it can correctly solve substantially more than just these cases (Figure 5 row 1,3). This results from the model’s ability to combine in-context examples and use them for prediction. Evidence for this can be seen in the proliferation of red points in row 3 of Figure 5 as more examples are given in-context. The d=2𝑑2d=2italic_d = 2 model performs inferior to its d=4𝑑4d=4italic_d = 4 counterpart in this latter skill, hinting at an imperfect implementation of the algorithm and the partially correct choice of the set of constants {c1,c2,}subscript𝑐1subscript𝑐2\{c_{1},c_{2},\cdots\}{ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ }.

5.2 Attention Heads Implement Essential Skills

Refer to caption
(a)
Figure 6: Models that generalize o.o.d. (left) exhibit more structured attention maps. Additionally, the top-2222 principal components of the features from the corresponding heads also show more structured patterns. The features are computed for sequences with an o.o.d. task vector (at,bt)=(6,6)superscript𝑎𝑡superscript𝑏𝑡66(a^{t},b^{t})=(6,6)( italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) = ( 6 , 6 ), loop over (xi,yi)subscript𝑥𝑖subscript𝑦𝑖(x_{i},y_{i})( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) at a specific shot while the previous parts of the sequence are fixed. We annotate each PCA plot with the (log27xi,log27yi)modpsubscript27subscript𝑥𝑖subscript27subscript𝑦𝑖mod𝑝(\log_{27}{x_{i}},\log_{27}{y_{i}})\,\,\textrm{mod}\,\,p( roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) mod italic_p pairs. (a) The principal components form a circle of circles where the position of the outer circle is controlled by xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. This pattern remains the same for different task vectors or the shot choices; (b) Only plotted pairs with even log27xisubscript27subscript𝑥𝑖\log_{27}x_{i}roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, with each log27xisubscript27subscript𝑥𝑖\log_{27}x_{i}roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT circled with different colors. The PCA pattern forms a similar double circle as those in (a), with the key difference that those circles depend on task vector choices and the shot choices; (c, d) Models without o.o.d. generalization ability. We pick heads from the first block that corresponds to the first column of (a). We see clearly the structure of attention maps and PCA patterns deteriorate as the task diversity decreases.

In the previous section, we found “black-box" evidence suggesting that the model is implementing an algorithm akin to solving equation eq. 2. Now, we turn to “white-box" interpretability, to identify the components within the transformer model that are responsible for the essential skills outlined at the beginning of Section 5. In figure Figure 6(a, b) we analyze the important attention heads in a model that generalizes o.o.d. We compare them with the heads from a model that does not generalize o.o.d in Figure 6(c, d).

In Figure 6(a), we show the attention head from layer 1 that implements skill I. In the top panel, we see that each query only pays attention to itself and the two preceding keys. This pattern likely stems from the fact that each example in the sequence contains three tokens xi,yi,zitsubscript𝑥𝑖subscript𝑦𝑖subscriptsuperscript𝑧𝑡𝑖x_{i},y_{i},z^{t}_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT; and suggests that this head is mostly focused on the information within the example.

In the bottom panel, we perform principal component analysis (PCA) on the outputs of this head. Specifically, we feed a batched k-shot sequence of the form (x1y1z1xkykzkxyz)subscript𝑥1subscript𝑦1subscript𝑧1subscript𝑥𝑘subscript𝑦𝑘subscript𝑧𝑘𝑥𝑦𝑧(x_{1}\;\;y_{1}\;\;{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}% {1,0,0}z_{1}}\;\;\cdots\;\;x_{k}\;\;y_{k}\;\;{\color[rgb]{1,0,0}\definecolor[% named]{pgfstrokecolor}{rgb}{1,0,0}z_{k}}\;\;x\;\;y\;\;{\color[rgb]{1,0,0}% \definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}z})( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋯ italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_x italic_y italic_z ), where the first k𝑘kitalic_k inputs are fixed and the last input (x,y)𝑥𝑦(x,y)( italic_x , italic_y ) is scanned over all p2superscript𝑝2p^{2}italic_p start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT possible pairs. We concatenate the resulting features from x𝑥xitalic_x and y𝑦yitalic_y, resulting in p2×2dembedsuperscript𝑝22subscript𝑑embedp^{2}\times 2d_{\mathrm{embed}}italic_p start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT × 2 italic_d start_POSTSUBSCRIPT roman_embed end_POSTSUBSCRIPT batch of features – and perform PCA on this matrix. We project all these 2dembed2subscript𝑑embed2d_{\mathrm{embed}}2 italic_d start_POSTSUBSCRIPT roman_embed end_POSTSUBSCRIPT dimensional features onto the first two principal components. Annotating the pairs (x,y)𝑥𝑦(x,y)( italic_x , italic_y ) with (log27x,log27y)subscript27𝑥subscript27𝑦(\log_{27}x,\log_{27}y)( roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT italic_x , roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT italic_y )222We use 27 as the base of logarithm, which is a primitive root of GF(29)GF29\mathrm{GF}(29)roman_GF ( 29 ), we find a “Clock-of-clocks" – where clocks of period 28 are themselves arranged in a bigger clock of period 28. Number 0 is located at the center of the clocks333This is expected since log270subscript270\log_{27}0roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT 0 is mathematically ill-defined and needs special treatment.. We observe similar clock-of-clocks for concatenated features from (x,z)𝑥𝑧(x,z)( italic_x , italic_z ) and (y,z)𝑦𝑧(y,z)( italic_y , italic_z ) as well. We refer the reader to Appendix E for further details.

In Figure 6(b), we analyze the head from layer 2 that effectively implements Skill II – it does so by building upon skill 1 and leveraging the features discussed above. The upper panel shows that the highly structured attention map that focuses on the current example as well as the examples preceding the current one. This pattern aligns with the step in our proposed algorithm where the model needs to compare (x,y)𝑥𝑦(x,y)( italic_x , italic_y ) pairs across different examples in-context; and selectively re-scale some of them for later combination. Additionally, as the number of in-context demonstrations increases, the attention map indicates that the model begins to focus on multiple examples simultaneously. This suggests that the model might be employing a potential weighted average approach over these examples, as discussed earlier.

By conducting a PCA analysis similar to that in Figure 6(a), we also identify clocks when annotating examples in the (log27x,log27y)subscript27𝑥subscript27𝑦(\log_{27}x,\log_{27}y)( roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT italic_x , roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT italic_y ) format. Unlike the previously discussed pattern, the specifics of this "clock-of-clocks" arrangement vary depending on the position and the choice of task vector (at,bt)superscript𝑎𝑡superscript𝑏𝑡(a^{t},b^{t})( italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ). This variability suggests that the head in question is utilizing information about the specific task from the context to appropriately rescale the equations.

We further highlight the importance of the signatures we find in the above paragraphs via comparison with a model that does not generalize o.o.d, in Figure 6(c, d). Note that as the number of pre-training tasks ni.d.subscript𝑛formulae-sequenceidn_{\mathrm{i.d.}}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT decreases, the attention map similar to the one in panel (a) starts to be mosaicked; and at the same time the PCA projections lose their shape. As a result, those models also lose the ability to perform ICL on modular arithmetic out-of-distribution.

From Figure 5, we infer that the d=2𝑑2d=2italic_d = 2 model has only partially acquired skill III, due to its limited capacity. On the other hand, d=4𝑑4d=4italic_d = 4 model is very good at combining equations via skill III, explaining its superior performance.

6 Discussion

We have investigated the emergence of ICL learning and skill composition in AMs on a novel algorithmic dataset. The dataset includes a large discrete set of modular arithmetic tasks and was specifically designed to force models to learn how to solve a variety of tasks. It consists of learning linear modular functions, where the model is expected to identify and perform a modular operation in-context. When the number of tasks is relatively small, the models simply memorize the task vectors and classify the input vectors by the task vectors they have memorized. In this case, the models still have to learn how to perform the memorized tasks because they do generalize in-distribution and develop ICL capabilities. Once the number of training tasks becomes too large the models transition to a qualitatively different, algorithmic approach, where the task vector is determined at inference time.

Finally, we have investigated the learnt representations and showed that qualitatively different circuits are formed in different phases.

Limitations

We have limited ourselves to algorithmic datasets. Without further investigation, it is difficult to say what lessons can be transported to the realistic language models, and what lessons are specific to the current setting. Mechanistic interpretability analysis of deeper models proved to be much more difficult than that of a smaller models. Consequently, we still do not understand the role of every part of the network in the deeper cases.

Acknowledgments

T.H. thanks Yue Xu and Dayal Singh Kalra for helpful discussions. A.G.’s work at the University of Maryland was supported in part by NSF CAREER Award DMR-2045181, Sloan Foundation and the Laboratory for Physical Sciences through the Condensed Matter Theory Center. The authors acknowledge the University of Maryland supercomputing resources (http://hpcc.umd.edu) made available for conducting the research reported in this paper.

References

  • Ahn et al. [2023] Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, and Suvrit Sra. Transformers learn to implement preconditioned gradient descent for in-context learning. In A. Oh, T. Neumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine, editors, Advances in Neural Information Processing Systems, volume 36, pages 45614–45650. Curran Associates, Inc., 2023. URL https://proceedings.neurips.cc/paper_files/paper/2023/file/8ed3d610ea4b68e7afb30ea7d01422c6-Paper-Conference.pdf.
  • Akyürek et al. [2023] Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm is in-context learning? investigations with linearmodels, 2023. URL https://openreview.net/forum?id=0g0X4H8yN4I.
  • Arora and Goyal [2023] Sanjeev Arora and Anirudh Goyal. A theory for emergence of complex skills in language models. arXiv preprint arXiv:2307.15936, 2023.
  • Boix-Adserà et al. [2024] Enric Boix-Adserà, Omid Saremi, Emmanuel Abbe, Samy Bengio, Etai Littwin, and Joshua M. Susskind. When can transformers reason with abstract symbols? In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=STUGfUz8ob.
  • Brown et al. [2020] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel Ziegler, Jeffrey Wu, Clemens Winter, Chris Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners. In H. Larochelle, M. Ranzato, R. Hadsell, M.F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 1877–1901. Curran Associates, Inc., 2020. URL https://proceedings.neurips.cc/paper_files/paper/2020/file/1457c0d6bfcb4967418bfb8ac142f64a-Paper.pdf.
  • Doshi et al. [2023] Darshil Doshi, Aritra Das, Tianyu He, and Andrey Gromov. To grok or not to grok: Disentangling generalization and memorization on corrupted algorithmic datasets. arXiv preprint arXiv:2310.13061, 2023.
  • Doshi et al. [2024] Darshil Doshi, Tianyu He, Aritra Das, and Andrey Gromov. On learning modular polynomials. In ICLR 2024 Workshop on Bridging the Gap Between Practice and Theory in Deep Learning, 2024. URL https://openreview.net/forum?id=QO0y9ysrgu.
  • Elhage et al. [2021] Nelson Elhage, Neel Nanda, Catherine Olsson, Tom Henighan, Nicholas Joseph, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Nova DasSarma, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. A mathematical framework for transformer circuits. Transformer Circuits Thread, 2021. https://transformer-circuits.pub/2021/framework/index.html.
  • Ganguli et al. [2022] Deep Ganguli, Danny Hernandez, Liane Lovitt, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Nova Dassarma, Dawn Drain, Nelson Elhage, et al. Predictability and surprise in large generative models. In Proceedings of the 2022 ACM Conference on Fairness, Accountability, and Transparency, pages 1747–1764, 2022.
  • Garg et al. [2023] Shivam Garg, Dimitris Tsipras, Percy Liang, and Gregory Valiant. What can transformers learn in-context? a case study of simple function classes. 2023.
  • Gromov [2023] Andrey Gromov. Grokking modular arithmetic, 2023. URL https://arxiv.longhoe.net/abs/2301.02679.
  • Gu et al. [2024] Jiuxiang Gu, Chenyang Li, Yingyu Liang, Zhenmei Shi, Zhao Song, and Tianyi Zhou. Fourier circuits in neural networks: Unlocking the potential of large language models in mathematical reasoning and modular arithmetic, 2024.
  • Guo et al. [2024] Tianyu Guo, Wei Hu, Song Mei, Huan Wang, Caiming Xiong, Silvio Savarese, and Yu Bai. How do transformers learn in-context beyond simple functions? a case study on learning with representations. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=ikwEDva1JZ.
  • Hendel et al. [2023] Roee Hendel, Mor Geva, and Amir Globerson. In-context learning creates task vectors, 2023.
  • Kirsch et al. [2024] Louis Kirsch, James Harrison, Jascha Sohl-Dickstein, and Luke Metz. General-purpose in-context learning by meta-learning transformers, 2024.
  • Lin and Lee [2024] Ziqian Lin and Kangwook Lee. Dual operating modes of in-context learning, 2024.
  • Liu et al. [2024] Sheng Liu, Haotian Ye, Lei Xing, and James Zou. In-context vectors: Making in context learning more effective and controllable through latent space steering, 2024.
  • Liu et al. [2022] Ziming Liu, Ouail Kitouni, Niklas S Nolte, Eric Michaud, Max Tegmark, and Mike Williams. Towards understanding grokking: An effective theory of representation learning. Advances in Neural Information Processing Systems, 35:34651–34663, 2022.
  • Nanda et al. [2023] Neel Nanda, Lawrence Chan, Tom Lieberum, Jess Smith, and Jacob Steinhardt. Progress measures for grokking via mechanistic interpretability. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=9XFSbDPmdW.
  • Nichani et al. [2024] Eshaan Nichani, Alex Damian, and Jason D. Lee. How transformers learn causal structure with gradient descent, 2024.
  • Olsson et al. [2022] Catherine Olsson, Nelson Elhage, Neel Nanda, Nicholas Joseph, Nova DasSarma, Tom Henighan, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Scott Johnston, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. In-context learning and induction heads, 2022.
  • Paszke et al. [2019] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Köpf, Edward Yang, Zach DeVito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. Pytorch: An imperative style, high-performance deep learning library, 2019.
  • Power et al. [2022] Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, and Vedant Misra. Grokking: Generalization beyond overfitting on small algorithmic datasets. arXiv preprint arXiv:2201.02177, 2022.
  • Raventos et al. [2023] Allan Raventos, Mansheej Paul, Feng Chen, and Surya Ganguli. Pretraining task diversity and the emergence of non-bayesian in-context learning for regression. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=BtAz4a5xDg.
  • Singh et al. [2023] Aaditya K Singh, Stephanie C.Y. Chan, Ted Moskovitz, Erin Grant, Andrew M Saxe, and Felix Hill. The transient nature of emergent in-context learning in transformers. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=Of0GBzow8P.
  • Su et al. [2023] Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu. Roformer: Enhanced transformer with rotary position embedding, 2023.
  • von Oswald et al. [2023] Johannes von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, and Max Vladymyrov. Transformers learn in-context by gradient descent, 2023.
  • Wei et al. [2022] Jason Wei, Yi Tay, Rishi Bommasani, Colin Raffel, Barret Zoph, Sebastian Borgeaud, Dani Yogatama, Maarten Bosma, Denny Zhou, Donald Metzler, et al. Emergent abilities of large language models. arXiv preprint arXiv:2206.07682, 2022.
  • Zhong et al. [2023] Ziqian Zhong, Ziming Liu, Max Tegmark, and Jacob Andreas. The clock and the pizza: Two stories in mechanistic explanation of neural networks. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=S5wmbQc1We.

Appendix A Experimental Details

A.1 Model and Training Hyperparameters

Architecture

We used GPT-like architectures with Rotary Positional Embedding (θ=10,000𝜃10000\theta=10,000italic_θ = 10 , 000) and ReLU activations. We fix the number of heads H=4𝐻4H=4italic_H = 4, embedding dimension dembed=512subscript𝑑embed512d_{\mathrm{embed}}=512italic_d start_POSTSUBSCRIPT roman_embed end_POSTSUBSCRIPT = 512 and MLP widening factor 4444 throughout every model. We use d=2,4,6𝑑246d=2,4,6italic_d = 2 , 4 , 6 throughout the paper. Embedding layers and the output layer are tied via weight tying.

Initialization

All linear layers and embedding layer weights are sampled from Gaussian distribution 𝒩(0,0.022)𝒩0superscript0.022\mathcal{N}(0,0.02^{2})caligraphic_N ( 0 , 0.02 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) at initialization, with the exception that the last linear layer in each MLP is sampled from 𝒩(0,0.022/2d)𝒩0superscript0.0222𝑑\mathcal{N}(0,0.02^{2}/2d)caligraphic_N ( 0 , 0.02 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 italic_d ). No bias is used in any layer.

Optimization and Schedule

We trained most models using AdamW optimizer with learning rate η=1.5×104𝜂1.5superscript104\eta=1.5\times 10^{-4}italic_η = 1.5 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT, weight decay λ=2.0𝜆2.0\lambda=2.0italic_λ = 2.0, β1=0.9subscript𝛽10.9\beta_{1}=0.9italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0.9, β2=0.98subscript𝛽20.98\beta_{2}=0.98italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 0.98, ϵ=108italic-ϵsuperscript108\epsilon=10^{-8}italic_ϵ = 10 start_POSTSUPERSCRIPT - 8 end_POSTSUPERSCRIPT, batch size B=1024𝐵1024B=1024italic_B = 1024, nct=32subscript𝑛ct32n_{\mathrm{ct}}=32italic_n start_POSTSUBSCRIPT roman_ct end_POSTSUBSCRIPT = 32 in-context examples for 200200200200k steps, together with a 5%percent55\%5 % linear warmup starting from 0.01η0.01𝜂0.01\eta0.01 italic_η and a cosine annealing towards the end to 0.1η0.1𝜂0.1\eta0.1 italic_η. Weight decay is not applied to LayerNorm layers.

Hyperparameter Choice

For d=2𝑑2d=2italic_d = 2 models we scanned learning rates η{7.5×105,1.5×104,3×104,6×104}𝜂7.5superscript1051.5superscript1043superscript1046superscript104\eta\in\{7.5\times 10^{-5},1.5\times 10^{-4},3\times 10^{-4},6\times 10^{-4}\}italic_η ∈ { 7.5 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT , 1.5 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT , 3 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT , 6 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT } and weight decay values λ{0.5,1.0,2.0,5.0}𝜆0.51.02.05.0\lambda\in\{0.5,1.0,2.0,5.0\}italic_λ ∈ { 0.5 , 1.0 , 2.0 , 5.0 }. Then we transfer our hyperparameters to other depths. Benefiting from the extra scaling in the initialization of the last linear in MLP, we find that the hyperparameters perform well for other depths. For larger p𝑝pitalic_p values, we lowering down the learning rate to 104superscript10410^{-4}10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT.

A.2 Further Details of Each Plot in the Main Text

Figure 1 (b) Phase diagram constructed using d=6𝑑6d=6italic_d = 6 data shown in Figure 18, threshold for defining each phase is set to 75%percent7575\%75 % for the corresponding set; (c) accuracy - number of shots curve for d=4𝑑4d=4italic_d = 4, ni.d.=28subscript𝑛formulae-sequenceidsuperscript28n_{\mathrm{i.d.}}=2^{8}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT, p=29𝑝29p=29italic_p = 29 and α=0.7𝛼0.7\alpha=0.7italic_α = 0.7. (d, e) d=2𝑑2d=2italic_d = 2 model with p=29𝑝29p=29italic_p = 29, ni.d.=29subscript𝑛formulae-sequenceidsuperscript29n_{\mathrm{i.d.}}=2^{9}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT and α=0.6𝛼0.6\alpha=0.6italic_α = 0.6.

Figure 3 (a) Selected the best out of three random seeds with early stop**. (b, c) averaged over three random seeds with standard error labeled. All o.o.d. data are measured every 1,00010001,0001 , 000 step for 16161616 randomly sampled sequences along the pre-training. (d, e) Used checkpoint at the end of pre-training, averaged over 128 random sequences sampled from Stesto.o.d.subscriptsuperscript𝑆formulae-sequenceoodtestS^{\mathrm{o.o.d.}}_{\mathrm{test}}italic_S start_POSTSUPERSCRIPT roman_o . roman_o . roman_d . end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT.

Figure 4 (a, b) Both phase diagrams are the best selected from three random seeds with early stop**. The loss - number of shots curves are plotted with α=0.6𝛼0.6\alpha=0.6italic_α = 0.6.

Figure 5 We use a d=2𝑑2d=2italic_d = 2 model trained with α=0.6𝛼0.6\alpha=0.6italic_α = 0.6 and ni.d.=29subscript𝑛formulae-sequenceidsuperscript29n_{\mathrm{i.d.}}=2^{9}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT, and a d=4𝑑4d=4italic_d = 4 model trained with α=0.8𝛼0.8\alpha=0.8italic_α = 0.8 and ni.d.=29subscript𝑛formulae-sequenceidsuperscript29n_{\mathrm{i.d.}}=2^{9}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT with batch size B=1536𝐵1536B=1536italic_B = 1536.

Figure 6 (a, b) Attention heads extracted form d=2𝑑2d=2italic_d = 2 model trained with ni.d.=29subscript𝑛formulae-sequenceidsuperscript29n_{\mathrm{i.d.}}=2^{9}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT and α=0.6𝛼0.6\alpha=0.6italic_α = 0.6. (c, d) Both models are trained with α=0.6𝛼0.6\alpha=0.6italic_α = 0.6.

Appendix B Details of resources used

To generate each phase diagram and training curve, we used 90909090 GPU days on NVIDIA A100 40GB, with automatic mixed precision (BFloat16) and Flash Attention implemented in the PyTorch library[22]. During the exploring stage, we also used around another 90909090 GPU days. Most inferences were running on 1/7-NVIDIA A100 40GB GPU, of which the cost is negligible compared to the pre-training cost.

Appendix C Table of log map for p=29𝑝29p=29italic_p = 29

n01234567891011121314log27(n)02815192226121710911211827𝑛01234567891011121314missing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionsubscript27𝑛02815192226121710911211827\begin{array}[]{c|*{15}{c}}n&0&1&2&3&4&5&6&7&8&9&10&11&12&13&14\\ \hline\cr\log_{27}(n)&0&28&15&19&2&22&6&12&17&10&9&11&21&18&27\end{array}start_ARRAY start_ROW start_CELL italic_n end_CELL start_CELL 0 end_CELL start_CELL 1 end_CELL start_CELL 2 end_CELL start_CELL 3 end_CELL start_CELL 4 end_CELL start_CELL 5 end_CELL start_CELL 6 end_CELL start_CELL 7 end_CELL start_CELL 8 end_CELL start_CELL 9 end_CELL start_CELL 10 end_CELL start_CELL 11 end_CELL start_CELL 12 end_CELL start_CELL 13 end_CELL start_CELL 14 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT ( italic_n ) end_CELL start_CELL 0 end_CELL start_CELL 28 end_CELL start_CELL 15 end_CELL start_CELL 19 end_CELL start_CELL 2 end_CELL start_CELL 22 end_CELL start_CELL 6 end_CELL start_CELL 12 end_CELL start_CELL 17 end_CELL start_CELL 10 end_CELL start_CELL 9 end_CELL start_CELL 11 end_CELL start_CELL 21 end_CELL start_CELL 18 end_CELL start_CELL 27 end_CELL end_ROW end_ARRAY
n1516171819202122232425262728log27(n)1347252324326208165114𝑛1516171819202122232425262728missing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionmissing-subexpressionsubscript27𝑛1347252324326208165114\begin{array}[]{c|*{14}{c}}n&15&16&17&18&19&20&21&22&23&24&25&26&27&28\\ \hline\cr\log_{27}(n)&13&4&7&25&23&24&3&26&20&8&16&5&1&14\end{array}start_ARRAY start_ROW start_CELL italic_n end_CELL start_CELL 15 end_CELL start_CELL 16 end_CELL start_CELL 17 end_CELL start_CELL 18 end_CELL start_CELL 19 end_CELL start_CELL 20 end_CELL start_CELL 21 end_CELL start_CELL 22 end_CELL start_CELL 23 end_CELL start_CELL 24 end_CELL start_CELL 25 end_CELL start_CELL 26 end_CELL start_CELL 27 end_CELL start_CELL 28 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT ( italic_n ) end_CELL start_CELL 13 end_CELL start_CELL 4 end_CELL start_CELL 7 end_CELL start_CELL 25 end_CELL start_CELL 23 end_CELL start_CELL 24 end_CELL start_CELL 3 end_CELL start_CELL 26 end_CELL start_CELL 20 end_CELL start_CELL 8 end_CELL start_CELL 16 end_CELL start_CELL 5 end_CELL start_CELL 1 end_CELL start_CELL 14 end_CELL end_ROW end_ARRAY

Appendix D Task Selection and Sequence Design

Task Selection

During the initial exploration phase, we observed that it was challenging for the model to learn multiple modular arithmetic tasks simultaneously. Typically, the loss would stagnate at a plateau indefinitely.

We hypothesize that when the model is trained on multiple modular arithmetic tasks, the strong bias inherent in each individual task may interfere significantly. When tasks are selected randomly, the resultant noise appears to inhibit the learning capabilities of the model, preventing it from acquiring any meaningful patterns or rules from the data.

Ultimately, we adopted a more structured approach to task selection as illustrated in Figure Figure 2, based on the following rationale: If the model is to learn a specific task vector (at,bt)superscript𝑎𝑡superscript𝑏𝑡(a^{t},b^{t})( italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ), it would presumably be more straightforward if one of the two components – either atsuperscript𝑎𝑡a^{t}italic_a start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT or btsuperscript𝑏𝑡b^{t}italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT – has already learned by the model. This would leave the model with only the other component to decipher, which we assume is a comparatively simpler task. Thus, we decided to employ the rectangular rule for sampling task vectors, which we believe strategically reduces the complexity of the learning process by partially leveraging previously acquired knowledge.

In Figure 7, we show an ablation plot of phase diagrams with a randomly selected task vector collection for pre-training. We still use the special sequence design.

Refer to caption
(a)
Figure 7: d=6𝑑6d=6italic_d = 6 phase diagram with task vectors randomly selected, selected the best results from two random seeds. Learning rate η=104𝜂superscript104\eta=10^{-4}italic_η = 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT.
Sequence Design

Following a similar spirit, we use a balanced batch where sequences generated from all task vectors appear exactly the same number of times during the training. We further align the examples across sequences generated by different task vectors, which we believe reduces the chance of the model getting confused by the same input (x,y)𝑥𝑦(x,y)( italic_x , italic_y ) appearing at different positions within the batch. Without this design, we could not make the model train.

Here, we also show the training curve per task in Figure 8, trained with all of our tricks. We see that the model first learned very few tasks and then eventually found its way out. For training without the task selection and sequence design, the loss typically plateau around 3.0similar-toabsent3.0\sim 3.0∼ 3.0.

Refer to caption
Refer to caption
(a)
Refer to caption
Refer to caption
(b)
Figure 8: d=4𝑑4d=4italic_d = 4 training curve, with loss and accuracy measured per task for the last in context example. Trained with ni.d.=28subscript𝑛formulae-sequenceidsuperscript28n_{\mathrm{i.d.}}=2^{8}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT and α=0.8𝛼0.8\alpha=0.8italic_α = 0.8. There are six tasks learned almost instantly: (5,26)526(5,26)( 5 , 26 ), (21,23)2123(21,23)( 21 , 23 ), (28,27)2827(28,27)( 28 , 27 ), (3,9)39(3,9)( 3 , 9 ), (10,10)1010(10,10)( 10 , 10 ), (5,9)59(5,9)( 5 , 9 ).

Appendix E Additional Interpretability Results

In this section, we show additional results on interpretability. Where we first show evidence in LABEL:subsecapp:pattern that a d=4𝑑4d=4italic_d = 4 model also implements a version of the scale-and-combine algorithm we proposed in Section 5, and it can do better than the d=2𝑑2d=2italic_d = 2 model. Then, we continue our discussion on how the algorithm might be implemented inside the model. Importantly, This includes the role of embedding (Section E.1), the role of other attention heads (Section E.2), and finally, the role of MLP and LayerNorm (Section E.3).

E.1 PCA over Embeddings

We begin the discussion with the role of embedding layers. By further examining different models, we find highly structured embedding from d=2𝑑2d=2italic_d = 2 models. d=4𝑑4d=4italic_d = 4 models, on the other hand, does not have such structured embedding layers.

First we focus the embedding of d=2𝑑2d=2italic_d = 2 models. As shown in Figure 9, clearly, the logarithm of each number is split into even and odd groups, and each group forms one clock, which is a suitable embedding for doing modular multiplications. However, one should note that this will not obviate the importance of the head shown in Figure 6(a), as the model still needs a way to distinguish the same number that appears in the different positions in the context.

Curiously, we could not find such a structured embedding for the d=4𝑑4d=4italic_d = 4 model, as shown in Figure 10. However, as we will show in the next subsection, this non-structural embedding, together with the first layer, prepares a foundation for the essential heads in the latter layer to create a similar “clock-of-clock" feature as we shown in Figure 6 (a, b).

Refer to caption
(a) seed 2
Refer to caption
Refer to caption
(b) seed 1
Figure 9: PCA over the embedding layers of d=2𝑑2d=2italic_d = 2 models with different random seeds. Each number is annotated by its log27subscript27\log_{27}roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT value.
Refer to caption
Refer to caption
Refer to caption
(a) PCA over the embedding layer for top-3333 principal components
Figure 10: PCA over the embedding layers of d=4𝑑4d=4italic_d = 4. Each number is annotated by its log27subscript27\log_{27}roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT value.

E.2 Attention Heads

Refer to caption
(a)
Refer to caption
(b)
Refer to caption
(c)
Refer to caption
(a) layer 2222 head 4444
Refer to caption
(b) layer 3333 head 1111
Refer to caption
(c) layer 3333 head 4444
Figure 11: PCA over the features of specific heads for d=4𝑑4d=4italic_d = 4 model. Each number is annotated by its log27subscript27\log_{27}roman_log start_POSTSUBSCRIPT 27 end_POSTSUBSCRIPT value. We plotted PCA over (a, b) (x,y)𝑥𝑦(x,y)( italic_x , italic_y ) and (c) (y,z)𝑦𝑧(y,z)( italic_y , italic_z ) concatenated features. Where the circle in (b, c) heavily depends on the sequence and (a) remains unchanged.

E.2.1 d=4𝑑4d=4italic_d = 4 Model

To continue the story, we analyse the attention heads in different models. We first study the d=4𝑑4d=4italic_d = 4 model, where we also find a similar head with “clock-of-clocks" (Figure 6(a)). Moreover, two other heads put together are seemingly equivalent to Figure 6(b). We surmise that any model needs that solves modular arithmetic in context requires such heads.

From Figure 11(a), we see that the head still pays attention locally within three token positions. Importantly, it also creates a clock-of-clock while performing PCA over concatenated (x,y)𝑥𝑦(x,y)( italic_x , italic_y ) features. The difference here is that the clock winds twice to go back to its origin. We believe that this factor of two differences in the period means that this head is effectively combing the row of the embedding layer of the d=2𝑑2d=2italic_d = 2 model and the role of the head in Figure 6(a). Overall, we do not yet know why the model needs to split the logarithm of numbers into even and odd groups.

From Figure 11(b, c), we find two heads that are structured but different from Figure 6(b). The PCA pattern forms clocks while annotated within logarithm space, but depends on the choice of the sequences. This hints that those two heads are trying to re-scale the examples by comparing them with the previous inputs. We think that the two can be combined to form a similar pattern as Figure 6(b).

E.2.2 d=2𝑑2d=2italic_d = 2 Model

Next we focus on the d=2𝑑2d=2italic_d = 2 model. In Figure 12, we plot similar PCA to the one in Figure 6(a), but with different concatenated features from the same heads: (x,z)𝑥𝑧(x,z)( italic_x , italic_z ) and (y,z)𝑦𝑧(y,z)( italic_y , italic_z ). We again see clocks-of-clock, solidifying our argument that this head provides proper spaces for later heads to perform modular operations.

Refer to caption
(a) (x,z)𝑥𝑧(x,z)( italic_x , italic_z )
Refer to caption
(b) (y,z)𝑦𝑧(y,z)( italic_y , italic_z )
Figure 12: PCA over the same head in Figure 6(a), with feature concatenated differently. We annotate numbers with their corresponding logarithm.

Finally, in Figure 13, we dump all the attention patterns we had for d=2𝑑2d=2italic_d = 2 models. The corresopnding PCA over concatenated (x,y)𝑥𝑦(x,y)( italic_x , italic_y ) features are shown in Figure 14 (except for layer 1 head 4, for which we plot for (x,z)𝑥𝑧(x,z)( italic_x , italic_z )). The special head we choose here has a similar behavior to the one in Figure 6(a), with its main focus shifted by one position but still within two preceding tokens. We believe this head implements another variant of Equation 2 – it re-scales equations for the next layer, then helps them figure out how to combine them to predict new examples.

We have reasons to believe that the non-clock heads are not essential to the models’ performance. However, we leave a careful exploration of this front for future work.

Refer to caption
(a) layer 1, head 1
Refer to caption
(b) layer 1, head 2
Refer to caption
(c) layer 1, head 3
Refer to caption
(d) layer 1, head 4
Refer to caption
(e) layer 2, head 1
Refer to caption
(f) layer 2, head 2
Refer to caption
(g) layer 2, head 3
Refer to caption
(h) layer 2, head 4
Figure 13: All attention heads in d=2𝑑2d=2italic_d = 2 model.
Refer to caption
(a) layer 1, head 1
Refer to caption
(b) layer 1, head 2
Refer to caption
(c) layer 1, head 3
Refer to caption
(d) layer 1, head 4, (x,z)𝑥𝑧(x,z)( italic_x , italic_z )
Refer to caption
(e) layer 2, head 1
Refer to caption
(f) layer 2, head 2
Refer to caption
(g) layer 2, head 3
Refer to caption
(h) layer 2, head 4
Figure 14: PCA for concatenated (x,y)𝑥𝑦(x,y)( italic_x , italic_y )/(x,z)𝑥𝑧(x,z)( italic_x , italic_z ) features for all heads in d=2𝑑2d=2italic_d = 2 model.

E.3 MLP and LayerNorm

The arguments in Section 5 and the experiments in the previous subsection, strongly suggest that skill III is implemented within the MLP layer. We tried to hunt for signals similar to those in Nanda et al. [19], Gromov [11] from MLP layers, but did not discover conclusive evidence. Similarly, no obvious signals were found in LayerNorm. We think it is crucial to study these layers more carefully. We leave this task for future work.

E.4 Label Noise

To gain insight into how the model combines the in-context examples, we introduce label-corruption in the in-context examples. In particular, we note the effect of (i) amount and (ii) position of label corruption on the model’s performance. When we corrupt a single in-context example for d=4𝑑4d=4italic_d = 4 model, the model performance remains unaffected for longer sequences. This hints at that weighted average of the in-context inputs being used in model prediction. The d=2𝑑2d=2italic_d = 2 model, however, did not show such resilience.

Next, we corrupt multiple in-context examples in random locations. We study the effect on model performance as the amount of corrupted labels increases. While the d=2𝑑2d=2italic_d = 2 model is easily overwhelmed, the d=4𝑑4d=4italic_d = 4 model is able to offer strong resistance even at 40%similar-toabsentpercent40\sim 40\%∼ 40 % label corruption, for long sequences. This behavior remains invariant with the change in task vector for the particular sequence, indicating the universality of the underlying algorithm necessary for o.o.d. generalization.

Refer to caption
(a) single label corruption, d=4𝑑4d=4italic_d = 4
Refer to caption
(b) multiple label corruption, d=4𝑑4d=4italic_d = 4
Refer to caption
(c) multiple label corruption, d=2𝑑2d=2italic_d = 2
Figure 15: (a) Single label corruption : Performance of last token prediction as a function of changing position for a single label corruption. The x-axis shows different shots presented to the model, while the y-axis shows the position j𝑗jitalic_j where the label is corrupted zjzjsubscript𝑧𝑗subscriptsuperscript𝑧𝑗z_{j}\to z^{\prime}_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT → italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. The d=4𝑑4d=4italic_d = 4 model is remarkably robust for long sequences, indicating the use of an algorithm which is not sensitive to the position of preceding examples. Qualitatively similar plots hold for other task vectors as well. (b, c) Multiple label corruption, at random locations for d=4𝑑4d=4italic_d = 4 and d=2𝑑2d=2italic_d = 2 models respectively. The d=4𝑑4d=4italic_d = 4 model is more resilient to label corruption than the d=2𝑑2d=2italic_d = 2 model, implying that the algorithm the latter employs is imperfect due to limited capacity.

Appendix F Additional Training Curves

We plot some selected training curves for d=4𝑑4d=4italic_d = 4 (Figure 16) and d=2𝑑2d=2italic_d = 2 (Figure 17) from Figure 4 phase diagrams. We see that even for d=4𝑑4d=4italic_d = 4, ICL can be a transient. With increased α𝛼\alphaitalic_α or ni.d.subscript𝑛formulae-sequenceidn_{\mathrm{i.d.}}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT, the transient nature goes away.

Refer to caption
Refer to caption
(a) d=4𝑑4d=4italic_d = 4, α=0.7𝛼0.7\alpha=0.7italic_α = 0.7
Refer to caption
Refer to caption
(b) d=4𝑑4d=4italic_d = 4, α=0.6𝛼0.6\alpha=0.6italic_α = 0.6
Figure 16: Training curves for d=4𝑑4d=4italic_d = 4, averaged over three random seeds.
Refer to caption
Refer to caption
(a) d=2𝑑2d=2italic_d = 2, α=0.7𝛼0.7\alpha=0.7italic_α = 0.7
Refer to caption
Refer to caption
(b) d=2𝑑2d=2italic_d = 2, α=0.6𝛼0.6\alpha=0.6italic_α = 0.6
Figure 17: Training curves for d=2𝑑2d=2italic_d = 2, averaged over three random seeds.

Appendix G Additional Phase Diagrams

In Figure 18, we plotted detailed/extended versions of the phase diagrams shown in Figure 1/Figure 4. The four phases story we have shown in Figure 1 still hold for other depths.

Refer to caption
(a) d=6𝑑6d=6italic_d = 6, last shot
Refer to caption
(b) d=4𝑑4d=4italic_d = 4
Refer to caption
(c) d=2𝑑2d=2italic_d = 2
Figure 18: Phase diagrams on all four sets for d=4𝑑4d=4italic_d = 4 and d=2𝑑2d=2italic_d = 2.

Appendix H Different Choice of p𝑝pitalic_p

In this section, we check the effect of varying task difficulties, i.e. the value of p𝑝pitalic_p. In Figure 19, we plotted o.o.d. generalization accuracy. Clearly as the task gets harder, the model needs to see more tasks to generalize out-of-distribution.

Refer to caption
Refer to caption
(a)
Figure 19: We compare the best performance for d=6𝑑6d=6italic_d = 6 models with different ni.d.subscript𝑛formulae-sequenceidn_{\mathrm{i.d.}}italic_n start_POSTSUBSCRIPT roman_i . roman_d . end_POSTSUBSCRIPT values, averaged over three seeds. We use learning rate η=104𝜂superscript104\eta=10^{-4}italic_η = 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT for p=37𝑝37p=37italic_p = 37 and p=47𝑝47p=47italic_p = 47, while kee** other hyperparameters the same as p=29𝑝29p=29italic_p = 29.