License: CC BY 4.0
arXiv:2403.00745v1 [cs.LG] 01 Mar 2024
\correspondingauthor

AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT: An efficient and scalable method for localizing LLM behaviour to components

János Kramár Google DeepMind Tom Lieberum Google DeepMind Rohin Shah Google DeepMind Neel Nanda Google DeepMind
Abstract

Activation Patching is a method of directly computing causal attributions of behavior to model components. However, applying it exhaustively requires a sweep with cost scaling linearly in the number of model components, which can be prohibitively expensive for SoTA Large Language Models (LLMs). We investigate Attribution Patching (AtP) (Nanda, 2022), a fast gradient-based approximation to Activation Patching and find two classes of failure modes of AtP which lead to significant false negatives.
We propose a variant of AtP called AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT, with two changes to address these failure modes while retaining scalability. We present the first systematic study of AtP and alternative methods for faster activation patching and show that AtP significantly outperforms all other investigated methods, with AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT providing further significant improvement. Finally, we provide a method to bound the probability of remaining false negatives of AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT estimates.

1 Introduction

As LLMs become ubiquitous and integrated into numerous digital applications, it’s an increasingly pressing research problem to understand the internal mechanisms that underlie their behaviour – this is the problem of mechanistic interpretability. A fundamental subproblem is to causally attribute particular behaviours to individual parts of the transformer forward pass, corresponding to specific components (such as attention heads, neurons, layer contributions, or residual streams), often at specific positions in the input token sequence. This is important because in numerous case studies of complex behaviours, they are found to be driven by sparse subgraphs within the model (Olsson et al., 2022; Wang et al., 2022; Meng et al., 2023).

A classic form of causal attribution uses zero-ablation, or knock-out, where a component is deleted and we see if this negatively affects a model’s output – a negative effect implies the component was causally important. More recent work has generalised this to replacing a component’s activations with samples from some baseline distribution (with zero-ablation being a special case where activations are resampled to be zero). We focus on the popular and widely used method of Activation Patching (also known as causal mediation analysis) (Geiger et al., 2022; Meng et al., 2023; Chan et al., 2022) where the baseline distribution is a component’s activations on some corrupted input, such as an alternate string with a different answer (Pearl, 2001; Robins and Greenland, 1992).

Given a causal attribution method, it is common to sweep across all model components, directly evaluating the effect of intervening on each of them via resampling (Meng et al., 2023). However, when working with SoTA models it can be expensive to attribute behaviour especially to small components (e.g. heads or neurons) – each intervention requires a separate forward pass, and so the number of forward passes can easily climb into the millions or billions. For example, on a prompt of length 1024, there are 2.71092.7superscript1092.7\cdot 10^{9}2.7 ⋅ 10 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT neuron nodes in Chinchilla 70B (Hoffmann et al., 2022).

We propose to accelerate this process by using Attribution Patching (AtP) (Nanda, 2022), a faster, approximate, causal attribution method, as a prefiltering step: after running AtP, we iterate through the nodes in decreasing order of absolute value of the AtP estimate, then use Activation Patching to more reliably evaluate these nodes and filter out false positives – we call this verification. We typically care about a small set of top contributing nodes, so verification is far cheaper than iterating over all nodes.

Our contributions:
  • We investigate the performance of AtP, finding two classes of failure modes which produce false negatives. We propose a variant of AtP called AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT, with two changes to address these failure modes while retaining scalability:

    • When patching queries and keys, recomputing the attention softmax and using a gradient based approximation from then on, as gradients are a poor approximation to saturated attention.

    • Using dropout on the backwards pass to fix brittle false negatives, where significant positive and negative effects cancel out.

  • We introduce several alternative methods to approximate Activation Patching as baselines to AtP which outperform brute force Activation Patching.

  • We present the first systematic study of AtP and these alternatives and show that AtP significantly outperforms all other investigated methods, with AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT providing further significant improvement.

  • To estimate the residual error of AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT and statistically bound the sizes of any remaining false negatives we provide a diagnostic method, based on using AtP to filter out high impact nodes, and then patching random subsets of the remainder. Good diagnostics mean that practitioners may still gauge whether AtP is reliable in relevant domains without the costs of exhaustive verification.

Finally, we provide some guidance in Section 5.4 on how to successfully perform causal attribution in practice and what attribution methods are likely to be useful and under what circumstances.

Refer to caption
(a) MLP neurons, on CITY-PP.
Refer to caption
(b) Attention nodes, on IOI-PP.
Figure 1: Costs of finding the most causally-important nodes in Pythia-12B using different methods, on sample prompt pairs (see Table 1). The shading indicates geometric standard deviation. Cost is measured in forward passes, thus each point’s y-coordinate gives the number of forward passes required to find the top x𝑥xitalic_x nodes. Note that each node must be verified, thus yx𝑦𝑥y\geq xitalic_y ≥ italic_x, so all lines are above the diagonal, and an oracle for the verification order would produce the diagonal line. For a detailed description see Section 4.3.
Refer to caption
(a) MLP neurons, on CITY-PP.

Refer to caption

Refer to caption
(b) Attention nodes, on IOI-PP.
Figure 2: Relative costs of methods across models, on sample prompt pairs. The costs are relative to having an oracle, which would verify nodes in decreasing order of true contribution size. Costs are aggregated using an inverse-rank-weighted geometric mean. This means they correspond to the area above the diagonal for each curve in Figure 1 and are relative to the area under the dotted (oracle) line. See Section 4.2 for more details on this metric. Note that GradDrop (difference between AtP+QKfix and AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT) comes with a noticeable upfront cost and so looks worse in this comparison while still hel** avoid false negatives as shown inFigure 1.

2 Background

2.1 Problem Statement

Our goal is to identify the contributions to model behavior by individual model components. We first formalize model components, then formalize model behaviour, and finally state the contribution problem in causal language. While we state the formalism in terms of a decoder-only transformer language model (Vaswani et al., 2017; Radford et al., 2018), and conduct all our experiments on models of that class, the formalism is also straightforwardly applicable to other model classes.

Model components.

We are given a model :XV:𝑋superscript𝑉\mathcal{M}:X\rightarrow\mathbb{R}^{V}caligraphic_M : italic_X → blackboard_R start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT that maps a prompt (token sequence) xX:={1,,V}T𝑥𝑋assignsuperscript1𝑉𝑇x\in X:=\{1,\ldots,V\}^{T}italic_x ∈ italic_X := { 1 , … , italic_V } start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT to output logits over a set of V𝑉Vitalic_V tokens, aiming to predict the next token in the sequence. We will view the model \mathcal{M}caligraphic_M as a computational graph (N,E)𝑁𝐸(N,E)( italic_N , italic_E ) where the node set N𝑁Nitalic_N is the set of model components, and a directed edge e=(n1,n2)E𝑒subscript𝑛1subscript𝑛2𝐸e=(n_{1},n_{2})\in Eitalic_e = ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ∈ italic_E is present iff the output of n1subscript𝑛1n_{1}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is a direct input into the computation of n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. We will use n(x)𝑛𝑥n(x)italic_n ( italic_x ) to represent the activation (intermediate computation result) of n𝑛nitalic_n when computing (x)𝑥\mathcal{M}(x)caligraphic_M ( italic_x ).

The choice of N𝑁Nitalic_N determines how fine-grained the attribution will be. For example, for transformer models, we could have a relatively coarse-grained attribution where each layer is considered a single node. In this paper we will primarily consider more fine-grained attributions that are more expensive to compute (see Section 4 for details); we revisit this issue in Section 5.

Model behaviour.

Following past work (Geiger et al., 2022; Chan et al., 2022; Wang et al., 2022), we assume a distribution 𝒟𝒟\mathcal{D}caligraphic_D over pairs of inputs xclean,xnoisesuperscript𝑥cleansuperscript𝑥noisex^{\text{clean}},x^{\text{noise}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT, where xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT is a prompt on which the behaviour occurs, and xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT is a reference prompt which we use as a source of noise to intervene with111This precludes interventions which use activation values that are never actually realized, such as zero-ablation or mean ablation. An alternative formulation via distributions of activation values is also possible.. We are also given a metric222Common metrics in language models are next token prediction loss, difference in log prob between a correct and incorrect next token, probability of the correct next token, etc. :V:superscript𝑉\mathcal{L}:\mathbb{R}^{V}\rightarrow\mathbb{R}caligraphic_L : blackboard_R start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT → blackboard_R, which quantifies the behaviour of interest.

Contribution of a component.

Similarly to the work referenced above we define the contribution c(n)𝑐𝑛c(n)italic_c ( italic_n ) of a node n𝑛nitalic_n to the model’s behaviour as the counterfactual absolute333The sign of the impact may be of interest, but in this work we’ll focus on the magnitude, as a measure of causal importance. expected impact of replacing that node on the clean prompt with its value on the reference prompt xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT.

Using do-calculus notation (Pearl, 2000) this can be expressed as c(n):=|(n)|assign𝑐𝑛𝑛c(n):=|\mathcal{I}(n)|italic_c ( italic_n ) := | caligraphic_I ( italic_n ) |, where

(n)𝑛\displaystyle\mathcal{I}(n)caligraphic_I ( italic_n ) :=𝔼(xclean,xnoise)𝒟[(n;xclean,xnoise)],assignabsentsubscript𝔼similar-tosuperscript𝑥cleansuperscript𝑥noise𝒟delimited-[]𝑛superscript𝑥cleansuperscript𝑥noise\displaystyle:=\mathbb{E}_{(x^{\text{clean}},x^{\text{noise}})\sim\mathcal{D}}% \left[\mathcal{I}(n;x^{\text{clean}},x^{\text{noise}})\right],:= blackboard_E start_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ∼ caligraphic_D end_POSTSUBSCRIPT [ caligraphic_I ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ] , (1)

where we define the intervention effect \mathcal{I}caligraphic_I for xclean,xnoisesuperscript𝑥cleansuperscript𝑥noisex^{\text{clean}},x^{\text{noise}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT as

(n;xclean,xnoise)𝑛superscript𝑥cleansuperscript𝑥noise\displaystyle\mathcal{I}(n;x^{\text{clean}},x^{\text{noise}})caligraphic_I ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) :=((xcleando(nn(xnoise))))((xclean)).assignabsentconditionalsuperscript𝑥cleando𝑛𝑛superscript𝑥noisesuperscript𝑥clean\displaystyle:=\mathcal{L}(\mathcal{M}(x^{\text{clean}}\mid\operatorname{do}(n% \leftarrow n(x^{\text{noise}}))))-\mathcal{L}(\mathcal{M}(x^{\text{clean}})).:= caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ∣ roman_do ( italic_n ← italic_n ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) ) ) - caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) . (2)

Note that the need to average the effect across a distribution adds a potentially large multiplicative factor to the cost of computing c(n)𝑐𝑛c(n)italic_c ( italic_n ), further motivating this work.

We can also intervene on a set of nodes η={ni}𝜂subscript𝑛𝑖\eta=\{n_{i}\}italic_η = { italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }. To do so, we overwrite the values of all nodes in η𝜂\etaitalic_η with their values from a reference prompt. Abusing notation, we write η(x)𝜂𝑥\eta(x)italic_η ( italic_x ) as the set of activations of the nodes in η𝜂\etaitalic_η, when computing (x)𝑥\mathcal{M}(x)caligraphic_M ( italic_x ).

(η;xclean,xnoise)𝜂superscript𝑥cleansuperscript𝑥noise\displaystyle\mathcal{I}(\eta;x^{\text{clean}},x^{\text{noise}})caligraphic_I ( italic_η ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) :=((xcleando(ηη(xnoise))))((xclean))assignabsentconditionalsuperscript𝑥cleando𝜂𝜂superscript𝑥noisesuperscript𝑥clean\displaystyle:=\mathcal{L}(\mathcal{M}(x^{\text{clean}}\mid\operatorname{do}(% \eta\leftarrow\eta(x^{\text{noise}}))))-\mathcal{L}(\mathcal{M}(x^{\text{clean% }})):= caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ∣ roman_do ( italic_η ← italic_η ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) ) ) - caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) (3)

We note that it is also valid to define contribution as the expected impact of replacing a node on the reference prompt with its value on the clean prompt, also known as denoising or knock-in. We follow Chan et al. (2022); Wang et al. (2022) in using noising, however denoising is also widely used in the literature (Meng et al., 2023; Lieberum et al., 2023). We briefly consider how this choice affects AtP in Section 5.2.

2.2 Attribution Patching

On state of the art models, computing c(n)𝑐𝑛c(n)italic_c ( italic_n ) for all n𝑛nitalic_n can be prohibitively expensive as there may be billions or more nodes. Furthermore, to compute this value precisely requires evaluating it on all prompt pairs, thus the runtime cost of Equation 1 for each n𝑛nitalic_n scales with the size of the support of 𝒟𝒟\mathcal{D}caligraphic_D.

We thus turn to a fast approximation of Equation 1. As suggested by Nanda (2022); Figurnov et al. (2016); Molchanov et al. (2017), we can make a first-order Taylor expansion to (n;xclean,xnoise)𝑛superscript𝑥cleansuperscript𝑥noise\mathcal{I}(n;x^{\text{clean}},x^{\text{noise}})caligraphic_I ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) around n(xnoise)n(xclean)𝑛superscript𝑥noise𝑛superscript𝑥cleann(x^{\text{noise}})\approx n(x^{\text{clean}})italic_n ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ≈ italic_n ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ):

^AtP(n;xclean,xnoise)subscript^AtP𝑛superscript𝑥cleansuperscript𝑥noise\displaystyle\hat{\mathcal{I}}_{\text{AtP}}(n;x^{\text{clean}},x^{\text{noise}})over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) :=(n(xnoise)n(xclean))((xclean))n|n=n(xclean)assignabsentevaluated-atsuperscript𝑛superscript𝑥noise𝑛superscript𝑥cleansuperscript𝑥clean𝑛𝑛𝑛superscript𝑥clean\displaystyle:=(n(x^{\text{noise}})-n(x^{\text{clean}}))^{\intercal}\frac{% \partial\mathcal{L}(\mathcal{M}(x^{\text{clean}}))}{\partial n}\Big{|}_{n=n(x^% {\text{clean}})}:= ( italic_n ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) - italic_n ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT divide start_ARG ∂ caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) end_ARG start_ARG ∂ italic_n end_ARG | start_POSTSUBSCRIPT italic_n = italic_n ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT (4)

Then, similarly to Syed et al. (2023), we apply this to a distribution by taking the absolute value inside the expectation in Equation 1 rather than outside; this decreases the chance that estimates across prompt pairs with positive and negative effects might erroneously lead to a significantly smaller estimate. (We briefly explore the amount of cancellation behaviour in the true effect distribution in Section B.2.) As a result, we get an estimate

c^AtP(n)subscript^𝑐AtP𝑛\displaystyle\hat{c}_{\text{AtP}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) :=𝔼xclean,xnoise[|^AtP(n;xclean,xnoise)|].assignabsentsubscript𝔼superscript𝑥cleansuperscript𝑥noisedelimited-[]subscript^AtP𝑛superscript𝑥cleansuperscript𝑥noise\displaystyle:=\mathbb{E}_{x^{\text{clean}},x^{\text{noise}}}\left[\left|\hat{% \mathcal{I}}_{\text{AtP}}(n;x^{\text{clean}},x^{\text{noise}})\right|\right].:= blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) | ] . (5)

This procedure is also called Attribution Patching (Nanda, 2022) or AtP. AtP requires two forward passes and one backward pass to compute an estimate score for all nodes on a given prompt pair, and so provides a very significant speedup over brute force activation patching.

3 Methods

We now describe some failure modes of AtP and address them, yielding an improved method AtP*. We then discuss some alternative methods for estimating c(n)𝑐𝑛c(n)italic_c ( italic_n ), to put AtP(*)’s performance in context. Finally we discuss how to combine Subsampling, one such alternative method described in Section 3.3, and AtP* to give a diagnostic to statistically test whether AtP* may have missed important false negatives.

3.1 AtP improvements

We identify two common classes of false negatives occurring when using AtP.

The first failure mode occurs when the preactivation on xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT is in a flat region of the activation function (e.g. produces a saturated attention weight), but the preactivation on xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT is not in that region. As is apparent from Equation 4, AtP uses a linear approximation to the ground truth in Equation 1, so if the non-linear function is badly approximated by the local gradient, AtP ceases to be accurate – see Figure 3 for an illustration and Figure 4 which denotes in color the maximal difference in attention observed between prompt pairs, suggesting that this failure mode occurs in practice.

Refer to caption
Figure 3: A linear approximation to the attention probability is a particularly poor approximation in cases where one or both of the endpoints are in a saturated region of the softmax. Note that when varying only a single key, the softmax becomes a sigmoid of the dot product of that key and the query.

Another, unrelated failure mode occurs due to cancellation between direct and indirect effects: roughly, if the total effect (on some prompt pair) is a sum of direct and indirect effects (Pearl, 2001) (n)=direct(n)+indirect(n)𝑛superscriptdirect𝑛superscriptindirect𝑛\mathcal{I}(n)=\mathcal{I}^{\text{direct}}(n)+\mathcal{I}^{\text{indirect}}(n)caligraphic_I ( italic_n ) = caligraphic_I start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) + caligraphic_I start_POSTSUPERSCRIPT indirect end_POSTSUPERSCRIPT ( italic_n ), and these are close to cancelling, then a small multiplicative approximation error in ^AtPindirect(n)superscriptsubscript^AtPindirect𝑛\hat{\mathcal{I}}_{\text{AtP}}^{\text{indirect}}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT indirect end_POSTSUPERSCRIPT ( italic_n ), due to non-linearities such as GELU and softmax, can accidentally cause |^AtPdirect(n)+^AtPindirect(n)|superscriptsubscript^AtPdirect𝑛superscriptsubscript^AtPindirect𝑛|\hat{\mathcal{I}}_{\text{AtP}}^{\text{direct}}(n)+\hat{\mathcal{I}}_{\text{% AtP}}^{\text{indirect}}(n)|| over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) + over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT indirect end_POSTSUPERSCRIPT ( italic_n ) | to be orders of magnitude smaller than |(n)|𝑛|\mathcal{I}(n)|| caligraphic_I ( italic_n ) |.

3.1.1 False negatives from attention saturation

AtP relies on the gradient at each activation being reflective of the true behaviour of the function with respect to intervention at that activation. In some cases, though, a node may immediately feed into a non-linearity whose effect may not be adequately predicted by the gradient; for example, attention key and query nodes feeding into the attention softmax non-linearity. To showcase this, we plot the true rank of each node’s effect against its rank assigned by AtP in Figure 4 (left). The plot shows that there are many pronounced false negatives (below the dashed line), especially among keys and queries.

Normal activation patching for queries and keys involves changing a query or key and then re-running the rest of the model, kee** all else the same. AtP takes a linear approximation to the entire rest of the model rather than re-running it. We propose explicitly re-computing the first step of the rest of the model, i.e. the attention softmax, and then taking a linear approximation to the rest. Formally, for attention key and query nodes, instead of using the gradient on those nodes directly, we take the difference in attention weight caused by that key or query, multiplied by the gradient on the attention weights themselves. This requires finding the change in attention weights from each key and query patch — but that can be done efficiently using (for all keys and queries in total) less compute than two transformer forward passes. This correction avoids the problem of saturated attention, while otherwise retaining the performance of AtP.

Queries

For the queries, we can easily compute the adjusted effect by running the model on xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT and caching the noise queries. We then run the model on xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT and cache the attention keys and weights. Finally, we compute the attention weights that result from combining all the keys from the xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT forward pass with the queries from the xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT forward pass. This costs approximately as much as the unperturbed attention computation of the transformer forward pass. For each query node n𝑛nitalic_n we refer to the resulting weight vector as attn(n)patch\operatorname{attn}(n)_{\text{patch}}roman_attn ( italic_n ) start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT, in contrast with the weights attn(n)(xclean)attn𝑛superscript𝑥clean\operatorname{attn}(n)(x^{\text{clean}})roman_attn ( italic_n ) ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) from the clean forward pass. The improved attribution estimate for n𝑛nitalic_n is then

^AtPfixQ(n;xclean,xnoise):=assignsubscriptsuperscript^𝑄AtPfix𝑛superscript𝑥cleansuperscript𝑥noiseabsent\displaystyle\hat{\mathcal{I}}^{Q}_{\text{AtPfix}}(n;x^{\text{clean}},x^{\text% {noise}}):={}over^ start_ARG caligraphic_I end_ARG start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT AtPfix end_POSTSUBSCRIPT ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) := k^AtP(attn(n)k;xclean,xnoise)subscript𝑘subscript^AtPattnsubscript𝑛𝑘superscript𝑥cleansuperscript𝑥noise\displaystyle\sum_{k}\hat{\mathcal{I}}_{\text{AtP}}(\text{attn}(n)_{k};x^{% \text{clean}},x^{\text{noise}})∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( attn ( italic_n ) start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) (6)
=\displaystyle={}= (attn(n)patchattn(n)(xclean))((xclean))attn(n)|attn(n)=attn(n)(xclean)\displaystyle(\operatorname{attn}(n)_{\text{patch}}-\operatorname{attn}(n)(x^{% \text{clean}}))^{\intercal}\frac{\partial\mathcal{L}(\mathcal{M}(x^{\text{% clean}}))}{\partial\operatorname{attn}(n)}\Big{|}_{\operatorname{attn}(n)=% \operatorname{attn}(n)(x^{\text{clean}})}( roman_attn ( italic_n ) start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT - roman_attn ( italic_n ) ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT divide start_ARG ∂ caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) end_ARG start_ARG ∂ roman_attn ( italic_n ) end_ARG | start_POSTSUBSCRIPT roman_attn ( italic_n ) = roman_attn ( italic_n ) ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT (7)
Keys

For the keys we first describe a simple but inefficient method. We again run the model on xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT, caching the noise keys. We also run it on xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT, caching the clean queries and attention probabilities. Let key nodes for a single attention head be n1k,,nTksubscriptsuperscript𝑛𝑘1subscriptsuperscript𝑛𝑘𝑇n^{k}_{1},\dots,n^{k}_{T}italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT and let queries(ntk)={n1q,,nTq}queriessubscriptsuperscript𝑛𝑘𝑡subscriptsuperscript𝑛𝑞1subscriptsuperscript𝑛𝑞𝑇\operatorname{queries}(n^{k}_{t})=\{n^{q}_{1},\dots,n^{q}_{T}\}roman_queries ( italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = { italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT } be the set of query nodes for the same head as node ntksubscriptsuperscript𝑛𝑘𝑡n^{k}_{t}italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. We then define

attnpatcht(nq)superscriptsubscriptattnpatch𝑡superscript𝑛𝑞\displaystyle\operatorname{attn}_{\text{patch}}^{t}(n^{q})roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) :=attn(nq)(xcleando(ntkntk(xnoise)))assignabsentattnsuperscript𝑛𝑞conditionalsuperscript𝑥cleandosubscriptsuperscript𝑛𝑘𝑡subscriptsuperscript𝑛𝑘𝑡superscript𝑥noise\displaystyle:=\operatorname{attn}(n^{q})(x^{\text{clean}}\mid\operatorname{do% }(n^{k}_{t}\leftarrow n^{k}_{t}(x^{\text{noise}}))):= roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ∣ roman_do ( italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) ) (8)
Δtattn(nq)subscriptΔ𝑡attnsuperscript𝑛𝑞\displaystyle\Delta_{t}\operatorname{attn}(n^{q})roman_Δ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) :=attnpatcht(nq)attn(nq)(xclean)assignabsentsuperscriptsubscriptattnpatch𝑡superscript𝑛𝑞attnsuperscript𝑛𝑞superscript𝑥clean\displaystyle:=\operatorname{attn}_{\text{patch}}^{t}(n^{q})-\operatorname{% attn}(n^{q})(x^{\text{clean}}):= roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) - roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) (9)

The improved attribution estimate for ntksubscriptsuperscript𝑛𝑘𝑡n^{k}_{t}italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is then

^AtPfixK(ntk;xclean,xnoise)superscriptsubscript^AtPfix𝐾subscriptsuperscript𝑛𝑘𝑡superscript𝑥cleansuperscript𝑥noise\displaystyle\hat{\mathcal{I}}_{\text{AtPfix}}^{K}(n^{k}_{t};x^{\text{clean}},% x^{\text{noise}})over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtPfix end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) :=nqqueries(ntk)Δtattn(nq)((xclean))attn(nq)|attn(nq)=attn(nq)(xclean)\displaystyle:=\sum_{n^{q}\in\operatorname{queries}(n^{k}_{t})}\Delta_{t}% \operatorname{attn}(n^{q})^{\intercal}\frac{\partial\mathcal{L}(\mathcal{M}(x^% {\text{clean}}))}{\partial\operatorname{attn}(n^{q})}\Big{|}_{\operatorname{% attn}(n^{q})=\operatorname{attn}(n^{q})(x^{\text{clean}})}:= ∑ start_POSTSUBSCRIPT italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ∈ roman_queries ( italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT divide start_ARG ∂ caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) end_ARG start_ARG ∂ roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) end_ARG | start_POSTSUBSCRIPT roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) = roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT (10)

However, the procedure we just described is costly to execute as it requires O(T3)Osuperscript𝑇3\operatorname{O}(T^{3})roman_O ( italic_T start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ) flops to naively compute Equation 9 for all T𝑇Titalic_T keys. In Section A.2.1 we describe a more efficient variant that takes no more compute than the forward pass attention computation itself (requiring O(T2)Osuperscript𝑇2\operatorname{O}(T^{2})roman_O ( italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) flops). Since Equation 6 is also cheaper to compute than a forward pass, the full QK fix requires less than two transformer forward passes (since the latter also includes MLP computations).

For attention nodes we show the effects of applying the query and key fixes in Figure 4 (middle). We observe that the propagation of Q/K effects has a major impact on reducing the false negative rate.

Refer to caption
Figure 4: Ranks of c(n)𝑐𝑛c(n)italic_c ( italic_n ) against ranks of c^AtP(n)subscript^𝑐AtP𝑛\hat{c}_{\text{AtP}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ), on Pythia-12B on CITY-PP. Both improvements to AtP reduce the number of false negatives (bottom right triangle area), where in this case most improvements come from the QK fix. Coloration indicates the maximum absolute difference in attention probability when comparing xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT and patching a given query or key. Many false negatives are keys and queries with significant maximum difference in attention probability, suggesting they are due to attention saturation as illustrated in Figure 3. Output and value nodes are colored in grey as they do not contribute to the attention probability.

3.1.2 False negatives from cancellation

This form of cancellation occurs when the backpropagated gradient from indirect effects is combined with the gradient from the direct effect. We propose a way to modify the backpropagation within the attribution patching to reduce this issue. If we artificially zero out the gradient at a downstream layer that contributes to the indirect effect, the cancellation is disrupted. (This is also equivalent to patching in clean activations at the outputs of the layer.) Thus we propose to do this iteratively, swee** across the layers. Any node whose effect does not route through the layer being gradient-zeroed will have its estimate unaffected.

We call this method GradDrop. For every layer {1,,L}1𝐿\ell\in\{1,\ldots,L\}roman_ℓ ∈ { 1 , … , italic_L } in the model, GradDrop computes an AtP estimate for all nodes, where gradients on the residual contribution from \ellroman_ℓ are set to 0, including the propagation to earlier layers. This provides a different estimate for all nodes, for each layer that was dropped. We call the so-modified gradient n=n((xcleando(noutnout(xclean))))superscript𝑛𝑛conditionalsuperscript𝑥cleandosubscriptsuperscript𝑛outsubscriptsuperscript𝑛outsuperscript𝑥clean\frac{\partial\mathcal{L}^{\ell}}{\partial n}=\frac{\partial\mathcal{L}}{% \partial n}(\mathcal{M}(x^{\text{clean}}\mid\operatorname{do}(n^{\text{out}}_{% \ell}\leftarrow n^{\text{out}}_{\ell}(x^{\text{clean}}))))divide start_ARG ∂ caligraphic_L start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ italic_n end_ARG = divide start_ARG ∂ caligraphic_L end_ARG start_ARG ∂ italic_n end_ARG ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ∣ roman_do ( italic_n start_POSTSUPERSCRIPT out end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ← italic_n start_POSTSUPERSCRIPT out end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) ) ) when drop** layer \ellroman_ℓ, where noutsubscriptsuperscript𝑛outn^{\text{out}}_{\ell}italic_n start_POSTSUPERSCRIPT out end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT is the contribution to the residual stream across all positions. Using nsuperscript𝑛\frac{\partial\mathcal{L}^{\ell}}{\partial n}divide start_ARG ∂ caligraphic_L start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ italic_n end_ARG in place of nsuperscript𝑛\frac{\partial\mathcal{L}^{\ell}}{\partial n}divide start_ARG ∂ caligraphic_L start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ italic_n end_ARG in the AtP formula produces an estimate ^AtP+GD(n)subscript^subscriptAtP+GD𝑛\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ). Then, the estimates are aggregated by averaging their absolute values, and then scaling by LL1𝐿𝐿1\frac{L}{L-1}divide start_ARG italic_L end_ARG start_ARG italic_L - 1 end_ARG to avoid changing the direct-effect path’s contribution (which is otherwise zeroed out when drop** the layer the node is in).

c^AtP+GD(n)subscript^𝑐AtP+GD𝑛\displaystyle\hat{c}_{\text{AtP+GD}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP+GD end_POSTSUBSCRIPT ( italic_n ) :=𝔼xclean,xnoise[1L1=1L|^AtP+GD(n;xclean,xnoise)|]assignabsentsubscript𝔼superscript𝑥cleansuperscript𝑥noisedelimited-[]1𝐿1superscriptsubscript1𝐿subscript^subscriptAtP+GD𝑛superscript𝑥cleansuperscript𝑥noise\displaystyle:=\mathbb{E}_{x^{\text{clean}},x^{\text{noise}}}\left[\frac{1}{L-% 1}\sum_{\ell=1}^{L}\left|\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n;x^{\text{% clean}},x^{\text{noise}})\right|\right]:= blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG italic_L - 1 end_ARG ∑ start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) | ] (11)

Note that the forward passes required for computing ^AtP+GD(n;xclean,xnoise)subscript^subscriptAtP+GD𝑛superscript𝑥cleansuperscript𝑥noise\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n;x^{\text{clean}},x^{\text{noise}})over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) don’t depend on \ellroman_ℓ, so the extra compute needed for GradDrop is L𝐿Litalic_L backwards passes from the same intermediate activations on a clean forward pass. This is also the case with the QK fix: the corrected attributions ^AtPfixsubscript^AtPfix\hat{\mathcal{I}}_{\text{AtPfix}}over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtPfix end_POSTSUBSCRIPT are dot products with the attention weight gradients, so the only thing that needs to be recomputed for ^AtPfix+GD(n)subscript^subscriptAtPfix+GD𝑛\hat{\mathcal{I}}_{\text{AtPfix+GD}_{\ell}}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtPfix+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) is the modified gradient attn(n)superscriptattn𝑛\frac{\partial\mathcal{L}^{\ell}}{\partial\operatorname{attn}(n)}divide start_ARG ∂ caligraphic_L start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ roman_attn ( italic_n ) end_ARG. Thus, computing Equation 11 takes L𝐿Litalic_L backwards passes444This can be reduced to (L+1)/2𝐿12(L+1)/2( italic_L + 1 ) / 2 by reusing intermediate results. on top of the costs for AtP.

We show the result of applying GradDrop on attention nodes in Figure 4 (right) and on MLP nodes in Figure 5. In Figure 5, we show the true effect magnitude rank against the AtP+GradDrop rank, while highlighting nodes which improved drastically by applying GradDrop. We give some arguments and intuitions on the benefit of GradDrop in Section A.2.2.

Direct Effect Ratio

To provide some evidence that the observed false negatives are due to cancellation, we compute the ratio between the direct effect cdirect(n)superscript𝑐direct𝑛c^{\text{direct}}(n)italic_c start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) and the total effect c(n)𝑐𝑛c(n)italic_c ( italic_n ). A higher direct effect ratio indicates more cancellation. We observe that the most significant false negatives corrected by GradDrop in Figure 5 (highlighted) have high direct effect ratios of 5.355.355.355.35, 12.212.212.212.2, and 00 (no direct effect) , while the median direct effect ratio of all nodes is 00 (if counting all nodes) or 0.770.770.770.77 (if only counting nodes that have direct effect). Note that direct effect ratio is only applicable to nodes which in fact have a direct connection to the output, and not e.g. to MLP nodes at non-final token positions, since all disconnected nodes have a direct effect of 0 by definition.

Refer to caption
Figure 5: True rank and rank of AtP estimates with and without GradDrop, using Pythia-12B on the CITY-PP distribution with NeuronNodes. GradDrop provides a significant improvement to the largest neuron false negatives (red circles) relative to Default AtP (orange crosses).

3.2 Diagnostics

Despite the improvements we have proposed in Section 3.1, there is no guarantee that AtP* produces no false negatives. Thus, it is desirable to obtain an upper confidence bound on the effect size of nodes that might be missed by AtP*, i.e. that aren’t in the top K𝐾Kitalic_K AtP* estimates, for some K𝐾Kitalic_K. Let the top K𝐾Kitalic_K nodes be TopAtP*KsubscriptsuperscriptTop𝐾𝐴𝑡𝑃\text{Top}^{K}_{AtP*}Top start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_A italic_t italic_P * end_POSTSUBSCRIPT. It so happens that we can use subset sampling to obtain such a bound.

As described in Algorithm 1 and Section 3.3, the subset sampling algorithm returns summary statistics: i¯±nsubscriptsuperscript¯𝑖𝑛plus-or-minus\bar{i}^{n}_{\pm}over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT, s±nsubscriptsuperscript𝑠𝑛plus-or-minuss^{n}_{\pm}italic_s start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT and count±nsubscriptsuperscriptcount𝑛plus-or-minus\text{count}^{n}_{\pm}count start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT for each node n𝑛nitalic_n: the average effect size i¯±nsubscriptsuperscript¯𝑖𝑛plus-or-minus\bar{i}^{n}_{\pm}over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT of a subset conditional on the node being contained in that subset (+++) or not (--), the sample standard deviations s±nsubscriptsuperscript𝑠𝑛plus-or-minuss^{n}_{\pm}italic_s start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT, and the sample sizes count±nsubscriptsuperscriptcount𝑛plus-or-minus\text{count}^{n}_{\pm}count start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT. Given these, consider a null hypothesis555This is an unconventional form of H0subscript𝐻0H_{0}italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT – typically a null hypothesis will say that an effect is insignificant. However, the framework of statistical hypothesis testing is based on determining whether the data let us reject the null hypothesis, and in this case the hypothesis we want to reject is the presence, rather than the absence, of a significant false negative. H0nsuperscriptsubscript𝐻0𝑛H_{0}^{n}italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT that |(n)|θ𝑛𝜃|\mathcal{I}(n)|\geq\theta| caligraphic_I ( italic_n ) | ≥ italic_θ, for some threshold θ𝜃\thetaitalic_θ, versus the alternative hypothesis H1nsuperscriptsubscript𝐻1𝑛H_{1}^{n}italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT that |(n)|<θ𝑛𝜃|\mathcal{I}(n)|<\theta| caligraphic_I ( italic_n ) | < italic_θ. We use a one-sided Welch’s t-test666This relies on the populations being approximately unbiased and normally distributed, and not skewed. This tended to be true on inspection, and it’s what the additivity assumption (see Section 3.3) predicts for a single prompt pair — but a nonparametric bootstrap test may be more reliable, at the cost of additional compute. to test this hypothesis; the general practice with a compound null hypothesis is to select the simple sub-hypothesis that gives the greatest p𝑝pitalic_p-value, so to be conservative, the simple null hypothesis is that (n)=θsign(i¯+ni¯n)𝑛𝜃signsubscriptsuperscript¯𝑖𝑛subscriptsuperscript¯𝑖𝑛\mathcal{I}(n)=\theta\operatorname{sign}(\bar{i}^{n}_{+}-\bar{i}^{n}_{-})caligraphic_I ( italic_n ) = italic_θ roman_sign ( over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT + end_POSTSUBSCRIPT - over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ), giving a test statistic of tn=(θ|i¯+ni¯n|)/sWelchnsuperscript𝑡𝑛𝜃subscriptsuperscript¯𝑖𝑛subscriptsuperscript¯𝑖𝑛subscriptsuperscript𝑠𝑛Welcht^{n}=(\theta-|\bar{i}^{n}_{+}-\bar{i}^{n}_{-}|)/s^{n}_{\text{Welch}}italic_t start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT = ( italic_θ - | over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT + end_POSTSUBSCRIPT - over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT - end_POSTSUBSCRIPT | ) / italic_s start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT Welch end_POSTSUBSCRIPT, which gives a p𝑝pitalic_p-value of pn=TtνWelchn(T>tn)superscript𝑝𝑛subscriptsimilar-to𝑇subscript𝑡subscriptsuperscript𝜈𝑛Welch𝑇superscript𝑡𝑛p^{n}=\mathbb{P}_{T\sim t_{\nu^{n}_{\text{Welch}}}}(T>t^{n})italic_p start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT = blackboard_P start_POSTSUBSCRIPT italic_T ∼ italic_t start_POSTSUBSCRIPT italic_ν start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT Welch end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_T > italic_t start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ).

To get a combined conclusion across all nodes in NTopAtP*K𝑁subscriptsuperscriptTop𝐾𝐴𝑡𝑃N\setminus\text{Top}^{K}_{AtP*}italic_N ∖ Top start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_A italic_t italic_P * end_POSTSUBSCRIPT, let’s consider the hypothesis H0=nNTopAtP*KH0nsubscript𝐻0subscript𝑛𝑁subscriptsuperscriptTop𝐾𝐴𝑡𝑃superscriptsubscript𝐻0𝑛H_{0}=\bigvee_{n\in N\setminus\text{Top}^{K}_{AtP*}}H_{0}^{n}italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ⋁ start_POSTSUBSCRIPT italic_n ∈ italic_N ∖ Top start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_A italic_t italic_P * end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT that any of those nodes has true effect |(n)|>θ𝑛𝜃|\mathcal{I}(n)|>\theta| caligraphic_I ( italic_n ) | > italic_θ. Since this is also a compound null hypothesis, maxnpnsubscript𝑛superscript𝑝𝑛\max_{n}p^{n}roman_max start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_p start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT is the corresponding p𝑝pitalic_p-value. Then, to find an upper confidence bound with specified confidence level 1p1𝑝1-p1 - italic_p, we invert this procedure to find the lowest θ𝜃\thetaitalic_θ for which we still have at least that level of confidence. We repeat this for various settings of the sample size m𝑚mitalic_m in Algorithm 1. The exact algorithm is described in Section A.3.

In Figure 6, we report the upper confidence bounds at confidence levels 90%, 99%, 99.9% from running Algorithm 1 with a given m𝑚mitalic_m (right subplots), as well as the number of nodes that have a true contribution c(n)𝑐𝑛c(n)italic_c ( italic_n ) greater than θ𝜃\thetaitalic_θ (left subplots).

Refer to caption
(a) IOI-PP
Refer to caption
(b) IOI
Figure 6: Upper confidence bounds on effect magnitudes of false negatives (i.e. nodes not in the top 1024 nodes according to AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT), at 3 confidence levels, varying the sampling budget. On the left we show in red the true effect of the nodes which are ranked highest by AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT. We also show the true effect magnitude at various ranks of the remaining nodes in orange.
We can see that the bound for (a) finds the true biggest false negative reasonably early, while for (b), where there is no large false negative, we progressively keep gaining confidence with more data.
Note that the costs involved per prompt pair are substantially different between the subplots, and in particular this diagnostic for the distributional case (b) is substantially cheaper to compute than the verification cost of 1024 samples per prompt pair.

3.3 Baselines

Iterative

The most straightforward method is to directly do Activation Patching to find the true effect c(n)𝑐𝑛c(n)italic_c ( italic_n ) of each node, in some uninformed random order. This is necessarily inefficient.

However, if we are scaling to a distribution, it is possible to improve on this, by alternating between phases of (i) for each unverified node, picking a not-yet-measured prompt pair on which to patch it, (ii) ranking the not-yet-verified nodes by the average observed patch effect magnitudes, taking the top |N|/|𝒟|𝑁𝒟|N|/|\mathcal{D}|| italic_N | / | caligraphic_D | nodes, and verifying them. This balances the computational expenditure on the two tasks, and allows us to find large nodes sooner, at least as long as their large effect shows up on many prompt pairs.

Our remaining baseline methods rely on an approximate node additivity assumption: that when intervening on a set of nodes η𝜂\etaitalic_η, the measured effect (η;xclean,xnoise)𝜂superscript𝑥cleansuperscript𝑥noise\mathcal{I}(\eta;x^{\text{clean}},x^{\text{noise}})caligraphic_I ( italic_η ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) is approximately equal to nη(n;xclean,xnoise)subscript𝑛𝜂𝑛superscript𝑥cleansuperscript𝑥noise\sum_{n\in\eta}\mathcal{I}(n;x^{\text{clean}},x^{\text{noise}})∑ start_POSTSUBSCRIPT italic_n ∈ italic_η end_POSTSUBSCRIPT caligraphic_I ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ).

Subsampling

Under the approximate node additivity assumption, we can construct an approximately unbiased estimator of c(n)𝑐𝑛c(n)italic_c ( italic_n ). We select the sets ηksubscript𝜂𝑘\eta_{k}italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT to contain each node independently with some probability p𝑝pitalic_p, and additionally sample prompt pairs xkclean,xknoise𝒟similar-tosubscriptsuperscript𝑥clean𝑘subscriptsuperscript𝑥noise𝑘𝒟x^{\text{clean}}_{k},x^{\text{noise}}_{k}\sim\mathcal{D}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∼ caligraphic_D. For any node n𝑛nitalic_n, and sets of nodes ηkNsubscript𝜂𝑘𝑁\eta_{k}\subset Nitalic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊂ italic_N, let η+(n)superscript𝜂𝑛\eta^{+}(n)italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) be the collection of all those that contain n𝑛nitalic_n, and η(n)superscript𝜂𝑛\eta^{-}(n)italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) be the collection of those that don’t contain n𝑛nitalic_n; we’ll write these node sets as ηk+(n)subscriptsuperscript𝜂𝑘𝑛\eta^{+}_{k}(n)italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_n ) and ηk(n)subscriptsuperscript𝜂𝑘𝑛\eta^{-}_{k}(n)italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_n ), and the corresponding prompt pairs as xkclean+(n),xknoise+(n)superscriptsubscriptsuperscript𝑥clean𝑘𝑛superscriptsubscriptsuperscript𝑥noise𝑘𝑛{x^{\text{clean}}_{k}}^{+}(n),{x^{\text{noise}}_{k}}^{+}(n)italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) and xkclean(n),xknoise(n)superscriptsubscriptsuperscript𝑥clean𝑘𝑛superscriptsubscriptsuperscript𝑥noise𝑘𝑛{x^{\text{clean}}_{k}}^{-}(n),{x^{\text{noise}}_{k}}^{-}(n)italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ). The subsampling (or subset sampling) estimator is then given by

^SS(n)subscript^SS𝑛\displaystyle\hat{\mathcal{I}}_{\text{SS}}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT ( italic_n ) :=1|η+(n)|k=1|η+(n)|(ηk+(n);xkclean+(n),xknoise+(n))1|η(n)|k=1|η(n)|(ηk(n);xkclean(n),xknoise(n))assignabsent1superscript𝜂𝑛superscriptsubscript𝑘1superscript𝜂𝑛subscriptsuperscript𝜂𝑘𝑛superscriptsubscriptsuperscript𝑥clean𝑘𝑛superscriptsubscriptsuperscript𝑥noise𝑘𝑛1superscript𝜂𝑛superscriptsubscript𝑘1superscript𝜂𝑛subscriptsuperscript𝜂𝑘𝑛superscriptsubscriptsuperscript𝑥clean𝑘𝑛superscriptsubscriptsuperscript𝑥noise𝑘𝑛\displaystyle:={\frac{1}{|\eta^{+}(n)|}\sum_{k=1}^{|\eta^{+}(n)|}\mathcal{I}(% \eta^{+}_{k}(n);{x^{\text{clean}}_{k}}^{+}(n),{x^{\text{noise}}_{k}}^{+}(n))-% \frac{1}{|\eta^{-}(n)|}\sum_{k=1}^{|\eta^{-}(n)|}\mathcal{I}(\eta^{-}_{k}(n);{% x^{\text{clean}}_{k}}^{-}(n),{x^{\text{noise}}_{k}}^{-}(n))}:= divide start_ARG 1 end_ARG start_ARG | italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) | end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) | end_POSTSUPERSCRIPT caligraphic_I ( italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_n ) ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) ) - divide start_ARG 1 end_ARG start_ARG | italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) | end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) | end_POSTSUPERSCRIPT caligraphic_I ( italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_n ) ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) ) (12)
c^SS(n)subscript^𝑐SS𝑛\displaystyle\hat{c}_{\text{SS}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT ( italic_n ) :=|^SS(n)|assignabsentsubscript^SS𝑛\displaystyle:=|\hat{\mathcal{I}}_{\text{SS}}(n)|:= | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT ( italic_n ) | (13)

The estimator ^SS(n)subscript^SS𝑛\hat{\mathcal{I}}_{\text{SS}}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT ( italic_n ) is unbiased if there are no interaction effects, and has a small bias proportional to p𝑝pitalic_p under a simple interaction model (see Section A.1.1 for proof).

In practice, we compute all the estimates c^SS(n)subscript^𝑐SS𝑛\hat{c}_{\text{SS}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT ( italic_n ) by sampling a binary mask over all nodes from i.i.d. Bernoulli(p)|N|{}^{|N|}(p)start_FLOATSUPERSCRIPT | italic_N | end_FLOATSUPERSCRIPT ( italic_p ) – each binary mask can be identified with a node set η𝜂\etaitalic_η. In Algorithm 1, we describe how to compute summary statistics related to Equation 13 efficiently for all nodes nN𝑛𝑁n\in Nitalic_n ∈ italic_N. The means i¯±superscript¯𝑖plus-or-minus\bar{i}^{\pm}over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT are enough to compute c^SS(n)subscript^𝑐SS𝑛\hat{c}_{\text{SS}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT ( italic_n ), while other summary statistics are involved in bounding the magnitude of a false negative (cf. Section 3.2). (Note, countn±subscriptsuperscriptcountplus-or-minus𝑛\text{count}^{\pm}_{n}count start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is just an alternate notation for |η±(n)|superscript𝜂plus-or-minus𝑛|\eta^{\pm}(n)|| italic_η start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT ( italic_n ) |.)

Algorithm 1 Subsampling
1:p(0,1)𝑝01p\in(0,1)italic_p ∈ ( 0 , 1 ), model \mathcal{M}caligraphic_M, metric \mathcal{L}caligraphic_L, prompt pair distribution 𝒟𝒟\mathcal{D}caligraphic_D, num samples m𝑚mitalic_m
2:count±superscriptcountplus-or-minus\text{count}^{\pm}count start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT, runSum±superscriptrunSumplus-or-minus\text{runSum}^{\pm}runSum start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT, runSquaredSum±superscriptrunSquaredSumplus-or-minus\text{runSquaredSum}^{\pm}runSquaredSum start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT 0|N|absentsuperscript0𝑁\leftarrow 0^{|N|}← 0 start_POSTSUPERSCRIPT | italic_N | end_POSTSUPERSCRIPT \triangleright Init counts and running sums to 0 vectors
3:for i1 to m𝑖1 to 𝑚i\leftarrow 1\textrm{ to }mitalic_i ← 1 to italic_m do
4:     xclean,xnoise𝒟similar-tosuperscript𝑥cleansuperscript𝑥noise𝒟x^{\text{clean}},x^{\text{noise}}\sim\mathcal{D}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ∼ caligraphic_D
5:     mask+Bernoulli|N|(p)superscriptmasksuperscriptBernoulli𝑁𝑝\text{mask}^{+}\leftarrow\text{Bernoulli}^{|N|}(p)mask start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ← Bernoulli start_POSTSUPERSCRIPT | italic_N | end_POSTSUPERSCRIPT ( italic_p ) \triangleright Sample binary mask for patching
6:     mask1mask+superscriptmask1superscriptmask\text{mask}^{-}\leftarrow 1-\text{mask}^{+}mask start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ← 1 - mask start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT
7:     i({nN:maskn+=1};xclean,xnoise)𝑖conditional-set𝑛𝑁subscriptsuperscriptmask𝑛1superscript𝑥cleansuperscript𝑥noisei\leftarrow\mathcal{I}(\{n\in N:\text{mask}^{+}_{n}=1\};x^{\text{clean}},x^{% \text{noise}})italic_i ← caligraphic_I ( { italic_n ∈ italic_N : mask start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = 1 } ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT )\triangleright η+={nN:maskn+=1}superscript𝜂conditional-set𝑛𝑁subscriptsuperscriptmask𝑛1\eta^{+}=\{n\in N:\text{mask}^{+}_{n}=1\}italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT = { italic_n ∈ italic_N : mask start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = 1 }
8:     count±count±+mask±superscriptcountplus-or-minussuperscriptcountplus-or-minussuperscriptmaskplus-or-minus\text{count}^{\pm}\,\leftarrow\,\text{count}^{\pm}+\text{mask}^{\pm}count start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT ← count start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT + mask start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT
9:     runSum±runSum±+imask±superscriptrunSumplus-or-minussuperscriptrunSumplus-or-minus𝑖superscriptmaskplus-or-minus\text{runSum}^{\pm}\,\leftarrow\,\text{runSum}^{\pm}+i\cdot\text{mask}^{\pm}runSum start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT ← runSum start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT + italic_i ⋅ mask start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT
10:     runSquaredSum±runSquaredSum±+i2mask±superscriptrunSquaredSumplus-or-minussuperscriptrunSquaredSumplus-or-minussuperscript𝑖2superscriptmaskplus-or-minus\text{runSquaredSum}^{\pm}\,\leftarrow\,\text{runSquaredSum}^{\pm}+i^{2}\cdot% \text{mask}^{\pm}runSquaredSum start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT ← runSquaredSum start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT + italic_i start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⋅ mask start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT
11:i¯±runSum±/count±superscript¯𝑖plus-or-minussuperscriptrunSumplus-or-minussuperscriptcountplus-or-minus\bar{i}^{\pm}\leftarrow\text{runSum}^{\pm}/\text{count}^{\pm}over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT ← runSum start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT / count start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT
12:s±(runSquaredSum±(i¯±)2)/(count±1)superscript𝑠plus-or-minussuperscriptrunSquaredSumplus-or-minussuperscriptsuperscript¯𝑖plus-or-minus2superscriptcountplus-or-minus1s^{\pm}\leftarrow\sqrt{(\text{runSquaredSum}^{\pm}-(\bar{i}^{\pm})^{2})/(\text% {count}^{\pm}-1)}italic_s start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT ← square-root start_ARG ( runSquaredSum start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT - ( over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) / ( count start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT - 1 ) end_ARG
13:return count±superscriptcountplus-or-minus\text{count}^{\pm}count start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT, i¯±superscript¯𝑖plus-or-minus\bar{i}^{\pm}over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT, s±superscript𝑠plus-or-minuss^{\pm}italic_s start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT\triangleright If diagnostics are not required, i¯±superscript¯𝑖plus-or-minus\bar{i}^{\pm}over¯ start_ARG italic_i end_ARG start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT is sufficient.
Blocks & Hierarchical

Instead of sampling each η𝜂\etaitalic_η independently, we can group nodes into fixed “blocks” η𝜂\etaitalic_η of some size, and patch each block to find its aggregated contribution c(η)𝑐𝜂c(\eta)italic_c ( italic_η ); we can then traverse the nodes, starting with high-contribution blocks and proceeding from there.

There is a tradeoff in terms of the block size: using large blocks increases the compute required to traverse a high-contribution block, but using small blocks increases the compute required to finish traversing all of the blocks. We refer to the fixed block size setting as Blocks. Another way to handle this tradeoff is to add recursion: the blocks can be grouped into higher-level blocks, and so forth. We call this method Hierarchical.

We present results from both methods in our comparison plots, but relegate details to Section A.1.2. Relative to subsampling, these grou**-based methods have the disadvantage that on distributions, their cost scales linearly with size of 𝒟𝒟\mathcal{D}caligraphic_D’s support, in addition to scaling with the number of nodes777AtP* also scales linearly in the same way, but with far fewer forward passes per prompt pair..

4 Experiments

4.1 Setup

Nodes

When attributing model behavior to components, an important choice is the partition of the model’s computational graph into units of analysis or ‘nodes’ Nn𝑛𝑁N\ni nitalic_N ∋ italic_n (cf. Section 2.1). We investigate two settings for the choice of N𝑁Nitalic_N, AttentionNodes and NeuronNodes. For NeuronNodes, each MLP neuron888We use the neuron post-activation for the node; this makes no difference when causally intervening, but for AtP it’s beneficial, because it makes the n(n)maps-to𝑛𝑛n\mapsto\mathcal{L}(n)italic_n ↦ caligraphic_L ( italic_n ) function more linear. is a separate node. For AttentionNodes, we consider the query, key, and value vector for each head as distinct nodes, as well as the pre-linear per-head attention output999We include the output node because it provides additional information about what function an attention head is serving, particularly in the case where its queries have negligible patch effects relative to its keys and/or values. This may happen as a result of choosing xclean,xnoisesuperscript𝑥cleansuperscript𝑥noisex^{\text{clean}},\,x^{\text{noise}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT such that the query does not differ across the prompts.. We also refer to these units as ‘sites’. For each site, we consider each copy of that site at different token positions as a separate node. As a result, we can identify each node nN𝑛𝑁n\in Nitalic_n ∈ italic_N with a pair (T,S)𝑇𝑆(T,S)( italic_T , italic_S ) from the product TokenPosition ×\times× Site. Since our two settings for N𝑁Nitalic_N are using a different level of granularity and are expected to have different per-node effect magnitudes, we present results on them separately.

Models

We investigate transformer language models from the Pythia suite (Biderman et al., 2023) of sizes between 410M and 12B parameters. This allows us to demonstrate that our methods are applicable across scale. Our cost-of-verified-recall plots in Figures 1, 7 and 8 refer to Pythia-12B. Results for other model sizes are presented via the relative-cost (cf. Section 4.2) plots in the main body Figure 9 and disaggregated via cost-of-verified recall in Section B.3.

Effect Metric \mathcal{L}caligraphic_L

All reported results use the negative log probability101010Another popular metric is the difference in logits between the clean and noise target. As opposed to the negative logprob, the logit difference is linear in the final logits and thus might favor AtP. A downside of logit difference is that it is sensitive to the noise target, which may not be meaningful if there are multiple plausible completions, such as in IOI. as their loss function \mathcal{L}caligraphic_L. We compute \mathcal{L}caligraphic_L relative to targets from the clean prompt xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT. We briefly explore other metrics in Section B.4.

4.2 Measuring Effectiveness and Efficiency

Cost of verified recall

As mentioned in the introduction, we’re primarily interested in finding the largest-effect nodes – see Appendix D for the distribution of c(n)𝑐𝑛c(n)italic_c ( italic_n ) across models and distributions. Once we have obtained node estimates via a given method, it is relatively cheap to directly measure true effects of top nodes one at a time; we refer to this as “verification”. Incorporating this into our methodology, we find that false positives are typically not a big issue; they are simply revealed during verification. In contrast, false negatives are not so easy to remedy without verifying all nodes, which is what we were trying to avoid.

We compare methods on the basis of total compute cost (in # of forward passes) to verify the K𝐾Kitalic_K nodes with biggest true effect magnitude, for varying K𝐾Kitalic_K. The procedure being measured is to first compute estimates (incurring an estimation cost), and then sweep through nodes in decreasing order of estimated magnitude, measuring their individual effects c(n)𝑐𝑛c(n)italic_c ( italic_n ) (i.e. verifying them), and incurring a verification cost. Then the total cost is the sum of these two costs.

Inverse-rank-weighted geometric mean cost

Sometimes we find it useful to summarize the method performance with a scalar; this is useful for comparing methods at a glance across different settings (e.g. model sizes, as in Figure 2), or for selecting hyperparameters (cf. Section B.5). The cost of verified recall of the top K𝐾Kitalic_K nodes is of interest for K𝐾Kitalic_K at varying orders of magnitude. In order to avoid the performance metric being dominated by small or large K𝐾Kitalic_K, we assign similar total weight to different orders of magnitude: we use a weighted average with weight 1/K1𝐾1/K1 / italic_K for the cost of the top K𝐾Kitalic_K nodes. Similarly, since the costs themselves may have different orders of magnitude, we average them on a log scale – i.e., we take a geometric mean.

This metric is also proportional to the area under the curve in plots like Figure 1. To produce a more understandable result, we always report it relative to (i.e. divided by) the oracle verification cost on the same metric; the diagonal line is the oracle, with relative cost 1. We refer to this as the IRWRGM (inverse-rank-weighted relative geometric mean) cost, or the relative cost.

Note that the preference of the individual practitioner may be different such that this metric is no longer accurately measuring the important rank regime. For example, AtP* pays a notable upfront cost relative to AtP or AtP+QKfix, which sets it at a disadvantage when it doesn’t manage to find additional false negatives; but this may or may not be practically significant. To understand the performance in more detail we advise to refer to the cost of verified recall plots, like Figure 1 (or many more in Section B.3).

4.3 Single Prompt Pairs versus Distributions

We focus many of our experiments on single prompt pairs. This is primarily because it’s easier to set up and get ground truth data. It’s also a simpler setting in which to investigate the question, and one that’s more universally applicable, since a distribution to generalize to is not always available.

Refer to caption
(a) NeuronNodes on CITY-PP
Refer to caption
(b) AttentionNodes on IOI-PP
Figure 7: Costs of finding the most causally-important nodes in Pythia-12B using different methods on clean prompt pairs, with 90% target recall. This highlights that the AtP* false negatives in Figure 1 are a small minority of nodes.
Clean single prompt pairs

As a starting point we report results on single prompt pairs which we expect to have relatively clean circuitry111111Formally, these represent prompt distributions via the delta distribution p(xclean,xnoise)=δx1clean,x1noise(xclean,xnoise)𝑝superscript𝑥cleansuperscript𝑥noisesubscript𝛿subscriptsuperscript𝑥clean1subscriptsuperscript𝑥noise1superscript𝑥cleansuperscript𝑥noisep(x^{\text{clean}},x^{\text{noise}})=\delta_{x^{\text{clean}}_{1},x^{\text{% noise}}_{1}}(x^{\text{clean}},x^{\text{noise}})italic_p ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) = italic_δ start_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) where x1clean,x1noisesubscriptsuperscript𝑥clean1subscriptsuperscript𝑥noise1x^{\text{clean}}_{1},x^{\text{noise}}_{1}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is the singular prompt pair.. All singular prompt pairs are shown in Table 1. IOI-PP is chosen to resemble an instance from the indirect object identification (IOI) task (Wang et al., 2022), a task predominantly involving attention heads. CITY-PP is chosen to elicit factual recall which previous research suggests involves early MLPs and a small number of late attention heads (Meng et al., 2023; Geva et al., 2023; Nanda et al., 2023). The country/city combinations were chosen such that Pythia-410M achieved low loss on both xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT and xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT and such that all places were represented by a single token.

Identifier Clean Prompt Noise Source Prompt
CITY-PP BOSCity:␣Barcelona\n Country:␣Spain BOSCity:␣Bei**g\n Country:␣China
IOI-PP BOSWhen␣Michael␣and␣Jessica ␣went␣to␣the␣bar,␣Michael ␣gave␣a␣drink␣to␣Jessica BOSWhen␣Michael␣and␣Jessica ␣went␣to␣the␣bar,␣Ashley ␣gave␣a␣drink␣to␣Michael
RAND-PP BOSHer␣biggest␣worry␣was␣the ␣festival␣might␣suffer␣and ␣people␣might␣erroneously␣think BOSalso␣think␣that␣there ␣should␣be␣the␣same␣rules ␣or␣regulations␣when␣it
Table 1: Clean and noise source prompts for singular prompt pair distributions. Vertical lines denote tokenization boundaries. All prompts are preceded by the BOS (beginning of sequence) token. The last token is not part of the input. The last token of the clean prompt is used as the target in \mathcal{L}caligraphic_L.

We show the cost of verified 100% recall for various methods in Figure 1, where we focus on NeuronNodes for CITY-PP and AttentionNodes for IOI-PP. Exhaustive results for smaller Pythia models are shown in Section B.3. Figure 2 shows the aggregated relative costs for all models on CITY-PP and IOI-PP.

Instead of applying the strict criterion of recalling all important nodes, we can also relax this constraint. In Figure 7, we show the cost of verified 90% recall in the two clean prompt pair settings.

Random prompt pair

The previous prompt pairs may in fact be the best-case scenarios: the interventions they create will be fairly localized to a specific circuit, and this may make it easy for AtP to approximate the contributions. It may thus be informative to see how the methods generalize to settings where the interventions are less surgical. To do this, we also report results in Figure 8 (top) and Figure 9 on a random prompt pair chosen from a non-copyright-protected section of The Pile (Gao et al., 2020) which we refer to as RAND-PP. The prompt pair was chosen such that Pythia-410M still achieved low loss on both prompts.

Refer to caption
(a) RAND-PP MLP neurons.
Refer to caption
(b) RAND-PP Attention nodes.

Refer to caption
(c) A-AN MLP neurons.
Refer to caption
(d) IOI Attention nodes.
Figure 8: Costs of finding the most causally-important nodes in Pythia-12B using different methods, on a random prompt pair (see Table 1) and on distributions. The shading indicates geometric standard deviation. Cost is measured in forward passes, or forward passes per prompt pair in the distributional case.
Refer to caption
(a) RAND-PP MLP neurons.
Refer to caption
Refer to caption
(b) RAND-PP Attention nodes.
Refer to caption
(c) A-AN MLP neurons.
Refer to caption
(d) IOI Attention nodes.
Figure 9: Costs of methods across models, on random prompt pair and on distributions. The costs are relative to having an oracle (and thus verifying nodes in decreasing order of true contribution size); they’re aggregated using an inverse-rank-weighted geometric mean. This means they correspond to the area above the diagonal for each curve in Figure 8.

We find that AtP/AtP* is only somewhat less effective here; this provides tentative evidence that the strong performance of AtP/AtP* isn’t reliant on the clean prompt using a particularly crisp circuit, or on the noise prompt being a precise control.

Distributions

Causal attribution is often of most interest when evaluated across a distribution, as laid out in Section 2. Of the methods, AtP, AtP*, and Subsampling scale reasonably to distributions; the former 2 because they’re inexpensive so running them |𝒟|𝒟|\mathcal{D}|| caligraphic_D | times is not prohibitive, and Subsampling because it intrinsically averages across the distribution and thus becomes proportionally cheaper relative to the verification via activation patching. In addition, having a distribution enables a more performant Iterative method, as described in Section 3.3.

We present a comparison of these methods on 2 distributional settings. The first is a reduced version of IOI (Wang et al., 2022) on 6 names, resulting in 6×5×4=1206541206\times 5\times 4=1206 × 5 × 4 = 120 prompt pairs, where we evaluate AttentionNodes. The other distribution prompts the model to output an indefinite article ‘ a’ or ‘ an’, where we evaluate NeuronNodes. See Section B.1 for details on constructing these distributions. Results are shown in Figure 8 for Pythia 12B, and in Figure 9 across models. The results show that AtP continues to perform well, especially with the QK fix; in addition, the cancellation failure mode tends to be sensitive to the particular input prompt pair, and as a result, averaging across a distribution diminishes the benefit of GradDrops.

An implication of Subsampling scaling well to this setting is that diagnostics may give reasonable confidence in not missing false negatives with much less overhead than in the single-prompt-pair case; this is illustrated in Figure 6.

5 Discussion

5.1 Limitations

Prompt pair distributions

We only considered a small set of prompt pair distributions, which often were limited to a single prompt pair, since evaluating the ground truth can be quite costly. While we aimed to evaluate on distributions that are reasonably representative, our results may not generalize to other distributions.

Choice of Nodes N𝑁Nitalic_N

In the NeuronNodes setting, we took MLP neurons as our fundamental unit of analysis. However, there is mounting evidence (Bricken et al., 2023) that the decomposition of signals into neuron contributions does not correspond directly to a semantically meaningful decomposition. Instead, achieving such a decomposition seems to require finding the right set of directions in neuron activation space (Bricken et al., 2023; Gurnee et al., 2023) – which we viewed as being out of scope for this paper. In Section 5.2 we further discuss the applicability of AtP to sparse autoencoders, a method of finding these decompositions.

More generally, we only considered relatively fine-grained nodes, because this is a case where very exhaustive verification is prohibitively expensive, justifying the need for an approximate, fast method. Nanda (2022) speculate that AtP may perform worse on coarser components like full layers or entire residual streams, as a larger change may have more of a non-linear effect. There may still be benefit in speeding up such an analysis, particularly if the context length is long – our alternative methods may have something to offer here, though we leave investigation of this to future work.

It is popular in the literature to do Activation Patching with these larger components, with short contexts – this doesn’t pose a performance issue, and so our work would not provide any benefit here.

Caveats of c(n)𝑐𝑛c(n)italic_c ( italic_n ) as importance measure

In this work we took the ground truth of activation patching, as defined in Equation 1, as our evaluation target. As discussed by McGrath et al. (2023), Equation 1 often significantly disagrees with a different evaluation target, the “direct effect”, by putting lower weight on some contributions when later components would shift their behaviour to compensate for the earlier patched component. In the worst case this could be seen as producing additional false negatives not accounted for by our metrics. To some degree this is likely to be mitigated by the GradDrop formula in Eq. 11, which will include a term drop** out the effect of that downstream shift.

However, it is also questionable whether we need to concern ourselves with finding high-direct-effect nodes. For example, direct effect is easy to efficiently compute for all nodes, as explored by nostalgebraist (2020) – so there is no need for fast approximations like AtP if direct effect is the quantity of interest. This ease of computation is no free lunch, though, because direct effect is also more limited as a tool for finding causally important nodes: it would not be able to locate any nodes that contribute only instrumentally to the circuit rather than producing its output. For example, there is no direct effect from nodes at non-final token positions. We discuss the direct effect further in Section 3.1.2 and Section A.2.2.

Another nuance of our ground–truth definition occurs in the distributional setting. Some nodes may have a real and significant effect, but only on a single clean prompt (e.g. they only respond to a particular name in IOI121212We did observe this particular behavior in a few instances. or object in A-AN). Since the effect is averaged over the distribution, the ground truth will not assign these nodes large causal importance. Depending on the goal of the practitioner this may or may not be desirable.

Effect size versus rank estimation

When evaluating the performance of various estimators, we focused on evaluating the relative rank of estimates, since our main goal was to identify important components (with effect size only instrumentally useful to this end), and we assumed a further verification step of the nodes with highest estimated effects one at a time, in contexts where knowing effect size is important. Thus, we do not present evidence about how closely the estimated effect magnitudes from AtP or AtP* match the ground truth. Similarly, we did not assess the prevalence of false positives in our analysis, because they can be filtered out via the verification process. Finally, we did not compare to past manual interpretability work to check whether our methods find the same nodes to be causally important as discovered by human researchers, as done in prior work (Conmy et al., 2023; Syed et al., 2023).

Other LLMs

While we think it likely that our results on the Pythia model family (Biderman et al., 2023) will transfer to other LLM families, we cannot rule out qualitatively different behavior without further evidence, especially on SotA–scale models or models that significantly deviate from the standard decoder-only transformer architecture.

5.2 Extensions/Variants

Edge Patching

While we focus on computing the effects of individual nodes, edge activation patching can give more fine-grained information about which paths in the computational graph matter. However, it suffers from an even larger blowup in number of forward passes if done naively. Fortunately, AtP is easy to generalize to estimating the effects of edges between nodes (Nanda, 2022; Syed et al., 2023), while AtP* may provide further improvement. We discuss edge-AtP, and how to efficiently carry over the insights from AtP*, in Section C.2.

Coarser nodes N𝑁Nitalic_N

We focused on fine-grained attribution, rather than full layers or sliding windows (Meng et al., 2023; Geva et al., 2023). In the latter case there’s less computational blowup to resolve, but for long contexts there may still be benefit in considering speedups like ours; on the other hand, they may be less linear, thus favouring other methods over AtP*. We leave investigation of this to future work.

Layer normalization

Nanda (2022) observed that AtP’s approximation to layer normalization may be a worse approximation when it comes to patching larger/coarser nodes: on average the patched and clean activations are likely to have similar norm, but may not have high cosine-similarity. They recommend treating the denominator in layer normalization as fixed, e.g. using a stop-gradient operator in the implementation. In Section C.1 we explore the effect of this, and illustrate the behaviour of this alternative form of AtP. It seems likely that this variant would indeed produce better results particularly when patching residual-stream nodes – but we leave empirical investigation of this to future work.

Denoising

Denoising (Meng et al., 2023; Lieberum et al., 2023) is a different use case for patching, which may produce moderately different results: the difference is that each forward pass is run on xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT with the activation to patch taken from xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT — colloquially, this tests whether the patched activation is sufficient to recover model performance on xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT, rather than necessary. We provide some preliminary evidence to the effect of this choice in Section B.4 but leave a more thorough investigation to future work.

Other forms of ablation

Further, in some settings it may be of interest to do mean-ablation, or even zero-ablation, and our tweaks remain applicable there; the random-prompt-pair result suggests AtP* isn’t overly sensitive to the noise distribution, so we speculate the results are likely to carry over.

5.3 Applications

Automated Circuit Finding

A natural application of the methods we discussed in this work is the automatic identification and localization of sparse subgraphs or ‘circuits’ (Cammarata et al., 2020). A variant of this was already discussed in concurrent work by Syed et al. (2023) who combined edge attribution patching with the ACDC algorithm (Conmy et al., 2023). As we mentioned in the edge patching discussion, AtP* can be generalized to edge attribution patching, which may bring additional benefit for automated circuit discovery.

Another approach is to learn a (probabilistic) mask over nodes, similar to Louizos et al. (2018); Cao et al. (2021), where the probability scales with the currently estimated node contribution c(n)𝑐𝑛c(n)italic_c ( italic_n ). For that approach, a fast method to estimate all node effects given the current mask probabilities could prove vital.

Sparse Autoencoders

Recently there has been increased interest by the community in using sparse autoencoders (SAEs) to construct disentangled sparse representations with potentially more semantic coherence than transformer-native units such as neurons (Cunningham et al., 2023; Bricken et al., 2023). SAEs usually have a lot more nodes than the corresponding transformer block they are applied to. This could pose a larger problem in terms of the activation patching effects, making the speedup of AtP* more valuable. However, due to the sparseness of the SAE, on a given forward pass the effect of most features will be zero. For example, some successful SAEs by Bricken et al. (2023) have 10-20 active features for 500 neurons for a given token position, which reduces the number of nodes by 20-50x relative to the MLP setting, increasing the scale at which existing iterative methods remain practical. It is still an open research question, however, what degree of sparsity is feasible with tolerable reconstruction error for practically relevant or SOTA–scale models, where the methods discussed in this work may become more important again.

Steering LLMs

AtP* could be used to discover single nodes in the model that can be leveraged for targeted inference time interventions to control the model’s behavior. In contrast to previous work (Li et al., 2023; Turner et al., 2023; Zou et al., 2023) it might provide more localized interventions with less impact on the rest of the model’s computation. One potential exciting direction would be to use AtP* (or other gradient-based approximations) to see which sparse autoencoder features, if activated, would have a significant effect.

5.4 Recommendation

Our results suggest that if a practitioner is trying to do fast causal attribution, there are 2 main factors to consider: (i) the desired granularity of localization, and (ii) the confidence vs compute tradeoff.

Regarding (i), the desired granularity, smaller components (e.g. MLP neurons or attention heads) are more numerous but more linear, likely yielding better results from gradient-based methods like AtP. We are less sure AtP will be a good approximation if patching layers or sliding windows of layers, and in this case practitioners may want to do normal patching. If the number of forward passes required remains prohibitive (e.g. a long context times many layers, when doing per token ×\times× layer patching), our other baselines may be useful. For a single prompt pair we particularly recommend trying Blocks, as it’s easy to make sense of; for a distribution we recommend Subsampling because it scales better to many prompt pairs.

Regarding (ii), the confidence vs compute tradeoff, depending on the application, it may be desirable to run AtP as an activation patching prefilter followed by running the diagnostic to increase confidence. On the other hand, if false negatives aren’t a big concern then it may be preferable to skip the diagnostic – and if false positives aren’t either, then in certain cases practitioners may want to skip activation patching verification entirely. In addition, if the prompt pair distribution does not adequately highlight the specific circuit/behaviour of interest, this may also limit what can be learned from any localization methods.

If AtP is appropriate, our results suggest the best variant to use is probably AtP* for single prompt pairs, AtP+QKFix for AttentionNodes on distributions, and AtP for NeuronNodes (or other sites that aren’t immediately before a nonlinearity) on distributions.

Of course, these recommendations are best-substantiated in settings similar to those we studied: focused prompt pairs / distribution, attention node or neuron sites, nodewise attribution, measuring cross-entropy loss on the clean-prompt next token. If departing from these assumptions we recommend looking before you leap.

6 Related work

Localization and Mediation Analysis

This work is concerned with identifying the effect of all (important) nodes in a causal graph (Pearl, 2000), in the specific case where the graph represents a language model’s computation. A key method for finding important intermediate nodes in a causal graph is intervening on those nodes and observing the effect, which was first discussed under the name of causal mediation analysis by Robins and Greenland (1992); Pearl (2001).

Activation Patching

In recent years there has been increasing success at applying the ideas of causal mediation analysis to identify causally important nodes in deep neural networks, in particular via the method of activation patching, where the output of a model component is intervened on. This technique has been widely used by the community and successfully applied in a range of contexts (Olsson et al., 2022; Vig et al., 2020; Soulos et al., 2020; Meng et al., 2023; Wang et al., 2022; Hase et al., 2023; Lieberum et al., 2023; Conmy et al., 2023; Hanna et al., 2023; Geva et al., 2023; Huang et al., 2023; Tigges et al., 2023; Merullo et al., 2023; McDougall et al., 2023; Goldowsky-Dill et al., 2023; Stolfo et al., 2023; Feng and Steinhardt, 2023; Hendel et al., 2023; Todd et al., 2023; Cunningham et al., 2023; Finlayson et al., 2021; Nanda et al., 2023).

Chan et al. (2022) introduce causal scrubbing, a generalized algorithm to verify a hypothesis about the internal mechanism underlying a model’s behavior, and detail their motivation behind performing noising and resample ablation rather than denoising or using mean or zero ablation – they interpret the hypothesis as implying the computation is invariant to some large set of perturbations, so their starting-point is the clean unperturbed forward pass.131313Our motivation for focusing on noising rather than denoising was a closely related one – we were motivated by automated circuit discovery, where gradually noising more and more of the model is the basic methodology for both of the approaches discussed in Section 5.3.

Another line of research concerning formalizing causal abstractions focuses on finding and verifying high-level causal abstractions of low-level variables (Geiger et al., 2020, 2021, 2022, 2023). See Jenner et al. (2022) for more details on how these different frameworks agree and differ. In contrast to those works, we are chiefly concerned with identifying the important low-level variables in the computational graph and are not investigating their semantics or potential grou**s of lower-level into higher-level variables.

In addition to causal mediation analysis, intervening on node activations in the model forward pass has also been studied as a way of steering models towards desirable behavior (Rimsky et al., 2023; Zou et al., 2023; Turner et al., 2023; Jorgensen et al., 2023; Li et al., 2023; Belrose et al., 2023).

Attribution Patching / Gradient-based Masking

While we use the resample–ablation variant of AtP as formulated in Nanda (2022), similar formulations have been used in the past to successfully prune deep neural networks (Figurnov et al., 2016; Molchanov et al., 2017; Michel et al., 2019), or even identify causally important nodes for interpretability (Cao et al., 2021). Concurrent work by Syed et al. (2023) also demonstrates AtP can help with automatically finding causally important circuits in a way that agrees with previous manual circuit identification work. In contrast to Syed et al. (2023), we provide further analysis of AtP’s failure modes, give improvements in the form of AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT, and evaluate both methods as well as several baselines on a suite of larger models against a ground truth that is independent of human researchers’ judgement.

7 Conclusion

In this paper, we have explored the use of attribution patching for node patch effect evaluation. We have compared attribution patching with alternatives and augmentations, characterized its failure modes, and presented reliability diagnostics. We have also discussed the implications of our contributions for other settings in which patching can be of interest, such as circuit discovery, edge localization, coarse-grained localization, and causal abstraction.

Our results show that AtP* can be a more reliable and scalable approach to node patch effect evaluation than alternatives. However, it is important to be aware of the failure modes of attribution patching, such as cancellation and saturation. We explored these in some detail, and provided mitigations, as well as recommendations for diagnostics to ensure that the results are reliable.

We believe that our work makes an important contribution to the field of mechanistic interpretability and will help to advance the development of more reliable and scalable methods for understanding the behavior of deep neural networks.

8 Author Contributions

János Kramár was research lead, and Tom Lieberum was also a core contributor – both were highly involved in most aspects of the project. Rohin Shah and Neel Nanda served as advisors and gave feedback and guidance throughout.

References

  • Belrose et al. (2023) N. Belrose, D. Schneider-Joseph, S. Ravfogel, R. Cotterell, E. Raff, and S. Biderman. Leace: Perfect linear concept erasure in closed form. arXiv preprint arXiv:2306.03819, 2023.
  • Biderman et al. (2023) S. Biderman, H. Schoelkopf, Q. G. Anthony, H. Bradley, K. O’Brien, E. Hallahan, M. A. Khan, S. Purohit, U. S. Prashanth, E. Raff, A. Skowron, L. Sutawika, and O. van der Wal. Pythia: A suite for analyzing large language models across training and scaling. In A. Krause, E. Brunskill, K. Cho, B. Engelhardt, S. Sabato, and J. Scarlett, editors, International Conference on Machine Learning, ICML 2023, 23-29 July 2023, Honolulu, Hawaii, USA, volume 202 of Proceedings of Machine Learning Research, pages 2397–2430. PMLR, 2023. URL https://proceedings.mlr.press/v202/biderman23a.html.
  • Bricken et al. (2023) T. Bricken, A. Templeton, J. Batson, B. Chen, A. Jermyn, T. Conerly, N. Turner, C. Anil, C. Denison, A. Askell, R. Lasenby, Y. Wu, S. Kravec, N. Schiefer, T. Maxwell, N. Joseph, Z. Hatfield-Dodds, A. Tamkin, K. Nguyen, B. McLean, J. E. Burke, T. Hume, S. Carter, T. Henighan, and C. Olah. Towards monosemanticity: Decomposing language models with dictionary learning. Transformer Circuits Thread, 2023. https://transformer-circuits.pub/2023/monosemantic-features/index.html.
  • Cammarata et al. (2020) N. Cammarata, S. Carter, G. Goh, C. Olah, M. Petrov, L. Schubert, C. Voss, B. Egan, and S. K. Lim. Thread: Circuits. Distill, 2020. 10.23915/distill.00024. https://distill.pub/2020/circuits.
  • Cao et al. (2021) N. D. Cao, L. Schmid, D. Hupkes, and I. Titov. Sparse interventions in language models with differentiable masking, 2021.
  • Chan et al. (2022) L. Chan, A. Garriga-Alonso, N. Goldwosky-Dill, R. Greenblatt, J. Nitishinskaya, A. Radhakrishnan, B. Shlegeris, and N. Thomas. Causal scrubbing, a method for rigorously testing interpretability hypotheses. AI Alignment Forum, 2022. https://www.alignmentforum.org/posts/JvZhhzycHu2Yd57RN/causal-scrubbing-a-method-for-rigorously-testing.
  • Conmy et al. (2023) A. Conmy, A. N. Mavor-Parker, A. Lynch, S. Heimersheim, and A. Garriga-Alonso. Towards automated circuit discovery for mechanistic interpretability, 2023.
  • Cunningham et al. (2023) H. Cunningham, A. Ewart, L. Riggs, R. Huben, and L. Sharkey. Sparse autoencoders find highly interpretable features in language models, 2023.
  • Feng and Steinhardt (2023) J. Feng and J. Steinhardt. How do language models bind entities in context?, 2023.
  • Figurnov et al. (2016) M. Figurnov, A. Ibraimova, D. P. Vetrov, and P. Kohli. Perforatedcnns: Acceleration through elimination of redundant convolutions. In D. Lee, M. Sugiyama, U. Luxburg, I. Guyon, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc., 2016. URL https://proceedings.neurips.cc/paper_files/paper/2016/file/f0e52b27a7a5d6a1a87373dffa53dbe5-Paper.pdf.
  • Finlayson et al. (2021) M. Finlayson, A. Mueller, S. Gehrmann, S. Shieber, T. Linzen, and Y. Belinkov. Causal analysis of syntactic agreement mechanisms in neural language models. In C. Zong, F. Xia, W. Li, and R. Navigli, editors, Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers), pages 1828–1843, Online, Aug. 2021. Association for Computational Linguistics. 10.18653/v1/2021.acl-long.144. URL https://aclanthology.org/2021.acl-long.144.
  • Gao et al. (2020) L. Gao, S. Biderman, S. Black, L. Golding, T. Hoppe, C. Foster, J. Phang, H. He, A. Thite, N. Nabeshima, S. Presser, and C. Leahy. The Pile: An 800gb dataset of diverse text for language modeling. arXiv preprint arXiv:2101.00027, 2020.
  • Geiger et al. (2020) A. Geiger, K. Richardson, and C. Potts. Neural natural language inference models partially embed theories of lexical entailment and negation, 2020.
  • Geiger et al. (2021) A. Geiger, H. Lu, T. Icard, and C. Potts. Causal abstractions of neural networks, 2021.
  • Geiger et al. (2022) A. Geiger, Z. Wu, H. Lu, J. Rozner, E. Kreiss, T. Icard, N. D. Goodman, and C. Potts. Inducing causal structure for interpretable neural networks, 2022.
  • Geiger et al. (2023) A. Geiger, C. Potts, and T. Icard. Causal abstraction for faithful model interpretation, 2023.
  • Geva et al. (2023) M. Geva, J. Bastings, K. Filippova, and A. Globerson. Dissecting recall of factual associations in auto-regressive language models, 2023.
  • Goldowsky-Dill et al. (2023) N. Goldowsky-Dill, C. MacLeod, L. Sato, and A. Arora. Localizing model behavior with path patching, 2023.
  • Gurnee et al. (2023) W. Gurnee, N. Nanda, M. Pauly, K. Harvey, D. Troitskii, and D. Bertsimas. Finding neurons in a haystack: Case studies with sparse probing, 2023.
  • Hanna et al. (2023) M. Hanna, O. Liu, and A. Variengien. How does gpt-2 compute greater-than?: Interpreting mathematical abilities in a pre-trained language model, 2023.
  • Hase et al. (2023) P. Hase, M. Bansal, B. Kim, and A. Ghandeharioun. Does localization inform editing? surprising differences in causality-based localization vs. knowledge editing in language models, 2023.
  • Hendel et al. (2023) R. Hendel, M. Geva, and A. Globerson. In-context learning creates task vectors, 2023.
  • Hoffmann et al. (2022) J. Hoffmann, S. Borgeaud, A. Mensch, E. Buchatskaya, T. Cai, E. Rutherford, D. de Las Casas, L. A. Hendricks, J. Welbl, A. Clark, T. Hennigan, E. Noland, K. Millican, G. van den Driessche, B. Damoc, A. Guy, S. Osindero, K. Simonyan, E. Elsen, O. Vinyals, J. Rae, and L. Sifre. An empirical analysis of compute-optimal large language model training. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 30016–30030. Curran Associates, Inc., 2022. URL https://proceedings.neurips.cc/paper_files/paper/2022/file/c1e2faff6f588870935f114ebe04a3e5-Paper-Conference.pdf.
  • Huang et al. (2023) J. Huang, A. Geiger, K. D’Oosterlinck, Z. Wu, and C. Potts. Rigorously assessing natural language explanations of neurons, 2023.
  • Jenner et al. (2022) E. Jenner, A. Garriga-Alonso, and E. Zverev. A comparison of causal scrubbing, causal abstractions, and related methods. AI Alignment Forum, 2022. https://www.alignmentforum.org/posts/uLMWMeBG3ruoBRhMW/a-comparison-of-causal-scrubbing-causal-abstractions-and.
  • Jorgensen et al. (2023) O. Jorgensen, D. Cope, N. Schoots, and M. Shanahan. Improving activation steering in language models with mean-centring, 2023.
  • Li et al. (2023) K. Li, O. Patel, F. Viégas, H. Pfister, and M. Wattenberg. Inference-time intervention: Eliciting truthful answers from a language model, 2023.
  • Lieberum et al. (2023) T. Lieberum, M. Rahtz, J. Kramár, N. Nanda, G. Irving, R. Shah, and V. Mikulik. Does circuit analysis interpretability scale? evidence from multiple choice capabilities in chinchilla, 2023.
  • Louizos et al. (2018) C. Louizos, M. Welling, and D. P. Kingma. Learning sparse neural networks through l0subscript𝑙0l_{0}italic_l start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT regularization, 2018.
  • McDougall et al. (2023) C. McDougall, A. Conmy, C. Rushing, T. McGrath, and N. Nanda. Copy suppression: Comprehensively understanding an attention head, 2023.
  • McGrath et al. (2023) T. McGrath, M. Rahtz, J. Kramár, V. Mikulik, and S. Legg. The hydra effect: Emergent self-repair in language model computations, 2023.
  • Meng et al. (2023) K. Meng, D. Bau, A. Andonian, and Y. Belinkov. Locating and editing factual associations in gpt, 2023.
  • Merullo et al. (2023) J. Merullo, C. Eickhoff, and E. Pavlick. Circuit component reuse across tasks in transformer language models, 2023.
  • Michel et al. (2019) P. Michel, O. Levy, and G. Neubig. Are sixteen heads really better than one? In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019. URL https://proceedings.neurips.cc/paper_files/paper/2019/file/2c601ad9d2ff9bc8b282670cdd54f69f-Paper.pdf.
  • Molchanov et al. (2017) P. Molchanov, S. Tyree, T. Karras, T. Aila, and J. Kautz. Pruning convolutional neural networks for resource efficient inference. In International Conference on Learning Representations, 2017. URL https://openreview.net/forum?id=SJGCiw5gl.
  • Nanda (2022) N. Nanda. Attribution patching: Activation patching at industrial scale. 2022. URL https://www.neelnanda.io/mechanistic-interpretability/attribution-patching.
  • Nanda et al. (2023) N. Nanda, S. Rajamanoharan, J. Kramár, and R. Shah. Fact finding: Attempting to reverse-engineer factual recall on the neuron level, Dec 2023. URL https://www.alignmentforum.org/posts/iGuwZTHWb6DFY3sKB/fact-finding-attempting-to-reverse-engineer-factual-recall.
  • nostalgebraist (2020) nostalgebraist. interpreting gpt: the logit lens. 2020. URL https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens.
  • Olsson et al. (2022) C. Olsson, N. Elhage, N. Nanda, N. Joseph, N. DasSarma, T. Henighan, B. Mann, A. Askell, Y. Bai, A. Chen, T. Conerly, D. Drain, D. Ganguli, Z. Hatfield-Dodds, D. Hernandez, S. Johnston, A. Jones, J. Kernion, L. Lovitt, K. Ndousse, D. Amodei, T. Brown, J. Clark, J. Kaplan, S. McCandlish, and C. Olah. In-context learning and induction heads. Transformer Circuits Thread, 2022. https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html.
  • Pearl (2000) J. Pearl. Causality: Models, Reasoning and Inference. Cambridge University Press, 2000.
  • Pearl (2001) J. Pearl. Direct and indirect effects, 2001.
  • Radford et al. (2018) A. Radford, K. Narasimhan, T. Salimans, and I. Sutskever. Improving language understanding by generative pre-training, 2018.
  • Rimsky et al. (2023) N. Rimsky, N. Gabrieli, J. Schulz, M. Tong, E. Hubinger, and A. M. Turner. Steering llama 2 via contrastive activation addition, 2023.
  • Robins and Greenland (1992) J. M. Robins and S. Greenland. Identifiability and exchangeability for direct and indirect effects. Epidemiology, 3:143–155, 1992. URL https://api.semanticscholar.org/CorpusID:10757981.
  • Soulos et al. (2020) P. Soulos, R. T. McCoy, T. Linzen, and P. Smolensky. Discovering the compositional structure of vector representations with role learning networks. In A. Alishahi, Y. Belinkov, G. Chrupała, D. Hupkes, Y. Pinter, and H. Sajjad, editors, Proceedings of the Third BlackboxNLP Workshop on Analyzing and Interpreting Neural Networks for NLP, pages 238–254, Online, Nov. 2020. Association for Computational Linguistics. 10.18653/v1/2020.blackboxnlp-1.23. URL https://aclanthology.org/2020.blackboxnlp-1.23.
  • Stolfo et al. (2023) A. Stolfo, Y. Belinkov, and M. Sachan. A mechanistic interpretation of arithmetic reasoning in language models using causal mediation analysis, 2023.
  • Syed et al. (2023) A. Syed, C. Rager, and A. Conmy. Attribution patching outperforms automated circuit discovery, 2023.
  • Tigges et al. (2023) C. Tigges, O. J. Hollinsworth, A. Geiger, and N. Nanda. Linear representations of sentiment in large language models, 2023.
  • Todd et al. (2023) E. Todd, M. L. Li, A. S. Sharma, A. Mueller, B. C. Wallace, and D. Bau. Function vectors in large language models, 2023.
  • Turner et al. (2023) A. M. Turner, L. Thiergart, D. Udell, G. Leech, U. Mini, and M. MacDiarmid. Activation addition: Steering language models without optimization, 2023.
  • Vaswani et al. (2017) A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin. Attention is all you need, 2017.
  • Veit et al. (2016) A. Veit, M. J. Wilber, and S. Belongie. Residual networks behave like ensembles of relatively shallow networks. In D. Lee, M. Sugiyama, U. Luxburg, I. Guyon, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc., 2016. URL https://proceedings.neurips.cc/paper_files/paper/2016/file/37bc2f75bf1bcfe8450a1a41c200364c-Paper.pdf.
  • Vig et al. (2020) J. Vig, S. Gehrmann, Y. Belinkov, S. Qian, D. Nevo, Y. Singer, and S. Shieber. Investigating gender bias in language models using causal mediation analysis. In H. Larochelle, M. Ranzato, R. Hadsell, M. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 12388–12401. Curran Associates, Inc., 2020. URL https://proceedings.neurips.cc/paper_files/paper/2020/file/92650b2e92217715fe312e6fa7b90d82-Paper.pdf.
  • Wang et al. (2022) K. Wang, A. Variengien, A. Conmy, B. Shlegeris, and J. Steinhardt. Interpretability in the wild: a circuit for indirect object identification in gpt-2 small, 2022.
  • Welch (1947) B. L. Welch. The generalization of ‘Student’s’ problem when several different population variances are involved. Biometrika, 34(1-2):28–35, 01 1947. ISSN 0006-3444. 10.1093/biomet/34.1-2.28. URL https://doi.org/10.1093/biomet/34.1-2.28.
  • Zou et al. (2023) A. Zou, L. Phan, S. Chen, J. Campbell, P. Guo, R. Ren, A. Pan, X. Yin, M. Mazeika, A.-K. Dombrowski, S. Goel, N. Li, M. J. Byun, Z. Wang, A. Mallen, S. Basart, S. Koyejo, D. Song, M. Fredrikson, J. Z. Kolter, and D. Hendrycks. Representation engineering: A top-down approach to ai transparency, 2023.

Appendix A Method details

A.1 Baselines

A.1.1 Properties of Subsampling

Here we prove that the subsampling estimator ^SS(n)subscript^SS𝑛\hat{\mathcal{I}}_{\text{SS}}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT ( italic_n ) from Section 3.3 is unbiased in the case of no interaction effects. Furthermore, assuming a simple interaction model, we show the bias of ^SS(n)subscript^SS𝑛\hat{\mathcal{I}}_{\text{SS}}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT ( italic_n ) is p𝑝pitalic_p times the total interaction effect of n𝑛nitalic_n with other nodes. We assume a pairwise interaction model. That is, given a set of nodes η𝜂\etaitalic_η, we have

(η;x)𝜂𝑥\displaystyle\mathcal{I}(\eta;x)caligraphic_I ( italic_η ; italic_x ) =nη(n;x)+n,nηnnσn,n(x)absentsubscript𝑛𝜂𝑛𝑥subscript𝑛superscript𝑛𝜂𝑛𝑛subscript𝜎𝑛superscript𝑛𝑥\displaystyle=\sum_{n\in\eta}\mathcal{I}(n;x)+\sum_{\begin{subarray}{c}n,n^{% \prime}\in\eta\\ n\neq n\end{subarray}}\sigma_{n,n^{\prime}}(x)= ∑ start_POSTSUBSCRIPT italic_n ∈ italic_η end_POSTSUBSCRIPT caligraphic_I ( italic_n ; italic_x ) + ∑ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_n , italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_η end_CELL end_ROW start_ROW start_CELL italic_n ≠ italic_n end_CELL end_ROW end_ARG end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_n , italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x ) (16)

with fixed constants σn,n(x)subscript𝜎𝑛superscript𝑛𝑥\sigma_{n,n^{\prime}}(x)\in\mathbb{R}italic_σ start_POSTSUBSCRIPT italic_n , italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x ) ∈ blackboard_R for each prompt pair xsupport(𝒟)𝑥support𝒟x\in\operatorname{support}(\mathcal{D})italic_x ∈ roman_support ( caligraphic_D ). Let σn,n=𝔼x𝒟[σn,n(x)]subscript𝜎𝑛superscript𝑛subscript𝔼similar-to𝑥𝒟delimited-[]subscript𝜎𝑛superscript𝑛𝑥\sigma_{n,n^{\prime}}=\mathbb{E}_{x\sim\mathcal{D}}\left[\sigma_{n,n^{\prime}}% (x)\right]italic_σ start_POSTSUBSCRIPT italic_n , italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_x ∼ caligraphic_D end_POSTSUBSCRIPT [ italic_σ start_POSTSUBSCRIPT italic_n , italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x ) ].

Let p𝑝pitalic_p be the probability of including each node in a given η𝜂\etaitalic_η and let M𝑀Mitalic_M be the number of node masks sampled from Bernoulli|N|(p)superscriptBernoulli𝑁𝑝\operatorname{Bernoulli}^{|N|}(p)roman_Bernoulli start_POSTSUPERSCRIPT | italic_N | end_POSTSUPERSCRIPT ( italic_p ) and prompt pairs x𝑥xitalic_x sampled from 𝒟𝒟\mathcal{D}caligraphic_D. Then,

𝔼[^SS(n)]𝔼delimited-[]subscript^SS𝑛\displaystyle\mathbb{E}\left[\hat{\mathcal{I}}_{\text{SS}}(n)\right]blackboard_E [ over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT ( italic_n ) ] =𝔼[1|η+(n)|k=1|η+(n)|(ηk+(n);xk+)1|η(n)|k=1|η(n)|(ηk(n);xk)]absent𝔼delimited-[]1superscript𝜂𝑛superscriptsubscript𝑘1superscript𝜂𝑛subscriptsuperscript𝜂𝑘𝑛superscriptsubscript𝑥𝑘1superscript𝜂𝑛superscriptsubscript𝑘1superscript𝜂𝑛subscriptsuperscript𝜂𝑘𝑛superscriptsubscript𝑥𝑘\displaystyle=\mathbb{E}\left[\frac{1}{|\eta^{+}(n)|}\sum_{k=1}^{|\eta^{+}(n)|% }\mathcal{I}(\eta^{+}_{k}(n);x_{k}^{+})-\frac{1}{|\eta^{-}(n)|}\sum_{k=1}^{|% \eta^{-}(n)|}\mathcal{I}(\eta^{-}_{k}(n);x_{k}^{-})\right]= blackboard_E [ divide start_ARG 1 end_ARG start_ARG | italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) | end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) | end_POSTSUPERSCRIPT caligraphic_I ( italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_n ) ; italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ) - divide start_ARG 1 end_ARG start_ARG | italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) | end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) | end_POSTSUPERSCRIPT caligraphic_I ( italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_n ) ; italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ) ] (17a)
=𝔼[𝔼[1|η+(n)|k=1|η+(n)|(ηk+(n);xk+)1|η(n)|k=1|η(n)|(ηk(n);xk)||η+(n)|]]\displaystyle=\mathbb{E}\left[\mathbb{E}\left[\frac{1}{|\eta^{+}(n)|}\sum_{k=1% }^{|\eta^{+}(n)|}\mathcal{I}(\eta^{+}_{k}(n);x_{k}^{+})-\frac{1}{|\eta^{-}(n)|% }\sum_{k=1}^{|\eta^{-}(n)|}\mathcal{I}(\eta^{-}_{k}(n);x_{k}^{-})\middle||\eta% ^{+}(n)|\right]\right]= blackboard_E [ blackboard_E [ divide start_ARG 1 end_ARG start_ARG | italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) | end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) | end_POSTSUPERSCRIPT caligraphic_I ( italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_n ) ; italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ) - divide start_ARG 1 end_ARG start_ARG | italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) | end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) | end_POSTSUPERSCRIPT caligraphic_I ( italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_n ) ; italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ) | | italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) | ] ] (17b)
=𝔼[𝔼[|η+(n)||η+(n)|𝔼[(η1;x1)|nη1]|η(n)||η(n)|𝔼[(η1;x1)|nη1]||η+(n)|]]\displaystyle=\mathbb{E}\left[\mathbb{E}\left[\frac{|\eta^{+}(n)|}{|\eta^{+}(n% )|}\mathbb{E}\left[\mathcal{I}(\eta_{1};x_{1})\middle|n\in\eta_{1}\right]-% \frac{|\eta^{-}(n)|}{|\eta^{-}(n)|}\mathbb{E}\left[\mathcal{I}(\eta_{1};x_{1})% \middle|n\not\in\eta_{1}\right]\middle||\eta^{+}(n)|\right]\right]= blackboard_E [ blackboard_E [ divide start_ARG | italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) | end_ARG start_ARG | italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) | end_ARG blackboard_E [ caligraphic_I ( italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) | italic_n ∈ italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] - divide start_ARG | italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) | end_ARG start_ARG | italic_η start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ( italic_n ) | end_ARG blackboard_E [ caligraphic_I ( italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) | italic_n ∉ italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] | | italic_η start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ( italic_n ) | ] ] (17c)
=𝔼[(η1;x1)|nη1]𝔼[(η1;x1)|nη1]\displaystyle=\mathbb{E}\left[\mathcal{I}(\eta_{1};x_{1})\middle|n\in\eta_{1}% \right]-\mathbb{E}\left[\mathcal{I}(\eta_{1};x_{1})\middle|n\not\in\eta_{1}\right]= blackboard_E [ caligraphic_I ( italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) | italic_n ∈ italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] - blackboard_E [ caligraphic_I ( italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) | italic_n ∉ italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] (17d)
=c(n)+𝔼[nn𝟙[nη1](c(n)+σnn+12n′′{n,n}𝟙[nη1]σnn′′|nη1)]\displaystyle=c(n)+\mathbb{E}\left[\sum_{n^{\prime}\neq n}\mathbb{1}[n^{\prime% }\in\eta_{1}]\left(c(n^{\prime})+\sigma_{nn^{\prime}}+\frac{1}{2}\sum_{n^{% \prime\prime}\not\in\{n^{\prime},n\}}\mathbb{1}[n^{\prime}\in\eta_{1}]\sigma_{% n^{\prime}n^{\prime\prime}}\middle|n\in\eta_{1}\right)\right]= italic_c ( italic_n ) + blackboard_E [ ∑ start_POSTSUBSCRIPT italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_n end_POSTSUBSCRIPT blackboard_1 [ italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ( italic_c ( italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) + italic_σ start_POSTSUBSCRIPT italic_n italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_n start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ∉ { italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_n } end_POSTSUBSCRIPT blackboard_1 [ italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] italic_σ start_POSTSUBSCRIPT italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_n ∈ italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ] (17e)
𝔼[nn𝟙[nη1](c(n)+12n′′{n,n}𝟙[nη1]σnn′′)|nη1]\displaystyle\quad-\mathbb{E}\left[\sum_{n^{\prime}\neq n}\mathbb{1}[n^{\prime% }\in\eta_{1}]\left(c(n^{\prime})+\frac{1}{2}\sum_{n^{\prime\prime}\not\in\{n^{% \prime},n\}}\mathbb{1}[n^{\prime}\in\eta_{1}]\sigma_{n^{\prime}n^{\prime\prime% }}\right)\middle|n\not\in\eta_{1}\right]- blackboard_E [ ∑ start_POSTSUBSCRIPT italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_n end_POSTSUBSCRIPT blackboard_1 [ italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ( italic_c ( italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_n start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ∉ { italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_n } end_POSTSUBSCRIPT blackboard_1 [ italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] italic_σ start_POSTSUBSCRIPT italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) | italic_n ∉ italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] (17f)
=c(n)+pnnσnnabsent𝑐𝑛𝑝subscriptsuperscript𝑛𝑛subscript𝜎𝑛superscript𝑛\displaystyle=c(n)+p\sum_{n^{\prime}\neq n}\sigma_{nn^{\prime}}= italic_c ( italic_n ) + italic_p ∑ start_POSTSUBSCRIPT italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_n end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_n italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (17g)

In Equation 17g, we observe that if the interaction terms σnnsubscript𝜎𝑛superscript𝑛\sigma_{nn^{\prime}}italic_σ start_POSTSUBSCRIPT italic_n italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT are all zero, the estimator is unbiased. Otherwise, the bias scales both with the sum of interaction effects and with p𝑝pitalic_p, as expected.

A.1.2 Pseudocode for Blocks and Hierarchical baselines

In Algorithm 2 we detail the Blocks baseline algorithm. As explained in Section 3.3, it comes with a tradeoff in its “block size” hyperparameter B𝐵Bitalic_B: a small block size requires a lot of time to evaluate all the blocks, while a large block size means many irrelevant nodes to evaluate in each high-contribution block.

Algorithm 2 Blocks algorithm for causal attribution.
1:block size B𝐵Bitalic_B, compute budget M𝑀Mitalic_M, nodes N={ni}𝑁subscript𝑛𝑖N=\{n_{i}\}italic_N = { italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }, prompts xclean,xnoisesuperscript𝑥cleansuperscript𝑥noisex^{\text{clean}},\,x^{\text{noise}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT, intervention function ~:η(η;xclean,xnoise):~maps-to𝜂𝜂superscript𝑥cleansuperscript𝑥noise\tilde{\mathcal{I}}:\eta\mapsto\mathcal{I}(\eta;x^{\text{clean}},x^{\text{% noise}})over~ start_ARG caligraphic_I end_ARG : italic_η ↦ caligraphic_I ( italic_η ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT )
2:numBlocks|N|/BnumBlocks𝑁𝐵\mathrm{numBlocks}\leftarrow\lceil|N|/B\rceilroman_numBlocks ← ⌈ | italic_N | / italic_B ⌉
3:πshuffle({numBlocksiB/|N|i{0,,|N|1}})𝜋shuffleconditionalnumBlocks𝑖𝐵𝑁𝑖0𝑁1\pi\leftarrow\operatorname{shuffle}\left(\left\{\left\lfloor\mathrm{numBlocks}% \cdot iB/|N|\right\rfloor\mid i\in\{0,\dots,|N|-1\}\right\}\right)italic_π ← roman_shuffle ( { ⌊ roman_numBlocks ⋅ italic_i italic_B / | italic_N | ⌋ ∣ italic_i ∈ { 0 , … , | italic_N | - 1 } } )\triangleright Assign each node to a block.
4:for i0 to numBlocks1𝑖0 to numBlocks1i\leftarrow 0\textrm{ to numBlocks}-1italic_i ← 0 to numBlocks - 1 do
5:     blockContribution[i]|~(π1({i}))|blockContributiondelimited-[]𝑖~superscript𝜋1𝑖\textrm{blockContribution}[i]\leftarrow|\tilde{\mathcal{I}}(\pi^{-1}(\{i\}))|blockContribution [ italic_i ] ← | over~ start_ARG caligraphic_I end_ARG ( italic_π start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( { italic_i } ) ) | \triangleright π1({i}):={n:π(n)=inN})\pi^{-1}(\{i\}):=\{n:\,\pi(n)=i\mid n\in N\})italic_π start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( { italic_i } ) := { italic_n : italic_π ( italic_n ) = italic_i ∣ italic_n ∈ italic_N } )
6:spentBudgetMnumBlocksspentBudget𝑀numBlocks\mathrm{spentBudget}\leftarrow M-\mathrm{numBlocks}roman_spentBudget ← italic_M - roman_numBlocks
7:topNodeContribsCreateEmptyDictionary()topNodeContribsCreateEmptyDictionary\mathrm{topNodeContribs}\leftarrow\operatorname{CreateEmptyDictionary}()roman_topNodeContribs ← roman_CreateEmptyDictionary ( )
8:for all i{0 to numBlocks1}𝑖0 to numBlocks1i\in\{0\textrm{ to numBlocks}-1\}italic_i ∈ { 0 to numBlocks - 1 } in decreasing order of blockContribution[i]blockContributiondelimited-[]𝑖\textrm{blockContribution}[i]blockContribution [ italic_i ] do
9:     for all nπ1({i})𝑛superscript𝜋1𝑖n\in\pi^{-1}(\{i\})italic_n ∈ italic_π start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( { italic_i } ) do \triangleright Eval all nodes in block.
10:         if spentBudget<MspentBudget𝑀\mathrm{spentBudget}<Mroman_spentBudget < italic_M then
11:              topNodeContribs[n]~({n})|topNodeContribsdelimited-[]𝑛delimited-∣|~𝑛\mathrm{topNodeContribs}[n]\leftarrow\mid\tilde{\mathcal{I}}(\{n\})|roman_topNodeContribs [ italic_n ] ← ∣ over~ start_ARG caligraphic_I end_ARG ( { italic_n } ) |
12:              spentBudgetspentBudget+1spentBudgetspentBudget1\mathrm{spentBudget}\leftarrow\mathrm{spentBudget}+1roman_spentBudget ← roman_spentBudget + 1
13:         else
14:              return topNodeContribs               
15:return topNodeContribs

The Hierarchical baseline algorithm aims to resolve this tradeoff, by using small blocks, but grouped into superblocks so it’s not necessary to traverse all the small blocks before finding the key nodes. In Algorithm 3 we detail the hierarchical algorithm in its iterative form, corresponding to batch size 1.

One aspect that might be surprising is that on line 22, we ensure a subblock is never added to the priority queue with higher priority than its ancestor superblocks. The reason for doing this is that in practice we use batched inference rather than patching a single block at a time, so depending on the batch size, we do evaluate blocks that aren’t the highest-priority unevaluated blocks, and this might impose a significant delay in when some blocks are evaluated. In order to reduce this dependence on the batch size hyperparameter, line 22 ensures that every block is evaluated at most L𝐿Litalic_L batches later than it would be with batch size 1.

Algorithm 3 Hierarchical algorithm for causal attribution, in iterative form. In practice we do additional batching rather than evaluating a single block at a time on line 15.
1:branching factor B𝐵Bitalic_B, num levels L𝐿Litalic_L, compute budget M𝑀Mitalic_M, nodes N={ni}𝑁subscript𝑛𝑖N=\{n_{i}\}italic_N = { italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }, intervention function \mathcal{I}caligraphic_I
2:numTopLevelBlocks|N|/BLnumTopLevelBlocks𝑁superscript𝐵𝐿\mathrm{numTopLevelBlocks}\leftarrow\lceil|N|/B^{L}\rceilroman_numTopLevelBlocks ← ⌈ | italic_N | / italic_B start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ⌉
3:πshuffle({numTopLevelBlocksiBL/|N||i{0,,|N|1}})\pi\leftarrow\operatorname{shuffle}\left(\left\{\left\lfloor\mathrm{% numTopLevelBlocks}\cdot iB^{L}/|N|\right\rfloor\middle|i\in\{0,\dots,|N|-1\}% \right\}\right)italic_π ← roman_shuffle ( { ⌊ roman_numTopLevelBlocks ⋅ italic_i italic_B start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT / | italic_N | ⌋ | italic_i ∈ { 0 , … , | italic_N | - 1 } } )
4:for all niNsubscript𝑛𝑖𝑁n_{i}\in Nitalic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_N do
5:     (dL1,dL2,,d0)zero-padded final Lsubscript𝑑𝐿1subscript𝑑𝐿2subscript𝑑0zero-padded final 𝐿(d_{L-1},d_{L-2},\dots,d_{0})\leftarrow\text{zero-padded final }L( italic_d start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_L - 2 end_POSTSUBSCRIPT , … , italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ← zero-padded final italic_L base-B𝐵Bitalic_B digits of πisubscript𝜋𝑖\pi_{i}italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
6:     address(ni)=(πi/BL,dL1,,d0)addresssubscript𝑛𝑖subscript𝜋𝑖superscript𝐵𝐿subscript𝑑𝐿1subscript𝑑0\mathrm{address}(n_{i})=(\lfloor\pi_{i}/B^{L}\rfloor,d_{L-1},\dots,d_{0})roman_address ( italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = ( ⌊ italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / italic_B start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ⌋ , italic_d start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT , … , italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
7:QCreateEmptyPriorityQueue()𝑄CreateEmptyPriorityQueueQ\leftarrow\operatorname{CreateEmptyPriorityQueue}()italic_Q ← roman_CreateEmptyPriorityQueue ( )
8:for i0 to numTopLevelBlocks1𝑖0 to numTopLevelBlocks1i\leftarrow 0\textrm{ to numTopLevelBlocks}-1italic_i ← 0 to numTopLevelBlocks - 1 do
9:     PriorityQueueInsert(Q,[i],)PriorityQueueInsert𝑄delimited-[]𝑖\operatorname{PriorityQueueInsert}(Q,[i],\infty)roman_PriorityQueueInsert ( italic_Q , [ italic_i ] , ∞ )
10:spentBudget0spentBudget0\mathrm{spentBudget}\leftarrow 0roman_spentBudget ← 0
11:topNodeContribsCreateEmptyDictionary()topNodeContribsCreateEmptyDictionary\mathrm{topNodeContribs}\leftarrow\operatorname{CreateEmptyDictionary}()roman_topNodeContribs ← roman_CreateEmptyDictionary ( )
12:repeat
13:     (addressPrefix,priority)PriorityQueuePop(Q)addressPrefixpriorityPriorityQueuePop𝑄(\mathrm{addressPrefix},\mathrm{priority})\leftarrow\operatorname{% PriorityQueuePop}(Q)( roman_addressPrefix , roman_priority ) ← roman_PriorityQueuePop ( italic_Q )
14:     blockNodes{nN|StartsWith(address(n),addressPrefix)}blockNodesconditional-set𝑛𝑁StartsWithaddress𝑛addressPrefix\mathrm{blockNodes}\leftarrow\left\{n\in N\middle|\operatorname{StartsWith}(% \mathrm{address}(n),\mathrm{addressPrefix})\right\}roman_blockNodes ← { italic_n ∈ italic_N | roman_StartsWith ( roman_address ( italic_n ) , roman_addressPrefix ) }
15:     blockContribution|(blockNodes)|blockContributionblockNodes\mathrm{blockContribution}\leftarrow|\mathcal{I}\left(\mathrm{blockNodes}% \right)|roman_blockContribution ← | caligraphic_I ( roman_blockNodes ) |
16:     spentBudgetspentBudget+1spentBudgetspentBudget1\mathrm{spentBudget}\leftarrow\mathrm{spentBudget}+1roman_spentBudget ← roman_spentBudget + 1
17:     if blockNodes={n}blockNodes𝑛\mathrm{blockNodes}=\{n\}roman_blockNodes = { italic_n } for some nN𝑛𝑁n\in Nitalic_n ∈ italic_N then
18:         topNodeContribs[n]blockContributiontopNodeContribsdelimited-[]𝑛blockContribution\mathrm{topNodeContribs}[n]\leftarrow\mathrm{blockContribution}roman_topNodeContribs [ italic_n ] ← roman_blockContribution
19:     else
20:         for i0 to B1𝑖0 to 𝐵1i\leftarrow 0\textrm{ to }B-1italic_i ← 0 to italic_B - 1 do
21:              if {nblockNodes|StartsWith(address(n),addressPrefix+[i]}\{n\in\mathrm{blockNodes}|\operatorname{StartsWith}(\mathrm{address}(n),% \mathrm{addressPrefix}+[i]\}\not=\emptyset{ italic_n ∈ roman_blockNodes | roman_StartsWith ( roman_address ( italic_n ) , roman_addressPrefix + [ italic_i ] } ≠ ∅ then
22:                  PriorityQueueInsert(Q,addressPrefix+[i],min(blockContribution,priority))PriorityQueueInsert𝑄addressPrefixdelimited-[]𝑖blockContributionpriority\operatorname{PriorityQueueInsert}(Q,\mathrm{addressPrefix}+[i],\min(\mathrm{% blockContribution},\mathrm{priority}))roman_PriorityQueueInsert ( italic_Q , roman_addressPrefix + [ italic_i ] , roman_min ( roman_blockContribution , roman_priority ) )                             
23:until spentBudget=MspentBudget𝑀\mathrm{spentBudget}=Mroman_spentBudget = italic_M or PriorityQueueEmpty(Q)PriorityQueueEmpty𝑄\operatorname{PriorityQueueEmpty}(Q)roman_PriorityQueueEmpty ( italic_Q )
24:return topNodeContribs

A.2 AtP improvements

A.2.1 Pseudocode for corrected AtP on attention keys

As described in Section 3.1.1, computing Equation 10 naïvely for all nodes requires O(T3)Osuperscript𝑇3\operatorname{O}(T^{3})roman_O ( italic_T start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ) flops at each attention head and prompt pair. Here we give a more efficient algorithm running in O(T2)Osuperscript𝑇2\operatorname{O}(T^{2})roman_O ( italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). In addition to keys, queries and attention probabilities, we now also cache attention logits (pre-softmax scaled key-query dot products).

We define attnLogitspatcht(nq)superscriptsubscriptattnLogitspatch𝑡superscript𝑛𝑞\operatorname{attnLogits}_{\text{patch}}^{t}(n^{q})roman_attnLogits start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) and ΔtattnLogits(nq)subscriptΔ𝑡attnLogitssuperscript𝑛𝑞\Delta_{t}\operatorname{attnLogits}(n^{q})roman_Δ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_attnLogits ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) analogously to Equations 8 and 9. For brevity we can also define attnLogitspatch(nq)t:=attnLogitspatcht(nq)t\operatorname{attnLogits}_{\text{patch}}(n^{q})_{t}:=\operatorname{attnLogits}% ^{t}_{\text{patch}}(n^{q})_{t}roman_attnLogits start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT := roman_attnLogits start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and ΔattnLogits(nq)t:=ΔtattnLogits(nq)t\Delta\operatorname{attnLogits}(n^{q})_{t}:=\Delta_{t}\operatorname{attnLogits% }(n^{q})_{t}roman_Δ roman_attnLogits ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT := roman_Δ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_attnLogits ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, since the aim with this algorithm is to avoid having to separately compute effects of do(ntkntk(xnoise))dosubscriptsuperscript𝑛𝑘𝑡subscriptsuperscript𝑛𝑘𝑡superscript𝑥noise\operatorname{do}(n^{k}_{t}\leftarrow n^{k}_{t}(x^{\text{noise}}))roman_do ( italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) on any other component of attnLogitsattnLogits\operatorname{attnLogits}roman_attnLogits than the one for key node ntksubscriptsuperscript𝑛𝑘𝑡n^{k}_{t}italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Note that, for a key ntksubscriptsuperscript𝑛𝑘𝑡n^{k}_{t}italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT at position t𝑡titalic_t in the sequence, the proportions of the non-t𝑡titalic_t components of attn(nq)t\operatorname{attn}(n^{q})_{t}roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT do not change when attnLogits(nq)t\operatorname{attnLogits}(n^{q})_{t}roman_attnLogits ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is changed, so Δtattn(nq)subscriptΔ𝑡attnsuperscript𝑛𝑞\Delta_{t}\operatorname{attn}(n^{q})roman_Δ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) is actually onehot(t)attn(nq)onehot𝑡attnsuperscript𝑛𝑞\mathrm{onehot}(t)-\operatorname{attn}(n^{q})roman_onehot ( italic_t ) - roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) multiplied by some scalar stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT; specifically, to get the right attention weight on ntksubscriptsuperscript𝑛𝑘𝑡n^{k}_{t}italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, the scalar must be st:=Δattn(nq)t1attn(nq)ts_{t}:=\frac{\Delta\operatorname{attn}(n^{q})_{t}}{1-\operatorname{attn}(n^{q}% )_{t}}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT := divide start_ARG roman_Δ roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 - roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG. Additionally, we have log(attnpatcht(nq)t1attnpatcht(nq)t)=log(attn(nq)t1attn(nq)t)+ΔattnLogits(nq)t\log\left(\frac{\operatorname{attn}_{\text{patch}}^{t}(n^{q})_{t}}{1-% \operatorname{attn}_{\text{patch}}^{t}(n^{q})_{t}}\right)=\log\left(\frac{% \operatorname{attn}(n^{q})_{t}}{1-\operatorname{attn}(n^{q})_{t}}\right)+% \Delta\operatorname{attnLogits}(n^{q})_{t}roman_log ( divide start_ARG roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 - roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ) = roman_log ( divide start_ARG roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 - roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ) + roman_Δ roman_attnLogits ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT; note that the logodds function plog(p1p)maps-to𝑝𝑝1𝑝p\mapsto\log\left(\frac{p}{1-p}\right)italic_p ↦ roman_log ( divide start_ARG italic_p end_ARG start_ARG 1 - italic_p end_ARG ) is the inverse of the sigmoid function, so attnpatcht(nq)=σ(log(attnpatcht(nq)t1attnpatcht(nq)t))\operatorname{attn}_{\text{patch}}^{t}(n^{q})=\operatorname{\sigma}\left(\log% \left(\frac{\operatorname{attn}_{\text{patch}}^{t}(n^{q})_{t}}{1-\operatorname% {attn}_{\text{patch}}^{t}(n^{q})_{t}}\right)\right)roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) = italic_σ ( roman_log ( divide start_ARG roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 - roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ) ). Putting this together, we can compute all attnLogitspatch(nq)subscriptattnLogitspatchsuperscript𝑛𝑞\operatorname{attnLogits}_{\text{patch}}(n^{q})roman_attnLogits start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) by combining all keys from the xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT forward pass with all queries from the xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT forward pass, and proceed to compute ΔattnLogits(nq)ΔattnLogitssuperscript𝑛𝑞\Delta\operatorname{attnLogits}(n^{q})roman_Δ roman_attnLogits ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ), and all Δtattn(nq)t\Delta_{t}\operatorname{attn}(n^{q})_{t}roman_Δ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and thus all ^AtPfixK(nt;xclean,xnoise)superscriptsubscript^AtPfix𝐾subscript𝑛𝑡superscript𝑥cleansuperscript𝑥noise\hat{\mathcal{I}}_{\text{{AtPfix}}}^{K}(n_{t};x^{\text{clean}},x^{\text{noise}})over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtPfix end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( italic_n start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ), using O(T2)Osuperscript𝑇2\operatorname{O}(T^{2})roman_O ( italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) flops per attention head.

Algorithm 4 computes the contribution of some query node nqsuperscript𝑛𝑞n^{q}italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT and prompt pair xclean,xnoisesuperscript𝑥cleansuperscript𝑥noisex^{\text{clean}},x^{\text{noise}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT to the corrected AtP estimates c^AtPfixK(ntk)superscriptsubscript^𝑐AtPfix𝐾subscriptsuperscript𝑛𝑘𝑡\hat{c}_{\text{AtPfix}}^{K}(n^{k}_{t})over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtPfix end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) for key nodes n1k,,nTksubscriptsuperscript𝑛𝑘1subscriptsuperscript𝑛𝑘𝑇n^{k}_{1},\dots,n^{k}_{T}italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT from a single attention head, using O(T)𝑂𝑇O(T)italic_O ( italic_T ) flops, while avoiding numerical overflows. We reuse the notation attn(nq)attnsuperscript𝑛𝑞\operatorname{attn}(n^{q})roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ), attnpatcht(nq)superscriptsubscriptattnpatch𝑡superscript𝑛𝑞\operatorname{attn}_{\text{patch}}^{t}(n^{q})roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ), Δtattn(nq)subscriptΔ𝑡attnsuperscript𝑛𝑞\Delta_{t}\operatorname{attn}(n^{q})roman_Δ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ), attnLogits(nq)attnLogitssuperscript𝑛𝑞\operatorname{attnLogits}(n^{q})roman_attnLogits ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ), attnLogitspatch(nq)subscriptattnLogitspatchsuperscript𝑛𝑞\operatorname{attnLogits}_{\text{patch}}(n^{q})roman_attnLogits start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ), and stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from Section 3.1.1, leaving the prompt pair implicit.

Algorithm 4 AtP correction for attention keys
1:𝐚:=attnLogits(nq)assign𝐚attnLogitssuperscript𝑛𝑞\mathbf{a}:=\operatorname{attnLogits}(n^{q})bold_a := roman_attnLogits ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ), 𝐚patch:=attnLogitspatch(nq)assignsuperscript𝐚patchsubscriptattnLogitspatchsuperscript𝑛𝑞\mathbf{a}^{\text{patch}}:=\operatorname{attnLogits}_{\text{patch}}(n^{q})bold_a start_POSTSUPERSCRIPT patch end_POSTSUPERSCRIPT := roman_attnLogits start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ), 𝐠:=((xclean))attn(nq)assign𝐠superscript𝑥cleanattnsuperscript𝑛𝑞\mathbf{g}:=\frac{\partial\mathcal{L}(\mathcal{M}(x^{\text{clean}}))}{\partial% \operatorname{attn}(n^{q})}bold_g := divide start_ARG ∂ caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) end_ARG start_ARG ∂ roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) end_ARG
2:t*argmaxt(at)superscript𝑡subscriptargmax𝑡subscript𝑎𝑡t^{*}\leftarrow\operatorname{argmax}_{t}(a_{t})italic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ← roman_argmax start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
3:𝐚at*log(teatat*)𝐚subscript𝑎superscript𝑡subscript𝑡superscript𝑒subscript𝑎𝑡subscript𝑎superscript𝑡\ell\leftarrow\mathbf{a}-a_{t^{*}}-\log\left(\sum_{t}e^{a_{t}-a_{t^{*}}}\right)roman_ℓ ← bold_a - italic_a start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - roman_log ( ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_a start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) \triangleright Clean log attn weights, =log(attn(nq))attnsuperscript𝑛𝑞\ell=\log(\operatorname{attn}(n^{q}))roman_ℓ = roman_log ( roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) )
4:𝐝log(1e)𝐝1superscript𝑒\mathbf{d}\leftarrow\ell-\log(1-e^{\ell})bold_d ← roman_ℓ - roman_log ( 1 - italic_e start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) \triangleright Clean logodds, dt=log(attn(nq)t1attn(nq)t)d_{t}=\log\left(\frac{\operatorname{attn}(n^{q})_{t}}{1-\operatorname{attn}(n^% {q})_{t}}\right)italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_log ( divide start_ARG roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 - roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG )
5:dt*at*maxtt*atlog(tt*eatmaxtt*at)subscript𝑑superscript𝑡subscript𝑎superscript𝑡subscript𝑡superscript𝑡subscript𝑎𝑡subscriptsuperscript𝑡superscript𝑡superscript𝑒subscript𝑎superscript𝑡subscript𝑡superscript𝑡subscript𝑎𝑡d_{t^{*}}\leftarrow a_{t^{*}}-\max_{t\not=t^{*}}a_{t}-\log\left(\sum_{t^{% \prime}\not=t^{*}}e^{a_{t^{\prime}}-\max_{t\not=t^{*}}a_{t}}\right)italic_d start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ← italic_a start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - roman_max start_POSTSUBSCRIPT italic_t ≠ italic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - roman_log ( ∑ start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - roman_max start_POSTSUBSCRIPT italic_t ≠ italic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) \triangleright Adjust 𝐝𝐝\mathbf{d}bold_d; more stable for at*maxtt*atmuch-greater-thansubscript𝑎superscript𝑡subscript𝑡superscript𝑡subscript𝑎𝑡a_{t^{*}}\gg\max_{t\not=t^{*}}a_{t}italic_a start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ≫ roman_max start_POSTSUBSCRIPT italic_t ≠ italic_t start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
6:patchlogsigmoid(𝐝+𝐚patch𝐚)superscriptpatchlogsigmoid𝐝superscript𝐚patch𝐚\ell^{\text{patch}}\leftarrow\operatorname{logsigmoid}(\mathbf{d}+\mathbf{a}^{% \text{patch}}-\mathbf{a})roman_ℓ start_POSTSUPERSCRIPT patch end_POSTSUPERSCRIPT ← roman_logsigmoid ( bold_d + bold_a start_POSTSUPERSCRIPT patch end_POSTSUPERSCRIPT - bold_a ) \triangleright Patched log attn weights, tpatch=log(attnpatcht(nq)t)\ell^{\text{patch}}_{t}=\log(\operatorname{attn}_{\text{patch}}^{t}(n^{q})_{t})roman_ℓ start_POSTSUPERSCRIPT patch end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_log ( roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
7:ΔpatchΔsuperscriptpatch\Delta\ell\leftarrow\ell^{\text{patch}}-\ellroman_Δ roman_ℓ ← roman_ℓ start_POSTSUPERSCRIPT patch end_POSTSUPERSCRIPT - roman_ℓ \triangleright Δt=log(attnpatcht(nq)tattn(nq)t)\Delta\ell_{t}=\log\left(\frac{\operatorname{attn}_{\text{patch}}^{t}(n^{q})_{% t}}{\operatorname{attn}(n^{q})_{t}}\right)roman_Δ roman_ℓ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_log ( divide start_ARG roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG )
8:bsoftmax(𝐚)𝐠b\leftarrow\operatorname{softmax}(\mathbf{a})^{\intercal}\mathbf{g}italic_b ← roman_softmax ( bold_a ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT bold_g \triangleright b=attn(nq)𝐠b=\operatorname{attn}(n^{q})^{\intercal}\mathbf{g}italic_b = roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT bold_g
9:for t1 to T𝑡1 to 𝑇t\leftarrow 1\textrm{ to }Titalic_t ← 1 to italic_T do
10:     \triangleright Compute scaling factor st:=Δtattn(nq)t1attn(nq)ts_{t}:=\frac{\Delta_{t}\operatorname{attn}(n^{q})_{t}}{1-\operatorname{attn}(n% ^{q})_{t}}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT := divide start_ARG roman_Δ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 - roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG
11:     if tpatch>tsubscriptsuperscriptpatch𝑡subscript𝑡\ell^{\text{patch}}_{t}>\ell_{t}roman_ℓ start_POSTSUPERSCRIPT patch end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT > roman_ℓ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT thennormal-▷\triangleright Avoid overflow when tpatchtmuch-greater-thansubscriptsuperscriptnormal-ℓpatch𝑡subscriptnormal-ℓ𝑡\ell^{\text{patch}}_{t}\gg\ell_{t}roman_ℓ start_POSTSUPERSCRIPT patch end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≫ roman_ℓ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
12:         stedt+Δt+log(1eΔt)subscript𝑠𝑡superscript𝑒subscript𝑑𝑡Δsubscript𝑡1superscript𝑒Δsubscript𝑡s_{t}\leftarrow e^{d_{t}+\Delta\ell_{t}+\log(1-e^{-\Delta\ell_{t}})}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_e start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + roman_Δ roman_ℓ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + roman_log ( 1 - italic_e start_POSTSUPERSCRIPT - roman_Δ roman_ℓ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT \triangleright st=attn(nq)t1attn(nq)tattnpatcht(nq)tattn(nq)t(1attn(nq)tattnpatcht(nq)t)s_{t}=\frac{\operatorname{attn}(n^{q})_{t}}{1-\operatorname{attn}(n^{q})_{t}}% \frac{\operatorname{attn}_{\text{patch}}^{t}(n^{q})_{t}}{\operatorname{attn}(n% ^{q})_{t}}\left(1-\frac{\operatorname{attn}(n^{q})_{t}}{\operatorname{attn}_{% \text{patch}}^{t}(n^{q})_{t}}\right)italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 - roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG divide start_ARG roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ( 1 - divide start_ARG roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG )
13:     elsenormal-▷\triangleright Avoid overflow when tpatchtmuch-less-thansubscriptsuperscriptnormal-ℓpatch𝑡subscriptnormal-ℓ𝑡\ell^{\text{patch}}_{t}\ll\ell_{t}roman_ℓ start_POSTSUPERSCRIPT patch end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≪ roman_ℓ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
14:         stedt+log(1eΔt)subscript𝑠𝑡superscript𝑒subscript𝑑𝑡1superscript𝑒Δsubscript𝑡s_{t}\leftarrow-e^{d_{t}+\log(1-e^{\Delta\ell_{t}})}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← - italic_e start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + roman_log ( 1 - italic_e start_POSTSUPERSCRIPT roman_Δ roman_ℓ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT \triangleright st=attn(nq)t1attn(nq)t(1attnpatcht(nq)tattn(nq)t)s_{t}=-\frac{\operatorname{attn}(n^{q})_{t}}{1-\operatorname{attn}(n^{q})_{t}}% \left(1-\frac{\operatorname{attn}_{\text{patch}}^{t}(n^{q})_{t}}{\operatorname% {attn}(n^{q})_{t}}\right)italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = - divide start_ARG roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG 1 - roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ( 1 - divide start_ARG roman_attn start_POSTSUBSCRIPT patch end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG )      
15:     rtst(gtb)subscript𝑟𝑡subscript𝑠𝑡subscript𝑔𝑡𝑏r_{t}\leftarrow s_{t}(g_{t}-b)italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_b ) \triangleright rt=st(onehot(t)attn(nq))𝐠=Δtattn(nq)((xclean))attn(nq)subscript𝑟𝑡subscript𝑠𝑡superscriptonehot𝑡attnsuperscript𝑛𝑞𝐠subscriptΔ𝑡attnsuperscript𝑛𝑞superscript𝑥cleanattnsuperscript𝑛𝑞r_{t}=s_{t}(\mathrm{onehot}(t)-\operatorname{attn}(n^{q}))^{\intercal}\mathbf{% g}=\Delta_{t}\operatorname{attn}(n^{q})\cdot\frac{\partial\mathcal{L}(\mathcal% {M}(x^{\text{clean}}))}{\partial\operatorname{attn}(n^{q})}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( roman_onehot ( italic_t ) - roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT bold_g = roman_Δ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) ⋅ divide start_ARG ∂ caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) end_ARG start_ARG ∂ roman_attn ( italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) end_ARG
16:return 𝐫𝐫\mathbf{r}bold_r

The corrected AtP estimates c^AtPfixK(ntk)superscriptsubscript^𝑐AtPfix𝐾subscriptsuperscript𝑛𝑘𝑡\hat{c}_{\text{AtPfix}}^{K}(n^{k}_{t})over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtPfix end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) can then be computed using Equation 10; in other words, by summing the returned rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from Algorithm 4 over queries nqsuperscript𝑛𝑞n^{q}italic_n start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT for this attention head, and averaging over xclean,xnoise𝒟similar-tosuperscript𝑥cleansuperscript𝑥noise𝒟x^{\text{clean}},x^{\text{noise}}\sim\mathcal{D}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ∼ caligraphic_D.

A.2.2 Properties of GradDrop

In Section 3.1.2 we introduced GradDrop to address an AtP failure mode arising from cancellation between direct and indirect effects: roughly, if the total effect (on some prompt pair) is (n)=direct(n)+indirect(n)𝑛superscriptdirect𝑛superscriptindirect𝑛\mathcal{I}(n)=\mathcal{I}^{\text{direct}}(n)+\mathcal{I}^{\text{indirect}}(n)caligraphic_I ( italic_n ) = caligraphic_I start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) + caligraphic_I start_POSTSUPERSCRIPT indirect end_POSTSUPERSCRIPT ( italic_n ), and these are close to cancelling, then a small multiplicative approximation error in ^AtPindirect(n)superscriptsubscript^AtPindirect𝑛\hat{\mathcal{I}}_{\text{AtP}}^{\text{indirect}}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT indirect end_POSTSUPERSCRIPT ( italic_n ), due to nonlinearities, can accidentally cause |^AtPdirect(n)+^AtPindirect(n)|superscriptsubscript^AtPdirect𝑛superscriptsubscript^AtPindirect𝑛|\hat{\mathcal{I}}_{\text{AtP}}^{\text{direct}}(n)+\hat{\mathcal{I}}_{\text{% AtP}}^{\text{indirect}}(n)|| over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) + over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT indirect end_POSTSUPERSCRIPT ( italic_n ) | to be orders of magnitude smaller than |(n)|𝑛|\mathcal{I}(n)|| caligraphic_I ( italic_n ) |.

To address this failure mode with an improved estimator c^AtP+GD(n)subscript^𝑐AtP+GD𝑛\hat{c}_{\text{AtP+GD}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP+GD end_POSTSUBSCRIPT ( italic_n ), there’s 3 desiderata for GradDrop:

  1. 1.

    c^AtP+GD(n)subscript^𝑐AtP+GD𝑛\hat{c}_{\text{AtP+GD}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP+GD end_POSTSUBSCRIPT ( italic_n ) shouldn’t be much smaller than c^AtP(n)subscript^𝑐AtP𝑛\hat{c}_{\text{AtP}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ), because that would risk creating more false negatives.

  2. 2.

    c^AtP+GD(n)subscript^𝑐AtP+GD𝑛\hat{c}_{\text{AtP+GD}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP+GD end_POSTSUBSCRIPT ( italic_n ) should usually not be much larger than c^AtP(n)subscript^𝑐AtP𝑛\hat{c}_{\text{AtP}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ), because that would create false positives, which also slows down verification and can effectively create false negatives at a given budget.

  3. 3.

    If c^AtP(n)subscript^𝑐AtP𝑛\hat{c}_{\text{AtP}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) is suffering from the cancellation failure mode, then c^AtP+GD(n)subscript^𝑐AtP+GD𝑛\hat{c}_{\text{AtP+GD}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP+GD end_POSTSUBSCRIPT ( italic_n ) should be significantly larger than c^AtP(n)subscript^𝑐AtP𝑛\hat{c}_{\text{AtP}}(n)over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ).

Let’s recall how GradDrop was defined in Section 3.1.2, using a virtual node noutsuperscriptsubscript𝑛outn_{\ell}^{\text{out}}italic_n start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT out end_POSTSUPERSCRIPT to represent the residual-stream contributions of layer \ellroman_ℓ:

c^AtP+GD(n):=assignsubscript^𝑐AtP+GD𝑛absent\displaystyle\hat{c}_{\text{AtP+GD}}(n):={}over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT AtP+GD end_POSTSUBSCRIPT ( italic_n ) := 𝔼xclean,xnoise[1L1=1L|^AtP+GD(n;xclean,xnoise)|]subscript𝔼superscript𝑥cleansuperscript𝑥noisedelimited-[]1𝐿1superscriptsubscript1𝐿subscript^subscriptAtP+GD𝑛superscript𝑥cleansuperscript𝑥noise\displaystyle\mathbb{E}_{x^{\text{clean}},x^{\text{noise}}}\left[\frac{1}{L-1}% \sum_{\ell=1}^{L}\left|\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n;x^{\text{% clean}},x^{\text{noise}})\right|\right]blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG italic_L - 1 end_ARG ∑ start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) | ]
=\displaystyle={}= 𝔼xclean,xnoise[1L1=1L|(n(xnoise)n(xclean))n|]subscript𝔼superscript𝑥cleansuperscript𝑥noisedelimited-[]1𝐿1superscriptsubscript1𝐿superscript𝑛superscript𝑥noise𝑛superscript𝑥cleansuperscript𝑛\displaystyle\mathbb{E}_{x^{\text{clean}},x^{\text{noise}}}\left[\frac{1}{L-1}% \sum_{\ell=1}^{L}\left|(n(x^{\text{noise}})-n(x^{\text{clean}}))^{\intercal}% \frac{\partial\mathcal{L}^{\ell}}{\partial n}\right|\right]blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG italic_L - 1 end_ARG ∑ start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT | ( italic_n ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) - italic_n ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT divide start_ARG ∂ caligraphic_L start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ italic_n end_ARG | ]
=\displaystyle={}= 𝔼xclean,xnoise[1L1=1L|(n(xnoise)n(xclean))n((xcleando(noutnout(xclean))))|]\displaystyle\mathbb{E}_{x^{\text{clean}},x^{\text{noise}}}\left[\frac{1}{L-1}% \sum_{\ell=1}^{L}\left|(n(x^{\text{noise}})-n(x^{\text{clean}}))^{\intercal}% \frac{\partial\mathcal{L}}{\partial n}(\mathcal{M}(x^{\text{clean}}\mid% \operatorname{do}(n^{\text{out}}_{\ell}\leftarrow n^{\text{out}}_{\ell}(x^{% \text{clean}}))))\right|\right]blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG italic_L - 1 end_ARG ∑ start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT | ( italic_n ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) - italic_n ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT divide start_ARG ∂ caligraphic_L end_ARG start_ARG ∂ italic_n end_ARG ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ∣ roman_do ( italic_n start_POSTSUPERSCRIPT out end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ← italic_n start_POSTSUPERSCRIPT out end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) ) ) | ]

To better understand the behaviour of GradDrop, let’s look more carefully at the gradient n𝑛\frac{\partial\mathcal{L}}{\partial n}divide start_ARG ∂ caligraphic_L end_ARG start_ARG ∂ italic_n end_ARG. The total gradient n𝑛\frac{\partial\mathcal{L}}{\partial n}divide start_ARG ∂ caligraphic_L end_ARG start_ARG ∂ italic_n end_ARG can be expressed as a sum of all path gradients from the node n𝑛nitalic_n to the output. Each path is characterized by the set of layers s𝑠sitalic_s it goes through (in contrast to routing via the skip connection). We write the gradient along a path s𝑠sitalic_s as snsubscript𝑠𝑛\frac{\partial\mathcal{L}_{s}}{\partial n}divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_n end_ARG.

Let 𝒮𝒮\mathcal{S}caligraphic_S be the set of all subsets of layers after the layer n𝑛nitalic_n is in. For example, the direct-effect path is given by 𝒮𝒮\emptyset\in\mathcal{S}∅ ∈ caligraphic_S. Then the total gradient can be expressed as

n𝑛\displaystyle\frac{\partial\mathcal{L}}{\partial n}divide start_ARG ∂ caligraphic_L end_ARG start_ARG ∂ italic_n end_ARG =s𝒮sn.absentsubscript𝑠𝒮subscript𝑠𝑛\displaystyle=\sum_{s\in\mathcal{S}}\frac{\partial\mathcal{L}_{s}}{\partial n}.= ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_S end_POSTSUBSCRIPT divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_n end_ARG . (18)

We can analogously define ^AtPs(n)=(n(xnoise)n(xclean))snsuperscriptsubscript^AtP𝑠𝑛superscript𝑛superscript𝑥noise𝑛superscript𝑥cleansubscript𝑠𝑛\hat{\mathcal{I}}_{\text{AtP}}^{s}(n)=(n(x^{\text{noise}})-n(x^{\text{clean}})% )^{\intercal}\frac{\partial\mathcal{L}_{s}}{\partial n}over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) = ( italic_n ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) - italic_n ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_n end_ARG, and break down ^AtP(n)=s𝒮^AtPs(n)subscript^AtP𝑛subscript𝑠𝒮superscriptsubscript^AtP𝑠𝑛\hat{\mathcal{I}}_{\text{AtP}}(n)=\sum_{s\in\mathcal{S}}\hat{\mathcal{I}}_{% \text{AtP}}^{s}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) = ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_S end_POSTSUBSCRIPT over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ). The effect of doing GradDrop at some layer \ellroman_ℓ is then to drop all terms ^AtPs(n)superscriptsubscript^AtP𝑠𝑛\hat{\mathcal{I}}_{\text{AtP}}^{s}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) with s𝑠\ell\in sroman_ℓ ∈ italic_s: in other words,

^AtP+GD(n)subscript^subscriptAtP+GD𝑛\displaystyle\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) =s𝒮s^AtPs(n).absentsubscript𝑠𝒮𝑠superscriptsubscript^AtP𝑠𝑛\displaystyle=\sum_{\begin{subarray}{c}s\in\mathcal{S}\\ \ell\not\in s\end{subarray}}\hat{\mathcal{I}}_{\text{AtP}}^{s}(n).= ∑ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_s ∈ caligraphic_S end_CELL end_ROW start_ROW start_CELL roman_ℓ ∉ italic_s end_CELL end_ROW end_ARG end_POSTSUBSCRIPT over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) . (21)

Now we’ll use this understanding to discuss the 3 desiderata.

Firstly, most node effects are approximately independent of most layers (see e.g. Veit et al. (2016)); for any layer \ellroman_ℓ that n𝑛nitalic_n’s effect is independent of, we’ll have ^AtP+GD(n)=^AtP(n)subscript^subscriptAtP+GD𝑛subscript^AtP𝑛\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n)=\hat{\mathcal{I}}_{\text{AtP}}(n)over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) = over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ). Letting K𝐾Kitalic_K be the set of downstream layers that matter, this guarantees 1L1=1L|^AtP+GD(n;xclean,xnoise)|L|K|1L1|^AtP(n;xclean,xnoise)|1𝐿1superscriptsubscript1𝐿subscript^subscriptAtP+GD𝑛superscript𝑥cleansuperscript𝑥noise𝐿𝐾1𝐿1subscript^AtP𝑛superscript𝑥cleansuperscript𝑥noise\frac{1}{L-1}\sum_{\ell=1}^{L}\left|\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n% ;x^{\text{clean}},x^{\text{noise}})\right|\geq\frac{L-|K|-1}{L-1}\left|\hat{% \mathcal{I}}_{\text{AtP}}(n;x^{\text{clean}},x^{\text{noise}})\right|divide start_ARG 1 end_ARG start_ARG italic_L - 1 end_ARG ∑ start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) | ≥ divide start_ARG italic_L - | italic_K | - 1 end_ARG start_ARG italic_L - 1 end_ARG | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) |, which meets the first desideratum.

Regarding the second desideratum: for each \ellroman_ℓ we have |^AtP+GD(n)|s𝒮|^AtPs(n)|subscript^subscriptAtP+GD𝑛subscript𝑠𝒮superscriptsubscript^AtP𝑠𝑛\left|\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n)\right|\leq\sum_{s\in\mathcal% {S}}\left|\hat{\mathcal{I}}_{\text{AtP}}^{s}(n)\right|| over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) | ≤ ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_S end_POSTSUBSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) |, so overall we have 1L1=1L|^AtP+GD(n)|L|K|1L1|^AtP(n)|+|K|L1s𝒮|^AtPs(n)|1𝐿1superscriptsubscript1𝐿subscript^subscriptAtP+GD𝑛𝐿𝐾1𝐿1subscript^AtP𝑛𝐾𝐿1subscript𝑠𝒮superscriptsubscript^AtP𝑠𝑛\frac{1}{L-1}\sum_{\ell=1}^{L}\left|\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n% )\right|\leq\frac{L-|K|-1}{L-1}\left|\hat{\mathcal{I}}_{\text{AtP}}(n)\right|+% \frac{|K|}{L-1}\sum_{s\in\mathcal{S}}\left|\hat{\mathcal{I}}_{\text{AtP}}^{s}(% n)\right|divide start_ARG 1 end_ARG start_ARG italic_L - 1 end_ARG ∑ start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) | ≤ divide start_ARG italic_L - | italic_K | - 1 end_ARG start_ARG italic_L - 1 end_ARG | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) | + divide start_ARG | italic_K | end_ARG start_ARG italic_L - 1 end_ARG ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_S end_POSTSUBSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) |. For the RHS to be much larger (e.g. α𝛼\alphaitalic_α times larger) than |s𝒮^AtPs(n)|=|^AtP(n)|subscript𝑠𝒮superscriptsubscript^AtP𝑠𝑛subscript^AtP𝑛\left|\sum_{s\in\mathcal{S}}\hat{\mathcal{I}}_{\text{AtP}}^{s}(n)\right|=|\hat% {\mathcal{I}}_{\text{AtP}}(n)|| ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_S end_POSTSUBSCRIPT over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) | = | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) |, there must be quite a lot of cancellation between different paths, enough so that s𝒮|^AtPs(n)|(L1)α|K||s𝒮^AtPs(n)|subscript𝑠𝒮superscriptsubscript^AtP𝑠𝑛𝐿1𝛼𝐾subscript𝑠𝒮superscriptsubscript^AtP𝑠𝑛\sum_{s\in\mathcal{S}}\left|\hat{\mathcal{I}}_{\text{AtP}}^{s}(n)\right|\geq% \frac{(L-1)\alpha}{|K|}\left|\sum_{s\in\mathcal{S}}\hat{\mathcal{I}}_{\text{% AtP}}^{s}(n)\right|∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_S end_POSTSUBSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) | ≥ divide start_ARG ( italic_L - 1 ) italic_α end_ARG start_ARG | italic_K | end_ARG | ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_S end_POSTSUBSCRIPT over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) |. This is possible, but seems generally unlikely for e.g. α>3𝛼3\alpha>3italic_α > 3.

Now let’s consider the third desideratum, i.e. suppose n𝑛nitalic_n is a cancellation false negative, with |^AtP(n)||(n)||direct(n)||^AtPdirect(n)|much-less-thansubscript^AtP𝑛𝑛much-less-thansuperscriptdirect𝑛superscriptsubscript^AtPdirect𝑛|\hat{\mathcal{I}}_{\text{AtP}}(n)|\ll|\mathcal{I}(n)|\ll|\mathcal{I}^{\text{% direct}}(n)|\approx|\hat{\mathcal{I}}_{\text{AtP}}^{\text{direct}}(n)|| over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) | ≪ | caligraphic_I ( italic_n ) | ≪ | caligraphic_I start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) | ≈ | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) |. Then, |s𝒮^AtPs(n)|=|^AtP(n)^AtPdirect(n)||(n)|subscript𝑠𝒮superscriptsubscript^AtP𝑠𝑛subscript^AtP𝑛superscriptsubscript^AtPdirect𝑛much-greater-than𝑛\left|\sum_{s\in\mathcal{S}\setminus\emptyset}\hat{\mathcal{I}}_{\text{AtP}}^{% s}(n)\right|=\left|\hat{\mathcal{I}}_{\text{AtP}}(n)-\hat{\mathcal{I}}_{\text{% AtP}}^{\text{direct}}(n)\right|\gg|\mathcal{I}(n)|| ∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_S ∖ ∅ end_POSTSUBSCRIPT over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) | = | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) | ≫ | caligraphic_I ( italic_n ) |. The summands in s𝒮^AtPs(n)subscript𝑠𝒮superscriptsubscript^AtP𝑠𝑛\sum_{s\in\mathcal{S}\setminus\emptyset}\hat{\mathcal{I}}_{\text{AtP}}^{s}(n)∑ start_POSTSUBSCRIPT italic_s ∈ caligraphic_S ∖ ∅ end_POSTSUBSCRIPT over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) are the union of the summands in s𝒮s^AtPs(n)=^AtP(n)^AtP+GD(n)subscript𝑠𝒮𝑠superscriptsubscript^AtP𝑠𝑛subscript^AtP𝑛subscript^subscriptAtP+GD𝑛\sum_{\begin{subarray}{c}s\in\mathcal{S}\\ \ell\in s\end{subarray}}\hat{\mathcal{I}}_{\text{AtP}}^{s}(n)=\hat{\mathcal{I}% }_{\text{AtP}}(n)-\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n)∑ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_s ∈ caligraphic_S end_CELL end_ROW start_ROW start_CELL roman_ℓ ∈ italic_s end_CELL end_ROW end_ARG end_POSTSUBSCRIPT over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_n ) = over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) across layers \ellroman_ℓ.

It’s then possible but intuitively unlikely that |^AtP(n)^AtP+GD(n)|subscriptsubscript^AtP𝑛subscript^subscriptAtP+GD𝑛\sum_{\ell}\left|\hat{\mathcal{I}}_{\text{AtP}}(n)-\hat{\mathcal{I}}_{\text{% AtP+GD}_{\ell}}(n)\right|∑ start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) | would be much smaller than |^AtP(n)^AtPdirect(n)|subscript^AtP𝑛superscriptsubscript^AtPdirect𝑛\left|\hat{\mathcal{I}}_{\text{AtP}}(n)-\hat{\mathcal{I}}_{\text{AtP}}^{\text{% direct}}(n)\right|| over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) |. Suppose the ratio is α𝛼\alphaitalic_α, i.e. suppose |^AtP(n)^AtP+GD(n)|=α|^AtP(n)^AtPdirect(n)|subscriptsubscript^AtP𝑛subscript^subscriptAtP+GD𝑛𝛼subscript^AtP𝑛superscriptsubscript^AtPdirect𝑛\sum_{\ell}\left|\hat{\mathcal{I}}_{\text{AtP}}(n)-\hat{\mathcal{I}}_{\text{% AtP+GD}_{\ell}}(n)\right|=\alpha\left|\hat{\mathcal{I}}_{\text{AtP}}(n)-\hat{% \mathcal{I}}_{\text{AtP}}^{\text{direct}}(n)\right|∑ start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) | = italic_α | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) |. For example, if all indirect effects use paths of length 1 then the union is a disjoint union, so |^AtP(n)^AtP+GD(n)||(^AtP(n)^AtP+GD(n))|=|^AtP(n)^AtPdirect(n)|subscriptsubscript^AtP𝑛subscript^subscriptAtP+GD𝑛subscriptsubscript^AtP𝑛subscript^subscriptAtP+GD𝑛subscript^AtP𝑛superscriptsubscript^AtPdirect𝑛\sum_{\ell}\left|\hat{\mathcal{I}}_{\text{AtP}}(n)-\hat{\mathcal{I}}_{\text{% AtP+GD}_{\ell}}(n)\right|\geq\left|\sum_{\ell}\left(\hat{\mathcal{I}}_{\text{% AtP}}(n)-\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n)\right)\right|=\left|\hat{% \mathcal{I}}_{\text{AtP}}(n)-\hat{\mathcal{I}}_{\text{AtP}}^{\text{direct}}(n)\right|∑ start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) | ≥ | ∑ start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) ) | = | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) |, so α1𝛼1\alpha\geq 1italic_α ≥ 1. Now:

K|^AtP+GD(n)|subscript𝐾subscript^subscriptAtP+GD𝑛\displaystyle\sum_{\ell\in K}\left|\hat{\mathcal{I}}_{\text{AtP+GD}_{\ell}}(n)\right|∑ start_POSTSUBSCRIPT roman_ℓ ∈ italic_K end_POSTSUBSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) | K|^AtP(n)^AtP+GD(n)||K||^AtP(n)|absentsubscript𝐾subscript^AtP𝑛subscript^subscriptAtP+GD𝑛𝐾subscript^AtP𝑛\displaystyle\geq\sum_{\ell\in K}\left|\hat{\mathcal{I}}_{\text{AtP}}(n)-\hat{% \mathcal{I}}_{\text{AtP+GD}_{\ell}}(n)\right|-|K|\left|\hat{\mathcal{I}}_{% \text{AtP}}(n)\right|≥ ∑ start_POSTSUBSCRIPT roman_ℓ ∈ italic_K end_POSTSUBSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) | - | italic_K | | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) | (22)
=α|^AtP(n)^AtPdirect(n)||K||^AtP(n)|absent𝛼subscript^AtP𝑛superscriptsubscript^AtPdirect𝑛𝐾subscript^AtP𝑛\displaystyle=\alpha\left|\hat{\mathcal{I}}_{\text{AtP}}(n)-\hat{\mathcal{I}}_% {\text{AtP}}^{\text{direct}}(n)\right|-|K|\left|\hat{\mathcal{I}}_{\text{AtP}}% (n)\right|= italic_α | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) - over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) | - | italic_K | | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) | (23)
α|^AtPdirect(n)|(|K|+α)|^AtP(n)|absent𝛼superscriptsubscript^AtPdirect𝑛𝐾𝛼subscript^AtP𝑛\displaystyle\geq\alpha\left|\hat{\mathcal{I}}_{\text{AtP}}^{\text{direct}}(n)% \right|-(|K|+\alpha)\left|\hat{\mathcal{I}}_{\text{AtP}}(n)\right|≥ italic_α | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) | - ( | italic_K | + italic_α ) | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) | (24)
1L1=1L|^AtP+GD(n)|thereforeabsent1𝐿1superscriptsubscript1𝐿subscript^subscriptAtP+GD𝑛\displaystyle\therefore\frac{1}{L-1}\sum_{\ell=1}^{L}\left|\hat{\mathcal{I}}_{% \text{AtP+GD}_{\ell}}(n)\right|∴ divide start_ARG 1 end_ARG start_ARG italic_L - 1 end_ARG ∑ start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) | =1L1K|^AtP+GD(n)|+L|K|1L1|^AtP(n)|absent1𝐿1subscript𝐾subscript^subscriptAtP+GD𝑛𝐿𝐾1𝐿1subscript^AtP𝑛\displaystyle=\frac{1}{L-1}\sum_{\ell\in K}\left|\hat{\mathcal{I}}_{\text{AtP+% GD}_{\ell}}(n)\right|+\frac{L-|K|-1}{L-1}\left|\hat{\mathcal{I}}_{\text{AtP}}(% n)\right|= divide start_ARG 1 end_ARG start_ARG italic_L - 1 end_ARG ∑ start_POSTSUBSCRIPT roman_ℓ ∈ italic_K end_POSTSUBSCRIPT | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP+GD start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n ) | + divide start_ARG italic_L - | italic_K | - 1 end_ARG start_ARG italic_L - 1 end_ARG | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) | (25)
αL1|^AtPdirect(n)|+L2|K|1αL1|^AtP(n)|absent𝛼𝐿1superscriptsubscript^AtPdirect𝑛𝐿2𝐾1𝛼𝐿1subscript^AtP𝑛\displaystyle\geq\frac{\alpha}{L-1}\left|\hat{\mathcal{I}}_{\text{AtP}}^{\text% {direct}}(n)\right|+\frac{L-2|K|-1-\alpha}{L-1}\left|\hat{\mathcal{I}}_{\text{% AtP}}(n)\right|≥ divide start_ARG italic_α end_ARG start_ARG italic_L - 1 end_ARG | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) | + divide start_ARG italic_L - 2 | italic_K | - 1 - italic_α end_ARG start_ARG italic_L - 1 end_ARG | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) | (26)

And the RHS is an improvement over |^AtP(n)|subscript^AtP𝑛\left|\hat{\mathcal{I}}_{\text{AtP}}(n)\right|| over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) | so long as α|^AtPdirect(n)|>(2|K|+α)|^AtP(n)|𝛼superscriptsubscript^AtPdirect𝑛2𝐾𝛼subscript^AtP𝑛\alpha\left|\hat{\mathcal{I}}_{\text{AtP}}^{\text{direct}}(n)\right|>(2|K|+% \alpha)\left|\hat{\mathcal{I}}_{\text{AtP}}(n)\right|italic_α | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT ( italic_n ) | > ( 2 | italic_K | + italic_α ) | over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) |, which is likely given the assumptions.

Ultimately, though, the desiderata are validated by the experiments, which consistently show GradDrops either decreasing or leaving untouched the number of false negatives, and thus improving performance apart from the initial upfront cost of the extra backwards passes.

A.3 Algorithm for computing diagnostics

Given summary statistics i¯±subscript¯𝑖plus-or-minus\bar{i}_{\pm}over¯ start_ARG italic_i end_ARG start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT, s±subscript𝑠plus-or-minuss_{\pm}italic_s start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT and count±subscriptcountplus-or-minus\text{count}_{\pm}count start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT for every node n𝑛nitalic_n, obtained from Algorithm 1, and a threshold θ>0𝜃0\theta>0italic_θ > 0 we can use Welch’s t𝑡titalic_t-test Welch (1947) to test the hypothesis that |i¯+i¯|θsubscript¯𝑖subscript¯𝑖𝜃|\bar{i}_{+}-\bar{i}_{-}|\geq\theta| over¯ start_ARG italic_i end_ARG start_POSTSUBSCRIPT + end_POSTSUBSCRIPT - over¯ start_ARG italic_i end_ARG start_POSTSUBSCRIPT - end_POSTSUBSCRIPT | ≥ italic_θ. Concretely we compute the t𝑡titalic_t-statistic via

si¯±subscript𝑠subscript¯𝑖plus-or-minus\displaystyle s_{\bar{i}_{\pm}}italic_s start_POSTSUBSCRIPT over¯ start_ARG italic_i end_ARG start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT end_POSTSUBSCRIPT =s±count±absentsubscript𝑠plus-or-minussubscriptcountplus-or-minus\displaystyle=\frac{s_{\pm}}{\sqrt{\text{count}_{\pm}}}= divide start_ARG italic_s start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG count start_POSTSUBSCRIPT ± end_POSTSUBSCRIPT end_ARG end_ARG (28)
t𝑡\displaystyle titalic_t =θ|i¯+i¯|si¯+2+si¯2.absent𝜃subscript¯𝑖subscript¯𝑖superscriptsubscript𝑠subscript¯𝑖2superscriptsubscript𝑠subscript¯𝑖2\displaystyle=\frac{\theta-|\bar{i}_{+}-\bar{i}_{-}|}{\sqrt{s_{\bar{i}_{+}}^{2% }+s_{\bar{i}_{-}}^{2}}}.= divide start_ARG italic_θ - | over¯ start_ARG italic_i end_ARG start_POSTSUBSCRIPT + end_POSTSUBSCRIPT - over¯ start_ARG italic_i end_ARG start_POSTSUBSCRIPT - end_POSTSUBSCRIPT | end_ARG start_ARG square-root start_ARG italic_s start_POSTSUBSCRIPT over¯ start_ARG italic_i end_ARG start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_s start_POSTSUBSCRIPT over¯ start_ARG italic_i end_ARG start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG . (29)

The effective degrees of freedom ν𝜈\nuitalic_ν can be approximated with the Welch–Satterthwaite equation

νWelch=(s+2count++s2count)2s+4count+2(count+1)+s4count2(count1)subscript𝜈Welchsuperscriptsuperscriptsubscript𝑠2subscriptcountsuperscriptsubscript𝑠2subscriptcount2superscriptsubscript𝑠4subscriptsuperscriptcount2subscriptcount1superscriptsubscript𝑠4subscriptsuperscriptcount2subscriptcount1\displaystyle\nu_{\text{Welch}}=\frac{\left(\frac{s_{+}^{2}}{\text{count}_{+}}% +\frac{s_{-}^{2}}{\text{count}_{-}}\right)^{2}}{\frac{s_{+}^{4}}{\text{count}^% {2}_{+}(\text{count}_{+}-1)}+\frac{s_{-}^{4}}{\text{count}^{2}_{-}(\text{count% }_{-}-1)}}italic_ν start_POSTSUBSCRIPT Welch end_POSTSUBSCRIPT = divide start_ARG ( divide start_ARG italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG count start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_ARG + divide start_ARG italic_s start_POSTSUBSCRIPT - end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG count start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG divide start_ARG italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG start_ARG count start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( count start_POSTSUBSCRIPT + end_POSTSUBSCRIPT - 1 ) end_ARG + divide start_ARG italic_s start_POSTSUBSCRIPT - end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG start_ARG count start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ( count start_POSTSUBSCRIPT - end_POSTSUBSCRIPT - 1 ) end_ARG end_ARG (30)

We then compute the probability (p𝑝pitalic_p-value) of obtaining a t𝑡titalic_t at least as large as observed, using the cumulative distribution function of Student’s t(x;νWelch)𝑡𝑥subscript𝜈Welcht\Big{(}x;\nu_{\text{Welch}}\Big{)}italic_t ( italic_x ; italic_ν start_POSTSUBSCRIPT Welch end_POSTSUBSCRIPT ) at the appropriate points. We take the max of the individual p𝑝pitalic_p-values of all nodes to obtain an aggregate upper bound. Finally, we use binary search to find the largest threshold θ𝜃\thetaitalic_θ that still has an aggregate p𝑝pitalic_p-value smaller than a given target p𝑝pitalic_p value. We show multiple such diagnostic curves in Section B.3, for different confidence levels (1ptarget1subscript𝑝target1-p_{\text{target}}1 - italic_p start_POSTSUBSCRIPT target end_POSTSUBSCRIPT).

Appendix B Experiments

B.1 Prompt Distributions

B.1.1 IOI

We use the following prompt template:

BOSWhen␣[A]␣and␣[B]␣went␣to␣the␣bar,␣[A/C]␣gave␣a␣drink␣to␣[B/A]

Each clean prompt xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT uses two names A and B with completion B, while a noise prompt xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT uses names A, B, and C with completion A. We construct all possible such assignments where names are chosen from the set of {Michael, Jessica, Ashley, Joshua, David, Sarah}, resulting in 120 prompt pairs.

B.1.2 A-AN

We use the following prompt template to induce the prediction of an indefinite article.

BOSI␣want␣one␣pear.␣Can␣you␣pick␣up␣a␣pear␣for␣me?
␣I␣want␣one␣orange.␣Can␣you␣pick␣up␣an␣orange␣for␣me?
␣I␣want␣one␣[OBJECT].␣Can␣you␣pick␣up␣[a/an]

We found that zero shot performance of small models was relatively low, but performance improved drastically when providing a single example of each case. Model performance was sensitive to the ordering of the two examples but was better than random in all cases. The magnitude and sign of the impact of the few-shot ordering was inconsistent.

Clean prompts xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT contain objects inducing ‘␣a’, one of {boat, coat, drum, horn, map, pipe, screw, stamp, tent, wall}. Noise prompts xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT contain objects inducing ‘␣an’, one of {apple, ant, axe, award, elephant, egg, orange, oven, onion, umbrella}. This results in a total of 100 prompt pairs.

B.2 Cancellation across a distribution

As mention in Section 2, we average the magnitudes of effects across a distribution, rather than taking the magnitude of the average effect. We do this because cancellation of effects is happening frequently across a distribution, which, together with imprecise estimates, could lead to significant false negatives. A proper ablation study to quantify this effect exactly is beyond the scope of this work. In Figure 10, we show the degree of cancellation across the IOI distribution for various model sizes. For this we define the Cancellation Ratio of node n𝑛nitalic_n as

1|xclean,xnoise(n;xclean,xnoise)|xclean,xnoise|(n;xclean,xnoise)|.1subscriptsuperscript𝑥cleansuperscript𝑥noise𝑛superscript𝑥cleansuperscript𝑥noisesubscriptsuperscript𝑥cleansuperscript𝑥noise𝑛superscript𝑥cleansuperscript𝑥noise\displaystyle 1-\frac{\left|\sum_{x^{\text{clean}},x^{\text{noise}}}\mathcal{I% }(n;x^{\text{clean}},x^{\text{noise}})\right|}{\sum_{x^{\text{clean}},x^{\text% {noise}}}\left|\mathcal{I}(n;x^{\text{clean}},x^{\text{noise}})\right|}.1 - divide start_ARG | ∑ start_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_I ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) | end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | caligraphic_I ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) | end_ARG .
Refer to caption
(a) Pythia-410M
Refer to caption
(b) Pythia-1B
Refer to caption
(c) Pythia-2.8B
Refer to caption
(d) Pythia-12B
Figure 10: Cancellation ratio across IOI for various model sizes. A ratio of 1 means positive and negative effects cancel out across the distribution, whereas a ratio of 0 means only either negative or positive effects exist across the distribution. We report cancellation ratio for different percentiles of nodes based on xclean,xnoise|(n;xclean,xnoise)|subscriptsuperscript𝑥cleansuperscript𝑥noise𝑛superscript𝑥cleansuperscript𝑥noise\sum_{x^{\text{clean}},x^{\text{noise}}}\left|\mathcal{I}(n;x^{\text{clean}},x% ^{\text{noise}})\right|∑ start_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | caligraphic_I ( italic_n ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) |.

B.3 Additional detailed results

We show the diagnostic measurements for Pythia-12B across all investigated distributions in Figure 11(b), and cost of verified 100% recall curves for all models and settings in Figures 12(c) and 13(c).

Figure 11: Diagnostic of false negatives for 12B across distributions.
Refer to caption
a.i IOI-PP
Refer to caption
a.ii RAND-PP
Refer to caption
a.iii IOI
(a) AttentionNodes
Refer to caption
b.i CITY-PP
Refer to caption
b.ii RAND-PP
Refer to caption
b.iii A-AN
(b) NeuronNodes
Figure 12: Cost of verified 100% recall curves, swee** across models and settings for NeuronNodes
Refer to caption
a.i Pythia 410M
Refer to caption
a.ii Pythia 1B
Refer to caption
a.iii Pythia 2.8B
Refer to caption
a.iv Pythia 12B
(a) CITY-PP
Refer to caption
b.i Pythia 410M
Refer to caption
b.ii Pythia 1B
Refer to caption
b.iii Pythia 2.8B
Refer to caption
b.iv Pythia 12B
(b) RAND-PP
Refer to caption
c.i Pythia 410M
Refer to caption
c.ii Pythia 1B
Refer to caption
c.iii Pythia 2.8B
Refer to caption
c.iv Pythia 12B
(c) A-AN distribution
Figure 13: Cost of verified 100% recall curves, swee** across models and settings for AttentionNodes
Refer to caption
a.i Pythia 410M
Refer to caption
a.ii Pythia 1B
Refer to caption
a.iii Pythia 2.8B
Refer to caption
a.iv Pythia 12B
(a) IOI-PP
Refer to caption
b.i Pythia 410M
Refer to caption
b.ii Pythia 1B
Refer to caption
b.iii Pythia 2.8B
Refer to caption
b.iv Pythia 12B
(b) RAND-PP
Refer to caption
c.i Pythia 410M
Refer to caption
c.ii Pythia 1B
Refer to caption
c.iii Pythia 2.8B
Refer to caption
c.iv Pythia 12B
(c) IOI distribution

B.4 Metrics

In this paper we focus on the difference in loss (negative log probability) as the metric \mathcal{L}caligraphic_L. We provide some evidence that AtP(*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT) is not sensitive to the choice of \mathcal{L}caligraphic_L. For Pythia-12B, on IOI-PP and IOI, we show the rank scatter plots in  Figure 14 for three different metrics.

For IOI, we also show that performance of AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT looks notably worse when effects are evaluated via denoising instead of noising (cf. Section 2.1). As of now we do not have a satisfactory explanation for this observation.

Refer to caption
Figure 14: True ranks against AtP*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT ranks on Pythia-12B using various metrics \mathcal{L}caligraphic_L. The last row shows the effect in the denoising (rather than noising) setting; we speculate that the lower-right subplot (log-odds denoising) is similar to the lower-middle one (logit-diff denoising) because IOI produces a bimodal distribution over the correct and alternate next token.

B.5 Hyperparameter selection

The iterative baseline, and the AtP-based methods, have no hyperparameters. In general, we used 5 random seeds for each hyperparameter setting, and selected the setting that produced the lowest IRWRGM cost (see Section 4.2).

For Subsampling, the two hyperparameters are the Bernoulli sampling probability p𝑝pitalic_p, and the number of samples to collect before verifying nodes in decreasing order of c^SSsubscript^𝑐SS\hat{c}_{\text{SS}}over^ start_ARG italic_c end_ARG start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT. p𝑝pitalic_p was chosen from {0.01, 0.03}141414We observed early on that larger values of p𝑝pitalic_p were consistently underperforming. We leave it to future work to investigate more granular and smaller values for p𝑝pitalic_p.. The number of steps was chosen among power-of-2 numbers of batches, where the batch size depended on the setting.

For Blocks, we swept across block sizes 2, 6, 20, 60, 250. For Hierarchical, we used a branching factor of B=3𝐵3B=3italic_B = 3, because of the following heuristic argument. If all but one node had zero effect, then discovering that node would be a matter of iterating through the hierarchy levels. We’d have number of levels logB|N|subscript𝐵𝑁\log_{B}|N|roman_log start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT | italic_N |, and at each level, B𝐵Bitalic_B forward passes would be required to find which lower-level block the special node is in – and thus the cost of finding the node would be BlogB|N|=BlogBlog|N|𝐵subscript𝐵𝑁𝐵𝐵𝑁B\log_{B}|N|=\frac{B}{\log B}\log|N|italic_B roman_log start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT | italic_N | = divide start_ARG italic_B end_ARG start_ARG roman_log italic_B end_ARG roman_log | italic_N |. BlogB𝐵𝐵\frac{B}{\log B}divide start_ARG italic_B end_ARG start_ARG roman_log italic_B end_ARG is minimized at B=e𝐵𝑒B=eitalic_B = italic_e, or at B=3𝐵3B=3italic_B = 3 if B𝐵Bitalic_B must be an integer. The other hyperparameter is the number of levels; we swept this from 2 to 12.

Appendix C AtP variants

C.1 Residual-site AtP and Layer normalization

Let’s consider the behaviour of AtP on sites that contain much or all of the total signal in the residual stream, such as residual-stream sites. Nanda (2022) described a concern about this behaviour: that linear approximation of the layer normalization would do poorly if the patched value is significantly different than the clean one, but with a similar norm. The proposed modification to AtP to account for this was to hold the scaling factors (in the denominators) fixed when computing the backwards pass. Here we’ll present an analysis of how this modification would affect the approximation error of AtP. (Empirical investigation of this issue is beyond the scope of this paper.)

Concretely, let the node under consideration be n𝑛nitalic_n, with clean and alternate values ncleansuperscript𝑛cleann^{\mathrm{clean}}italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT and nnoisesuperscript𝑛noisen^{\mathrm{noise}}italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT; and for simplicity, let’s assume the model does nothing more than an unparametrized RMSNorm (n):=n/|n|assign𝑛𝑛𝑛\mathcal{M}(n):=n/|n|caligraphic_M ( italic_n ) := italic_n / | italic_n |. Let’s now consider how well (nnoise)superscript𝑛noise\mathcal{M}(n^{\mathrm{noise}})caligraphic_M ( italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT ) is approximated, both by its first-order approximation ^AtP(nnoise):=(nclean)+(nclean)(nnoisenclean)assignsubscript^AtPsuperscript𝑛noisesuperscript𝑛cleansuperscriptsuperscript𝑛cleanperpendicular-tosuperscript𝑛noisesuperscript𝑛clean\hat{\mathcal{M}}_{\text{AtP}}(n^{\mathrm{noise}}):=\mathcal{M}(n^{\mathrm{% clean}})+\mathcal{M}(n^{\mathrm{clean}})^{\perp}(n^{\mathrm{noise}}-n^{\mathrm% {clean}})over^ start_ARG caligraphic_M end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT ) := caligraphic_M ( italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT ) + caligraphic_M ( italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT ( italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT - italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT ) where (nclean)=I(nclean)(nclean)superscriptsuperscript𝑛cleanperpendicular-to𝐼superscript𝑛cleansuperscriptsuperscript𝑛clean\mathcal{M}(n^{\mathrm{clean}})^{\perp}=I-\mathcal{M}(n^{\mathrm{clean}})% \mathcal{M}(n^{\mathrm{clean}})^{\intercal}caligraphic_M ( italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT = italic_I - caligraphic_M ( italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT ) caligraphic_M ( italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT is the projection to the hyperplane orthogonal to (nclean)superscript𝑛clean\mathcal{M}(n^{\mathrm{clean}})caligraphic_M ( italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT ), and by the variant that fixes the denominator: ^AtP+frozenLN(nnoise):=nnoise/|nclean|assignsubscript^AtP+frozenLNsuperscript𝑛noisesuperscript𝑛noisesuperscript𝑛clean\hat{\mathcal{M}}_{\text{AtP+frozenLN}}(n^{\mathrm{noise}}):=n^{\mathrm{noise}% }/|n^{\mathrm{clean}}|over^ start_ARG caligraphic_M end_ARG start_POSTSUBSCRIPT AtP+frozenLN end_POSTSUBSCRIPT ( italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT ) := italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT / | italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT |.

To quantify the error in the above, we’ll measure the error ϵitalic-ϵ\epsilonitalic_ϵ in terms of Euclidean distance. Let’s also assume, without loss of generality, that |nclean|=1superscript𝑛clean1|n^{\mathrm{clean}}|=1| italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT | = 1. Geometrically, then, (n)𝑛\mathcal{M}(n)caligraphic_M ( italic_n ) is a projection onto the unit hypersphere, AtP(n)subscriptAtP𝑛\mathcal{M}_{\text{AtP}}(n)caligraphic_M start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n ) is a projection onto the tangent hyperplane at ncleansuperscript𝑛cleann^{\mathrm{clean}}italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT, and AtP+frozenLNsubscriptAtP+frozenLN\mathcal{M}_{\text{AtP+frozenLN}}caligraphic_M start_POSTSUBSCRIPT AtP+frozenLN end_POSTSUBSCRIPT is the identity function.

Now, let’s define orthogonal coordinates (x,y)𝑥𝑦(x,y)( italic_x , italic_y ) on the plane spanned by nclean,nnoisesuperscript𝑛cleansuperscript𝑛noisen^{\mathrm{clean}},n^{\mathrm{noise}}italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT , italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT, such that ncleansuperscript𝑛cleann^{\mathrm{clean}}italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT is mapped to (1,0)10(1,0)( 1 , 0 ) and nnoisesuperscript𝑛noisen^{\mathrm{noise}}italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT is mapped to (x,y)𝑥𝑦(x,y)( italic_x , italic_y ), with y0𝑦0y\geq 0italic_y ≥ 0. Then, ϵAtP:=|^(nnoise)(nnoise)|=2+y22x+y2x2+y2assignsubscriptitalic-ϵAtP^superscript𝑛noisesuperscript𝑛noise2superscript𝑦22𝑥superscript𝑦2superscript𝑥2superscript𝑦2\epsilon_{\text{AtP}}:=\left|\hat{\mathcal{M}}(n^{\mathrm{noise}})-\mathcal{M}% (n^{\mathrm{noise}})\right|=\sqrt{2+y^{2}-2\frac{x+y^{2}}{\sqrt{x^{2}+y^{2}}}}italic_ϵ start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT := | over^ start_ARG caligraphic_M end_ARG ( italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT ) - caligraphic_M ( italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT ) | = square-root start_ARG 2 + italic_y start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 divide start_ARG italic_x + italic_y start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_y start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG end_ARG, while ϵAtP+frozenLN:=|^fix(nnoise)(nnoise)|=|x2+y21|assignsubscriptitalic-ϵAtP+frozenLNsubscript^fixsuperscript𝑛noisesuperscript𝑛noisesuperscript𝑥2superscript𝑦21\epsilon_{\text{AtP+frozenLN}}:=\left|\hat{\mathcal{M}}_{\mathrm{fix}}(n^{% \mathrm{noise}})-\mathcal{M}(n^{\mathrm{noise}})\right|=\left|\sqrt{x^{2}+y^{2% }}-1\right|italic_ϵ start_POSTSUBSCRIPT AtP+frozenLN end_POSTSUBSCRIPT := | over^ start_ARG caligraphic_M end_ARG start_POSTSUBSCRIPT roman_fix end_POSTSUBSCRIPT ( italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT ) - caligraphic_M ( italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT ) | = | square-root start_ARG italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_y start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG - 1 |.

Plotting the error in Figure 15, we can see that, as might be expected, freezing the layer norm denominators helps whenever nnoisesuperscript𝑛noisen^{\mathrm{noise}}italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT indeed has the same norm as ncleansuperscript𝑛cleann^{\mathrm{clean}}italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT, and (barring weird cases with x>1𝑥1x>1italic_x > 1) whenever the cosine-similarity is less than 1212\frac{1}{2}divide start_ARG 1 end_ARG start_ARG 2 end_ARG; but largely hurts if nnoisesuperscript𝑛noisen^{\mathrm{noise}}italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT is close to ncleansuperscript𝑛cleann^{\mathrm{clean}}italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT. This illustrates that, while freezing the denominators will generally be unhelpful when patch distances are small relative to the full residual signal (as with almost all nodes considered in this paper), it will likely be helpful in a different setting of patching residual streams, which could be quite unaligned but have similar norm.

Refer to caption
(a) ϵAtPsubscriptitalic-ϵAtP\epsilon_{\text{AtP}}italic_ϵ start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT
Refer to caption
(b) ϵAtP+frozenLNsubscriptitalic-ϵAtP+frozenLN\epsilon_{\text{AtP+frozenLN}}italic_ϵ start_POSTSUBSCRIPT AtP+frozenLN end_POSTSUBSCRIPT

Refer to caption
(c) ϵAtP+frozenLNϵAtPsubscriptitalic-ϵAtP+frozenLNsubscriptitalic-ϵAtP\epsilon_{\text{AtP+frozenLN}}-\epsilon_{\text{AtP}}italic_ϵ start_POSTSUBSCRIPT AtP+frozenLN end_POSTSUBSCRIPT - italic_ϵ start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT
Figure 15: A comparison of how AtP and AtP with frozen layernorm scaling behave in a toy setting where the model we’re trying to approximate is just (n):=n/|n|assign𝑛𝑛𝑛\mathcal{M}(n):=n/|n|caligraphic_M ( italic_n ) := italic_n / | italic_n |. The red region is where frozen layernorm scaling helps; the blue region is where it hurts. We find that unless x>1𝑥1x>1italic_x > 1, frozen layernorm scaling always has lower error when the cosine-similarity between nnoisesuperscript𝑛noisen^{\mathrm{noise}}italic_n start_POSTSUPERSCRIPT roman_noise end_POSTSUPERSCRIPT and ncleansuperscript𝑛cleann^{\mathrm{clean}}italic_n start_POSTSUPERSCRIPT roman_clean end_POSTSUPERSCRIPT is <12absent12<\frac{1}{2}< divide start_ARG 1 end_ARG start_ARG 2 end_ARG (in other words the angle >60absentsuperscript60>60^{\circ}> 60 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT), but often has higher error otherwise.

C.2 Edge AtP and AtP*

Here we will investigate edge attribution patching, and how the cost scales if we use GradDrop and/or QK fix. (For this section we’ll focus on a single prompt pair.)

First, let’s review what edge attribution patching is trying to approximate, and how it works.

C.2.1 Edge intervention effects

Given nodes n1,n2subscript𝑛1subscript𝑛2n_{1},n_{2}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT where n1subscript𝑛1n_{1}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is upstream of n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, if we were to patch in an alternate value for n1subscript𝑛1n_{1}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, this could impact n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT in a complicated nonlinear way. As discussed in 3.1.2, because LLMs have a residual stream, the “direct effect” can be understood as the one holding all other possible intermediate nodes between n1subscript𝑛1n_{1}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT fixed – and it’s a relatively simple function, composed of transforming the alternate value n1(xnoise)subscript𝑛1superscript𝑥noisen_{1}(x^{\text{noise}})italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) to a residual stream contribution rout,1(xclean|do(n1n1(xnoise)))subscript𝑟outsubscript1conditionalsuperscript𝑥cleandosubscript𝑛1subscript𝑛1superscript𝑥noiser_{\text{out},{\ell_{1}}}(x^{\text{clean}}|\operatorname{do}(n_{1}\leftarrow n% _{1}(x^{\text{noise}})))italic_r start_POSTSUBSCRIPT out , roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT | roman_do ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) ), then carrying it along the residual stream to an input rin,2=rin,2(xclean)+(rout,1rout,1(xclean))subscript𝑟insubscript2subscript𝑟insubscript2superscript𝑥cleansubscript𝑟outsubscript1subscript𝑟outsubscript1superscript𝑥cleanr_{\text{in},{\ell_{2}}}=r_{\text{in},{\ell_{2}}}(x^{\text{clean}})+(r_{\text{% out},{\ell_{1}}}-r_{\text{out},{\ell_{1}}}(x^{\text{clean}}))italic_r start_POSTSUBSCRIPT in , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = italic_r start_POSTSUBSCRIPT in , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) + ( italic_r start_POSTSUBSCRIPT out , roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_r start_POSTSUBSCRIPT out , roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ), and transforming that into a value n2directsuperscriptsubscript𝑛2directn_{2}^{\text{direct}}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT.

In the above, 1subscript1\ell_{1}roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are the semilayers containing n1subscript𝑛1n_{1}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, respectively. Let’s define 𝐧(1,2)subscript𝐧subscript1subscript2\mathbf{n}_{(\ell_{1},\ell_{2})}bold_n start_POSTSUBSCRIPT ( roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT to be the set of non-residual nodes between semilayers 1subscript1\ell_{1}roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Then, we can define the resulting n2directsuperscriptsubscript𝑛2directn_{2}^{\text{direct}}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct end_POSTSUPERSCRIPT as:

n2direct1(xclean|do(n1n1(xnoise))):=n2(xclean|do(n1n1(xnoise)),do(𝐧(1,2)𝐧(1,2)(xclean))).assignsuperscriptsubscript𝑛2superscriptdirectsubscript1conditionalsuperscript𝑥cleandosubscript𝑛1subscript𝑛1superscript𝑥noisesubscript𝑛2conditionalsuperscript𝑥cleandosubscript𝑛1subscript𝑛1superscript𝑥noisedosubscript𝐧subscript1subscript2subscript𝐧subscript1subscript2superscript𝑥cleann_{2}^{\text{direct}^{\ell_{1}}}(x^{\text{clean}}|\operatorname{do}(n_{1}% \leftarrow n_{1}(x^{\text{noise}}))):=n_{2}(x^{\text{clean}}|\operatorname{do}% (n_{1}\leftarrow n_{1}(x^{\text{noise}})),\operatorname{do}(\mathbf{n}_{(\ell_% {1},\ell_{2})}\leftarrow\mathbf{n}_{(\ell_{1},\ell_{2})}(x^{\text{clean}}))).italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct start_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT | roman_do ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) ) := italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT | roman_do ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) , roman_do ( bold_n start_POSTSUBSCRIPT ( roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ← bold_n start_POSTSUBSCRIPT ( roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) ) .

The residual-stream input rin,2direct1(xclean|do(n1n1(xnoise)))superscriptsubscript𝑟insubscript2superscriptdirectsubscript1conditionalsuperscript𝑥cleandosubscript𝑛1subscript𝑛1superscript𝑥noiser_{\text{in},{\ell_{2}}}^{\text{direct}^{\ell_{1}}}(x^{\text{clean}}|% \operatorname{do}(n_{1}\leftarrow n_{1}(x^{\text{noise}})))italic_r start_POSTSUBSCRIPT in , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct start_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT | roman_do ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) ) is defined similarly.

Finally, n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT itself isn’t enough to compute the metric \mathcal{L}caligraphic_L – for that we also need to let the forward pass (xclean)superscript𝑥clean\mathcal{M}(x^{\text{clean}})caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) run using the modified n2direct1(xclean|do(n1n1(xnoise)))superscriptsubscript𝑛2superscriptdirectsubscript1conditionalsuperscript𝑥cleandosubscript𝑛1subscript𝑛1superscript𝑥noisen_{2}^{\text{direct}^{\ell_{1}}}(x^{\text{clean}}|\operatorname{do}(n_{1}% \leftarrow n_{1}(x^{\text{noise}})))italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct start_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT | roman_do ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) ), while removing all other effects of n1subscript𝑛1n_{1}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT (i.e. not patching it).

Writing this out, we have edge intervention effect

(n1n2;xclean,xnoise)subscript𝑛1subscript𝑛2superscript𝑥cleansuperscript𝑥noise\displaystyle\mathcal{I}(n_{1}\rightarrow n_{2};x^{\text{clean}},x^{\text{% noise}})caligraphic_I ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) :=((xclean|do(n2n2direct1(xclean|do(n1n1(xnoise))))))\displaystyle:=\mathcal{L}(\mathcal{M}(x^{\text{clean}}|\operatorname{do}(n_{2% }\leftarrow n_{2}^{\text{direct}^{\ell_{1}}}(x^{\text{clean}}|\operatorname{do% }(n_{1}\leftarrow n_{1}(x^{\text{noise}})))))):= caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT | roman_do ( italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ← italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT direct start_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT | roman_do ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) ) ) ) )
((xclean)).superscript𝑥clean\displaystyle\mathrel{\phantom{:=}}-\mathcal{L}(\mathcal{M}(x^{\text{clean}})).- caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) . (31)

C.2.2 Nodes and Edges

Let’s briefly consider what edges we’d want to be evaluating this on. In Section 4.1, we were able to conveniently separate attention nodes from MLP neurons, knowing that to handle both kinds of nodes, we’d just need to be able handle each kind of node on its own, and then combine the results. For edge interventions this of course isn’t true, because edges can go from MLP neurons to attention nodes, and vice versa. For the purposes of this section, we’ll assume that the node set N𝑁Nitalic_N contains the attention nodes, and for MLPs either a node per layer (as in Syed et al. (2023)), or a node per neuron (as in the NeuronNodes setting).

Regarding the edges, the MLP nodes can reasonably be connected with any upstream or downstream node, but this isn’t true for the attention nodes, which have more of a structure amongst themselves: the key, query, and value nodes for an attention head can only affect downstream nodes via the attention output nodes for that head, and vice versa. As a result, on edges between different semilayers, upstream attention nodes must be attention head outputs, and downstream attention nodes must be keys, queries, or values. In addition, there are some within-attention-head edges, connecting each query node to the output node in the same position, and each key and value node to output nodes in causally affectable positions.

C.2.3 Edge AtP

As with node activation patching, the edge intervention effect (n1n2;xclean,xnoise)subscript𝑛1subscript𝑛2superscript𝑥cleansuperscript𝑥noise\mathcal{I}(n_{1}\rightarrow n_{2};x^{\text{clean}},x^{\text{noise}})caligraphic_I ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) is costly to evaluate directly for every edge, since a forward pass is required each time. However, as with AtP, we can apply first-order approximations: we define

^AtP(n1n2;xclean,xnoise)subscript^AtPsubscript𝑛1subscript𝑛2superscript𝑥cleansuperscript𝑥noise\displaystyle\hat{\mathcal{I}}_{\text{AtP}}(n_{1}\rightarrow n_{2};x^{\text{% clean}},x^{\text{noise}})over^ start_ARG caligraphic_I end_ARG start_POSTSUBSCRIPT AtP end_POSTSUBSCRIPT ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ; italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) :=(Δrn1AtP(xclean,xnoise))rn2AtP((xclean)),assignabsentsuperscriptΔsuperscriptsubscript𝑟subscript𝑛1AtPsuperscript𝑥cleansuperscript𝑥noisesuperscriptsubscriptsubscript𝑟subscript𝑛2AtPsuperscript𝑥clean\displaystyle:=\left(\Delta r_{n_{1}}^{\text{AtP}}(x^{\text{clean}},x^{\text{% noise}})\right)^{\intercal}\nabla_{r_{n_{2}}}^{\text{AtP}}\mathcal{L}(\mathcal% {M}(x^{\text{clean}})),:= ( roman_Δ italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) , (32)
where Δrn1AtP(xclean,xnoise)where Δsuperscriptsubscript𝑟subscript𝑛1AtPsuperscript𝑥cleansuperscript𝑥noise\displaystyle\text{where }\Delta r_{n_{1}}^{\text{AtP}}(x^{\text{clean}},x^{% \text{noise}})where roman_Δ italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) :=Jacn1(rout,1)(n1(xclean))(n1(xnoise)n1(xclean))assignabsentsubscriptJacsubscript𝑛1subscript𝑟outsubscript1subscript𝑛1superscript𝑥cleansubscript𝑛1superscript𝑥noisesubscript𝑛1superscript𝑥clean\displaystyle:=\operatorname{Jac}_{n_{1}}(r_{\text{out},{\ell_{1}}})(n_{1}(x^{% \text{clean}}))(n_{1}(x^{\text{noise}})-n_{1}(x^{\text{clean}})):= roman_Jac start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_r start_POSTSUBSCRIPT out , roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) - italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) (33)
and rn2AtP((xclean))and superscriptsubscriptsubscript𝑟subscript𝑛2AtPsuperscript𝑥clean\displaystyle\text{and }\nabla_{r_{n_{2}}}^{\text{AtP}}\mathcal{L}(\mathcal{M}% (x^{\text{clean}}))and ∇ start_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) :=(Jacrin,2(n2)(rin,2(xclean)))n2(((xclean)))(n2(xclean)),assignabsentsuperscriptsubscriptJacsubscript𝑟insubscript2subscript𝑛2subscript𝑟insubscript2superscript𝑥cleansubscriptsubscript𝑛2superscript𝑥cleansubscript𝑛2superscript𝑥clean\displaystyle:=\left(\operatorname{Jac}_{r_{\text{in},{\ell_{2}}}}(n_{2})(r_{% \text{in},{\ell_{2}}}(x^{\text{clean}}))\right)^{\intercal}\nabla_{n_{2}}(% \mathcal{L}(\mathcal{M}(x^{\text{clean}})))(n_{2}(x^{\text{clean}})),:= ( roman_Jac start_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT in , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( italic_r start_POSTSUBSCRIPT in , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) ) ( italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) , (34)

and this is a close approximation when n1(xnoise)n1(xclean)subscript𝑛1superscript𝑥noisesubscript𝑛1superscript𝑥cleann_{1}(x^{\text{noise}})\approx n_{1}(x^{\text{clean}})italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ≈ italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ).

A key benefit of this decomposition is that the first term depends only on n1subscript𝑛1n_{1}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, and the second term depends only on n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT; and they’re both easy to compute from a forward and backward pass on xcleansuperscript𝑥cleanx^{\text{clean}}italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT and a forward pass on xnoisesuperscript𝑥noisex^{\text{noise}}italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT, just like AtP itself.

Then, to complete the edge-AtP evaluation, what remains computationally is to evaluate all the dot products between nodes in different semilayers, at each token position. This requires dresidT(11L)|N|2/2subscript𝑑resid𝑇11𝐿superscript𝑁22d_{\mathrm{resid}}T(1-\frac{1}{L})|N|^{2}/2italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT italic_T ( 1 - divide start_ARG 1 end_ARG start_ARG italic_L end_ARG ) | italic_N | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 multiplications in total151515This formula omits edges within a single layer, for simplicity – but those are a small minority., where L𝐿Litalic_L is the number of layers, T𝑇Titalic_T is the number of tokens, and |N|𝑁|N|| italic_N | is the total number of nodes. This cost exceeds the cost of computing all Δrn1AtP(xclean,xnoise)Δsuperscriptsubscript𝑟subscript𝑛1AtPsuperscript𝑥cleansuperscript𝑥noise\Delta r_{n_{1}}^{\text{AtP}}(x^{\text{clean}},x^{\text{noise}})roman_Δ italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) and rn2AtP((xclean))superscriptsubscriptsubscript𝑟subscript𝑛2AtPsuperscript𝑥clean\nabla_{r_{n_{2}}}^{\text{AtP}}\mathcal{L}(\mathcal{M}(x^{\text{clean}}))∇ start_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) on Pythia 2.8B even with a single node per MLP layer; if we look at a larger model, or especially if we consider single-neuron nodes even for small models, the gap grows significantly.

Due to this observation, we’ll focus our attention on the quadratic part of the compute cost, pertaining to two nodes rather than just one – i.e. the number of multiplications in computing all (Δrn1AtP(xclean,xnoise))rn2AtP((xclean))superscriptΔsuperscriptsubscript𝑟subscript𝑛1AtPsuperscript𝑥cleansuperscript𝑥noisesuperscriptsubscriptsubscript𝑟subscript𝑛2AtPsuperscript𝑥clean(\Delta r_{n_{1}}^{\text{AtP}}(x^{\text{clean}},x^{\text{noise}}))^{\intercal}% \nabla_{r_{n_{2}}}^{\text{AtP}}\mathcal{L}(\mathcal{M}(x^{\text{clean}}))( roman_Δ italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ). Notably, we’ll also exclude within-attention-head edges from the “quadratic cost”: these edges, from some key, query, or value node to an attention output node can be handled by minor variations of the nodewise AtP or AtP* methods for the corresponding key, query, or value node.

C.2.4 MLPs

There are a couple of issues that can come up around the MLP nodes. One is that, similarly to the attention saturation issue described in Section 3.1.1, the linear approximation to the MLP may be fairly bad in some cases, creating significant false negatives if n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is an MLP node. Another issue is that if we use single-neuron nodes, then those are very numerous, making the dresidsubscript𝑑residd_{\mathrm{resid}}italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT-dimensional dot product per edge quite costly.

MLP saturation and fix

Just as clean activations that saturate the attention probability may have small gradients that lead to strongly underestimated effects, the same is true of the MLP nonlinearity. A similar fix is applicable: instead of using a linear approximation to the function from n1subscript𝑛1n_{1}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT to n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, we can linearly approximate the function from n1subscript𝑛1n_{1}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT to the preactivation n2,presubscript𝑛2pren_{2,\text{pre}}italic_n start_POSTSUBSCRIPT 2 , pre end_POSTSUBSCRIPT, and then recompute n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT using that, before multiplying by the gradient.

This kind of rearrangement, where the gradient-delta-activation dot product is computed in dn2subscript𝑑subscript𝑛2d_{n_{2}}italic_d start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT dimensions rather than dresidsubscript𝑑residd_{\mathrm{resid}}italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT, will come up again – we’ll call it the factored form of AtP.

If the nodes are neurons then the factored form requires no change to the number of multiplications; however, if they’re MLP layers then there’s a large increase in cost, by a factor of dneuronssubscript𝑑neuronsd_{\mathrm{neurons}}italic_d start_POSTSUBSCRIPT roman_neurons end_POSTSUBSCRIPT. This increase is mitigated by two factors: one is that this is a small minority of edges, outnumbered by the number of edges ending in attention nodes by 3×(# heads per layer)3# heads per layer3\times(\text{\# heads per layer})3 × ( # heads per layer ); the other is the potential for parameter sharing.

Neuron edges and parameter sharing

A useful observation is that each edge, across different token161616Also across different batch entries, if we do this on more than one prompt pair. positions, reuses the same parameter matrices in Jacn1(rout,1)(n1(xclean))subscriptJacsubscript𝑛1subscript𝑟outsubscript1subscript𝑛1superscript𝑥clean\operatorname{Jac}_{n_{1}}(r_{\text{out},{\ell_{1}}})(n_{1}(x^{\text{clean}}))roman_Jac start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_r start_POSTSUBSCRIPT out , roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) and Jacrin,2(n2)(rin,2(xclean))subscriptJacsubscript𝑟insubscript2subscript𝑛2subscript𝑟insubscript2superscript𝑥clean\operatorname{Jac}_{r_{\text{in},{\ell_{2}}}}(n_{2})(r_{\text{in},{\ell_{2}}}(% x^{\text{clean}}))roman_Jac start_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT in , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( italic_r start_POSTSUBSCRIPT in , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ). Indeed, setting aside the MLP activation function, the only other nonlinearity in those functions is a layer normalization; if we freeze the scaling factor at its clean value as in Section C.1, the Jacobians are equal to the product of the corresponding parameter matrices, divided by the clean scaling factor.

Thus if we premultiply the parameter matrices then we eliminate the need to do so at each token, which reduces the per-token quadratic cost by dresidsubscript𝑑residd_{\mathrm{resid}}italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT (i.e. to a scalar multiplication) for neuron-neuron edges, or by dresid/dsitesubscript𝑑residsubscript𝑑sited_{\mathrm{resid}}/d_{\mathrm{site}}italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT / italic_d start_POSTSUBSCRIPT roman_site end_POSTSUBSCRIPT (i.e. to a dsitesubscript𝑑sited_{\mathrm{site}}italic_d start_POSTSUBSCRIPT roman_site end_POSTSUBSCRIPT-dimensional dot product) for edges between neurons and some attention site.

It’s worth noting, though, that these premultiplied parameter matrices (or, indeed, the edge-AtP estimates if we use neuron sites) will in total be many times (specifically, (L1)dneurons4dresid𝐿1subscript𝑑neurons4subscript𝑑resid(L-1)\frac{d_{\mathrm{neurons}}}{4d_{\mathrm{resid}}}( italic_L - 1 ) divide start_ARG italic_d start_POSTSUBSCRIPT roman_neurons end_POSTSUBSCRIPT end_ARG start_ARG 4 italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT end_ARG times) larger than the MLP weights themselves, so storage may need to be considered carefully. It may be worth considering ways to only find the largest estimates, or the estimates over some threshold, rather than full estimates for all edges.

C.2.5 Edge AtP* costs

Let’s now consider how to adapt the AtP* proposals from Section 3.1 to this setting. We’ve already seen that the MLP fix, which is similarly motivated to the QK fix, has negligible cost in the neuron-nodes case, but comes with a dneurons/dresidsubscript𝑑neuronssubscript𝑑residd_{\mathrm{neurons}}/d_{\mathrm{resid}}italic_d start_POSTSUBSCRIPT roman_neurons end_POSTSUBSCRIPT / italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT overhead in quadratic cost in the case of using an MLP layer per node, at least on edges into those MLP nodes. We’ll consider the MLP fix to be part of edge-AtP*. Now let’s investigate the two corrections in regular AtP*: GradDrops, and the QK fix.

GradDrops

GradDrops works by replacing the single backward pass in the AtP formula with L𝐿Litalic_L backward passes; this in effect means L𝐿Litalic_L values for the multiplicand rn2AtP((xclean))superscriptsubscriptsubscript𝑟subscript𝑛2AtPsuperscript𝑥clean\nabla_{r_{n_{2}}}^{\text{AtP}}\mathcal{L}(\mathcal{M}(x^{\text{clean}}))∇ start_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ), so this is a multiplicative factor of L𝐿Litalic_L on the quadratic cost (though in fact some of these will be duplicates, and taking this into account lets us drive the multiplicative factor down to (L+1)/2𝐿12(L+1)/2( italic_L + 1 ) / 2). Notably this works equally well with “factored AtP”, as used for neuron edges; and in particular, if n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is a neuron, the gradients can easily be combined and shared across n1subscript𝑛1n_{1}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTs, eliminating the (L+1)/2𝐿12(L+1)/2( italic_L + 1 ) / 2 quadratic-cost overhead.

However, the motivation for GradDrops was to account for multiple paths whose effects may cancel; in the edge-interventions setting, these can already be discovered in a different way (by identifying the responsible edges out of n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT), so the benefit of GradDrops is lessened. At the same time, the cost remains substantial. Thus, we’ll omit GradDrops from our recommended procedure edge-AtP*.

QK fix

The QK fix applies to the n2(((xclean)))(n2(xclean))subscriptsubscript𝑛2superscript𝑥cleansubscript𝑛2superscript𝑥clean\nabla_{n_{2}}(\mathcal{L}(\mathcal{M}(x^{\text{clean}})))(n_{2}(x^{\text{% clean}}))∇ start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_L ( caligraphic_M ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) ) ( italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT ) ) term, i.e. to replacing the linear approximation to the softmax with a correct calculation to the change in softmax, for each different input Δrn1AtP(xclean,xnoise)Δsuperscriptsubscript𝑟subscript𝑛1AtPsuperscript𝑥cleansuperscript𝑥noise\Delta r_{n_{1}}^{\text{AtP}}(x^{\text{clean}},x^{\text{noise}})roman_Δ italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ). As in Section 3.1.1, there’s the simpler case of accounting for n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTs that are query nodes, and the more complicated case of n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTs that are key nodes using Algorithm 4 – but these are both cheap to do after computing the ΔattnLogitsΔattnLogits\Delta\operatorname{attnLogits}roman_Δ roman_attnLogits corresponding to n2subscript𝑛2n_{2}italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT.

The “factored AtP” way is to matrix-multiply Δrn1AtP(xclean,xnoise)Δsuperscriptsubscript𝑟subscript𝑛1AtPsuperscript𝑥cleansuperscript𝑥noise\Delta r_{n_{1}}^{\text{AtP}}(x^{\text{clean}},x^{\text{noise}})roman_Δ italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) with key or query weights and with the clean queries or keys, respectively. This means instead of the dresidsubscript𝑑residd_{\mathrm{resid}}italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT multiplications required for each edge n1n2subscript𝑛1subscript𝑛2n_{1}\rightarrow n_{2}italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → italic_n start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT with AtP, we need dresiddkey+Tdkeysubscript𝑑residsubscript𝑑key𝑇subscript𝑑keyd_{\mathrm{resid}}d_{\mathrm{key}}+Td_{\mathrm{key}}italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT + italic_T italic_d start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT multiplications (which, thanks to the causal mask, can be reduced to an average of dkey(dresid+(T+1)/2)subscript𝑑keysubscript𝑑resid𝑇12d_{\mathrm{key}}(d_{\mathrm{resid}}+(T+1)/2)italic_d start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT ( italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT + ( italic_T + 1 ) / 2 )).

The “unfactored” option is to stay in the rin,2subscript𝑟insubscript2r_{\text{in},{\ell_{2}}}italic_r start_POSTSUBSCRIPT in , roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT space: pre-multiply the clean queries or keys with the respective key or query weight matrices, and then take the dot product of Δrn1AtP(xclean,xnoise)Δsuperscriptsubscript𝑟subscript𝑛1AtPsuperscript𝑥cleansuperscript𝑥noise\Delta r_{n_{1}}^{\text{AtP}}(x^{\text{clean}},x^{\text{noise}})roman_Δ italic_r start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT AtP end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT clean end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT noise end_POSTSUPERSCRIPT ) with each one. This way, the quadratic part of the compute cost contains dresid(T+1)/2subscript𝑑resid𝑇12d_{\mathrm{resid}}(T+1)/2italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT ( italic_T + 1 ) / 2 multiplications; this will be more efficient for short sequence lengths.

This means that for edges into key and query nodes, the overhead of doing AtP+QKfix on the quadratic cost is a multiplicative factor of min(T+12,dkey(1+T+12dresid))𝑇12subscript𝑑key1𝑇12subscript𝑑resid\min\left(\frac{T+1}{2},d_{\mathrm{key}}\left(1+\frac{T+1}{2d_{\mathrm{resid}}% }\right)\right)roman_min ( divide start_ARG italic_T + 1 end_ARG start_ARG 2 end_ARG , italic_d start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT ( 1 + divide start_ARG italic_T + 1 end_ARG start_ARG 2 italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT end_ARG ) ).

QK fix + GradDrops

If the QK fix is being combined with GradDrops, then the first multiplication by the dresid×dkeysubscript𝑑residsubscript𝑑keyd_{\mathrm{resid}}\times d_{\mathrm{key}}italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT matrix can be shared between the different gradients; so the overhead on the quadratic cost of QKfix + GradDrops for edges into queries and keys, using the factored method, is dkey(1+(T+1)(L+1)4dresid)subscript𝑑key1𝑇1𝐿14subscript𝑑residd_{\mathrm{key}}\left(1+\frac{(T+1)(L+1)}{4d_{\mathrm{resid}}}\right)italic_d start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT ( 1 + divide start_ARG ( italic_T + 1 ) ( italic_L + 1 ) end_ARG start_ARG 4 italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT end_ARG ).

C.3 Conclusion

Considering all the above possibilities, it’s not obvious where the best tradeoff is between correctness and compute cost in all situations. In Table 2 we provide formulas measuring the number of multiplications in the quadratic cost for each kind of edge, across the variations we’ve mentioned. In Figure 16 we plug in the 4 sizes of Pythia model used elsewhere in the paper, such as Figure 2, to enable numerical comparison.

AtP variant O\rightarrowV O\rightarrowQ,K O\rightarrowMLP MLP\rightarrowV MLP\rightarrowQ,K MLP\rightarrowMLP
MLP layers DH2𝐷superscript𝐻2DH^{2}italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 2DH22𝐷superscript𝐻22DH^{2}2 italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT DH𝐷𝐻DHitalic_D italic_H DH𝐷𝐻DHitalic_D italic_H 2DH2𝐷𝐻2DH2 italic_D italic_H D𝐷Ditalic_D
QKfix DH2𝐷superscript𝐻2DH^{2}italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (T+1)DH2𝑇1𝐷superscript𝐻2(T+1)DH^{2}( italic_T + 1 ) italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT DH𝐷𝐻DHitalic_D italic_H DH𝐷𝐻DHitalic_D italic_H (T+1)DH𝑇1𝐷𝐻(T+1)DH( italic_T + 1 ) italic_D italic_H D𝐷Ditalic_D
QKfix+GD L+12DH2𝐿12𝐷superscript𝐻2\frac{L+1}{2}DH^{2}divide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (L+1)(T+1)2DH2𝐿1𝑇12𝐷superscript𝐻2\frac{(L+1)(T+1)}{2}DH^{2}divide start_ARG ( italic_L + 1 ) ( italic_T + 1 ) end_ARG start_ARG 2 end_ARG italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT L+12DH𝐿12𝐷𝐻\frac{L+1}{2}DHdivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H L+12DH𝐿12𝐷𝐻\frac{L+1}{2}DHdivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H (L+1)(T+1)2DH𝐿1𝑇12𝐷𝐻\frac{(L+1)(T+1)}{2}DHdivide start_ARG ( italic_L + 1 ) ( italic_T + 1 ) end_ARG start_ARG 2 end_ARG italic_D italic_H L+12D𝐿12𝐷\frac{L+1}{2}Ddivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D
AtP* DH2𝐷superscript𝐻2DH^{2}italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (T+1)DH2𝑇1𝐷superscript𝐻2(T+1)DH^{2}( italic_T + 1 ) italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H DH𝐷𝐻DHitalic_D italic_H (T+1)DH𝑇1𝐷𝐻(T+1)DH( italic_T + 1 ) italic_D italic_H ND𝑁𝐷NDitalic_N italic_D
AtP*+GD L+12DH2𝐿12𝐷superscript𝐻2\frac{L+1}{2}DH^{2}divide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (L+1)(T+1)2DH2𝐿1𝑇12𝐷superscript𝐻2\frac{(L+1)(T+1)}{2}DH^{2}divide start_ARG ( italic_L + 1 ) ( italic_T + 1 ) end_ARG start_ARG 2 end_ARG italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H L+12DH𝐿12𝐷𝐻\frac{L+1}{2}DHdivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H (L+1)(T+1)2DH𝐿1𝑇12𝐷𝐻\frac{(L+1)(T+1)}{2}DHdivide start_ARG ( italic_L + 1 ) ( italic_T + 1 ) end_ARG start_ARG 2 end_ARG italic_D italic_H ND𝑁𝐷NDitalic_N italic_D
QKfix (long) DH2𝐷superscript𝐻2DH^{2}italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (2D+T+1)KH22𝐷𝑇1𝐾superscript𝐻2(2D+T+1)KH^{2}( 2 italic_D + italic_T + 1 ) italic_K italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT DH𝐷𝐻DHitalic_D italic_H DH𝐷𝐻DHitalic_D italic_H (2D+T+1)KH2𝐷𝑇1𝐾𝐻(2D+T+1)KH( 2 italic_D + italic_T + 1 ) italic_K italic_H D𝐷Ditalic_D
QKfix+GD L+12DH2𝐿12𝐷superscript𝐻2\frac{L+1}{2}DH^{2}divide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT L+12(2D+T+1)KH2𝐿122𝐷𝑇1𝐾superscript𝐻2\frac{L+1}{2}(2D+T+1)KH^{2}divide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG ( 2 italic_D + italic_T + 1 ) italic_K italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT L+12DH𝐿12𝐷𝐻\frac{L+1}{2}DHdivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H L+12DH𝐿12𝐷𝐻\frac{L+1}{2}DHdivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H L+12(2D+T+1)KH𝐿122𝐷𝑇1𝐾𝐻\frac{L+1}{2}(2D+T+1)KHdivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG ( 2 italic_D + italic_T + 1 ) italic_K italic_H L+12D𝐿12𝐷\frac{L+1}{2}Ddivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D
ATP* DH2𝐷superscript𝐻2DH^{2}italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (2D+T+1)KH22𝐷𝑇1𝐾superscript𝐻2(2D+T+1)KH^{2}( 2 italic_D + italic_T + 1 ) italic_K italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H DH𝐷𝐻DHitalic_D italic_H (2D+T+1)KH2𝐷𝑇1𝐾𝐻(2D+T+1)KH( 2 italic_D + italic_T + 1 ) italic_K italic_H ND𝑁𝐷NDitalic_N italic_D
AtP*+GD L+12DH2𝐿12𝐷superscript𝐻2\frac{L+1}{2}DH^{2}divide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT L+12(2D+T+1)KH2𝐿122𝐷𝑇1𝐾superscript𝐻2\frac{L+1}{2}(2D+T+1)KH^{2}divide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG ( 2 italic_D + italic_T + 1 ) italic_K italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H L+12DH𝐿12𝐷𝐻\frac{L+1}{2}DHdivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H L+12(2D+T+1)KH𝐿122𝐷𝑇1𝐾𝐻\frac{L+1}{2}(2D+T+1)KHdivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG ( 2 italic_D + italic_T + 1 ) italic_K italic_H ND𝑁𝐷NDitalic_N italic_D
Neurons DH2𝐷superscript𝐻2DH^{2}italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 2DH22𝐷superscript𝐻22DH^{2}2 italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H 2KNH2𝐾𝑁𝐻2KNH2 italic_K italic_N italic_H N2superscript𝑁2N^{2}italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
MLPfix DH2𝐷superscript𝐻2DH^{2}italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 2DH22𝐷superscript𝐻22DH^{2}2 italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H 2KNH2𝐾𝑁𝐻2KNH2 italic_K italic_N italic_H N2superscript𝑁2N^{2}italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
AtP* DH2𝐷superscript𝐻2DH^{2}italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (T+1)DH2𝑇1𝐷superscript𝐻2(T+1)DH^{2}( italic_T + 1 ) italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H (T+1)KNH𝑇1𝐾𝑁𝐻(T+1)KNH( italic_T + 1 ) italic_K italic_N italic_H N2superscript𝑁2N^{2}italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
AtP*+GD L+12DH2𝐿12𝐷superscript𝐻2\frac{L+1}{2}DH^{2}divide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT L+12(T+1)DH2𝐿12𝑇1𝐷superscript𝐻2\frac{L+1}{2}(T+1)DH^{2}divide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG ( italic_T + 1 ) italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H L+12VNH𝐿12𝑉𝑁𝐻\frac{L+1}{2}VNHdivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_V italic_N italic_H (L+1)(T+1)2KNH𝐿1𝑇12𝐾𝑁𝐻\frac{(L+1)(T+1)}{2}KNHdivide start_ARG ( italic_L + 1 ) ( italic_T + 1 ) end_ARG start_ARG 2 end_ARG italic_K italic_N italic_H N2superscript𝑁2N^{2}italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
ATP* (long) DH2𝐷superscript𝐻2DH^{2}italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (2D+T+1)KH22𝐷𝑇1𝐾superscript𝐻2(2D+T+1)KH^{2}( 2 italic_D + italic_T + 1 ) italic_K italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H (T+1)KNH𝑇1𝐾𝑁𝐻(T+1)KNH( italic_T + 1 ) italic_K italic_N italic_H N2superscript𝑁2N^{2}italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
AtP*+GD L+12DH2𝐿12𝐷superscript𝐻2\frac{L+1}{2}DH^{2}divide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_D italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT L+12(2D+T+1)KH2𝐿122𝐷𝑇1𝐾superscript𝐻2\frac{L+1}{2}(2D+T+1)KH^{2}divide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG ( 2 italic_D + italic_T + 1 ) italic_K italic_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT VNH𝑉𝑁𝐻VNHitalic_V italic_N italic_H L+12VNH𝐿12𝑉𝑁𝐻\frac{L+1}{2}VNHdivide start_ARG italic_L + 1 end_ARG start_ARG 2 end_ARG italic_V italic_N italic_H (L+1)(T+1)2KNH𝐿1𝑇12𝐾𝑁𝐻\frac{(L+1)(T+1)}{2}KNHdivide start_ARG ( italic_L + 1 ) ( italic_T + 1 ) end_ARG start_ARG 2 end_ARG italic_K italic_N italic_H N2superscript𝑁2N^{2}italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
Table 2: Per-token per-layer-pair total quadratic cost of each kind of between-layers edge, across edge-AtP variants. For brevity, we omit the layer-pair (L2)binomial𝐿2\binom{L}{2}( FRACOP start_ARG italic_L end_ARG start_ARG 2 end_ARG ) factor that would otherwise be in every cell, and use D:=dresid,H:=# heads per layer,K:=dkey,V:=dvalue,N:=dneuronsformulae-sequenceassign𝐷subscript𝑑residformulae-sequenceassign𝐻# heads per layerformulae-sequenceassign𝐾subscript𝑑keyformulae-sequenceassign𝑉subscript𝑑valueassign𝑁subscript𝑑neuronsD:=d_{\mathrm{resid}},H:=\text{\# heads per layer},K:=d_{\mathrm{key}},V:=d_{% \mathrm{value}},N:=d_{\mathrm{neurons}}italic_D := italic_d start_POSTSUBSCRIPT roman_resid end_POSTSUBSCRIPT , italic_H := # heads per layer , italic_K := italic_d start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT , italic_V := italic_d start_POSTSUBSCRIPT roman_value end_POSTSUBSCRIPT , italic_N := italic_d start_POSTSUBSCRIPT roman_neurons end_POSTSUBSCRIPT.
Refer to caption
Figure 16: A comparison of edge-AtP variants across model sizes and prompt lengths. AtP* here is defined to include QKfix and MLPfix, but not GradDrops. The costs vary across several orders of magnitude for each setting.
In the setting with full-MLP nodes, MLPfix carries substantial cost for short prompts, but barely matters for long prompts.
In the neuron-nodes setting, MLPfix is costless. But GradDrops in that setting continues to impose a large cost; even though it doesn’t affect MLP\rightarrowMLP edges, it does affect MLP\rightarrowQ,K edges, which come out dominating the cost with QKfix.

Appendix D Distribution of true effects

In Figure 17, we show the distribution of c(n)𝑐𝑛c(n)italic_c ( italic_n ) across models and distributions.

Figure 17: Distribution of true effects across models and prompt pair distributions

AttentionNodes

NeuronNodes

Refer to caption
a.i
Refer to caption
a.ii
(a) Pythia-410M
Refer to caption
b.i
Refer to caption
b.ii
(b) Pythia-1B
Refer to caption
c.i
Refer to caption
c.ii
(c) Pythia-2.8B
Refer to caption
d.i
Refer to caption
d.ii
(d) Pythia-12B