11affiliationtext: University of Illinois at Urbana Champaign22affiliationtext: Stanford University

Causally Inspired Regularization Enables Domain General Representations

Olawale Salaudeen Contact: [email protected] Sanmi Koyejo
Abstract

Given a causal graph representing the data-generating process shared across different domains/distributions, enforcing sufficient graph-implied conditional independencies can identify domain-general (non-spurious) feature representations. For the standard input-output predictive setting, we categorize the set of graphs considered in the literature into two distinct groups: (i) those in which the empirical risk minimizer across training domains gives domain-general representations and (ii) those where it does not. For the latter case (ii), we propose a novel framework with regularizations, which we demonstrate are sufficient for identifying domain-general feature representations without a priori knowledge (or proxies) of the spurious features. Empirically, our proposed method is effective for both (semi) synthetic and real-world data, outperforming other state-of-the-art methods in average and worst-domain transfer accuracy.

1 Introduction

A key feature of machine learning is its capacity to generalize across new domains. When these domains present different data distributions, the algorithm must leverage shared structural concepts to achieve out-of-distribution (OOD) or out-of-domain generalization. This capability is vital in numerous important real-world machine learning applications. For example, in safety-critical settings such as autonomous driving, a lack of resilience to unfamiliar distributions could lead to human casualties. Likewise, in the healthcare sector, where ethical considerations are critical, an inability to adjust to shifts in data distribution can result in unfair biases, manifesting as inconsistent performance across different demographic groups.

An influential approach to domain generalization is Invariant Causal Prediction (ICP; (Peters et al., 2016)). ICP posits that although some aspects of data distributions (like spurious or non-causal mechanisms (Pearl, 2010)) may change across domains, certain causal mechanisms remain constant. ICP suggests focusing on these invariant mechanisms for prediction. However, the estimation method for these invariant mechanisms suggested by (Peters et al., 2016) struggles with scalability in high-dimensional feature spaces. To overcome this, Arjovsky et al. (2019) introduced Invariant Risk Minimization (IRM), designed to identify these invariant mechanisms by minimizing an objective. However, requires strong assumptions for identifying the desired domain-general solutions (Ahuja et al., 2021; Rosenfeld et al., 2022); for instance, observing a number of domains proportional to the spurious features’ dimensions is necessary, posing a significant challenge in these high-dimensional settings.

Subsequent variants of IRM have been developed with improved capabilities for identifying domain-general solutions (Ahuja et al., 2020; Krueger et al., 2021; Robey et al., 2021; Wang et al., 2022; Ahuja et al., 2021). Additionally, regularizers for Distributionally Robust Optimization with subgroup shift have been proposed (GroupDRO) (Sagawa et al., 2019). However, despite their solid theoretical motivation, empirical evidence suggests that these methods may not consistently deliver domain-general solutions in practice Gulrajani and Lopez-Paz (2020); Kaur et al. (2022); Rosenfeld et al. (2022).

Kaur et al. (2022) demonstrated that regularizing directly for conditional independencies implied by the generative process can give domain-general solutions, including conditional independencies beyond those considered by IRM. However, their experimental approach involves regularization terms that require direct observation of spurious features, a condition not always feasible in real-world applications. Our proposed methodology also leverages regularizers inspired by the conditional independencies indicated by causal graphs but, crucially, it does so without necessitating prior knowledge (or proxies) of the spurious features.

1.1 Contributions

In this work,

  • we outline sufficient properties to uniquely identify domain-general predictors for a general set of generative processes that include domain-correlated spurious features,

  • we propose regularizers to implement these constraints without independent observations of the spurious features, and

  • finally, we show that the proposed framework outperforms the state-of-the-art on semi-synthetic and real-world data.

The code for our proposed method is provided at https://github.com/olawalesalaudeen/tcri.

Notation:

Capital letters denote bounded random variables, and corresponding lowercase letters denote their value. Unless otherwise stated, we represent latent domain-general features as Zdg𝒵dgmsubscript𝑍dgsubscript𝒵dgsuperscript𝑚Z_{\text{dg}}\in\mathcal{Z}_{\text{dg}}\equiv\mathbb{R}^{m}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ∈ caligraphic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ≡ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT and spurious latent features as Zspu𝒵spuosubscript𝑍spusubscript𝒵spusuperscript𝑜Z_{\text{spu}}\in\mathcal{Z}_{\text{spu}}\equiv\mathbb{R}^{o}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ∈ caligraphic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ≡ blackboard_R start_POSTSUPERSCRIPT italic_o end_POSTSUPERSCRIPT. Let X𝒳d𝑋𝒳superscript𝑑X\in\mathcal{X}\equiv\mathbb{R}^{d}italic_X ∈ caligraphic_X ≡ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT be the observed feature space and the output space of an invertible function Γ:𝒵dg×𝒵spu𝒳:Γmaps-tosubscript𝒵dgsubscript𝒵spu𝒳\Gamma:\mathcal{Z}_{\text{dg}}\times\mathcal{Z}_{\text{spu}}\mapsto\mathcal{X}roman_Γ : caligraphic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT × caligraphic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ↦ caligraphic_X and Y𝒴{0,1,,K1}𝑌𝒴01𝐾1Y\in\mathcal{Y}\equiv\{0,1,\ldots,K-1\}italic_Y ∈ caligraphic_Y ≡ { 0 , 1 , … , italic_K - 1 } be the observed label space for a K𝐾Kitalic_K-class classification task. We then define feature extractors aimed at identifying latent features Φdg:𝒳m:subscriptΦdgmaps-to𝒳superscript𝑚\Phi_{\text{dg}}:\mathcal{X}\mapsto\mathbb{R}^{m}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT : caligraphic_X ↦ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT, Φspu:𝒳o:subscriptΦspumaps-to𝒳superscript𝑜\Phi_{\text{spu}}:\mathcal{X}\mapsto\mathbb{R}^{o}roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT : caligraphic_X ↦ blackboard_R start_POSTSUPERSCRIPT italic_o end_POSTSUPERSCRIPT so that Φ:𝒳m+o(that is Φ(x)=[Φdg(x);Φspu(x)]x𝒳):Φmaps-to𝒳superscript𝑚𝑜that is Φ𝑥subscriptΦdg𝑥subscriptΦspu𝑥for-all𝑥𝒳\Phi:\mathcal{X}\mapsto\mathbb{R}^{m+o}\,\big{(}\text{that is }\Phi(x)=[\Phi_{% \text{dg}}(x);\Phi_{\text{spu}}(x)]\forall x\in\mathcal{X}\big{)}roman_Φ : caligraphic_X ↦ blackboard_R start_POSTSUPERSCRIPT italic_m + italic_o end_POSTSUPERSCRIPT ( that is roman_Φ ( italic_x ) = [ roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_x ) ; roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_x ) ] ∀ italic_x ∈ caligraphic_X ). We define e𝑒eitalic_e as a discrete random variable denoting domains and ={Pe(Zdg,Zspu,X,Y):e=1,2,}conditional-setsuperscript𝑃𝑒subscript𝑍dgsubscript𝑍spu𝑋𝑌𝑒12\mathcal{E}=\{P^{e}(Z_{\text{dg}},Z_{\text{spu}},X,Y):e=1,2,\ldots\}caligraphic_E = { italic_P start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT , italic_X , italic_Y ) : italic_e = 1 , 2 , … } to be the set of possible domains. trsubscript𝑡𝑟\mathcal{E}_{tr}\subset\mathcal{E}caligraphic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT ⊂ caligraphic_E is the set of observed domains available during training.

2 Related Work

The source of distribution shift can be isolated to components of the joint distribution. One special case of distribution shift is covariate shift (Shimodaira, 2000; Zadrozny, 2004; Huang et al., 2006; Gretton et al., 2009; Sugiyama et al., 2007; Bickel et al., 2009; Chen et al., 2016; Schneider et al., 2020), where only the covariate distribution P(X)𝑃𝑋P(X)italic_P ( italic_X ) changes across domains. Ben-David et al. (2009) give upper-bounds on target error based on the \mathcal{H}caligraphic_H-divergence between the source and target covariate distributions, which motivates domain alignment methods like the Domain Adversarial Neural Networks (Ganin et al., 2016) and others (Long et al., 2015; Blanchard et al., 2017). Others have followed up on this work with other notions of covariate distance for domain adaptation, such as mean maximum discrepancy (MMD) (Long et al., 2016), Wasserstein distance (Courty et al., 2017), etc. However, Kpotufe and Martinet (2018) show that these divergence metrics fail to capture many important properties of transferability, such as asymmetry and non-overlap** support. Furthermore, Zhao et al. (2019) shows that even with the alignment of covariates, large distances between label distributions can inhibit transfer; they propose a label conditional importance weighting adjustment to address this limitation. Other works have also proposed conditional covariate alignment (des Combes et al., 2020; Li et al., 2018c, b).

Another form of distribution shift is label shift, where only the label distribution changes across domains. Lipton et al. (2018) propose a method to address this scenario. Schrouff et al. (2022) illustrate that many real-world problems exhibit more complex ’compound’ shifts than just covariate or label shifts alone.

One can leverage domain adaptation to address distribution shifts; however, these methods are contingent on having access to unlabeled or partially labeled samples from the target domain during training. When such samples are available, more sophisticated domain adaptation strategies aim to leverage and adapt spurious feature information to enhance performance (Liu et al., 2021; Zhang et al., 2021; Kirichenko et al., 2022). However, domain generalization, as a problem, does not assume access to such samples (Muandet et al., 2013).

To address the domain generalization problem, Invariant Causal Predictors (ICP) leverage shared causal structure to learn domain-general predictors (Peters et al., 2016). Previous works, enumerated in the introduction (Section 1), have proposed various algorithms to identify domain-general predictors. Arjovsky et al. (2019)’s proposed invariance risk minimization (IRM) and its variants motivated by domain invariance:

minw,Φ1|tr|etrRe(wΦ) s.t. wargminw~Re(w~Φ),etr,formulae-sequencesubscript𝑤Φ1subscript𝑡𝑟subscript𝑒subscript𝑡𝑟superscript𝑅𝑒𝑤Φ s.t. 𝑤subscriptargmin~𝑤superscript𝑅𝑒~𝑤Φfor-all𝑒subscript𝑡𝑟\displaystyle\min_{w,\Phi}\frac{1}{|\mathcal{E}_{tr}|}\sum_{e\in\mathcal{E}_{% tr}}R^{e}(w\circ\Phi)\text{ s.t. }w\in\mathop{\mathrm{argmin}}_{\widetilde{w}}% R^{e}(\widetilde{w}\cdot\Phi),\forall e\in\mathcal{E}_{tr},roman_min start_POSTSUBSCRIPT italic_w , roman_Φ end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG | caligraphic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_w ∘ roman_Φ ) s.t. italic_w ∈ roman_argmin start_POSTSUBSCRIPT over~ start_ARG italic_w end_ARG end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( over~ start_ARG italic_w end_ARG ⋅ roman_Φ ) , ∀ italic_e ∈ caligraphic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT ,

where Re(wΦ)=𝔼[(y,wΦ(x))]superscript𝑅𝑒𝑤Φ𝔼delimited-[]𝑦𝑤Φ𝑥R^{e}(w\circ\Phi)=\mathbb{E}\big{[}\ell(y,w\cdot\Phi(x))\big{]}italic_R start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_w ∘ roman_Φ ) = blackboard_E [ roman_ℓ ( italic_y , italic_w ⋅ roman_Φ ( italic_x ) ) ], with loss function \ellroman_ℓ, feature extractor ΦΦ\Phiroman_Φ, and linear predictor w𝑤witalic_w. This objective aims to learn a representation ΦΦ\Phiroman_Φ such that predictor w𝑤witalic_w that minimizes empirical risks on average across all domains also minimizes within-domain empirical risk for all domains. However, Rosenfeld et al. (2020); Ahuja et al. (2020) showed that this objective requires unreasonable constraints on the number of observed domains at train times, e.g., observing distinct domains on the order of the rank of spurious features. Follow-up works have attempted to improve these limitations with stronger constraints on the problem – enumerated in the introduction section.

Our method falls under domain generalization; however, unlike the domain-general solutions previously discussed, our proposed solution leverages different conditions than domain invariance directly, which we show may be more suited to learning domain-general representations.

3 Causality and Domain Generalization

We often represent causal relationships with a causal graph. A causal graph is a directed acyclic graph (DAG), G=(V,E)𝐺𝑉𝐸G=(V,E)italic_G = ( italic_V , italic_E ), with nodes V𝑉Vitalic_V representing random variables and directed edges E𝐸Eitalic_E representing causal relationships, i.e., parents are causes and children are effects. A structural equation model (SEM) provides a mathematical representation of the causal relationships in its corresponding DAG. Each variable YV𝑌𝑉Y\in Vitalic_Y ∈ italic_V is given by Y=fY(X)+εY𝑌subscript𝑓𝑌𝑋subscript𝜀𝑌Y=f_{Y}(X)+\varepsilon_{Y}italic_Y = italic_f start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ( italic_X ) + italic_ε start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT, where X𝑋Xitalic_X denotes the parents of Y𝑌Yitalic_Y in G𝐺Gitalic_G, fYsubscript𝑓𝑌f_{Y}italic_f start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT is a deterministic function, and εYsubscript𝜀𝑌\varepsilon_{Y}italic_ε start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT is an error capturing exogenous influences on Y𝑌Yitalic_Y. The main property we need here is that fYsubscript𝑓𝑌f_{Y}italic_f start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT is invariant to interventions to V\{Y}\𝑉𝑌V\backslash\{Y\}italic_V \ { italic_Y } and is consequently invariant to changes in P(V)𝑃𝑉P(V)italic_P ( italic_V ) induced by these interventions. Interventions refer to changes to fZsubscript𝑓𝑍f_{Z}italic_f start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT, ZV\{Y}𝑍\𝑉𝑌Z\in V\backslash\{Y\}italic_Z ∈ italic_V \ { italic_Y }.

In this work, we focus on domain-general predictors dgsubscript𝑑𝑔d_{g}italic_d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT that are linear functions of features with domain-general mechanisms, denoted as gdgwΦdgsubscript𝑔dg𝑤subscriptΦdgg_{\text{dg}}\coloneqq w\circ\Phi_{\text{dg}}italic_g start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ≔ italic_w ∘ roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT, where w𝑤witalic_w is a linear predictor and ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT identifies features with domain-general mechanisms. We use domain-general rather than domain-invariant since domain-invariance is strongly tied to the property: Ye|Zdg{Y\perp\!\!\!\perp e\,|\,Z_{\text{dg}}}italic_Y ⟂ ⟂ italic_e | italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT (Arjovsky et al., 2019). As shown in the subsequent sections, this work leverages other properties of appropriate causal graphs to obtain domain-general features. This distinction is crucial given the challenges associated with learning domain-general features through domain-invariance methods (Rosenfeld et al., 2020).

Given the presence of a distribution shift, it’s essential to identify some common structure across domains that can be utilized for out-of-distribution (OOD) generalization. For example, Shimodaira (2000) assume P(Y|X)𝑃conditional𝑌𝑋P(Y|X)italic_P ( italic_Y | italic_X ) is shared across all domains for the covariate shift problem. In this work, we consider a setting where each domain is composed of observed features and labels, X𝒳,Y𝒴formulae-sequence𝑋𝒳𝑌𝒴X\in\mathcal{X},Y\in\mathcal{Y}italic_X ∈ caligraphic_X , italic_Y ∈ caligraphic_Y, where X𝑋Xitalic_X is given by an invertible function ΓΓ\Gammaroman_Γ of two latent random variables: domain-general Zdg𝒵dgsubscript𝑍dgsubscript𝒵dgZ_{\text{dg}}\in\mathcal{Z}_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ∈ caligraphic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and spurious Zspu𝒵spusubscript𝑍spusubscript𝒵spuZ_{\text{spu}}\in\mathcal{Z}_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ∈ caligraphic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT. By construction, the conditional expectation of the label Y𝑌Yitalic_Y given the domain-general features Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT is the same across domains, i.e.,

𝔼ei[Y|Zdg=zdg]=𝔼ej[Y|Zdg=zdg]subscript𝔼subscript𝑒𝑖delimited-[]conditional𝑌subscript𝑍dgsubscript𝑧dgsubscript𝔼subscript𝑒𝑗delimited-[]conditional𝑌subscript𝑍dgsubscript𝑧dg\displaystyle\mathbb{E}_{e_{i}}\left[Y|Z_{\text{dg}}=z_{\text{dg}}\right]=% \mathbb{E}_{e_{j}}\left[Y|Z_{\text{dg}}=z_{\text{dg}}\right]blackboard_E start_POSTSUBSCRIPT italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_Y | italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT = italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ] = blackboard_E start_POSTSUBSCRIPT italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_Y | italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT = italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ] (1)
zdg𝒵dg,eiej.formulae-sequencefor-allsubscript𝑧dgsubscript𝒵dgfor-allsubscript𝑒𝑖subscript𝑒𝑗\displaystyle\forall z_{\text{dg}}\in\mathcal{Z}_{\text{dg}},\forall e_{i}\neq e% _{j}\in\mathcal{E}.∀ italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ∈ caligraphic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , ∀ italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≠ italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∈ caligraphic_E .

Conversely, this robustness to e𝑒eitalic_e does not necessarily extend to spurious features Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT; in other words, Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT may assume values that could lead a predictor relying on it to experience arbitrarily high error rates. Then, a sound strategy for learning a domain-general predictor – one that is robust to distribution shifts – is to identify the latent domain-general Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT from the observed features X𝑋Xitalic_X.

The approach we take to do this is motivated by the Reichenbach Common Cause Principle, which claims that if two events are correlated, there is either a causal connection between the correlated events that is responsible for the correlation or there is a third event, a so-called (Reichenbachian) common cause, which brings about the correlation (Hitchcock and Rédei, 2021; Rédei, 2002). This principle allows us to posit the class of generative processes or causal mechanisms that give rise to the correlated observed features and labels, where the observed features are a function of domain-general and spurious features. We represent these generative processes as causal graphs. Importantly, the map** from a node’s causal parents to itself is preserved in all distributions generated by the causal graph (Equation 1), and distributions can vary arbitrarily so long as they preserve the conditional independencies implied by the DAG (Markov Property (Pearl, 2010)).

We now enumerate DAGs that give observe features with spurious correlations with the label.

e𝑒eitalic_eZdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPTZspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPTY𝑌Yitalic_YX𝑋Xitalic_X
Figure 1: Partial Ancestral Graph representing all non-trivial and valid generative processes (DAGs); dashed edges indicate that an edge may or may not exist.
Valid DAGs.

We consider generative processes, where both latent features, Zspu,Zdg, and observed Xsubscript𝑍spusubscript𝑍dg and observed 𝑋Z_{\text{spu}},Z_{\text{dg}},\text{ and observed }Xitalic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , and observed italic_X are correlated with Y𝑌Yitalic_Y, and the observed X𝑋Xitalic_X is a function of only Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT (Figure 1).

Given this setup, there is an enumerable set of valid generative processes. Such processes are (i) without cycles, (ii) are feature complete – including edges from Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT to X𝑋Xitalic_X, i.e., ZdgXZspusubscript𝑍dg𝑋subscript𝑍spuZ_{\text{dg}}\rightarrow X\leftarrow Z_{\text{spu}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT → italic_X ← italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT, and (iii) where the observed features mediate domain influence, i.e., there is no direct domain influence on the label e↛Y↛𝑒𝑌e\not\rightarrow Yitalic_e ↛ italic_Y. We discuss this enumeration in detail in Appendix B. The result of our analysis is identifying a representative set of DAGs that describe valid generative processes – these DAGs come from orienting the partial ancestral graph (PAG) in Figure 1. We compare the conditional independencies implied by the DAGs defined by Figure 1 as illustrated in Figure 2, resulting in three canonical DAGs in the literature (see Appendix B for further discussion). Other DAGs that induce spurious correlations are outside the scope of this work.

e𝑒eitalic_eZdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPTZspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPTY𝑌Yitalic_YX𝑋Xitalic_X
(a) Causal (Arjovsky et al., 2019).
e𝑒eitalic_eZdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPTZspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPTY𝑌Yitalic_YX𝑋Xitalic_X
(b) Anticausal (Rosenfeld et al., 2020).
e𝑒eitalic_eZdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPTZspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPTY𝑌Yitalic_YX𝑋Xitalic_X
(c) Fully Informative Causal (Ahuja et al., 2021).
Figure 2: Generative Processes. Graphical models depicting the structure of possible data-generating processes – shaded nodes indicate observed variables. X𝑋Xitalic_X represents the observed features, Y𝑌Yitalic_Y represents observed targets, and e𝑒eitalic_e represents domain influences (domain indexes in practice). There is an explicit separation of domain-general Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and domain-specific Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT features; they are combined to generate observed X𝑋Xitalic_X. Dashed edges indicate the possibility of an edge.
Conditional independencies implied by identified DAGs (Figure 2).
  1. Fig. 2(a):

    𝐙dg𝐙spu|{𝐘,𝐞}\bf{Z_{\text{dg}}\perp\!\!\!\perp Z_{\text{spu}}\,|\,\{Y,e\}}bold_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ bold_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | { bold_Y , bold_e }; Ye|ZdgY\perp\!\!\!\perp e\,|\,Z_{\text{dg}}italic_Y ⟂ ⟂ italic_e | italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT.

    This causal graphical model implies that the map** from Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT to its causal child Y𝑌Yitalic_Y is preserved and consequently, Equation 1 holds (Pearl, 2010; Peters et al., 2016). As an example, consider the task of predicting the spread of a disease. Features may include causes (vaccination rate and public health policies) and effects (coughing). e𝑒eitalic_e is the time of month; the distribution of coughing changes depending on the season.

  2. Fig. 2(b):

    𝐙dg𝐙spu|{𝐘,𝐞}{\bf Z_{\text{dg}}\perp\!\!\!\perp Z_{\text{spu}}\,|\,\{Y,e\}}bold_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ bold_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | { bold_Y , bold_e }; ZdgZspu|YZ_{\text{dg}}\perp\!\!\!\perp Z_{\text{spu}}\,|\,Yitalic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | italic_Y; Ye|ZdgY\perp\!\!\!\perp e\,|\,Z_{\text{dg}}italic_Y ⟂ ⟂ italic_e | italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT, ZdgeZ_{\text{dg}}\perp\!\!\!\perp eitalic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ italic_e.

    The causal graphical model does not directly imply that ZdgYsubscript𝑍dg𝑌Z_{\text{dg}}\rightarrow Yitalic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT → italic_Y is preserved across domains. However, in this work, it represents the setting where the inverse of the causal direction is preserved (inverse: ZdgYsubscript𝑍dg𝑌Z_{\text{dg}}\rightarrow Yitalic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT → italic_Y), and thus Equation 1 holds. A context where this setting is relevant is in healthcare where medical conditions (Y𝑌Yitalic_Y) cause symptoms (Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT), but the prediction task is often predicting conditions from symptoms, and this map** ZdgYsubscript𝑍dg𝑌Z_{\text{dg}}\rightarrow Yitalic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT → italic_Y, opposite of the causal direction, is preserved across distributions. Again, we may consider e𝑒eitalic_e as the time of month; the distribution of coughing changes depending on the season.

  3. Fig. 2(c):

    Ye|ZdgY\perp\!\!\!\perp e\,|\,Z_{\text{dg}}italic_Y ⟂ ⟂ italic_e | italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT; ZdgeZ_{\text{dg}}\perp\!\!\!\perp eitalic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ italic_e.

    Similar to Figure 2(a), this causal graphical model implies that the map** from Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT to its causal child Y𝑌Yitalic_Y is preserved, so Equation 1 holds (Pearl, 2010; Peters et al., 2016). This setting is especially interesting because it represents a Fully Informative Invariant Features setting, that is ZspuY|ZdgZ_{\text{spu}}\perp\!\!\!\perp Y\,|\,Z_{\text{dg}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ⟂ ⟂ italic_Y | italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT (Ahuja et al., 2021). Said differently, Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT does not induce a backdoor path from e𝑒eitalic_e to Y𝑌Yitalic_Y that Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT does not block. As an example of this, we can consider the task of predicting hospital readmission rates. Features may include the severity of illness, which is a direct cause of readmission rates, and also include the length of stay, which is also caused by the severity of illness. However, length of stay may not be a cause of readmission; the correlation between the two would be a result of the confounding effect of a common cause, illness severity. e𝑒eitalic_e is an indicator for distinct hospitals.

Table 1: Generative Processes and Sufficient Conditions for Domain-Generality
Graphs in Figure 2
(a) (b) (c)
ZdgZspu|{Y,e}Z_{\text{dg}}\perp\!\!\!\perp Z_{\text{spu}}\,|\,\{Y,e\}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | { italic_Y , italic_e }
Identifying Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT is necessary

We call the condition 𝐘𝐞|𝐙dg\mathbf{Y\perp\!\!\!\perp e\,|\,Z_{\text{dg}}}bold_Y ⟂ ⟂ bold_e | bold_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT the domain invariance property. This condition is common to all the DAGs in Figure 2. We call the condition ZdgZspu|{Y,e}{Z_{\text{dg}}\perp\!\!\!\perp Z_{\text{spu}}\,|\,\{Y,e\}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | { italic_Y , italic_e } the target conditioned representation independence (TCRI) property. This condition is common to the DAGs in Figure 2(a), 2(b). In the settings considered in this work, the TCRI property is equivalently 𝐙dg𝐙spu|𝐘𝐞\mathbf{Z_{\text{dg}}\perp\!\!\!\perp Z_{\text{spu}}\,|\,Y\forall e\in\mathcal% {E}}bold_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ bold_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | bold_Y ∀ bold_e ∈ caligraphic_E since e𝑒eitalic_e will simply index the set of empirical distributions available at training.

Domain generalization with conditional independencies.

Kaur et al. (2022) showed that sufficiently regularizing for the correct conditional independencies described by the appropriate DAGs can give domain-general solutions, i.e., identifies Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT. However, in practice, one does not (partially) observe the latent features independently to regularize directly. Other works have also highlighted the need to consider generative processes when designing robust algorithms to distribute shifts (Veitch et al., 2021; Makar et al., 2022). However, previous work has largely focused on regularizing for the domain invariance property, ignoring the conditional independence property ZdgZspu|Y,eZ_{\text{dg}}\perp\!\!\!\perp Z_{\text{spu}}\,|\,Y,eitalic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | italic_Y , italic_e.

Sufficiency of ERM under Fully Informative Invariant Features.

Despite the known challenges of learning domain-general features from the domain-invariance properties in practice, this approach persists, likely due to it being the only property shared across all DAGs. We alleviate this constraint by observing that Graph (Fig. 2(c)) falls under what Ahuja et al. (2021) refer to as the fully informative invariant features settings, meaning that Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT is redundant, having only information about Y𝑌Yitalic_Y that is already in Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT. Ahuja et al. (2021) show that the empirical risk minimizer is domain-general for bounded features.

Easy vs. hard DAGs imply the generality of TCRI.

Consequently, we categorize the generative processes into easy and hard cases Table 1: (i) easy meaning that minimizing average risk gives domain-general solutions, i.e., ERM is sufficient (Fig. 2(c)), and (ii) hard meaning that one needs to identify Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT to obtain domain-general solutions (Figs. 2(a)-2(b)). We show empirically that regularizing for ZdgZspu|YeZ_{\text{dg}}\perp\!\!\!\perp Z_{\text{spu}}\,|\,Y\forall e\in\mathcal{E}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | italic_Y ∀ italic_e ∈ caligraphic_E also gives a domain-general solution in the easy case. The generality of TCRI follows from its sufficiency for identifying domain-general Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT in the hard cases while still giving domain-general solutions empirically in the easy case.

4 Proposed Learning Framework

We have now clarified that hard DAGs (i.e., those not solved by ERM) share the TCRI property. The challenge is that Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT are not independently observed; otherwise, one could directly regularize. Existing work such as Kaur et al. (2022) empirically study semi-synthetic datasets where Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT is (partially) observed and directly learn Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT by regularizing that Φ(X)Zspu|Y,e\Phi(X)\perp\!\!\!\perp Z_{\text{spu}}\,|\,Y,eroman_Φ ( italic_X ) ⟂ ⟂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | italic_Y , italic_e for feature extractor ΦΦ\Phiroman_Φ. To our knowledge, we are the first to leverage the TCRI property without requiring observation of Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT. Next, we set up our approach with some key assumptions. The first is that the observed distributions are Markov to an appropriate DAG.

Assumption 4.1.

All distributions, sources and targets, are generated by one of the structural causal models 𝒮𝒞𝒮𝒞\mathcal{SCM}caligraphic_S caligraphic_C caligraphic_M that follow:
𝒮𝒞(e)causal{Zdg(e)PZdg(e),Y(e)wdg,Zdg(e)+ηY,Zspu(e)wspu,Y+ηZspu(e),XΓ(Zdg,Zspu),superscript𝒮𝒞𝑒𝑐𝑎𝑢𝑠𝑎𝑙casessimilar-tosuperscriptsubscript𝑍dg𝑒superscriptsubscript𝑃subscript𝑍dg𝑒otherwisesuperscript𝑌𝑒superscriptsubscript𝑤dgsuperscriptsubscript𝑍dg𝑒subscript𝜂𝑌otherwisesuperscriptsubscript𝑍spu𝑒superscriptsubscript𝑤spu𝑌superscriptsubscript𝜂subscript𝑍spu𝑒otherwise𝑋Γsubscript𝑍dgsubscript𝑍spuotherwise\overbrace{\mathcal{SCM}(e)}^{causal}\coloneqq\begin{cases}Z_{\text{dg}}^{(e)}% \sim P_{Z_{\text{dg}}}^{(e)},\\ Y^{(e)}\leftarrow\langle w_{\text{dg}}^{*},Z_{\text{dg}}^{(e)}\rangle+\eta_{Y}% ,\\ Z_{\text{spu}}^{(e)}\leftarrow\langle w_{\text{spu}}^{*},Y\rangle+\eta_{Z_{% \text{spu}}}^{(e)},\\ X\leftarrow\Gamma(Z_{\text{dg}},Z_{\text{spu}}),\end{cases}over⏞ start_ARG caligraphic_S caligraphic_C caligraphic_M ( italic_e ) end_ARG start_POSTSUPERSCRIPT italic_c italic_a italic_u italic_s italic_a italic_l end_POSTSUPERSCRIPT ≔ { start_ROW start_CELL italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∼ italic_P start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ← ⟨ italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ⟩ + italic_η start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT , end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ← ⟨ italic_w start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_Y ⟩ + italic_η start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_X ← roman_Γ ( italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) , end_CELL start_CELL end_CELL end_ROW (2) 𝒮𝒞(e)anticausal{Y(e)PY,Zdg(e)w~dg,Y+ηZdg(e),Zspu(e)wspu,Y+ηZspu(e),XΓ(Zdg,Zspu),superscript𝒮𝒞𝑒𝑎𝑛𝑡𝑖𝑐𝑎𝑢𝑠𝑎𝑙casessimilar-tosuperscript𝑌𝑒subscript𝑃𝑌otherwisesuperscriptsubscript𝑍dg𝑒subscript~𝑤dg𝑌superscriptsubscript𝜂subscript𝑍dg𝑒otherwisesuperscriptsubscript𝑍spu𝑒superscriptsubscript𝑤spu𝑌superscriptsubscript𝜂subscript𝑍spu𝑒otherwise𝑋Γsubscript𝑍dgsubscript𝑍spuotherwise\overbrace{\mathcal{SCM}(e)}^{anticausal}\coloneqq\begin{cases}Y^{(e)}\sim P_{% Y},\\ Z_{\text{dg}}^{(e)}\leftarrow\langle\tilde{w}_{\text{dg}},Y\rangle+\eta_{Z_{% \text{dg}}}^{(e)},\\ Z_{\text{spu}}^{(e)}\leftarrow\langle w_{\text{spu}}^{*},Y\rangle+\eta_{Z_{% \text{spu}}}^{(e)},\\ X\leftarrow\Gamma(Z_{\text{dg}},Z_{\text{spu}}),\end{cases}over⏞ start_ARG caligraphic_S caligraphic_C caligraphic_M ( italic_e ) end_ARG start_POSTSUPERSCRIPT italic_a italic_n italic_t italic_i italic_c italic_a italic_u italic_s italic_a italic_l end_POSTSUPERSCRIPT ≔ { start_ROW start_CELL italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∼ italic_P start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT , end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ← ⟨ over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_Y ⟩ + italic_η start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ← ⟨ italic_w start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_Y ⟩ + italic_η start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_X ← roman_Γ ( italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) , end_CELL start_CELL end_CELL end_ROW (3)

𝒮𝒞(e)FIIF{Zdg(e)PZdg(e),Y(e)wdg,Zdg(e)+ηY,Zspu(e)wspu,Zdg+ηZspu(e),XΓ(Zdg,Zspu),superscript𝒮𝒞𝑒𝐹𝐼𝐼𝐹casessimilar-tosuperscriptsubscript𝑍dg𝑒superscriptsubscript𝑃subscript𝑍dg𝑒otherwisesuperscript𝑌𝑒superscriptsubscript𝑤dgsuperscriptsubscript𝑍dg𝑒subscript𝜂𝑌otherwisesuperscriptsubscript𝑍spu𝑒superscriptsubscript𝑤spusubscript𝑍dgsuperscriptsubscript𝜂subscript𝑍spu𝑒otherwise𝑋Γsubscript𝑍dgsubscript𝑍spuotherwise\overbrace{\mathcal{SCM}(e)}^{FIIF}\coloneqq\begin{cases}Z_{\text{dg}}^{(e)}% \sim P_{Z_{\text{dg}}}^{(e)},\\ Y^{(e)}\leftarrow\langle w_{\text{dg}}^{*},Z_{\text{dg}}^{(e)}\rangle+\eta_{Y}% ,\\ Z_{\text{spu}}^{(e)}\leftarrow\langle w_{\text{spu}}^{*},Z_{\text{dg}}\rangle+% \eta_{Z_{\text{spu}}}^{(e)},\\ X\leftarrow\Gamma(Z_{\text{dg}},Z_{\text{spu}}),\end{cases}over⏞ start_ARG caligraphic_S caligraphic_C caligraphic_M ( italic_e ) end_ARG start_POSTSUPERSCRIPT italic_F italic_I italic_I italic_F end_POSTSUPERSCRIPT ≔ { start_ROW start_CELL italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∼ italic_P start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ← ⟨ italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ⟩ + italic_η start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT , end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ← ⟨ italic_w start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟩ + italic_η start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_X ← roman_Γ ( italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) , end_CELL start_CELL end_CELL end_ROW (4)

where PZdgsubscript𝑃subscript𝑍dgP_{Z_{\text{dg}}}italic_P start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the causal covariate distribution, w𝑤witalic_w’s are linear generative mechanisms, η𝜂\etaitalic_η’s are exogenous independent noise variables, and Γ:𝒵dg×𝒵spu𝒳:Γsubscript𝒵dgsubscript𝒵spu𝒳\Gamma:\mathcal{Z}_{\text{dg}}\times\mathcal{Z}_{\text{spu}}\rightarrow% \mathcal{X}roman_Γ : caligraphic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT × caligraphic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT → caligraphic_X is an invertible function. It follows from having causal mechanisms that we can learn a predictor wdgsuperscriptsubscript𝑤dgw_{\text{dg}}^{*}italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT for Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT that is domain-general (Equation 2-4) – wdgsuperscriptsubscript𝑤dgw_{\text{dg}}^{*}italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT inverts the map** w~dgsubscript~𝑤dg\tilde{w}_{\text{dg}}over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT in the anticausal case.

These structural causal models (Equation 2-4) correspond to causal graphs Figures 2(a)-2(c), respectively.

Assumption 4.2 (Structural).

Causal Graphs and their distributions are Markov and Faithful (Pearl, 2010).

Given Assumption 4.2, we aim to leverage TCRI property (ZdgZspu|YetrZ_{\text{dg}}\perp\!\!\!\perp Z_{\text{spu}}\,|\,Y\forall e\in\mathcal{E}_{tr}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | italic_Y ∀ italic_e ∈ caligraphic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT) to learn the latent Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT without observing Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT directly. We do this by learning two feature extractors that, together, recover Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT and satisfy TCRI (Figure 3). We formally define these properties as follows.

Xesuperscript𝑋𝑒X^{e}italic_X start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPTΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPTΦspusubscriptΦspu\Phi_{\text{spu}}roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPTZ^dgsubscript^𝑍dg\widehat{Z}_{\text{dg}}over^ start_ARG italic_Z end_ARG start_POSTSUBSCRIPT dg end_POSTSUBSCRIPTθcsubscript𝜃𝑐\theta_{c}italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPTdirect-sum\oplusZ^spusubscript^𝑍spu\widehat{Z}_{\text{spu}}over^ start_ARG italic_Z end_ARG start_POSTSUBSCRIPT spu end_POSTSUBSCRIPTθesubscript𝜃𝑒\theta_{e}italic_θ start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPTy^csubscript^𝑦𝑐\widehat{y}_{c}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPTy^esubscript^𝑦𝑒\widehat{y}_{e}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT
Figure 3: Modeling approach. During training, both representations, ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT, and ΦspusubscriptΦspu\Phi_{\text{spu}}roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT, generate domain-general and domain-specific predictions, respectively. However, only the domain-invariant representations/predictions are used during testing – indicated by the solid red arrows.
Definition 4.3 (Total Information Criterion (TIC)).

Φ=ΦdgΦspuΦdirect-sumsubscriptΦdgsubscriptΦspu\Phi=\Phi_{\text{dg}}\oplus\Phi_{\text{spu}}roman_Φ = roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⊕ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT satisfies TIC with respect to random variables X,Y,e𝑋𝑌𝑒X,\,Y,\,eitalic_X , italic_Y , italic_e if for Φ(Xe)=[Φdg(Xe);Φspu(Xe)]Φsuperscript𝑋𝑒subscriptΦdgsuperscript𝑋𝑒subscriptΦspusuperscript𝑋𝑒\Phi(X^{e})=[\Phi_{\text{dg}}(X^{e});\Phi_{\text{spu}}(X^{e})]roman_Φ ( italic_X start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) = [ roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) ; roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) ], there exists a linear operator 𝒯𝒯{\mathcal{T}}caligraphic_T s.t., 𝒯(Φ(Xe))=[Zdge;Zspue]etr𝒯Φsuperscript𝑋𝑒superscriptsubscript𝑍dg𝑒superscriptsubscript𝑍spu𝑒for-all𝑒subscript𝑡𝑟{\mathcal{T}}(\Phi(X^{e}))=[Z_{\text{dg}}^{e};Z_{\text{spu}}^{e}]\forall e\in% \mathcal{E}_{tr}caligraphic_T ( roman_Φ ( italic_X start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) ) = [ italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ; italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ] ∀ italic_e ∈ caligraphic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT.

In other words, a feature extractor that satisfies the total information criterion recovers the complete latent feature sets Zdg,Zspusubscript𝑍dgsubscript𝑍spuZ_{\text{dg}},\,Z_{\text{spu}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT. This allows us to define the proposed implementation of the TCRI property non-trivially – the conditional independence of subsets of the latents may not have the same implications on domain generalization. We note that XY|Zdg,ZspuX\perp\!\!\!\perp Y|Z_{\text{dg}},Z_{\text{spu}}italic_X ⟂ ⟂ italic_Y | italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT, so X𝑋Xitalic_X has no information about Y𝑌Yitalic_Y that is not in Zdg,Zspusubscript𝑍dgsubscript𝑍spuZ_{\text{dg}},Z_{\text{spu}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT.

Definition 4.4 (Target Conditioned Representation Independence).

Φ=ΦdgΦspuΦdirect-sumsubscriptΦdgsubscriptΦspu\Phi=\Phi_{\text{dg}}\oplus\Phi_{\text{spu}}roman_Φ = roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⊕ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT satisfies TCRI with respect to random variables X,Y,e𝑋𝑌𝑒X,\,Y,\,eitalic_X , italic_Y , italic_e if Φdg(X)Φspu(X)|Ye\Phi_{\text{dg}}(X)\perp\!\!\!\perp\Phi_{\text{spu}}(X)\,|\,Y\forall e\in% \mathcal{E}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) ⟂ ⟂ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) | italic_Y ∀ italic_e ∈ caligraphic_E.

Proposition 4.5.

Assume that Φdg(X)subscriptΦdg𝑋\Phi_{\text{dg}}(X)roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) and Φspu(X)subscriptΦspu𝑋\Phi_{\text{spu}}(X)roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) are correlated with Y𝑌Yitalic_Y. Given Assumptions 4.1-4.2 and a representation Φ=ΦdgΦspuΦdirect-sumsubscriptΦdgsubscriptΦspu\Phi=\Phi_{\text{dg}}\oplus\Phi_{\text{spu}}roman_Φ = roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⊕ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT that satisfies TIC, Φdg(X)=ZdgiffsubscriptΦdg𝑋subscript𝑍dgabsent\Phi_{\text{dg}}(X)=Z_{\text{dg}}\iffroman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) = italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⇔ ΦΦ\Phiroman_Φ satisfies TCRI. (see Appendix C for proof).

Proposition 4.5 shows that TCRI is necessary and sufficient to identify Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT from a set of training domains. We note that we can verify if Φdg(X)subscriptΦdg𝑋\Phi_{\text{dg}}(X)roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) and Φspu(X)subscriptΦspu𝑋\Phi_{\text{spu}}(X)roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) are correlated with Y𝑌Yitalic_Y by checking if the learned predictors are equivalent to chance. Next, we describe our proposed algorithm to implement the conditions to learn such a feature map. Figure 3 illustrates the learning framework.

Learning Objective:

The first term in our proposed objective is

Φdg=e(θcΦdg),subscriptsubscriptΦdgsuperscript𝑒subscript𝜃𝑐subscriptΦdg\mathcal{L}_{\Phi_{\text{dg}}}=\mathcal{R}^{e}(\theta_{c}\circ\Phi_{\text{dg}}),caligraphic_L start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT end_POSTSUBSCRIPT = caligraphic_R start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∘ roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ) ,

where Φdg:𝒳m:subscriptΦdgmaps-to𝒳superscript𝑚\Phi_{\text{dg}}:\mathcal{X}\mapsto\mathbb{R}^{m}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT : caligraphic_X ↦ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT is a feature extractor, θc:m𝒴:subscript𝜃𝑐maps-tosuperscript𝑚𝒴\theta_{c}:\mathbb{R}^{m}\mapsto\mathcal{Y}italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ↦ caligraphic_Y is a linear predictor, and e(θcΦdg)=𝔼[(y,θcΦ(x))]superscript𝑒subscript𝜃𝑐subscriptΦdg𝔼delimited-[]𝑦subscript𝜃𝑐Φ𝑥\mathcal{R}^{e}(\theta_{c}\circ\Phi_{\text{dg}})=\mathbb{E}\big{[}\ell(y,% \theta_{c}\cdot\Phi(x))\big{]}caligraphic_R start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∘ roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ) = blackboard_E [ roman_ℓ ( italic_y , italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ⋅ roman_Φ ( italic_x ) ) ] is the empirical risk achieved by the feature extractor and predictor pair on samples from domain e𝑒eitalic_e. ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and θcsubscript𝜃𝑐\theta_{c}italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT are designed to capture the domain-general portion of the framework.

Next, to implement the total information criterion, we use another feature extractor Φspu:𝒳o:subscriptΦspumaps-to𝒳superscript𝑜\Phi_{\text{spu}}:\mathcal{X}\mapsto\mathbb{R}^{o}roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT : caligraphic_X ↦ blackboard_R start_POSTSUPERSCRIPT italic_o end_POSTSUPERSCRIPT, designed to capture the domain-specific information in X𝑋Xitalic_X that is not captured by ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT. Together, we have Φ=ΦdgΦspuΦdirect-sumsubscriptΦdgsubscriptΦspu\Phi=\Phi_{\text{dg}}\oplus\Phi_{\text{spu}}roman_Φ = roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⊕ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT where ΦΦ\Phiroman_Φ has domain-specific predictors θe:m+o𝒴:subscript𝜃𝑒maps-tosuperscript𝑚𝑜𝒴\theta_{e}:\mathbb{R}^{m+o}\mapsto\mathcal{Y}italic_θ start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_m + italic_o end_POSTSUPERSCRIPT ↦ caligraphic_Y for each training domain, allowing the feature extractor to utilize domain-specific information to learn distinct optimal domain-specific (non-general) predictors:

Φ=e(θeΦ).subscriptΦsuperscript𝑒subscript𝜃𝑒Φ\mathcal{L}_{\Phi}=\mathcal{R}^{e}\big{(}\theta_{e}\circ\Phi\big{)}.caligraphic_L start_POSTSUBSCRIPT roman_Φ end_POSTSUBSCRIPT = caligraphic_R start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT ∘ roman_Φ ) .

ΦsubscriptΦ\mathcal{L}_{\Phi}caligraphic_L start_POSTSUBSCRIPT roman_Φ end_POSTSUBSCRIPT aims to ensure that Φdg and ΦspusubscriptΦdg and subscriptΦspu\Phi_{\text{dg}}\text{ and }\Phi_{\text{spu}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT capture all of the information about Y𝑌Yitalic_Y in X𝑋Xitalic_X – total information criterion. Since we do not know o,m𝑜𝑚o,mitalic_o , italic_m, we select them to be the same size on our experiments; o,m𝑜𝑚o,mitalic_o , italic_m could be treated as hyperparameters though we do not treat them as such.

Finally, we implement the TCRI property (Definition 4.4). We denote TCRIsubscript𝑇𝐶𝑅𝐼\mathcal{L}_{TCRI}caligraphic_L start_POSTSUBSCRIPT italic_T italic_C italic_R italic_I end_POSTSUBSCRIPT to be a conditional independence penalty for Φdg and ΦspusubscriptΦdg and subscriptΦspu\Phi_{\text{dg}}\text{ and }\Phi_{\text{spu}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT. We utilize the Hilbert Schmidt independence Criterion (HSIC) (Gretton et al., 2007) as TCRIsubscript𝑇𝐶𝑅𝐼\mathcal{L}_{TCRI}caligraphic_L start_POSTSUBSCRIPT italic_T italic_C italic_R italic_I end_POSTSUBSCRIPT. However, in principle, any conditional independence penalty can be used in its place. HSIC:

TCRI(Φdg,Φspu)=12k{0,1}HSIC^(Φdg(X),Φspu(X))y=k=12k{0,1}1nk2tr(𝐊Φdg𝐇nk𝐊Φspu𝐇nk)y=k,subscript𝑇𝐶𝑅𝐼subscriptΦdgsubscriptΦspu12subscript𝑘01^𝐻𝑆𝐼𝐶superscriptsubscriptΦdg𝑋subscriptΦspu𝑋𝑦𝑘12subscript𝑘011superscriptsubscript𝑛𝑘2trsuperscriptsubscript𝐊subscriptΦdgsubscript𝐇subscript𝑛𝑘subscript𝐊subscriptΦspusubscript𝐇subscript𝑛𝑘𝑦𝑘\displaystyle\mathcal{L}_{TCRI}(\Phi_{\text{dg}},\Phi_{\text{spu}})=\frac{1}{2% }\sum_{k\in\{0,1\}}\widehat{HSIC}\Big{(}\Phi_{\text{dg}}(X),\Phi_{\text{spu}}(% X)\Big{)}^{y=k}=\frac{1}{2}\sum_{k\in\{0,1\}}\frac{1}{n_{k}^{2}}\text{tr}\Big{% (}\mathbf{K}_{\Phi_{\text{dg}}}\mathbf{H}_{n_{k}}\mathbf{K}_{\Phi_{\text{spu}}% }\mathbf{H}_{n_{k}}\Big{)}^{y=k},caligraphic_L start_POSTSUBSCRIPT italic_T italic_C italic_R italic_I end_POSTSUBSCRIPT ( roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_k ∈ { 0 , 1 } end_POSTSUBSCRIPT over^ start_ARG italic_H italic_S italic_I italic_C end_ARG ( roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) , roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) ) start_POSTSUPERSCRIPT italic_y = italic_k end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_k ∈ { 0 , 1 } end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG tr ( bold_K start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT end_POSTSUBSCRIPT bold_H start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT end_POSTSUBSCRIPT bold_H start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_y = italic_k end_POSTSUPERSCRIPT ,

where k𝑘kitalic_k, indicates which class the examples in the estimate correspond to, C𝐶Citalic_C is the number of classes, 𝐊Φdgnk×nk,𝐊Φspunk×nkformulae-sequencesubscript𝐊subscriptΦdgsuperscriptsubscript𝑛𝑘subscript𝑛𝑘subscript𝐊subscriptΦspusuperscriptsubscript𝑛𝑘subscript𝑛𝑘\mathbf{K}_{\Phi_{\text{dg}}}\in\mathbb{R}^{n_{k}\times n_{k}},\,\mathbf{K}_{% \Phi_{\text{spu}}}\in\mathbb{R}^{n_{k}\times n_{k}}bold_K start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , bold_K start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are Gram matrices, 𝐊Φi,j=κ(Φdg(X)i,Φdg(X)j)superscriptsubscript𝐊Φ𝑖𝑗𝜅subscriptΦdgsubscript𝑋𝑖subscriptΦdgsubscript𝑋𝑗\mathbf{K}_{\Phi}^{i,j}=\kappa(\Phi_{\text{dg}}(X)_{i},\Phi_{\text{dg}}(X)_{j})bold_K start_POSTSUBSCRIPT roman_Φ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i , italic_j end_POSTSUPERSCRIPT = italic_κ ( roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ), 𝐊Φspui,j=ω(Φspu(X)i,Φspu(X)j)superscriptsubscript𝐊subscriptΦspu𝑖𝑗𝜔subscriptΦspusubscript𝑋𝑖subscriptΦspusubscript𝑋𝑗\mathbf{K}_{\Phi_{\text{spu}}}^{i,j}=\omega(\Phi_{\text{spu}}(X)_{i},\Phi_{% \text{spu}}(X)_{j})bold_K start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i , italic_j end_POSTSUPERSCRIPT = italic_ω ( roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) with kernels κ,ω𝜅𝜔\kappa,\omegaitalic_κ , italic_ω are radial basis functions, 𝐇nk=𝐈nk1nk2𝟏𝟏subscript𝐇subscript𝑛𝑘subscript𝐈subscript𝑛𝑘1superscriptsubscript𝑛𝑘2superscript11top\mathbf{H}_{n_{k}}=\mathbf{I}_{n_{k}}-\frac{1}{n_{k}^{2}}\mathbf{1}\mathbf{1}^% {\top}bold_H start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT = bold_I start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_11 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT is a centering matrix, 𝐈nksubscript𝐈subscript𝑛𝑘\mathbf{I}_{n_{k}}bold_I start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the nk×nksubscript𝑛𝑘subscript𝑛𝑘{n_{k}}\times{n_{k}}italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT dimensional identity matrix, 𝟏nksubscript1subscript𝑛𝑘\mathbf{1}_{n_{k}}bold_1 start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the nksubscript𝑛𝑘{n_{k}}italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT-dimensional vector whose elements are all 1, and denotes the transpose. We condition on the label by taking only examples of each label and computing the empirical HSIC; then, we take the average.

Taken together, the full objective to be minimized is as follows:

=1Etretr[e(θcΦdg)+e(θeΦ)+βTCRI(Φdg,Φspu)],1subscript𝐸𝑡𝑟subscript𝑒subscript𝑡𝑟delimited-[]superscript𝑒subscript𝜃𝑐subscriptΦdgsuperscript𝑒subscript𝜃𝑒Φ𝛽subscript𝑇𝐶𝑅𝐼subscriptΦdgsubscriptΦspu\displaystyle\begin{split}\mathcal{L}=\frac{1}{E_{tr}}\sum_{e\in\mathcal{E}_{% tr}}\Bigg{[}\mathcal{R}^{e}(\theta_{c}\circ\Phi_{\text{dg}})+\mathcal{R}^{e}(% \theta_{e}\circ\Phi)+\beta\mathcal{L}_{TCRI}(\Phi_{\text{dg}},\Phi_{\text{spu}% })\Bigg{]},\end{split}start_ROW start_CELL caligraphic_L = divide start_ARG 1 end_ARG start_ARG italic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ caligraphic_R start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∘ roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ) + caligraphic_R start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT ∘ roman_Φ ) + italic_β caligraphic_L start_POSTSUBSCRIPT italic_T italic_C italic_R italic_I end_POSTSUBSCRIPT ( roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) ] , end_CELL end_ROW

where β>0𝛽0\beta>0italic_β > 0 is a hyperparameter and Etrsubscript𝐸𝑡𝑟E_{tr}italic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT is the number of training domains. Figure 3 shows the full framework. We note that when β=0𝛽0\beta=0italic_β = 0, this loss reduces to ERM.

Note that while we minimize this objective with respect to Φ,θc,θ1,,θEtrΦsubscript𝜃𝑐subscript𝜃1subscript𝜃subscript𝐸𝑡𝑟\Phi,\theta_{c},\theta_{1},\ldots,\theta_{E_{tr}}roman_Φ , italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_θ start_POSTSUBSCRIPT italic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT, only the domain-general representation and its predictor, θcΦdgsubscript𝜃𝑐subscriptΦdg\theta_{c}\cdot\Phi_{\text{dg}}italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ⋅ roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT are used for inference.

5 Experiments

We begin by evaluating with simulated data, i.e., with known ground truth mechanisms; we use Equation 5 to generate our simulated data, with domain parameter σeisubscript𝜎subscript𝑒𝑖\sigma_{e_{i}}italic_σ start_POSTSUBSCRIPT italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT; code is provided in the supplemental materials.

𝒮𝒞(ei){Zdg(ei)𝒩(0,σei2)y(ei)=Zdg(ei)+𝒩(0,σy2),Zspu(ei)=Y(ei)+𝒩(0,σei2).𝒮𝒞subscript𝑒𝑖casessimilar-tosuperscriptsubscript𝑍dgsubscript𝑒𝑖𝒩0superscriptsubscript𝜎subscript𝑒𝑖2otherwisesuperscript𝑦subscript𝑒𝑖superscriptsubscript𝑍dgsubscript𝑒𝑖𝒩0superscriptsubscript𝜎𝑦2otherwisesuperscriptsubscript𝑍spusubscript𝑒𝑖superscript𝑌subscript𝑒𝑖𝒩0superscriptsubscript𝜎subscript𝑒𝑖2otherwise\displaystyle\mathcal{SCM}(e_{i})\coloneqq\begin{cases}Z_{\text{dg}}^{(e_{i})}% \sim\mathcal{N}\left(0,\sigma_{e_{i}}^{2}\right)&\\ y^{(e_{i})}=Z_{\text{dg}}^{(e_{i})}+\mathcal{N}\left(0,\sigma_{y}^{2}\right),% \\ Z_{\text{spu}}^{(e_{i})}=Y^{(e_{i})}+\mathcal{N}\left(0,\sigma_{e_{i}}^{2}% \right).\end{cases}caligraphic_S caligraphic_C caligraphic_M ( italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ≔ { start_ROW start_CELL italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ∼ caligraphic_N ( 0 , italic_σ start_POSTSUBSCRIPT italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_y start_POSTSUPERSCRIPT ( italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT = italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + caligraphic_N ( 0 , italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) , end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT = italic_Y start_POSTSUPERSCRIPT ( italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + caligraphic_N ( 0 , italic_σ start_POSTSUBSCRIPT italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) . end_CELL start_CELL end_CELL end_ROW (5)
Table 2: Continuous Simulated Results – Feature Extractor with a dummy predictor θc=1.subscript𝜃𝑐1\theta_{c}=1.italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = 1 ., i.e., y^=xΦdgw^𝑦𝑥subscriptΦdg𝑤\widehat{y}=x\cdot\Phi_{\text{dg}}\cdot wover^ start_ARG italic_y end_ARG = italic_x ⋅ roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⋅ italic_w, where xN×2,Φdg,Φspu2×1,wformulae-sequence𝑥superscript𝑁2subscriptΦdgformulae-sequencesubscriptΦspusuperscript21𝑤x\in\mathbb{R}^{N\times 2},\,\Phi_{\text{dg}},\Phi_{\text{spu}}\in\mathbb{R}^{% 2\times 1},\,w\in\mathbb{R}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × 2 end_POSTSUPERSCRIPT , roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 × 1 end_POSTSUPERSCRIPT , italic_w ∈ blackboard_R. Oracle indicates the coefficients achieved by regressing y𝑦yitalic_y on zcsubscript𝑧𝑐z_{c}italic_z start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT directly.
Algorithm (𝚽dg)𝟎subscriptsubscript𝚽dg0\mathbf{(\Phi_{\text{dg}})_{0}}( bold_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT (𝚽dg)𝟏subscriptsubscript𝚽dg1\mathbf{(\Phi_{\text{dg}})_{1}}( bold_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT
(i.e., 𝐙dgsubscript𝐙dg\mathbf{Z_{\text{dg}}}bold_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT weight) (i.e., 𝐙spusubscript𝐙spu\mathbf{Z_{\text{spu}}}bold_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT weight)
ERM 0.29 0.71
IRM 0.28 0.71
TCRI 1.01 0.06
Oracle 1.04 0.00

We observe 2 domains with parameters σe=0=0.1subscript𝜎𝑒00.1\sigma_{e=0}=0.1italic_σ start_POSTSUBSCRIPT italic_e = 0 end_POSTSUBSCRIPT = 0.1, σe=1=0.2superscript𝜎𝑒10.2\sigma^{e=1}=0.2italic_σ start_POSTSUPERSCRIPT italic_e = 1 end_POSTSUPERSCRIPT = 0.2 with σy=0.25subscript𝜎𝑦0.25\sigma_{y}=0.25italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT = 0.25, 5000 samples, and linear feature extractors and predictors. We use partial covariance as our conditional independence penalty TCRIsubscript𝑇𝐶𝑅𝐼\mathcal{L}_{TCRI}caligraphic_L start_POSTSUBSCRIPT italic_T italic_C italic_R italic_I end_POSTSUBSCRIPT. Table  2 shows the learned value of ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT, where ‘Oracle’ indicates the true coefficients obtained by regressing Y𝑌Yitalic_Y on domain-general Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT directly. The ideal ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT recovers Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and puts zero weight on Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT.

Now, we evaluate the efficacy of our proposed objective on non-simulated datasets.

5.1 Semisynthetic and Real-World Datasets

Algorithms:

We compare our method to baselines corresponding to DAG properties: Empirical Risk Minimization (ERM, (Vapnik, 1991)), Invariant Risk Minimization (IRM (Arjovsky et al., 2019)), Variance Risk Extrapolation (V-REx, (Krueger et al., 2021)), (Li et al., 2018a)), Group Distributionally Robust Optimization (GroupDRO), (Sagawa et al., 2019)), and Information Bottleneck methods (IB_ERM/IB_IRM, (Ahuja et al., 2021)). Additional baseline methods are provided in the Appendix A.

We evaluate our proposed method on the semisynthetic ColoredMNIST (Arjovsky et al., 2019) and real-world Terra Incognita dataset (Beery et al., 2018). Given observed domains tr={e:1,2,,Etr}subscript𝑡𝑟conditional-set𝑒12subscript𝐸𝑡𝑟\mathcal{E}_{tr}=\{e:1,2,\ldots,E_{tr}\}caligraphic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT = { italic_e : 1 , 2 , … , italic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT }, we train on tr\ei\subscript𝑡𝑟subscript𝑒𝑖\mathcal{E}_{tr}\,\backslash\,e_{i}caligraphic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT \ italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and evaluate the model on the unseen domain eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, for each etr𝑒subscript𝑡𝑟e\in\mathcal{E}_{tr}italic_e ∈ caligraphic_E start_POSTSUBSCRIPT italic_t italic_r end_POSTSUBSCRIPT.

ColoredMNIST:

The ColoredMNIST dataset (Arjovsky et al., 2019) is composed of 7000700070007000 (2×28×28228282\times 28\times 282 × 28 × 28, 1111) images of a hand-written digit and binary-label pairs. There are three domains with different correlations between image color and label, i.e., the image color is spuriously related to the label by assigning a color to each of the two classes (0: digits 0-4, 1: digits 5-9). The color is then flipped with probabilities {0.1,0.2,0.9}0.10.20.9\{0.1,0.2,0.9\}{ 0.1 , 0.2 , 0.9 } to create three domains, making the color-label relationship domain-specific because it changes across domains. There is also label flip noise of 0.250.250.250.25, so we expect that the best accuracy a domain-general model can achieve is 75%, while a non-domain general model can achieve higher. In this dataset, Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT corresponds to the original image, Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT the color, e𝑒eitalic_e the label-color correlation, Y𝑌Yitalic_Y the image label, and X𝑋Xitalic_X the observed colored image. This DAG follows the generative process of Figure 2(a) (Arjovsky et al., 2019).

Spurrious PACS:

Variables. X𝑋Xitalic_X: images, Y𝑌Yitalic_Y: non-urban (elephant, giraffe, horse) vs. urban (dog, guitar, house, person). Domains. {{cartoon, art painting}, {art painting, cartoon}, {photo}} (Li et al., 2017). The photo domain is the same as in the original dataset. In the {cartoon, art painting} domain, urban examples are selected from the original cartoon domain, while non-urban examples are selected from the original art painting domain. In the {art painting, cartoon} domain, urban examples are selected from the original art painting domain, while non-urban examples are selected from the original cartoon domain. This sampling encourages the model to use spurious correlations (domain-related information) to predict the labels; however, since these relationships are flipped between domains {{cartoon, art painting} and {art painting, cartoon}, these predictions will be wrong when generalized to other domains.

Terra Incognita:

The Terra Incognita dataset contains subsets of the Caltech Camera Traps dataset (Beery et al., 2018) defined by (Gulrajani and Lopez-Paz, 2020). There are four domains representing different locations {L100, L38, L43, L46} of cameras in the American Southwest. There are 9 species of wild animals {bird, bobcat, cat, coyote, dog, empty, opossum, rabbit, raccoon, squirrel} and a ‘no-animal’ class to be predicted. Like Ahuja et al. (2021), we classify this dataset as following the generative process in Figure 2(c), the Fully Informative Invariant Features (FIIF) setting. Additional details on model architecture, training, and hyperparameters are detailed in Appendix 5.

Model Selection.

The standard approach for model selection is a training-domain hold-out validation set accuracy. We find that model selection across hyperparameters using this held-out training domain validation accuracy often returns non-domain-general models in the ‘hard’ cases. One advantage of our model is that we can do model selection based on the TCRI condition (conditional independence between the two representations) on held-out training domain validation examples to mitigate this challenge. In the easy case, we expect the empirical risk minimizer to be domain-general, so selecting the best-performing training-domain model is sound – we additionally do this for all baselines (see Appendix A.1 for further discussion). We find that, empirically, this heuristic works in the examples we study in this work. Nevertheless, model selection under distribution shift remains a significant bottleneck for domain generalization.

5.2 Results and Discussion

Table 3: \etestetest\subscript𝑒𝑡𝑒𝑠𝑡subscript𝑒𝑡𝑒𝑠𝑡\mathcal{E}\backslash e_{test}\rightarrow e_{test}caligraphic_E \ italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT → italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT (model selection on held-out source domains validation set). The ‘mean’ column indicates the average generalization accuracy over all three domains as the etestsubscript𝑒𝑡𝑒𝑠𝑡e_{test}italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT distinctly; the ‘min’ column indicates the worst generalization accuracy.
ColoredMNIST Spurious PACS Terra Incognita
Algorithm average worst-case average worst-case average worst-case
ERM 51.6 ±plus-or-minus\pm± 0.1 10.0 ±plus-or-minus\pm± 0.1 57.2 ±plus-or-minus\pm± 0.7 31.2 ±plus-or-minus\pm± 1.3 44.2 ±plus-or-minus\pm± 1.8 35.1 ±plus-or-minus\pm± 2.8
IRM 51.7 ±plus-or-minus\pm± 0.1 9.9 ±plus-or-minus\pm± 0.1 54.7 ±plus-or-minus\pm± 0.8 30.3 ±plus-or-minus\pm± 0.3 38.9 ±plus-or-minus\pm± 3.7 32.6 ±plus-or-minus\pm± 4.7
GroupDRO 52.0 ±plus-or-minus\pm± 0.1 9.9 ±plus-or-minus\pm± 0.1 58.5 ±plus-or-minus\pm± 0.4 37.7 ±plus-or-minus\pm± 0.7 47.8 ±plus-or-minus\pm± 0.9 39.9 ±plus-or-minus\pm± 0.7
VREx 51.7 ±plus-or-minus\pm± 0.2 10.2 ±plus-or-minus\pm± 0.0 58.8 ±plus-or-minus\pm± 0.4 37.5 ±plus-or-minus\pm± 1.1 45.1 ±plus-or-minus\pm± 0.4 38.1 ±plus-or-minus\pm± 1.3
IB_ERM 51.5 ±plus-or-minus\pm± 0.2 10.0 ±plus-or-minus\pm± 0.1 56.3 ±plus-or-minus\pm± 1.1 35.5 ±plus-or-minus\pm± 0.4 46.0 ±plus-or-minus\pm± 1.4 39.3 ±plus-or-minus\pm± 1.1
IB_IRM 51.7 ±plus-or-minus\pm± 0.0 9.9 ±plus-or-minus\pm± 0.0 55.9 ±plus-or-minus\pm± 1.2 33.8 ±plus-or-minus\pm± 2.2 37.0 ±plus-or-minus\pm± 2.8 29.6 ±plus-or-minus\pm± 4.1
TCRI_HSIC 59.6 ±plus-or-minus\pm± 1.8 45.1 ±plus-or-minus\pm± 6.7 63.4 ±plus-or-minus\pm± 0.2 62.3 ±plus-or-minus\pm± 0.2 49.2 ±plus-or-minus\pm± 0.3 40.4 ±plus-or-minus\pm± 1.6
Table 4: Total Information Criterion: Domain General (DG) and Domain Specific (DS) Accuracies. The DG classifier is shared across all training domains, and the DS classifiers are trained on each domain. The first row indicates the domain from which the held-out examples are sampled, and the second indicates which domain-specific predictor is used. {+90%, +80%, -90%} indicate domains – {0.1,0.2,0.9}0.10.20.9\{0.1,0.2,0.9\}{ 0.1 , 0.2 , 0.9 } digit label and color correlation, respectively.
DG Classifier DS Classifier on +90 DS Classifier on +80 DS Classifier on -90
Test Domain No DS clf. +90% +80% -90% +90% +80% -90% +90% +80% -90% +90% +80% -90%
+90% 68.7 69.0 68.5 - 90.1 9.8 - 79.9 20.1 - 10.4 89.9
+80% 63.1 62.4 64.4 76.3 - 24.3 70.0 - 30.4 24.5 - 76.3
-90% 65.6 63.4 44.1 75.3 75.3 - 69.2 69.5 - 29.3 26.0 -
Worst-domain Accuracy.

A critical implication of domain generality is stability – robustness in worst-domain performance up to domain difficulty. While average accuracy across domains provides some insight into an algorithm’s ability to generalize to new domains, the average hides the variance of performance across domains. Average improvement can be increased while the worst-domain accuracy stays the same or decreases, leading to incorrect conclusions about domain generalization. Additionally, in real-world challenges such as algorithmic fairness where worst-group performance is considered, some metrics or fairness are analogous to achieving domain generalization (Creager et al., 2021).

Results.

TCRI achieves the highest average and worst-case accuracy across all baselines (Table 3). We find no method recovers the exact domain-general model’s accuracy of 75%percent7575\%75 %. However, TCRI achieves over 7% increase in both average accuracy and worst-case accuracy. Appendix A.2 shows transfer accuracies with cross-validation on held-out test domain examples (oracle) and TCRI again outperforms all baselines, achieving an average accuracy of 70.0% ±plus-or-minus\pm± 0.4% and a worst-case accuracy of 65.7% ±plus-or-minus\pm± 1.5, showing that regularizing for TCRI gives very close to optimal domain-general solutions.

Similarly, for the Spurious-PACS dataset, we observe that TCRI outperforms the baselines. TRCI achieves the highest average accuracy of 63.4%±0.2plus-or-minuspercent63.40.263.4\%\pm 0.263.4 % ± 0.2 and worst-case accuracy of 62.3%±0.1plus-or-minuspercent62.30.162.3\%\pm 0.162.3 % ± 0.1 with the next best, VREx, achieving 58.8±1.0plus-or-minus58.81.058.8\pm 1.058.8 ± 1.0 and 33.8±0.0plus-or-minus33.80.033.8\pm 0.033.8 ± 0.0, respectively. Additionally, for the Terra-Incognita dataset, TCRI achieves the highest average and worst-case accuracies of 49.2% ±plus-or-minus\pm± 0.3% and 40.4% ±plus-or-minus\pm± 1.6% with the next best, GroupDRO, achieving 47.8±0.9plus-or-minus47.80.947.8\pm 0.947.8 ± 0.9 and 39.9±0.7plus-or-minus39.90.739.9\pm 0.739.9 ± 0.7, respectively.

Appendix A.2 shows transfer accuracies with cross-validation held-out target domain examples (oracle) where we observe that TCRI also obtains the highest average and worst-case accuracy for Spurrious-PACS and Terra Incognita.

Overall, regularizing for TCRI gives the most domain-general solutions compared to our baselines, achieving the highest worst-case accuracy on all benchmarks. Additionally, TCRI achieves the highest average accuracy on ColoredMNIST and Spurious-PAC and the second highest on Terra Incognita, where we expect the empirical risk minimizer to be domain-general.

Additional results are provided in the Appendix A.

Table 5: TIC ablation for ColoredMNIST.
Algorithm average worst-case
TCRI_HSIC (No TIC) 51.8 ±plus-or-minus\pm± 5.9 27.7 ±plus-or-minus\pm± 8.9
TCRI_HSIC 59.6 ±plus-or-minus\pm± 1.8 45.1 ±plus-or-minus\pm± 6.7
The Effect of the Total Information Criterion.

Without the TIC loss term, our proposed method is less effective. Table 5 shows that for Colored MNIST, the hardest ‘hard’ case we encounter, removing the TIC criteria, performs worse in average and worst case accuracy, drop** over 8% and 18⁢, respectively.

Separation of Domain General and Domain Specific Features

. In the case of Colored MNIST, we can reason about the extent of feature disentanglement from the accuracies achieved by the domain-general and domain-specific predictors. Table 4 shows how much each component of ΦΦ\Phiroman_Φ, ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and ΦspusubscriptΦspu\Phi_{\text{spu}}roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT, behaves as expected. For each domain, we observe that the domain-specific predictors’ accuracies follow the same trend as the color-label correlation, indicating that they capture the color-label relationship. The domain-general predictor, however, does not follow such a trend, indicating that it is not using color as the predictor.

For example, when evaluating the domain-specific predictors from the +90% test domain experiment (row +90%) on held-out examples from the +80% training domain (column "DS Classifier on +80%"), we find that the +80% domain-specific predictor achieves an accuracy of nearly 79.9% – exactly what one would expect from a predictor that uses a color correlation with the same direction ‘+’. Conversely, the -90% predictor achieves an accuracy of 20.1%, exactly what one would expect from a predictor that uses a color correlation with the opposite direction ‘-’. The -90% domain has the opposite label-color pairing, so a color-based classifier will give the opposite label in any ‘+’ domain.

Another advantage of this method, exemplified by Table 4, is that if one believes a particular domain is close to one of the training domains, one can opt to use the close domain’s domain-specific predictor and leverage spurious information to improve performance.

On Benchmarking Domain Generalization.

Previous work on benchmarking domain generalization showed that across standard benchmarks, the domain-unaware empirical risk minimizer outperforms or achieves equivalent performance to the state-of-the-art domain generalization methods (Gulrajani and Lopez-Paz, 2020). Additionally, Rosenfeld et al. (2022) gives results that show weak conditions that define regimes where the empirical risk minimizer across domains is optimal in both average and worst-case accuracy. Consequently, to accurately evaluate our work and baselines, we focus on settings where it is clear that (i) the empirical risk minimizer fails, (ii) spurious features, as we have defined them, do not generalize across the observed domains, and (iii) there is room for improvement via better domain-general predictions. We discuss this point further in the Appendix A.1.

Oracle Transfer Accuracies.

While model selection is an integral part of the machine learning development cycle, it remains a non-trivial challenge when there is a distribution shift. While we have proposed a selection process tailored to our method that can be generalized to other methods with an assumed causal graph, we acknowledge that model selection under distribution shift is still an important open problem. Consequently, we disentangle this challenge from the learning problem and evaluate an algorithm’s capacity to give domain-general solutions independently of model selection. We report experimental reports using held-out test-set examples for model selection in Appendix A Table 6. We find that our method, TCRI_HSIC, also outperforms baselines in this setting.

6 Conclusion and Future Work

We reduce the gap in learning domain-general predictors by leveraging conditional independence properties implied by generative processes to identify domain-general mechanisms. We do this without independent observations of domain-general and spurious mechanisms and show that our framework outperforms other state-of-the-art domain-generalization algorithms on real-world datasets in average and worst-case across domains. Future work includes further improvements to the framework to fully recover the strict set of domain-general mechanisms and model selection strategies that preserve desired domain-general properties.

Acknowledgements

OS was partially supported by the UIUC Beckman Institute Graduate Research Fellowship, NSF-NRT 1735252. This work is partially supported by the NSF III 2046795, IIS 1909577, CCF 1934986, NIH 1R01MH116226-01A, NIFA award 2020-67021-32799, the Alfred P. Sloan Foundation, and Google Inc.

References

  • Ahuja et al. [2020] Kartik Ahuja, Karthikeyan Shanmugam, Kush Varshney, and Amit Dhurandhar. Invariant risk minimization games. In International Conference on Machine Learning, pages 145–155. PMLR, 2020.
  • Ahuja et al. [2021] Kartik Ahuja, Ethan Caballero, Dinghuai Zhang, Jean-Christophe Gagnon-Audet, Yoshua Bengio, Ioannis Mitliagkas, and Irina Rish. Invariance principle meets information bottleneck for out-of-distribution generalization. Advances in Neural Information Processing Systems, 34:3438–3450, 2021.
  • Arjovsky et al. [2019] Martín Arjovsky, L. Bottou, Ishaan Gulrajani, and David Lopez-Paz. Invariant risk minimization. ArXiv, abs/1907.02893, 2019.
  • Beery et al. [2018] Sara Beery, Grant Van Horn, and Pietro Perona. Recognition in terra incognita. In Proceedings of the European conference on computer vision (ECCV), pages 456–473, 2018.
  • Ben-David et al. [2009] Shai Ben-David, John Blitzer, K. Crammer, A. Kulesza, Fernando C Pereira, and Jennifer Wortman Vaughan. A theory of learning from different domains. Machine Learning, 79:151–175, 2009.
  • Bickel et al. [2009] Steffen Bickel, Michael Brückner, and Tobias Scheffer. Discriminative learning under covariate shift. Journal of Machine Learning Research, 10(9), 2009.
  • Blanchard et al. [2017] Gilles Blanchard, Aniket Anand Deshmukh, Urun Dogan, Gyemin Lee, and Clayton Scott. Domain generalization by marginal transfer learning. arXiv preprint arXiv:1711.07910, 2017.
  • Chen et al. [2016] Xiangli Chen, Mathew Monfort, Anqi Liu, and Brian D Ziebart. Robust covariate shift regression. In Artificial Intelligence and Statistics, pages 1270–1279. PMLR, 2016.
  • Courty et al. [2017] Nicolas Courty, Rémi Flamary, Amaury Habrard, and Alain Rakotomamonjy. Joint distribution optimal transportation for domain adaptation. Advances in Neural Information Processing Systems, 30, 2017.
  • Creager et al. [2021] Elliot Creager, Jörn-Henrik Jacobsen, and Richard Zemel. Environment inference for invariant learning. In International Conference on Machine Learning, pages 2189–2200. PMLR, 2021.
  • des Combes et al. [2020] Rémi Tachet des Combes, Han Zhao, Yu-Xiang Wang, and Geoffrey J. Gordon. Domain adaptation with conditional distribution matching and generalized label shift. ArXiv, abs/2003.04475, 2020.
  • Ganin et al. [2016] Yaroslav Ganin, Evgeniya Ustinova, Hana Ajakan, Pascal Germain, Hugo Larochelle, François Laviolette, Mario Marchand, and Victor Lempitsky. Domain-adversarial training of neural networks. The journal of machine learning research, 17(1):2096–2030, 2016.
  • Gretton et al. [2007] A. Gretton, K. Fukumizu, C. Teo, Le Song, B. Schölkopf, and Alex Smola. A kernel statistical test of independence. In NIPS, 2007.
  • Gretton et al. [2009] Arthur Gretton, Alex Smola, Jiayuan Huang, Marcel Schmittfull, Karsten Borgwardt, and Bernhard Schölkopf. Covariate shift by kernel mean matching. Dataset shift in machine learning, 3(4):5, 2009.
  • Gulrajani and Lopez-Paz [2020] Ishaan Gulrajani and David Lopez-Paz. In search of lost domain generalization. CoRR, abs/2007.01434, 2020. URL https://arxiv.longhoe.net/abs/2007.01434.
  • He et al. [2016] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
  • Hitchcock and Rédei [2021] Christopher Hitchcock and Miklós Rédei. Reichenbach’s Common Cause Principle. In Edward N. Zalta, editor, The Stanford Encyclopedia of Philosophy. Metaphysics Research Lab, Stanford University, Summer 2021 edition, 2021.
  • Huang et al. [2006] Jiayuan Huang, Arthur Gretton, Karsten Borgwardt, Bernhard Schölkopf, and Alex Smola. Correcting sample selection bias by unlabeled data. Advances in Neural Information Processing Systems, 19, 2006.
  • Kaur et al. [2022] Jivat Neet Kaur, Emre Kiciman, and Amit Sharma. Modeling the data-generating process is necessary for out-of-distribution generalization. arXiv preprint arXiv:2206.07837, 2022.
  • Kirichenko et al. [2022] Polina Kirichenko, Pavel Izmailov, and Andrew Gordon Wilson. Last layer re-training is sufficient for robustness to spurious correlations. arXiv preprint arXiv:2204.02937, 2022.
  • Kpotufe and Martinet [2018] Samory Kpotufe and Guillaume Martinet. Marginal singularity, and the benefits of labels in covariate-shift. In Sébastien Bubeck, Vianney Perchet, and Philippe Rigollet, editors, Proceedings of the 31st Conference On Learning Theory, volume 75 of Proceedings of Machine Learning Research, pages 1882–1886. PMLR, 06–09 Jul 2018. URL https://proceedings.mlr.press/v75/kpotufe18a.html.
  • Krueger et al. [2021] David Krueger, Ethan Caballero, Joern-Henrik Jacobsen, Amy Zhang, Jonathan Binas, Dinghuai Zhang, Remi Le Priol, and Aaron Courville. Out-of-distribution generalization via risk extrapolation (rex). In International Conference on Machine Learning, pages 5815–5826. PMLR, 2021.
  • Li et al. [2017] Da Li, Yongxin Yang, Yi-Zhe Song, and Timothy M Hospedales. Deeper, broader and artier domain generalization. In Proceedings of the IEEE international conference on computer vision, pages 5542–5550, 2017.
  • Li et al. [2018a] Da Li, Yongxin Yang, Yi-Zhe Song, and Timothy M Hospedales. Learning to generalize: Meta-learning for domain generalization. In Thirty-Second AAAI Conference on Artificial Intelligence, 2018a.
  • Li et al. [2018b] Haoliang Li, Sinno Jialin Pan, Shiqi Wang, and Alex C. Kot. Domain generalization with adversarial feature learning. In 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 5400–5409, 2018b. doi: 10.1109/CVPR.2018.00566.
  • Li et al. [2018c] Ya Li, Xinmei Tian, Mingming Gong, Ya**g Liu, Tongliang Liu, Kun Zhang, and D. Tao. Deep domain generalization via conditional invariant adversarial networks. In ECCV, 2018c.
  • Lipton et al. [2018] Zachary Chase Lipton, Yu-Xiang Wang, and Alex Smola. Detecting and correcting for label shift with black box predictors. ArXiv, abs/1802.03916, 2018.
  • Liu et al. [2021] Evan Z Liu, Behzad Haghgoo, Annie S Chen, Aditi Raghunathan, Pang Wei Koh, Shiori Sagawa, Percy Liang, and Chelsea Finn. Just train twice: Improving group robustness without training group information. In International Conference on Machine Learning, pages 6781–6792. PMLR, 2021.
  • Long et al. [2015] Mingsheng Long, Yue Cao, Jianmin Wang, and Michael I. Jordan. Learning transferable features with deep adaptation networks. ArXiv, abs/1502.02791, 2015.
  • Long et al. [2016] Mingsheng Long, Han Zhu, Jianmin Wang, and Michael I Jordan. Unsupervised domain adaptation with residual transfer networks. Advances in neural information processing systems, 29, 2016.
  • Makar et al. [2022] Maggie Makar, Ben Packer, Dan Moldovan, Davis Blalock, Yoni Halpern, and Alexander D’Amour. Causally motivated shortcut removal using auxiliary labels. In Gustau Camps-Valls, Francisco J. R. Ruiz, and Isabel Valera, editors, Proceedings of The 25th International Conference on Artificial Intelligence and Statistics, volume 151 of Proceedings of Machine Learning Research, pages 739–766. PMLR, 28–30 Mar 2022. URL https://proceedings.mlr.press/v151/makar22a.html.
  • Muandet et al. [2013] Krikamol Muandet, David Balduzzi, and Bernhard Schölkopf. Domain generalization via invariant feature representation. In International conference on machine learning, pages 10–18. PMLR, 2013.
  • Pearl [2010] J. Pearl. Causal inference. In NIPS Causality: Objectives and Assessment, 2010.
  • Peters et al. [2016] Jonas Peters, Peter Bühlmann, and Nicolai Meinshausen. Causal inference by using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society. Series B (Statistical Methodology), pages 947–1012, 2016.
  • Rédei [2002] Miklós Rédei. Reichenbach’s Common Cause Principle and Quantum Correlations, pages 259–270. Springer Netherlands, Dordrecht, 2002. ISBN 978-94-010-0385-8. doi: 10.1007/978-94-010-0385-8_17. URL https://doi.org/10.1007/978-94-010-0385-8_17.
  • Robey et al. [2021] Alexander Robey, George J Pappas, and Hamed Hassani. Model-based domain generalization. Advances in Neural Information Processing Systems, 34:20210–20229, 2021.
  • Rosenfeld et al. [2020] Elan Rosenfeld, Pradeep Ravikumar, and Andrej Risteski. The risks of invariant risk minimization. arXiv preprint arXiv:2010.05761, 2020.
  • Rosenfeld et al. [2022] Elan Rosenfeld, Pradeep Ravikumar, and Andrej Risteski. An online learning approach to interpolation and extrapolation in domain generalization. In International Conference on Artificial Intelligence and Statistics, pages 2641–2657. PMLR, 2022.
  • Sagawa et al. [2019] Shiori Sagawa, Pang Wei Koh, Tatsunori B Hashimoto, and Percy Liang. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. arXiv preprint arXiv:1911.08731, 2019.
  • Schneider et al. [2020] Steffen Schneider, Evgenia Rusak, Luisa Eck, Oliver Bringmann, Wieland Brendel, and Matthias Bethge. Improving robustness against common corruptions by covariate shift adaptation. Advances in Neural Information Processing Systems, 33:11539–11551, 2020.
  • Schrouff et al. [2022] Jessica Schrouff, Natalie Harris, Oluwasanmi Koyejo, Ibrahim Alabdulmohsin, Eva Schnider, Krista Opsahl-Ong, Alex Brown, Subhrajit Roy, Diana Mincu, Christina Chen, et al. Maintaining fairness across distribution shift: do we have viable solutions for real-world applications? arXiv preprint arXiv:2202.01034, 2022.
  • Shimodaira [2000] Hidetoshi Shimodaira. Improving predictive inference under covariate shift by weighting the log-likelihood function. Journal of statistical planning and inference, 90(2):227–244, 2000.
  • Sugiyama et al. [2007] Masashi Sugiyama, Shinichi Nakajima, Hisashi Kashima, Paul Buenau, and Motoaki Kawanabe. Direct importance estimation with model selection and its application to covariate shift adaptation. Advances in Neural Information Processing Systems, 20, 2007.
  • Vapnik [1991] Vladimir Vapnik. Principles of risk minimization for learning theory. In NIPS, volume 91, pages 831–840, 1991.
  • Veitch et al. [2021] Victor Veitch, Alexander D’Amour, Steve Yadlowsky, and Jacob Eisenstein. Counterfactual invariance to spurious correlations: Why and how to pass stress tests. arXiv preprint arXiv:2106.00545, 2021.
  • Wang et al. [2022] Haoxiang Wang, Haozhe Si, Bo Li, and Han Zhao. Provable domain generalization via invariant-feature subspace recovery. In ICML, 2022.
  • Zadrozny [2004] Bianca Zadrozny. Learning and evaluating classifiers under sample selection bias. In Proceedings of the twenty-first international conference on Machine learning, page 114, 2004.
  • Zhang et al. [2021] Marvin Zhang, Henrik Marklund, Nikita Dhawan, Abhishek Gupta, Sergey Levine, and Chelsea Finn. Adaptive risk minimization: Learning to adapt to domain shift. Advances in Neural Information Processing Systems, 34, 2021.
  • Zhao et al. [2019] H. Zhao, Rémi Tachet des Combes, Kun Zhang, and Geoffrey J. Gordon. On learning invariant representations for domain adaptation. In ICML, 2019.

Appendix A Additional Results and Discussion

A.1 On Benchmarking Domain Generalization

Table 6: Oracle (model selection on held-out target domain validation set) \etestetest\subscript𝑒𝑡𝑒𝑠𝑡subscript𝑒𝑡𝑒𝑠𝑡\mathcal{E}\backslash e_{test}\rightarrow e_{test}caligraphic_E \ italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT → italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT. The ‘mean’ column indicates the average generalization accuracy over all three domains as the etestsubscript𝑒𝑡𝑒𝑠𝑡e_{test}italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT distinctly; the ‘min’ column indicates the worst generalization accuracy.
ColoredMNIST Spurious PACS Terra Incognita
Algorithm average worst-case average worst-case average worst-case
ERM 57.8 ±plus-or-minus\pm± 0.2 38.4 ±plus-or-minus\pm± 1.4 59.2 ±plus-or-minus\pm± 1.3 38.4 ±plus-or-minus\pm± 1.4 52.9 ±plus-or-minus\pm± 0.8 42.0 ±plus-or-minus\pm± 0.6
IRM 68.9 ±plus-or-minus\pm± 1.6 62.0 ±plus-or-minus\pm± 4.9 67.5 ±plus-or-minus\pm± 5.8 53.9 ±plus-or-minus\pm± 6.6 42.6 ±plus-or-minus\pm± 4.0 42.7 ±plus-or-minus\pm± 1.2
GroupDRO 61.1 ±plus-or-minus\pm± 1.3 37.6 ±plus-or-minus\pm± 3.6 61.8 ±plus-or-minus\pm± 1.8 40.0 ±plus-or-minus\pm± 1.6 50.7 ±plus-or-minus\pm± 1.0 42.7 ±plus-or-minus\pm± 1.2
VREx 68.0 ±plus-or-minus\pm± 2.5 59.4 ±plus-or-minus\pm± 7.3 62.8 ±plus-or-minus\pm± 2.4 38.7 ±plus-or-minus\pm± 0.9 43.2 ±plus-or-minus\pm± 2.0 34.9 ±plus-or-minus\pm± 4.2
IB_ERM 65.0 ±plus-or-minus\pm± 0.1 50.6 ±plus-or-minus\pm± 0.3 67.3 ±plus-or-minus\pm± 3.7 53.1 ±plus-or-minus\pm± 8.0 49.0 ±plus-or-minus\pm± 0.3 39.9 ±plus-or-minus\pm± 0.8
IB_IRM 68.4 ±plus-or-minus\pm± 1.0 58.5 ±plus-or-minus\pm± 2.8 69.0 ±plus-or-minus\pm± 1.3 62.3 ±plus-or-minus\pm± 0.3 32.8 ±plus-or-minus\pm± 6.6 20.4 ±plus-or-minus\pm± 7.5
TCRI_HSIC 70.4 ±plus-or-minus\pm± 0.4 65.7 ±plus-or-minus\pm± 1.5 69.5 ±plus-or-minus\pm± 1.1 62.3 ±plus-or-minus\pm± 0.2 51.2 ±plus-or-minus\pm± 0.1 43.0 ±plus-or-minus\pm± 0.4
Oracle Transfer Accuracies.

While model selection is an integral part of the machine learning development cycle, it remains a non-trivial challenge when there is a distribution shift. While we have proposed a selection process tailored to our method that can be generalized to other methods with an assumed causal graph, we acknowledge that model selection under distribution shift is still an important open problem. Consequently, we disentangle this challenge from the learning problem and evaluate an algorithm’s capacity to give domain-general solutions independently of model selection. We report experimental reports using held-out test-set examples for model selection in Appendix A Table 6.

In this case, we find that there is indeed a separation between ERM and some domain-generalization algorithms, suggesting that model selection might be a substantial bottleneck for learning domain-general predictors. Nevertheless, we still find that our method, TCRI_HSIC, also outperforms baselines in this setting.

Challenges of Benchmarking Domaing Generalization.

We show some results below that illustrate the challenge of accurately evaluating the efficacy of an algorithm for domain generalization. We first note that we expect ERM (naive) to perform poorly in domain generalization tasks, certainly so when we observe worst-case shifts at test time. However, like other works [Gulrajani and Lopez-Paz, 2020], we observe that ERM performs as well as other baselines during transfer on various benchmark datasets. Previous theoretical results [Rosenfeld et al., 2022] suggest that this observation is indicative of properties of the benchmark domains that may be sufficient for ERM to give domain-general solutions - specifically that the distribution (and equivalently the loss) of the target domain can be written as a convex combination of the those in the source domains.

To further investigate this, we develop additional experiments motivated by the ColoredMNIST [Arjovsky et al., 2019] – since its generative process is well understood. We note that in the +90%, +80%, and -90% domains of ColoredMNIST, the -90% domain has the opposite relationship between the spurious correlation and the label, so the use of spurious correlations from {+90%, +80%} generalizes catastrophically to the -90% domain. In this setting, the baseline algorithms we present, including ERM, achieve poor accuracy in the -90% domain while maintaining high accuracy in the +90% and +80% domains. Consequently, we investigate two settings, setting a: observe {+90%, +80%, +70%, -90%} domains and setting b: observe {+90%, +80%, -80%, -90%} domains – we focus on generalizing to the -90% domain. In setting a, we add another domain with the majority direction in the relationship between spurious correlation and labels. In setting b, we add another domain with the minority direction. Note that in setting a, the closest domain to -90% that can be generated with a convex combination of the other domains still has a ‘+’ correlation between the color and label. In setting b, however, one can generate a domain with a ‘-’ correlation between color and label with a convex combination of the other domains. Thus, we expect the empirical risk minimizer to give domain-general solutions in setting b but not in setting a.

We use Oracle model selection (held-out target data) to remove the effect of model selection for all methods in the results. We find that in setting a, where we add a domain (+70%), we observe that the generalization accuracy to the -90% domain is still very different from the other domains (Table 7).

Table 7: ColoredMNIST setting a. Columns {+90%, +80%, +70%, -90%} indicate domains – {0.1,0.2,0.3,0.9}0.10.20.30.9\{0.1,0.2,0.3,0.9\}{ 0.1 , 0.2 , 0.3 , 0.9 } digit label and color correlation, respectively. We report domain accuracies over 3 trials each. We use the oracle selection method – held out target data. \etestetest\subscript𝑒𝑡𝑒𝑠𝑡subscript𝑒𝑡𝑒𝑠𝑡\mathcal{E}\backslash e_{test}\rightarrow e_{test}caligraphic_E \ italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT → italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT.
Algorithm +90% +80% +70% -90%
ERM 72.8 ±plus-or-minus\pm± 0.3 74.7 ±plus-or-minus\pm± 0.3 73.3 ±plus-or-minus\pm± 0.1 16.3 ±plus-or-minus\pm± 1.5
IRM 49.0 ±plus-or-minus\pm± 0.1 54.2 ±plus-or-minus\pm± 2.0 50.3 ±plus-or-minus\pm± 0.3 43.8 ±plus-or-minus\pm± 2.8
GroupDRO 71.0 ±plus-or-minus\pm± 0.6 72.2 ±plus-or-minus\pm± 0.3 70.7 ±plus-or-minus\pm± 0.9 36.4 ±plus-or-minus\pm± 4.2
VREx 74.1 ±plus-or-minus\pm± 1.3 72.6 ±plus-or-minus\pm± 0.5 72.1 ±plus-or-minus\pm± 0.5 19.5 ±plus-or-minus\pm± 5.5
TCRI_HSIC 72.1 ±plus-or-minus\pm± 1.5 73.6 ±plus-or-minus\pm± 0.4 72.6 ±plus-or-minus\pm± 0.4 49.9 ±plus-or-minus\pm± 0.3

However, in setting b, where we add a domain (-80%), we observe that the generalization accuracy to the -90% domain is on par with the other domains (Table 8).

Table 8: ColoredMNIST setting b. Columns {+90%, +80%, -80%, -90%} indicate domains – {0.1,0.2,0.8,0.9}0.10.20.80.9\{0.1,0.2,0.8,0.9\}{ 0.1 , 0.2 , 0.8 , 0.9 } digit label and color correlation, respectively. We report the average domain accuracies over 3 trials each. We use the oracle selection method – held out target data. \etestetest\subscript𝑒𝑡𝑒𝑠𝑡subscript𝑒𝑡𝑒𝑠𝑡\mathcal{E}\backslash e_{test}\rightarrow e_{test}caligraphic_E \ italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT → italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT.
Algorithm +90% +80% -80% -90%
ERM 58.4 ±plus-or-minus\pm± 1.3 67.0 ±plus-or-minus\pm± 0.5 64.2 ±plus-or-minus\pm± 2.0 52.6 ±plus-or-minus\pm± 3.2
IRM 56.7 ±plus-or-minus\pm± 3.3 56.6 ±plus-or-minus\pm± 2.8 51.6 ±plus-or-minus\pm± 0.7 51.7 ±plus-or-minus\pm± 0.7
GroupDRO 69.7 ±plus-or-minus\pm± 0.8 71.7 ±plus-or-minus\pm± 0.3 72.0 ±plus-or-minus\pm± 0.2 71.4 ±plus-or-minus\pm± 1.9
VREx 67.4 ±plus-or-minus\pm± 1.9 70.4 ±plus-or-minus\pm± 0.1 71.2 ±plus-or-minus\pm± 0.2 59.4 ±plus-or-minus\pm± 4.3
TCRI_HSIC 62.2 ±plus-or-minus\pm± 4.4 70.0 ±plus-or-minus\pm± 1.3 67.9 ±plus-or-minus\pm± 1.4 65.4 ±plus-or-minus\pm± 2.8

This illustrates the challenge of accurately evaluating an algorithm’s ability to give domain-general predictions. We note that it is generally difficult to distinguish between setting a and setting b. The primary signature we see is some consistency between the empirical risk minimizer and the other baselines. Gulrajani and Lopez-Paz [2020] observe a similar trend for standard benchmarks for domain generalization. Hence, we focus our empirical evaluations in this work on settings where we know that the ERM solution fails by design.

A.2 ColoredMNIST

ColoredMNIST: The ColoredMNIST dataset [Arjovsky et al., 2019] is composed of 7000700070007000 (2×28×28228282\times 28\times 282 × 28 × 28, 1111) images of a hand-written digit and binary-label pairs. There are three domains with different correlations between image color and label, i.e., the image color is spuriously related to the label by assigning a color to each of the two classes (0: digits 0-4, 1: digits 5-9). The color is then flipped with probabilities {0.1,0.2,0.9}0.10.20.9\{0.1,0.2,0.9\}{ 0.1 , 0.2 , 0.9 } to create three domains, making the color-label relationship domain-specific because it changes across domains. There is also label flip noise of 0.250.250.250.25, so we expect that the best accuracy a domain-general model can achieve is 75%, while a non-domain general model can achieve higher. In this dataset, Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT corresponds to the original image, Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT the color, e𝑒eitalic_e the label-color correlation, Y𝑌Yitalic_Y the image label, and X𝑋Xitalic_X the observed colored image. This DAG follows the generative process of Figure 2(a)

Table 9: ColoredMNIST Hyperparameters. Additional hyperparameters are provided in https://github.com/olawalesalaudeen/tcri.
Algorithm Hyperparameter Default Random Distribution
All Learning Rate 13superscript131^{-3}1 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT 10Uniform(4.5,2.5)superscript10Uniform4.52.510^{\text{Uniform}(-4.5,-2.5)}10 start_POSTSUPERSCRIPT Uniform ( - 4.5 , - 2.5 ) end_POSTSUPERSCRIPT
Batch Size 64 2Uniform(3,9)superscript2Uniform392^{\text{Uniform}(3,9)}2 start_POSTSUPERSCRIPT Uniform ( 3 , 9 ) end_POSTSUPERSCRIPT
TCRI β𝛽\betaitalic_β penalty weight 100100100100 10Uniform(1,5)superscript10Uniform1510^{\text{Uniform}}(-1,5)10 start_POSTSUPERSCRIPT Uniform end_POSTSUPERSCRIPT ( - 1 , 5 )
annealing steps 500500500500 10Uniform(2.5,5)superscript10Uniform2.5510^{\text{Uniform}}(2.5,5)10 start_POSTSUPERSCRIPT Uniform end_POSTSUPERSCRIPT ( 2.5 , 5 )
Table 10: MNIST ConvNet architecture. All convolutions use 3×\times×3 kernels and "same" padding.
# Layer
1 Conv2D (in=d, out=64)
2 ReLU
3 GroupNorm (groups=8)
4 Conv2D (in=64, out=128, stride=2)
5 ReLU
6 GroupNorm (groups=8)
7 Conv2D (in=128, out=128)
8 ReLU
9 GroupNorm (groups=8)
10 Conv2D (in=128, out=128)
11 ReLU
12 GroupNorm (8 groups)
13 Global average-pooling

We use MNIST-ConvNet [Gulrajani and Lopez-Paz, 2020] backbones for the MNIST datasets (Table 10). Both ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and ΦspusubscriptΦspu\Phi_{\text{spu}}roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT are linear layers of size 128×128128128128\times 128128 × 128 that are appended to the backbone. The predictors (classification hyperplanes) θc,{θ1,θ2}subscript𝜃𝑐subscript𝜃1subscript𝜃2\theta_{c},\,\{\theta_{1},\,\theta_{2}\}italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , { italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT } are also parameterized to be linear and appended to the ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and ΦΦ\Phiroman_Φ, respectively.

We do a random search to select hyperparameters using the same scheme as Gulrajani and Lopez-Paz [2020] (https://github.com/facebookresearch/DomainBed). We select 25 hyperparameters with 5 random restarts each to generate error bars.

We show transfer accuracies with both source and target domain validation for model selection in Tables 11-12. We find that TCRI outperforms all baselines in average and worst-case accuracy.

Table 11: ColoredMNIST Transfer Accuracy – model selection on held-out source validation set. Columns {+90%, +80%, -90%} indicate domains – {0.1,0.2,0.9}0.10.20.9\{0.1,0.2,0.9\}{ 0.1 , 0.2 , 0.9 } digit label and color correlation, respectively. \etestetest\subscript𝑒𝑡𝑒𝑠𝑡subscript𝑒𝑡𝑒𝑠𝑡\mathcal{E}\backslash e_{test}\rightarrow e_{test}caligraphic_E \ italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT → italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT.
Domains Domain Accuracy Statistics
Algorithm +90% +80% -90% Avg Std Min
ERM 71.6 ±plus-or-minus\pm± 0.3 73.1 ±plus-or-minus\pm± 0.1 10.0 ±plus-or-minus\pm± 0.1 51.6 ±plus-or-minus\pm± 0.1 29.4 ±plus-or-minus\pm± 0.1 10.0 ±plus-or-minus\pm± 0.1
IRM 72.1 ±plus-or-minus\pm± 0.1 73.0 ±plus-or-minus\pm± 0.3 9.9 ±plus-or-minus\pm± 0.1 51.7 ±plus-or-minus\pm± 0.1 29.5 ±plus-or-minus\pm± 0.1 9.9 ±plus-or-minus\pm± 0.1
GroupDRO 72.6 ±plus-or-minus\pm± 0.2 73.4 ±plus-or-minus\pm± 0.2 9.9 ±plus-or-minus\pm± 0.1 52.0 ±plus-or-minus\pm± 0.1 29.8 ±plus-or-minus\pm± 0.1 9.9 ±plus-or-minus\pm± 0.1
VREx 72.2 ±plus-or-minus\pm± 0.2 72.7 ±plus-or-minus\pm± 0.3 10.2 ±plus-or-minus\pm± 0.0 51.7 ±plus-or-minus\pm± 0.2 29.3 ±plus-or-minus\pm± 0.1 10.2 ±plus-or-minus\pm± 0.0
IB_ERM 71.0 ±plus-or-minus\pm± 0.4 73.4 ±plus-or-minus\pm± 0.3 10.0 ±plus-or-minus\pm± 0.1 51.5 ±plus-or-minus\pm± 0.2 29.4 ±plus-or-minus\pm± 0.1 10.0 ±plus-or-minus\pm± 0.1
IB_IRM 71.7 ±plus-or-minus\pm± 0.2 73.4 ±plus-or-minus\pm± 0.1 9.9 ±plus-or-minus\pm± 0.0 51.7 ±plus-or-minus\pm± 0.0 29.5 ±plus-or-minus\pm± 0.0 9.9 ±plus-or-minus\pm± 0.0
TCRI_HSIC 67.2 ±plus-or-minus\pm± 2.3 65.6 ±plus-or-minus\pm± 3.4 45.9 ±plus-or-minus\pm± 6.9 59.6 ±plus-or-minus\pm± 1.8 11.4 ±plus-or-minus\pm± 3.3 45.1 ±plus-or-minus\pm± 6.7
Table 12: Oracle ColoredMNIST Transfer Accuracy – model selection on held-out target validation set accuracy. Columns {+90%, +80%, -90%} indicate domains – {0.1,0.2,0.9}0.10.20.9\{0.1,0.2,0.9\}{ 0.1 , 0.2 , 0.9 } digit label and color correlation, respectively. \etestetest\subscript𝑒𝑡𝑒𝑠𝑡subscript𝑒𝑡𝑒𝑠𝑡\mathcal{E}\backslash e_{test}\rightarrow e_{test}caligraphic_E \ italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT → italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT.
ColoredMNIST Spurious PACS Terra Incognita
Algorithm average worst-case average worst-case average worst-case
ERM 57.8 ±plus-or-minus\pm± 0.2 38.4 ±plus-or-minus\pm± 1.4 59.2 ±plus-or-minus\pm± 1.3 38.4 ±plus-or-minus\pm± 1.4 52.9 ±plus-or-minus\pm± 0.8 42.0 ±plus-or-minus\pm± 0.6
IRM 68.9 ±plus-or-minus\pm± 1.6 62.0 ±plus-or-minus\pm± 4.9 67.5 ±plus-or-minus\pm± 5.8 53.9 ±plus-or-minus\pm± 6.6 42.6 ±plus-or-minus\pm± 4.0 42.7 ±plus-or-minus\pm± 1.2
GroupDRO 61.1 ±plus-or-minus\pm± 1.3 37.6 ±plus-or-minus\pm± 3.6 61.8 ±plus-or-minus\pm± 1.8 40.0 ±plus-or-minus\pm± 1.6 50.7 ±plus-or-minus\pm± 1.0 42.7 ±plus-or-minus\pm± 1.2
VREx 68.0 ±plus-or-minus\pm± 2.5 59.4 ±plus-or-minus\pm± 7.3 62.8 ±plus-or-minus\pm± 2.4 38.7 ±plus-or-minus\pm± 0.9 43.2 ±plus-or-minus\pm± 2.0 34.9 ±plus-or-minus\pm± 4.2
IB_ERM 65.0 ±plus-or-minus\pm± 0.1 50.6 ±plus-or-minus\pm± 0.3 67.3 ±plus-or-minus\pm± 3.7 53.1 ±plus-or-minus\pm± 8.0 49.0 ±plus-or-minus\pm± 0.3 39.9 ±plus-or-minus\pm± 0.8
IB_IRM 68.4 ±plus-or-minus\pm± 1.0 58.5 ±plus-or-minus\pm± 2.8 69.0 ±plus-or-minus\pm± 1.3 62.3 ±plus-or-minus\pm± 0.3 32.8 ±plus-or-minus\pm± 6.6 20.4 ±plus-or-minus\pm± 7.5
TCRI_HSIC 70.4 ±plus-or-minus\pm± 0.4 65.7 ±plus-or-minus\pm± 1.5 69.5 ±plus-or-minus\pm± 1.1 62.3 ±plus-or-minus\pm± 0.2 51.2 ±plus-or-minus\pm± 0.1 43.0 ±plus-or-minus\pm± 0.4

A.3 Spurrious PACS

Spurious–PACS. Variables. X𝑋Xitalic_X: images, Y𝑌Yitalic_Y: non-urban (elephant, giraffe, horse) vs. urban (dog, guitar, house, person). Domains. {{cartoon, art painting}, {art painting, cartoon}, {photo}} [Li et al., 2017]. The photo domain is the same as in the original dataset. In the {cartoon, art painting} domain, urban examples are selected from the original cartoon domain, while non-urban examples are selected from the original art painting domain. In the {art painting, cartoon} domain, urban examples are selected from the original art painting domain, while non-urban examples are selected from the original cartoon domain. This sampling encourages the model to use spurious correlations (domain-related information) to predict the labels; however, since these relationships are flipped between domains {{cartoon, art painting} and {art painting, cartoon}, these predictions will be wrong when generalized to other domains.

Table 13: Spurrious PACS Hyperparameters. Additional hyperparameters provided in https://github.com/olawalesalaudeen/tcri.
Algorithm Hyperparameter Default Range
All Learning Rate 13superscript131^{-3}1 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT 10Uniform(4.5,2.5)superscript10Uniform4.52.510^{\text{Uniform}(-4.5,-2.5)}10 start_POSTSUPERSCRIPT Uniform ( - 4.5 , - 2.5 ) end_POSTSUPERSCRIPT
Batch Size 64 2Uniform(3,9)superscript2Uniform392^{\text{Uniform}(3,9)}2 start_POSTSUPERSCRIPT Uniform ( 3 , 9 ) end_POSTSUPERSCRIPT
TCRI β𝛽\betaitalic_β penalty weight 100100100100 10Uniform(1,5)superscript10Uniform1510^{\text{Uniform}}(-1,5)10 start_POSTSUPERSCRIPT Uniform end_POSTSUPERSCRIPT ( - 1 , 5 )
annealing steps 500500500500 10Uniform(2.5,5)superscript10Uniform2.5510^{\text{Uniform}}(2.5,5)10 start_POSTSUPERSCRIPT Uniform end_POSTSUPERSCRIPT ( 2.5 , 5 )

We use a ResNet-50 backbone [He et al., 2016]. ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and ΦspusubscriptΦspu\Phi_{\text{spu}}roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT are linear layers of size 2048×2048204820482048\times 20482048 × 2048 that are appended to the backbone. The predictors (classification hyperplanes) θc,{θ1,θ2,θ3}subscript𝜃𝑐subscript𝜃1subscript𝜃2subscript𝜃3\theta_{c},\{\theta_{1},\theta_{2},\theta_{3}\}italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , { italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT } are linear and appended to ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and ΦΦ\Phiroman_Φ layers, respectively.

Hyperparameters:

We do a random search to select hyperparameters using the same scheme as Gulrajani and Lopez-Paz [2020] (https://github.com/facebookresearch/DomainBed). We select 5 hyperparameters with 3 random restarts each to generate error bars.

We show transfer accuracies with both source and target domain validation for model selection in Tables A.3-15. We find that TCRI outperforms all baselines in average and worst-case accuracy.

Table 14: Spurious–PACS Transfer Accuracy – model selection on held-out source validation set. \etestetest\subscript𝑒𝑡𝑒𝑠𝑡subscript𝑒𝑡𝑒𝑠𝑡\mathcal{E}\backslash e_{test}\rightarrow e_{test}caligraphic_E \ italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT → italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT.
Table 15: Oracle Spurious–PACS Transfer Accuracy – model selection on held-out target validation set. \etestetest\subscript𝑒𝑡𝑒𝑠𝑡subscript𝑒𝑡𝑒𝑠𝑡\mathcal{E}\backslash e_{test}\rightarrow e_{test}caligraphic_E \ italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT → italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT.
Domains Domain Accuracy Statistics
Algorithm C x A A x C P mean std min
ERM 38.4 ±plus-or-minus\pm± 1.4 43.4 ±plus-or-minus\pm± 1.9 95.9 ±plus-or-minus\pm± 0.6 59.2 26.0 38.4
IRM 62.8 ±plus-or-minus\pm± 0.1 53.9 ±plus-or-minus\pm± 6.6 85.8 ±plus-or-minus\pm± 8.2 67.5 13.4 53.9
GroupDRO 40.0 ±plus-or-minus\pm± 1.6 49.7 ±plus-or-minus\pm± 2.9 95.7 ±plus-or-minus\pm± 0.6 61.8 24.3 40.0
VREx 55.8 ±plus-or-minus\pm± 5.5 38.7 ±plus-or-minus\pm± 0.9 93.8 ±plus-or-minus\pm± 0.8 62.8 23.0 38.7
IB_ERM 53.1 ±plus-or-minus\pm± 8.0 55.4 ±plus-or-minus\pm± 5.7 93.5 ±plus-or-minus\pm± 1.8 67.3 18.5 53.1
IB_IRM 62.8 ±plus-or-minus\pm± 0.1 62.3 ±plus-or-minus\pm± 0.3 81.8 ±plus-or-minus\pm± 7.0 69.0 9.1 62.3
TCRI_HSIC 64.0 ±plus-or-minus\pm± 0.7 62.3 ±plus-or-minus\pm± 0.2 82.4 ±plus-or-minus\pm± 5.7 69.5 9.1 62.3

A.4 Terra Incognita

The Terra Incognita dataset contains subsets of the Caltech Camera Traps dataset [Beery et al., 2018] defined by [Gulrajani and Lopez-Paz, 2020]. Four domains represent different locations {L100, L38, L43, L46} of cameras in the American Southwest. There are 10 different species of wild animals {bird, bobcat, cat, coyote, dog, empty, opossum, rabbit, raccoon, squirrel} (classes) to be predicted. Like Ahuja et al. [2021], we classify this dataset as following the generative process in Figure 2(c), the Fully Informative Invariant Features (FIIF) setting.

Table 16: Terra Incognita Hyperparameters. Additional hyperparameters provided in https://github.com/olawalesalaudeen/tcri.
Algorithm Hyperparameter Default Range
All Learning Rate 13superscript131^{-3}1 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT 10Uniform(4.5,2.5)superscript10Uniform4.52.510^{\text{Uniform}(-4.5,-2.5)}10 start_POSTSUPERSCRIPT Uniform ( - 4.5 , - 2.5 ) end_POSTSUPERSCRIPT
Batch Size 64 2Uniform(3,9)superscript2Uniform392^{\text{Uniform}(3,9)}2 start_POSTSUPERSCRIPT Uniform ( 3 , 9 ) end_POSTSUPERSCRIPT
TCRI β𝛽\betaitalic_β penalty weight 100100100100 10Uniform(1,5)superscript10Uniform1510^{\text{Uniform}}(-1,5)10 start_POSTSUPERSCRIPT Uniform end_POSTSUPERSCRIPT ( - 1 , 5 )
annealing steps 500500500500 10Uniform(0,4)superscript10Uniform0410^{\text{Uniform}}(0,4)10 start_POSTSUPERSCRIPT Uniform end_POSTSUPERSCRIPT ( 0 , 4 )

We use a ResNet-50 backbone [He et al., 2016]. ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and ΦspusubscriptΦspu\Phi_{\text{spu}}roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT are linear layers of size 2048×2048204820482048\times 20482048 × 2048 that are appended to the backbone. The predictors (classification hyperplanes) θc,{θ1,θ2,θ3,θ4}subscript𝜃𝑐subscript𝜃1subscript𝜃2subscript𝜃3subscript𝜃4\theta_{c},\{\theta_{1},\theta_{2},\theta_{3},\theta_{4}\}italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , { italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT } are linear and appended to ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and ΦΦ\Phiroman_Φ layers, respectively.

Hyperparameters:

We do a random search to select hyperparameters using the same scheme as Gulrajani and Lopez-Paz [2020] (https://github.com/facebookresearch/DomainBed). We select 5 hyperparameters with 3 random restarts each to generate error bars.

We show transfer accuracies with both source and target domain validation for model selection in Tables 17-18. We find that TCRI outperforms all baselines except ERM on average and outperforms all baselines in worst-case accuracy.

Table 17: Terra Incognita Transfer Accuracy – model selection on held-out source validation set. \etestetest\subscript𝑒𝑡𝑒𝑠𝑡subscript𝑒𝑡𝑒𝑠𝑡\mathcal{E}\backslash e_{test}\rightarrow e_{test}caligraphic_E \ italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT → italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT.
Domains Domain Accuracy Statistics
Algorithm L100 L38 L43 L46 Avg Std Min
ERM 43.6 ±plus-or-minus\pm± 3.9 45.2 ±plus-or-minus\pm± 0.6 53.0 ±plus-or-minus\pm± 1.2 35.1 ±plus-or-minus\pm± 2.8 44.2 ±plus-or-minus\pm± 1.8 6.8 ±plus-or-minus\pm± 1.0 35.1 ±plus-or-minus\pm± 2.8
IRM 43.9 ±plus-or-minus\pm± 3.3 35.7 ±plus-or-minus\pm± 4.0 37.7 ±plus-or-minus\pm± 7.8 38.3 ±plus-or-minus\pm± 2.4 38.9 ±plus-or-minus\pm± 3.7 5.4 ±plus-or-minus\pm± 1.8 32.6 ±plus-or-minus\pm± 4.7
GroupDRO 53.8 ±plus-or-minus\pm± 4.6 40.5 ±plus-or-minus\pm± 0.7 55.3 ±plus-or-minus\pm± 1.5 41.8 ±plus-or-minus\pm± 1.1 47.8 ±plus-or-minus\pm± 0.9 7.7 ±plus-or-minus\pm± 0.9 39.9 ±plus-or-minus\pm± 0.7
VREx 48.8 ±plus-or-minus\pm± 2.0 38.1 ±plus-or-minus\pm± 1.3 54.4 ±plus-or-minus\pm± 0.6 39.0 ±plus-or-minus\pm± 1.4 45.1 ±plus-or-minus\pm± 0.4 7.0 ±plus-or-minus\pm± 0.9 38.1 ±plus-or-minus\pm± 1.3
IB_ERM 46.1 ±plus-or-minus\pm± 4.5 40.7 ±plus-or-minus\pm± 0.7 55.2 ±plus-or-minus\pm± 0.8 42.2 ±plus-or-minus\pm± 1.1 46.0 ±plus-or-minus\pm± 1.4 6.4 ±plus-or-minus\pm± 0.8 39.3 ±plus-or-minus\pm± 1.1
IB_IRM 39.7 ±plus-or-minus\pm± 7.3 40.8 ±plus-or-minus\pm± 2.3 34.7 ±plus-or-minus\pm± 4.3 32.9 ±plus-or-minus\pm± 2.6 37.0 ±plus-or-minus\pm± 2.8 6.7 ±plus-or-minus\pm± 1.3 29.6 ±plus-or-minus\pm± 4.1
TCRI_HSIC 54.6 ±plus-or-minus\pm± 2.4 48.6 ±plus-or-minus\pm± 2.0 53.2 ±plus-or-minus\pm± 1.0 40.4 ±plus-or-minus\pm± 1.6 49.2 ±plus-or-minus\pm± 0.3 6.1 ±plus-or-minus\pm± 1.1 40.4 ±plus-or-minus\pm± 1.6
Table 18: Oracle Terra Incognita Transfer Accuracy – model selection on held-out target validation set. \etestetest\subscript𝑒𝑡𝑒𝑠𝑡subscript𝑒𝑡𝑒𝑠𝑡\mathcal{E}\backslash e_{test}\rightarrow e_{test}caligraphic_E \ italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT → italic_e start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT.
Domains Domain Accuracy Statistics
Algorithm L100 L38 L43 L46 Avg Std Min
ERM 58.5 ±plus-or-minus\pm± 1.8 52.0 ±plus-or-minus\pm± 1.3 59.2 ±plus-or-minus\pm± 0.2 42.0 ±plus-or-minus\pm± 0.6 52.9 ±plus-or-minus\pm± 0.8 7.0 ±plus-or-minus\pm± 0.5 42.0 ±plus-or-minus\pm± 0.6
IRM 53.0 ±plus-or-minus\pm± 0.9 48.0 ±plus-or-minus\pm± 1.8 36.3 ±plus-or-minus\pm± 9.6 33.2 ±plus-or-minus\pm± 3.9 42.6 ±plus-or-minus\pm± 4.0 9.6 ±plus-or-minus\pm± 1.7 30.8 ±plus-or-minus\pm± 5.4
GroupDRO 56.2 ±plus-or-minus\pm± 3.0 45.2 ±plus-or-minus\pm± 2.3 58.0 ±plus-or-minus\pm± 0.2 43.3 ±plus-or-minus\pm± 0.7 50.7 ±plus-or-minus\pm± 1.0 6.9 ±plus-or-minus\pm± 0.9 42.7 ±plus-or-minus\pm± 1.2
VREx 43.2 ±plus-or-minus\pm± 1.5 49.3 ±plus-or-minus\pm± 1.2 41.5 ±plus-or-minus\pm± 7.8 38.9 ±plus-or-minus\pm± 1.1 43.2 ±plus-or-minus\pm± 2.0 6.5 ±plus-or-minus\pm± 1.8 34.9 ±plus-or-minus\pm± 4.2
IB_ERM 55.6 ±plus-or-minus\pm± 1.7 47.2 ±plus-or-minus\pm± 1.1 53.4 ±plus-or-minus\pm± 0.7 39.9 ±plus-or-minus\pm± 0.8 49.0 ±plus-or-minus\pm± 0.3 6.4 ±plus-or-minus\pm± 0.5 39.9 ±plus-or-minus\pm± 0.8
IB_IRM 40.2 ±plus-or-minus\pm± 8.2 31.9 ±plus-or-minus\pm± 11.8 29.4 ±plus-or-minus\pm± 4.4 29.7 ±plus-or-minus\pm± 3.8 32.8 ±plus-or-minus\pm± 6.6 8.2 ±plus-or-minus\pm± 1.0 20.4 ±plus-or-minus\pm± 7.5
TCRI_HSIC 57.7 ±plus-or-minus\pm± 1.8 50.1 ±plus-or-minus\pm± 1.8 54.1 ±plus-or-minus\pm± 0.6 43.0 ±plus-or-minus\pm± 0.4 51.2 ±plus-or-minus\pm± 0.1 5.8 ±plus-or-minus\pm± 0.7 43.0 ±plus-or-minus\pm± 0.4

Appendix B DAGs

e𝑒eitalic_eZdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPTZspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPTY𝑌Yitalic_YX𝑋Xitalic_X
Figure 4: Partial Ancestral Graph (PAG). Dashed edges indicate that the edge may or may not exist. The combination of YZdgZspu𝑌subscript𝑍dgsubscript𝑍spuY\rightarrow Z_{\text{dg}}\rightarrow Z_{\text{spu}}italic_Y → italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT → italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT, and YZdg𝑌subscript𝑍dgY\rightarrow Z_{\text{dg}}italic_Y → italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT, eZdg𝑒subscript𝑍dge\rightarrow Z_{\text{dg}}italic_e → italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT is not allowed.

B.1 On Valid DAGS:

We consider other edges that could be introduced to Figure 4 where Zdg⟂̸Zspu|Y,eZ_{\text{dg}}\not\perp\!\!\!\perp Z_{\text{spu}}\,|\,Y,eitalic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂̸ ⟂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | italic_Y , italic_e, Zspu⟂̸Y|ZdgZ_{\text{spu}}\not\perp\!\!\!\perp Y\,|\,Z_{\text{dg}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ⟂̸ ⟂ italic_Y | italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT, or are not included in Figure 5. We then show that these edges either make the problem intractable or require new assumptions about the generative process – note we do not discuss edges that induce a cycle, thus, are invalid.

  1. (i)

    eY𝑒𝑌e-Yitalic_e - italic_Y: we cannot have a direct edge in either direction e𝑒eitalic_e between Y𝑌Yitalic_Y otherwise, Y𝑌Yitalic_Y is always dependent on e𝑒eitalic_e and the problem becomes intractable.

  2. (ii)

    eX𝑒𝑋e-Xitalic_e - italic_X: we cannot have a direct edge from eX𝑒𝑋e-Xitalic_e - italic_X without making additional parametric assumptions about the role of e𝑒eitalic_e in Γ(Zdg,Zspu,e)Γsubscript𝑍dgsubscript𝑍spu𝑒\Gamma(Z_{\text{dg}},Z_{\text{spu}},e)roman_Γ ( italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT , italic_e ).

  3. (iii)

    ZspuYsubscript𝑍spu𝑌Z_{\text{spu}}\rightarrow Yitalic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT → italic_Y: we cannot have both ZdgYsubscript𝑍dg𝑌Z_{\text{dg}}\rightarrow Yitalic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT → italic_Y and ZspuYsubscript𝑍spu𝑌Z_{\text{spu}}\rightarrow Yitalic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT → italic_Y, since then, both mechanisms are domain general. WLOG, we let Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT denote the features that never have domain-general mechanisms to Y𝑌Yitalic_Y.

  4. (iv)

    YZdgZspu𝑌subscript𝑍dgsubscript𝑍spuY\rightarrow Z_{\text{dg}}\rightarrow Z_{\text{spu}}italic_Y → italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT → italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT and YZdge𝑌subscript𝑍dg𝑒Y\rightarrow Z_{\text{dg}}\leftarrow eitalic_Y → italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ← italic_e: conditioning on Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and/or Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT make Y𝑌Yitalic_Y dependent on e𝑒eitalic_e, so Y𝑌Yitalic_Y is always dependent on e𝑒eitalic_e and the problem becomes intractable.

e𝑒eitalic_eZdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPTZspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPTY𝑌Yitalic_YX𝑋Xitalic_X
(a)
e𝑒eitalic_eZdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPTZspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPTY𝑌Yitalic_YX𝑋Xitalic_X
(b)
e𝑒eitalic_eZdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPTZspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPTY𝑌Yitalic_YX𝑋Xitalic_X
(c)
Figure 5: Generative Processes. Graphical model depicting the structure of our data-generating process - shaded nodes indicate observed variables. X𝑋Xitalic_X represents the observed features, Y𝑌Yitalic_Y represents observed targets, and e𝑒eitalic_e represents domain influences. There is an explicit separation of domain-general Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and domain-specific Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT features combined to generate observed X𝑋Xitalic_X. Dashed edges indicate the possibility of an edge.
Table 19: Generative Processes and Sufficient Conditions for Domain-Generality
Graphs in Figure 5
(a) (b) (c)
ZdgZspu|{Y,e}Z_{\text{dg}}\perp\!\!\!\perp Z_{\text{spu}}\,|\,\{Y,e\}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⟂ ⟂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | { italic_Y , italic_e }
Identifying Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT is necessary
Table 20: Generative Processes and Sufficient Algorithms
Graphs in Figure 5
(a) (b) (c)
Solved by ERM
Solved by TCRI

B.2 Fully Informative Invariant Features

We briefly summarize Ahuja et al. [2021]’s results on minimax-optimality of Empirical Risk Minimization in the Fully Informative Invariant Features setting (their Lemma 4). First, we informally state their assumptions.

  • Assumption 2: Linear structural equation model.

  • Assumption 3-4: Bounded Features.

  • Assumption 8: wdgsubscript𝑤dgw_{\text{dg}}italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT partitions 𝒵𝒵\mathcal{Z}caligraphic_Z up to noise ηYsubscript𝜂𝑌\eta_{Y}italic_η start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT.

These assumptions are implied by our Assumption 4.1.

B.2.1 Proof Sufficiency of ERM [Ahuja et al., 2021]

If Assumptions 2, 4, and 8 hold, then there exists a classifier that puts a non-zero weight on the spurious feature and continues to be Bayes optimal in all the training environments.

Proof.

Choose an arbitrary non-zero vector and derive a bound on the margin of (wdg,γsubscript𝑤dg𝛾w_{\text{dg}},\,\gammaitalic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_γ), where wdgsubscript𝑤dgw_{\text{dg}}italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT is the true (optimal) linear predictor of Y𝑌Yitalic_Y from Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT. Recall domain-general and domain-specific features zdg𝒵dg,zspu𝒵spuformulae-sequencesubscript𝑧dgsubscript𝒵dgsubscript𝑧spusubscript𝒵spuz_{\text{dg}}\in\mathcal{Z}_{\text{dg}},\,z_{\text{spu}}\in\mathcal{Z}_{\text{% spu}}italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ∈ caligraphic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ∈ caligraphic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT, respectively. Let y=sign(wdgzdg)superscript𝑦signsubscript𝑤dgsubscript𝑧dgy^{*}=\mathop{\mathrm{sign}}(w_{\text{dg}}\cdot z_{\text{dg}})italic_y start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = roman_sign ( italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⋅ italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ). The margin of (wdg,γ)w_{\text{dg}},\,\gamma)italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_γ )) at point (zdg,zspu)subscript𝑧dgsubscript𝑧spu(z_{\text{dg}},\,z_{\text{spu}})( italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) with respect to ysuperscript𝑦y^{*}italic_y start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is defined as: y^*(w_dg⋅z_dg) + y^*(γ⋅z_spu).

Using Cauchy-Schwartz inequality, we get |y^*(γ⋅z_spu)| = |γ⋅z_spu| ≤∥∥γ∥z_spu∥.

Since Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT is bounded, one can set γ𝛾\gammaitalic_γ sufficiently small enough to control y(γZspu)superscript𝑦𝛾subscript𝑍spuy^{*}(\gamma\cdot Z_{\text{spu}})italic_y start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_γ ⋅ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ). If γc2zsupnorm𝛾𝑐2superscript𝑧supremum\|\gamma\|\leq\frac{c}{2z^{\sup}}∥ italic_γ ∥ ≤ divide start_ARG italic_c end_ARG start_ARG 2 italic_z start_POSTSUPERSCRIPT roman_sup end_POSTSUPERSCRIPT end_ARG, then |γzspu|c2𝛾subscript𝑧spu𝑐2|\gamma\cdot z_{\text{spu}}|\leq\frac{c}{2}| italic_γ ⋅ italic_z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | ≤ divide start_ARG italic_c end_ARG start_ARG 2 end_ARG, where zsupsuperscript𝑧supremumz^{\sup}italic_z start_POSTSUPERSCRIPT roman_sup end_POSTSUPERSCRIPT satisfies that zzsupz𝒵spunorm𝑧superscript𝑧supremumfor-all𝑧subscript𝒵spu\|z\|\leq z^{\sup}\forall z\in\mathcal{Z}_{\text{spu}}∥ italic_z ∥ ≤ italic_z start_POSTSUPERSCRIPT roman_sup end_POSTSUPERSCRIPT ∀ italic_z ∈ caligraphic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT. From Assumption 8, c>0𝑐0\exists\,c>0∃ italic_c > 0 s.t., y^*(w_dg⋅z_dg) ≥c.

Using |γzspu|c2𝛾subscript𝑧spu𝑐2|\gamma\cdot z_{\text{spu}}|\leq\frac{c}{2}| italic_γ ⋅ italic_z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT | ≤ divide start_ARG italic_c end_ARG start_ARG 2 end_ARG, the margin becomes y^*(w_dg⋅z_dg) + y^*(γ⋅z_spu) ≥c - |γ⋅z_spu| ≥c2.

From the above equation, it follows that sign((wdg,γ)(zdg,zspu))=sign((wdg,0)(zdg,zspu))zdg𝒵dg,zspu𝒵spuformulae-sequencesignsubscript𝑤dg𝛾subscript𝑧dgsubscript𝑧spusignsubscript𝑤dg0subscript𝑧dgsubscript𝑧spufor-allsubscript𝑧dgsubscript𝒵dgsubscript𝑧spusubscript𝒵spu\mathop{\mathrm{sign}}\big{(}(w_{\text{dg}},\gamma)\cdot(z_{\text{dg}},z_{% \text{spu}})\big{)}=\mathop{\mathrm{sign}}\big{(}(w_{\text{dg}},0)\cdot(z_{% \text{dg}},z_{\text{spu}})\big{)}\forall z_{\text{dg}}\in\mathcal{Z}_{\text{dg% }},\,z_{\text{spu}}\in\mathcal{Z}_{\text{spu}}roman_sign ( ( italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_γ ) ⋅ ( italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) ) = roman_sign ( ( italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , 0 ) ⋅ ( italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) ) ∀ italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ∈ caligraphic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ∈ caligraphic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT.

Now, this condition is used to compute the error of a spurious classifier, i.e., based on (,γ)(\sc,\gamma)( , italic_γ ). Define gspu=I(wdg,γ)Γ1subscript𝑔spu𝐼subscript𝑤dg𝛾superscriptΓ1g_{\text{spu}}=I\circ(w_{\text{dg}},\gamma)\circ\Gamma^{-1}italic_g start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT = italic_I ∘ ( italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_γ ) ∘ roman_Γ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT, where I()𝐼I(\cdot)italic_I ( ⋅ ) is an indicator function that returns 1 if its input is \geq 0. The error achieved by gspusubscript𝑔spug_{\text{spu}}italic_g start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT is

Re(gspu)superscript𝑅𝑒subscript𝑔spu\displaystyle R^{e}(g_{\text{spu}})italic_R start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) =𝔼[YeI((wdg,γ)(zdg,zspu)]\displaystyle=\mathbb{E}\big{[}Y^{e}\oplus I((w_{\text{dg}},\gamma)\cdot(z_{% \text{dg}},z_{\text{spu}})\big{]}= blackboard_E [ italic_Y start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ⊕ italic_I ( ( italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_γ ) ⋅ ( italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) ]
=𝔼[I((wdg,0)(zdg,zspu))ηyI((wdg,γ)(zdg,zspu))]absent𝔼delimited-[]direct-sum𝐼subscript𝑤dg0subscript𝑧dgsubscript𝑧spusubscript𝜂𝑦𝐼subscript𝑤dg𝛾subscript𝑧dgsubscript𝑧spu\displaystyle=\mathbb{E}\Big{[}I\big{(}(w_{\text{dg}},0)\cdot(z_{\text{dg}},z_% {\text{spu}})\big{)}\oplus\eta_{y}\oplus I\big{(}(w_{\text{dg}},\gamma)\cdot(z% _{\text{dg}},z_{\text{spu}})\big{)}\Big{]}= blackboard_E [ italic_I ( ( italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , 0 ) ⋅ ( italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) ) ⊕ italic_η start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ⊕ italic_I ( ( italic_w start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_γ ) ⋅ ( italic_z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ) ) ]
=𝔼[ηy].absent𝔼delimited-[]subscript𝜂𝑦\displaystyle=\mathbb{E}[\eta_{y}].= blackboard_E [ italic_η start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ] .

The error achieved by gspusubscript𝑔spug_{\text{spu}}italic_g start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT is then due to the noise in observed Y𝑌Yitalic_Y and is, therefore, optimal in all domains. ∎

It follows from above that since gspusubscript𝑔spug_{\text{spu}}italic_g start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT is Bayes optimal in every domain, it is also the empirical risk minimizer (ERM) as it minimizes the sum of risks across training domains.

Appendix C Proof of Proposition 4.5

Assume that Φdg(X)subscriptΦdg𝑋\Phi_{\text{dg}}(X)roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) and Φspu(X)subscriptΦspu𝑋\Phi_{\text{spu}}(X)roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) are correlated with Y𝑌Yitalic_Y. Given Assumptions 4.1-4.2 and a representation Φ=ΦdgΦspuΦdirect-sumsubscriptΦdgsubscriptΦspu\Phi=\Phi_{\text{dg}}\oplus\Phi_{\text{spu}}roman_Φ = roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⊕ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT that satisfies TIC, Φdg(X)=ZdgiffsubscriptΦdg𝑋subscript𝑍dgabsent\Phi_{\text{dg}}(X)=Z_{\text{dg}}\iffroman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) = italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⇔ ΦΦ\Phiroman_Φ satisfies TCRI.

Proof.

‘only if’. Assume that Φdg(X)=ZdgsubscriptΦdg𝑋subscript𝑍dg\Phi_{\text{dg}}(X)=Z_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) = italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT. By the Total Information Criterion, we have that Φspu(X)=ZspusubscriptΦspu𝑋subscript𝑍spu\Phi_{\text{spu}}(X)=Z_{\text{spu}}roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) = italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT. We observe the following paths from Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT to Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT: (i) ZdgYZspusubscript𝑍dg𝑌subscript𝑍spuZ_{\text{dg}}\rightarrow Y\rightarrow Z_{\text{spu}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT → italic_Y → italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT, (ii) ZdgeZspusubscript𝑍dg𝑒subscript𝑍spuZ_{\text{dg}}\leftarrow e\rightarrow Z_{\text{spu}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ← italic_e → italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT, and (iii) ZdgXZspusubscript𝑍dg𝑋subscript𝑍spuZ_{\text{dg}}\rightarrow X\rightarrow Z_{\text{spu}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT → italic_X → italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT. Conditioning on Y,e𝑌𝑒Y,eitalic_Y , italic_e blocks both paths (i) and path (ii); path (iii) contains a collider (Zdg and Zspusubscript𝑍dg and subscript𝑍spuZ_{\text{dg}}\text{ and }Z_{\text{spu}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT and italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT are common causes of X𝑋Xitalic_X), so this path is blocked when X𝑋Xitalic_X is not in the conditioning set. So, ZspuZdg|Y,eZ_{\text{spu}}\perp\!\!\!\perp Z_{\text{dg}}\,|\,Y,eitalic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ⟂ ⟂ italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT | italic_Y , italic_e and therefore Φdg(X)Φspu(X)|Y,e\Phi_{\text{dg}}(X)\perp\!\!\!\perp\Phi_{\text{spu}}(X)\,|\,Y,eroman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) ⟂ ⟂ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) | italic_Y , italic_e, which completes this direction.

‘if’. Assume that ΦΦ\Phiroman_Φ satisfies TCRI. We proceed by contradiction. Let Φ=[Φdg;Φspu]ΦsubscriptΦdgsubscriptΦspu\Phi=[\Phi_{\text{dg}};\Phi_{\text{spu}}]roman_Φ = [ roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ; roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ]. We consider the following scenario for ΦdgZdgsubscriptΦdgsubscript𝑍dg\Phi_{\text{dg}}\neq Z_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ≠ italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT.

Scenario 1 (causal aggregation): Assume that Φdg(X)ZdgsubscriptΦdg𝑋subscript𝑍dg\Phi_{\text{dg}}(X)\subset Z_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) ⊂ italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT. From TIC, we have that ZdgΦspu(X)superscriptsubscript𝑍dgsubscriptΦspu𝑋Z_{\text{dg}}^{\dagger}\subset\Phi_{\text{spu}}(X)italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ⊂ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ), where ZdgZdgsuperscriptsubscript𝑍dgsubscript𝑍dgZ_{\text{dg}}^{\dagger}\subset Z_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ⊂ italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT is the subset of Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT not captured by ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT. Since Φdg(X)subscriptΦdg𝑋\Phi_{\text{dg}}(X)roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) and Zdgsuperscriptsubscript𝑍dgZ_{\text{dg}}^{\dagger}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT are colliders on Y𝑌Yitalic_Y, given both are subsets of Zdgsubscript𝑍dgZ_{\text{dg}}italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT, Φdg(X)⟂̸Φspu(X)|Y,e\Phi_{\text{dg}}(X)\not\perp\!\!\!\perp\Phi_{\text{spu}}(X)|Y,eroman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) ⟂̸ ⟂ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) | italic_Y , italic_e, violating TCRI and giving a contradiction. So, ZdgΦ(X)subscript𝑍dgΦ𝑋Z_{\text{dg}}\subset\Phi(X)italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ⊂ roman_Φ ( italic_X )

Scenario 2 (anticausal exclusion): Assume that Φdg(X)ZspusubscriptΦdg𝑋subscript𝑍spu\Phi_{\text{dg}}(X)\subset Z_{\text{spu}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) ⊂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT. From TIC, we have that ZspuΦspu(X)superscriptsubscript𝑍spusubscriptΦspu𝑋Z_{\text{spu}}^{\dagger}\subset\Phi_{\text{spu}}(X)italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ⊂ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ), where ZspuZspusuperscriptsubscript𝑍spusubscript𝑍spuZ_{\text{spu}}^{\dagger}\subset Z_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ⊂ italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT is the subset of Zspusubscript𝑍spuZ_{\text{spu}}italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT not captured by ΦdgsubscriptΦdg\Phi_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT. From Assumption 4.2 (faithfulness), we have that Φdg(X)⟂̸Φspu(X)|Y,e\Phi_{\text{dg}}(X)\not\perp\!\!\!\perp\Phi_{\text{spu}}(X)|Y,eroman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) ⟂̸ ⟂ roman_Φ start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ( italic_X ) | italic_Y , italic_e, violating TCRI and giving a contradiction. So, ZspuΦdg(X)not-subset-ofsubscript𝑍spusubscriptΦdg𝑋Z_{\text{spu}}\not\subset\Phi_{\text{dg}}(X)italic_Z start_POSTSUBSCRIPT spu end_POSTSUBSCRIPT ⊄ roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ).

Combining scenarios 1-2, it follows that Φdg(X)=ZdgsubscriptΦdg𝑋subscript𝑍dg\Phi_{\text{dg}}(X)=Z_{\text{dg}}roman_Φ start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT ( italic_X ) = italic_Z start_POSTSUBSCRIPT dg end_POSTSUBSCRIPT. ∎