Views Can Be Deceiving: Improved SSL Through Feature Space Augmentation

Kimia Hamidieh1  Haoran Zhang1  Swami Sankaranarayanan2  Marzyeh Ghassemi1
1MIT, 2Sony AI
{hamidieh,haoranz,swamiviv,mghassem}@mit.edu
Abstract

Supervised learning methods have been found to exhibit inductive biases favoring simpler features. When such features are spuriously correlated with the label, this can result in suboptimal performance on minority subgroups. Despite the growing popularity of methods which learn from unlabeled data, the extent to which these representations encode spurious features is unclear. In this work, we explore the impact of spurious features on Self-Supervised Learning (SSL) for visual representation learning. We first empirically show that commonly used augmentations in SSL can cause undesired invariances in the image space, and illustrate this with a simple example. We further show that classical approaches in combating spurious correlations, such as dataset re-sampling during SSL, do not consistently lead to invariant representations. Motivated by these findings, we propose LateTVG to remove spurious information from these representations during pretraining, by regularizing later layers of the encoder via pruning. We find that our method produces representations which outperform the baselines on several benchmarks, without the need for group or label information during SSL.

1 Introduction

Standard supervised machine learning models exhibit high overall performance but often perform poorly on minority subgroups (Shah et al., 2020; McCoy et al., 2019; Gururangan et al., 2018). One potential cause is the presence of spurious correlations, which are features that are only correlated with the label for specific subsets of data. For instance, a machine learning model tasked with predicting bird species from images across different habitats may use the background the bird commonly appears in as a “shortcut”, instead of core features specific to the bird such as the shape of their beak or plumage. This results in poor performance on bird groups that appear in unexpected environments (Sagawa et al., 2020a). Identifying spurious correlations in the supervised learning setting has been well studied, where empirical risk minimization has been shown to exploit spurious correlations and result in poor performance for minority subgroups (Hashimoto et al., 2018). As downstream tasks are explicitly defined, the label can be used to distinguish between core and spurious features (Liu et al., 2021a; Zhang et al., 2022). Recent work has proposed various methods to identify and mitigate the effects of spurious features, such as learning multiple prediction heads (Lee et al., 2022b), causal inference (Creager et al., 2021), data augmentation (Gao et al., 2023) and targeted strategies such as importance weighting (Lahoti et al., 2020), re-sampling Idrissi et al. (2021); Tu et al. (2020), or approaches based on group distributionally robust optimization (Sagawa et al., 2020a; Duchi et al., 2019).

More recently, self-supervised learning (SSL) has emerged as a common form of pre-training for task-agnostic learning with large, unlabeled datasets (Chen et al., 2020a; He et al., 2019; Grill et al., 2020; Chen & He, 2020; Caron et al., 2020; Zbontar et al., 2021; Chen et al., 2020b). SSL methods learn representations from unlabeled datasets by solving an auxiliary pretext task (Doersch et al., 2015), such as inducing invariance between the representations of two augmented views of the same image (He et al., 2019; Chen et al., 2020a). These methods have shown impressive results for a wide range of downstream tasks and datasets (Liu et al., 2021b; Jaiswal et al., 2020; Tamkin et al., 2021).

Capturing core features – rather than spurious features – is essential for learning effective representations that can be used in downstream tasks, but is particularly difficult in the case of SSL due to the absence of labeled data during the pre-training process. Given only unlabeled data, we define spurious features as those that strongly correlate with core features for most examples in the training set, but are not useful for downstream tasks. For example, when training an SSL model on multi-object images, larger objects may interfere with the learning of smaller objects (Chen et al., 2021). If the downstream task involves only the prediction of smaller objects, the larger (spurious) object may suppress the smaller (core) object from being learned. Large-scale unlabeled datasets that are commonly used in machine learning are inevitably imbalanced (Van Horn et al., 2021), have been found to be biased towards spuriously correlated sensitive attributes (Calude & Longo, 2017) such as gender or race (Agarwal et al., 2021), and can also include label-irrelevant features (Torralba & Efros, 2011; Fan et al., 2014).

In this paper, we investigate the impact of spurious correlations on SSL pre-training. We first show theoretically that image augmentations used in SSL pre-training can lead to spurious connectivity when learning representations, causing the model to fail to predict the label using core features in downstream tasks. We empirically evaluate spurious connectivity, and then show that existing methods for utilizing group information in ERM based approaches do not provide an analogous improvement in SSL pre-training. We then propose Late-layer Transformation-based View Generation or LateTVG – a method that induces invariance to spurious features in the representation space by regularizing final layers of the featurizer via pruning. Importantly, since our approach addresses SSL pre-training, we do not assume that model developers know apriori the identity or values of the spurious features that exist in the data. We first evaluate LateTVG on several popular benchmarks for spurious feature learning, and then connect our method to the theoretical analysis by showing that LateTVG models empirically exhibit lower spurious connectivity. Our method demonstrates improved discriminative ability, especially over minority subgroups, for downstream predictive tasks, without access to group or label information. We make the following contributions:

  • We provide theoretical arguments (Sec 3.3) that illustrate how common augmentations used in SSL pre-training affect the model’s ability to rely on spurious features, for downstream linear classifiers.

  • We explore the extent of spurious learning in self-supervised representations through the lens of downstream worst-group performance. We empirically show that known techniques for avoiding spurious correlations, such as re-sampling of the training set given group information, do not consistently improve core feature representations (Sec 4.4).

  • We propose LateTVG – an approach that corrects for the biases caused by augmentations, by modifying views of samples in the representation space (Sec 5.1). We find that LateTVG effectively improves worst-group performance in downstream tasks on four datasets by enforcing core feature learning (Sec 5.2).

2 Related Work

Spurious Correlations.

Spurious correlations arise in supervised learning models Koh et al. (2021); Joshi et al. (2023); Singla & Feizi (2021) in a variety of domains, from medical imaging (Zech et al., 2018; DeGrave et al., 2021) to natural language processing (Tu et al., 2020; Wang & Culotta, 2020). A variety of approaches have been proposed to learn classifiers which do not make use of spurious information. Methods like GroupDRO (Sagawa et al., 2020a) and DFR (Kirichenko et al., 2022) require group information during training, while methods like JTT (Liu et al., 2021a), LfF (Nam et al., 2020), CVaR DRO (Duchi et al., 2019), and CnC (Zhang et al., 2022) do not. However, all methods require group information for model selection.

Self-supervised Representation Learning.

Self-supervised learning methods learn representations from large-scale unlabeled datasets where annotations are scarce. In vision applications, the pretext task is typically to maximize similarity between two augmented views of the same image (**g & Tian, 2020). This can be done in a contrastive fashion using the InfoNCE loss (Oord et al., 2018), such as in Chen et al. (2020a) and Chen et al. (2020b), or without the need for negative samples at all, as in  Grill et al. (2020); Caron et al. (2020); Chen & He (2020); Caron et al. (2021); Oquab et al. (2023); Zbontar et al. (2021). Prior work has shown that SSL models may learn to spuriously associate certain foreground items with certain backgrounds (Meehan et al., 2023), In this work, we explore one potential mechanism for this phenomenon, both theoretically and empirically.

Representation Learning under Dataset Imbalance and Shortcuts.

Self-supervised models have demonstrated increased robustness to dataset imbalance (Liu et al., 2021b; Jiang et al., 2021b; a), and the dominance of easier or larger features suppressing the learning of other features (Chen et al., 2021). Some prior work has addressed shortcut learning in contrastive learning through adversarial feature modification without group labels (Robinson et al., 2021). However, other approaches to group robustness or fairness in self-supervised learning require group information or labels (Tsai et al., 2020; Song et al., 2019; Wang et al., 2021; Bordes et al., 2023; Scalbert et al., 2023). This paper focuses on learning representations from an unlabeled dataset with spurious correlations, encompassing both dataset imbalance and features of varying difficulty.

Regularization in Self-supervised Learning.

The concept of regularizing a specific subset of the network is relatively unexplored in self-supervised learning but finds motivation in recent findings from supervised settings, such as addressing minority examples (Hooker et al., 2019), out-of-distribution generalization (Zhang et al., 2021), late-layer regularizations through head weight-decay (Abnar et al., 2021), and initialization (Zhou et al., 2022). Additionally, Lee et al. (2022a) propose surgically fine-tuning specific layers of the network to handle distribution shifts in particular categories. These studies provide support for the approach of targeting a specific component of the network in self-supervised learning.

3 Spurious Connectivity Induces Downstream Failures

In this section, we introduce a toy setting to demonstrate that common augmentations used in SSL pre-training affect a model’s ability to rely on spurious features for downstream linear classifiers. We consider a binary classification problem with a binary spurious attribute, with an equal number of samples per group (Section 3.2). We show that augmentations applied during SSL pre-training can introduce undesired invariances in the representation space learned by a contrastive objective, making the downstream linear classifier trained on representations more reliant on the spurious feature (Section 3.3).

3.1 Background and Setup

Setup.

We consider learning representations from an unlabeled data space 𝒳𝒳\mathcal{X}caligraphic_X generated from an underlying latent feature space 𝒵m:={zcore,zspur,,zm}𝒵superscript𝑚assignsubscript𝑧coresubscript𝑧spursubscript𝑧𝑚\mathcal{Z}\in\mathbb{R}^{m}:=\{z_{\text{core}},z_{\text{spur}},\dots,z_{m}\}caligraphic_Z ∈ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT := { italic_z start_POSTSUBSCRIPT core end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT spur end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT }, where zcoresubscript𝑧corez_{\text{core}}italic_z start_POSTSUBSCRIPT core end_POSTSUBSCRIPT and zspursubscript𝑧spurz_{\text{spur}}italic_z start_POSTSUBSCRIPT spur end_POSTSUBSCRIPT are correlated features. For a given downstream task with labeled samples, we assume that each x𝒳𝑥𝒳x\in\mathcal{X}italic_x ∈ caligraphic_X belongs to a class given by the ground-truth labeling function y:𝒳𝒴:𝑦𝒳𝒴y:\mathcal{X}\rightarrow\mathcal{Y}italic_y : caligraphic_X → caligraphic_Y where zcoresubscript𝑧corez_{\text{core}}italic_z start_POSTSUBSCRIPT core end_POSTSUBSCRIPT determines the labels for our downstream task of interest, while zspursubscript𝑧spurz_{\text{spur}}italic_z start_POSTSUBSCRIPT spur end_POSTSUBSCRIPT determines the spurious attribute, which is easier to learn, and is not of interest for downstream tasks. We can define a deterministic attribute function a:𝒳𝒮:𝑎𝒳𝒮a:\mathcal{X}\rightarrow\mathcal{S}italic_a : caligraphic_X → caligraphic_S where each x𝒳𝑥𝒳x\in\mathcal{X}italic_x ∈ caligraphic_X takes a value in 𝒮𝒮\mathcal{S}caligraphic_S. Let g=(y(x),a(x))𝑔𝑦𝑥𝑎𝑥g=(y(x),a(x))italic_g = ( italic_y ( italic_x ) , italic_a ( italic_x ) ) denote the subgroup of a given sample x𝑥xitalic_x, where 𝒢=𝒴×𝒮𝒢𝒴𝒮\mathcal{G}=\mathcal{Y}\times\mathcal{S}caligraphic_G = caligraphic_Y × caligraphic_S is the set of all possible subgroups. Figure 1 illustrates the subgroups on the Waterbirds dataset, where the background is a spurious feature that correlates with the bird species.

Contrastive learning.

We aim to learn representations by bringing together data-augmented views of the same input, which we refer to as positive pairs, using a contrastive objective. Let P+subscript𝑃P_{+}italic_P start_POSTSUBSCRIPT + end_POSTSUBSCRIPT be the distribution of positive pairs, which can be defined as the marginal probability of generating the augmented pair x𝑥xitalic_x and xsuperscript𝑥x^{\prime}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT from the same image in the (natural) population data. Thus the distribution P+subscript𝑃P_{+}italic_P start_POSTSUBSCRIPT + end_POSTSUBSCRIPT relies both on original data distribution and the choice of SSL augmentations. To analyze the representation space learned in contrastive learning and core feature predictivity of the representations, consider a weighted graph with vertex set 𝒳𝒳\mathcal{X}caligraphic_X where the undirected edge (x,x)𝑥superscript𝑥(x,x^{\prime})( italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) has weight wxx=P+(x,x)subscript𝑤𝑥superscript𝑥subscript𝑃𝑥superscript𝑥w_{xx^{\prime}}=P_{+}(x,x^{\prime})italic_w start_POSTSUBSCRIPT italic_x italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = italic_P start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) similar to augmentation graph in HaoChen et al. (2021).

Although the augmentation graph learns semantically similar structures that enables generalization to new domains (Shen et al., 2022), the inductive biases set by these augmentations is not well studied. In this work, we draw attention to cases where augmentations can create spurious connectivities within subgroups of the data, and when and why these connectivities can cause the downstream linear model to rely on the spurious feature.

Refer to caption
(a)
Refer to caption
(b)
Figure 1: Analysing SSL augmentations. (a) Images generated from a latent space with correlating features. (b) If the connectivity induced by SSL augmentations between subgroups with the same spurious features is higher than the ones with the same invariant features, learned representations lead a downstream linear model to separate the data based on the spurious feature (red dashed line) instead of the invariant feature (green dashed line). Our empirical evaluation in Table 4 shows that this is indeed the case across different datasets considered in this work.

3.2 Spurious Connectivity in a Toy Setup

In this section, we introduce a setting in which contrastive objectives can learn representations that cause linear downstream models fail on downstream tasks. To start, we investigate how augmentations can transform the samples such that the subgroup assignment changes.

Definition 3.1.

Subgroup connectivity. Define the average subgroup connectivity given two disjoint subsets G1,G2𝒳subscript𝐺1subscript𝐺2𝒳G_{1},G_{2}\subseteq\mathcal{X}italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊆ caligraphic_X as w(G1,G2)=1|G1|.|G2|xG1,xG2wxx𝑤subscript𝐺1subscript𝐺21formulae-sequencesubscript𝐺1subscript𝐺2subscriptformulae-sequence𝑥subscript𝐺1superscript𝑥subscript𝐺2subscript𝑤𝑥superscript𝑥w(G_{1},G_{2})=\frac{1}{|G_{1}|.|G_{2}|}\sum_{x\in G_{1},x^{\prime}\in G_{2}}w% _{xx^{\prime}}italic_w ( italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG | italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | . | italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | end_ARG ∑ start_POSTSUBSCRIPT italic_x ∈ italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_x italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. where wxxsubscript𝑤𝑥superscript𝑥w_{xx^{\prime}}italic_w start_POSTSUBSCRIPT italic_x italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is the probability of generating the augmented pair x𝑥xitalic_x and xsuperscript𝑥x^{\prime}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT from the same image in the natural population data.

Intuitively, this subgroup connectivity is the average weight of edges connecting G1subscript𝐺1G_{1}italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT to G2subscript𝐺2G_{2}italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and is proportional to the probability of a sample xG1𝑥subscript𝐺1x\in G_{1}italic_x ∈ italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT being transformed to a sample xG2superscript𝑥subscript𝐺2x^{\prime}\in G_{2}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT via augmentations. See Appendix C for further details.

We specifically define the following terms to be the expected value of w(G1,G2)𝑤subscript𝐺1subscript𝐺2w(G_{1},G_{2})italic_w ( italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) from Definition 3.1, when subgroups G1subscript𝐺1G_{1}italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and G2subscript𝐺2G_{2}italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT have the following properties:

  • Spurious connectivity (α𝛼\alphaitalic_α): G1subscript𝐺1G_{1}italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and G2subscript𝐺2G_{2}italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT share the same spurious attribute but differ in class

  • Invariant connectivity (β𝛽\betaitalic_β): G1subscript𝐺1G_{1}italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and G2subscript𝐺2G_{2}italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT share the same class but differ in spurious attribute

  • Opposite connectivity (γ𝛾\gammaitalic_γ): G1subscript𝐺1G_{1}italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and G2subscript𝐺2G_{2}italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT differ both in the spurious attribute and the label

Where α𝛼\alphaitalic_α, β𝛽\betaitalic_β, γ𝛾\gammaitalic_γ are average values estimated across a dataset consisting of subgroups.

Toy Setup.

We consider a downstream classification problem where a spurious attribute is present, and both the input and the spurious attribute take binary values. We define the probability of sampling a positive pair (x,x)𝑥superscript𝑥(x,x^{\prime})( italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) based on the expected connectivity terms αtoysubscript𝛼toy\alpha_{\text{toy}}italic_α start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT, βtoysubscript𝛽toy\beta_{\text{toy}}italic_β start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT, γtoysubscript𝛾toy\gamma_{\text{toy}}italic_γ start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT, and ρtoysubscript𝜌toy\rho_{\text{toy}}italic_ρ start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT as follows:

P+(x,x)={αtoy,if a(x)a(x) and y(x)=y(x)βtoy,if a(x)=a(x) and y(x)y(x)γtoy,if a(x)a(x) and y(x)y(x)ρtoy,if a(x)=a(x) and y(x)=y(x)subscript𝑃𝑥superscript𝑥casessubscript𝛼toyif 𝑎𝑥𝑎superscript𝑥 and 𝑦𝑥𝑦superscript𝑥subscript𝛽toyif 𝑎𝑥𝑎superscript𝑥 and 𝑦𝑥𝑦superscript𝑥subscript𝛾toyif 𝑎𝑥𝑎superscript𝑥 and 𝑦𝑥𝑦superscript𝑥subscript𝜌toyif 𝑎𝑥𝑎superscript𝑥 and 𝑦𝑥𝑦superscript𝑥P_{+}(x,x^{\prime})=\begin{cases}\alpha_{\text{toy}},&\text{if }a(x)\neq a(x^{% \prime})\text{ and }y(x)=y(x^{\prime})\\ \beta_{\text{toy}},&\text{if }a(x)=a(x^{\prime})\text{ and }y(x)\neq y(x^{% \prime})\\ \gamma_{\text{toy}},&\text{if }a(x)\neq a(x^{\prime})\text{ and }y(x)\neq y(x^% {\prime})\\ \rho_{\text{toy}},&\text{if }a(x)=a(x^{\prime})\text{ and }y(x)=y(x^{\prime})% \end{cases}italic_P start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = { start_ROW start_CELL italic_α start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT , end_CELL start_CELL if italic_a ( italic_x ) ≠ italic_a ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) and italic_y ( italic_x ) = italic_y ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_β start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT , end_CELL start_CELL if italic_a ( italic_x ) = italic_a ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) and italic_y ( italic_x ) ≠ italic_y ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_γ start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT , end_CELL start_CELL if italic_a ( italic_x ) ≠ italic_a ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) and italic_y ( italic_x ) ≠ italic_y ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_ρ start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT , end_CELL start_CELL if italic_a ( italic_x ) = italic_a ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) and italic_y ( italic_x ) = italic_y ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_CELL end_ROW

Note that the average subgroup connectivity for this setup, would be exactly the same as the corresponding connectivity variable. Thus in our running example we have α=αtoy,β=βtoy,γ=γtoyformulae-sequence𝛼subscript𝛼toyformulae-sequence𝛽subscript𝛽toy𝛾subscript𝛾toy\alpha=\alpha_{\text{toy}},\beta=\beta_{\text{toy}},\gamma=\gamma_{\text{toy}}italic_α = italic_α start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT , italic_β = italic_β start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT , italic_γ = italic_γ start_POSTSUBSCRIPT toy end_POSTSUBSCRIPT, and we can use them interchangeably. For this simplified augmentation graph, the expected connectivity terms between groups are a property of the graph, and independent of the model or architecture we use for learning representations. Combined with a contrastive objective, the expected connectivity can be a proxy for how close different subgroups are going to be in the representation space.

3.3 Analysis of the Toy Setting

In Section 4.2, we empirically show that common augmentations used in contrastive learning can be detrimental to learning invariant representations, as they implicitly encourage samples to cluster primarily based on the spurious feature. Based on this observation, we make the following assumption.

Assumption 3.2.

Given a spurious attribute function a:𝒳|G|:𝑎𝒳𝐺a:\mathcal{X}\rightarrow|G|italic_a : caligraphic_X → | italic_G | which is defined for all x𝒳𝑥𝒳x\in\mathcal{X}italic_x ∈ caligraphic_X, we assume that for a data point x𝒳𝑥𝒳x\in\mathcal{X}italic_x ∈ caligraphic_X, the probability of distorting the labeling of the augmented images sampled from the augmentation distribution 𝒜(|x¯)\mathcal{A}(\cdot|\bar{x})caligraphic_A ( ⋅ | over¯ start_ARG italic_x end_ARG ), is greater than the probability of distorting the attribute. More formally,

Prx~𝒜(|x)\displaystyle\Pr_{\tilde{x}\sim\mathcal{A}(\cdot|x)}roman_Pr start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG ∼ caligraphic_A ( ⋅ | italic_x ) end_POSTSUBSCRIPT (y(x~)y(x),a(x~)=a(x))Prx~𝒜(|x)(y(x~)=y(x),a(x~)a(x))\displaystyle\left({y}(\tilde{x})\neq y(x),{a}(\tilde{x})=a(x)\right)\geq\Pr_{% \tilde{x}\sim\mathcal{A}(\cdot|x)}\left({y}(\tilde{x})=y(x),{a}(\tilde{x})\neq a% (x)\right)( italic_y ( over~ start_ARG italic_x end_ARG ) ≠ italic_y ( italic_x ) , italic_a ( over~ start_ARG italic_x end_ARG ) = italic_a ( italic_x ) ) ≥ roman_Pr start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG ∼ caligraphic_A ( ⋅ | italic_x ) end_POSTSUBSCRIPT ( italic_y ( over~ start_ARG italic_x end_ARG ) = italic_y ( italic_x ) , italic_a ( over~ start_ARG italic_x end_ARG ) ≠ italic_a ( italic_x ) )
Lemma 3.3.

Consider the set of (unlabeled) population data 𝒳𝒳\mathcal{X}caligraphic_X in a binary-class setting where the spurious attribute takes binary values, consisting of |𝒢|=4𝒢4|\mathcal{G}|=4| caligraphic_G | = 4 groups, with the same number of examples per group. Consider a simplified augmentation graph with parameters α𝛼\alphaitalic_α, β𝛽\betaitalic_β, ρ𝜌\rhoitalic_ρ, γ𝛾\gammaitalic_γ defined as in  3.2, and assume that augmentations are more likely to change either class or attribute, than to change neither of the two (α>γ,β>γformulae-sequence𝛼𝛾𝛽𝛾\alpha>\gamma,\beta>\gammaitalic_α > italic_γ , italic_β > italic_γ), and that augmentations are less likely to change both at the same time (ρ>α,ρ>βformulae-sequence𝜌𝛼𝜌𝛽\rho>\alpha,\rho>\betaitalic_ρ > italic_α , italic_ρ > italic_β).

Under these conditions, the spectral contrastive loss recovers both invariant and spurious features, and for each sample in the population data, the spurious feature is bounded by constant Bsp=βαγ+ρsubscript𝐵𝑠𝑝𝛽𝛼𝛾𝜌B_{sp}=\sqrt{\beta-\alpha-\gamma+\rho}italic_B start_POSTSUBSCRIPT italic_s italic_p end_POSTSUBSCRIPT = square-root start_ARG italic_β - italic_α - italic_γ + italic_ρ end_ARG, while the invariant feature is bounded by Binv=αβγ+ρsubscript𝐵𝑖𝑛𝑣𝛼𝛽𝛾𝜌B_{inv}=\sqrt{\alpha-\beta-\gamma+\rho}italic_B start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT = square-root start_ARG italic_α - italic_β - italic_γ + italic_ρ end_ARG, in the representation space. Proof in Appendix C.

Corollary 3.4.

Given Assumption 3.2, where α>β𝛼𝛽\alpha>\betaitalic_α > italic_β in the simplified augmentation graph, the margin of the spurious classifier is Bspsubscript𝐵𝑠𝑝B_{sp}italic_B start_POSTSUBSCRIPT italic_s italic_p end_POSTSUBSCRIPT, and is less than the margin of the invariant classifier Binvsubscript𝐵𝑖𝑛𝑣B_{inv}italic_B start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT, and the max-margin classifier trained on representations given by spectral clustering converges to the spurious classifier.

This suggests that even with the same number of samples across different groups during pre-training, downstream linear classifiers can rely on the spurious feature to make predictions, where the representations are determined by the simplified augmentation graph and the spectral contrastive loss.

4 Exploring Spurious Learning in Representations

In this section, we investigate the performance of downstream linear models trained on self-supervised representations, empirically verify our assumption regarding spurious and invariant connectivity, and show that in practice – similar to our toy analysis – having the same number of examples across groups in the presence of spurious connectivity does not lead to performance gains.

4.1 Experimental Setup

Datasets

We evaluate methods on five commonly used benchmarks in spurious correlations – CelebA (Liu et al., 2015), CMNIST (Arjovsky et al., 2019), MetaShift (Liang & Zou, 2022), Spurious CIFAR-10 (Nagarajan et al., 2020), and Waterbirds (Wah et al., 2011) (See Appendix D.1 for dataset descriptions). For each dataset, we train an encoder with an SSL-based pre-training step followed by a supervised training of a linear model that probes the representations learned using SSL for the downstream task.

SSL Pre-training

For the SSL pre-training, we train SimSiam (Chen & He, 2020) models with a ResNet backbone throughout the paper. The training split used during the pre-training stage are unbalanced and contain spuriously correlated data. The group/label counts for each dataset and split is shown in Appendix D.1. The backbone network used for most of our experiments are initialized with random weights, unless specified otherwise. We additionally report results for SimCLR (Chen et al., 2020a) models in Section 5.2.1,

Downstream Task

For downstream task prediction, we train a linear layer using logistic regression on top of the pretrained embeddings. Note that the backbone is frozen during this finetuning phase and only the linear layer is updated. We use a balanced dataset for training where the spurious correlation does not hold. To create this downstream training dataset, we subsample majority groups (Sagawa et al., 2020b; Idrissi et al., 2021), to avoid the geometrical skews (Nagarajan et al., 2020) of the linear classifier on representations. Then, we evaluate the learned representations on the standard test split of each dataset, where group information is given. For each run, we report the average and worst-group accuracy.

Empirical Evaluation of Spurious Connectivity

To evaluate the connectivity term for each pair of subgroups in datasets exhibiting spurious correlations, we conduct an empirical analysis similar to Shen et al. (2022). Specifically, we train a classifier to distinguish between each pair of subgroups and evaluate its performance on a subset of the data that has been augmented with SSL augmentations. The error of the classifier represents the probability that the augmentation module alters the subgroup assignment for each example between the two subgroups, making them indistinguishable. Figure 1 illustrates this procedure.

The Role of Initialization

In representation learning, encoders are not typically trained from scratch but initialized from a model pretrained on larger datasets, such as ImageNet (Deng et al., 2009). Recent work in transfer learning (Geirhos et al., 2018; Salman et al., 2022) has questioned this assumption and pointed out that biases in pretrained models linger even after finetuning on downstream target tasks. In this section and more broadly in our work, we focus on performing SSL pre-training from randomly initialized weights. In addition, since the datasets considered in this work are similar to ImageNet, the performance of off-the-shelf ImageNet pretrained models is expected to be higher. For completeness, we have added these results to Appendix G.2.

4.2 High Levels of Spurious Connectivity in Practice

We measure connectivity across four datasets in Table 4, and on all of them, we find that the average spurious connectivity is higher than invariant connectivity. We also confirm that both these values are higher than the probability of simultaneously changing both spurious attributes and invariant attributes. This means that the samples within the training set are more likely to be connected to each other through the spurious attribute, rather than the core feature. This suggests that the contrastive loss prefers alignment based on the spurious attribute instead of the class.

Table 1: We report the error of classifiers trained to distinguish between two subgroups as a proxy for the probability of augmentations flip** group assignments between each two groups in the dataset, or the connectivity of two subgroups in the image space.
Dataset
Spurious
Connectivity
Invariant
Connectivity
Opposite
Connectivity
celebA 10.4 3.7 2.8
cmnist 31.6 8.3 6.8
metashift 16.3 13.6 5.0
waterbirds 25.3 11.2 7.8

We compute the connectivity terms by training classifiers to distinguish augmented data from each combination of the two groups in the dataset and reporting their error rates.

The details of the choice of augmentations and training for this step can be found in Appendix E.

4.3 SSL Models Learn Spurious Features

To measure the reliance of downstream models to spurious correlations, we measure the accuracy of the downstream model on each group in the test set, and use the worst-performing group accuracy as a lens to reason about spurious correlations. We find across all datasets, SSL models exhibit gaps between worst-group and average accuracy when predicting the core feature (Table 5 in Appendix D.3).

These results indicate, that unlike supervised learning (Menon et al., 2021; Kirichenko et al., 2022; Rosenfeld et al., 2022), training of the final layer on a balanced set where the spurious correlation does not hold is not sufficient for improving worst-group accuracy when predicting the core attribute.

4.4 Resampling During SSL does not Improve Downstream Performance

To probe the effect of availability of such group information during the SSL pre-training stage, we examine whether classical approaches for combating spurious correlations, such as re-sampling training examples (Idrissi et al., 2021), are effective in removing spurious information during SSL pre-training.

Assuming that group information is available, we train SimSiam on datasets re-sampled using the following strategies: (i) Balancing groups by resampling training examples to match the downstream validation distribution. (ii) Downsampling examples in majority groups to have the same number of examples in all groups. (iii) Upsampling minority examples to have the same number of examples in all groups.

Table 2: Worst-group accuracy difference (%) between each balancing strategy and the original training set. Original training performance are shown in parentheses below each dataset. Full results can be found in Appendix Table 10.
Sampling Strategy
celebA
(77.5)
cmnist
(75.4)
metashift
(42.3)
spurcifar10
(43.4)
waterbirds
(48.3)
Balancing -1.7 -8.7 -3.8 -8.3 +3.0
Downsampling +0.3 -10.6 +3.9 -14.4 +0.5
Upsampling +4.1 -5.3 +2.7 -19.4 -0.3

We find that re-sampling during self-supervised pre-training does not improve downstream worst-group accuracy in a consistent manner as in Table 2. We do see minor improvements for metashift and celebA, but contrast this with large drops for spurcifar10 and cmnist. Given that the downstream linear model is trained on a downsampled dataset where such correlations do not exist, this means that re-sampling during self-supervised training does not necessarily improve the linear separability of representations with respect to the core feature, even given a balanced finetuning dataset. This is analogous to our findings in the toy setting in Section 3.3.

5 Creating Robust Representations via Feature Space Augmentations

In the previous sections, we showed that augmentation mechanisms used in SSL result in poor performance under spuriously correlated features in the training set. Instead of curating specific image augmentations that correct for these biases in the image space, we propose an approach to target spurious connectivity in the representation space by modifying positive pairs. In this section, we describe our approach, LateTVG that improves the performance of SSL models by introducing pruning based regularization to the later layers of the encoder.

5.1 Late-layer Transformation-based View Generation

Motivated by improved SSL model invariance when trained with augmentations in image space (Chen et al., 2020a), we propose a model transformation module that specifically targets augmentations that modify the spurious feature in representation space. We propose Late-layer Transformation-based View Generation - LateTVG , which uses feature space transformations to mitigate spurious learning in SSL models and improve learning of the core feature.

Formally, we propose using a model transformation module 𝒰𝒰\mathcal{U}caligraphic_U, that transforms any given model fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT parameterized by θ={W1,,Wn}𝜃subscript𝑊1subscript𝑊𝑛\theta=\{W_{1},\dots,W_{n}\}italic_θ = { italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } to fθ~subscript𝑓~𝜃f_{\tilde{\theta}}italic_f start_POSTSUBSCRIPT over~ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT. At each step, we draw a transformation ϕM,θ𝒰similar-tosubscriptitalic-ϕ𝑀𝜃𝒰\phi_{M,\theta}\sim\mathcal{U}italic_ϕ start_POSTSUBSCRIPT italic_M , italic_θ end_POSTSUBSCRIPT ∼ caligraphic_U to obtain the transformed encoder. Each model transformation can be defined with a mask M{0,1}|θ|𝑀superscript01𝜃M\in\{0,1\}^{|\theta|}italic_M ∈ { 0 , 1 } start_POSTSUPERSCRIPT | italic_θ | end_POSTSUPERSCRIPT, where we transform the unmasked weights (1M)θdirect-product1𝑀𝜃(1-M)\odot\theta( 1 - italic_M ) ⊙ italic_θ by ϕitalic-ϕ\phiitalic_ϕ, and keep the rest of the weights Mθdirect-product𝑀𝜃M\odot\thetaitalic_M ⊙ italic_θ the same to obtain θ~~𝜃\tilde{\theta}over~ start_ARG italic_θ end_ARG. Here, we propose a specific transformation module 𝒰𝒰\mathcal{U}caligraphic_U.

Transformations.

For mitigating spurious connectivity, we choose a simple transformation targeted towards regularizing the final layers of the encoder. In our experiments, we consider a threshold pruning transformation module, which uses magnitude pruning on a%percent𝑎a\%italic_a % of the weights in all layers deeper than L𝐿Litalic_L. More specifically, we propose a model transformation module 𝒰 Prune, L, asubscript𝒰 Prune, L, a\mathcal{U}_{\text{\;Prune, L, a}}caligraphic_U start_POSTSUBSCRIPT Prune, L, a end_POSTSUBSCRIPT, with ϕ(θ)=0italic-ϕ𝜃0\phi(\theta)=0italic_ϕ ( italic_θ ) = 0, M:=ML,a={MLlTopa(Wl)l[n]}assign𝑀subscript𝑀𝐿𝑎conditional-setdirect-productsubscriptsuperscript𝑀𝑙𝐿subscriptTop𝑎subscript𝑊𝑙𝑙delimited-[]𝑛M:=M_{L,a}=\{M^{l}_{L}\odot\text{Top}_{a}(W_{l})\mid l\in[n]\}italic_M := italic_M start_POSTSUBSCRIPT italic_L , italic_a end_POSTSUBSCRIPT = { italic_M start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ⊙ Top start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ∣ italic_l ∈ [ italic_n ] } and Topa(Wl)i,j=𝕀(|Wl(i,j)| in top a% of θ)\text{Top}_{a}(W_{l})_{i,j}=\mathbb{I}(\lvert{W_{l_{(i,j)}}\lvert}\text{ in % top }a\%\text{ of }\theta)Top start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = blackboard_I ( | italic_W start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT ( italic_i , italic_j ) end_POSTSUBSCRIPT end_POSTSUBSCRIPT | in top italic_a % of italic_θ ). Note that in this specific setting, the module transformation is deterministic (i.e. |𝒰|=1𝒰1|\mathcal{U}|=1| caligraphic_U | = 1), though our formalization also allows for random transformations such as randomized pruning or re-initialization.

Refer to caption
Figure 2: We use model transformation modules to create new views of training examples in the representation space. The introduced set of transformations removes the features learned in the final few layers, and provides final representations invariant to such transformations.

To learn these representations, given two random augmentations t,t𝒯similar-to𝑡superscript𝑡𝒯t,t^{\prime}\sim\mathcal{T}italic_t , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ caligraphic_T from the augmentation module 𝒯𝒯\mathcal{T}caligraphic_T, two views x1=t(x)subscript𝑥1𝑡𝑥x_{1}=t(x)italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_t ( italic_x ) and x2=t(x)subscript𝑥2superscript𝑡𝑥x_{2}=t^{\prime}(x)italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) are generated from an input image x𝑥xitalic_x. At each step, given a feature encoder f𝑓fitalic_f, and an augmentation module 𝒰𝒰\mathcal{U}caligraphic_U, we obtain a transformed model f~=ϕ(f)~𝑓italic-ϕ𝑓\tilde{f}=\phi(f)over~ start_ARG italic_f end_ARG = italic_ϕ ( italic_f ) with ϕ𝒰similar-toitalic-ϕ𝒰\phi\sim\mathcal{U}italic_ϕ ∼ caligraphic_U. During training, examples x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and x2subscript𝑥2x_{2}italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are respectively passed through the normal encoder v1=f(x1)subscript𝑣1𝑓subscript𝑥1v_{1}=f(x_{1})italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_f ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), and the transformed encoder v~2=f~(x2)subscript~𝑣2~𝑓subscript𝑥2\tilde{v}_{2}=\tilde{f}(x_{2})over~ start_ARG italic_v end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = over~ start_ARG italic_f end_ARG ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). Encoded feature v~2subscript~𝑣2\tilde{v}_{2}over~ start_ARG italic_v end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is now a positive example that should be close to v1subscript𝑣1v_{1}italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT in the representation space. An algorithmic representation of the method can be found in Appendix B.

Intuition for LateTVG.

When learning a discriminative process that maps data to a separable space, the variance among different subpopulations is stored in distinct regions of the network (Lee et al., 2022a). As a result, both spurious and core features, which describe the high-level data distribution, tend to reside at the end of a neural network. Thus, in LateTVG , we aim to encourage the final layers to learn more difficult features, by applying a model transformation that targets these layers, and causing the model to be invariant to final layer transformations. As pruning in supervised models have been shown to affect minority examples more than majority ones (Hooker et al., 2019), we hypothesize that our transformation can be considered as a curated view generating operation for the minority groups. In particular, pruning would contribute to “forgetting” the minority examples from the network, resulting in upweighting the loss for these examples.

5.2 Experiments

In this section, we demonstrate the efficacy of LateTVG in mitigating the dependence on spurious correlations. We use the same experimental setup as described in Section 4.1. For evaluation of LateTVG , we use our SSL-LateTVG approach during the pre-training stage. We compare this performance to SSL models pre-trained with the standard SSL-base trained with either SimSiam or SimCLR.

5.2.1 LateTVG Improves SSL Worst-Group Performance

Refer to caption
Figure 3: Downstream worst-group accuracy of SSL-Late-TVG on the metashift dataset as we vary the percentage of minority group in the downstream training set. For all cases except for extreme minority decrement, SSL-Late-TVG outperforms the baseline.

The goal of this experiment is to understand how LateTVG affects worst-group performance in downstream tasks that use SSL representations. We compare the worst group accuracy of two approaches, SSL-Base and SSL-LateTVG on 5 different datasets. Both models used similar hyper-parameter grids and model selection criteria as noted previously. The results are presented in Table 3. We show the performance of the best hyperparameter combination here, and have provided figures of performance gains for all hyperparameters in Appendix D.2. It can be clearly observed that SSL-LateTVG outperforms the base model by large margins across most datasets and for both SimSiam and SimCLR. On cmnist, our performance is very close to the baseline model and we do not see significant improvement. We hypothesize that this is due to the fact that the base encoder on the easier cmnist dataset is already quite performant. On datasets where the base encoder performs poorly such as metashift and spurcifar10, our approach improves the performance by at least 10% over base SimSiam. On a dataset of a larger scale like celebA, LateTVG still improves upon a strong encoder baseline.

Table 3: Worst-group accuracy (%) of SSL-Base and LateTVG for SimSiam and SimCLR pre-training. Results for average accuracy can be found in Table 8.

SimSiam SimCLR SSL-base SSL-Late-TVG SSL-base SSL-Late-TVG celebA 77.5 83.1 76.7 82.2 cmnist 80.7 83.1 81.7 83.8 metashift 42.3 79.6 45.5 59.3 spurcifar10 43.4 61.4 36.5 40.4 waterbirds 48.3 56.3 43.8 55.4

Further, we find that LateTVG closes the gap in performance to supervised pretraining (Table  8). We emphasize that this is an unfair comparison to begin with, since supervised pretraining requires labeled data whereas SSL does not, hence reducing the annotation budget drastically. Regardless, we find that LateTVG narrows the gap between the SSL baseline and the ERM model significantly – 17% relative improvement for cmnist to 50% in the case of spurcifar10. In the case of celebA, we even outperform the ERM baseline.

5.2.2 SSL downstream linear performance is less reliant on a balanced downstream dataset

Traditional approaches that mitigate spurious correlations in ERM-based settings assume that the downstream training set is balanced (Kirichenko et al., 2022). However, this still requires knowledge of the spurious feature, which we may not always have in practice. In this experiment, we challenge this assumption and analyze how SSL models behave when the downstream training set is imbalanced.

We vary the proportion of minority groups in the downstream training set, by first downsampling the training set to have the same number of samples across groups, and second randomly sampling minority groups with weight λ𝜆\lambdaitalic_λ (x-axis in Figure 3) and majority groups with weights 1λ1𝜆1-\lambda1 - italic_λ. We measure the worst group accuracy of the trained linear models for each dataset. We show the results on metashift in Figure 3, comparing the performance of SSL-Base and SSL-LateTVG . We can observe that LateTVG outperforms the baseline across a range of minority weights – implying that LateTVG is more robust to imbalances in downstream training data. This is a crucial aspect where LateTVG differs from other approaches in the supervised pretraining literature, such as DFR (Kirichenko et al., 2022), which requires a balanced training set for the reweighting strategy to be successful. Similar results for other datasets and linear models are provided in in Appendix F.5.

5.2.3 LateTVG reduces Spurious Connectivity in the Representation Space

Finally, we relate our method back to the theoretical analysis presented in Section 3, by computing the connectivity of the representation space learned by the SSL models, using the procedure outlined in Section E. In Table 4, we find that LateTVG empirically reduces the spurious connectivity, while increasing the invariant connectivity, for all datasets. Thus, we have shown that LateTVG successfully augments the representation space to induce desired invariances.

Table 4: We report the error of classifiers trained to distinguish between the representations of two subgroups as a proxy for connectivity terms. We find that LateTVG decreases spurious connectivity while increasing invariant connectivity in comparison to the baseline.
Dataset Representation Space
Spurious
Connectivity
Invariant
Connectivity
Opposite
Connectivity
celebA SSL-Base 18.9 15.7 8.3
SSL-Late-TVG 15.8 17.9 8.0
cmnist SSL-Base 37.3 3.2 2.7
SSL-Late-TVG 34.8 3.8 3.0
metashift SSL-Base 28.6 21.4 21.8
SSL-Late-TVG 27.3 27.3 21.3
waterbirds SSL-Base 44.9 9.4 8.4
SSL-Late-TVG 44.6 13.5 12.8

6 Conclusion

In this paper, we have investigated the impact of spurious correlations on self-supervised learning (SSL) pre-training and proposed a new approach, called LateTVG to address the issue. Our experiments demonstrated that spurious correlations caused by data augmentation can lead to spurious connectivity and hinder the model’s ability to learn core features, which ultimately impacts downstream task performance. We have shown that traditional debiasing techniques, such as re-sampling, are not effective in mitigating the impact of spurious correlations in SSL pre-training. In contrast, LateTVG effectively improves the worst-group performance in downstream tasks by inducing invariance to spurious features in the representation space throughout training. Our approach does not require access to group or label information during training and can be applied to large-scale, imbalanced datasets with spurious correlations. We believe our work will help advance the field of SSL pre-training and encourage future research in develo** methods that are robust to spurious correlations.

References

  • Abnar et al. (2021) Samira Abnar, Mostafa Dehghani, Behnam Neyshabur, and Hanie Sedghi. Exploring the limits of large scale pre-training. arXiv preprint arXiv:2110.02095, 2021.
  • Agarwal et al. (2021) Sandhini Agarwal, Gretchen Krueger, Jack Clark, Alec Radford, Jong Wook Kim, and Miles Brundage. Evaluating clip: towards characterization of broader capabilities and downstream implications. arXiv preprint arXiv:2108.02818, 2021.
  • Arjovsky et al. (2019) Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, and David Lopez-Paz. Invariant risk minimization. arXiv preprint arXiv:1907.02893, 2019.
  • Bordes et al. (2023) Florian Bordes, Randall Balestriero, Quentin Garrido, Adrien Bardes, and Pascal Vincent. Guillotine regularization: Why removing layers is needed to improve generalization in self-supervised learning. Transactions on Machine Learning Research, 2023.
  • Calude & Longo (2017) Cristian S Calude and Giuseppe Longo. The deluge of spurious correlations in big data. Foundations of science, 22:595–612, 2017.
  • Caron et al. (2020) Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, and Armand Joulin. Unsupervised learning of visual features by contrasting cluster assignments. arXiv preprint arXiv:2006.09882, 2020.
  • Caron et al. (2021) Mathilde Caron, Hugo Touvron, Ishan Misra, Hervé Jégou, Julien Mairal, Piotr Bojanowski, and Armand Joulin. Emerging properties in self-supervised vision transformers. In Proceedings of the IEEE/CVF international conference on computer vision, pp.  9650–9660, 2021.
  • Chen et al. (2020a) Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. In International conference on machine learning, pp.  1597–1607. PMLR, 2020a.
  • Chen et al. (2021) Ting Chen, Calvin Luo, and Lala Li. Intriguing properties of contrastive losses. Advances in Neural Information Processing Systems, 34:11834–11845, 2021.
  • Chen & He (2020) Xinlei Chen and Kaiming He. Exploring Simple Siamese Representation Learning. arXiv e-prints, art. arXiv:2011.10566, November 2020.
  • Chen et al. (2020b) Xinlei Chen, Haoqi Fan, Ross Girshick, and Kaiming He. Improved baselines with momentum contrastive learning. arXiv preprint arXiv:2003.04297, 2020b.
  • Creager et al. (2021) Elliot Creager, Jörn-Henrik Jacobsen, and Richard Zemel. Environment inference for invariant learning. In International Conference on Machine Learning, pp.  2189–2200. PMLR, 2021.
  • DeGrave et al. (2021) Alex J DeGrave, Joseph D Janizek, and Su-In Lee. Ai for radiographic covid-19 detection selects shortcuts over signal. Nature Machine Intelligence, 3(7):610–619, 2021.
  • Deng et al. (2009) Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp.  248–255. Ieee, 2009.
  • Doersch et al. (2015) Carl Doersch, Abhinav Gupta, and Alexei A Efros. Unsupervised visual representation learning by context prediction. In Proceedings of the IEEE international conference on computer vision, pp.  1422–1430, 2015.
  • Duchi et al. (2019) John C Duchi, Tatsunori Hashimoto, and Hongseok Namkoong. Distributionally robust losses against mixture covariate shifts. Under review, 2, 2019.
  • Fan et al. (2014) Jianqing Fan, Fang Han, and Han Liu. Challenges of big data analysis. National science review, 1(2):293–314, 2014.
  • Gao et al. (2023) Irena Gao, Shiori Sagawa, Pang Wei Koh, Tatsunori Hashimoto, and Percy Liang. Out-of-domain robustness via targeted augmentations. arXiv preprint arXiv:2302.11861, 2023.
  • Geirhos et al. (2018) Robert Geirhos, Patricia Rubisch, Claudio Michaelis, Matthias Bethge, Felix A Wichmann, and Wieland Brendel. Imagenet-trained cnns are biased towards texture; increasing shape bias improves accuracy and robustness. arXiv preprint arXiv:1811.12231, 2018.
  • Grill et al. (2020) Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, Bilal Piot, Koray Kavukcuoglu, Rémi Munos, and Michal Valko. Bootstrap your own latent: A new approach to self-supervised Learning. arXiv e-prints, art. arXiv:2006.07733, June 2020.
  • Gururangan et al. (2018) Suchin Gururangan, Swabha Swayamdipta, Omer Levy, Roy Schwartz, Samuel R Bowman, and Noah A Smith. Annotation artifacts in natural language inference data. arXiv preprint arXiv:1803.02324, 2018.
  • HaoChen et al. (2021) Jeff Z HaoChen, Colin Wei, Adrien Gaidon, and Tengyu Ma. Provable guarantees for self-supervised deep learning with spectral contrastive loss. Advances in Neural Information Processing Systems, 34:5000–5011, 2021.
  • Hashimoto et al. (2018) Tatsunori Hashimoto, Megha Srivastava, Hongseok Namkoong, and Percy Liang. Fairness without demographics in repeated loss minimization. In International Conference on Machine Learning, pp.  1929–1938. PMLR, 2018.
  • He et al. (2019) Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick. Momentum contrast for unsupervised visual representation learning. arXiv:1911.05722, 2019.
  • Hooker et al. (2019) Sara Hooker, Aaron Courville, Gregory Clark, Yann Dauphin, and Andrea Frome. What do compressed deep neural networks forget? arXiv preprint arXiv:1911.05248, 2019.
  • Idrissi et al. (2021) Badr Youbi Idrissi, Martin Arjovsky, Mohammad Pezeshki, and David Lopez-Paz. Simple data balancing achieves competitive worst-group-accuracy. arXiv preprint arXiv:2110.14503, 2021.
  • Jaiswal et al. (2020) Ashish Jaiswal, Ashwin Ramesh Babu, Mohammad Zaki Zadeh, Debapriya Banerjee, and Fillia Makedon. A survey on contrastive self-supervised learning. Technologies, 9(1):2, 2020.
  • Jiang et al. (2021a) Ziyu Jiang, Tianlong Chen, Ting Chen, and Zhangyang Wang. Improving contrastive learning on imbalanced data via open-world sampling. Advances in Neural Information Processing Systems, 34, 2021a.
  • Jiang et al. (2021b) Ziyu Jiang, Tianlong Chen, Bobak J Mortazavi, and Zhangyang Wang. Self-damaging contrastive learning. In International Conference on Machine Learning, pp.  4927–4939. PMLR, 2021b.
  • **g & Tian (2020) Longlong **g and Yingli Tian. Self-supervised visual feature learning with deep neural networks: A survey. IEEE transactions on pattern analysis and machine intelligence, 43(11):4037–4058, 2020.
  • Joshi et al. (2023) Siddharth Joshi, Yu Yang, Yihao Xue, Wenhan Yang, and Baharan Mirzasoleiman. Towards mitigating spurious correlations in the wild: A benchmark & a more realistic dataset. arXiv preprint arXiv:2306.11957, 2023.
  • Kirichenko et al. (2022) Polina Kirichenko, Pavel Izmailov, and Andrew Gordon Wilson. Last layer re-training is sufficient for robustness to spurious correlations. arXiv preprint arXiv:2204.02937, 2022.
  • Koh et al. (2021) Pang Wei Koh, Shiori Sagawa, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Irena Gao, et al. Wilds: A benchmark of in-the-wild distribution shifts. In International Conference on Machine Learning, pp.  5637–5664. PMLR, 2021.
  • Lahoti et al. (2020) Preethi Lahoti, Alex Beutel, Jilin Chen, Kang Lee, Flavien Prost, Nithum Thain, Xuezhi Wang, and Ed H. Chi. Fairness without demographics through adversarially reweighted learning, 2020.
  • Lee et al. (2022a) Yoonho Lee, Annie S Chen, Fahim Tajwar, Ananya Kumar, Huaxiu Yao, Percy Liang, and Chelsea Finn. Surgical fine-tuning improves adaptation to distribution shifts. arXiv preprint arXiv:2210.11466, 2022a.
  • Lee et al. (2022b) Yoonho Lee, Huaxiu Yao, and Chelsea Finn. Diversify and disambiguate: Learning from underspecified data. arXiv preprint arXiv:2202.03418, 2022b.
  • Liang & Zou (2022) Weixin Liang and James Zou. Metashift: A dataset of datasets for evaluating contextual distribution shifts and training conflicts. arXiv preprint arXiv:2202.06523, 2022.
  • Liu et al. (2021a) Evan Z Liu, Behzad Haghgoo, Annie S Chen, Aditi Raghunathan, Pang Wei Koh, Shiori Sagawa, Percy Liang, and Chelsea Finn. Just train twice: Improving group robustness without training group information. In International Conference on Machine Learning, pp.  6781–6792. PMLR, 2021a.
  • Liu et al. (2021b) Hong Liu, Jeff Z HaoChen, Adrien Gaidon, and Tengyu Ma. Self-supervised learning is more robust to dataset imbalance. arXiv preprint arXiv:2110.05025, 2021b.
  • Liu et al. (2015) Ziwei Liu, ** Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of the IEEE international conference on computer vision, pp.  3730–3738, 2015.
  • McCoy et al. (2019) R Thomas McCoy, Ellie Pavlick, and Tal Linzen. Right for the wrong reasons: Diagnosing syntactic heuristics in natural language inference. arXiv preprint arXiv:1902.01007, 2019.
  • Meehan et al. (2023) Casey Meehan, Florian Bordes, Pascal Vincent, Kamalika Chaudhuri, and Chuan Guo. Do ssl models have déjà vu? a case of unintended memorization in self-supervised learning, 2023.
  • Menon et al. (2021) Aditya Krishna Menon, Ankit Singh Rawat, and Sanjiv Kumar. Overparameterisation and worst-case generalisation: friend or foe? In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=jphnJNOwe36.
  • Moayeri et al. (2022) Mazda Moayeri, Sahil Singla, and Soheil Feizi. Hard imagenet: Segmentations for objects with strong spurious cues. Advances in Neural Information Processing Systems, 35:10068–10077, 2022.
  • Nagarajan et al. (2020) Vaishnavh Nagarajan, Anders Andreassen, and Behnam Neyshabur. Understanding the failure modes of out-of-distribution generalization. arXiv preprint arXiv:2010.15775, 2020.
  • Nam et al. (2020) Junhyun Nam, Hyuntak Cha, Sungsoo Ahn, Jaeho Lee, and **woo Shin. Learning from failure: De-biasing classifier from biased classifier. Advances in Neural Information Processing Systems, 33:20673–20684, 2020.
  • Oord et al. (2018) Aaron van den Oord, Yazhe Li, and Oriol Vinyals. Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748, 2018.
  • Oquab et al. (2023) Maxime Oquab, Timothée Darcet, Théo Moutakanni, Huy Vo, Marc Szafraniec, Vasil Khalidov, Pierre Fernandez, Daniel Haziza, Francisco Massa, Alaaeldin El-Nouby, et al. Dinov2: Learning robust visual features without supervision. arXiv preprint arXiv:2304.07193, 2023.
  • Robinson et al. (2021) Joshua Robinson, Li Sun, Ke Yu, Kayhan Batmanghelich, Stefanie Jegelka, and Suvrit Sra. Can contrastive learning avoid shortcut solutions? arXiv preprint arXiv:2106.11230, 2021.
  • Rosenfeld et al. (2022) Elan Rosenfeld, Pradeep Ravikumar, and Andrej Risteski. Domain-adjusted regression or: Erm may already learn features sufficient for out-of-distribution generalization. arXiv preprint arXiv:2202.06856, 2022.
  • Sagawa et al. (2020a) Shiori Sagawa, Pang Wei Koh, Tatsunori B. Hashimoto, and Percy Liang. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization, 2020a.
  • Sagawa et al. (2020b) Shiori Sagawa, Aditi Raghunathan, Pang Wei Koh, and Percy Liang. An investigation of why overparameterization exacerbates spurious correlations. In International Conference on Machine Learning, pp.  8346–8356. PMLR, 2020b.
  • Salman et al. (2022) Hadi Salman, Saachi Jain, Andrew Ilyas, Logan Engstrom, Eric Wong, and Aleksander Madry. When does bias transfer in transfer learning? arXiv preprint arXiv:2207.02842, 2022.
  • Scalbert et al. (2023) Marin Scalbert, Maria Vakalopoulou, and Florent Couzinié-Devy. Improving domain-invariance in self-supervised learning via batch styles standardization. arXiv preprint arXiv:2303.06088, 2023.
  • Selvaraju et al. (2016) Ramprasaath R Selvaraju, Abhishek Das, Ramakrishna Vedantam, Michael Cogswell, Devi Parikh, and Dhruv Batra. Grad-cam: Why did you say that? arXiv preprint arXiv:1611.07450, 2016.
  • Shah et al. (2020) Harshay Shah, Kaustav Tamuly, Aditi Raghunathan, Prateek Jain, and Praneeth Netrapalli. The pitfalls of simplicity bias in neural networks. Advances in Neural Information Processing Systems, 33:9573–9585, 2020.
  • Shen et al. (2022) Kendrick Shen, Robbie M Jones, Ananya Kumar, Sang Michael Xie, Jeff Z HaoChen, Tengyu Ma, and Percy Liang. Connect, not collapse: Explaining contrastive learning for unsupervised domain adaptation. In International Conference on Machine Learning, pp.  19847–19878. PMLR, 2022.
  • Singla & Feizi (2021) Sahil Singla and Soheil Feizi. Salient imagenet: How to discover spurious features in deep learning? arXiv preprint arXiv:2110.04301, 2021.
  • Song et al. (2019) Jiaming Song, Pratyusha Kalluri, Aditya Grover, Shengjia Zhao, and Stefano Ermon. Learning controllable fair representations. In The 22nd International Conference on Artificial Intelligence and Statistics, pp.  2164–2173. PMLR, 2019.
  • Tamkin et al. (2021) Alex Tamkin, Vincent Liu, Rongfei Lu, Daniel Fein, Colin Schultz, and Noah Goodman. Dabs: A domain-agnostic benchmark for self-supervised learning. arXiv preprint arXiv:2111.12062, 2021.
  • Torralba & Efros (2011) Antonio Torralba and Alexei A Efros. Unbiased look at dataset bias. In CVPR 2011, pp.  1521–1528. IEEE, 2011.
  • Tsai et al. (2020) Yao-Hung Hubert Tsai, Yue Wu, Ruslan Salakhutdinov, and Louis-Philippe Morency. Demystifying self-supervised learning: An information-theoretical framework. arXiv e-prints, pp.  arXiv–2006, 2020.
  • Tu et al. (2020) Lifu Tu, Garima Lalwani, Spandana Gella, and He He. An empirical study on robustness to spurious correlations using pre-trained language models. Transactions of the Association for Computational Linguistics, 8:621–633, 2020.
  • Van Horn et al. (2021) Grant Van Horn, Elijah Cole, Sara Beery, Kimberly Wilber, Serge Belongie, and Oisin Mac Aodha. Benchmarking representation learning for natural world image collections. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp.  12884–12893, 2021.
  • Wah et al. (2011) Catherine Wah, Steve Branson, Peter Welinder, Pietro Perona, and Serge Belongie. The caltech-ucsd birds-200-2011 dataset. 2011.
  • Wang et al. (2021) Tan Wang, Zhongqi Yue, Jianqiang Huang, Qianru Sun, and Hanwang Zhang. Self-supervised learning disentangled group representation as feature. Advances in Neural Information Processing Systems, 34:18225–18240, 2021.
  • Wang & Culotta (2020) Zhao Wang and Aron Culotta. Identifying spurious correlations for robust text classification. arXiv preprint arXiv:2010.02458, 2020.
  • Yang et al. (2023) Yuzhe Yang, Haoran Zhang, Dina Katabi, and Marzyeh Ghassemi. Change is hard: A closer look at subpopulation shift. In International Conference on Machine Learning, 2023.
  • Zbontar et al. (2021) Jure Zbontar, Li **g, Ishan Misra, Yann LeCun, and Stéphane Deny. Barlow twins: Self-supervised learning via redundancy reduction. In International Conference on Machine Learning, pp.  12310–12320. PMLR, 2021.
  • Zech et al. (2018) John R Zech, Marcus A Badgeley, Manway Liu, Anthony B Costa, Joseph J Titano, and Eric Karl Oermann. Variable generalization performance of a deep learning model to detect pneumonia in chest radiographs: a cross-sectional study. PLoS medicine, 15(11):e1002683, 2018.
  • Zhang et al. (2021) Dinghuai Zhang, Kartik Ahuja, Yilun Xu, Yisen Wang, and Aaron Courville. Can subnetwork structure be the key to out-of-distribution generalization? In International Conference on Machine Learning, pp.  12356–12367. PMLR, 2021.
  • Zhang et al. (2022) Michael Zhang, Nimit S Sohoni, Hongyang R Zhang, Chelsea Finn, and Christopher Ré. Correct-n-contrast: A contrastive approach for improving robustness to spurious correlations. arXiv preprint arXiv:2203.01517, 2022.
  • Zhou et al. (2022) Hattie Zhou, Ankit Vani, Hugo Larochelle, and Aaron Courville. Fortuitous forgetting in connectionist networks. arXiv preprint arXiv:2202.00155, 2022.

Appendix A Limitations

In this work, we validated our method on several benchmark datasets containing spurious correlations from prior work (Sagawa et al., 2020a; Liang & Zou, 2022; Yang et al., 2023). However, we recognize that the scale of these datasets are small, relative to typical SSL training corpora (e.g. ImageNet (Deng et al., 2009)). As these large datasets do not contain annotations of spurious features, we are unable to evaluate our method in these settings. In addition, we primarily focus on SimSiam Chen & He (2020) in our experiments, as it does not rely on large batch sizes and shows improved performance for smaller datasets. Moreover, we expect LateTVG to perform best in cases where Siamese encoders coupled with stop-gradient operation are used when learning the representations.

Appendix B LateTVG Algorithm

We provide an algorithm representation of our proposed method LateTVG in Section 5.1 as follows.

Algorithm 1 Self-supervised Learning with LateTVG
1:Inputs: Encoder f𝑓fitalic_f parameterized by θ={W1,,Wn}𝜃subscript𝑊1subscript𝑊𝑛\theta=\{W_{1},\dots,W_{n}\}italic_θ = { italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT }, Projection head and predictor g𝑔gitalic_g, Augmentation module 𝒯𝒯\mathcal{T}caligraphic_T, Threshold L𝐿Litalic_L, Pruning rate a𝑎aitalic_a, Training epochs N𝑁Nitalic_N.
2:Initialize f~~𝑓\tilde{f}over~ start_ARG italic_f end_ARG with θ𝜃\thetaitalic_θ
3:for all i=1N𝑖1𝑁i=1\rightarrow Nitalic_i = 1 → italic_N do
4:     Stage 1: Self-supervised Training  
5:     for all i=1N𝑖1𝑁i=1\rightarrow Nitalic_i = 1 → italic_N do
6:         Draw two random augmentations t,t𝒯similar-to𝑡superscript𝑡𝒯t,t^{\prime}\sim\mathcal{T}italic_t , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ caligraphic_T
7:         x1=t(x),x2=t(x)formulae-sequencesubscript𝑥1𝑡𝑥subscript𝑥2superscript𝑡𝑥x_{1}=t(x),x_{2}=t^{\prime}(x)italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_t ( italic_x ) , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) \triangleright Generate views x1,x2subscript𝑥1subscript𝑥2x_{1},x_{2}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT from input x𝑥xitalic_x using augmentation t𝑡titalic_t
8:         v1=f(x1)subscript𝑣1𝑓subscript𝑥1v_{1}=f(x_{1})italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_f ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) \triangleright Obtain encoded features from normal encoder f𝑓fitalic_f
9:         v~2=f~(x2)subscript~𝑣2~𝑓subscript𝑥2\tilde{v}_{2}=\tilde{f}(x_{2})over~ start_ARG italic_v end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = over~ start_ARG italic_f end_ARG ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) \triangleright Obtain encoded features from transformed encoder f~~𝑓\tilde{f}over~ start_ARG italic_f end_ARG
10:         =Loss(v1,v~2;g)Losssubscript𝑣1subscript~𝑣2𝑔\mathcal{L}=\text{Loss}(v_{1},\tilde{v}_{2};g)caligraphic_L = Loss ( italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_v end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ; italic_g ) \triangleright Calculate contrastive loss given views v1subscript𝑣1v_{1}italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and v~2subscript~𝑣2\tilde{v}_{2}over~ start_ARG italic_v end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
11:         Update f,g𝑓𝑔f,gitalic_f , italic_g to minimize \mathcal{L}caligraphic_L \triangleright Update the encoder and other SSL parameters      
12:     Stage 2: Model Transformation
13:      Compute the mask ML,a={MLlTopa(Wl)l[n]}subscript𝑀𝐿𝑎conditional-setdirect-productsubscriptsuperscript𝑀𝑙𝐿subscriptTop𝑎subscript𝑊𝑙𝑙delimited-[]𝑛M_{L,a}=\{M^{l}_{L}\odot\text{Top}_{a}(W_{l})\mid l\in[n]\}italic_M start_POSTSUBSCRIPT italic_L , italic_a end_POSTSUBSCRIPT = { italic_M start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ⊙ Top start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ∣ italic_l ∈ [ italic_n ] }   where Topa(Wl)i,j=𝕀(|Wl(i,j)| in top a% of θ)\text{Top}_{a}(W_{l})_{i,j}=\mathbb{I}(\lvert{W_{l_{(i,j)}}\lvert}\text{ in % top }a\%\text{ of }\theta)Top start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = blackboard_I ( | italic_W start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT ( italic_i , italic_j ) end_POSTSUBSCRIPT end_POSTSUBSCRIPT | in top italic_a % of italic_θ )
14:      Update f~~𝑓\tilde{f}over~ start_ARG italic_f end_ARG with parameters θ~=ML,aθ~𝜃direct-productsubscript𝑀𝐿𝑎𝜃\tilde{\theta}=M_{L,a}\odot\thetaover~ start_ARG italic_θ end_ARG = italic_M start_POSTSUBSCRIPT italic_L , italic_a end_POSTSUBSCRIPT ⊙ italic_θ \triangleright Magnitude pruning of weights
15:Return encoder f~~𝑓\tilde{f}over~ start_ARG italic_f end_ARG

Appendix C Theoretical Analysis of Spurious Connectivity

Setup

We consider the pre-text task of learning representations from unlabeled population data 𝒳𝒳\mathcal{X}caligraphic_X consisting of unknown groups 𝒢𝒢\mathcal{G}caligraphic_G which are not equally represented. For a given downstream task with labeled samples, we assume that each x𝒳𝑥𝒳x\in\mathcal{X}italic_x ∈ caligraphic_X belongs to one of c=|𝒴|𝑐𝒴c=|\mathcal{Y}|italic_c = | caligraphic_Y | classes, and let y:𝒳[c]:𝑦𝒳delimited-[]𝑐y:\mathcal{X}\rightarrow[c]italic_y : caligraphic_X → [ italic_c ] denote the ground-truth labeling function. Let us define a:𝒳[m]:𝑎𝒳delimited-[]𝑚a:\mathcal{X}\rightarrow[m]italic_a : caligraphic_X → [ italic_m ] as the deterministic attribute function creating groups (of potential different sizes) as 𝒢=𝒴×𝒮.𝒢𝒴𝒮\mathcal{G}=\mathcal{Y}\times\mathcal{S}.caligraphic_G = caligraphic_Y × caligraphic_S .

Spectral Contrastive Learning

In order to investigate why the invariant feature can be suppressed in contrastive learning, we consider the setting from HaoChen et al. (2021) – Spectral Contrastive learning, which achieves similar empirical results to other contrastive learning methods and is easier for theoretical analysis.

Given the set of all natural data or data without any augmentation 𝒳¯¯𝒳\overline{\mathcal{X}}over¯ start_ARG caligraphic_X end_ARG, we use 𝒜(|x¯)\mathcal{A}(\cdot|\bar{x})caligraphic_A ( ⋅ | over¯ start_ARG italic_x end_ARG ) to denote the distribution of augmentations of x¯𝒳¯¯𝑥¯𝒳\bar{x}\in\overline{\mathcal{X}}over¯ start_ARG italic_x end_ARG ∈ over¯ start_ARG caligraphic_X end_ARG. For instance, when x¯¯𝑥\bar{x}over¯ start_ARG italic_x end_ARG represents an image, 𝒜(|x¯)\mathcal{A}(\cdot|\bar{x})caligraphic_A ( ⋅ | over¯ start_ARG italic_x end_ARG ) can be the distribution of common augmentations that includes Gaussian blur, color distortion and random crop**.

Let 𝒫𝒳¯subscript𝒫¯𝒳\mathcal{P}_{\overline{\mathcal{X}}}caligraphic_P start_POSTSUBSCRIPT over¯ start_ARG caligraphic_X end_ARG end_POSTSUBSCRIPT be the population distribution over 𝒳¯¯𝒳\overline{\mathcal{X}}over¯ start_ARG caligraphic_X end_ARG from which we draw training data and test our final performance. For any two augmented data points x,x𝒳𝑥superscript𝑥𝒳x,x^{\prime}\in\mathcal{X}italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_X, the weight between a pair wxxsubscript𝑤𝑥superscript𝑥w_{xx^{\prime}}italic_w start_POSTSUBSCRIPT italic_x italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is the marginal probability of generating the pair x𝑥xitalic_x and xsuperscript𝑥x^{\prime}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT from a random data point x¯𝒫𝒳¯similar-to¯𝑥subscript𝒫¯𝒳\bar{x}\sim\mathcal{P}_{\overline{\mathcal{X}}}over¯ start_ARG italic_x end_ARG ∼ caligraphic_P start_POSTSUBSCRIPT over¯ start_ARG caligraphic_X end_ARG end_POSTSUBSCRIPT:

wxx=𝔼x¯𝒫𝒳¯[𝒜(x|x¯)𝒜(x|x¯)]subscript𝑤𝑥superscript𝑥subscript𝔼similar-to¯𝑥subscript𝒫¯𝒳delimited-[]𝒜conditional𝑥¯𝑥𝒜conditionalsuperscript𝑥¯𝑥\displaystyle w_{xx^{\prime}}=\mathrm{\mathbb{E}}_{\bar{x}\sim\mathcal{P}_{% \overline{\mathcal{X}}}}\left[\mathcal{A}(x|\bar{x})\mathcal{A}(x^{\prime}|% \bar{x})\right]italic_w start_POSTSUBSCRIPT italic_x italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT over¯ start_ARG italic_x end_ARG ∼ caligraphic_P start_POSTSUBSCRIPT over¯ start_ARG caligraphic_X end_ARG end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ caligraphic_A ( italic_x | over¯ start_ARG italic_x end_ARG ) caligraphic_A ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | over¯ start_ARG italic_x end_ARG ) ]

Define expansion between two sets similar to HaoChen et al. (2021) as below:

ϕ(S1,S2)=xS1,xS2wxxxS1wxitalic-ϕsubscript𝑆1subscript𝑆2subscriptformulae-sequence𝑥subscript𝑆1superscript𝑥subscript𝑆2subscript𝑤𝑥superscript𝑥subscript𝑥subscript𝑆1subscript𝑤𝑥\phi(S_{1},S_{2})=\frac{\sum_{x\in S_{1},x^{\prime}\in S_{2}}w_{xx^{\prime}}}{% \sum_{x\in S_{1}}w_{x}}italic_ϕ ( italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = divide start_ARG ∑ start_POSTSUBSCRIPT italic_x ∈ italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_x italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_x ∈ italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_ARG

where wx=x𝒮wxxsubscript𝑤𝑥subscriptsuperscript𝑥𝒮subscript𝑤𝑥superscript𝑥w_{x}=\sum_{x^{\prime}\in\mathcal{S}}w_{xx^{\prime}}italic_w start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_S end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_x italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. We note that this is similar to our definition of connectivity, where we have assumed the marginal distribution over x𝑥xitalic_x is uniform, or wx=1Nsubscript𝑤𝑥1𝑁w_{x}=\frac{1}{N}italic_w start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG.

Toy Setup

Let the ground-truth labeling function y𝑦yitalic_y and the deterministic attribute function a𝑎aitalic_a, determine the subgroup g=(y(x),a(x))𝑔𝑦𝑥𝑎𝑥g=(y(x),a(x))italic_g = ( italic_y ( italic_x ) , italic_a ( italic_x ) ) of a given sample x𝑥xitalic_x. We suppose we have n𝑛nitalic_n samples from each subgroup, and that labels and attributes take binary values111For an ease of notation and operations.

Suppose that each edge in the augmentation graph is given by connectivity terms α𝛼\alphaitalic_α, β𝛽\betaitalic_β, ρ𝜌\rhoitalic_ρ, γ𝛾\gammaitalic_γ as below:

x,x𝒳:P+(x,x)\displaystyle\forall x,x^{\prime}\in\mathcal{X}:\quad P_{+}(x,x^{\prime})∀ italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_X : italic_P start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) =𝟙(a(x)=a(x),y(x)=y(x))ρabsent1formulae-sequence𝑎𝑥𝑎superscript𝑥𝑦𝑥𝑦superscript𝑥𝜌\displaystyle=\mathds{1}(a(x)=a(x^{\prime}),y(x)=y(x^{\prime}))\rho= blackboard_1 ( italic_a ( italic_x ) = italic_a ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , italic_y ( italic_x ) = italic_y ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) italic_ρ
+𝟙(a(x)a(x),y(x)=y(x))α1formulae-sequence𝑎𝑥𝑎superscript𝑥𝑦𝑥𝑦superscript𝑥𝛼\displaystyle+\mathds{1}(a(x)\neq a(x^{\prime}),y(x)=y(x^{\prime}))\alpha+ blackboard_1 ( italic_a ( italic_x ) ≠ italic_a ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , italic_y ( italic_x ) = italic_y ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) italic_α
+𝟙(a(x)=a(x),y(x)y(x))β1formulae-sequence𝑎𝑥𝑎superscript𝑥𝑦𝑥𝑦superscript𝑥𝛽\displaystyle+\mathds{1}(a(x)=a(x^{\prime}),y(x)\neq y(x^{\prime}))\beta+ blackboard_1 ( italic_a ( italic_x ) = italic_a ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , italic_y ( italic_x ) ≠ italic_y ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) italic_β
+𝟙(a(x)a(x),y(x)y(x))γ1formulae-sequence𝑎𝑥𝑎superscript𝑥𝑦𝑥𝑦superscript𝑥𝛾\displaystyle+\mathds{1}(a(x)\neq a(x^{\prime}),y(x)\neq y(x^{\prime}))\gamma+ blackboard_1 ( italic_a ( italic_x ) ≠ italic_a ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , italic_y ( italic_x ) ≠ italic_y ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) italic_γ

We suppose that each edge in the augmentation graph is deterministically equal to one of the connectivity terms, and make the following assumptions:

  1. 1.

    α>γ,β>γformulae-sequence𝛼𝛾𝛽𝛾\alpha>\gamma,\beta>\gammaitalic_α > italic_γ , italic_β > italic_γ – The probability that augmentation changes the spurious attribute only, or the class only is both greater than the probability that augmentation changes both attribute and class (at the same time).

  2. 2.

    ρ>α,ρ>βformulae-sequence𝜌𝛼𝜌𝛽\rho>\alpha,\rho>\betaitalic_ρ > italic_α , italic_ρ > italic_β – The probability that augmentation that keeps both attribute and class is greater than the probability that it changes the spurious attribute only, or the class only is both higher than the probability that augmentation changes both domain and class (at the same time).

  3. 3.

    α>β𝛼𝛽\alpha>\betaitalic_α > italic_β or Assumption 3.2 – The probability that augmentation changes the spurious feature is higher than the probability of it changing the class, as observed in 4.

C.1 Proof of Lemma 3.3

Proof.

Let the A4n×4n𝐴superscript4𝑛4𝑛A\in\mathbb{R}^{4n\times 4n}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT 4 italic_n × 4 italic_n end_POSTSUPERSCRIPT be the adjacency matrix of the simplified augmentation graph. It is easy to show that A𝐴Aitalic_A is equivalent to adjacency matrix A¯¯𝐴\bar{A}over¯ start_ARG italic_A end_ARG up to a rotation where:

A¯=¯𝐴absent\displaystyle\bar{A}=over¯ start_ARG italic_A end_ARG = (βγ)I2(𝟏2𝟏2)(𝟏n𝟏n)tensor-product𝛽𝛾subscript𝐼2subscript12superscriptsubscript12topsubscript1𝑛superscriptsubscript1𝑛top\displaystyle(\beta-\gamma)\cdot I_{2}\otimes\left(\mathbf{1}_{2}\mathbf{1}_{2% }^{\top}\right)\otimes\left(\mathbf{1}_{n}\mathbf{1}_{n}^{\top}\right)( italic_β - italic_γ ) ⋅ italic_I start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊗ ( bold_1 start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ⊗ ( bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT )
+(αγ)(𝟏2𝟏2)I2(𝟏n𝟏n)tensor-product𝛼𝛾subscript12superscriptsubscript12topsubscript𝐼2subscript1𝑛superscriptsubscript1𝑛top\displaystyle+(\alpha-\gamma)\cdot\left(\mathbf{1}_{2}\mathbf{1}_{2}^{\top}% \right)\otimes I_{2}\otimes\left(\mathbf{1}_{n}\mathbf{1}_{n}^{\top}\right)+ ( italic_α - italic_γ ) ⋅ ( bold_1 start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ⊗ italic_I start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊗ ( bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT )
+(ρβα+γ)I4(𝟏n𝟏n)tensor-product𝜌𝛽𝛼𝛾subscript𝐼4subscript1𝑛superscriptsubscript1𝑛top\displaystyle+(\rho-\beta-\alpha+\gamma)\cdot I_{4}\otimes\left(\mathbf{1}_{n}% \mathbf{1}_{n}^{\top}\right)+ ( italic_ρ - italic_β - italic_α + italic_γ ) ⋅ italic_I start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ⊗ ( bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT )
+γ(𝟏4𝟏4)(𝟏n𝟏n)tensor-product𝛾subscript14superscriptsubscript14topsubscript1𝑛superscriptsubscript1𝑛top\displaystyle+\gamma\cdot\left(\mathbf{1}_{4}\mathbf{1}_{4}^{\top}\right)% \otimes\left(\mathbf{1}_{n}\mathbf{1}_{n}^{\top}\right)+ italic_γ ⋅ ( bold_1 start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ⊗ ( bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT )

Where 𝟏ksubscript1𝑘\mathbf{1}_{k}bold_1 start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is used to denote the all-one vector of dimension k𝑘kitalic_k and let 𝟏¯ksubscript¯1𝑘\bar{\mathbf{1}}_{k}over¯ start_ARG bold_1 end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT be the normalized version.

For the case of n=1𝑛1n=1italic_n = 1, it is easy to show that the matrix is reduced to an adjacency matrix of 4 nodes, each in one group, where the first two rows/columns correspond to samples with the same spurious attribute, and odd or even rows correspond to samples that are from the same class, based on the placements of α𝛼\alphaitalic_α and β𝛽\betaitalic_β in the matrix.

Let F𝐹Fitalic_F be an embedding matrix with uxsubscript𝑢𝑥u_{x}italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT on the x𝑥xitalic_x-th row which corresponds to the embeddings of sample x𝑥xitalic_x, and consider the matrix factorization based form of the spectral contrastive loss as below

minFN×kmf(F):=A¯FFF2assignsubscript𝐹superscript𝑁𝑘subscriptmf𝐹superscriptsubscriptnorm¯𝐴𝐹superscript𝐹top𝐹2\min_{F\in\mathbb{R}^{N\times k}}\mathcal{L}_{\mathrm{mf}}(F):=\left\|\bar{A}-% FF^{\top}\right\|_{F}^{2}roman_min start_POSTSUBSCRIPT italic_F ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_k end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_mf end_POSTSUBSCRIPT ( italic_F ) := ∥ over¯ start_ARG italic_A end_ARG - italic_F italic_F start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

It is enough to compute the eigenvectors of A¯,¯𝐴\bar{{A}},over¯ start_ARG italic_A end_ARG , to obtain F𝐹Fitalic_F. It is easy to compute the eigenvectors of A¯¯𝐴\bar{{A}}over¯ start_ARG italic_A end_ARG similar to Shen et al. (2022). The set of four sets of eigenvectors would be as below:

  • For eigenvalue λ1=ρ+β+α+γsubscript𝜆1𝜌𝛽𝛼𝛾\lambda_{1}=\rho+\beta+\alpha+\gammaitalic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_ρ + italic_β + italic_α + italic_γ,   the eigenvector is 𝟏¯2𝟏¯2𝟏¯ntensor-productsubscript¯12subscript¯12subscript¯1𝑛\bar{\mathbf{1}}_{2}\otimes\bar{\mathbf{1}}_{2}\otimes\bar{\mathbf{1}}_{n}over¯ start_ARG bold_1 end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊗ over¯ start_ARG bold_1 end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊗ over¯ start_ARG bold_1 end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

  • For eigenvalue λ2=ρ+βαγsubscript𝜆2𝜌𝛽𝛼𝛾\lambda_{2}=\rho+\beta-\alpha-\gammaitalic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_ρ + italic_β - italic_α - italic_γ   the eigenvectors are [11]T𝟏¯2𝟏¯ntensor-productsuperscriptdelimited-[]11𝑇subscript¯12subscript¯1𝑛[1\;-1]^{T}\otimes\bar{\mathbf{1}}_{2}\otimes\bar{\mathbf{1}}_{n}[ 1 - 1 ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ over¯ start_ARG bold_1 end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊗ over¯ start_ARG bold_1 end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

  • For eigenvalue λ3=ρβ+αγsubscript𝜆3𝜌𝛽𝛼𝛾\lambda_{3}=\rho-\beta+\alpha-\gammaitalic_λ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = italic_ρ - italic_β + italic_α - italic_γ  the eigenvectors are 𝟏¯2[11]T𝟏¯ntensor-productsubscript¯12superscriptdelimited-[]11𝑇subscript¯1𝑛\bar{\mathbf{1}}_{2}\otimes[1\;-1]^{T}\otimes\bar{\mathbf{1}}_{n}over¯ start_ARG bold_1 end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊗ [ 1 - 1 ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ over¯ start_ARG bold_1 end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

  • λ4=ρβα+γsubscript𝜆4𝜌𝛽𝛼𝛾\lambda_{4}=\rho-\beta-\alpha+\gammaitalic_λ start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT = italic_ρ - italic_β - italic_α + italic_γ which is smaller than the first three eigenvalues, given the above assumptions.

Thus F𝐹Fitalic_F would be a rank-3 matrix with columns equal to λisubscript𝜆𝑖\sqrt{\lambda_{i}}square-root start_ARG italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG multiplied by each eigenvector. Given the case of n=1𝑛1n=1italic_n = 1 explained above and by induction, it is easy to show that λ2subscript𝜆2\lambda_{2}italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT corresponds to the spurious attribute subspace, and λ3subscript𝜆3\lambda_{3}italic_λ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT corresponds to the class. Projecting samples in A¯¯𝐴\bar{A}over¯ start_ARG italic_A end_ARG with representations as rows of F𝐹Fitalic_F, onto the spurious subspace suggests that the spurious feature takes two values {λ2,λ2}subscript𝜆2subscript𝜆2\{-\sqrt{\lambda_{2}},\sqrt{\lambda_{2}}\}{ - square-root start_ARG italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG , square-root start_ARG italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG }, and similarly, the invariant feature takes two values {λ3,λ3}subscript𝜆3subscript𝜆3\{-\sqrt{\lambda_{3}},\sqrt{\lambda_{3}}\}{ - square-root start_ARG italic_λ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_ARG , square-root start_ARG italic_λ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_ARG } in the representation space learned by spectral contrastive loss. ∎

Intuitively, this means that with higher spurious connectivity —or higher weights on edges connecting images that only share the same spurious attribute— spectral clustering will learn representations of the population data based on the spurious feature, rather than the invariant feature.

Appendix D Data and Models

D.1 Datasets

We make use of the following four image datasets:

  • celebA (Liu et al., 2015): Gender (Male, Female) is spuriously correlated with Hair color (blond hair, not blond hair).

  • waterbirds (Sagawa et al., 2020a): Background (land, water) is spuriously correlated with bird type (landbird, waterbird).

  • cmnist (Colored MNIST): The color of the digit on the images is spuriously correlated with the binary class based on the number. This is the same setup as  Arjovsky et al. (2019), except with no label flip**.

  • spurcifar10 (Spurious CIFAR10) (Nagarajan et al., 2020): The color of lines on the images spuriously correlated with the class.

  • metashift (Liang & Zou, 2022) We consider the Cats vs Dogs task where Background (indoor, outdoor) is spuriously correlated with pet type (cat, dog).

Note that each data contains both labels (or core attribute) y𝑦yitalic_y, and spurious attribute a𝑎aitalic_a. We then use the group information g=(y,a)𝑔𝑦𝑎g=(y,a)italic_g = ( italic_y , italic_a ) to partition dataset splits into groups.

D.2 Methods and Hyperparameters

We use SimSiam (Chen & He, 2020) with ResNet encoders to train both base models and LateTVG . We select ResNet-18 models as the backbone for all datasets except for celebA, which we use ResNet-50 models.

For each dataset, we use the following set of hyperparameters for SimSiam training.

Dataset Learning Rate Batch Size Weight Decay Number of Epochs
celebA 0.01 128 1e-4 400
cmnist 1e-3 128 1e-5 1000
metashift 0.05 256 0.001 400
spurcifar10 0.02 128 5e-4 800
waterbirds 0.01 64 1e-3 800

The specific augmentations that we used for learning the representations, are exactly similar to the SimSiam Chen & He (2020) paper but without color jitter.

Note that the model architecture and parameters for SSL-Base and SSL-Late-TVG are exactly the same, but SSL-Late-TVG uses the pruning hyperparameters to prune the encoder during training.

Computational Cost

The SSL-LateTVG model updates the same number of parameters as SSL-Base during training, with the forward pass kee** both the original and pruned encoder. The pruning operation is cost O(n) where n is the number of parameters. So any FLOPs used for the extra pruning mechanism will be very small compared to a single forward pass.

D.3 The Role of Downstream Regularization

We investigate the impact of regularization techniques during downstream Linear probing. Interestingly, we find that the presence and type of regularization has a notable effect on the accuracy of the worst-performing group, with improvements of approximately 10% on the celebA dataset and 7% on the metashift dataset. We hypothesize that the minority samples contribute more to the variance of the linear models, and the additional regularization helps penalize them, leading to a reduction in the variance of the downstream models.

Table 5: Accuracy (%) of SimSiam models pretrained on each dataset with random initialization.
Average Worst-Group
None L1 L2 None L1 L2
celebA 78.5 81.9 82.8 66.1 77.5 76.7
cmnist 80.5 80.8 82.7 76.3 78.9 81.3
metashift 54.2 59.8 56.3 41.8 48.1 45.6
spurcifar10 69.3 72.5 73.4 41.1 45.1 49.0
waterbirds 52.0 51.2 50.5 47.4 47.5 47.2

Appendix E Measuring Spurious Connectivity in Augmentations

In this section, we present our methodology for measuring spurious connectivity in augmentations. We conduct experiments on four datasets, and our goal is to quantify the extent to which samples within the training set are connected to each other through the spurious attribute, as opposed to the core feature.

To estimate the average connectivity between two groups, denoted as g1subscript𝑔1g_{1}italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and g2subscript𝑔2g_{2}italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, specified by class-attribute pairs (y,a)𝑦𝑎(y,a)( italic_y , italic_a ) and (y,a)superscript𝑦superscript𝑎(y^{\prime},a^{\prime})( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ), we follow the algorithm outlined below:

Initially, we label all training examples belonging to group g1subscript𝑔1g_{1}italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT or class y𝑦yitalic_y and attribute a𝑎aitalic_a as 0, and all training examples belonging to group g2subscript𝑔2g_{2}italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Next, we train a classifier to distinguish between the two groups. The error of this classifier would be a proxy for “the probability of augmented images being assigned to the other group”, or how close they are in the augmentation space. Instead of training a large classifier from scratch for each pair, we use CLIP’s representations in Section 4.2, and assume that it is extracting all necessary features for distinguishing between the two groups. In Section 5.2.3, we instead use the representations learned by each SSL model.

We train a linear model on these features to distinguish between each of the two groups. It is important to note that the augmentations used in our experiments are the classical augmentations commonly employed in SimSiam, excluding Gaussian blur. Subsequently, we create the test set following a similar process, where images are labeled based on their group or class-attribute pairs. The trained linear classifier is evaluated on this strongly augmented test set. The test error of the classifier serves as an estimate for the connectivity between the two pairs, providing insights into the degree of connectivity based on the spurious attribute.

By applying this methodology to all four datasets, we obtain results regarding the average spurious connectivity compared to the invariant connectivity. Table 4 summarizes the findings, revealing that, across all datasets, the average spurious connectivity is higher than the invariant connectivity. Furthermore, we validate that both these connectivity values are higher than the probability of simultaneously changing both the spurious attribute and the invariant attribute. These observations indicate that the samples within the training set are more likely to be connected to each other through the spurious attribute, rather than the core feature. This finding suggests a preference of the contrastive loss for alignment based on the spurious attribute rather than class alignment.

Appendix F Additional Results for LateTVG

F.1 LateTVG reduces background reliance in Hard ImageNet

We evaluate LateTVG on the Hard ImageNet dataset (Moayeri et al., 2022), which consists of 15 challenging ImageNet classes where models rely heavily on spurious correlations. The authors provide spuriousness rankings that enable creating a balanced subset.

In our experiments, we train the SSL model on the full Hard ImageNet train split, and train the linear classifier on the spurious-balanced subset. This tests the model’s ability to learn representations without exploiting spurious cues.

We then evaluate the downstream classifier on four different dataset splits as below:

  • None: Original test split

  • Gray: The object region is grayed out by replacing RGB values with the mean RGB value. This removes texture/color cues.

  • Gray BBox: The object region is removed by replacing it with the mean RGB value of the surrounding bounding box region. This ablates shape cues.

  • Tile: The object region is replaced by tiling the surrounding bounding box region. This also ablates shape cues.

A classifier relying on the spurious (i.e. non-object) features will achieve high performance in all evaluation splits. However, a classifier relying on the invariant features should perform decently on the original test split, but exhibit greatly reduced accuracy on the other splits. Thus, we desire high accuracy for the None split, and low accuracy for the other three splits.

Comparing the results to section 7 from (Moayeri et al., 2022), we find that the gap between None and other three splits is already large in SSL-base, and SSL-LateTVG is further decreasing the accuracy in the spurious datasets. This shows that the SSL-LateTVG encoder relies less on the spurious feature to predict the labels, which degrades the performance on splits that try remove the core feature.

We do not tune the hyperparameters in this experiment, but we find that for all sets of hyperparameters, SSL-LateTVG results in lower downstream accuracy on Gray, Gray BBox, and Tile splits as shown in Table 6.

Algorithm Pruning threshold, percentage None \uparrow Gray \downarrow Gray BBox \downarrow Tile \downarrow SSL-Base - 79.5 61.6 53.5 58.1 SSL-LateTVG 46, 0.5 78.0 59.5 51.1 52.1 47, 0.5 76.7 59.1 49.6 54.4 48, 0.8 73.9 56.1 48.0 51.3 49, 0.8 68.4 50.7 42.4 44.7

Table 6: We train SimSiam models with a ResNet-50 backbone on unlabeled data from Hard ImageNet containing spurious correlation, we then train the downstream linear classifier on a balanced subset, and evaluate the downstream model on splits containing spurious features – LateTVG degrades the performance on these splits, without hyperparameter tuning

F.2 LateTVG closes the gap to supervised pre-training

Self-supervised pretraining has shown a lot of promise in bridging the gap to supervised approaches in general representation learning. In this section, we explore whether this trend holds true for pre-training with data containing spurious correlations. To perform this analysis, we start with the same encoder model and vary only the pretraining strategy while fixing other aspects of the training, such hyperparameter selection and model selection.

Table 7: Accuracy (%) of SSL models pre-trained on each dataset versus features of a supervised model: Representations obtained from the supervised featurizer are more predictive of the core feature than SimCLR and SimSiam featurizers

Average Accuracy Worst-Group Accuracy SimCLR SimSiam Supervised SimCLR SimSiam Supervised celebA 82.1 81.9 91.9 76.7 77.5 81.7 cmnist 82.5 82.1 98.4 81.7 80.7 94.9 metashift 55.1 55.8 89.8 45.5 42.3 83.5 spurcifar10 69.3 75.1 89.9 36.5 43.4 79.6 waterbirds 47.5 50.7 67.9 43.8 48.3 41.1

We emphasize that this is an unfair comparison to begin with, since supervised pretraining requires labeled data whereas SSL does not, hence reducing the annotation budget drastically as shown in table 7. However, the goal of this experiment to understand to what extent do SSL models and specifically LateTVG , compare with ERM based supervised pretraining strategies.

Table 8 shows the results of our experiment – we have compared both average and worst group accuracies for the SSL-based and ERM-based encoders across all our evaluation datasets. In terms of worst group accuracy it is clear that LateTVG narrows the gap between the SSL baseline and the ERM model significantly – 17% relative improvement for cmnist to 50% in the case of spurcifar10. In the case of celebA, we even outperform the ERM baseline. Similar to previous experiments, the relative boost in performance from LateTVG is higher for cases where the base encoder is weaker, indicating the strength of our final layer augmentation in extracting useful signal relevant to the core features during pretraining.

Table 8: LateTVG with SimSiam closes the gap between SSL baseline and supervised pre-training on worst group and average accuracy.

Average Accuracy Worst-group Accuracy SSL-base SSL-Late-TVG Supervised SSL-base SSL-Late-TVG Supervised celebA 81.9 88.9 91.9 77.5 83.1 81.7 cmnist 82.1 80.6 98.4 80.7 83.1 94.9 metashift 55.8 70.1 89.8 42.3 79.6 83.5 spurcifar10 75.1 76.1 89.9 43.4 61.4 79.6 waterbirds 50.7 54.8 67.9 48.3 56.3 41.1

F.3 SSL-LateTVG outperforms baseline across hyper-parameter settings

Disrupting the features and creating new views of the pairs is possible even with small amounts of pruning. We run a grid-search over the last three to five convolutional layers of ResNet models depending on the dataset, and choose pruning percentages varying between [0.5,0.7,0.8,0.9,0.95]0.50.70.80.90.95[0.5,0.7,0.8,0.9,0.95][ 0.5 , 0.7 , 0.8 , 0.9 , 0.95 ]. We find that in the metashift dataset, all hyperparameter settings improve the worst-group accuracy and outperform the baseline. Average and worst-group accuracies of different pruning hyperparameters on the metashift and celebA datasets is show in in figure 4.

Refer to caption
Refer to caption
Figure 4: Downstream worst-group accuracy of SSL-Late-TVG on the metashift (left) and celebA (right) datasets as we vary the model pruning hyperparameters.

Additionally, instead of choosing the best-performing model, we consider top 5 models across different pruning hyperparameters, and report the performance in Table 9. Even in this scenario, we observe large performance gains with LateTVG .

Table 9: LateTVG improves baseline worst group and average accuracy of SSL models.

Worst-group Accuracy celebA cmnist metashift spurcifar10 waterbirds SSL-Base 77.52 80.7±plus-or-minus\pm±2.71 42.33±plus-or-minus\pm±2.32 43.44±plus-or-minus\pm±8.87 48.3±plus-or-minus\pm±1.82 SSL-Late-TVG 81.83±plus-or-minus\pm±1.75 77.18±plus-or-minus\pm±1.59 60.34±plus-or-minus\pm±0.97 54.58±plus-or-minus\pm±1.74 51.87±plus-or-minus\pm±2.37 Average Accuracy celebA cmnist metashift spurcifar10 waterbirds SSL-Base 81.94 82.08±plus-or-minus\pm±1.17 55.8±plus-or-minus\pm±2.11 75.05±plus-or-minus\pm±0.19 50.68±plus-or-minus\pm±1.27 SSL-Late-TVG 87.32±plus-or-minus\pm±1.46 79.74±plus-or-minus\pm±1.19 69.7±plus-or-minus\pm±2.09 75.68±plus-or-minus\pm±0.72 55.36±plus-or-minus\pm±0.72

F.4 What features does LateTVG learn?

Recall that we motivated LateTVG by explaining that more difficult features could be learned in the later layers of an encoder, and by removing the spurious feature from the encoder, we force the model to learn more complex features. In this section, we use Grad-CAM Selvaraju et al. (2016) to compare the SSL-base and SSL-LateTVG . We consider the representations that SSL-base and SSL-LateTVG learn for metashift, and use that to visualize the final layer of the encoder. We choose the best-performing LateTVG model based on downstream worst-group accuracy. We visualize the parts of the image that both SSL-Base and LateTVG attend to, in majority 5, and minority 6 groups.

Refer to caption
Figure 5: We use Grad-CAM to explain the ResNet-18 SSL-base (top), and SSL-LateTVG model (bottom) for majority examples
Refer to caption
Figure 6: We use Grad-CAM to explain the ResNet-18 SSL-base (top), and SSL-LateTVG model (bottom) for minority examples

F.5 Additional Downstream Imbalance Results

For both the best downstream linear model chosen based on worst-group accuracy, and linear models with no regularization, we observe the same trend for the datasets shown in Figure 7.

Refer to caption
Refer to caption
Figure 7: Effect of changing minority weight in downstream training set on metashift. Left (no regularization), Right (Downstream hyperparameters tuned)

Appendix G Spurious Learning in Self-supervised Represetations

G.1 Additional Re-sampling Results

We present the complete table from experiment in section 4.4.

Table 10: Downstream performance Accuracy (%) of linear models; For each dataset, we pre-train the model on up-sampled, down-sampled, and balanced training sets
Dataset SSL Train Set Average Worst Group
celebA Balanced 86.4 75.8
Downsampled 83.2 77.8
Original 81.9 77.5
Upsampled 86.3 81.6
cmnist Balanced 75.4 72.0
Downsampled 74.7 70.1
Original 82.1 80.7
Upsampled 77.7 75.4
metashift Balanced 60.7 38.5
Downsampled 55.7 46.2
Original 55.8 42.3
Upsampled 64.4 45.1
spurcifar10 Balanced 68.7 35.2
Downsampled 53.1 29.0
Original 75.0 43.4
Upsampled 57.4 24.1
waterbirds Balanced 53.1 51.3
Downsampled 51.0 48.8
Original 50.7 48.3
Upsampled 55.2 48.0

G.2 ImageNet Pre-trained Self-supervised Models

We obtain pre-trained ResNet50 encoders with SimSiam, SimCLR, and CLIP training strategies, and evaluate the accuracy of core feature prediction similar to the previous sections.

Table 11: Worst-group Accuracy (%) of ImageNet pre-trained models when evaluated on each dataset using a linear probe.
Dataset CLIP ERMin SimCLRin SimSiamin
celebA 87.2 79.4 87.2 84.4
cmnist 88.9 85.4 85.5 86.4
metashift 83.1 83.1 79.7 69.5
waterbirds 81.1 84.3 78.8 79.3