Causality Pursuit from Heterogeneous Environments via Neural Adversarial Invariance Learning

Yihong Gu1      Cong Fang2      Peter Bühlmann3      Jianqing Fan1
1Department of Operations Research and Financial Engineering, Princeton University
2School of Intelligence Science and Technology, Peking University
3Seminar for Statistics, ETH Zürich
(This version: June 30, 2024)
Abstract

Pursuing causality from data is a fundamental problem in scientific discovery, treatment intervention, and transfer learning. This paper introduces a novel algorithmic method for addressing nonparametric invariance and causality learning in regression models across multiple environments, where the joint distribution of response variables and covariates varies, but the conditional expectations of outcome given an unknown set of quasi-causal variables are invariant. The challenge of finding such an unknown set of quasi-causal or invariant variables is compounded by the presence of endogenous variables that have heterogeneous effects across different environments, including even one of them in the regression would make the estimation inconsistent. The proposed Focused Adversial Invariant Regularization (FAIR) framework utilizes an innovative minimax optimization approach that breaks down the barriers, driving regression models toward prediction-invariant solutions through adversarial testing. Leveraging the representation power of neural networks, FAIR neural networks (FAIR-NN) are introduced for causality pursuit. It is shown that FAIR-NN can find the invariant variables and quasi-causal variables under a minimal identification condition and that the resulting procedure is adaptive to low-dimensional composition structures in a non-asymptotic analysis. Under a structural causal model, variables identified by FAIR-NN represent pragmatic causality and provably align with exact causal mechanisms under conditions of sufficient heterogeneity. Computationally, FAIR-NN employs a novel Gumbel approximation with decreased temperature and stochastic gradient descent ascent algorithm. The procedures are convincingly demonstrated using simulated and real-data examples.

Keywords: Adversarial Estimation, Causal Discovery, Conditional Moment Restriction, Gumbel Approximation, Invariance, Neural Networks.

1 Introduction

A fundamental problem in statistics and machine learning is to predict the response variable Y𝑌Yitalic_Y based on explanatory covariates denoted as Xd𝑋superscript𝑑X\in\mathbb{R}^{d}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT using collected data. The objective often centers on estimating the regression function m0(x)=𝔼[Y|X=x]subscript𝑚0𝑥𝔼delimited-[]conditional𝑌𝑋𝑥m_{0}(x)=\mathbb{E}[Y|X=x]italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) = blackboard_E [ italic_Y | italic_X = italic_x ], which minimizes the population L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT risk 𝖱(m)=|ym(x)|2μ0(dx,dy)𝖱𝑚superscript𝑦𝑚𝑥2subscript𝜇0𝑑𝑥𝑑𝑦\mathsf{R}(m)=\int|y-m(x)|^{2}\mu_{0}(dx,dy)sansserif_R ( italic_m ) = ∫ | italic_y - italic_m ( italic_x ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_d italic_x , italic_d italic_y ), starting from the pioneering work of least squares in Legendre, (1805); Gauss, (1809). In the age of data, the problem of achieving sample-efficient estimation of m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT was extensively studied. There are a lot of structural methods attempting to exploit the low-dimensional structure such as sparsity, low-rankness and additivity, and design corresponding optimal methods tailored to that assumed structure (Hastie et al.,, 2009; Wainwright,, 2019; Fan et al.,, 2020). However, these methods lack scalable applicability and suffer from model misspecification due to their reliance on imposed structures. As an alternative, algorithmic methods (Breiman,, 2001) like neural networks can be adaptive to the low-dimensional structure efficiently (Schmidt-Hieber,, 2020; Fan & Gu,, 2024) with no supervision of function structure. This nature endows them with universal applicability across various tasks and data.

Despite many celebrated efforts in the efficient estimation of m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT or its variants like quantile function, the ultimate goal is to utilize observations to fit a model capable of making decent predictions on unseen data, elucidating the causal relationships among variables, and guiding decision-making in real-world scenarios. We instinctively regard m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT as such a target function for achieving decent prediction and causal attribution. However, this can be flawed: m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT can produce unstable predictions on unseen data and risk false scientific conclusions in numerous cases. Consider a simple thought experiment where we aim to classify an object in a picture as either a cow (Y=1)𝑌1(Y=1)( italic_Y = 1 ) or a camel (Y=0)𝑌0(Y=0)( italic_Y = 0 ) using two provided features X1subscript𝑋1X_{1}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT (body shape) and X2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT (background color). In the data we collected from μ0subscript𝜇0\mu_{0}italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, the cows usually appear on green grass, while camels often stay on yellow sand. Consequently, the conditional expectation m0(x1,x2)=𝔼μ0[Y|X1=x1,X2=x2]subscript𝑚0subscript𝑥1subscript𝑥2subscript𝔼subscript𝜇0delimited-[]formulae-sequenceconditional𝑌subscript𝑋1subscript𝑥1subscript𝑋2subscript𝑥2m_{0}(x_{1},x_{2})=\mathbb{E}_{\mu_{0}}[Y|X_{1}=x_{1},X_{2}=x_{2}]italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_Y | italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] would be heavily dependent on x2subscript𝑥2x_{2}italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Such a model is problematic both for prediction and attribution. Its application in a setting with a different background such as zoos would lead to unreliable predictions. Furthermore, attributing the determination of an object to the background surrounding it also contradicts our understanding of causality. In the above case, we may prefer m(x)=𝔼[Y|X1=x1]subscript𝑚𝑥𝔼delimited-[]conditional𝑌subscript𝑋1subscript𝑥1m_{\star}(x)=\mathbb{E}[Y|X_{1}=x_{1}]italic_m start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ( italic_x ) = blackboard_E [ italic_Y | italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] for prediction and attribution as we know the causal mechanisms.

We refer to the above problem as the “curse of endogeneity” in that the conditional expectation of the residual for the “potential” interested (causal) msubscript𝑚m_{\star}italic_m start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT is not zero given all the explanatory variables, i.e., 𝔼[Ym(X)|X]0𝔼delimited-[]𝑌conditionalsubscript𝑚𝑋𝑋0\mathbb{E}[Y-m_{\star}(X)|X]\neq 0blackboard_E [ italic_Y - italic_m start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ( italic_X ) | italic_X ] ≠ 0, leading to a misalignment between m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and msubscript𝑚m_{\star}italic_m start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT, i.e., m0(X)m(X)0subscript𝑚0𝑋subscript𝑚𝑋0m_{0}(X)-m_{\star}(X)\neq 0italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_X ) - italic_m start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ( italic_X ) ≠ 0. Hence traditional regression techniques for estimating m0subscript𝑚0m_{0}italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT will result in an unsatisfactory solution.

Causal inference methods offer structural remedies to the curse of endogeneity. These methods are structural in that they are tailored to pre-assumed, task-specific, and untestable identification conditions or causal-effect knowledge. This prior knowledge can be formally encoded in the potential outcome (Rubin,, 1974), or structural causal model (Glymour et al.,, 2016) framework and fully shaped the “causality skeleton” that exactly determines the causal estimand of interests by some statistical estimand, and the “association flesh” of the latter can be further estimated via structural or algorithmic regression techniques. Examples include estimating the average treatment effect (Robins et al.,, 1994) and conditional average treatment effects (Athey et al.,, 2019; Kennedy et al.,, 2024) under the unconfoundedness condition. These methods’ reliance on prior knowledge limits their scalable use, exposes them to severe model misspecification, and prevents their drawn conclusion from going beyond hindsight because it is impossible to falsify (Popper,, 2005) these assumptions using data.

This paper aims to answer the following fundamental question:

Can we design methods that can algorithmically circumvent the “curse of endogeneity”without the supervision of cause-effect knowledge?Can we design methods that can algorithmically circumvent the “curse of endogeneity”without the supervision of cause-effect knowledge?\displaystyle\begin{split}&\text{\it Can we design methods that can % algorithmically circumvent the ``curse of endogeneity''}\\ &\text{\it without the supervision of cause-effect knowledge?}\end{split}start_ROW start_CELL end_CELL start_CELL Can we design methods that can algorithmically circumvent the “curse of endogeneity” end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL without the supervision of cause-effect knowledge? end_CELL end_ROW (Q)

Without prior causal structural knowledge, we leverage the principle of how humans understand causality: the causal association consistently occurs in the past, now, and (potentially) future, or more broadly, in diverse environments. In other words, we pursue certain data-driven or data-shaped causality that is invariant across diverse environments, this is essentially what one can pursue based only on observed data without prior knowledge. Hence we do not differentiate the concepts of invariance and (data-driven) causality in this paper. Levering the invariance principle, we propose a unified and algorithmic framework for causality pursuit that is robust to model misspecification based on data from multiple environments. Though the proposed data-driven causality is conceptually different from previous knowledge-based causality that pre-assumes the ground truth, these two types of causality can coincide when the heterogeneity of environments is sufficient.

1.1 The Canonical Model under Study

Let us revisit the thought experiment from the perspective of a hyper-intelligent alien, Alice. Alice knows nothing about cows and camels except for 1000 images with annotated labels highly associated with the background, for example, r=90%𝑟percent90r=90\%italic_r = 90 % cows/camels on grass/sand. It’s impossible for her to know that the background cannot determine the object given this limited information. In other words, both X1subscript𝑋1X_{1}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and X2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT can be regarded as causality out of pragmatic considerations. However, if she receives another set of 1000 labeled images, where r=70%𝑟percent70r=70\%italic_r = 70 % cows/camels on grass/sand, she might begin to question the causality role of the background: the emerging evidence of the varying associations between X2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and Y𝑌Yitalic_Y falsify the hypothesis that X2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is causality if she believes that causality persists across diverse environments.

When there is no supervision of the cause-effect relationship, the observation from heterogeneous sources is essential. We consider the following multi-environment regression problem that mimics human causality learning. Let \mathcal{E}caligraphic_E be the set of sources/environments. For each environment e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, we observe n𝑛nitalic_n i.i.d. data (X1(e),Y1(e)),,(Xn(e),Yn(e))μ(e)similar-tosuperscriptsubscript𝑋1𝑒superscriptsubscript𝑌1𝑒superscriptsubscript𝑋𝑛𝑒superscriptsubscript𝑌𝑛𝑒superscript𝜇𝑒(X_{1}^{(e)},Y_{1}^{(e)}),\ldots,(X_{n}^{(e)},Y_{n}^{(e)})\sim\mu^{(e)}( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) , … , ( italic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ∼ italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT, where μ(e)superscript𝜇𝑒\mu^{(e)}italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT, the joint distribution of (X(e),Y(e))superscript𝑋𝑒superscript𝑌𝑒(X^{(e)},Y^{(e)})( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ), satisfies

Y(e)=m(XS(e))+ε(e)with𝔼[ε(e)|XS(e)]0.formulae-sequencesuperscript𝑌𝑒superscript𝑚superscriptsubscript𝑋superscript𝑆𝑒superscript𝜀𝑒with𝔼delimited-[]conditionalsuperscript𝜀𝑒superscriptsubscript𝑋superscript𝑆𝑒0\displaystyle Y^{(e)}=m^{\star}(X_{S^{\star}}^{(e)})+\varepsilon^{(e)}\qquad% \text{with}\qquad\mathbb{E}[\varepsilon^{(e)}|X_{S^{\star}}^{(e)}]\equiv 0.italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) + italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT with blackboard_E [ italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ≡ 0 . (1.1)

Here Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, the unknown true important variable set, and m:|S|:superscript𝑚superscriptsuperscript𝑆m^{\star}:\mathbb{R}^{|S^{\star}|}\to\mathbb{R}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT : blackboard_R start_POSTSUPERSCRIPT | italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT | end_POSTSUPERSCRIPT → blackboard_R, the target regression function, are both invariant across different environments; but the joint distributions μ(e)superscript𝜇𝑒\mu^{(e)}italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT can vary. We aim to learn the set of quasi-causal variables Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and estimate the invariant regression function msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT using data {{(Xi(e),Yi(e))}i=1n}esubscriptsuperscriptsubscriptsuperscriptsubscript𝑋𝑖𝑒superscriptsubscript𝑌𝑖𝑒𝑖1𝑛𝑒\{\{(X_{i}^{(e)},Y_{i}^{(e)})\}_{i=1}^{n}\}_{e\in\mathcal{E}}{ { ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT from |||\mathcal{E}|| caligraphic_E | heterogeneous environments. The same n𝑛nitalic_n in the problem formulation is just for expository simplicity, the extension to varying n(e)superscript𝑛𝑒n^{(e)}italic_n start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT is straightforward. We refer to the above problem as nonparametric invariance pursuit or nonparametric causality pursuit exchangeably, as based on the data alone, without prior knowledge, we can not differentiate these concepts.

Here, we temporarily refrain from causal discussions. Under particular scenarios, such a problem can be instantiated to causal discovery in the Structural Causal Model (SCM) framework (Peters et al.,, 2016) and transfer learning with a more realistic assumption (Rojas-Carulla et al.,, 2018); see the details in Section A.1. We offer in Section 3 a rigorous and comprehensive interpretation of what Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is in the SCM with interventions on X𝑋Xitalic_X. It is also notable to mention that model (1.1) only requires invariance in the first moment instead of full distributional invariance, i.e., ε(e)Fεsimilar-tosuperscript𝜀𝑒subscript𝐹𝜀\varepsilon^{(e)}\sim F_{\varepsilon}italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∼ italic_F start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT and independent of XS(e)superscriptsubscript𝑋superscript𝑆𝑒X_{S^{\star}}^{(e)}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT, as typically required for causal discovery (Peters et al.,, 2016). It is more realistic and allows for between-environment heteroscedastic errors.

It is important to note that the standard nonparametric regression generally diverges from our target msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, i.e., 𝔼[Y(e)|X(e)=x]m(xS)𝔼delimited-[]conditionalsuperscript𝑌𝑒superscript𝑋𝑒𝑥superscript𝑚subscript𝑥superscript𝑆\mathbb{E}[Y^{(e)}|X^{(e)}=x]\neq m^{\star}(x_{S^{\star}})blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = italic_x ] ≠ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ). This discrepancy arises because 𝔼[ε(e)|X(e)]0𝔼delimited-[]conditionalsuperscript𝜀𝑒superscript𝑋𝑒0\mathbb{E}[\varepsilon^{(e)}|X^{(e)}]\neq 0blackboard_E [ italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ≠ 0. Such a “curse of endogeneity” problem is the main challenge we need to address. Including even one of endogenous spurious variables, for example, X2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT background color in the above thought experiment, in the regression function will create an inconsistent estimation of msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Thus, it is essential to design an algorithm to eliminate all endogenous spurious variables.

1.2 Our Algorithmic Remedy: FAIR Estimation

This paper proposes a unified estimation framework – the Focused Adversarial Invariance Regularized (FAIR) estimator. It regularizes the user-specified risk loss (y,v)𝑦𝑣\ell(y,v)roman_ℓ ( italic_y , italic_v ) by a novel regularizer. Specifically, the FAIR estimator is the solution of the following minimax optimization program

ming𝒢maxf(e)Sg,ee𝔼μ(e)[(Y,g(X))]𝖱(g)+γe𝔼μ(e)[{Yg(X)}f(e)(X){f(e)(X)}2/2]𝖩(g,{f(e)}e).subscript𝑔𝒢subscriptformulae-sequencesuperscript𝑓𝑒subscriptsubscript𝑆𝑔for-all𝑒subscriptsubscript𝑒subscript𝔼superscript𝜇𝑒delimited-[]𝑌𝑔𝑋𝖱𝑔𝛾subscriptsubscript𝑒subscript𝔼superscript𝜇𝑒delimited-[]𝑌𝑔𝑋superscript𝑓𝑒𝑋superscriptsuperscript𝑓𝑒𝑋22𝖩𝑔subscriptsuperscript𝑓𝑒𝑒\displaystyle\min_{g\in\mathcal{G}}\max_{f^{(e)}\in\mathcal{F}_{S_{g}},\forall e% \in\mathcal{E}}\underbrace{\sum_{e\in\mathcal{E}}\mathbb{E}_{\mu^{(e)}}\left[% \ell(Y,g(X))\right]}_{\mathsf{R}(g)}+\gamma\underbrace{\sum_{e\in\mathcal{E}}% \mathbb{E}_{\mu^{(e)}}\left[\{Y-g(X)\}f^{(e)}(X)-\{f^{(e)}(X)\}^{2}/2\right]}_% {\mathsf{J}(g,\{f^{(e)}\}_{e\in\mathcal{E}})}.roman_min start_POSTSUBSCRIPT italic_g ∈ caligraphic_G end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT , ∀ italic_e ∈ caligraphic_E end_POSTSUBSCRIPT under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_Y , italic_g ( italic_X ) ) ] end_ARG start_POSTSUBSCRIPT sansserif_R ( italic_g ) end_POSTSUBSCRIPT + italic_γ under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ { italic_Y - italic_g ( italic_X ) } italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X ) - { italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X ) } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 ] end_ARG start_POSTSUBSCRIPT sansserif_J ( italic_g , { italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT . (1.2)

Here (,)\ell(\cdot,\cdot)roman_ℓ ( ⋅ , ⋅ ) is a loss whose population solution leads to the conditional expectation, γ>0𝛾0\gamma>0italic_γ > 0 is the regularization hyper-parameter to be determined, (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ) are the function classes to be specified by the user satisfying 𝒢𝒢\mathcal{G}\subseteq\mathcal{F}caligraphic_G ⊆ caligraphic_F. The first part is the risk minimization, and the second component is the test of exogeneity of the variables Sg=supp(g)subscript𝑆𝑔supp𝑔S_{g}=\mathrm{supp}(g)italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = roman_supp ( italic_g ) used by the regression function g𝑔gitalic_g, where Sg={f:f(x)=h(xSg) for some h:|Sg|}subscriptsubscript𝑆𝑔conditional-set𝑓:𝑓𝑥subscript𝑥subscript𝑆𝑔 for some superscriptsubscript𝑆𝑔\mathcal{F}_{S_{g}}=\{f\in\mathcal{F}:f(x)=h(x_{S_{g}})\text{ for some }h:% \mathbb{R}^{|S_{g}|}\to\mathbb{R}\}caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT = { italic_f ∈ caligraphic_F : italic_f ( italic_x ) = italic_h ( italic_x start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) for some italic_h : blackboard_R start_POSTSUPERSCRIPT | italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT | end_POSTSUPERSCRIPT → blackboard_R } is the testing function class for the prediction functions in 𝒢𝒢\mathcal{G}caligraphic_G that only “focuses” on the variables Sgsubscript𝑆𝑔S_{g}italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT that g𝑔gitalic_g used. Two useful classes of functions are linear and square-integrable classes for (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ), which correspond respectively to linear models and nonparametric regression models; see Section 4.1 for additional details. Note that the second component is nonnegative after maximization by comparing with f(e)=0superscript𝑓𝑒0f^{(e)}=0italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = 0 so that the penalty is nonnegative. For the empirical counterpart, we solve a similar minimax optimization program that substitutes 𝔼μ(e)[]subscript𝔼superscript𝜇𝑒delimited-[]\mathbb{E}_{\mu^{(e)}}[\cdot]blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ⋅ ] with the corresponding sample means.

To see why such a FAIR penalty works, let us consider the nonparametric regression setting in which ={f:𝔼μ(e)[f2(XSg)]<}conditional-set𝑓subscript𝔼superscript𝜇𝑒delimited-[]superscript𝑓2subscript𝑋subscript𝑆𝑔\mathcal{F}=\{f:\mathbb{E}_{\mu^{(e)}}[f^{2}(X_{S_{g}})]<\infty\}caligraphic_F = { italic_f : blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ] < ∞ }. By conditioning on XSgsubscript𝑋subscript𝑆𝑔X_{S_{g}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT, for f(e)Sgsuperscript𝑓𝑒subscriptsubscript𝑆𝑔f^{(e)}\in\mathcal{F}_{S_{g}}italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT, we have

𝔼μ(e)[{Yg(X)}f(e)(X)]=𝔼μ(e)[{𝔼μ(e)[Y|XSg]g(X)}f(e)(X)].subscript𝔼superscript𝜇𝑒delimited-[]𝑌𝑔𝑋superscript𝑓𝑒𝑋subscript𝔼superscript𝜇𝑒delimited-[]subscript𝔼superscript𝜇𝑒delimited-[]conditional𝑌subscript𝑋subscript𝑆𝑔𝑔𝑋superscript𝑓𝑒𝑋\mathbb{E}_{\mu^{(e)}}\left[\{Y-g(X)\}f^{(e)}(X)\right]=\mathbb{E}_{\mu^{(e)}}% \left[\left\{\mathbb{E}_{\mu^{(e)}}[Y|X_{S_{g}}]-g(X)\right\}f^{(e)}(X)\right].blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ { italic_Y - italic_g ( italic_X ) } italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X ) ] = blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ { blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ italic_Y | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] - italic_g ( italic_X ) } italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X ) ] .

Then, the supremum in (1.2) can be explicitly found and the objective now becomes

ming𝒢𝖱(g)+γ𝖩(g)with𝖩(g)=12e𝔼μ(e)[|g(X)𝔼μ(e)[Y|XSg]|2].\displaystyle\min_{g\in\mathcal{G}}\mathsf{R}(g)+\gamma\cdot\mathsf{J}^{\star}% (g)\qquad\text{with}\qquad\mathsf{J}^{\star}(g)=\frac{1}{2}\sum_{e\in\mathcal{% E}}\mathbb{E}_{\mu^{(e)}}\left[\left|g(X)-\mathbb{E}_{\mu^{(e)}}[Y|X_{S_{g}}]% \right|^{2}\right].roman_min start_POSTSUBSCRIPT italic_g ∈ caligraphic_G end_POSTSUBSCRIPT sansserif_R ( italic_g ) + italic_γ ⋅ sansserif_J start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_g ) with sansserif_J start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_g ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ | italic_g ( italic_X ) - blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ italic_Y | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] . (1.3)

Therefore, g(X)=m(XS)𝑔𝑋superscript𝑚subscript𝑋superscript𝑆g(X)=m^{\star}(X_{S^{\star}})italic_g ( italic_X ) = italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) is a minimax solution.

To motivate (1.2), let us first consider the additional constraint 𝔼μ(e)[f(e)(XSg)2]=1subscript𝔼superscript𝜇𝑒delimited-[]superscript𝑓𝑒superscriptsubscript𝑋subscript𝑆𝑔21\mathbb{E}_{\mu^{(e)}}[f^{(e)}(X_{S_{g}})^{2}]=1blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] = 1 so that the first part of the second component in (1.2) is basically the maximal correlation between the residual {Yg(XSg)}𝑌𝑔subscript𝑋subscript𝑆𝑔\{Y-g(X_{S_{g}})\}{ italic_Y - italic_g ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) } and testing functions f(e)(XSg)superscript𝑓𝑒subscript𝑋subscript𝑆𝑔f^{(e)}(X_{S_{g}})italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT ). Hence, the criterion (1.2) is to find a set of variables XSgsubscript𝑋subscript𝑆𝑔X_{S_{g}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT as exogenous (weakly correlated) with the residuals as possible for all testing functions in Sgsubscriptsubscript𝑆𝑔\mathcal{F}_{S_{g}}caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT. By the Lagrange multiplier method, the constrained maximization problem can be written as

maxf(e)Sg𝔼μ(e)[{Yg(X)}f(e)(X)λ{f(e)(X)}2].subscriptsuperscript𝑓𝑒subscriptsubscript𝑆𝑔subscript𝔼superscript𝜇𝑒delimited-[]𝑌𝑔𝑋superscript𝑓𝑒𝑋𝜆superscriptsuperscript𝑓𝑒𝑋2\max_{f^{(e)}\in\mathcal{F}_{S_{g}}}\mathbb{E}_{\mu^{(e)}}\left[\{Y-g(X)\}f^{(% e)}(X)-\lambda\{f^{(e)}(X)\}^{2}\right].roman_max start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ { italic_Y - italic_g ( italic_X ) } italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X ) - italic_λ { italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X ) } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] .

Choosing the multiplier λ=1/2𝜆12\lambda=1/2italic_λ = 1 / 2 (justified in the above paragraph) gives rise to the object function (1.2).

FAIR penalty screens out all endogenous spurious variables when γ𝛾\gammaitalic_γ is sufficiently large. This is easily seen when the penalty in (1.2) is not zero, such a g𝑔gitalic_g is dominated by g=m𝑔superscript𝑚g=m^{\star}italic_g = italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT when γ𝛾\gammaitalic_γ is sufficiently large. After endogenous spurious variables, we can apply the commonly-used statistical variable selection methods (Hastie et al.,, 2009; Wainwright,, 2019; Fan et al.,, 2020) to further eliminate exogenous spurious or weak causal variables. In addition, we will show that under the SCM with arbitrary and nondegenerate interventions on X𝑋Xitalic_X, our proposed FAIR estimator can unveil Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT being precisely expressed by the graph structure of the SCM, which can be interpreted as the “pragmatic” direct cause of the response Y𝑌Yitalic_Y in general and will coincide with the direct causes if all the root children are intervened. The obtained result is clearly distinguished from what least squares, or even its worst-case variants like distribution robust optimization (Duchi & Namkoong,, 2021), Maximin (Meinshausen & Bühlmann,, 2015), can obtain. Our method indeed learns certain data-driven causality, while others cannot go beyond learning associations.

1.3 New Contributions

We propose a unified, algorithmic, and sample-efficient methodological framework that can discover the invariant regression function, i.e. to solve a generalized version of the problem in Section 1.1. The method is simple, universal, fully algorithmic, and sample-efficient: It is just one optimization objective (1.2) complemented by one extra hyper-parameter γ𝛾\gammaitalic_γ; it accommodates many losses and can be seamlessly integrated by various machine learning algorithms; it does not require any prior structural knowledge, and it is almost as statistically efficient as standard regression under various cases.

As a special instance in our framework, the FAIR neural network (FAIR-NN) estimator is proposed for which 𝒢𝒢\mathcal{G}caligraphic_G and \mathcal{F}caligraphic_F are neural networks to unveil msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in (1.1). It is the first theoretically guaranteed estimator that can efficiently recover msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT under a single general and minimal identification condition associated with the heterogeneity of the environments. Its sample efficiency can be understood in several notable aspects: it requires the minimal identification condition, leading to fewer required environments; it exhibits the same L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error rate as if directly regressing Y𝑌Yitalic_Y on known XSsubscript𝑋superscript𝑆X_{S^{\star}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, regardless of the complexity of spurious associations; and it adapts to the unknown low-dimension structure of the invariant association msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in a same manner as Kohler & Langer, (2021). In summary, the FAIR-NN estimator circumvents the “curse of dimensionality” and “curse of endogeneity” simultaneously in a fully algorithmic manner, which does not rely on the prior knowledge of msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT structure or cause-effect relationships among variables.

While the complicated combinatorial constraint and minimax optimization are introduced in (1.2), we show that a variant of gradient descent – gradient descent ascent with Gumbel approximation to handle the combinatorial-nature “focused” constraint fSg𝑓subscriptsubscript𝑆𝑔f\in\mathcal{F}_{S_{g}}italic_f ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT – continues to apply to our specifically designed algorithm and neural network estimators with no curse-of-dimension in implementation. Numerical results in Section 5 support this.

Though our framework is designed for algorithmic learning, it is versatile in that the user can also incorporate their strong prior structural knowledge such as linearity or additivity of msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT into the FAIR estimation. This can be realized by restricting the function class 𝒢𝒢\mathcal{G}caligraphic_G within this known structure and designating \mathcal{F}caligraphic_F as a more expansive class. We demonstrate that harnessing such strong structural knowledge can relax the condition for identification. It is worth pointing out that identification is viable even when ||=11|\mathcal{E}|=1| caligraphic_E | = 1 corresponding to observational data; see examples in Section B.6. At the methodology level, our method bridges the invariance principle (Peters et al.,, 2016) and asymmetry principle (Janzing et al.,, 2016) for observational data into a unified framework.

1.4 Related Works and Comparisons

Starting from the pioneering work of Peters et al., (2016), there is considerable literature proposing methods to estimate msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in (1.1), predominantly when msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is linear. These methods broadly fall into two categories: hypothesis test-based methods and optimization-based methods. For the hypothesis test-based methods (Peters et al.,, 2016; Heinze-Deml et al.,, 2018; Pfister et al.,, 2019), the Type-I error is controlled for an estimator S^^𝑆\widehat{S}over^ start_ARG italic_S end_ARG with (S^S)1α^𝑆superscript𝑆1𝛼\mathbb{P}(\widehat{S}\subseteq S^{\star})\geq 1-\alphablackboard_P ( over^ start_ARG italic_S end_ARG ⊆ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ≥ 1 - italic_α. Nonetheless, these procedures may result in missing important variables or conservative solutions like S^=^𝑆\widehat{S}=\emptysetover^ start_ARG italic_S end_ARG = ∅ due to the inherent worst-case construction in the algorithm. Additionally, the introduction of hypothesis tests also hinders its seamless integration by machine learning algorithms, limiting their scalability. On the other hand, some optimization-based methods (Ghassami et al.,, 2017; Rothenhäusler et al.,, 2019, 2021) focus on linear msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and tackle the problem under additional structures such as linear SCMs with additive interventions (Rothenhäusler et al.,, 2019). This limitation curtails its applicability to a broader nonparametric setting. Some optimization-based methods (Pfister et al.,, 2021; Yin et al.,, 2021) designed for linear models are heuristic and lack finite sample guarantees. In summary, there is still a crucial gap towards efficiently estimating msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT without additional assumptions on the underlying model. Although Fan et al., (2023) recently bridged this gap for linear msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT through an optimization-based method, it is still unclear under the general nonparametric setting. This paper is the first to attain sample-efficient estimation for the general model with non-asymptotic guarantees in terms of both |||\mathcal{E}|| caligraphic_E | and n𝑛nitalic_n.

Arjovsky et al., (2019) considers a general task, which aims to search for a data representation such that the optimal solution given that representation is optimal across diverse environments. They propose an optimization-based approach called invariant risk minimization (IRM), with many subsequent variants proposed later. However, their method comes with no statistical guarantees and requires at least d𝑑ditalic_d environments even for the linear model, and the improvement over standard empirical risk minimization is not clear (Rosenfeld et al.,, 2021; Kamath et al.,, 2021). Our paper is the first to offer a comprehensive theoretical analysis of general invariance learning when the representation class is {(x1,,xd)(a1x1,,adxd):a1,,ad{0,1}}conditional-setsubscript𝑥1subscript𝑥𝑑subscript𝑎1subscript𝑥1subscript𝑎𝑑subscript𝑥𝑑subscript𝑎1subscript𝑎𝑑01\{(x_{1},\ldots,x_{d})\to(a_{1}x_{1},\ldots,a_{d}x_{d}):a_{1},\ldots,a_{d}\in% \{0,1\}\}{ ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) → ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_a start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) : italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_a start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ∈ { 0 , 1 } } and to show that sample efficient estimation is in general viable even when ||=22|\mathcal{E}|=2| caligraphic_E | = 2. The main reason why this is attainable is due to the exact invariance pursued by our FAIR penalty and its “focused” nature, see the discussion in Section A.2.

Under the SCM framework, there is considerable literature on causal discovery using observational data (Spirtes et al.,, 2000; Richardson,, 1996; Chickering,, 2002; Hyttinen et al.,, 2013, 2014), but they cannot go beyond Markov equivalent class (Geiger & Pearl,, 1990) and thus fail to establish the exact cause-effect direction in general. Such a problem can be resolved by imposing additional assumptions under the circumstances that the algorithm can only passively observe data rather than performing intervention actively. These methods can be divided into two categories – one based on the invariance principle and the other based on the asymmetry principle. The invariance-based approaches (Peters et al.,, 2016) use samples from multiple experiments where some unknown intervention may apply to the variables other than Y𝑌Yitalic_Y. It leverages the idea that the cause-effect mechanism will remain constant while the reverse effect-cause association may vary. On the other hand, the asymmetry-based approaches (Shimizu et al.,, 2006; Hoyer et al.,, 2008; Zhang & Hyvärinen,, 2009; Janzing et al.,, 2012; Peters et al.,, 2014) only observe one sample of observational data and use the idea that the cause-effect mechanism admits a simple prior known structure, whereas its inverse does not, example includes the additive noise structure (Hoyer et al.,, 2008). These two principles for causal discovery seem to have been orthogonal before. Our estimation framework is the first to offer a unified methodological perspective on these two principles with theoretical guarantees. It demonstrates the ability to simultaneously leverage both principles for identification and estimation.

Adversarial estimation is introduced in Goodfellow et al., (2014) for generative modeling. Its application in the statistics spans distribution estimation (Liang,, 2021), instrumental variable regression (Dikkala et al.,, 2020), estimating the (implicit) influence function (Chernozhukov et al.,, 2020; Hirshberg & Wager,, 2021), and so on. The idea of minimizing the worst-case reward among diverse environments can also be considered as “an algorithmic remedy” for out-of-distribution generalization. There are different considerations of the “reward” such as risk (Sagawa et al.,, 2020), excess risk (Agarwal & Zhang,, 2022), and the negative of the explained variance (Meinshausen & Bühlmann,, 2015). However, these methods are conceptually similar to running least squares in regression and thus cannot go beyond just learning associations. We adopted the adversarial estimation in our estimation from two novel aspects. Firstly, it allows us to use a simple objective function that homogenizes different tasks and prediction models for estimation. Moreover, such a minimax optimization objective and the Gumbel approximation in the implementation jointly relax the combinatorial nature in (1.3) and make a variant of gradient descents continue to work numerically.

1.5 Organization

This paper is structured as follows. We first provide the proposed method with non-asymptotic theoretical analysis, and causal interpretations for our canonical nonparametric causality (invariance) pursuit problem in Sections 23, respectively. Such a special instance of our framework also helps to illustrate the main idea and philosophy of our general invariance pursuit problem and FAIR estimation framework, which will be formally presented in Section 4. In the main text, we provide a sketch of the abstract unified result, from which all non-asymptotic results are derived as corollaries, along with its other applications in Section 4.3 and defer the detailed statements to the Appendix. We provide a computationally efficient implementation using variants of gradient descent and Gumbel approximation, followed by its application to the simulation and real data analysis in Section 5. All the proofs are collected in the supplemental material.

1.6 Notations

We use upper case (X,Y,Z)𝑋𝑌𝑍(X,Y,Z)( italic_X , italic_Y , italic_Z ) to represent random variables/vectors and denote their instances as (x,y,z)𝑥𝑦𝑧(x,y,z)( italic_x , italic_y , italic_z ). Define [n]={1,,n}delimited-[]𝑛1𝑛[n]=\{1,\ldots,n\}[ italic_n ] = { 1 , … , italic_n }. For a vector x=(x1,,xd)d𝑥superscriptsubscript𝑥1subscript𝑥𝑑topsuperscript𝑑x=(x_{1},\ldots,x_{d})^{\top}\in\mathbb{R}^{d}italic_x = ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, we let x2=(j=1dxj2)1/2subscriptnorm𝑥2superscriptsuperscriptsubscript𝑗1𝑑superscriptsubscript𝑥𝑗212\|x\|_{2}=(\sum_{j=1}^{d}x_{j}^{2})^{1/2}∥ italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ( ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT. For given index set S={j1,,j|S|}[d]𝑆subscript𝑗1subscript𝑗𝑆delimited-[]𝑑S=\{j_{1},\ldots,j_{|S|}\}\subseteq[d]italic_S = { italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT | italic_S | end_POSTSUBSCRIPT } ⊆ [ italic_d ] with j1<<j|S|subscript𝑗1subscript𝑗𝑆j_{1}<\cdots<j_{|S|}italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < ⋯ < italic_j start_POSTSUBSCRIPT | italic_S | end_POSTSUBSCRIPT, we denote [x]S=(xj1,,xj|S|)|S|subscriptdelimited-[]𝑥𝑆superscriptsubscript𝑥subscript𝑗1subscript𝑥subscript𝑗𝑆topsuperscript𝑆[x]_{S}=(x_{j_{1}},\ldots,x_{j_{|S|}})^{\top}\in\mathbb{R}^{|S|}[ italic_x ] start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT = ( italic_x start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT | italic_S | end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_S | end_POSTSUPERSCRIPT and abbreviate it as xSsubscript𝑥𝑆x_{S}italic_x start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT if there is no ambiguity. We let ab=max{a,b}𝑎𝑏𝑎𝑏a\lor b=\max\{a,b\}italic_a ∨ italic_b = roman_max { italic_a , italic_b } and ab=min{a,b}𝑎𝑏𝑎𝑏a\land b=\min\{a,b\}italic_a ∧ italic_b = roman_min { italic_a , italic_b }. We use a(n)b(n)less-than-or-similar-to𝑎𝑛𝑏𝑛a(n)\lesssim b(n)italic_a ( italic_n ) ≲ italic_b ( italic_n ), b(n)a(n)greater-than-or-equivalent-to𝑏𝑛𝑎𝑛b(n)\gtrsim a(n)italic_b ( italic_n ) ≳ italic_a ( italic_n ), or a(n)=O(b(n))𝑎𝑛𝑂𝑏𝑛a(n)=O(b(n))italic_a ( italic_n ) = italic_O ( italic_b ( italic_n ) ) if there exists some constant C>0𝐶0C>0italic_C > 0 such that a(n)Cb(n)𝑎𝑛𝐶𝑏𝑛a(n)\leq Cb(n)italic_a ( italic_n ) ≤ italic_C italic_b ( italic_n ) for any n3𝑛3n\geq 3italic_n ≥ 3. Denote a(n)b(n)asymptotically-equals𝑎𝑛𝑏𝑛a(n)\asymp b(n)italic_a ( italic_n ) ≍ italic_b ( italic_n ) if a(n)b(n)less-than-or-similar-to𝑎𝑛𝑏𝑛a(n)\lesssim b(n)italic_a ( italic_n ) ≲ italic_b ( italic_n ) and a(n)b(n)greater-than-or-equivalent-to𝑎𝑛𝑏𝑛a(n)\gtrsim b(n)italic_a ( italic_n ) ≳ italic_b ( italic_n ). In the theorem statement and proof, we will use C𝐶Citalic_C to represent the universal constants that may vary from line to line and will use C~,C~1,~𝐶subscript~𝐶1\widetilde{C},\widetilde{C}_{1},\ldotsover~ start_ARG italic_C end_ARG , over~ start_ARG italic_C end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … to represent the constant that may depend on the other constants defined in the paper.

In the context of the multi-environment setup, consider the following notations. For each e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, let Θ(e)=L2(μx(e)):={f:f2(x)μx(e)(dx)<}superscriptΘ𝑒subscript𝐿2subscriptsuperscript𝜇𝑒𝑥assignconditional-set𝑓superscript𝑓2𝑥superscriptsubscript𝜇𝑥𝑒𝑑𝑥\Theta^{(e)}=L_{2}(\mu^{(e)}_{x}):=\{f:\int f^{2}(x)\mu_{x}^{(e)}(dx)<\infty\}roman_Θ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) := { italic_f : ∫ italic_f start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_x ) italic_μ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_d italic_x ) < ∞ }, and denote f2,e={f2(x)μx(e)(dx)}1/2subscriptnorm𝑓2𝑒superscriptsuperscript𝑓2𝑥subscriptsuperscript𝜇𝑒𝑥𝑑𝑥12\|f\|_{2,e}=\{\int f^{2}(x)\mu^{(e)}_{x}(dx)\}^{1/2}∥ italic_f ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT = { ∫ italic_f start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_x ) italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_d italic_x ) } start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT. Given n𝑛nitalic_n observations {(Xi(e),Yi(e))}i=1nd×superscriptsubscriptsuperscriptsubscript𝑋𝑖𝑒superscriptsubscript𝑌𝑖𝑒𝑖1𝑛superscript𝑑\{(X_{i}^{(e)},Y_{i}^{(e)})\}_{i=1}^{n}\subseteq\mathbb{R}^{d}\times\mathbb{R}{ ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ⊆ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT × blackboard_R drawn i.i.d. from μ(e)superscript𝜇𝑒\mu^{(e)}italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT, we define 𝔼[f(X(e),Y(e))]=f(x,y)μ(e)(dx,dy)𝔼delimited-[]𝑓superscript𝑋𝑒superscript𝑌𝑒𝑓𝑥𝑦superscript𝜇𝑒𝑑𝑥𝑑𝑦\mathbb{E}[f(X^{(e)},Y^{(e)})]=\int f(x,y)\mu^{(e)}(dx,dy)blackboard_E [ italic_f ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ] = ∫ italic_f ( italic_x , italic_y ) italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_d italic_x , italic_d italic_y ) and 𝔼^[f(X(e),Y(e))]=1ni=1nf(Xi(e),Yi(e))^𝔼delimited-[]𝑓superscript𝑋𝑒superscript𝑌𝑒1𝑛superscriptsubscript𝑖1𝑛𝑓subscriptsuperscript𝑋𝑒𝑖subscriptsuperscript𝑌𝑒𝑖\widehat{\mathbb{E}}[f(X^{(e)},Y^{(e)})]=\frac{1}{n}\sum_{i=1}^{n}f(X^{(e)}_{i% },Y^{(e)}_{i})over^ start_ARG blackboard_E end_ARG [ italic_f ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ] = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_f ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) for any fΘ(e)𝑓superscriptΘ𝑒f\in\Theta^{(e)}italic_f ∈ roman_Θ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT. We assume 𝔼[|Y(e)|2]<𝔼delimited-[]superscriptsuperscript𝑌𝑒2\mathbb{E}[|Y^{(e)}|^{2}]<\inftyblackboard_E [ | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] < ∞. Let μ¯=1||eμ(e)¯𝜇1subscript𝑒superscript𝜇𝑒\bar{\mu}=\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}\mu^{(e)}over¯ start_ARG italic_μ end_ARG = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT, and Θ=L2(μ¯x)Θsubscript𝐿2subscript¯𝜇𝑥\Theta=L_{2}(\bar{\mu}_{x})roman_Θ = italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) equipped with the norm 2={f2(x)μ¯x(dx)}1/2\|\cdot\|_{2}=\{\int f^{2}(x)\bar{\mu}_{x}(dx)\}^{1/2}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = { ∫ italic_f start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_x ) over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_d italic_x ) } start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT. It is easy to verify that Θ=eΘ(e)Θsubscript𝑒superscriptΘ𝑒\Theta=\bigcap_{e\in\mathcal{E}}\Theta^{(e)}roman_Θ = ⋂ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT roman_Θ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT.

Let S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] be any index set. Given a function class {h:d}conditional-setsuperscript𝑑\mathcal{H}\subseteq\{h:\mathbb{R}^{d}\to\mathbb{R}\}caligraphic_H ⊆ { italic_h : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R }, we define Ssubscript𝑆\mathcal{H}_{S}caligraphic_H start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT be the class of functions in \mathcal{H}caligraphic_H that only depend on variables xSsubscript𝑥𝑆x_{S}italic_x start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT, i.e., S={h,h(x)u(xS) for some u:|S|μ(e)-a.s.e}subscript𝑆conditional-setformulae-sequence𝑥𝑢subscript𝑥𝑆 for some 𝑢formulae-sequencesuperscript𝑆superscript𝜇𝑒-𝑎𝑠for-all𝑒\mathcal{H}_{S}=\{h\in\mathcal{H},h(x)\equiv u(x_{S})\text{ for some }u:% \mathbb{R}^{|S|}\to\mathbb{R}~{}~{}\mu^{(e)}\text{-}a.s.\forall e\in\mathcal{E}\}caligraphic_H start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT = { italic_h ∈ caligraphic_H , italic_h ( italic_x ) ≡ italic_u ( italic_x start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ) for some italic_u : blackboard_R start_POSTSUPERSCRIPT | italic_S | end_POSTSUPERSCRIPT → blackboard_R italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_a . italic_s . ∀ italic_e ∈ caligraphic_E }. We sometimes also write h(xS)subscript𝑥𝑆h(x_{S})italic_h ( italic_x start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ) instead of h(x)𝑥h(x)italic_h ( italic_x ) for hSsubscript𝑆h\in\mathcal{H}_{S}italic_h ∈ caligraphic_H start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT since hhitalic_h only depends on xSsubscript𝑥𝑆x_{S}italic_x start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT. For any hh\in\mathcal{H}italic_h ∈ caligraphic_H, we use Sh[d]subscript𝑆delimited-[]𝑑S_{h}\subseteq[d]italic_S start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ⊆ [ italic_d ] to represent the index set of the variables hhitalic_h depends on. We let {}k={(h1,,hk):hii[k]}superscript𝑘conditional-setsubscript1subscript𝑘subscript𝑖for-all𝑖delimited-[]𝑘\{\mathcal{H}\}^{k}=\{(h_{1},\ldots,h_{k}):h_{i}\in\mathcal{H}~{}\forall i\in[% k]\}{ caligraphic_H } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = { ( italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) : italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ caligraphic_H ∀ italic_i ∈ [ italic_k ] }. For any (X,Y)𝑋𝑌(X,Y)( italic_X , italic_Y )’s joint distribution ν𝜈\nuitalic_ν, we use νxsubscript𝜈𝑥\nu_{x}italic_ν start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT to denote the marginal distribution of X𝑋Xitalic_X, and νx,Ssubscript𝜈𝑥𝑆\nu_{x,S}italic_ν start_POSTSUBSCRIPT italic_x , italic_S end_POSTSUBSCRIPT to denote the marginal distribution of XSsubscript𝑋𝑆X_{S}italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT.

Neural Networks. We use neural networks as a scalable nonparametric technique: we adopt the fully connected deep neural network with ReLU activation σ()=max{0,}𝜎0\sigma(\cdot)=\max\{0,\cdot\}italic_σ ( ⋅ ) = roman_max { 0 , ⋅ }, and call it deep ReLU network for short. Let L,N𝐿𝑁L,Nitalic_L , italic_N be any positive integer, a deep ReLU network with depth L𝐿Litalic_L width N𝑁Nitalic_N admits the form of

g(x)=TL+1σ¯LTLσ¯T2σ¯1T1(x).𝑔𝑥subscript𝑇𝐿1subscript¯𝜎𝐿subscript𝑇𝐿¯𝜎subscript𝑇2subscript¯𝜎1subscript𝑇1𝑥\displaystyle g(x)=T_{L+1}\circ\bar{\sigma}_{L}\circ T_{L}\circ\bar{\sigma}% \circ\cdots\circ T_{2}\circ\bar{\sigma}_{1}\circ T_{1}(x).italic_g ( italic_x ) = italic_T start_POSTSUBSCRIPT italic_L + 1 end_POSTSUBSCRIPT ∘ over¯ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ∘ italic_T start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ∘ over¯ start_ARG italic_σ end_ARG ∘ ⋯ ∘ italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∘ over¯ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∘ italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) . (1.4)

Here Tl(z)=Wlz+bl:dldl+1:subscript𝑇𝑙𝑧subscript𝑊𝑙𝑧subscript𝑏𝑙superscriptsubscript𝑑𝑙superscriptsubscript𝑑𝑙1T_{l}(z)=W_{l}z+b_{l}:\mathbb{R}^{d_{l}}\to\mathbb{R}^{d_{l+1}}italic_T start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_z ) = italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_z + italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_l + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a linear map with weight matrix Wldl×dl1subscript𝑊𝑙superscriptsubscript𝑑𝑙subscript𝑑𝑙1W_{l}\in\mathbb{R}^{d_{l}\times d_{l-1}}italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and bias vector bldlsubscript𝑏𝑙superscriptsubscript𝑑𝑙b_{l}\in\mathbb{R}^{d_{l}}italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, where (d0,d1,dL,dL+1)=(d,N,,N,1)subscript𝑑0subscript𝑑1subscript𝑑𝐿subscript𝑑𝐿1𝑑𝑁𝑁1(d_{0},d_{1}\ldots,d_{L},d_{L+1})=(d,N,\ldots,N,1)( italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT … , italic_d start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_L + 1 end_POSTSUBSCRIPT ) = ( italic_d , italic_N , … , italic_N , 1 ), and σ¯l:dldl:subscript¯𝜎𝑙superscriptsubscript𝑑𝑙superscriptsubscript𝑑𝑙\bar{\sigma}_{l}:\mathbb{R}^{d_{l}}\to\mathbb{R}^{d_{l}}over¯ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT applies the ReLU activation σ()𝜎\sigma(\cdot)italic_σ ( ⋅ ) to each entry of a dlsubscript𝑑𝑙d_{l}italic_d start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT-dimensional vector. Here the equal width is for presentation simplicity.

Definition 1 (Deep ReLU network class).

Define the family of deep ReLU networks taking d𝑑ditalic_d-dimensional vector as input with depth L𝐿Litalic_L, width N𝑁Nitalic_N, truncated by B𝐵Bitalic_B as 𝚗𝚗(d,L,N,B)={g~(x)=TcB(g(x)):g(x) in (1.4)}subscript𝚗𝚗𝑑𝐿𝑁𝐵conditional-set~𝑔𝑥subscriptTc𝐵𝑔𝑥𝑔𝑥 in italic-(1.4italic-)\mathcal{H}_{\mathtt{nn}}(d,L,N,B)=\{\widetilde{g}(x)=\mathrm{Tc}_{B}(g(x)):g(% x)\text{ in }\eqref{eq:nn-architecture}\}caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L , italic_N , italic_B ) = { over~ start_ARG italic_g end_ARG ( italic_x ) = roman_Tc start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_g ( italic_x ) ) : italic_g ( italic_x ) in italic_( italic_) }, where TcB::subscriptTc𝐵\mathrm{Tc}_{B}:\mathbb{R}\to\mathbb{R}roman_Tc start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT : blackboard_R → blackboard_R is the truncation operator defined as TcB(z)=max{|z|,B}sign(z)subscriptTc𝐵𝑧𝑧𝐵sign𝑧\mathrm{Tc}_{B}(z)=\max\{|z|,B\}\cdot\mathrm{sign}(z)roman_Tc start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_z ) = roman_max { | italic_z | , italic_B } ⋅ roman_sign ( italic_z ).

2 FAIR Least Squares Estimator Using Neural Networks

In this section, we show that one can use the FAIR-NN least squares estimator, a realization of the FAIR estimator by setting (y,v)=12(yv)2𝑦𝑣12superscript𝑦𝑣2\ell(y,v)=\frac{1}{2}(y-v)^{2}roman_ℓ ( italic_y , italic_v ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_y - italic_v ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and specifying both (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ) as neural networks, to attain sample-efficient estimation in nonparametric causality pursuit.

The main messages of this section are two-fold. From a theoretical perspective, it shows that sample-efficient estimation (in both n𝑛nitalic_n and |||\mathcal{E}|| caligraphic_E |) in the general nonparametric causality pursuit problem is viable under a minimal identification condition related to the heterogeneity of the environments. From a methodological perspective, it demonstrates one key feature of our proposed framework: one can seamlessly integrate black-box machine learning models (e.g. neural networks) into it and fully exploit these models’ sample efficiency and capability in being adaptive to low-dimension structures.

2.1 Setup

We introduce some notations. Recall that μ(e)superscript𝜇𝑒\mu^{(e)}italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT is the joint distribution of (X,Y)𝑋𝑌(X,Y)( italic_X , italic_Y ) in environment e𝑒eitalic_e. Let m(e,S)(x):=𝔼[Y(e)|XS(e)=xS]assignsuperscript𝑚𝑒𝑆𝑥𝔼delimited-[]conditionalsuperscript𝑌𝑒superscriptsubscript𝑋𝑆𝑒subscript𝑥𝑆m^{(e,S)}(x):=\mathbb{E}[Y^{(e)}|X_{S}^{(e)}=x_{S}]italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ( italic_x ) := blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = italic_x start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ] be the conditional expectation of Y𝑌Yitalic_Y given XSsubscript𝑋𝑆X_{S}italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT in environment e𝑒eitalic_e. Recall that νx,Ssubscript𝜈𝑥𝑆\nu_{x,S}italic_ν start_POSTSUBSCRIPT italic_x , italic_S end_POSTSUBSCRIPT is the marginal distribution of XSsubscript𝑋𝑆X_{S}italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT for (X,Y)νsimilar-to𝑋𝑌𝜈(X,Y)\sim\nu( italic_X , italic_Y ) ∼ italic_ν. It is easy to see that μx,S(e)subscriptsuperscript𝜇𝑒𝑥𝑆\mu^{(e)}_{x,S}italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_x , italic_S end_POSTSUBSCRIPT is absolutely continuous with respect to μ¯x,S=[1||eμ(e)]x,Ssubscript¯𝜇𝑥𝑆subscriptdelimited-[]1subscript𝑒superscript𝜇𝑒𝑥𝑆\bar{\mu}_{x,S}=[\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}\mu^{(e)}]_{x,S}over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x , italic_S end_POSTSUBSCRIPT = [ divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_x , italic_S end_POSTSUBSCRIPT for any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] hence ρS(e)subscriptsuperscript𝜌𝑒𝑆\rho^{(e)}_{S}italic_ρ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT, the Radon–Nikodym derivative of μx,S(e)superscriptsubscript𝜇𝑥𝑆𝑒\mu_{x,S}^{(e)}italic_μ start_POSTSUBSCRIPT italic_x , italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT with respect to μ¯x,Ssubscript¯𝜇𝑥𝑆\bar{\mu}_{x,S}over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x , italic_S end_POSTSUBSCRIPT, is well defined. We define m¯(S)(x)=eρS(e)(xS)m(e,S)(x)superscript¯𝑚𝑆𝑥subscript𝑒subscriptsuperscript𝜌𝑒𝑆subscript𝑥𝑆superscript𝑚𝑒𝑆𝑥\bar{m}^{(S)}(x)=\sum_{e\in\mathcal{E}}\rho^{(e)}_{S}(x_{S})m^{(e,S)}(x)over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ( italic_x ) = ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT italic_ρ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ) italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ( italic_x ), which can be interpreted as the population-level least squares that regress Y𝑌Yitalic_Y on XSsubscript𝑋𝑆X_{S}italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT using all the data in \mathcal{E}caligraphic_E.

Condition 1 (Model and Regularity Conditions).

There exists some positive constants (C0,smin)subscript𝐶0subscript𝑠(C_{0},s_{\min})( italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ) such that the following conditions hold.

  • (a)

    Data Generating Process We collect data from ||+superscript|\mathcal{E}|\in\mathbb{N}^{+}| caligraphic_E | ∈ blackboard_N start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT environments with ||nC0superscript𝑛subscript𝐶0|\mathcal{E}|\leq n^{C_{0}}| caligraphic_E | ≤ italic_n start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. For each environment e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, we observe {(Xi(e),Yi(e))}i=1ni.i.d.μ(e)\{(X_{i}^{(e)},Y_{i}^{(e)})\}_{i=1}^{n}\overset{i.i.d.}{\sim}\mu^{(e)}{ ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_OVERACCENT italic_i . italic_i . italic_d . end_OVERACCENT start_ARG ∼ end_ARG italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT.

  • (b)

    Invariance Structure: There exists some set Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and m:|S|:superscript𝑚superscriptsuperscript𝑆m^{\star}:\mathbb{R}^{|S^{\star}|}\to\mathbb{R}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT : blackboard_R start_POSTSUPERSCRIPT | italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT | end_POSTSUPERSCRIPT → blackboard_R such that m(e,S)(x)m(xS)superscript𝑚𝑒superscript𝑆𝑥superscript𝑚subscript𝑥superscript𝑆m^{(e,S^{\star})}(x)\equiv m^{\star}(x_{S^{\star}})italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ( italic_x ) ≡ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) for any e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E.

  • (c)

    Sub-Gaussian Response: For any e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E and t0𝑡0t\geq 0italic_t ≥ 0, [|Y(e)|t]C0et2/(2C0)delimited-[]superscript𝑌𝑒𝑡subscript𝐶0superscript𝑒superscript𝑡22subscript𝐶0\mathbb{P}\left[|Y^{(e)}|\geq t\right]\leq C_{0}e^{-t^{2}/(2C_{0})}blackboard_P [ | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | ≥ italic_t ] ≤ italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT - italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( 2 italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT.

  • (d)

    Boundedness: X[C0,C0]d𝑋superscriptsubscript𝐶0subscript𝐶0𝑑X\in[-C_{0},C_{0}]^{d}italic_X ∈ [ - italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT μ¯¯𝜇\bar{\mu}over¯ start_ARG italic_μ end_ARG-a.s. and m(e,S)C0subscriptnormsuperscript𝑚𝑒𝑆subscript𝐶0\|m^{(e,S)}\|_{\infty}\leq C_{0}∥ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT for any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] and e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E.

  • (e)

    Nondegenerate Covariate: For any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] with SSsuperscript𝑆𝑆S^{\star}\setminus S\neq\emptysetitalic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∖ italic_S ≠ ∅, infmΘSmm22smin>0subscriptinfimum𝑚subscriptΘ𝑆superscriptsubscriptnorm𝑚superscript𝑚22subscript𝑠0\inf_{m\in\Theta_{S}}\|m-m^{\star}\|_{2}^{2}\geq s_{\min}>0roman_inf start_POSTSUBSCRIPT italic_m ∈ roman_Θ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_m - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT > 0.

1 (a)–(b) is just a restatement of (1.1) together with i.i.d. data within each environment; data across different environments may be dependent. (c)–(d) are standard in nonparametric regression. (e) rules out some degenerate cases, for example, m(x1)=x12superscript𝑚subscript𝑥1superscriptsubscript𝑥12m^{\star}(x_{1})=x_{1}^{2}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT with S={1}superscript𝑆1S^{\star}=\{1\}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = { 1 } and X2=X14subscript𝑋2superscriptsubscript𝑋14X_{2}=X_{1}^{4}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT, or m(x1,x2)=f(x1)superscript𝑚subscript𝑥1subscript𝑥2𝑓subscript𝑥1m^{\star}(x_{1},x_{2})=f(x_{1})italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = italic_f ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) with S={1,2}superscript𝑆12S^{\star}=\{1,2\}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = { 1 , 2 }, and is imposed for technical convenience. The target (invariant) regression function in nonparametric causality pursuit is msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

2.2 Proposed FAIR-NN Least Squares Estimator

Given all the data {{(Xi(e),Yi(e))}i=1n}esubscriptsuperscriptsubscriptsuperscriptsubscript𝑋𝑖𝑒superscriptsubscript𝑌𝑖𝑒𝑖1𝑛𝑒\{\{(X_{i}^{(e)},Y_{i}^{(e)})\}_{i=1}^{n}\}_{e\in\mathcal{E}}{ { ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT from heterogeneous environments, we consider using the following FAIR-NN least squares estimator to learn msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in (1.1). Specifically, the FAIR-NN least squares estimator is the solution to the subsequent minimax optimization objective

g^argming𝒢supf{Sg}||1||ne,i[n]{Yi(e)g(Xi(e))}2+γ𝖩^(g,f).^𝑔subscriptargmin𝑔𝒢subscriptsupremumsuperscript𝑓superscriptsubscriptsubscript𝑆𝑔1𝑛subscriptformulae-sequence𝑒𝑖delimited-[]𝑛superscriptsubscriptsuperscript𝑌𝑒𝑖𝑔subscriptsuperscript𝑋𝑒𝑖2𝛾^𝖩𝑔superscript𝑓\displaystyle\widehat{g}\in\mathop{\mathrm{argmin}}_{g\in\mathcal{G}}\sup_{f^{% \mathcal{E}}\in\{\mathcal{F}_{S_{g}}\}^{|\mathcal{E}|}}\frac{1}{|\mathcal{E}|% \cdot n}\sum_{e\in\mathcal{E},i\in[n]}\left\{Y^{(e)}_{i}-g(X^{(e)}_{i})\right% \}^{2}+\gamma\widehat{\mathsf{J}}(g,f^{\mathcal{E}}).over^ start_ARG italic_g end_ARG ∈ roman_argmin start_POSTSUBSCRIPT italic_g ∈ caligraphic_G end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG | caligraphic_E | ⋅ italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E , italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT { italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_g ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_γ over^ start_ARG sansserif_J end_ARG ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) . (2.1)

where the first part of the objective 𝖰^γ(g,f)subscript^𝖰𝛾𝑔superscript𝑓\widehat{\mathsf{Q}}_{\gamma}(g,f^{\mathcal{E}})over^ start_ARG sansserif_Q end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) is the pooled least squares loss preventing the estimator from collapsing to conservative solutions, γ𝛾\gammaitalic_γ is the hyper-parameter to be determined, and 𝖩^(g,f)^𝖩𝑔superscript𝑓\widehat{\mathsf{J}}(g,f^{\mathcal{E}})over^ start_ARG sansserif_J end_ARG ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) is the empirical counterpart of the focused adversarial invariance regularizer defined as

𝖩^(g,f)=1||ne,i[n][{Yi(e)g(Xi(e))}f(e)(Xi(e))12{f(e)(Xi(e))}2].^𝖩𝑔superscript𝑓1𝑛subscriptformulae-sequence𝑒𝑖delimited-[]𝑛delimited-[]subscriptsuperscript𝑌𝑒𝑖𝑔superscriptsubscript𝑋𝑖𝑒superscript𝑓𝑒subscriptsuperscript𝑋𝑒𝑖12superscriptsuperscript𝑓𝑒subscriptsuperscript𝑋𝑒𝑖2\displaystyle\widehat{\mathsf{J}}(g,f^{\mathcal{E}})=\frac{1}{|\mathcal{E}|% \cdot n}\sum_{e\in\mathcal{E},i\in[n]}\left[\big{\{}Y^{(e)}_{i}-g(X_{i}^{(e)})% \big{\}}f^{(e)}(X^{(e)}_{i})-\frac{1}{2}\big{\{}f^{(e)}(X^{(e)}_{i})\big{\}}^{% 2}\right].over^ start_ARG sansserif_J end_ARG ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | ⋅ italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E , italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT [ { italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_g ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG { italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] . (2.2)

The minimax program (2.1) is the empirical version of (1.2) via setting (y,v)=12(yv)2𝑦𝑣12superscript𝑦𝑣2\ell(y,v)=\frac{1}{2}(y-v)^{2}roman_ℓ ( italic_y , italic_v ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_y - italic_v ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. Here we specify the predictor function class 𝒢𝒢\mathcal{G}caligraphic_G and testing (discriminator) function class \mathcal{F}caligraphic_F as

𝒢=𝚗𝚗(d,L,N,B)and=𝚗𝚗(d,L+2,2N,2B)formulae-sequence𝒢subscript𝚗𝚗𝑑𝐿𝑁𝐵andsubscript𝚗𝚗𝑑𝐿22𝑁2𝐵\displaystyle\mathcal{G}=\mathcal{H}_{\mathtt{nn}}(d,L,N,B)\qquad\text{and}% \qquad\mathcal{F}=\mathcal{H}_{\mathtt{nn}}(d,L+2,2N,2B)caligraphic_G = caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L , italic_N , italic_B ) and caligraphic_F = caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L + 2 , 2 italic_N , 2 italic_B ) (2.3)

for neural network architecture hyper-parameters N,L𝑁𝐿N,Litalic_N , italic_L and truncation parameter B=C0𝐵subscript𝐶0B=C_{0}italic_B = italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Here B𝐵Bitalic_B can be larger than C0subscript𝐶0C_{0}italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT but should satisfies B=O(1)𝐵𝑂1B=O(1)italic_B = italic_O ( 1 ). A larger width, depth, and truncation parameter can also be adopted for \mathcal{F}caligraphic_F. Our specification of (N,L,B)𝑁𝐿𝐵(N,L,B)( italic_N , italic_L , italic_B ) for \mathcal{F}caligraphic_F here is for technical purposes, that is, any m(e,S)gsuperscript𝑚𝑒𝑆𝑔m^{(e,S)}-gitalic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT - italic_g for g𝒢𝑔𝒢g\in\mathcal{G}italic_g ∈ caligraphic_G can be well approximated by some f𝑓f\in\mathcal{F}italic_f ∈ caligraphic_F.

2.3 Non-Asymptotic Result for FAIR-NN

Condition 2 (Identification for Nonparametric Causality Pursuit).

For any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] such that μ¯({mm¯(SS)})>0¯𝜇superscript𝑚superscript¯𝑚𝑆superscript𝑆0\bar{\mu}(\{m^{\star}\neq\bar{m}^{(S\cup S^{\star})}\})>0over¯ start_ARG italic_μ end_ARG ( { italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ≠ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT } ) > 0, there exists some e,e𝑒superscript𝑒e,e^{\prime}\in\mathcal{E}italic_e , italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E such that min{μ(e),μ(e)}({m(e,S)m(e,S)})}>0\min\{\mu^{(e)},\mu^{(e^{\prime})}\}(\{m^{(e,S)}\neq m^{(e^{\prime},S)}\})\}>0roman_min { italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_μ start_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT } ( { italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≠ italic_m start_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_S ) end_POSTSUPERSCRIPT } ) } > 0.

Remark 1 (Minimal Heterogeneity Condition for Identification).

The above identification condition necessitates that whenever a bias emerges when regressing Y𝑌Yitalic_Y on XSSsubscript𝑋𝑆superscript𝑆X_{S\cup S^{\star}}italic_X start_POSTSUBSCRIPT italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT using least squares, there should be noticeable shifts in conditional expectation m(e,S)superscript𝑚𝑒𝑆m^{(e,S)}italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT across environments. In other words, Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT the maximum set that preserves the invariant associations. This condition is minimal. If it is violated, it would imply

S~[d]withS~Ss.t.e𝔼[Y(e)|XS~(e)]g(XS~(e))μ(e)-a.s.for someg:|S|,\displaystyle\exists\widetilde{S}\subseteq[d]~{}\text{with}~{}\widetilde{S}% \setminus S^{\star}\neq\emptyset\qquad s.t.\qquad\forall e\in\mathcal{E}~{}~{}% \mathbb{E}[Y^{(e)}|X_{\widetilde{S}}^{(e)}]\equiv g(X_{\widetilde{S}}^{(e)})~{% }~{}\mu^{(e)}\text{-}a.s.~{}~{}\text{for some}~{}g:\mathbb{R}^{|S|}\to\mathbb{% R},∃ over~ start_ARG italic_S end_ARG ⊆ [ italic_d ] with over~ start_ARG italic_S end_ARG ∖ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ≠ ∅ italic_s . italic_t . ∀ italic_e ∈ caligraphic_E blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT over~ start_ARG italic_S end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ≡ italic_g ( italic_X start_POSTSUBSCRIPT over~ start_ARG italic_S end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_a . italic_s . for some italic_g : blackboard_R start_POSTSUPERSCRIPT | italic_S | end_POSTSUPERSCRIPT → blackboard_R ,

in which both set Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and S~~𝑆\widetilde{S}over~ start_ARG italic_S end_ARG embody the invariant conditional expectation structure, thus more environments are needed in this case to pinpoint Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Such a minimal identification condition underscores that our proposed FAIR-NN estimator is “sample efficient” regarding the number of environments |||\mathcal{E}|| caligraphic_E | required; see the discussions in Section 3. Notably, such an identification condition relaxes those employed in approaches using intersections like ICP (Peters et al.,, 2016; Heinze-Deml et al.,, 2018). These approaches require the shifts of conditional distributions for all the S𝑆Sitalic_S with m¯(S)msuperscript¯𝑚𝑆superscript𝑚\bar{m}^{(S)}\neq m^{\star}over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ≠ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT for identifying Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

Remark 2 (Relaxing 2).

We claim that 2 can be slightly relaxed given our algorithm searches for the most predictive variable set that preserves the invariance structure. But it is of a technical style and lacks semantic meaning; see discussions in Section A.4.

The following theorem provides an oracle-type inequality for the FAIR-NN least squares estimator in a structure-agnostic manner. The first term is the maximum approximation bias of neural networks across environments and the second term is related to the complexity of the neural networks used in the fitting. It implies that when the FAIR-NN penalty parameter γ𝛾\gammaitalic_γ is large enough, all endogenous spurious variables can be surely screened Fan & Lv, (2008) when n𝑛nitalic_n is large enough, thus msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT can be estimated as well as if the invariant quasi-causal set of variables Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is known. In addition, the theorem quantifies the amount of penalty needed, which is related to the signal-to-noise ratio of the problem.

Theorem 1 (Oracle-type Inequality for FAIR-NN Least Squares Estimator).

Assume 1 and 2 hold. Then γ𝙽𝙽=supS[d]:𝖻𝙽𝙽(S)>0(𝖻𝙽𝙽(S)/𝖽¯𝙽𝙽(S))<subscriptsuperscript𝛾𝙽𝙽subscriptsupremum:𝑆delimited-[]𝑑subscript𝖻𝙽𝙽𝑆0subscript𝖻𝙽𝙽𝑆subscript¯𝖽𝙽𝙽𝑆\gamma^{\star}_{\mathtt{NN}}=\sup_{S\subseteq[d]:\mathsf{b}_{\mathtt{NN}}(S)>0% }(\mathsf{b}_{\mathtt{NN}}(S)/\bar{\mathsf{d}}_{\mathtt{NN}}(S))<\inftyitalic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT = roman_sup start_POSTSUBSCRIPT italic_S ⊆ [ italic_d ] : sansserif_b start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) > 0 end_POSTSUBSCRIPT ( sansserif_b start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) / over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) ) < ∞, where

𝖻𝙽𝙽(S)=mm¯(SS)22and𝖽¯𝙽𝙽(S)=1||em(e,S)m¯(S)2,e2.formulae-sequencesubscript𝖻𝙽𝙽𝑆superscriptsubscriptnormsuperscript𝑚superscript¯𝑚𝑆superscript𝑆22andsubscript¯𝖽𝙽𝙽𝑆1subscript𝑒superscriptsubscriptnormsuperscript𝑚𝑒𝑆superscript¯𝑚𝑆2𝑒2\displaystyle\mathsf{b}_{\mathtt{NN}}(S)=\|m^{\star}-\bar{m}^{(S\cup S^{\star}% )}\|_{2}^{2}\qquad\text{and}\qquad\bar{\mathsf{d}}_{\mathtt{NN}}(S)=\frac{1}{|% \mathcal{E}|}\sum_{e\in\mathcal{E}}\|m^{(e,S)}-\bar{m}^{(S)}\|_{2,e}^{2}.sansserif_b start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) = ∥ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT - over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (2.4)

Consider the estimator that solves (2.1) using γ8γ𝙽𝙽𝛾8subscriptsuperscript𝛾𝙽𝙽\gamma\geq 8\gamma^{\star}_{\mathtt{NN}}italic_γ ≥ 8 italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT and function classes (2.3) with L,N𝐿𝑁L,Nitalic_L , italic_N satisfying NLn𝑁𝐿𝑛NL\leq nitalic_N italic_L ≤ italic_n and N4𝑁4N\geq 4italic_N ≥ 4. Then, there exists some constant C~~𝐶\widetilde{C}over~ start_ARG italic_C end_ARG depending on (d,C0)𝑑subscript𝐶0(d,C_{0})( italic_d , italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) such that for any n3𝑛3n\geq 3italic_n ≥ 3,

g^m2C~maxeinfh𝒢Smh2,e+NLlog3/2nn+1{δ𝙽𝙽,1>𝗌}(γδ𝙽𝙽,1)subscriptnorm^𝑔superscript𝑚2~𝐶subscript𝑒subscriptinfimumsubscript𝒢superscript𝑆subscriptnormsuperscript𝑚2𝑒𝑁𝐿superscript32𝑛𝑛subscript1subscript𝛿𝙽𝙽1𝗌𝛾subscript𝛿𝙽𝙽1\displaystyle\frac{\|\widehat{g}-m^{\star}\|_{2}}{\widetilde{C}}\leq\max_{e\in% \mathcal{E}}\inf_{h\in\mathcal{G}_{S^{\star}}}\|m^{\star}-h\|_{2,e}+\frac{NL% \log^{3/2}n}{\sqrt{n}}+1_{\{{\delta}_{\mathtt{NN},1}>\mathsf{s}\}}\cdot\left(% \gamma{\delta}_{\mathtt{NN},1}\right)divide start_ARG ∥ over^ start_ARG italic_g end_ARG - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG over~ start_ARG italic_C end_ARG end_ARG ≤ roman_max start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_h ∈ caligraphic_G start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_h ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT + divide start_ARG italic_N italic_L roman_log start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_n end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG + 1 start_POSTSUBSCRIPT { italic_δ start_POSTSUBSCRIPT typewriter_NN , 1 end_POSTSUBSCRIPT > sansserif_s } end_POSTSUBSCRIPT ⋅ ( italic_γ italic_δ start_POSTSUBSCRIPT typewriter_NN , 1 end_POSTSUBSCRIPT )

occurs with probability at least 1C~n1001~𝐶superscript𝑛1001-\widetilde{C}n^{-100}1 - over~ start_ARG italic_C end_ARG italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT. Here δ𝙽𝙽,1=maxe,S[d]infh𝒢Sm(e,S)h2,e+NLlog3/2nnsubscript𝛿𝙽𝙽1subscriptformulae-sequence𝑒𝑆delimited-[]𝑑subscriptinfimumsubscript𝒢𝑆subscriptnormsuperscript𝑚𝑒𝑆2𝑒𝑁𝐿superscript32𝑛𝑛\delta_{\mathtt{NN},1}=\max_{e\in\mathcal{E},S\subseteq[d]}\inf_{h\in\mathcal{% G}_{S}}\|m^{(e,S)}-h\|_{2,e}+\frac{NL\log^{3/2}n}{\sqrt{n}}italic_δ start_POSTSUBSCRIPT typewriter_NN , 1 end_POSTSUBSCRIPT = roman_max start_POSTSUBSCRIPT italic_e ∈ caligraphic_E , italic_S ⊆ [ italic_d ] end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_h ∈ caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT - italic_h ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT + divide start_ARG italic_N italic_L roman_log start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_n end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG and 𝗌=C~1[1smin{γinfS:𝖽¯𝙽𝙽(S)>0𝖽¯𝙽𝙽(S)}]/(1+γ)𝗌superscript~𝐶1delimited-[]1subscript𝑠𝛾subscriptinfimum:𝑆subscript¯𝖽𝙽𝙽𝑆0subscript¯𝖽𝙽𝙽𝑆1𝛾\mathsf{s}=\widetilde{C}^{-1}[1\land s_{\min}\land\{\gamma\inf_{S:\bar{\mathsf% {d}}_{\mathtt{NN}}(S)>0}\bar{\mathsf{d}}_{\mathtt{NN}}(S)\}]/(1+\gamma)sansserif_s = over~ start_ARG italic_C end_ARG start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ 1 ∧ italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ∧ { italic_γ roman_inf start_POSTSUBSCRIPT italic_S : over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) > 0 end_POSTSUBSCRIPT over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) } ] / ( 1 + italic_γ ), where sminsubscript𝑠s_{\min}italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT is defined in 1(5).

As our result is non-asymptotic, for a given n𝑛nitalic_n, we may not be able to eliminate all endogenous spurious variables. The third term in Theorem 1 reflects this when the signal is not sufficiently large. It is more explicitly given in Corollary 1.

Remark 3 (Interpretation of 𝖻𝙽𝙽(S)subscript𝖻𝙽𝙽𝑆\mathsf{b}_{\mathtt{NN}}(S)sansserif_b start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) and 𝖽¯𝙽𝙽(S)subscript¯𝖽𝙽𝙽𝑆\bar{\mathsf{d}}_{\mathtt{NN}}(S)over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S )).

We refer to 𝖻𝙽𝙽(S)subscript𝖻𝙽𝙽𝑆\mathsf{b}_{\mathtt{NN}}(S)sansserif_b start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) as bias mean since it exactly characterizes the bias of the least squares estimator in the presence of endogenous spurious variables like the background color in the thought experiment. In particular, letting g^𝙻𝚂𝙴(S)subscript^𝑔𝙻𝚂𝙴𝑆\widehat{g}_{\mathtt{LSE}(S)}over^ start_ARG italic_g end_ARG start_POSTSUBSCRIPT typewriter_LSE ( italic_S ) end_POSTSUBSCRIPT be the least squares estimator that regresses Y𝑌Yitalic_Y on XSsubscript𝑋𝑆X_{S}italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT using all the data, namely, the FAIR-NN estimator with γ=0𝛾0\gamma=0italic_γ = 0, Proposition 6 implies

|g^𝙻𝚂𝙴(S)m22𝖻𝙽𝙽(S)1|=o(1)ifSSand𝖻𝙽𝙽(S)>0.formulae-sequencesuperscriptsubscriptnormsubscript^𝑔𝙻𝚂𝙴𝑆superscript𝑚22subscript𝖻𝙽𝙽𝑆1subscript𝑜1ifsuperscript𝑆𝑆andsubscript𝖻𝙽𝙽𝑆0\displaystyle\left|\frac{\|\widehat{g}_{\mathtt{LSE}(S)}-m^{\star}\|_{2}^{2}}{% \mathsf{b}_{\mathtt{NN}}(S)}-1\right|=o_{\mathbb{P}}(1)\qquad\text{if}~{}~{}S^% {\star}\subseteq S~{}\text{and}~{}\mathsf{b}_{\mathtt{NN}}(S)>0.| divide start_ARG ∥ over^ start_ARG italic_g end_ARG start_POSTSUBSCRIPT typewriter_LSE ( italic_S ) end_POSTSUBSCRIPT - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG sansserif_b start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) end_ARG - 1 | = italic_o start_POSTSUBSCRIPT blackboard_P end_POSTSUBSCRIPT ( 1 ) if italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ italic_S and sansserif_b start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) > 0 .

We refer to 𝖽¯𝙽𝙽(S)subscript¯𝖽𝙽𝙽𝑆\bar{\mathsf{d}}_{\mathtt{NN}}(S)over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) as the bias variance because it measures the variations of bias across environments. Specifically, when SSsuperscript𝑆𝑆S^{\star}\subseteq Sitalic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ italic_S, the bias in environment e𝑒eitalic_e is (m(e,S)m)superscript𝑚𝑒𝑆superscript𝑚(m^{(e,S)}-m^{\star})( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ), and 𝖽¯𝙽𝙽(S)subscript¯𝖽𝙽𝙽𝑆\bar{\mathsf{d}}_{\mathtt{NN}}(S)over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) can be viewed as the variance of the bias concerning the uniform distribution on \mathcal{E}caligraphic_E since 𝖽¯𝙽𝙽(S)=1||e(m(e,S)m)(m¯(S)m)2,e2subscript¯𝖽𝙽𝙽𝑆1subscript𝑒superscriptsubscriptnormsuperscript𝑚𝑒𝑆superscript𝑚superscript¯𝑚𝑆superscript𝑚2𝑒2\bar{\mathsf{d}}_{\mathtt{NN}}(S)=\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}% }\|(m^{(e,S)}-m^{\star})-(\bar{m}^{(S)}-m^{\star})\|_{2,e}^{2}over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) - ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. We have 𝖽¯𝙽𝙽(S)=0subscript¯𝖽𝙽𝙽superscript𝑆0\bar{\mathsf{d}}_{\mathtt{NN}}(S^{\star})=0over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = 0 by the invariance structure in 1(b).

Remark 4 (Identification).

Theorem 1 combines the identification result, which characterizes when it is possible to consistently estimate msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, and the finite-sample estimation error result, which characterizes how accurately we can estimate msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. The main identification message disentangled from the above theorem is that if the minimal heterogeneity condition 2 holds, then one can consistently estimate msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT provided γ𝛾\gammaitalic_γ is larger than some threshold 8γ𝙽𝙽8subscriptsuperscript𝛾𝙽𝙽8\gamma^{\star}_{\mathtt{NN}}8 italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT that is independent of n𝑛nitalic_n.

2.4 Adapting to the Low-dimensional Structures Algorithmically

To present the explicit L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error rate under a specific nonparametric setup, we first introduce the concept of (β,C)𝛽𝐶(\beta,C)( italic_β , italic_C )-smooth function.

Definition 2 ((β,C)𝛽𝐶(\beta,C)( italic_β , italic_C )-smooth Function).

Let β=r+s𝛽𝑟𝑠\beta=r+sitalic_β = italic_r + italic_s for some nonnegative integer r0𝑟0r\geq 0italic_r ≥ 0 and 0<s10𝑠10<s\leq 10 < italic_s ≤ 1, and C>0𝐶0C>0italic_C > 0. A d𝑑ditalic_d-variate function f𝑓fitalic_f is (β,C)𝛽𝐶(\beta,C)( italic_β , italic_C )-smooth if for every non-negative sequence αd𝛼superscript𝑑\alpha\in\mathbb{N}^{d}italic_α ∈ blackboard_N start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT such that j=1dαj=rsuperscriptsubscript𝑗1𝑑subscript𝛼𝑗𝑟\sum_{j=1}^{d}\alpha_{j}=r∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_r, the partial derivative αf=(f)/(x1α1xdαd)superscript𝛼𝑓𝑓superscriptsubscript𝑥1subscript𝛼1superscriptsubscript𝑥𝑑subscript𝛼𝑑\partial^{\alpha}f=(\partial f)/(\partial x_{1}^{\alpha_{1}}\cdots x_{d}^{% \alpha_{d}})∂ start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT italic_f = ( ∂ italic_f ) / ( ∂ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ⋯ italic_x start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) exists and satisfies |αf(x)αf(z)|Cxz2ssuperscript𝛼𝑓𝑥superscript𝛼𝑓𝑧𝐶superscriptsubscriptnorm𝑥𝑧2𝑠|\partial^{\alpha}f(x)-\partial^{\alpha}f(z)|\leq C\|x-z\|_{2}^{s}| ∂ start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT italic_f ( italic_x ) - ∂ start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT italic_f ( italic_z ) | ≤ italic_C ∥ italic_x - italic_z ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT. We use 𝙷𝚂(d,β,C)subscript𝙷𝚂𝑑𝛽𝐶\mathcal{H}_{\mathtt{HS}}(d,\beta,C)caligraphic_H start_POSTSUBSCRIPT typewriter_HS end_POSTSUBSCRIPT ( italic_d , italic_β , italic_C ) to denote the set of all the d𝑑ditalic_d-variate (β,C)𝛽𝐶(\beta,C)( italic_β , italic_C )-smooth functions.

One significant advantage of neural networks over traditional nonparametric methods is their intrinsic capability for algorithmic nonparametric regression. This enables them to learn low-dimensional structures with little or no explicit guidance regarding the forms of functions (Bauer & Kohler,, 2019; Schmidt-Hieber,, 2020; Kohler & Langer,, 2021; Fan & Gu,, 2024). We begin by elucidating the concept of the Hierarchical Composition Model (HCM), which is basically the compositions of t𝑡titalic_t-variate functions with (β,C)𝛽𝐶(\beta,C)( italic_β , italic_C )-smooth l𝑙litalic_l times for (t,β)𝑡𝛽(t,\beta)( italic_t , italic_β ) in a certain set 𝒪𝒪\mathcal{O}caligraphic_O.

Definition 3 (Hierarchical Composition Model 𝙷𝙲𝙼(d,l,𝒪,C)subscript𝙷𝙲𝙼𝑑𝑙𝒪𝐶\mathcal{H}_{\mathtt{HCM}}(d,l,\mathcal{O},C)caligraphic_H start_POSTSUBSCRIPT typewriter_HCM end_POSTSUBSCRIPT ( italic_d , italic_l , caligraphic_O , italic_C )).

We define function class of hierarchical composition model 𝙷𝙲𝙼(d,l,𝒪,C)subscript𝙷𝙲𝙼𝑑𝑙𝒪𝐶\mathcal{H}_{\mathtt{HCM}}(d,l,\mathcal{O},C)caligraphic_H start_POSTSUBSCRIPT typewriter_HCM end_POSTSUBSCRIPT ( italic_d , italic_l , caligraphic_O , italic_C ) (Kohler & Langer,, 2021) with l,d+𝑙𝑑superscriptl,d\in\mathbb{N}^{+}italic_l , italic_d ∈ blackboard_N start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT, C+𝐶superscriptC\in\mathbb{R}^{+}italic_C ∈ blackboard_R start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT, and 𝒪𝒪\mathcal{O}caligraphic_O, a subset of [1,)×+1superscript[1,\infty)\times\mathbb{N}^{+}[ 1 , ∞ ) × blackboard_N start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT, in a recursive way as follows. Let 𝙷𝙲𝙼(d,0,𝒪,C)={h(x)=xj,j[d]}subscript𝙷𝙲𝙼𝑑0𝒪𝐶formulae-sequence𝑥subscript𝑥𝑗𝑗delimited-[]𝑑\mathcal{H}_{\mathtt{HCM}}(d,0,\mathcal{O},C)=\{h(x)=x_{j},j\in[d]\}caligraphic_H start_POSTSUBSCRIPT typewriter_HCM end_POSTSUBSCRIPT ( italic_d , 0 , caligraphic_O , italic_C ) = { italic_h ( italic_x ) = italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_j ∈ [ italic_d ] }, and for each l1𝑙1l\geq 1italic_l ≥ 1,

𝙷𝙲𝙼(d,l,𝒪,C)={\displaystyle\mathcal{H}_{\mathtt{HCM}}(d,l,\mathcal{O},C)=\big{\{}caligraphic_H start_POSTSUBSCRIPT typewriter_HCM end_POSTSUBSCRIPT ( italic_d , italic_l , caligraphic_O , italic_C ) = { h:d:h(x)=g(f1(x),,ft(x)), where:superscript𝑑:𝑥𝑔subscript𝑓1𝑥subscript𝑓𝑡𝑥, where\displaystyle h:\mathbb{R}^{d}\to\mathbb{R}:h(x)=g(f_{1}(x),...,f_{t}(x))\text% {, where}italic_h : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R : italic_h ( italic_x ) = italic_g ( italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) , … , italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ) , where
g𝙷𝚂(t,β,C) with (β,t)𝒪 and fi𝙷𝙲𝙼(d,l1,𝒪,C)}.\displaystyle~{}~{}~{}~{}~{}g\in\mathcal{H}_{\mathtt{HS}}(t,\beta,C)\text{ % with }(\beta,t)\in\mathcal{O}\text{ and }f_{i}\in\mathcal{H}_{\mathtt{HCM}}(d,% l-1,\mathcal{O},C)\big{\}}.italic_g ∈ caligraphic_H start_POSTSUBSCRIPT typewriter_HS end_POSTSUBSCRIPT ( italic_t , italic_β , italic_C ) with ( italic_β , italic_t ) ∈ caligraphic_O and italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ caligraphic_H start_POSTSUBSCRIPT typewriter_HCM end_POSTSUBSCRIPT ( italic_d , italic_l - 1 , caligraphic_O , italic_C ) } .

Following Kohler & Langer, (2021), we assume all the compositions are at least Lipschitz functions to simplify the presentation. The minimax optimal L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT estimation risk over (d,l,𝒪,Ch)𝑑𝑙𝒪subscript𝐶\mathcal{H}(d,l,\mathcal{O},C_{h})caligraphic_H ( italic_d , italic_l , caligraphic_O , italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) is nα/(2α+1)superscript𝑛superscript𝛼2superscript𝛼1n^{-\alpha^{\star}/(2\alpha^{\star}+1)}italic_n start_POSTSUPERSCRIPT - italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT / ( 2 italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + 1 ) end_POSTSUPERSCRIPT, where α=min(β,t)𝒪(β/t)superscript𝛼subscript𝛽𝑡𝒪𝛽𝑡\alpha^{\star}=\min_{(\beta,t)\in\mathcal{O}}(\beta/t)italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = roman_min start_POSTSUBSCRIPT ( italic_β , italic_t ) ∈ caligraphic_O end_POSTSUBSCRIPT ( italic_β / italic_t ) is the smallest dimensionality-adjusted degree of smoothness that represents the hardest component in the composition. For example, if m(x)=f1(x1)+f2(f3(x2,x3),f4(x4,x5))+f5(x1,x3,x5)superscript𝑚𝑥subscript𝑓1subscript𝑥1subscript𝑓2subscript𝑓3subscript𝑥2subscript𝑥3subscript𝑓4subscript𝑥4subscript𝑥5subscript𝑓5subscript𝑥1subscript𝑥3subscript𝑥5m^{\star}(x)=f_{1}(x_{1})+f_{2}(f_{3}(x_{2},x_{3}),f_{4}(x_{4},x_{5}))+f_{5}(x% _{1},x_{3},x_{5})italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x ) = italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) , italic_f start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ) ) + italic_f start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ) and all functions have a bounded second derivative, then the hardest component is the last one, and the dimensionality-adjusted degree of smoothness is α=2/3superscript𝛼23\alpha^{*}=2/3italic_α start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = 2 / 3.

Condition 3 (Function Complexity and Neural Network Architecture).

The following holds:

(a) m(e,S)𝙷𝙲𝙼(|S|,l,𝒪,Ch)superscript𝑚𝑒𝑆subscript𝙷𝙲𝙼𝑆𝑙𝒪subscript𝐶m^{(e,S)}\in\mathcal{H}_{\mathtt{HCM}}(|S|,l,\mathcal{O},C_{h})italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ∈ caligraphic_H start_POSTSUBSCRIPT typewriter_HCM end_POSTSUBSCRIPT ( | italic_S | , italic_l , caligraphic_O , italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) for any e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E and S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] with α0=inf(β,t)𝒪(β/t)subscript𝛼0subscriptinfimum𝛽𝑡𝒪𝛽𝑡\alpha_{0}=\inf_{(\beta,t)\in\mathcal{O}}(\beta/t)italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = roman_inf start_POSTSUBSCRIPT ( italic_β , italic_t ) ∈ caligraphic_O end_POSTSUBSCRIPT ( italic_β / italic_t ).

(b) m𝙷𝙲𝙼(|S|,l,𝒪,Ch)superscript𝑚subscript𝙷𝙲𝙼superscript𝑆𝑙superscript𝒪subscript𝐶m^{\star}\in\mathcal{H}_{\mathtt{HCM}}(|S^{\star}|,l,\mathcal{O}^{\star},C_{h})italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ caligraphic_H start_POSTSUBSCRIPT typewriter_HCM end_POSTSUBSCRIPT ( | italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT | , italic_l , caligraphic_O start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) with α=inf(β,t)𝒪(β/t)superscript𝛼subscriptinfimum𝛽𝑡superscript𝒪𝛽𝑡\alpha^{\star}=\inf_{(\beta,t)\in\mathcal{O}^{\star}}(\beta/t)italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = roman_inf start_POSTSUBSCRIPT ( italic_β , italic_t ) ∈ caligraphic_O start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_β / italic_t ).

(c) We choose N,L𝑁𝐿N,Litalic_N , italic_L satisfying LN{n(logn)8α3}12(2α+1)asymptotically-equals𝐿𝑁superscript𝑛superscript𝑛8superscript𝛼3122superscript𝛼1LN\asymp\{n(\log n)^{8\alpha^{\star}-3}\}^{\frac{1}{2(2\alpha^{\star}+1)}}italic_L italic_N ≍ { italic_n ( roman_log italic_n ) start_POSTSUPERSCRIPT 8 italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT } start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 ( 2 italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + 1 ) end_ARG end_POSTSUPERSCRIPT and (logn)/(NL)=o(1)𝑛𝑁𝐿𝑜1(\log n)/(N\land L)=o(1)( roman_log italic_n ) / ( italic_N ∧ italic_L ) = italic_o ( 1 ) .

(d) max{C0,d,l,Ch,sup(β,t)𝒪(βt),sup(β,t)𝒪(βt)}C1subscript𝐶0𝑑𝑙subscript𝐶subscriptsupremum𝛽𝑡𝒪𝛽𝑡subscriptsupremum𝛽𝑡superscript𝒪𝛽𝑡subscript𝐶1\max\{C_{0},d,l,C_{h},\sup_{(\beta,t)\in\mathcal{O}}(\beta\lor t),\sup_{(\beta% ,t)\in\mathcal{O}^{\star}}(\beta\lor t)\}\leq C_{1}roman_max { italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_d , italic_l , italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT , roman_sup start_POSTSUBSCRIPT ( italic_β , italic_t ) ∈ caligraphic_O end_POSTSUBSCRIPT ( italic_β ∨ italic_t ) , roman_sup start_POSTSUBSCRIPT ( italic_β , italic_t ) ∈ caligraphic_O start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_β ∨ italic_t ) } ≤ italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT for some constant C1>1subscript𝐶11C_{1}>1italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT > 1.

Corollary 1 (Optimal Rate for FAIR-NN).

Under the setting of Theorem 1, assume further that 3 holds. Then, for any n3𝑛3n\geq 3italic_n ≥ 3, with probability at least 1C~n1001~𝐶superscript𝑛1001-\widetilde{C}n^{-100}1 - over~ start_ARG italic_C end_ARG italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT, the following holds

g^m2C~log7(n)nα/(2α+1)+1{n<n0}γnα0/(2α+1),subscriptnorm^𝑔superscript𝑚2~𝐶superscript7𝑛superscript𝑛superscript𝛼2superscript𝛼1subscript1𝑛subscript𝑛0𝛾superscript𝑛subscript𝛼02superscript𝛼1\displaystyle\frac{\|\widehat{g}-m^{\star}\|_{2}}{\widetilde{C}\log^{7}(n)}% \leq n^{-\alpha^{\star}/(2\alpha^{\star}+1)}+1_{\{n<n_{0}\}}\gamma\cdot n^{-% \alpha_{0}/(2\alpha^{\star}+1)},divide start_ARG ∥ over^ start_ARG italic_g end_ARG - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG over~ start_ARG italic_C end_ARG roman_log start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT ( italic_n ) end_ARG ≤ italic_n start_POSTSUPERSCRIPT - italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT / ( 2 italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + 1 ) end_POSTSUPERSCRIPT + 1 start_POSTSUBSCRIPT { italic_n < italic_n start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT } end_POSTSUBSCRIPT italic_γ ⋅ italic_n start_POSTSUPERSCRIPT - italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / ( 2 italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + 1 ) end_POSTSUPERSCRIPT , (2.5)

where n0subscript𝑛0n_{0}italic_n start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT depends on (C1,γ,smin,infS:𝖽¯𝙽𝙽(S)>0𝖽¯𝙽𝙽(S))subscript𝐶1𝛾subscript𝑠subscriptinfimum:𝑆subscript¯𝖽𝙽𝙽𝑆0subscript¯𝖽𝙽𝙽𝑆(C_{1},\gamma,s_{\min},\inf_{S:\bar{\mathsf{d}}_{\mathtt{NN}}(S)>0}\bar{% \mathsf{d}}_{\mathtt{NN}}(S))( italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_γ , italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT , roman_inf start_POSTSUBSCRIPT italic_S : over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) > 0 end_POSTSUBSCRIPT over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) ), and C~~𝐶\widetilde{C}over~ start_ARG italic_C end_ARG is a constant dependent only on C1subscript𝐶1C_{1}italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

From Corollary 1, we can get (up to logarithmic factors) minimax convergence rate nα/(2α+1)superscript𝑛superscript𝛼2superscript𝛼1n^{-\alpha^{\star}/(2\alpha^{\star}+1)}italic_n start_POSTSUPERSCRIPT - italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT / ( 2 italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + 1 ) end_POSTSUPERSCRIPT, which is independent of both α0subscript𝛼0\alpha_{0}italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and γ𝛾\gammaitalic_γ, when n𝑛nitalic_n is larger than some constant n0subscript𝑛0n_{0}italic_n start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Utilizing neural networks in predictor and discriminator function classes allows the estimator to adapt to the invariant regression function msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT efficiently from two crucial perspectives. Firstly, similar to using neural networks in nonparametric regression (Schmidt-Hieber,, 2020; Kohler & Langer,, 2021; Fan & Gu,, 2024), adopting neural networks in 𝒢𝒢\mathcal{G}caligraphic_G endows the estimator with the capability of being adaptive to the low-dimensional hierarchical structure algorithmically. Secondly, the choice of model parameter (N,L)𝑁𝐿(N,L)( italic_N , italic_L ), and the convergence rate depends only on msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. The (spurious) conditional expectations m(e,S)superscript𝑚𝑒𝑆m^{(e,S)}italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT can be much more complex than msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Notably, this complexity will not affect the convergence rate. This can be credited to the scalability of neural networks used as discriminators, i.e., their adaptivity capability in the regularization part of FAIR.

Remark 5 (Guaranteed for All n𝑛nitalic_n).

The error bound (2.5) is applicable for any n3𝑛3n\geq 3italic_n ≥ 3, even when it selects the wrong variables. Notably, the error bound will not inflate if the invariant signal sminsubscript𝑠s_{\min}italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT and the heterogeneity signal infS[d]:𝖽¯𝙽𝙽(S)>0𝖽¯𝙽𝙽(S)subscriptinfimum:𝑆delimited-[]𝑑subscript¯𝖽𝙽𝙽𝑆0subscript¯𝖽𝙽𝙽𝑆\inf_{S\subseteq[d]:\bar{\mathsf{d}}_{\mathtt{NN}}(S)>0}\bar{\mathsf{d}}_{% \mathtt{NN}}(S)roman_inf start_POSTSUBSCRIPT italic_S ⊆ [ italic_d ] : over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) > 0 end_POSTSUBSCRIPT over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) is small. Though the error bound scales linearly with γ𝛾\gammaitalic_γ, the estimator we propose is not vulnerable to “weak spurious” variables, e.g., xjsubscript𝑥𝑗x_{j}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT with supem(e,S{j})m2,eϵsubscriptsupremum𝑒subscriptnormsuperscript𝑚𝑒superscript𝑆𝑗superscript𝑚2𝑒italic-ϵ\sup_{e\in\mathcal{E}}\|m^{(e,S^{\star}\cup\{j\})}-m^{\star}\|_{2,e}\leq\epsilonroman_sup start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∪ { italic_j } ) end_POSTSUPERSCRIPT - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT ≤ italic_ϵ, provided all the ratio of the bias 𝖻𝙽𝙽(S)subscript𝖻𝙽𝙽𝑆\mathsf{b}_{\mathtt{NN}}(S)sansserif_b start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) to heterogeneity 𝖽¯𝙽𝙽(S)subscript¯𝖽𝙽𝙽𝑆\bar{\mathsf{d}}_{\mathtt{NN}}(S)over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) gets controlled.

Remark 6 (Choice of the Hyper-parameter γ𝛾\gammaitalic_γ).

Though we have to choose a hyper-parameter γ𝛾\gammaitalic_γ larger than a certain threshold to attain such a rate, the convergence rate is independent of γ𝛾\gammaitalic_γ. This implies that when the sample size n𝑛nitalic_n is large, we do not need to tune the hyper-parameter γ𝛾\gammaitalic_γ for optimal performance. Instead, we can choose some conservative (large) γ𝛾\gammaitalic_γ such that the lower bound γ8γ𝙽𝙽𝛾8subscriptsuperscript𝛾𝙽𝙽\gamma\geq 8\gamma^{\star}_{\mathtt{NN}}italic_γ ≥ 8 italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT is guaranteed.

3 Nonparametric Invariance Pursuit under SCMs

The results in Section 2 are for the problem nonparametric invariance pursuit itself. In a population-level view, it pursues “maximum invariant set” Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT satisfying

m(e,S)m¯(S)(invariant)andS[d],m(e,S)m¯(S)m¯(SS)=m¯(S)(maximum).formulae-sequencesuperscript𝑚𝑒superscript𝑆superscript¯𝑚superscript𝑆invariantandfor-all𝑆delimited-[]𝑑superscript𝑚𝑒𝑆superscript¯𝑚𝑆superscript¯𝑚𝑆superscript𝑆superscript¯𝑚superscript𝑆maximum\displaystyle m^{(e,S^{\star})}\equiv\bar{m}^{(S^{\star})}~{}(\text{invariant}% )~{}~{}~{}\text{and}~{}~{}~{}\forall S\subseteq[d],~{}m^{(e,S)}\equiv\bar{m}^{% (S)}\Longrightarrow\bar{m}^{(S\cup S^{\star})}=\bar{m}^{(S^{\star})}~{}(\text{% maximum}).italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ≡ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ( invariant ) and ∀ italic_S ⊆ [ italic_d ] , italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≡ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ⟹ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT = over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ( maximum ) . (3.1)

Section 2 shows the FAIR-NN estimator can estimate such a Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT (msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT) efficiently. It is natural to ask

Does such a maximum invariant set Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT exist? What’s the semantic meaning of it?

We offer a clean yet general answer to the question under the SCM with arbitrary interventions (on X𝑋Xitalic_X) setting. The short answer is: Yes, and it can be interpreted as the “pragmatic direct causes”.

3.1 Structural Causal Model with Interventions on Covariates

We first introduce the concept of the structural causal model (Glymour et al.,, 2016). See Fig. 1 for examples of SCM. It says that each variable in the directed graph is a function of its parents (if any) and an independent innovation or noise.

Definition 4 (Structural Causal Model).

A structural causal model M=(𝒮,ν)𝑀𝒮𝜈M=(\mathcal{S},\nu)italic_M = ( caligraphic_S , italic_ν ) on p𝑝pitalic_p variables Z1,,Zpsubscript𝑍1subscript𝑍𝑝Z_{1},\ldots,Z_{p}italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_Z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT can be described using p𝑝pitalic_p assignment functions {f1,,fp}=𝒮subscript𝑓1subscript𝑓𝑝𝒮\{f_{1},\ldots,f_{p}\}=\mathcal{S}{ italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_f start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT } = caligraphic_S:

Zjfj(Z𝚙𝚊(j),Uj)j=1,,p,formulae-sequencesubscript𝑍𝑗subscript𝑓𝑗subscript𝑍𝚙𝚊𝑗subscript𝑈𝑗𝑗1𝑝\displaystyle Z_{j}\leftarrow f_{j}(Z_{\mathtt{pa}(j)},U_{j})\qquad j=1,\ldots% ,p,italic_Z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ← italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_Z start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) italic_j = 1 , … , italic_p ,

where 𝚙𝚊(j){1,,p}𝚙𝚊𝑗1𝑝\mathtt{pa}(j)\subseteq\{1,\ldots,p\}typewriter_pa ( italic_j ) ⊆ { 1 , … , italic_p } is the set of parents, or the direct causes, of the variable Zjsubscript𝑍𝑗Z_{j}italic_Z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, and the joint distribution ν(du)=j=1pνj(duj)𝜈𝑑𝑢superscriptsubscriptproduct𝑗1𝑝subscript𝜈𝑗𝑑subscript𝑢𝑗\nu(du)=\prod_{j=1}^{p}\nu_{j}(du_{j})italic_ν ( italic_d italic_u ) = ∏ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT italic_ν start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_d italic_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) over p𝑝pitalic_p independent exogenous variables (U1,,Up)subscript𝑈1subscript𝑈𝑝(U_{1},\ldots,U_{p})( italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_U start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ). For a given model M𝑀Mitalic_M, there is an associated directed graph G(M)=(V,E)𝐺𝑀𝑉𝐸G(M)=(V,E)italic_G ( italic_M ) = ( italic_V , italic_E ) that describes the causal relationships among variables, where V=[p]𝑉delimited-[]𝑝V=[p]italic_V = [ italic_p ] is the set of nodes, E𝐸Eitalic_E is the edge set such that (i,j)E𝑖𝑗𝐸(i,j)\in E( italic_i , italic_j ) ∈ italic_E if and only if i𝚙𝚊(j)𝑖𝚙𝚊𝑗i\in\mathtt{pa}(j)italic_i ∈ typewriter_pa ( italic_j ). G(M)𝐺𝑀G(M)italic_G ( italic_M ) is acyclic if there is no sequence (v1,,vk)subscript𝑣1subscript𝑣𝑘(v_{1},\ldots,v_{k})( italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) with k2𝑘2k\geq 2italic_k ≥ 2 such that v1=vksubscript𝑣1subscript𝑣𝑘v_{1}=v_{k}italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and (vi,vi+1)Esubscript𝑣𝑖subscript𝑣𝑖1𝐸(v_{i},v_{i+1})\in E( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_v start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ∈ italic_E for any i[k1]𝑖delimited-[]𝑘1i\in[k-1]italic_i ∈ [ italic_k - 1 ].

As in Peters et al., (2016), we consider the following data-generating process in |||\mathcal{E}|| caligraphic_E | environments. For each e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, the process governing p=d+1𝑝𝑑1p=d+1italic_p = italic_d + 1 random variables Z(e)=(Z1(e),,Zd+1(e))=(X1(e),,Xd(e),Y(e))superscript𝑍𝑒superscriptsubscript𝑍1𝑒superscriptsubscript𝑍𝑑1𝑒superscriptsubscript𝑋1𝑒superscriptsubscript𝑋𝑑𝑒superscript𝑌𝑒Z^{(e)}=(Z_{1}^{(e)},\ldots,Z_{d+1}^{(e)})=(X_{1}^{(e)},\ldots,X_{d}^{(e)},Y^{% (e)})italic_Z start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = ( italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , … , italic_Z start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) = ( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) is derived from an SCM M(e)(𝒮(e),ν)superscript𝑀𝑒superscript𝒮𝑒𝜈M^{(e)}(\mathcal{S}^{(e)},\nu)italic_M start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( caligraphic_S start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_ν ), whose induced graph G(M(e))𝐺superscript𝑀𝑒G(M^{(e)})italic_G ( italic_M start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) is acyclic, and assignments as

Xj(e)fj(e)(Z𝚙𝚊(j)(e),Uj),j=1,,dY(e)fd+1(X𝚙𝚊(d+1)(e),Ud+1).\displaystyle\begin{split}X_{j}^{(e)}&\leftarrow f_{j}^{(e)}(Z_{\mathtt{pa}(j)% }^{(e)},U_{j}),\qquad\qquad j=1,\ldots,d\\ Y^{(e)}&\leftarrow f_{d+1}(X_{\mathtt{pa}(d+1)}^{(e)},U_{d+1}).\end{split}start_ROW start_CELL italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_CELL start_CELL ← italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_Z start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) , italic_j = 1 , … , italic_d end_CELL end_ROW start_ROW start_CELL italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_CELL start_CELL ← italic_f start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT typewriter_pa ( italic_d + 1 ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_U start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ) . end_CELL end_ROW (3.2)

Here the distribution of exogenous variables (U1,,Ud+1)subscript𝑈1subscript𝑈𝑑1(U_{1},\ldots,U_{d+1})( italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_U start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ), the cause-effect relationship graph G𝐺Gitalic_G, and the structural assignment fd+1subscript𝑓𝑑1f_{d+1}italic_f start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT are invariant across e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, while the structural assignments for X𝑋Xitalic_X may vary among e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E. We use superscript (e)𝑒(e)( italic_e ) to highlight this heterogeneity. This heterogeneity may arise from performing arbitrary interventions on the variables X𝑋Xitalic_X. We use Z𝚙𝚊(j)subscript𝑍𝚙𝚊𝑗Z_{\mathtt{pa}(j)}italic_Z start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT to emphasize that Y𝑌Yitalic_Y can be the direct cause of some variables in the covariate vector. See an example in Fig. 1 (a).

To present the result, we consider an augmented SCM that incorporates the environment label e𝑒eitalic_e as a variable E𝐸Eitalic_E. We consider the case where ={0,,||1}01\mathcal{E}=\{0,\ldots,|\mathcal{E}|-1\}caligraphic_E = { 0 , … , | caligraphic_E | - 1 }. We let 00 be the observational environment, and the rest are the interventional environments where some unknown, arbitrary interventions are applied to the variables in some given set I[d]𝐼delimited-[]𝑑I\subseteq[d]italic_I ⊆ [ italic_d ] defined as I:={j:es.t.fj(e)fj(0)}assign𝐼conditional-set𝑗𝑒s.t.superscriptsubscript𝑓𝑗𝑒superscriptsubscript𝑓𝑗0I:=\{j:\exists e\in\mathcal{E}~{}\text{s.t.}~{}f_{j}^{(e)}\neq f_{j}^{(0)}\}italic_I := { italic_j : ∃ italic_e ∈ caligraphic_E s.t. italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ≠ italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT }. The interventions can be arbitrary: it can be a “hard” do-intervention via set Xjsubscript𝑋𝑗X_{j}italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT being vjsubscript𝑣𝑗v_{j}italic_v start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, or soft intervention that slightly perturbs the association, e.g., replace Xj2Xk+Ujsubscript𝑋𝑗2subscript𝑋𝑘subscript𝑈𝑗X_{j}\leftarrow 2X_{k}+U_{j}italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ← 2 italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT by Xj1.5Xk+Ujsubscript𝑋𝑗1.5subscript𝑋𝑘subscript𝑈𝑗X_{j}\leftarrow 1.5X_{k}+U_{j}italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ← 1.5 italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. The shared cause-effect relationships in all the environments are encoded by G𝐺Gitalic_G, or {𝚙𝚊(j)}j=1d+1superscriptsubscript𝚙𝚊𝑗𝑗1𝑑1\{\mathtt{pa}(j)\}_{j=1}^{d+1}{ typewriter_pa ( italic_j ) } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d + 1 end_POSTSUPERSCRIPT.

The following SCM M~=(𝒮~,ν~)~𝑀~𝒮~𝜈\widetilde{M}=(\widetilde{\mathcal{S}},\widetilde{\nu})over~ start_ARG italic_M end_ARG = ( over~ start_ARG caligraphic_S end_ARG , over~ start_ARG italic_ν end_ARG ) on d+2𝑑2d+2italic_d + 2 variables Z=(Z1,,Zd,Zd+1,Zd+2)=(X1,,Xd,Y,E)𝑍subscript𝑍1subscript𝑍𝑑subscript𝑍𝑑1subscript𝑍𝑑2subscript𝑋1subscript𝑋𝑑𝑌𝐸Z=(Z_{1},\ldots,Z_{d},Z_{d+1},Z_{d+2})=(X_{1},\ldots,X_{d},Y,E)italic_Z = ( italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_Z start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT italic_d + 2 end_POSTSUBSCRIPT ) = ( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT , italic_Y , italic_E ) encodes all the information of |||\mathcal{E}|| caligraphic_E | models {M(e)(𝒮(e),ν)}esubscriptsuperscript𝑀𝑒superscript𝒮𝑒𝜈𝑒\{M^{(e)}(\mathcal{S}^{(e)},\nu)\}_{e\in\mathcal{E}}{ italic_M start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( caligraphic_S start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_ν ) } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT in (3.2). Denote νbUniform()similar-tosubscript𝜈𝑏Uniform\nu_{b}\sim\mathrm{Uniform}(\mathcal{E})italic_ν start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ∼ roman_Uniform ( caligraphic_E ). Here ν~(du1,,dud+2)=ν(du1,,dud+1)νb(dud+2)~𝜈𝑑subscript𝑢1𝑑subscript𝑢𝑑2𝜈𝑑subscript𝑢1𝑑subscript𝑢𝑑1subscript𝜈𝑏𝑑subscript𝑢𝑑2\widetilde{\nu}(du_{1},\ldots,du_{d+2})=\nu(du_{1},\ldots,du_{d+1})\nu_{b}(du_% {d+2})over~ start_ARG italic_ν end_ARG ( italic_d italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_d italic_u start_POSTSUBSCRIPT italic_d + 2 end_POSTSUBSCRIPT ) = italic_ν ( italic_d italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_d italic_u start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ) italic_ν start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ( italic_d italic_u start_POSTSUBSCRIPT italic_d + 2 end_POSTSUBSCRIPT ), and the assignments 𝒮~={f~1,,f~d+2}~𝒮subscript~𝑓1subscript~𝑓𝑑2\widetilde{\mathcal{S}}=\{\widetilde{f}_{1},\ldots,\widetilde{f}_{d+2}\}over~ start_ARG caligraphic_S end_ARG = { over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_d + 2 end_POSTSUBSCRIPT } are defined as

Ef~d+2(Ud+2):=Ud+2Xj{f~j(Z𝚙𝚊(j),Uj):=fj(0)(Z𝚙𝚊(j),Uj)j[d]If~j(Z𝚙𝚊(j),E,Uj):=fj(E)(Z𝚙𝚊(j),Uj)jIYf~d+1(X𝚙𝚊(d+1),Ud+1):=fd+1(X𝚙𝚊(𝚍+𝟷),Ud+1),𝐸subscript~𝑓𝑑2subscript𝑈𝑑2assignsubscript𝑈𝑑2subscript𝑋𝑗casesassignsubscript~𝑓𝑗subscript𝑍𝚙𝚊𝑗subscript𝑈𝑗subscriptsuperscript𝑓0𝑗subscript𝑍𝚙𝚊𝑗subscript𝑈𝑗for-all𝑗delimited-[]𝑑𝐼assignsubscript~𝑓𝑗subscript𝑍𝚙𝚊𝑗𝐸subscript𝑈𝑗subscriptsuperscript𝑓𝐸𝑗subscript𝑍𝚙𝚊𝑗subscript𝑈𝑗for-all𝑗𝐼𝑌subscript~𝑓𝑑1subscript𝑋𝚙𝚊𝑑1subscript𝑈𝑑1assignsubscript𝑓𝑑1subscript𝑋𝚙𝚊𝚍1subscript𝑈𝑑1\displaystyle\begin{split}E&\leftarrow\widetilde{f}_{d+2}(U_{d+2}):=U_{d+2}\\ X_{j}&\leftarrow\begin{cases}\widetilde{f}_{j}(Z_{\mathtt{pa}(j)},U_{j}):=f^{(% 0)}_{j}(Z_{\mathtt{pa}(j)},U_{j})&\qquad\forall j\in[d]\setminus I\\ \widetilde{f}_{j}(Z_{\mathtt{pa}(j)},E,U_{j}):=f^{(E)}_{j}(Z_{\mathtt{pa}(j)},% U_{j})&\qquad\forall j\in I\end{cases}\\ Y&\leftarrow\widetilde{f}_{d+1}(X_{\mathtt{pa}}(d+1),U_{d+1}):=f_{d+1}(X_{% \mathtt{pa(d+1)}},U_{d+1}),\end{split}start_ROW start_CELL italic_E end_CELL start_CELL ← over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_d + 2 end_POSTSUBSCRIPT ( italic_U start_POSTSUBSCRIPT italic_d + 2 end_POSTSUBSCRIPT ) := italic_U start_POSTSUBSCRIPT italic_d + 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_CELL start_CELL ← { start_ROW start_CELL over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_Z start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) := italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_Z start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_CELL start_CELL ∀ italic_j ∈ [ italic_d ] ∖ italic_I end_CELL end_ROW start_ROW start_CELL over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_Z start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT , italic_E , italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) := italic_f start_POSTSUPERSCRIPT ( italic_E ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_Z start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_CELL start_CELL ∀ italic_j ∈ italic_I end_CELL end_ROW end_CELL end_ROW start_ROW start_CELL italic_Y end_CELL start_CELL ← over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT typewriter_pa end_POSTSUBSCRIPT ( italic_d + 1 ) , italic_U start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ) := italic_f start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT typewriter_pa ( typewriter_d + typewriter_1 ) end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ) , end_CELL end_ROW (3.3)

where I𝐼Iitalic_I is the set of all intervention variables in \mathcal{E}caligraphic_E. It should be noted that throughout this section, the direct cause map 𝚙𝚊:[d+1][d+1]:𝚙𝚊delimited-[]𝑑1delimited-[]𝑑1\mathtt{pa}:[d+1]\to[d+1]typewriter_pa : [ italic_d + 1 ] → [ italic_d + 1 ] matches the causal relationship G𝐺Gitalic_G instead of G~=G(M~)~𝐺𝐺~𝑀\widetilde{G}=G(\widetilde{M})over~ start_ARG italic_G end_ARG = italic_G ( over~ start_ARG italic_M end_ARG ). See a graphical illustration of the construction in Fig. 1 (b).

We summarize the above construction as a condition.

X1subscript𝑋1X_{1}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTY𝑌Yitalic_YX2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTX3subscript𝑋3X_{3}italic_X start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPTX4subscript𝑋4X_{4}italic_X start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPTX5subscript𝑋5X_{5}italic_X start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPTX6subscript𝑋6X_{6}italic_X start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPTX7subscript𝑋7X_{7}italic_X start_POSTSUBSCRIPT 7 end_POSTSUBSCRIPTX9subscript𝑋9X_{9}italic_X start_POSTSUBSCRIPT 9 end_POSTSUBSCRIPTX8subscript𝑋8X_{8}italic_X start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPTX10subscript𝑋10X_{10}italic_X start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPTX11subscript𝑋11X_{11}italic_X start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPTM(0)superscript𝑀0M^{(0)}italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPTX1subscript𝑋1X_{1}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTY𝑌Yitalic_YX2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTX3subscript𝑋3X_{3}italic_X start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPTX4subscript𝑋4X_{4}italic_X start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPTX5subscript𝑋5X_{5}italic_X start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPTX6subscript𝑋6X_{6}italic_X start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPTX7subscript𝑋7X_{7}italic_X start_POSTSUBSCRIPT 7 end_POSTSUBSCRIPTX9subscript𝑋9X_{9}italic_X start_POSTSUBSCRIPT 9 end_POSTSUBSCRIPTX8subscript𝑋8X_{8}italic_X start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPTX10subscript𝑋10X_{10}italic_X start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPTX11subscript𝑋11X_{11}italic_X start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPTM(1)superscript𝑀1M^{(1)}italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT
X1subscript𝑋1X_{1}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTY𝑌Yitalic_YX2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTX3subscript𝑋3X_{3}italic_X start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPTX4subscript𝑋4X_{4}italic_X start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPTX5subscript𝑋5X_{5}italic_X start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPTX6subscript𝑋6X_{6}italic_X start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPTX7subscript𝑋7X_{7}italic_X start_POSTSUBSCRIPT 7 end_POSTSUBSCRIPTX9subscript𝑋9X_{9}italic_X start_POSTSUBSCRIPT 9 end_POSTSUBSCRIPTX8subscript𝑋8X_{8}italic_X start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPTX10subscript𝑋10X_{10}italic_X start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPTX11subscript𝑋11X_{11}italic_X start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPTE𝐸Eitalic_EM~~𝑀\widetilde{M}over~ start_ARG italic_M end_ARG
X1subscript𝑋1X_{1}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTY𝑌Yitalic_YX2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTX3subscript𝑋3X_{3}italic_X start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPTX4subscript𝑋4X_{4}italic_X start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPTX5subscript𝑋5X_{5}italic_X start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPTX6subscript𝑋6X_{6}italic_X start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPTX7subscript𝑋7X_{7}italic_X start_POSTSUBSCRIPT 7 end_POSTSUBSCRIPTX9subscript𝑋9X_{9}italic_X start_POSTSUBSCRIPT 9 end_POSTSUBSCRIPTX8subscript𝑋8X_{8}italic_X start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPTX10subscript𝑋10X_{10}italic_X start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPTX11subscript𝑋11X_{11}italic_X start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPTE𝐸Eitalic_E\mathcal{E}caligraphic_E     Ssubscript𝑆S_{\star}italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT in (3.4)00000\leftrightarrow 00 ↔ 0, {1,2,3,5,6,7,8,9}12356789\{1,2,3,5,6,7,8,9\}{ 1 , 2 , 3 , 5 , 6 , 7 , 8 , 9 }01010\leftrightarrow{\color[rgb]{0.01953125,0.203125,0.546875}\definecolor[named]{% pgfstrokecolor}{rgb}{0.01953125,0.203125,0.546875}1}0 ↔ 1, {1,2,3,5,6,7,8,9}12356789{\color[rgb]{0.01953125,0.203125,0.546875}\definecolor[named]{pgfstrokecolor}{% rgb}{0.01953125,0.203125,0.546875}\{1,2,3,5,6,7,8,9\}}{ 1 , 2 , 3 , 5 , 6 , 7 , 8 , 9 }02020\leftrightarrow{\color[rgb]{0.8984375,0.66015625,0.29296875}\definecolor[% named]{pgfstrokecolor}{rgb}{0.8984375,0.66015625,0.29296875}2}0 ↔ 2, {1,2,3,5,6,7}123567{\color[rgb]{0.8984375,0.66015625,0.29296875}\definecolor[named]{% pgfstrokecolor}{rgb}{0.8984375,0.66015625,0.29296875}\{1,2,3,5,6,7\}}{ 1 , 2 , 3 , 5 , 6 , 7 }03030\leftrightarrow{\color[rgb]{0.41796875,0.703125,0.57421875}\definecolor[named% ]{pgfstrokecolor}{rgb}{0.41796875,0.703125,0.57421875}3}0 ↔ 3, {1,2,3,7}1237{\color[rgb]{0.41796875,0.703125,0.57421875}\definecolor[named]{pgfstrokecolor% }{rgb}{0.41796875,0.703125,0.57421875}\{1,2,3,7\}}{ 1 , 2 , 3 , 7 }04040\leftrightarrow{\color[rgb]{0.68359375,0.09765625,0.03125}\definecolor[named]% {pgfstrokecolor}{rgb}{0.68359375,0.09765625,0.03125}4}0 ↔ 4, {1,2,3}123{\color[rgb]{0.68359375,0.09765625,0.03125}\definecolor[named]{pgfstrokecolor}% {rgb}{0.68359375,0.09765625,0.03125}\{1,2,3\}}{ 1 , 2 , 3 }
Figure 1: (a) is an illustration of the two-environment model, the SCMs in the two environments share the same associated graph: M(0)superscript𝑀0M^{(0)}italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT is an observational environment, and M(1)superscript𝑀1M^{(1)}italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT is an intervention environment where some unknown intervention is applied to (X4,X6,X7)subscript𝑋4subscript𝑋6subscript𝑋7(X_{4},X_{6},X_{7})( italic_X start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT 7 end_POSTSUBSCRIPT ), where M(0)superscript𝑀0M^{(0)}italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT and M(1)superscript𝑀1M^{(1)}italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT are defined as (3.2). (b) visualizes G~~𝐺\widetilde{G}over~ start_ARG italic_G end_ARG, the associated graph of M~~𝑀\widetilde{M}over~ start_ARG italic_M end_ARG constructed based on (M(0),M(1))superscript𝑀0superscript𝑀1(M^{(0)},M^{(1)})( italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) and (3.3), which is another plot of the environments in (a). (c) An illustration of Theorem 2 by showing how Ssubscript𝑆S_{\star}italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT therein will change as we see more and more environments: the arrow from E𝐸Eitalic_E to Xjsubscript𝑋𝑗X_{j}italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT with color e𝑒eitalic_e means Xjsubscript𝑋𝑗X_{j}italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT is intervened in e{1,2,3,4}𝑒1234e\in\{{\color[rgb]{0.01953125,0.203125,0.546875}\definecolor[named]{% pgfstrokecolor}{rgb}{0.01953125,0.203125,0.546875}1},{\color[rgb]{% 0.8984375,0.66015625,0.29296875}\definecolor[named]{pgfstrokecolor}{rgb}{% 0.8984375,0.66015625,0.29296875}2},{\color[rgb]{0.41796875,0.703125,0.57421875% }\definecolor[named]{pgfstrokecolor}{rgb}{0.41796875,0.703125,0.57421875}3},{% \color[rgb]{0.68359375,0.09765625,0.03125}\definecolor[named]{pgfstrokecolor}{% rgb}{0.68359375,0.09765625,0.03125}4}\}italic_e ∈ { 1 , 2 , 3 , 4 }. For example, 03030\leftrightarrow{\color[rgb]{0.41796875,0.703125,0.57421875}\definecolor[named% ]{pgfstrokecolor}{rgb}{0.41796875,0.703125,0.57421875}3}0 ↔ 3 means with interventions in environments 1, 2, and 3, the invariant variable set is {1,2,3,7}1237{\color[rgb]{0.41796875,0.703125,0.57421875}\definecolor[named]{pgfstrokecolor% }{rgb}{0.41796875,0.703125,0.57421875}\{1,2,3,7\}}{ 1 , 2 , 3 , 7 }. Although X7subscript𝑋7X_{7}italic_X start_POSTSUBSCRIPT 7 end_POSTSUBSCRIPT and is reverse causal and hence related to Y𝑌Yitalic_Y, we do not know this based only on the given environments.
Condition 4 (SCM with Interventions on X𝑋Xitalic_X).

Suppose M(0),,M(||1)superscript𝑀0superscript𝑀1M^{(0)},\ldots,M^{(|\mathcal{E}|-1)}italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , italic_M start_POSTSUPERSCRIPT ( | caligraphic_E | - 1 ) end_POSTSUPERSCRIPT are defined by (3.2), and G𝐺Gitalic_G is acyclic. Let M~~𝑀\widetilde{M}over~ start_ARG italic_M end_ARG be the model constructed as (3.3) by {M(e)}esubscriptsuperscript𝑀𝑒𝑒\{M^{(e)}\}_{e\in\mathcal{E}}{ italic_M start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT with I𝐼Iitalic_I be given set of variables intervened.

3.2 Maximum Invariant Set as the Pragmatic Direct Causes

We characterize what Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT would satisfy (3.1) given a fixed intervention set I𝐼Iitalic_I, and how large I𝐼Iitalic_I should be to recover the Y𝑌Yitalic_Y’s direct causes under arbitrary types of interventions. We define 𝚌𝚑(k):={j:k𝚙𝚊(j)}assign𝚌𝚑𝑘conditional-set𝑗𝑘𝚙𝚊𝑗\mathtt{ch}(k):=\{j:k\in\mathtt{pa}(j)\}typewriter_ch ( italic_k ) := { italic_j : italic_k ∈ typewriter_pa ( italic_j ) } as the set of children of variable k𝑘kitalic_k and 𝚊𝚝(k)𝚊𝚝𝑘\mathtt{at}(k)typewriter_at ( italic_k ) as the set of all the ancestors of the variable Zksubscript𝑍𝑘Z_{k}italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, defined recursively as 𝚊𝚝(k)=𝚙𝚊(k)j𝚙𝚊(k)𝚊𝚝(j)\mathtt{at}(k)=\mathtt{pa}(k)\cup\cup_{j\in\mathtt{pa}(k)}\mathtt{at}(j)typewriter_at ( italic_k ) = typewriter_pa ( italic_k ) ∪ ∪ start_POSTSUBSCRIPT italic_j ∈ typewriter_pa ( italic_k ) end_POSTSUBSCRIPT typewriter_at ( italic_j ) in the topological order of G𝐺Gitalic_G. The following condition rules out some degenerated cases.

Condition 5 (Nondegenerate Interventions).

The following holds for M~~𝑀\widetilde{M}over~ start_ARG italic_M end_ARG: (a) S[d]for-all𝑆delimited-[]𝑑\forall S\subseteq[d]∀ italic_S ⊆ [ italic_d ] containing Y𝑌Yitalic_Y’s descendants, if EM~Y|XSsubscriptperpendicular-toperpendicular-to~𝑀𝐸conditional𝑌subscript𝑋𝑆E\mathchoice{\mathrel{\hbox to0.0pt{\kern 7.77777pt\kern-5.27776pt$% \displaystyle\not$\hss}{\perp\!\!\!\perp}}}{\mathrel{\hbox to0.0pt{\kern 7.777% 77pt\kern-5.27776pt$\textstyle\not$\hss}{\perp\!\!\!\perp}}}{\mathrel{\hbox to% 0.0pt{\kern 2.75006pt\kern-4.11108pt$\scriptstyle\not$\hss}{\perp\!\!\!\perp}}% }{\mathrel{\hbox to0.0pt{\kern 1.25006pt\kern-3.3333pt$\scriptscriptstyle\not$% \hss}{\perp\!\!\!\perp}}}_{\widetilde{M}}Y|X_{S}italic_E start_RELOP / ⟂ ⟂ end_RELOP start_POSTSUBSCRIPT over~ start_ARG italic_M end_ARG end_POSTSUBSCRIPT italic_Y | italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT, then there exists some e,e𝑒superscript𝑒e,e^{\prime}\in\mathcal{E}italic_e , italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E such that (μ(e)μ(e))({m(e,S)m(e,S)})>0superscript𝜇𝑒superscript𝜇superscript𝑒superscript𝑚𝑒𝑆superscript𝑚superscript𝑒𝑆0(\mu^{(e)}\land\mu^{(e^{\prime})})(\{m^{(e,S)}\neq m^{(e^{\prime},S)}\})>0( italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∧ italic_μ start_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ) ( { italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≠ italic_m start_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_S ) end_POSTSUPERSCRIPT } ) > 0; (b) M~~𝑀\widetilde{M}over~ start_ARG italic_M end_ARG is faithful, that is,

DisjointA,B,C[d+2],ZAZB|ZCZAG~ZB|ZC,\displaystyle\forall~{}\text{Disjoint}~{}A,B,C\subseteq[d+2],\qquad Z_{A}\perp% \!\!\!\perp Z_{B}|Z_{C}~{}~{}{\Longrightarrow}~{}~{}Z_{A}\perp\!\!\!\perp_{% \widetilde{G}}Z_{B}|Z_{C},∀ Disjoint italic_A , italic_B , italic_C ⊆ [ italic_d + 2 ] , italic_Z start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ⟂ ⟂ italic_Z start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT | italic_Z start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ⟹ italic_Z start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ⟂ ⟂ start_POSTSUBSCRIPT over~ start_ARG italic_G end_ARG end_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT | italic_Z start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ,

where ZAG~ZB|ZCZ_{A}\perp\!\!\!\perp_{\widetilde{G}}Z_{B}|Z_{C}italic_Z start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ⟂ ⟂ start_POSTSUBSCRIPT over~ start_ARG italic_G end_ARG end_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT | italic_Z start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT means the node set A𝐴Aitalic_A and B𝐵Bitalic_B are d-separated by C𝐶Citalic_C in the graph G~~𝐺\widetilde{G}over~ start_ARG italic_G end_ARG; see Definition 2.4.1 in Glymour et al., (2016) for a formal definition of d𝑑ditalic_d-separation.

The condition (b), faithfulness on the graph G~~𝐺\widetilde{G}over~ start_ARG italic_G end_ARG constraining that the graph G~~𝐺\widetilde{G}over~ start_ARG italic_G end_ARG truly depicts all the conditional independence relationships, is widely used in the causal discovery literature. Condition (a) is further imposed since we only leverage the information of conditional expectations instead of conditional distributions. We impose 5 such that the dependence on E𝐸Eitalic_E in conditional expectation of Y𝑌Yitalic_Y given XSsubscript𝑋𝑆X_{S}italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT with any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] can be represented by the graph G~~𝐺\widetilde{G}over~ start_ARG italic_G end_ARG itself. The imposed 5 rules out the possibility of some degenerated cases; see the justifications for 5 and some degenerated examples in Section A.5. It should be noted that our general results in Theorem 2 and Proposition 1 apply to arbitrary forms of interventions under 5, which is a mild condition as the violation of faithfulness in 5 occurs with probability zero under some suitable measure on the model (Spirtes et al.,, 2000).

Theorem 2 (General Identification under SCM with Interventions on X𝑋Xitalic_X).

Under 4, for

S=𝚙𝚊(d+1)A(I)jA(I)(𝚙𝚊(j){d+1})subscript𝑆𝚙𝚊𝑑1𝐴𝐼subscript𝑗𝐴𝐼𝚙𝚊𝑗𝑑1\displaystyle S_{\star}=\mathtt{pa}(d+1)\cup A(I)\cup\bigcup_{j\in A(I)}\left(% \mathtt{pa}(j)\setminus\{d+1\}\right)italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT = typewriter_pa ( italic_d + 1 ) ∪ italic_A ( italic_I ) ∪ ⋃ start_POSTSUBSCRIPT italic_j ∈ italic_A ( italic_I ) end_POSTSUBSCRIPT ( typewriter_pa ( italic_j ) ∖ { italic_d + 1 } ) (3.4)

with A(I)={j:j𝚌𝚑(d+1),𝚊𝚝(j)𝚌𝚑(d+1)I=}𝐴𝐼conditional-set𝑗formulae-sequence𝑗𝚌𝚑𝑑1𝚊𝚝𝑗𝚌𝚑𝑑1𝐼A(I)=\{j:j\in\mathtt{ch}(d+1),\mathtt{at}(j)\cap\mathtt{ch}(d+1)\cap I=\emptyset\}italic_A ( italic_I ) = { italic_j : italic_j ∈ typewriter_ch ( italic_d + 1 ) , typewriter_at ( italic_j ) ∩ typewriter_ch ( italic_d + 1 ) ∩ italic_I = ∅ }, we have the invariance m(e,S)m¯(S):=msuperscript𝑚𝑒subscript𝑆superscript¯𝑚subscript𝑆assignsubscript𝑚m^{(e,S_{\star})}\equiv\bar{m}^{(S_{\star})}:=m_{\star}italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ≡ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT := italic_m start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT. Suppose further 5 holds, then 2 holds with S=Ssuperscript𝑆subscript𝑆S^{\star}=S_{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT and m=msuperscript𝑚subscript𝑚m^{\star}=m_{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_m start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT.

Theorem 2 exactly characterizes what Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is in our nonparametric invariant pursuit under the SCM with interventions on X𝑋Xitalic_X – it doesn’t require intervention to be “sufficient”. Firstly, such a Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is well-defined in that there exists one maximum set Ssubscript𝑆S_{\star}italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT satisfying the invariant condition (1.1) and heterogeneity condition 2 simultaneously. Secondly, in the SCM setting, such a S=Ssuperscript𝑆subscript𝑆S^{\star}=S_{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT can be represented in a simple way in (3.4), which lies in between the Markov blanket of the variable Y𝑌Yitalic_Y and the set of Y𝑌Yitalic_Y’s direct causes. Note that A(I)𝐴𝐼A(I)italic_A ( italic_I ) can be interpreted as the “unaffected” children of Y𝑌Yitalic_Y from the interventions I𝐼Iitalic_I. Theorem 2 states explicitly that the pursued set of invariant variables Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is the union of (1) parents of Y𝑌Yitalic_Y, (2) unaffected children of Y𝑌Yitalic_Y; and (3) parents of these unaffected children. The size of that set Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT will keep decreasing when I𝐼Iitalic_I enlarges. It will finally recover the set of direct causes of Y𝑌Yitalic_Y when I𝐼Iitalic_I includes “root children set” Isuperscript𝐼I^{\star}italic_I start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT as stated in the following Proposition 1. See an illustration in Fig. 1 (c).

Proposition 1 (Direct Cause Recovery).

(Sufficiency) Under 4, define I={j:j𝚌𝚑(d+1),𝚊𝚝(j)𝚌𝚑(d+1)=}superscript𝐼conditional-set𝑗formulae-sequence𝑗𝚌𝚑𝑑1𝚊𝚝𝑗𝚌𝚑𝑑1I^{\star}=\{j:j\in\mathtt{ch}(d+1),\mathtt{at}(j)\cap\mathtt{ch}(d+1)=\emptyset\}italic_I start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = { italic_j : italic_j ∈ typewriter_ch ( italic_d + 1 ) , typewriter_at ( italic_j ) ∩ typewriter_ch ( italic_d + 1 ) = ∅ }. If 5 holds and IIsuperscript𝐼𝐼I\supseteq I^{\star}italic_I ⊇ italic_I start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, then 2 holds with S=𝚙𝚊(d+1)superscript𝑆𝚙𝚊𝑑1S^{\star}=\mathtt{pa}(d+1)italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = typewriter_pa ( italic_d + 1 ).

(Necessity) Moreover, if m¯(SS)msuperscript¯𝑚superscript𝑆𝑆superscript𝑚\bar{m}^{(S^{\star}\cup S)}\neq m^{\star}over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∪ italic_S ) end_POSTSUPERSCRIPT ≠ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT for any S𝑆Sitalic_S with 𝚌𝚑(d+1)S𝚌𝚑𝑑1𝑆\mathtt{ch}(d+1)\cap S\neq\emptysettypewriter_ch ( italic_d + 1 ) ∩ italic_S ≠ ∅, i.e., Y𝑌Yitalic_Y does not have degenerated children, then 2 holds only if IIsuperscript𝐼𝐼I\supseteq I^{\star}italic_I ⊇ italic_I start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

We refer to Isuperscript𝐼I^{\star}italic_I start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT as the minimal intervention set because it is the exact minimal set of variables that should be intervened on for exact direct cause recovery in general, nondegenerated cases. The set Isuperscript𝐼I^{\star}italic_I start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is determined by the cause-effect relationship graph G𝐺Gitalic_G. In particular, Isuperscript𝐼I^{\star}italic_I start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is {6,7}67\{6,7\}{ 6 , 7 } for the example in Fig. 1. Notably, X8subscript𝑋8X_{8}italic_X start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPT does not require intervention, as X7subscript𝑋7X_{7}italic_X start_POSTSUBSCRIPT 7 end_POSTSUBSCRIPT, one of its ancestors, is included in Isuperscript𝐼I^{\star}italic_I start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

Unfortunately, S𝚙𝚊(d+1)𝚙𝚊𝑑1subscript𝑆S_{\star}\supsetneq\mathtt{pa}(d+1)italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ⊋ typewriter_pa ( italic_d + 1 ) when IInot-subset-of-or-equalssuperscript𝐼𝐼I^{\star}\not\subseteq Iitalic_I start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊈ italic_I in general. This is due to a lack of evidence in environments to falsify that some variables in Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT are not direct causes. Nevertheless, Ssubscript𝑆S_{\star}italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT in this setup can still be interpreted as the “contemporary direct causes” or “pragmatic direct causes” of Y𝑌Yitalic_Y based on the observed environments. If the future interventions are made within the set I𝐼Iitalic_I, then Ssubscript𝑆S_{\star}italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT can be regarded as the direct causes since the conditional expectation of Y𝑌Yitalic_Y given XSsubscript𝑋subscript𝑆X_{S_{\star}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT end_POSTSUBSCRIPT will remain invariant in a new environment t𝑡titalic_t. Moreover, one can deploy such a predictor in unseen environments because it depicts the most predictive one among all the associations in environment 00 that remains in environment t𝑡titalic_t. This can be formally stated in Proposition 2.

Proposition 2 (Robust Transfer Learning).

Under 4, for a new environment t𝑡titalic_t with SCM M(t)={𝒮(t),ν}superscript𝑀𝑡superscript𝒮𝑡𝜈M^{(t)}=\{\mathcal{S}^{(t)},\nu\}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = { caligraphic_S start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_ν } satisfying fj(t)fj(0)superscriptsubscript𝑓𝑗𝑡superscriptsubscript𝑓𝑗0f_{j}^{(t)}\equiv f_{j}^{(0)}italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ≡ italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT for any j[d+1]I𝑗delimited-[]𝑑1𝐼j\in[d+1]\setminus Iitalic_j ∈ [ italic_d + 1 ] ∖ italic_I, i.e., only XIsubscript𝑋𝐼X_{I}italic_X start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT is intervened, we have 𝔼[Y(t)|XS(t)]𝔼[Y(0)|XS(0)]𝔼delimited-[]conditionalsuperscript𝑌𝑡superscriptsubscript𝑋subscript𝑆𝑡𝔼delimited-[]conditionalsuperscript𝑌0superscriptsubscript𝑋subscript𝑆0\mathbb{E}[Y^{(t)}|X_{S_{\star}}^{(t)}]\equiv\mathbb{E}[Y^{(0)}|X_{S_{\star}}^% {(0)}]blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] ≡ blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ] with Ssubscript𝑆S_{\star}italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT in (3.4). If 5 holds and M(t)superscript𝑀𝑡M^{(t)}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT satisfies a condition akin to 5 (see Section A.6), then Ssubscript𝑆S_{\star}italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT is the maximum set whose conditional expectation is transferable in that for any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] such that 𝔼[Y(t)|XSS(t)]𝔼[Y(t)|XS(t)]𝔼delimited-[]conditionalsuperscript𝑌𝑡superscriptsubscript𝑋subscript𝑆𝑆𝑡𝔼delimited-[]conditionalsuperscript𝑌𝑡superscriptsubscript𝑋subscript𝑆𝑡\mathbb{E}[Y^{(t)}|X_{S_{\star}\cup S}^{(t)}]\neq\mathbb{E}[Y^{(t)}|X_{S_{% \star}}^{(t)}]blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ∪ italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] ≠ blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ], one has 𝔼[Y(t)|XS(t)]𝔼[Y(0)|XS(0)]𝔼delimited-[]conditionalsuperscript𝑌𝑡superscriptsubscript𝑋𝑆𝑡𝔼delimited-[]conditionalsuperscript𝑌0superscriptsubscript𝑋𝑆0\mathbb{E}[Y^{(t)}|X_{S}^{(t)}]\neq\mathbb{E}[Y^{(0)}|X_{S}^{(0)}]blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] ≠ blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ].

4 A Unified Framework

The proposed FAIR-NN least squares is a special instance of our generic FAIR estimation framework, which homogenizes different risk loss and prediction models. Moreover, our framework also allows the user to incorporate additional structural knowledge into estimation such that identification is sometimes viable when ||=11|\mathcal{E}|=1| caligraphic_E | = 1. The invariance pursuit problem, the estimation method, and the non-asymptotic results will be presented in a unified manner in this section.

4.1 General Invariance Pursuit from Heterogeneous Environments

In this section, we formalize the problem of invariance pursuit using data from multiple environments, which admits the canonical nonparametric invariance pursuit in Section 1.1 as a special case.

Let Y𝑌Y\in\mathbb{R}italic_Y ∈ blackboard_R be the response variable and Xd𝑋superscript𝑑X\in\mathbb{R}^{d}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT be the explanatory variable. We consider the general setting in which we have collected data from multiple environments ={e1,,e||}subscript𝑒1subscript𝑒\mathcal{E}=\{e_{1},\ldots,e_{|\mathcal{E}|}\}caligraphic_E = { italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT | caligraphic_E | end_POSTSUBSCRIPT }, where \mathcal{E}caligraphic_E is the set of a finite number of environments. In each environment e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, we observe n𝑛nitalic_n i.i.d. observations {(Xi(e),Yi(e))}i=1nsuperscriptsubscriptsubscriptsuperscript𝑋𝑒𝑖subscriptsuperscript𝑌𝑒𝑖𝑖1𝑛\{(X^{(e)}_{i},Y^{(e)}_{i})\}_{i=1}^{n}{ ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT that follow from some distribution μ(e)superscript𝜇𝑒\mu^{(e)}italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT. Let Θg,ΘfΘsubscriptΘ𝑔subscriptΘ𝑓Θ\Theta_{g},\Theta_{f}\subseteq\Thetaroman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ⊆ roman_Θ be the class of prediction functions and testing functions, respectively. Our goal is to estimate the underlying invariant regression function gΘgsuperscript𝑔subscriptΘ𝑔g^{\star}\in\Theta_{g}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT satisfying the invariance structure

e𝔼[(Y(e)g(XS(e)))f(XS(e))]=0f[Θf]S,formulae-sequencefor-all𝑒formulae-sequence𝔼delimited-[]superscript𝑌𝑒superscript𝑔superscriptsubscript𝑋superscript𝑆𝑒𝑓subscriptsuperscript𝑋𝑒superscript𝑆0for-all𝑓subscriptdelimited-[]subscriptΘ𝑓superscript𝑆\displaystyle\forall e\in\mathcal{E}\qquad\mathbb{E}\left[\left(Y^{(e)}-g^{% \star}(X_{S^{\star}}^{(e)})\right)f(X^{(e)}_{S^{\star}})\right]=0\qquad\forall f% \in[\Theta_{f}]_{S^{\star}},∀ italic_e ∈ caligraphic_E blackboard_E [ ( italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ) italic_f ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ] = 0 ∀ italic_f ∈ [ roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , (4.1)

where Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is the unknown set of true important variables. We refer to the above problem as invariance pursuit or causal pursuit exchangeably, as no evidence against casualty with the available experiments.

The problem of estimating gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in (4.1) is a generalized version of the canonical nonparametric invariance pursuit with g=msuperscript𝑔superscript𝑚g^{\star}=m^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in (1.1) and Θf=Θg=ΘsubscriptΘ𝑓subscriptΘ𝑔Θ\Theta_{f}=\Theta_{g}=\Thetaroman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT = roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = roman_Θ. It depicts a general form and unifies several problems of interest in predecessors. For example, when ΘgsubscriptΘ𝑔\Theta_{g}roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT and ΘfsubscriptΘ𝑓\Theta_{f}roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT are all linear function classes, it reduces to the linear invariance pursuit problem, i.e., estimating g(x)=(β)x=(βS)xSsuperscript𝑔𝑥superscriptsuperscript𝛽top𝑥superscriptsubscriptsuperscript𝛽superscript𝑆topsubscript𝑥superscript𝑆g^{\star}(x)=(\beta^{\star})^{\top}x=(\beta^{\star}_{S^{\star}})^{\top}x_{S^{% \star}}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x ) = ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x = ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT with βdsuperscript𝛽superscript𝑑\beta^{\star}\in\mathbb{R}^{d}italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT satisfying supp(β)=Ssuppsuperscript𝛽superscript𝑆\mathrm{supp}(\beta^{\star})=S^{\star}roman_supp ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in the multi-environment linear regression (Fan et al.,, 2023) with linear invariance structure

𝔼[(Y(e)(βS)XS(e))Xj(e)]=0e,jS.formulae-sequence𝔼delimited-[]superscript𝑌𝑒superscriptsubscriptsuperscript𝛽superscript𝑆topsuperscriptsubscript𝑋superscript𝑆𝑒superscriptsubscript𝑋𝑗𝑒0formulae-sequencefor-all𝑒𝑗superscript𝑆\displaystyle\mathbb{E}\left[\left(Y^{(e)}-(\beta^{\star}_{S^{\star}})^{\top}X% _{S^{\star}}^{(e)}\right)X_{j}^{(e)}\right]=0\qquad\forall e\in\mathcal{E},j% \in S^{\star}.blackboard_E [ ( italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] = 0 ∀ italic_e ∈ caligraphic_E , italic_j ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT . (4.2)

Another example is the augmented linear invariance pursuit where ΘgsubscriptΘ𝑔\Theta_{g}roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT is linear and Θf={f(x)=j=1dβ0,jxj+β1,jϕ(xj)}subscriptΘ𝑓𝑓𝑥superscriptsubscript𝑗1𝑑subscript𝛽0𝑗subscript𝑥𝑗subscript𝛽1𝑗italic-ϕsubscript𝑥𝑗\Theta_{f}=\{f(x)=\sum_{j=1}^{d}\beta_{0,j}x_{j}+\beta_{1,j}\phi(x_{j})\}roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT = { italic_f ( italic_x ) = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 0 , italic_j end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 , italic_j end_POSTSUBSCRIPT italic_ϕ ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) } with some transform function ϕ::italic-ϕ\phi:\mathbb{R}\to\mathbb{R}italic_ϕ : blackboard_R → blackboard_R. This can further generalize this to multiple transformed testing functions such as ϕ1(xj)=xj2subscriptitalic-ϕ1subscript𝑥𝑗superscriptsubscript𝑥𝑗2\phi_{1}(x_{j})=x_{j}^{2}italic_ϕ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and ϕ2(xj)=|xj|subscriptitalic-ϕ2subscript𝑥𝑗subscript𝑥𝑗\phi_{2}(x_{j})=|x_{j}|italic_ϕ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = | italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT | but we keep one here for simplicity. The augmented linear invariance structure that realizes (4.1) in this case is

𝔼[(Y(e)(βS)XS(e))Xj(e)]=𝔼[(Y(e)(βS)XS(e))ϕ(Xj(e))]=0e,jSformulae-sequence𝔼delimited-[]superscript𝑌𝑒superscriptsubscriptsuperscript𝛽superscript𝑆topsuperscriptsubscript𝑋superscript𝑆𝑒superscriptsubscript𝑋𝑗𝑒𝔼delimited-[]superscript𝑌𝑒superscriptsubscriptsuperscript𝛽superscript𝑆topsuperscriptsubscript𝑋superscript𝑆𝑒italic-ϕsuperscriptsubscript𝑋𝑗𝑒0formulae-sequencefor-all𝑒𝑗superscript𝑆\displaystyle\mathbb{E}\left[\left(Y^{(e)}-(\beta^{\star}_{S^{\star}})^{\top}X% _{S^{\star}}^{(e)}\right)X_{j}^{(e)}\right]=\mathbb{E}\left[\left(Y^{(e)}-(% \beta^{\star}_{S^{\star}})^{\top}X_{S^{\star}}^{(e)}\right)\phi(X_{j}^{(e)})% \right]=0\qquad\forall e\in\mathcal{E},j\in S^{\star}blackboard_E [ ( italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] = blackboard_E [ ( italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) italic_ϕ ( italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ] = 0 ∀ italic_e ∈ caligraphic_E , italic_j ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT (4.3)

It coincides with the problem considered by Fan & Liao, (2014) when ||=11|\mathcal{E}|=1| caligraphic_E | = 1 and our method reduces to the FGMM method therein. The augmented linear invariance pursuit leverages further a part of the structural knowledge that 𝔼[Y(e)|XS(e)]=(βS)XS(e)𝔼delimited-[]conditionalsuperscript𝑌𝑒superscriptsubscript𝑋superscript𝑆𝑒superscriptsubscriptsuperscript𝛽superscript𝑆topsuperscriptsubscript𝑋superscript𝑆𝑒\mathbb{E}[Y^{(e)}|X_{S^{\star}}^{(e)}]=(\beta^{\star}_{S^{\star}})^{\top}X_{S% ^{\star}}^{(e)}blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] = ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT, which is much weaker than the assumption 𝔼[Y(e)|X(e)]=(βS)XS(e)𝔼delimited-[]conditionalsuperscript𝑌𝑒superscript𝑋𝑒superscriptsubscriptsuperscript𝛽superscript𝑆topsuperscriptsubscript𝑋superscript𝑆𝑒\mathbb{E}[Y^{(e)}|X^{(e)}]=(\beta^{\star}_{S^{\star}})^{\top}X_{S^{\star}}^{(% e)}blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] = ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT in the sparse linear regression. Identification is possible in this case even when ||=11|\mathcal{E}|=1| caligraphic_E | = 1. This is important for most biological medical studies, where data are usually collected in similar settings. In this case, the FAIR penalty eliminates endogenous spurious variables, making traditional variable selection methods applicable.

Remark 7.

We point out here that there are two kinds of spurious variables. One is endogenous spurious variables such as X2=subscript𝑋2absentX_{2}=italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = background color, and the other is exogenous spurious variables such as X3=subscript𝑋3absentX_{3}=italic_X start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = the time the photo was taken or the types of camera used. The former is harmful, and the latter is nearly harmless in statistical prediction, transfer learning, and even statistical attribution or causality, thinking of X3subscript𝑋3X_{3}italic_X start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT as a weak causal variable. The introduction of our FAIR method is to surely screen the endogenous spurious variables (Fan & Lv,, 2008). Exogenous spurious variables can be reduced by using commonly used statistical variable selection methods.

Similar to the discussion in Section 1.1, the main challenge here is the curse of endogeneity. To address this issue, we will harness the insight that the distributions of (X,Y)𝑋𝑌(X,Y)( italic_X , italic_Y ) across diverse environments capture the invariance structure (4.1). The central idea of this paper is to exploit both the heterogeneity among different environments, i.e., the shifts in population distributions μ(e)superscript𝜇𝑒\mu^{(e)}italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT, in conjunction with the above invariance structure (4.1) to pinpoint the invariant regression function gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

It should be noted that both gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT are determined by (Θg,Θf)subscriptΘ𝑔subscriptΘ𝑓(\Theta_{g},\Theta_{f})( roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) and \mathcal{E}caligraphic_E through the structure (4.1). It is required that Θg={gg:g,gΘg}ΘfsubscriptΘ𝑔conditional-set𝑔superscript𝑔𝑔superscript𝑔subscriptΘ𝑔subscriptΘ𝑓\partial\Theta_{g}=\{g-g^{\prime}:g,g^{\prime}\in\Theta_{g}\}\subseteq\Theta_{f}∂ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = { italic_g - italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT : italic_g , italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT } ⊆ roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT. In the case of Θf=ΘgsubscriptΘ𝑓subscriptΘ𝑔\Theta_{f}=\partial\Theta_{g}roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT = ∂ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT, one uses only heterogeneity among different environments, or the “invariance principle”, to identify the invariant regression function gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, as in (4.2). Heterogeneous environments are essential in this case. By choosing substantially large ΘfΘgsubscriptΘ𝑔subscriptΘ𝑓\Theta_{f}\supsetneq\partial\Theta_{g}roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ⊋ ∂ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT, one further injects the strong structural assumption that the invariant regression function lies in the class ΘgsubscriptΘ𝑔\Theta_{g}roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT rather than ΘfΘgsubscriptΘ𝑓subscriptΘ𝑔\Theta_{f}\setminus\Theta_{g}roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ∖ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT as in (4.3). In this case, one leverages both heterogeneity among environments, i.e., the “invariance principle”, and the mentioned prior structure knowledge, i.e., the “asymmetry principle”, to jointly identify gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Only one environment may be enough for identifying gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT when the intersection of both principles gives sufficient conditions.

4.2 General FAIR Estimation Framework

Let :×:\ell:\mathbb{R}\times\mathbb{R}\to\mathbb{R}roman_ℓ : blackboard_R × blackboard_R → blackboard_R be a user-determined risk loss such that

(y,v)v=(vy)ψ(v)and2(y,v)v2>0,formulae-sequence𝑦𝑣𝑣𝑣𝑦𝜓𝑣andsuperscript2𝑦𝑣superscript𝑣20\displaystyle\frac{\partial\ell(y,v)}{\partial v}=(v-y)\psi(v)\qquad\text{and}% \qquad\frac{\partial^{2}\ell(y,v)}{\partial v^{2}}>0,divide start_ARG ∂ roman_ℓ ( italic_y , italic_v ) end_ARG start_ARG ∂ italic_v end_ARG = ( italic_v - italic_y ) italic_ψ ( italic_v ) and divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( italic_y , italic_v ) end_ARG start_ARG ∂ italic_v start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG > 0 , (4.4)

which is slightly more general than the quasi-likelihood in the generalized linear model (Nelder & Wedderburn,, 1972). The constraints in (4.4) ensure that the conditional expectation aligns with the unique global minima and can be satisfied by various risk losses. Two leading examples are the least square loss (y,v)=12(yv)2𝑦𝑣12superscript𝑦𝑣2\ell(y,v)=\frac{1}{2}(y-v)^{2}roman_ℓ ( italic_y , italic_v ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_y - italic_v ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT with ψ(v)=1𝜓𝑣1\psi(v)=1italic_ψ ( italic_v ) = 1 for regression, and the cross-entropy loss (y,v)=log(1v)ylog{v/(1v)}𝑦𝑣1𝑣𝑦𝑣1𝑣\ell(y,v)=-\log(1-v)-y\log\{v/(1-v)\}roman_ℓ ( italic_y , italic_v ) = - roman_log ( 1 - italic_v ) - italic_y roman_log { italic_v / ( 1 - italic_v ) } with ψ(v)=1/{v(1v)}𝜓𝑣1𝑣1𝑣\psi(v)=1/\{v(1-v)\}italic_ψ ( italic_v ) = 1 / { italic_v ( 1 - italic_v ) } for classification.

Given all the data {{(Xi(e),Yi(e))}i=1n}esubscriptsuperscriptsubscriptsuperscriptsubscript𝑋𝑖𝑒superscriptsubscript𝑌𝑖𝑒𝑖1𝑛𝑒\{\{(X_{i}^{(e)},Y_{i}^{(e)})\}_{i=1}^{n}\}_{e\in\mathcal{E}}{ { ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT from heterogeneous environments together with (Θg,Θf)subscriptΘ𝑔subscriptΘ𝑓(\Theta_{g},\Theta_{f})( roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) that may encode part of the prior information when ΘgΘsubscriptΘ𝑔Θ\Theta_{g}\neq\Thetaroman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ≠ roman_Θ, our proposed focused adversarial invariance regularized estimator (FAIR estimator) is the solution to the subsequent minimax optimization objective

g^argming𝒢supf{Sg}||𝖱^(g)+γ𝖩^(g,f)=:𝖰^γ(g,f).^𝑔subscriptargmin𝑔𝒢subscriptsupremumsuperscript𝑓superscriptsubscriptsubscript𝑆𝑔subscript^𝖱𝑔𝛾^𝖩𝑔superscript𝑓:absentsubscript^𝖰𝛾𝑔superscript𝑓\displaystyle\widehat{g}\in\mathop{\mathrm{argmin}}_{g\in\mathcal{G}}\sup_{f^{% \mathcal{E}}\in\{\mathcal{F}_{S_{g}}\}^{|\mathcal{E}|}}\underbrace{\widehat{% \mathsf{R}}(g)+\gamma\widehat{\mathsf{J}}(g,f^{\mathcal{E}})}_{=:\widehat{% \mathsf{Q}}_{\gamma}(g,{f}^{\mathcal{E}})}.over^ start_ARG italic_g end_ARG ∈ roman_argmin start_POSTSUBSCRIPT italic_g ∈ caligraphic_G end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG over^ start_ARG sansserif_R end_ARG ( italic_g ) + italic_γ over^ start_ARG sansserif_J end_ARG ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT = : over^ start_ARG sansserif_Q end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT . (4.5)

where 𝒢Θg𝒢subscriptΘ𝑔\mathcal{G}\subseteq\Theta_{g}caligraphic_G ⊆ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT and ΘfsubscriptΘ𝑓\mathcal{F}\subseteq\Theta_{f}caligraphic_F ⊆ roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT are function classes that approximates ΘgsubscriptΘ𝑔\Theta_{g}roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT and ΘfsubscriptΘ𝑓\Theta_{f}roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT, respectively. Here 𝖱^(g)^𝖱𝑔\widehat{\mathsf{R}}(g)over^ start_ARG sansserif_R end_ARG ( italic_g ) is the pooled sample mean of the user-specified loss across all the environments \mathcal{E}caligraphic_E:

𝖱^(g)=1||e𝔼^[(Y(e),g(X(e)))]=1||ne,i[n](Yi(e),g(Xi(e))),^𝖱𝑔1subscript𝑒^𝔼delimited-[]superscript𝑌𝑒𝑔superscript𝑋𝑒1𝑛subscriptformulae-sequence𝑒𝑖delimited-[]𝑛subscriptsuperscript𝑌𝑒𝑖𝑔subscriptsuperscript𝑋𝑒𝑖\displaystyle\widehat{\mathsf{R}}(g)=\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal% {E}}\widehat{\mathbb{E}}\left[\ell(Y^{(e)},g(X^{(e)}))\right]=\frac{1}{|% \mathcal{E}|\cdot n}\sum_{e\in\mathcal{E},i\in[n]}\ell(Y^{(e)}_{i},g(X^{(e)}_{% i})),over^ start_ARG sansserif_R end_ARG ( italic_g ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT over^ start_ARG blackboard_E end_ARG [ roman_ℓ ( italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_g ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ) ] = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | ⋅ italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E , italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT roman_ℓ ( italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_g ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , (4.6)

γ𝛾\gammaitalic_γ is the hyper-parameter to be determined, and 𝖩^(g,f)^𝖩𝑔superscript𝑓\widehat{\mathsf{J}}(g,f^{\mathcal{E}})over^ start_ARG sansserif_J end_ARG ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) is defined the same as (2.2).

Discussions and Extensions. From a high-level perspective, our proposed FAIR estimator searches for the most predictive variable set S𝑆Sitalic_S that preserves some invariance structure imposed by the specification of (Θg,Θf)subscriptΘ𝑔subscriptΘ𝑓(\Theta_{g},\Theta_{f})( roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ). The FAIR estimation framework presented has several limitations: (1) the loss \ellroman_ℓ has restrictions in that the conditional expectation must uniquely minimize it; (2) the environment label is discrete; and (3) the discussion still lies within the variable selection level invariance rather than general representation level invariance. We will discuss in Section A.3 that our entire framework can be easily extended to the cases where (1) and (2) fail to hold. We add some discussions on the rationale, comparison with IRM, and extension on (3) in Section A.2.

4.3 Sketch of the Generic Result and Its Applications

The non-asymptotic results in Section 2 can be extended to be the result for the general FAIR estimation framework, formally stated in Theorem 4, which unifies the identification condition and L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT estimation errors for specific (Θg,Θf)subscriptΘ𝑔subscriptΘ𝑓(\Theta_{g},\Theta_{f})( roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) or (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ) under the least squares loss (y,v)=12(yv)2𝑦𝑣12superscript𝑦𝑣2\ell(y,v)=\frac{1}{2}(y-v)^{2}roman_ℓ ( italic_y , italic_v ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_y - italic_v ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. We sketch the main idea and informal statement here and defer the complete result and applications to Appendix B.

Suppose [Θg]Ssubscriptdelimited-[]subscriptΘ𝑔𝑆[\Theta_{g}]_{S}[ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT and [Θf]Ssubscriptdelimited-[]subscriptΘ𝑓𝑆[\Theta_{f}]_{S}[ roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT are closed subspaces of ΘSsubscriptΘ𝑆\Theta_{S}roman_Θ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT for any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] so that one can define

g¯(S)(x)=argming[Θg]Sgm¯(S)2andf(e,S)(x)=argminf[Θf]Sfm(e,S)2,e.formulae-sequencesuperscript¯𝑔𝑆𝑥subscriptargmin𝑔subscriptdelimited-[]subscriptΘ𝑔𝑆subscriptnorm𝑔superscript¯𝑚𝑆2andsuperscript𝑓𝑒𝑆𝑥subscriptargmin𝑓subscriptdelimited-[]subscriptΘ𝑓𝑆subscriptnorm𝑓superscript𝑚𝑒𝑆2𝑒\displaystyle\bar{g}^{(S)}(x)=\mathop{\mathrm{argmin}}_{g\in[\Theta_{g}]_{S}}% \|g-\bar{m}^{(S)}\|_{2}\qquad\text{and}\qquad f^{(e,S)}(x)=\mathop{\mathrm{% argmin}}_{f\in[\Theta_{f}]_{S}}\|f-m^{(e,S)}\|_{2,e}.over¯ start_ARG italic_g end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ( italic_x ) = roman_argmin start_POSTSUBSCRIPT italic_g ∈ [ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_g - over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and italic_f start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ( italic_x ) = roman_argmin start_POSTSUBSCRIPT italic_f ∈ [ roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_f - italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT .

In this case, the invariant structure and the invariant regression function in (4.1) can be simplified as

f(e,S)(x)g¯(S)(x):=g(x).superscript𝑓𝑒superscript𝑆𝑥superscript¯𝑔superscript𝑆𝑥assignsuperscript𝑔𝑥\displaystyle f^{(e,S^{\star})}(x)\equiv\bar{g}^{(S^{\star})}(x):=g^{\star}(x).italic_f start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ( italic_x ) ≡ over¯ start_ARG italic_g end_ARG start_POSTSUPERSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ( italic_x ) := italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x ) . (4.7)

Similar to the nonparametric bias mean and bias variance in Remark 3, we can define the generalized bias mean and bias variance with respect to (Θg,Θf)subscriptΘ𝑔subscriptΘ𝑓(\Theta_{g},\Theta_{f})( roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) as 𝖻(S)=g¯(SS)g22𝖻𝑆superscriptsubscriptnormsuperscript¯𝑔𝑆superscript𝑆superscript𝑔22\mathsf{b}(S)=\|\bar{g}^{(S\cup S^{\star})}-g^{\star}\|_{2}^{2}sansserif_b ( italic_S ) = ∥ over¯ start_ARG italic_g end_ARG start_POSTSUPERSCRIPT ( italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and 𝖽¯(S)=1||eg¯(S)f(e,S)2,e2¯𝖽𝑆1subscript𝑒superscriptsubscriptnormsuperscript¯𝑔𝑆superscript𝑓𝑒𝑆2𝑒2\bar{\mathsf{d}}(S)=\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}\|\bar{g}^{(S% )}-f^{(e,S)}\|_{2,e}^{2}over¯ start_ARG sansserif_d end_ARG ( italic_S ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ over¯ start_ARG italic_g end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT - italic_f start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. The general identification condition akin to 2 is

S[d],𝖻(S)>0𝖽¯(S)>0,formulae-sequencefor-all𝑆delimited-[]𝑑𝖻𝑆0¯𝖽𝑆0\displaystyle\forall~{}S\subseteq[d],\qquad\mathsf{b}(S)>0~{}~{}% \Longrightarrow~{}~{}\bar{\mathsf{d}}(S)>0,∀ italic_S ⊆ [ italic_d ] , sansserif_b ( italic_S ) > 0 ⟹ over¯ start_ARG sansserif_d end_ARG ( italic_S ) > 0 , (4.8)

The above condition requires that whenever incorporating more variables in S𝑆Sitalic_S will lead to better prediction performance, the set S𝑆Sitalic_S will not satisfy the invariance structure (4.1). 2 instantiates (4.8) by letting 𝖽¯(S)=𝖽¯𝙽𝙽(S)¯𝖽𝑆subscript¯𝖽𝙽𝙽𝑆\bar{\mathsf{d}}(S)=\bar{\mathsf{d}}_{\mathtt{NN}}(S)over¯ start_ARG sansserif_d end_ARG ( italic_S ) = over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) and 𝖻(S)=𝖻𝙽𝙽(S)𝖻𝑆subscript𝖻𝙽𝙽𝑆{\mathsf{b}}(S)=\mathsf{b}_{\mathtt{NN}}(S)sansserif_b ( italic_S ) = sansserif_b start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) with (𝖻𝙽𝙽(S),𝖽¯𝙽𝙽(S))subscript𝖻𝙽𝙽𝑆subscript¯𝖽𝙽𝙽𝑆(\mathsf{b}_{\mathtt{NN}}(S),\bar{\mathsf{d}}_{\mathtt{NN}}(S))( sansserif_b start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) , over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT ( italic_S ) ) defined in (2.4).

Theorem 3 (Main Result for FAIR Least Squares Estimator, Informal).

Under (4.7), (4.8) and some regularity conditions in regression, one can consistently estimate gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT by choosing γ8supS:𝖻(S)>0{𝖻(S)/𝖽¯(S)}𝛾8subscriptsupremum:𝑆𝖻𝑆0𝖻𝑆¯𝖽𝑆\gamma\geq 8\sup_{S:\mathsf{b}(S)>0}\{\mathsf{b}(S)/\bar{\mathsf{d}}(S)\}italic_γ ≥ 8 roman_sup start_POSTSUBSCRIPT italic_S : sansserif_b ( italic_S ) > 0 end_POSTSUBSCRIPT { sansserif_b ( italic_S ) / over¯ start_ARG sansserif_d end_ARG ( italic_S ) }. In this case, the FAIR estimator g^^𝑔\widehat{g}over^ start_ARG italic_g end_ARG in (4.5) with (y,v)=12(yv)2𝑦𝑣12superscript𝑦𝑣2\ell(y,v)=\frac{1}{2}(y-v)^{2}roman_ℓ ( italic_y , italic_v ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_y - italic_v ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT satisfies, for any n3𝑛3n\geq 3italic_n ≥ 3, w.h.p.,

g^g2C1δ𝚜𝚝𝚘𝚌+δ𝚊𝚙𝚙𝚛𝚘𝚡+γ(δ𝚜𝚝𝚘𝚌+δ𝚊𝚙𝚙𝚛𝚘𝚡)1{δ𝚜𝚝𝚘𝚌+δ𝚊𝚙𝚙𝚛𝚘𝚡s1+γ}subscriptnorm^𝑔superscript𝑔2subscript𝐶1subscript𝛿𝚜𝚝𝚘𝚌superscriptsubscript𝛿𝚊𝚙𝚙𝚛𝚘𝚡𝛾subscript𝛿𝚜𝚝𝚘𝚌subscript𝛿𝚊𝚙𝚙𝚛𝚘𝚡subscript1subscript𝛿𝚜𝚝𝚘𝚌subscript𝛿𝚊𝚙𝚙𝚛𝚘𝚡𝑠1𝛾\displaystyle\frac{\|\widehat{g}-g^{\star}\|_{2}}{C_{1}}\leq\delta_{\mathtt{% stoc}}+\delta_{\mathtt{approx}}^{\star}+\gamma(\delta_{\mathtt{stoc}}+\delta_{% \mathtt{approx}})1_{\{\delta_{\mathtt{stoc}}+\delta_{\mathtt{approx}}\geq\frac% {s}{1+\gamma}\}}divide start_ARG ∥ over^ start_ARG italic_g end_ARG - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ≤ italic_δ start_POSTSUBSCRIPT typewriter_stoc end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_approx end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + italic_γ ( italic_δ start_POSTSUBSCRIPT typewriter_stoc end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_approx end_POSTSUBSCRIPT ) 1 start_POSTSUBSCRIPT { italic_δ start_POSTSUBSCRIPT typewriter_stoc end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_approx end_POSTSUBSCRIPT ≥ divide start_ARG italic_s end_ARG start_ARG 1 + italic_γ end_ARG } end_POSTSUBSCRIPT (4.9)

Here δ𝚜𝚝𝚘𝚌subscript𝛿𝚜𝚝𝚘𝚌\delta_{\mathtt{stoc}}italic_δ start_POSTSUBSCRIPT typewriter_stoc end_POSTSUBSCRIPT is the stochastic error characterized by the local Rademacher complexity of ,𝒢𝒢\mathcal{F},\partial\mathcal{G}caligraphic_F , ∂ caligraphic_G and n𝑛nitalic_n, δ𝚊𝚙𝚙𝚛𝚘𝚡superscriptsubscript𝛿𝚊𝚙𝚙𝚛𝚘𝚡\delta_{\mathtt{approx}}^{\star}italic_δ start_POSTSUBSCRIPT typewriter_approx end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT measures certain approximation error of (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ) w.r.t. gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, and δ𝚊𝚙𝚙𝚛𝚘𝚡subscript𝛿𝚊𝚙𝚙𝚛𝚘𝚡\delta_{\mathtt{approx}}italic_δ start_POSTSUBSCRIPT typewriter_approx end_POSTSUBSCRIPT measures the worst case approximation error of (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ) w.r.t. all the {f(e,S)}superscript𝑓𝑒𝑆\{f^{(e,S)}\}{ italic_f start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT }. The constant s>0𝑠0s>0italic_s > 0 is the signal strength related to minS:𝖽¯(S)>0𝖽¯(S)subscript:𝑆¯𝖽𝑆0¯𝖽𝑆\min_{S:\bar{\mathsf{d}}(S)>0}\bar{\mathsf{d}}(S)roman_min start_POSTSUBSCRIPT italic_S : over¯ start_ARG sansserif_d end_ARG ( italic_S ) > 0 end_POSTSUBSCRIPT over¯ start_ARG sansserif_d end_ARG ( italic_S ) and minS:SSinfg[Θg]Sgg2subscript:𝑆superscript𝑆𝑆subscriptinfimum𝑔subscriptdelimited-[]subscriptΘ𝑔𝑆subscriptnorm𝑔superscript𝑔2\min_{S:S^{\star}\setminus S\neq\emptyset}\inf_{g\in[\Theta_{g}]_{S}}\|g-g^{% \star}\|_{2}roman_min start_POSTSUBSCRIPT italic_S : italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∖ italic_S ≠ ∅ end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_g ∈ [ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_g - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and C1subscript𝐶1C_{1}italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is a universal constant independent of the two quantities.

ΘgsubscriptΘ𝑔\Theta_{g}roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ΘfsubscriptΘ𝑓\Theta_{f}roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT 𝒢𝒢\mathcal{G}caligraphic_G \mathcal{F}caligraphic_F Priors ||=11|\mathcal{E}|=1| caligraphic_E | = 1 Ident Result
Linear Linear Linear Linear None Impossible Thm 8
Linear Linear w/ ϕitalic-ϕ\phiitalic_ϕ Linear Linear w/ ϕitalic-ϕ\phiitalic_ϕ Nearly Linear Possible Thm 9
Linear ΘΘ\Thetaroman_Θ Linear NN Linear Possible Thm 10
Additive ΘΘ\Thetaroman_Θ Additive NN NN Additive Impossible Thm 7
ΘΘ\Thetaroman_Θ ΘΘ\Thetaroman_Θ NN NN None Impossible Thm 1
Table 1: Applications of Theorem 4. Recall that ΘΘ\Thetaroman_Θ is the set of all L2(μ¯x)subscript𝐿2subscript¯𝜇𝑥L_{2}(\bar{\mu}_{x})italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) functions. For the function classes in columns Θg,Θf,𝒢subscriptΘ𝑔subscriptΘ𝑓𝒢\Theta_{g},\Theta_{f},\mathcal{G}roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , caligraphic_G and \mathcal{F}caligraphic_F, “Linear” is {f(x)=j=1dβjxj}𝑓𝑥superscriptsubscript𝑗1𝑑subscript𝛽𝑗subscript𝑥𝑗\{f(x)=\sum_{j=1}^{d}\beta_{j}x_{j}\}{ italic_f ( italic_x ) = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT }, “Linear w/ ϕitalic-ϕ\phiitalic_ϕ” is {f(x)=j=1dβjxj+αjϕ(xj)}𝑓𝑥superscriptsubscript𝑗1𝑑subscript𝛽𝑗subscript𝑥𝑗subscript𝛼𝑗italic-ϕsubscript𝑥𝑗\{f(x)=\sum_{j=1}^{d}\beta_{j}x_{j}+\alpha_{j}\phi(x_{j})\}{ italic_f ( italic_x ) = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_α start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_ϕ ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) }, “NN” is deep ReLU network class, “Additive” is the additive functions {f(x)=j=1dfj(xj)}𝑓𝑥superscriptsubscript𝑗1𝑑subscript𝑓𝑗subscript𝑥𝑗\{f(x)=\sum_{j=1}^{d}f_{j}(x_{j})\}{ italic_f ( italic_x ) = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) } and “Additive NN” is a structured neural network approximating additive functions. The column “Priors” indicates what prior structure knowledge is injected by the choice of (Θg,Θf)subscriptΘ𝑔subscriptΘ𝑓(\Theta_{g},\Theta_{f})( roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ). For the second row, it is “nearly linear” given it only requires that the residual is uncorrelated with all the ϕ(xj)italic-ϕsubscript𝑥𝑗\phi(x_{j})italic_ϕ ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) with jS𝑗superscript𝑆j\in S^{\star}italic_j ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT; the prior for the third row is exactly linear provided Θf=ΘsubscriptΘ𝑓Θ\Theta_{f}=\Thetaroman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT = roman_Θ. The column “||=11|\mathcal{E}|=1| caligraphic_E | = 1 Ident” indicates whether identification for Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in (1.1) is possible with only one environment.

The complete and rigorous statement is deferred to Theorem 4 in Section B.1, with more loss function \ellroman_ℓ in Theorem 5. These generic results can characterize several advantages in our FAIR framework’s sample efficiency. Firstly, the error (4.9) is structure-agnostic in that it is represented by the sum of approximation error and stochastic error, indicating that (1) our framework can fully exploit the capability of (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ) in learning low-dimensional structures, and (2) it has almost no additional cost in sample efficiency compared with standard regression. Moreover, the error rate applies to any n𝑛nitalic_n, implying the estimation error is guaranteed even when it selects the wrong variable, especially when the signal s𝑠sitalic_s is weak. Finally, though a large enough regularization hyper-parameter γ𝛾\gammaitalic_γ is needed to guarantee consistent estimation, the error will be free of γ𝛾\gammaitalic_γ when n𝑛nitalic_n is large enough. We also apply our unified result to various specifications of (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ), including the non-asymptotic results in identification and convergence rate; see a summary in Table 1.

5 Experiments

5.1 An End-to-End Implementation

We realize the minimax optimization using gradient descent ascent, a similar approach adopted in GAN (Goodfellow et al.,, 2014) training. The main challenge here is how to do “focused regularization” which enforces f(e)Sgsuperscript𝑓𝑒subscriptsubscript𝑆𝑔f^{(e)}\in\mathcal{F}_{S_{g}}italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT. Here we consider a re-parameterization trick that disentangles the function g𝑔gitalic_g and the the variable Sgsubscript𝑆𝑔S_{g}italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT it selects. To start with, we can write g(x)=g(ax)=g(x1a1,,xdad)𝑔𝑥𝑔direct-product𝑎𝑥𝑔subscript𝑥1subscript𝑎1subscript𝑥𝑑subscript𝑎𝑑g(x)=g(a\odot x)=g(x_{1}a_{1},\ldots,x_{d}a_{d})italic_g ( italic_x ) = italic_g ( italic_a ⊙ italic_x ) = italic_g ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) with a{0,1}d𝑎superscript01𝑑a\in\{0,1\}^{d}italic_a ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT indicating presence and absence of variables. Then the objective (4.5) can be written as

(g^,a^)argming𝒢,a{0,1}dsupf{}||𝖱^(g(a))+γ𝖩^(g(a),f(a))\displaystyle(\widehat{g},\widehat{a})\in\mathop{\mathrm{argmin}}_{g\in% \mathcal{G},a\in\{0,1\}^{d}}\sup_{f^{\mathcal{E}}\in\{\mathcal{F}\}^{|\mathcal% {E}|}}\widehat{\mathsf{R}}(g(a\odot\cdot))+\gamma\widehat{\mathsf{J}}(g(a\odot% \cdot),f^{\mathcal{E}}(a\odot\cdot))( over^ start_ARG italic_g end_ARG , over^ start_ARG italic_a end_ARG ) ∈ roman_argmin start_POSTSUBSCRIPT italic_g ∈ caligraphic_G , italic_a ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { caligraphic_F } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over^ start_ARG sansserif_R end_ARG ( italic_g ( italic_a ⊙ ⋅ ) ) + italic_γ over^ start_ARG sansserif_J end_ARG ( italic_g ( italic_a ⊙ ⋅ ) , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ( italic_a ⊙ ⋅ ) ) (5.1)

A naive implementation is to first enumerate all the possible a{0,1}d𝑎superscript01𝑑a\in\{0,1\}^{d}italic_a ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and then do gradient descent ascent for given a𝑎aitalic_a, which is computationally inefficient. To avoid this, we first rewrite the optimization as a “continuous” optimization:

(g^,w^)argming𝒢,wRdsupf{}||𝔼B(w)[𝖱^(g(B(w)))+γ𝖩^(g(B(w)),f(B(w)))],\displaystyle(\widehat{g},\widehat{w})\in\mathop{\mathrm{argmin}}_{g\in% \mathcal{G},w\in R^{d}}\sup_{f^{\mathcal{E}}\in\{\mathcal{F}\}^{|\mathcal{E}|}% }\mathbb{E}_{B(w)}\left[\widehat{\mathsf{R}}(g(B(w)\odot\cdot))+\gamma\widehat% {\mathsf{J}}(g(B(w)\odot\cdot),f^{\mathcal{E}}(B(w)\odot\cdot))\right],( over^ start_ARG italic_g end_ARG , over^ start_ARG italic_w end_ARG ) ∈ roman_argmin start_POSTSUBSCRIPT italic_g ∈ caligraphic_G , italic_w ∈ italic_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { caligraphic_F } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_B ( italic_w ) end_POSTSUBSCRIPT [ over^ start_ARG sansserif_R end_ARG ( italic_g ( italic_B ( italic_w ) ⊙ ⋅ ) ) + italic_γ over^ start_ARG sansserif_J end_ARG ( italic_g ( italic_B ( italic_w ) ⊙ ⋅ ) , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ( italic_B ( italic_w ) ⊙ ⋅ ) ) ] ,

where the jthsuperscript𝑗𝑡j^{th}italic_j start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT component of B(w){0,1}d𝐵𝑤superscript01𝑑B(w)\in\{0,1\}^{d}italic_B ( italic_w ) ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT follows an independent Bernoulli with probability of success sig(wj)=exp(wj)/(1+exp(wj))sigsubscript𝑤𝑗subscript𝑤𝑗1subscript𝑤𝑗{\mathrm{sig}(w_{j})}=\exp(w_{j})/(1+\exp(w_{j}))roman_sig ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = roman_exp ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) / ( 1 + roman_exp ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ). This is easily seen by taking w^=logit(a^)=log(a^1a^)^𝑤logit^𝑎^𝑎1^𝑎\widehat{w}=\mbox{logit}(\widehat{a})=\log(\frac{\widehat{a}}{1-\widehat{a}})over^ start_ARG italic_w end_ARG = logit ( over^ start_ARG italic_a end_ARG ) = roman_log ( divide start_ARG over^ start_ARG italic_a end_ARG end_ARG start_ARG 1 - over^ start_ARG italic_a end_ARG end_ARG ). Note that Bj(wj)=I(logit(Uj)wj)subscript𝐵𝑗subscript𝑤𝑗𝐼logitsubscript𝑈𝑗subscript𝑤𝑗B_{j}(w_{j})=I(\mbox{logit}(U_{j})\leq w_{j})italic_B start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = italic_I ( logit ( italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ≤ italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) is discontinuous in wjsubscript𝑤𝑗w_{j}italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT where Ujsimilar-tosubscript𝑈𝑗absentU_{j}\simitalic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∼ uniform[0,1], but can be approximated as

Bj(wj)11+e(logit(Uj)wj))/τVτ(Uj,wj)asτ0+,B_{j}(w_{j})\approx\frac{1}{1+e^{(\mbox{\scriptsize logit}(U_{j})-w_{j}))/\tau% }}\equiv V_{\tau}(U_{j},w_{j})~{}~{}~{}~{}\text{as}~{}~{}~{}~{}\tau\to 0^{+},italic_B start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ≈ divide start_ARG 1 end_ARG start_ARG 1 + italic_e start_POSTSUPERSCRIPT ( logit ( italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ) / italic_τ end_POSTSUPERSCRIPT end_ARG ≡ italic_V start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) as italic_τ → 0 start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT , (5.2)

for which its gradient can be taken. Let

Aτ(U,w)=(Vτ(U1,w1),,Vτ(Ud,wd))dsubscript𝐴𝜏𝑈𝑤superscriptsubscript𝑉𝜏subscript𝑈1subscript𝑤1subscript𝑉𝜏subscript𝑈𝑑subscript𝑤𝑑topsuperscript𝑑\displaystyle A_{\tau}(U,w)=(V_{\tau}(U_{1},w_{1}),\ldots,V_{\tau}(U_{d},w_{d}% ))^{\top}\in\mathbb{R}^{d}italic_A start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_U , italic_w ) = ( italic_V start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , italic_V start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_U start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT

with {Uj}j=1dsuperscriptsubscriptsubscript𝑈𝑗𝑗1𝑑\{U_{j}\}_{j=1}^{d}{ italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT being i.i.d. uniform random variables. One can approximate of the original objective (5.1) by

(θ^,w^)^𝜃^𝑤\displaystyle(\widehat{\theta},\widehat{w})( over^ start_ARG italic_θ end_ARG , over^ start_ARG italic_w end_ARG ) argminθNg,wdsupe,ϕ(e)Nf𝔼Aτ(U,w)[𝖱^(g(A;θ))+γ𝖩^(g(A),f(A;{ϕ(e)}e))]𝔼U[𝖫^(Aτ(U,w),θ,{ϕ(e)}e)],\displaystyle\in\mathop{\mathrm{argmin}}_{\theta\in\mathbb{R}^{N_{g}},w\in% \mathbb{R}^{d}}\sup_{\forall e\in\mathcal{E},\phi^{(e)}\in\mathbb{R}^{N_{f}}}% \underbrace{\mathbb{E}_{A_{\tau}(U,w)}\left[\widehat{\mathsf{R}}(g(A\odot\cdot% ;\theta))+\gamma\widehat{\mathsf{J}}(g(A\odot\cdot),f^{\mathcal{E}}(A\odot% \cdot;\{\phi^{(e)}\}_{e\in\mathcal{E}}))\right]}_{\mathbb{E}_{U}[\widehat{% \mathsf{L}}(A_{\tau}(U,w),\theta,\{\phi^{(e)}\}_{e\in\mathcal{E}})]},∈ roman_argmin start_POSTSUBSCRIPT italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_w ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT ∀ italic_e ∈ caligraphic_E , italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_U , italic_w ) end_POSTSUBSCRIPT [ over^ start_ARG sansserif_R end_ARG ( italic_g ( italic_A ⊙ ⋅ ; italic_θ ) ) + italic_γ over^ start_ARG sansserif_J end_ARG ( italic_g ( italic_A ⊙ ⋅ ) , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ( italic_A ⊙ ⋅ ; { italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ) ) ] end_ARG start_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT [ over^ start_ARG sansserif_L end_ARG ( italic_A start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_U , italic_w ) , italic_θ , { italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ) ] end_POSTSUBSCRIPT , (5.3)

where parametrizations of g𝒢𝑔𝒢g\in\mathcal{G}italic_g ∈ caligraphic_G and fesuperscript𝑓𝑒f^{e}\in\mathcal{F}italic_f start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ∈ caligraphic_F are used. Since logit(Uj)=dUj,1Uj,2superscript𝑑logitsubscript𝑈𝑗subscript𝑈𝑗1subscript𝑈𝑗2\mbox{logit}(U_{j})\stackrel{{\scriptstyle d}}{{=}}U_{j,1}-U_{j,2}logit ( italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_RELOP SUPERSCRIPTOP start_ARG = end_ARG start_ARG italic_d end_ARG end_RELOP italic_U start_POSTSUBSCRIPT italic_j , 1 end_POSTSUBSCRIPT - italic_U start_POSTSUBSCRIPT italic_j , 2 end_POSTSUBSCRIPT with {Uj,1,Uj,2}j=1dsuperscriptsubscriptsubscript𝑈𝑗1subscript𝑈𝑗2𝑗1𝑑\{U_{j,1},U_{j,2}\}_{j=1}^{d}{ italic_U start_POSTSUBSCRIPT italic_j , 1 end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_j , 2 end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT being i.i.d. Gumbel(0,1) random variables, the approximation (5.2) is also referred to as the Gumble approximation.

One can use similar implementation tricks widely used in stochastic gradient descent with Gumbel approximation that gradually anneals the Gumbel approximation hyperparameter τ𝜏\tauitalic_τ. We defer the formal pseudo-code Algorithm 1 to the Section C.1. The code to reproduce the results in this section can be found at https://github.com/wmyw96/FAIR.

5.2 Simulations

In this section, we present the simulation result for the FAIR-Linear estimator and FAIR-NN estimator implemented by the Gumbel approximation trick and gradient descent ascent algorithm.

5.2.1 Finite Performance of FAIR-Linear Estimator

Data Generating Process. We consider the case where ||=22|\mathcal{E}|=2| caligraphic_E | = 2 and the data (X(e),Y(e))superscript𝑋𝑒superscript𝑌𝑒(X^{(e)},Y^{(e)})( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) in each environment e{0,1}𝑒01e\in\{0,1\}italic_e ∈ { 0 , 1 } are generated from two SCMs sharing the same causal relationship between variables. For each trial, we first generate the parent-children relationship among the variables. We enumerate all the i[d+1]𝑖delimited-[]𝑑1i\in[d+1]italic_i ∈ [ italic_d + 1 ]. For each i[d+1]𝑖delimited-[]𝑑1i\in[d+1]italic_i ∈ [ italic_d + 1 ], we randomly pick at most 4444 parents for the variable Zisubscript𝑍𝑖Z_{i}italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT from {Z1,,Zi1}subscript𝑍1subscript𝑍𝑖1\{Z_{1},\ldots,Z_{i-1}\}{ italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_Z start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT }, this step ensures that the induced graph is a DAG. We use fixed d=70𝑑70d=70italic_d = 70, and let the variable Z36subscript𝑍36Z_{36}italic_Z start_POSTSUBSCRIPT 36 end_POSTSUBSCRIPT be Y𝑌Yitalic_Y and the rest variables constitute the covariate X𝑋Xitalic_X, that is, we let (Z1,,Z35,Z36,Z37,,Z71)=(X1,,X35,Y,X36,,X70)subscript𝑍1subscript𝑍35subscript𝑍36subscript𝑍37subscript𝑍71subscript𝑋1subscript𝑋35𝑌subscript𝑋36subscript𝑋70(Z_{1},\ldots,Z_{35},Z_{36},Z_{37},\ldots,Z_{71})=(X_{1},\ldots,X_{35},Y,X_{36% },\ldots,X_{70})( italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_Z start_POSTSUBSCRIPT 35 end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT 36 end_POSTSUBSCRIPT , italic_Z start_POSTSUBSCRIPT 37 end_POSTSUBSCRIPT , … , italic_Z start_POSTSUBSCRIPT 71 end_POSTSUBSCRIPT ) = ( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT 35 end_POSTSUBSCRIPT , italic_Y , italic_X start_POSTSUBSCRIPT 36 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT 70 end_POSTSUBSCRIPT ). We also enforce that Y𝑌Yitalic_Y has at least 5555 parents and at least 5555 children by adding parents and children when needed. The structural assignment for each variable Zjsubscript𝑍𝑗Z_{j}italic_Z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT is defined as

Zj(e)k𝚙𝚊(j)Cj,k(e)fj,k(e)(Zk(e))+Cj,j(e)εjsuperscriptsubscript𝑍𝑗𝑒subscript𝑘𝚙𝚊𝑗superscriptsubscript𝐶𝑗𝑘𝑒superscriptsubscript𝑓𝑗𝑘𝑒superscriptsubscript𝑍𝑘𝑒superscriptsubscript𝐶𝑗𝑗𝑒subscript𝜀𝑗\displaystyle Z_{j}^{(e)}\leftarrow\sum_{k\in\mathtt{pa}(j)}C_{j,k}^{(e)}f_{j,% k}^{(e)}(Z_{k}^{(e)})+C_{j,j}^{(e)}\varepsilon_{j}italic_Z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ← ∑ start_POSTSUBSCRIPT italic_k ∈ typewriter_pa ( italic_j ) end_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT italic_j , italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT italic_j , italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) + italic_C start_POSTSUBSCRIPT italic_j , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_ε start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT

where (ε1,,ε71)subscript𝜀1subscript𝜀71(\varepsilon_{1},\ldots,\varepsilon_{71})( italic_ε start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_ε start_POSTSUBSCRIPT 71 end_POSTSUBSCRIPT ) are independent standard normal random variables. For j36𝑗36j\neq 36italic_j ≠ 36, fj,k(e)superscriptsubscript𝑓𝑗𝑘𝑒f_{j,k}^{(e)}italic_f start_POSTSUBSCRIPT italic_j , italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT are sampled randomly from the candidate functions {cos(x),sin(x),sin(πx),x,1/(1+ex)}𝑥𝑥𝜋𝑥𝑥11superscript𝑒𝑥\{\cos(x),\sin(x),\sin(\pi x),x,1/(1+e^{-x})\}{ roman_cos ( italic_x ) , roman_sin ( italic_x ) , roman_sin ( italic_π italic_x ) , italic_x , 1 / ( 1 + italic_e start_POSTSUPERSCRIPT - italic_x end_POSTSUPERSCRIPT ) }, Cj,k(e)superscriptsubscript𝐶𝑗𝑘𝑒C_{j,k}^{(e)}italic_C start_POSTSUBSCRIPT italic_j , italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT are sampled from Uniform[1.5,1.5]Uniform1.51.5\mathrm{Uniform}[-1.5,1.5]roman_Uniform [ - 1.5 , 1.5 ] with |Cj,j(e)|0.5superscriptsubscript𝐶𝑗𝑗𝑒0.5|C_{j,j}^{(e)}|\geq 0.5| italic_C start_POSTSUBSCRIPT italic_j , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | ≥ 0.5. For j=36𝑗36j=36italic_j = 36 and k<j𝑘𝑗k<jitalic_k < italic_j, we have f36,k(e)(x)=xsuperscriptsubscript𝑓36𝑘𝑒𝑥𝑥f_{36,k}^{(e)}(x)=xitalic_f start_POSTSUBSCRIPT 36 , italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_x ) = italic_x and C36,k(0)C36,k(1)superscriptsubscript𝐶36𝑘0superscriptsubscript𝐶36𝑘1C_{36,k}^{(0)}\equiv C_{36,k}^{(1)}italic_C start_POSTSUBSCRIPT 36 , italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ≡ italic_C start_POSTSUBSCRIPT 36 , italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT for linearity and invariance. The above data-generating process can be regarded as one observation environment e=0𝑒0e=0italic_e = 0 and an interventional environment e=1𝑒1e=1italic_e = 1 where the random and simultaneous interventions are applied to all the variables other than the variable Y𝑌Yitalic_Y, while the assignment from Y𝑌Yitalic_Y’s parent to Y𝑌Yitalic_Y remains and furnishes the target regression function m(x)=k𝚙𝚊(36)C36,k(e)xksuperscript𝑚𝑥subscript𝑘𝚙𝚊36superscriptsubscript𝐶36𝑘𝑒subscript𝑥𝑘m^{\star}(x)=\sum_{k\in\mathtt{pa}(36)}C_{36,k}^{(e)}x_{k}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x ) = ∑ start_POSTSUBSCRIPT italic_k ∈ typewriter_pa ( 36 ) end_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT 36 , italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT in pursuit. In this case, we let S=𝚙𝚊(36)superscript𝑆𝚙𝚊36S^{\star}=\mathtt{pa}(36)italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = typewriter_pa ( 36 ) and βsuperscript𝛽\beta^{\star}italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT with support set Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT be such that βj=C36,k(0)=C36,k(1)superscriptsubscript𝛽𝑗subscriptsuperscript𝐶036𝑘subscriptsuperscript𝐶136𝑘\beta_{j}^{\star}=C^{(0)}_{36,k}=C^{(1)}_{36,k}italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_C start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 36 , italic_k end_POSTSUBSCRIPT = italic_C start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 36 , italic_k end_POSTSUBSCRIPT for any kS𝑘superscript𝑆k\in S^{\star}italic_k ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. We also let the noise variance be different for the two environments, i.e., C36,36(0)C36,36(1)superscriptsubscript𝐶36360superscriptsubscript𝐶36361C_{36,36}^{(0)}\neq C_{36,36}^{(1)}italic_C start_POSTSUBSCRIPT 36 , 36 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ≠ italic_C start_POSTSUBSCRIPT 36 , 36 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT. Now, the model only has conditional expectation invariance rather than the full conditional distribution invariance. Fig. 2 (a) visualizes the induced graph in one trial. The complex cause-effect relationships in high-dimensional variables make the problem of causal pursuit and estimating βsuperscript𝛽\beta^{\star}italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT very challenging.

Refer to caption
Refer to caption
Figure 2: The visualization of (a) the SCM and (b) the sig(w)sig𝑤{\mathrm{sig}(w)}roman_sig ( italic_w ) during training in one trial for the FAIR-Linear estimator. We use different colors to represent the different relationships with Y𝑌Yitalic_Y: blue = parent, red = child, orange = offspring, lightblue = other.

Implementation. For the FAIR-Linear estimator, we realize 𝒢𝒢\mathcal{G}caligraphic_G and \mathcal{F}caligraphic_F by linear function classes, i.e., 𝒢={g(x)=βgx:βgd}𝒢conditional-set𝑔𝑥superscriptsubscript𝛽𝑔top𝑥subscript𝛽𝑔superscript𝑑\mathcal{G}=\{g(x)=\beta_{g}^{\top}x:\beta_{g}\in\mathbb{R}^{d}\}caligraphic_G = { italic_g ( italic_x ) = italic_β start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x : italic_β start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT } and ={f(x)=βfx:βfd}conditional-set𝑓𝑥superscriptsubscript𝛽𝑓top𝑥subscript𝛽𝑓superscript𝑑\mathcal{F}=\{f(x)=\beta_{f}^{\top}x:\beta_{f}\in\mathbb{R}^{d}\}caligraphic_F = { italic_f ( italic_x ) = italic_β start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x : italic_β start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT }, and run gradient descent ascent using Adam optimizer with a learning rate of 1e-3, batch size 64646464 for 50k50𝑘50k50 italic_k iterations. In each iteration, one gradient descent update of the parameters of the predictor βgsubscript𝛽𝑔\beta_{g}italic_β start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT and Gumbel logits parameters w𝑤witalic_w is followed by the three gradient ascent updates of the discriminators’ parameters (βf(1),βf(2))superscriptsubscript𝛽𝑓1superscriptsubscript𝛽𝑓2(\beta_{f}^{(1)},\beta_{f}^{(2)})( italic_β start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , italic_β start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ). We adopt a fixed hyper-parameter γ=36𝛾36\gamma=36italic_γ = 36 and report the performance of the following estimators using the median of the estimation error β^β22superscriptsubscriptnorm^𝛽superscript𝛽22\|\widehat{\beta}-\beta^{\star}\|_{2}^{2}∥ over^ start_ARG italic_β end_ARG - italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over 50505050 replications and varying n{200,500,1000,2000,5000}𝑛200500100020005000n\in\{200,500,1000,2000,5000\}italic_n ∈ { 200 , 500 , 1000 , 2000 , 5000 }.

  • (1)

    Pool-LS: it simply runs least squares on the full covariate X𝑋Xitalic_X using all the data.

  • (2)

    FAIR-GB: Our FAIR-Linear estimator with Gumbel approximation that outputs βgsig(w)direct-productsubscript𝛽𝑔sig𝑤\beta_{g}\odot{\mathrm{sig}(w)}italic_β start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ⊙ roman_sig ( italic_w ).

  • (3)

    FAIR-RF: it selects the variables xjsubscript𝑥𝑗x_{j}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT with sig(wj)>0.9sigsubscript𝑤𝑗0.9{\mathrm{sig}(w_{j})}>0.9roman_sig ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) > 0.9 of the fitted model in (2), i.e., S^={j:sig(wj)>0.9}^𝑆conditional-set𝑗sigsubscript𝑤𝑗0.9\widehat{S}=\{j:{\mathrm{sig}(w_{j})}>0.9\}over^ start_ARG italic_S end_ARG = { italic_j : roman_sig ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) > 0.9 }, and refits least squares again on XS^subscript𝑋^𝑆X_{\widehat{S}}italic_X start_POSTSUBSCRIPT over^ start_ARG italic_S end_ARG end_POSTSUBSCRIPT using all the data.

  • (4)

    Oracle: it runs least squares on XSsubscript𝑋superscript𝑆X_{S^{\star}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT using all the data.

  • (5)

    Semi-Oracle: it runs least squares on XGcsubscript𝑋superscript𝐺𝑐X_{G^{c}}italic_X start_POSTSUBSCRIPT italic_G start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT end_POSTSUBSCRIPT using all the data, where G𝐺Gitalic_G is the set of all the descendants of Y𝑌Yitalic_Y. Compared with the ERM, it manually removes all the variables that will lead to a biased estimation, but it will also keep uncorrelated variables compared with the full Oracle estimation.

Fig. 2 (b) visualizes how the Gumbel gate values for different covariables sig(w)sig𝑤{\mathrm{sig}(w)}roman_sig ( italic_w ) evolve during training in one trial. We can see that sig(wj)sigsubscript𝑤𝑗{\mathrm{sig}(w_{j})}roman_sig ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) for jS𝑗superscript𝑆j\in S^{\star}italic_j ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT quickly increases and dominates the values for other variables like children/offspring of Y𝑌Yitalic_Y during the whole training process.

Results. The results are shown in Fig. 3 (a). We can see that the square of the 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT estimation error β^β22superscriptsubscriptnorm^𝛽superscript𝛽22\|\widehat{\beta}-\beta^{*}\|_{2}^{2}∥ over^ start_ARG italic_β end_ARG - italic_β start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for the pooled least squares estimator ( ×\times×) does not decrease and remains to be very large (1.5absent1.5\approx 1.5≈ 1.5) as n𝑛nitalic_n increases, indicating that it converges to a biased solution. At the same time, the estimation error for FAIR-GB ( \blacklozenge) decays as n𝑛nitalic_n grows (0.01absent0.01\approx 0.01≈ 0.01 when n=1k𝑛1𝑘n=1kitalic_n = 1 italic_k) and lies in between that for least squares on XGcsubscript𝑋superscript𝐺𝑐X_{G^{c}}italic_X start_POSTSUBSCRIPT italic_G start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (Semi-Oracle \blacktriangledown) and least squares on XSsubscript𝑋superscript𝑆X_{S^{*}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (Oracle \blacktriangle). This is expected to happen since the FAIR-Linear estimator is not designed to screen out all the exogenous spurious variables: They can be further regularized using the commonly variable selection techniques; see footnote 4. We also observe that the training dynamics of adversarial estimation are highly non-stable: though it can converge to an estimate around βsuperscript𝛽\beta^{\star}italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT when n𝑛nitalic_n is very large, it fails to converge to βsuperscript𝛽\beta^{\star}italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT at a comparable rate compared to the standard least squares. The FAIR-RF ( +++) estimator then completes the last step towards attaining better accuracy in this regard: we can see that its performances are very close to that of the Oracle estimator when n𝑛nitalic_n is very large (n=5000𝑛5000n=5000italic_n = 5000).

Refer to caption
Refer to caption
Figure 3: The simulation results for linear models with (a) d=70𝑑70d=70italic_d = 70 and (b) d=15𝑑15d=15italic_d = 15. Both figures depict how the median estimation errors (based on 50505050 replications, shown in log scale) for different estimators (marked with different shapes and colors) change when n𝑛nitalic_n varies in (a) {200,500,1000,2000,5000}200500100020005000\{200,500,1000,2000,5000\}{ 200 , 500 , 1000 , 2000 , 5000 } and (b) {100,200,500,800,1000}1002005008001000\{100,200,500,800,1000\}{ 100 , 200 , 500 , 800 , 1000 }, respectively.

Comparison with Other Methods. We also compare our FAIR-Linear estimator with the cousin estimator EILLS ( \blacktriangleright) in Fan et al., (2023) and other invariance learning estimators (dotted lines), including invariant causal prediction Peters et al., (2016) (ICP \blacktriangledown), invariant risk minimization Arjovsky et al., (2019) (IRM +++), anchor regression Rothenhäusler et al., (2021) (Anchor \bullet) in a similar but smaller dimension setting with d=15𝑑15d=15italic_d = 15, under which ICP and EILLS can be computed within affordable time. For the FAIR-Linear estimator, we report the performance of the FAIR-RF ( \blacklozenge) and the one with brute force search (FAIR-BF \blacksquare). The results are shown in Fig. 3 (b): we can see that the FAIR family estimators ( \blacktriangleright \blacksquare \blacklozenge with solid lines) are the only ones attaining consistent estimation among all the invariant learning methods; see a detailed discussion of the data generating process and results in Section C.2.1.

5.2.2 Finite Performance of FAIR-NN Estimator

Data Generating Process. We consider the following data generating process with d=26𝑑26d=26italic_d = 26 and ||=22|\mathcal{E}|=2| caligraphic_E | = 2 in each trial as

Xi(e)superscriptsubscript𝑋𝑖𝑒\displaystyle X_{i}^{(e)}italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT {εi(e)i5fi,0(e)(Y(e))+εi(e)6i9j𝚙𝚊(i)[8]fi,j(e)(Xj(e))+εi(e)10i26absentcasessuperscriptsubscript𝜀𝑖𝑒𝑖5superscriptsubscript𝑓𝑖0𝑒superscript𝑌𝑒superscriptsubscript𝜀𝑖𝑒6𝑖9subscript𝑗𝚙𝚊𝑖delimited-[]8superscriptsubscript𝑓𝑖𝑗𝑒superscriptsubscript𝑋𝑗𝑒superscriptsubscript𝜀𝑖𝑒10𝑖26\displaystyle\leftarrow\begin{cases}\varepsilon_{i}^{(e)}&\qquad i\leq 5\\ f_{i,0}^{(e)}(Y^{(e)})+\varepsilon_{i}^{(e)}&\qquad 6\leq i\leq 9\\ \sum_{j\in\mathtt{pa}(i)\subseteq[8]}f_{i,j}^{(e)}(X_{j}^{(e)})+\varepsilon_{i% }^{(e)}&\qquad 10\leq i\leq 26\end{cases}← { start_ROW start_CELL italic_ε start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_CELL start_CELL italic_i ≤ 5 end_CELL end_ROW start_ROW start_CELL italic_f start_POSTSUBSCRIPT italic_i , 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) + italic_ε start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_CELL start_CELL 6 ≤ italic_i ≤ 9 end_CELL end_ROW start_ROW start_CELL ∑ start_POSTSUBSCRIPT italic_j ∈ typewriter_pa ( italic_i ) ⊆ [ 8 ] end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) + italic_ε start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_CELL start_CELL 10 ≤ italic_i ≤ 26 end_CELL end_ROW
Y(e)superscript𝑌𝑒\displaystyle Y^{(e)}italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT mk(X1(e),,X5(e))+ε0,absentsuperscriptsubscript𝑚𝑘superscriptsubscript𝑋1𝑒superscriptsubscript𝑋5𝑒subscript𝜀0\displaystyle\leftarrow m_{k}^{\star}(X_{1}^{(e)},\ldots,X_{5}^{(e)})+% \varepsilon_{0},← italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , … , italic_X start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) + italic_ε start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ,

where the regression function msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is either m1(x)=k=15m0,j(xj)subscriptsuperscript𝑚1𝑥superscriptsubscript𝑘15subscript𝑚0𝑗subscript𝑥𝑗m^{\star}_{1}(x)=\sum_{k=1}^{5}m_{0,j}(x_{j})italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT 0 , italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) with random chosen m0,jsubscript𝑚0𝑗m_{0,j}italic_m start_POSTSUBSCRIPT 0 , italic_j end_POSTSUBSCRIPT or a hierarchical composition model m2(x)=x1x23+log(1+etanh(x3)+ex4)+sin(x5)subscriptsuperscript𝑚2𝑥subscript𝑥1superscriptsubscript𝑥231superscript𝑒subscript𝑥3superscript𝑒subscript𝑥4subscript𝑥5m^{\star}_{2}(x)=x_{1}x_{2}^{3}+\log(1+e^{\tanh(x_{3})}+e^{x_{4}})+\sin(x_{5})italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x ) = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT + roman_log ( 1 + italic_e start_POSTSUPERSCRIPT roman_tanh ( italic_x start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) + roman_sin ( italic_x start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ); see detailed model and omitted implementation details in Section C.2.2. In the two environments, the cause-effect relationships are shared. The variable Y𝑌Yitalic_Y’s parent set is {1,2,3,4,5}12345\{1,2,3,4,5\}{ 1 , 2 , 3 , 4 , 5 }, its children set is {6,7,8,9}6789\{6,7,8,9\}{ 6 , 7 , 8 , 9 }, and may have potential descendants in {9,,26}926\{9,\ldots,26\}{ 9 , … , 26 }. The above data generating process can be regarded as one observation environment e=0𝑒0e=0italic_e = 0 and an interventional environment e=1𝑒1e=1italic_e = 1 where the random and simultaneous interventions are applied to all the variables other than the variable Y𝑌Yitalic_Y, while the assignment from Y𝑌Yitalic_Y’s parent to Y𝑌Yitalic_Y remains and furnishes the target regression function mk(x)superscriptsubscript𝑚𝑘𝑥m_{k}^{\star}(x)italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x ) with k{1,2}𝑘12k\in\{1,2\}italic_k ∈ { 1 , 2 } in pursuit. Fig. 4 (a) visualizes the induced graph in one trial.

Refer to caption
Refer to caption
Figure 4: The visualization of (a) the SCM and (b) the sig(w)sig𝑤{\mathrm{sig}(w)}roman_sig ( italic_w ) during training in one trial for FAIR-NN estimator when k=1𝑘1k=1italic_k = 1. We use different colors to represent the different relationships with Y𝑌Yitalic_Y: blue = parent, red = child, orange = offspring, lightblue = other.

Implementation. We let 𝒢𝒢\mathcal{G}caligraphic_G be the class of ReLU neural network with depth 2222 and width 128128128128 and \mathcal{F}caligraphic_F be the class of ReLU neural network with depth 2222 and width 196196196196, and run gradient descent ascent using similar experimental configurations. We use the following empirical mean squared square computed using another 2×ntest=2×300002subscript𝑛test2300002\times n_{\mathrm{test}}=2\times 300002 × italic_n start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT = 2 × 30000 i.i.d. sampled data

𝙼𝚂𝙴^=12ntestei=1ntest{m(xi(e))m^(xi(e))}2^𝙼𝚂𝙴12subscript𝑛testsubscript𝑒superscriptsubscript𝑖1subscript𝑛testsuperscriptsuperscript𝑚subscriptsuperscript𝑥𝑒𝑖^𝑚subscriptsuperscript𝑥𝑒𝑖2\displaystyle\widehat{\mathtt{MSE}}=\frac{1}{2n_{\mathrm{test}}}\sum_{e\in% \mathcal{E}}\sum_{i=1}^{n_{\mathrm{test}}}\{m^{\star}(x^{(e)}_{i})-\widehat{m}% (x^{(e)}_{i})\}^{2}over^ start_ARG typewriter_MSE end_ARG = divide start_ARG 1 end_ARG start_ARG 2 italic_n start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT end_POSTSUPERSCRIPT { italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - over^ start_ARG italic_m end_ARG ( italic_x start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

as the evaluation metric. We report the median of 𝙼𝚂𝙴^^𝙼𝚂𝙴\widehat{\mathtt{MSE}}over^ start_ARG typewriter_MSE end_ARG over 100100100100 replications for the estimators (1) – (4) akin to that for the linear model. For (1), (2), and (4), we also use ReLU neural network width depth 2222 and width 128128128128 in running least squares. Fig. 4 (b) also visualizes how the Gumbel gate values for different covariables sig(w)sig𝑤{\mathrm{sig}(w)}roman_sig ( italic_w ) evolve during training in one trial. We can see that the training dynamics for sig(w)sig𝑤{\mathrm{sig}(w)}roman_sig ( italic_w ) is much more challenging and interesting than that for the linear model depicted in Fig. 2: the weight for some Y𝑌Yitalic_Y’s children quickly increases at a comparable rate than the variables in Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT at the beginning, but such a trend slows down and finally completely reverses in the middle. We leave the rigorous and in-depth analysis behind such dynamics for future studies.

Refer to caption
Refer to caption
Figure 5: The simulation results for nonlinear models with (a) m1superscriptsubscript𝑚1m_{1}^{\star}italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and (b) m2superscriptsubscript𝑚2m_{2}^{\star}italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Both figures depict how the median estimation errors (based on 50505050 replications) for different estimators (marked with different shapes and colors) change when n𝑛nitalic_n varies in {1000,2000,3000,5000}1000200030005000\{1000,2000,3000,5000\}{ 1000 , 2000 , 3000 , 5000 } for (a) and {1000,2000,3000,5000,10000}100020003000500010000\{1000,2000,3000,5000,10000\}{ 1000 , 2000 , 3000 , 5000 , 10000 } for (b).

Results. The results are shown in Fig. 5 and the messages are similar to those for FAIR-Linear estimators. The pooled least squares yield biased estimation, while our proposed FAIR-NN estimator can unveil the invariant association msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT from the two environments. Moreover, the refitted FAIR-NN estimator can obtain a near-oracle performance when n𝑛nitalic_n is large.

5.3 Application I: Discovery in Real Physical Systems

We apply our method to perform causal discovery in the light tunnel datasets from Gamella et al., (2024). The data are collected from a real physical device under different manipulation settings. The tunnel device contains a controllable light source at one end and two linear polarizers mounted on rotating frames. Several sensors are deployed in various positions to measure the light intensity. The causal relationships between the variables of interest is known such that we can get access to the ground-truth cause-effect relationship; see Fig. 2(d) and Fig. 3(a) therein for the device diagram and the cause-effect graphs, respectively. It is worth noticing that the data are collected from a real-world device where the associations between the measurements follow from real-world physical laws. This realistic nature together with the knowledge of ground-truth cause-effect knowledge make it an excellent testbed for causal discovery algorithms.

Using the notation in Gamella et al., (2024), we use the variables (R,G,B,θ1,θ2,V~3,V~2,V~1,I~3,I~2,I~1,C~)𝑅𝐺𝐵subscript𝜃1subscript𝜃2subscript~𝑉3subscript~𝑉2subscript~𝑉1subscript~𝐼3subscript~𝐼2subscript~𝐼1~𝐶(R,G,B,\theta_{1},\theta_{2},\widetilde{V}_{3},\widetilde{V}_{2},\widetilde{V}% _{1},\widetilde{I}_{3},\widetilde{I}_{2},\widetilde{I}_{1},\widetilde{C})( italic_R , italic_G , italic_B , italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_C end_ARG ). Here (R,G,B)𝑅𝐺𝐵(R,G,B)( italic_R , italic_G , italic_B ) is the intensity of the light source at three different wavelengths, C~~𝐶\widetilde{C}over~ start_ARG italic_C end_ARG is the drawn electric current, (θ1,θ2)subscript𝜃1subscript𝜃2(\theta_{1},\theta_{2})( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) represent the angles of the polarizer frame, and (V~3,V~2,V~1,I~3,I~2,I~1)subscript~𝑉3subscript~𝑉2subscript~𝑉1subscript~𝐼3subscript~𝐼2subscript~𝐼1(\widetilde{V}_{3},\widetilde{V}_{2},\widetilde{V}_{1},\widetilde{I}_{3},% \widetilde{I}_{2},\widetilde{I}_{1})( over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) are the measurement of light-intensity sensors in various positions.

We plan to learn algorithmically the direct cause for Y=I~3𝑌subscript~𝐼3Y=\widetilde{I}_{3}italic_Y = over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT, the infrared measurement of the light-intensity sensor after the polarizers, among a subset of manipulable variables and measurement variables (X1,,X11)=(R,G,B,θ1,θ2,V~3,V~2,V~1,I~2,I~1,C~)subscript𝑋1subscript𝑋11𝑅𝐺𝐵subscript𝜃1subscript𝜃2subscript~𝑉3subscript~𝑉2subscript~𝑉1subscript~𝐼2subscript~𝐼1~𝐶(X_{1},\ldots,X_{11})=(R,G,B,\theta_{1},\theta_{2},\widetilde{V}_{3},% \widetilde{V}_{2},\widetilde{V}_{1},\widetilde{I}_{2},\widetilde{I}_{1},% \widetilde{C})( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT ) = ( italic_R , italic_G , italic_B , italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_C end_ARG ) under the following two-environment experimental setting: e=0𝑒0e=0italic_e = 0 is the observational environment, e=1𝑒1e=1italic_e = 1 is the interventional environment where the variables {V~j}j=13superscriptsubscriptsubscript~𝑉𝑗𝑗13\{\widetilde{V}_{j}\}_{j=1}^{3}{ over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and {I~j}j=12superscriptsubscriptsubscript~𝐼𝑗𝑗12\{\widetilde{I}_{j}\}_{j=1}^{2}{ over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT are weakly intervened on. This leads to the following “equivalent” ground-truth cause-effect relationship among those variables and the effect of “environment intervention” in Fig. 6 (a). In this case, the variables (R,G,B,θ1,θ2)𝑅𝐺𝐵subscript𝜃1subscript𝜃2(R,G,B,\theta_{1},\theta_{2})( italic_R , italic_G , italic_B , italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) are the direct causes, i.e., S={1,2,3,4,5}superscript𝑆12345S^{\star}=\{1,2,3,4,5\}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = { 1 , 2 , 3 , 4 , 5 }, V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT are the spurious variables that will lead to biased estimation. The remaining variables are exogenous but have marginal predictive power, i.e., Var[Y|Xj]>0Vardelimited-[]conditional𝑌subscript𝑋𝑗0\mathrm{Var}[Y|X_{j}]>0roman_Var [ italic_Y | italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ] > 0 for j7𝑗7j\geq 7italic_j ≥ 7.

We will use the following dataset in the experiment: the environment dataset 𝒟0subscript𝒟0\mathcal{D}_{0}caligraphic_D start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT with size |𝒟0|=104subscript𝒟0superscript104|\mathcal{D}_{0}|=10^{4}| caligraphic_D start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | = 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT, the weakly interventional environment dataset 𝒟1subscript𝒟1\mathcal{D}_{1}caligraphic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT with |𝒟1|=3000subscript𝒟13000|\mathcal{D}_{1}|=3000| caligraphic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | = 3000, and five strongly interventional environment dataset 𝒟2,Zsubscript𝒟2𝑍\mathcal{D}_{2,Z}caligraphic_D start_POSTSUBSCRIPT 2 , italic_Z end_POSTSUBSCRIPT with Z{V~1,V~1,V~3,I~1,I~2}𝑍subscript~𝑉1subscript~𝑉1subscript~𝑉3subscript~𝐼1subscript~𝐼2Z\in\{\widetilde{V}_{1},\widetilde{V}_{1},\widetilde{V}_{3},\widetilde{I}_{1},% \widetilde{I}_{2}\}italic_Z ∈ { over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT } and |𝒟2,Z|=1000subscript𝒟2𝑍1000|\mathcal{D}_{2,Z}|=1000| caligraphic_D start_POSTSUBSCRIPT 2 , italic_Z end_POSTSUBSCRIPT | = 1000. In each trial, different methods use the same random subsample 𝒟˘={𝒟˘0,𝒟˘1}˘𝒟subscript˘𝒟0subscript˘𝒟1\breve{\mathcal{D}}=\{\breve{\mathcal{D}}_{0},\breve{\mathcal{D}}_{1}\}over˘ start_ARG caligraphic_D end_ARG = { over˘ start_ARG caligraphic_D end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , over˘ start_ARG caligraphic_D end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } with 𝒟˘k𝒟ksubscript˘𝒟𝑘subscript𝒟𝑘\breve{\mathcal{D}}_{k}\subseteq\mathcal{D}_{k}over˘ start_ARG caligraphic_D end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊆ caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and |𝒟˘k|=n=1000subscript˘𝒟𝑘𝑛1000|\breve{\mathcal{D}}_{k}|=n=1000| over˘ start_ARG caligraphic_D end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | = italic_n = 1000 to fit the model. How the fitted model f^^𝑓\widehat{f}over^ start_ARG italic_f end_ARG quantitatively depends on exogenous/endogeneous spurious variable Z𝑍Zitalic_Z is evaluated using the OOS R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT in corresponding 𝒟2,Zsubscript𝒟2𝑍\mathcal{D}_{2,Z}caligraphic_D start_POSTSUBSCRIPT 2 , italic_Z end_POSTSUBSCRIPT defined as

ROOS,Z2:=(X,Y)𝒟2,Z{f^(X)Y}2(X,Y)𝒟2,Z{YY¯}2withY¯=(X,Y)𝒟˘0𝒟˘1Y2n.formulae-sequenceassignsubscriptsuperscript𝑅2OOS𝑍subscript𝑋𝑌subscript𝒟2𝑍superscript^𝑓𝑋𝑌2subscript𝑋𝑌subscript𝒟2𝑍superscript𝑌¯𝑌2with¯𝑌subscript𝑋𝑌subscript˘𝒟0subscript˘𝒟1𝑌2𝑛\displaystyle R^{2}_{\mathrm{OOS},Z}:=\frac{\sum_{(X,Y)\in\mathcal{D}_{2,Z}}\{% \widehat{f}(X)-Y\}^{2}}{\sum_{(X,Y)\in\mathcal{D}_{2,Z}}\{Y-\bar{Y}\}^{2}}% \qquad\text{with}\qquad\bar{Y}=\frac{\sum_{(X,Y)\in\breve{\mathcal{D}}_{0}\cup% \breve{\mathcal{D}}_{1}}Y}{2n}.italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_OOS , italic_Z end_POSTSUBSCRIPT := divide start_ARG ∑ start_POSTSUBSCRIPT ( italic_X , italic_Y ) ∈ caligraphic_D start_POSTSUBSCRIPT 2 , italic_Z end_POSTSUBSCRIPT end_POSTSUBSCRIPT { over^ start_ARG italic_f end_ARG ( italic_X ) - italic_Y } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT ( italic_X , italic_Y ) ∈ caligraphic_D start_POSTSUBSCRIPT 2 , italic_Z end_POSTSUBSCRIPT end_POSTSUBSCRIPT { italic_Y - over¯ start_ARG italic_Y end_ARG } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG with over¯ start_ARG italic_Y end_ARG = divide start_ARG ∑ start_POSTSUBSCRIPT ( italic_X , italic_Y ) ∈ over˘ start_ARG caligraphic_D end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∪ over˘ start_ARG caligraphic_D end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_Y end_ARG start_ARG 2 italic_n end_ARG .

See the detailed data collection and experimental configuration in Section C.3.

(R,G,B)𝑅𝐺𝐵(R,G,B)( italic_R , italic_G , italic_B )θ1subscript𝜃1\theta_{1}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTθ2subscript𝜃2\theta_{2}italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTC~~𝐶\widetilde{C}over~ start_ARG italic_C end_ARGI~2subscript~𝐼2\widetilde{I}_{2}over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTI~1subscript~𝐼1\widetilde{I}_{1}over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTV~2subscript~𝑉2\widetilde{V}_{2}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTV~1subscript~𝑉1\widetilde{V}_{1}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTI~3subscript~𝐼3\widetilde{I}_{3}over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPTV~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPTE𝐸Eitalic_E
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 6: Discovery in Real Physical Systems: (a) the unified cause-effect relationship and interventions similar to Fig. 1 (b). (b) the average out-of-sample R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for different estimators using the spider chart: the axis annotated by placeholder variable Z𝑍Zitalic_Z corresponds to the test environment where Z𝑍Zitalic_Z is strongly intervened on. We can see the performance of Oracle-NN and FAIR-NN-RF is almost identical. (c) the average (based on 100 replications) of the worst-case (across 5 environments) of OOS R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for different methods as a function of n𝑛nitalic_n. (d) the variable selection rate over 100100100100 trials for different methods (top panel) and the variable selection rate for FAIR-NN for various n𝑛nitalic_n (bottom panel). We use different colors to represent different relationships with Y𝑌Yitalic_Y: blue=parent, red=child, orange=neither ancestor nor descendants. (e) the distribution of worst-case OOS R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (y-axis) for Gumbel-trick optimized FAIR-NN (Gumbel), the follow-up refitted estimator (Refit), and Pooled LS (Pooled) when FAIR-NN selects the wrong variables: the subplots from top to bottom consider the cases of (i) failure in selection consistency (ii) false positive that it falsely selects the child X8=V~3subscript𝑋8subscript~𝑉3X_{8}=\widetilde{V}_{3}italic_X start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPT = over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT (iii) false negative that it does not select the entire ground-truth (X1,,X5)=(R,G,B,θ1,θ2)subscript𝑋1subscript𝑋5𝑅𝐺𝐵subscript𝜃1subscript𝜃2(X_{1},\ldots,X_{5})=(R,G,B,\theta_{1},\theta_{2})( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ) = ( italic_R , italic_G , italic_B , italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ).

The first four rows in Fig. 6 (d) report the variable selection result for several methods over 100100100100 trials. The nonlinear ICP (Heinze-Deml et al.,, 2018) method does not select any variables because of its conservative nature and stronger heterogeneity condition to recover the direct cause. We can see that FAIR-NN can successfully recover the direct cause (R,G,B,θ1,θ2)𝑅𝐺𝐵subscript𝜃1subscript𝜃2(R,G,B,\theta_{1},\theta_{2})( italic_R , italic_G , italic_B , italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) in this case. It exploits neural networks’ capability in efficiently detecting the nonlinear associations (the Malus’s law, I~3cos2(θ1θ2)proportional-tosubscript~𝐼3superscript2subscript𝜃1subscript𝜃2\widetilde{I}_{3}\propto\cos^{2}(\theta_{1}-\theta_{2})over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ∝ roman_cos start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) for fixed (R,G,B)𝑅𝐺𝐵(R,G,B)( italic_R , italic_G , italic_B )), while the linear counterpart FAIR-Linear fails to select the variables (θ1,θ2)subscript𝜃1subscript𝜃2(\theta_{1},\theta_{2})( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). It is worth pointing out that such a causality recovery cannot be attained by the traditional predictive power and simplicity tradeoff: the variable selection method based on random forest variable importance measures (ForestVarSel) in Heinze-Deml et al., (2018) cannot detect (G,B,θ1,θ2)𝐺𝐵subscript𝜃1subscript𝜃2(G,B,\theta_{1},\theta_{2})( italic_G , italic_B , italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) and falsely select (I~1,I~2)subscript~𝐼1subscript~𝐼2(\widetilde{I}_{1},\widetilde{I}_{2})( over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). The last three rows in Fig. 6 (d) illustrate how the variable selection rate for the FAIR-NN estimator changes when n𝑛nitalic_n grows.

Fig. 6 (b) offers a quantitative illustration by showing the out-of-sample (OOS) R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT of different estimators under environments with strong interventions on (I~1,I~2,V~1,V~2,V~3)subscript~𝐼1subscript~𝐼2subscript~𝑉1subscript~𝑉2subscript~𝑉3(\widetilde{I}_{1},\widetilde{I}_{2},\widetilde{V}_{1},\widetilde{V}_{2},% \widetilde{V}_{3})( over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ), respectively. The estimator denoted as Oracle-M𝑀Mitalic_M with M{Linear,NN}𝑀LinearNNM\in\{\mathrm{Linear},\mathrm{NN}\}italic_M ∈ { roman_Linear , roman_NN } referred to the method that runs regress Y𝑌Yitalic_Y on XSsubscript𝑋superscript𝑆X_{S^{\star}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT using model M𝑀Mitalic_M. In the spider chart, the red shade represents the out-of-sample R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT under different interventions for the Oracle-NN estimator that regresses Y𝑌Yitalic_Y on its direct causes. We can see that its performances behave uniformly under various interventions: all the OOS R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT are approximately equal to 0.910.910.910.91. This is slightly better than that for the linear model ( Oracle-Linear) by 0.04. This illustrates the capability of neural networks introduced to detect weak, nonlinear causal signals from heterogeneous environments. The PoolLS-NN estimator regressing Y𝑌Yitalic_Y on X𝑋Xitalic_X using neural network and all the data fully exploits the strong spurious association between V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT and Y=I~3𝑌subscript~𝐼3Y=\widetilde{I}_{3}italic_Y = over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT, its heavy reliance on V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT let it predict better (than the causal model Oracle-NN) when V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT is not intervened. However, its OOS R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT significantly decreases by 0.20.20.20.2 when V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT is strongly intervened hence the spurious association changes. On the contrary, the OOS R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for FAIR-NN after refitting ( FAIR-NN-RF) behaves almost identical to that for Oracle-NN. This quantitative result illustrates its capability to correct non-trivial and strong bias without no supervision and its efficiency in detecting nonlinear and weak signals.

Fig. 6 (c) shows how the worst-case OOS R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT among the five, strong intervention environments changes for different estimators when n𝑛nitalic_n grows. The performance of the Gumbel-trick optimized FAIR-NN estimator without refitting ( FAIR-NN-GB) lies between Oracle-NN and Oracle-Linear and significantly outperforms that of the PoolLS-NN estimator. This suggests that the gradient descent optimized algorithm has already found predictions nearly independent of the spurious variable, and the success of variable selection in Fig. 6 (d) is not because of truncating weak but non-negligible spurious signals. Moreover, as shown in Fig. 6 (e), its performance significantly outperforms the least squares estimator using either the full covariate or the selected covariates when it selects the wrong variable. This further supports the theoretical claims and the advantages of adopting penalized least squares.

5.4 Application II: Prediction Based on Extracted Features

We consider an image object classification task with a spurious background. The target is to classify water birds (Y=1𝑌1Y=1italic_Y = 1) and land birds (Y=0𝑌0Y=0italic_Y = 0) (see examples in Fig. 7 (a)) under backgrounds of water or land based on the feature X500𝑋superscript500X\in\mathbb{R}^{500}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT 500 end_POSTSUPERSCRIPT extracted from ResNet pre-trained on ImageNet. We train a linear classifier on top of X𝑋Xitalic_X using data from two environments {𝒟k}k=12superscriptsubscriptsubscript𝒟𝑘𝑘12\{\mathcal{D}_{k}\}_{k=1}^{2}{ caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. In the first environment 𝒟1subscript𝒟1\mathcal{D}_{1}caligraphic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, rw=95%subscript𝑟𝑤percent95r_{w}=95\%italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = 95 % water birds appear on the water background and rl=90%subscript𝑟𝑙percent90r_{l}=90\%italic_r start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = 90 % land birds stay in land background. The spurious correlation numbers are rw=75%subscript𝑟𝑤percent75r_{w}=75\%italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = 75 % and rl=70%subscript𝑟𝑙percent70r_{l}=70\%italic_r start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = 70 % in 𝒟2subscript𝒟2\mathcal{D}_{2}caligraphic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. A good predictor should based on the core features related to the bird’s appearance rather than the strong spurious correlation between the background and label. The trained model is evaluated in a test environment |𝒟3|subscript𝒟3|\mathcal{D}_{3}|| caligraphic_D start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT | where the spurious correlation reverses: rw=2%subscript𝑟𝑤percent2r_{w}=2\%italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = 2 % and rl=2%subscript𝑟𝑙percent2r_{l}=2\%italic_r start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = 2 %. We repeat the experiment 10101010 times, where in each trial the training dataset and the test dataset are sampled from a larger dataset with sizes n=|𝒟1|=|𝒟2|=50k𝑛subscript𝒟1subscript𝒟250𝑘n=|\mathcal{D}_{1}|=|\mathcal{D}_{2}|=50kitalic_n = | caligraphic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | = | caligraphic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | = 50 italic_k and |𝒟3|=30ksubscript𝒟330𝑘|\mathcal{D}_{3}|=30k| caligraphic_D start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT | = 30 italic_k. We compare our FAIR(-egulerized) estimator using 𝒢={sig(βx)},={βx}formulae-sequence𝒢sigsuperscript𝛽top𝑥superscript𝛽top𝑥\mathcal{G}=\{{\mathrm{sig}(\beta^{\top}x)}\},\mathcal{F}=\{\beta^{\top}x\}caligraphic_G = { roman_sig ( italic_β start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) } , caligraphic_F = { italic_β start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x } and classification loss (y,v)=log(1v)ylog{v/(1v)}𝑦𝑣1𝑣𝑦𝑣1𝑣\ell(y,v)=-\log(1-v)-y\log\{v/(1-v)\}roman_ℓ ( italic_y , italic_v ) = - roman_log ( 1 - italic_v ) - italic_y roman_log { italic_v / ( 1 - italic_v ) } ( FAIR-GB) with invariant risk minimization ( IRM) (Arjovsky et al.,, 2019) and group distributionally robust optimization ( GroupDRO) (Sagawa et al.,, 2020). We also consider running Lasso on different environments for reference, including (1) using all the data 𝒟1𝒟2subscript𝒟1subscript𝒟2\mathcal{D}_{1}\cup\mathcal{D}_{2}caligraphic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∪ caligraphic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( Pooled Lasso); (2) using data in 𝒟2subscript𝒟2\mathcal{D}_{2}caligraphic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( Lasso on D2); (3) using another randomized controlled environment 𝒟4subscript𝒟4\mathcal{D}_{4}caligraphic_D start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT with rw=rl=50%subscript𝑟𝑤subscript𝑟𝑙percent50r_{w}=r_{l}=50\%italic_r start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = italic_r start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = 50 % and |𝒟4|=nsubscript𝒟4𝑛|\mathcal{D}_{4}|=n| caligraphic_D start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT | = italic_n ( Oracle). All the models are linear, and the performance of (3) can be seen as the upper bound of the performance using linear models; see data collection and experiential configuration details in Section C.4.

The performances are reported in Fig. 7 (b). Fig. 7 (c) also depicts how test accuracy changes as iterations in one trial. FAIR-GB performs similar to Oracle and significantly outperforms Lasso on D2, while other methods ( IRM, DRO) falls behind Lasso on D2. This indicates that these methods cannot go beyond interpolating the spurious associations in 𝒟1subscript𝒟1\mathcal{D}_{1}caligraphic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝒟2subscript𝒟2\mathcal{D}_{2}caligraphic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, while our method can nearly eliminate the spurious association using the relatively small perturbations in the two environments.

Refer to caption
MethodTest AccuracyOracle91.06 ±plus-or-minus\pm± 0.24 %Lasso on D284.57 ±plus-or-minus\pm± 0.71 %Pooled Lasso79.08 ±plus-or-minus\pm± 0.54 %IRM80.32 ±plus-or-minus\pm± 0.67 %GroupDRO82.75 ±plus-or-minus\pm± 1.10 %FAIR-GB89.56 ±plus-or-minus\pm± 0.53 %
Refer to caption
Figure 7: Prediction Based on Extracted Features: (a) provides two sample images in the dataset: land bird on land (up) and water bird in water (bottom). (b) reports the average ±plus-or-minus\pm± standard deviation of test accuracy over 10 trials for different estimators. (c) shows how the test accuracy changes over iterations for different methods in one trial.

Acknowledgement

We thank Yiran Jia for helpful discussions on presenting a generic identification result on SCM using the unified graph including E𝐸Eitalic_E, Yimu Zhang for the help with the numerical implementation in Section 5.4, and Xinwei Shen for suggestions of using Gumbel approximation in implementation.

Appendix

The appendix is organized as follows:

  • Appendix A

    contains the omitted discussions in the main text, including the applicable scenarios for the nonparametric invariance pursuit, some discussions and extensions on the method, and some discussions on the conditions in Section 2 and Section 3.

  • Appendix B

    contains the complete result that is sketched in Section 4.3.

  • Appendix C

    contains omitted discussions and results in experiments section.

Appendix A Omitted Discussions and Results

A.1 Applicable Scenarios for Nonparametric Invariance Pursuit

This section is devoted to providing a self-contained introduction to the motivation behind the nonparametric invariance pursuit using statements akin to previous literature (Peters et al.,, 2016; Rojas-Carulla et al.,, 2018; Fan et al.,, 2023).

Causal Discovery.

If we can expect \mathcal{E}caligraphic_E to be heterogeneous enough, recovering Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in nonparametric invariance pursuit coincides with discovering the direct cause of Y𝑌Yitalic_Y when the multi-environment data come from SCM with intervention on X𝑋Xitalic_X setting.

Proposition 3.

Under the model (3.2), if we further assume that 𝔼[|Y(e)|2]<𝔼delimited-[]superscriptsuperscript𝑌𝑒2\mathbb{E}[|Y^{(e)}|^{2}]<\inftyblackboard_E [ | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] < ∞ for any e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, then (1.1) holds with S=𝚙𝚊(d+1)superscript𝑆𝚙𝚊𝑑1S^{\star}=\mathtt{pa}(d+1)italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = typewriter_pa ( italic_d + 1 ).

The SCM (3.2) and Proposition 3 extend the framework described in Peters et al., (2016) (specifically Section 4.1 and Proposition 1). This model accommodates nonlinear structural assignments. Critically, the residuals ε(e)=Y(e)𝔼[Y(e)|XS(e)]superscript𝜀𝑒superscript𝑌𝑒𝔼delimited-[]conditionalsuperscript𝑌𝑒superscriptsubscript𝑋superscript𝑆𝑒\varepsilon^{(e)}=Y^{(e)}-\mathbb{E}[Y^{(e)}|X_{S^{\star}}^{(e)}]italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ], do not need to be independent of XS(e)superscriptsubscript𝑋superscript𝑆𝑒X_{S^{\star}}^{(e)}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT or remain invariant across various environments as represented by ε(e)μεsimilar-tosuperscript𝜀𝑒subscript𝜇𝜀\varepsilon^{(e)}\sim\mu_{\varepsilon}italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∼ italic_μ start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT. Such flexibility broadens the scope for various applications, including binary classification. According to Proposition 3, when restricted to model (3.2), a specific instantiation of our generic statistical model (1.1), identifying the true important variable set Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is tantamount to pinpointing the direct cause of the target variable Y𝑌Yitalic_Y. Concurrently, unveiling the invariant association msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT aligns with uncovering the causal mechanism between Y𝑌Yitalic_Y and its direct causes.

Transfer Learning. Consider we collect data {(Xi(e),Yi(e))}e,i[n]subscriptsuperscriptsubscript𝑋𝑖𝑒superscriptsubscript𝑌𝑖𝑒formulae-sequence𝑒𝑖delimited-[]𝑛\{(X_{i}^{(e)},Y_{i}^{(e)})\}_{e\in\mathcal{E},i\in[n]}{ ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E , italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT from |||\mathcal{E}|| caligraphic_E | distinct sources and aim to develop a model that produces decent predictions on the data {Xi(t)}i[nt]subscriptsubscriptsuperscript𝑋𝑡𝑖𝑖delimited-[]subscript𝑛𝑡\{X^{(t)}_{i}\}_{i\in[n_{t}]}{ italic_X start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i ∈ [ italic_n start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT in an unseen environment t𝑡titalic_t. A significant portion of transfer learning algorithms fundamentally relies on the covariate shift assumption, represented as

𝔼[Y(t)|X(t)]𝔼[Y(e)|X(e)]e.formulae-sequence𝔼delimited-[]conditionalsuperscript𝑌𝑡superscript𝑋𝑡𝔼delimited-[]conditionalsuperscript𝑌𝑒superscript𝑋𝑒for-all𝑒\displaystyle\mathbb{E}[Y^{(t)}|X^{(t)}]\equiv\mathbb{E}[Y^{(e)}|X^{(e)}]% \qquad\forall e\in\mathcal{E}.blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | italic_X start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] ≡ blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ∀ italic_e ∈ caligraphic_E .

However, as illustrated in Fan et al., (2023); Rojas-Carulla et al., (2018), it is hard for this to be true given collecting so many variables. Therefore, a more realistic assumption is that information from true important variables is transferable, articulated as 𝔼[Y(t)|XS(t)]=𝔼[Y(e)|XS(e)]𝔼delimited-[]conditionalsuperscript𝑌𝑡subscriptsuperscript𝑋𝑡superscript𝑆𝔼delimited-[]conditionalsuperscript𝑌𝑒subscriptsuperscript𝑋𝑒superscript𝑆\mathbb{E}[Y^{(t)}|X^{(t)}_{S^{\star}}]=\mathbb{E}[Y^{(e)}|X^{(e)}_{S^{\star}}]blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | italic_X start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] = blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ]. The subsequent proposition suggests that though msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT might not be the optimal predictor in the unseen environment t𝑡titalic_t, it does minimize the worst-case L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT risk, and the associated excess risk can be decomposed as follows.

We suppose both the distribution μ(e)superscript𝜇𝑒\mu^{(e)}italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT we observed in \mathcal{E}caligraphic_E and the future distributions ν𝜈\nuitalic_ν come from the following distribution family.

𝒰S,m,σ2={μ:𝔼μ[Y2]<,𝔼μ[Y|XS]=m(XS),𝔼μ[Varμ(Y|XS)]max1jd𝔼μ[Xj2]σ2},subscript𝒰superscript𝑆superscript𝑚superscript𝜎2conditional-set𝜇formulae-sequencesubscript𝔼𝜇delimited-[]superscript𝑌2formulae-sequencesubscript𝔼𝜇delimited-[]conditional𝑌subscript𝑋superscript𝑆superscript𝑚subscript𝑋superscript𝑆subscript𝔼𝜇delimited-[]subscriptVar𝜇conditional𝑌subscript𝑋superscript𝑆subscript1𝑗𝑑subscript𝔼𝜇delimited-[]superscriptsubscript𝑋𝑗2superscript𝜎2\displaystyle\mathcal{U}_{S^{\star},m^{\star},\sigma^{2}}=\Big{\{}\mu:\mathbb{% E}_{\mu}[Y^{2}]<\infty,\mathbb{E}_{\mu}[Y|X_{S^{\star}}]=m^{\star}(X_{S^{\star% }}),\mathbb{E}_{\mu}[\mathrm{Var}_{\mu}(Y|X_{S^{\star}})]\lor\max_{1\leq j\leq d% }\mathbb{E}_{\mu}[X_{j}^{2}]\leq\sigma^{2}\Big{\}},caligraphic_U start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = { italic_μ : blackboard_E start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT [ italic_Y start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] < ∞ , blackboard_E start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT [ italic_Y | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] = italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) , blackboard_E start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT [ roman_Var start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ( italic_Y | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ] ∨ roman_max start_POSTSUBSCRIPT 1 ≤ italic_j ≤ italic_d end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT [ italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ≤ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } ,
Proposition 4.

Let ν𝒰S,m,σ2𝜈subscript𝒰superscript𝑆superscript𝑚superscript𝜎2\nu\in\mathcal{U}_{S^{\star},m^{\star},\sigma^{2}}italic_ν ∈ caligraphic_U start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT be arbitrary. Define 𝖱𝚘𝚘𝚜(m;νx)=supμ𝒰S,m,σ2,μxνx𝔼(X,Y)μ[|Ym(X)|2]subscript𝖱𝚘𝚘𝚜𝑚subscript𝜈𝑥subscriptsupremumformulae-sequence𝜇subscript𝒰superscript𝑆superscript𝑚superscript𝜎2similar-tosubscript𝜇𝑥subscript𝜈𝑥subscript𝔼similar-to𝑋𝑌𝜇delimited-[]superscript𝑌𝑚𝑋2\mathsf{R}_{\mathtt{oos}}(m;\nu_{x})=\sup_{\mu\in\mathcal{U}_{S^{\star},m^{% \star},\sigma^{2}},\mu_{x}\sim\nu_{x}}\mathbb{E}_{(X,Y)\sim\mu}[|Y-m(X)|^{2}]sansserif_R start_POSTSUBSCRIPT typewriter_oos end_POSTSUBSCRIPT ( italic_m ; italic_ν start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) = roman_sup start_POSTSUBSCRIPT italic_μ ∈ caligraphic_U start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ∼ italic_ν start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT ( italic_X , italic_Y ) ∼ italic_μ end_POSTSUBSCRIPT [ | italic_Y - italic_m ( italic_X ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] and Θ(t)=L2(νx)superscriptΘ𝑡subscript𝐿2subscript𝜈𝑥\Theta^{(t)}=L_{2}(\nu_{x})roman_Θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_ν start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ). We have

mΘ(t)𝖱𝚘𝚘𝚜(m;νx)𝖱𝚘𝚘𝚜(m;νx)=mmL2(νx)2+2σmm~L2(νx),formulae-sequencefor-all𝑚superscriptΘ𝑡subscript𝖱𝚘𝚘𝚜𝑚subscript𝜈𝑥subscript𝖱𝚘𝚘𝚜superscript𝑚subscript𝜈𝑥superscriptsubscriptnorm𝑚superscript𝑚subscript𝐿2subscript𝜈𝑥22𝜎subscriptnorm𝑚~𝑚subscript𝐿2subscript𝜈𝑥\displaystyle\forall m\in\Theta^{(t)}\qquad\mathsf{R}_{\mathtt{oos}}(m;\nu_{x}% )-\mathsf{R}_{\mathtt{oos}}(m^{\star};\nu_{x})=\|m-m^{\star}\|_{L_{2}(\nu_{x})% }^{2}+2\sigma\|m-\widetilde{m}\|_{L_{2}(\nu_{x})},∀ italic_m ∈ roman_Θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT sansserif_R start_POSTSUBSCRIPT typewriter_oos end_POSTSUBSCRIPT ( italic_m ; italic_ν start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) - sansserif_R start_POSTSUBSCRIPT typewriter_oos end_POSTSUBSCRIPT ( italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ; italic_ν start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) = ∥ italic_m - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_ν start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_σ ∥ italic_m - over~ start_ARG italic_m end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_ν start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ,

where m~(x)=𝔼Xνx[m(X)|XS=xS]~𝑚𝑥subscript𝔼similar-to𝑋subscript𝜈𝑥delimited-[]conditional𝑚𝑋subscript𝑋superscript𝑆subscript𝑥superscript𝑆\widetilde{m}(x)=\mathbb{E}_{X\sim\nu_{x}}[m(X)|X_{S^{\star}}=x_{S^{\star}}]over~ start_ARG italic_m end_ARG ( italic_x ) = blackboard_E start_POSTSUBSCRIPT italic_X ∼ italic_ν start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m ( italic_X ) | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ]. The term 2σmm~L2(νx)2𝜎subscriptnorm𝑚~𝑚subscript𝐿2subscript𝜈𝑥2\sigma\|m-\widetilde{m}\|_{L_{2}(\nu_{x})}2 italic_σ ∥ italic_m - over~ start_ARG italic_m end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_ν start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT is zero when mΘS(t)𝑚superscriptsubscriptΘsuperscript𝑆𝑡m\in\Theta_{S^{\star}}^{(t)}italic_m ∈ roman_Θ start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT.

Given the framework described above, our proposed method solving problem in Section 1.1 can be integrated with the re-weighting technique (Gretton et al.,, 2009), a strategy addressing discrepancies within the marginal distribution of X𝑋Xitalic_X, to yield reliable predictions in the previously unobserved environment t𝑡titalic_t.

A.2 Discussion on the Methods

We provide a discussion in a question-and-response manner.


[Q] You are doing “focused regularizer” that are of combinatorial nature in computation, can it be removed?

Answer: The short answer is No. The regularizer will be the same as running least squares if we do not enforce the discriminator using the same variables that the predictor uses. This is also the main computational difficulty in our framework and why we use randomness relaxation and Gumbel approximation in implementation. Indeed, even for linear invariance pursuit, there are certain fundamental computational limits in this such that no polynomial-time algorithm can attain consistent estimation in pursuing invariance without relying on additional structures other than invariance.


[Q] The method has a similar form to IRM, what’s the major difference?

Answer: The main difference is we should at least let ΘfΘgsubscriptΘ𝑔subscriptΘ𝑓\Theta_{f}\supseteq\Theta_{g}roman_Θ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ⊇ roman_Θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT, such a constraint leverage the idea of over-identification and make identification possible even when ||=22|\mathcal{E}|=2| caligraphic_E | = 2 provided enough heterogeneity. Suppose our regularizer, which can be seen as a “correct” method to pursue condition expectation invariance, is to make u(1)=u(2)superscript𝑢1superscript𝑢2u^{(1)}=u^{(2)}italic_u start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = italic_u start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT for two s𝑠sitalic_s-dimensional parameter vectors u(1),u(2)ssuperscript𝑢1superscript𝑢2superscript𝑠u^{(1)},u^{(2)}\in\mathbb{R}^{s}italic_u start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , italic_u start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT, what IRM does is to let i=1sui(1)=i=1sui(2)superscriptsubscript𝑖1𝑠subscriptsuperscript𝑢1𝑖superscriptsubscript𝑖1𝑠subscriptsuperscript𝑢2𝑖\sum_{i=1}^{s}u^{(1)}_{i}=\sum_{i=1}^{s}u^{(2)}_{i}∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT italic_u start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT italic_u start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. It is hard to say the latter constraint will make sense and can obtain a similar effect as the former.


[Q] Could your proposed framework be extended to the representation-level invariance like IRM?

Answer: The short answer is Yes given its algorithmic nature. But identification with two or constant-level environments is impossible now: a linear-in-dimension number of environments is required even for linear representation learning. For example, one can find some linear representation Φ:dr:Φsuperscript𝑑superscript𝑟\Phi:\mathbb{R}^{d}\to\mathbb{R}^{r}roman_Φ : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT such that

𝔼[Y(e)|ΦX(e)]m(ΦX(e))𝔼delimited-[]conditionalsuperscript𝑌𝑒Φsuperscript𝑋𝑒superscript𝑚Φsuperscript𝑋𝑒\displaystyle\mathbb{E}[Y^{(e)}|\Phi X^{(e)}]\equiv m^{\star}(\Phi X^{(e)})blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | roman_Φ italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ≡ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( roman_Φ italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT )

However, ||r𝑟|\mathcal{E}|\geq r| caligraphic_E | ≥ italic_r is the necessary condition for identification even when the heterogeneity is enough and r𝑟ritalic_r is pre-known to us. We conjecture that any finite number of environments ||<|\mathcal{E}|<\infty| caligraphic_E | < ∞ may be impossible for identification if ΦΦ\Phiroman_Φ lies in some nonparametric function class.

A.3 Extensions to General Environment Variable and Loss Function

In the main text, we propose an estimation framework leveraging conditional expectation invariance with respect to discrete environment variables. It is worth noticing that our adversarial estimation framework is indeed more versatile than this: one can easily extend it to other conditional point prediction invariance with respect to more general environment covariates. We briefly discuss the direct extension here and leave a rigorous treatment as future work. In the following discussions, suppose we observe data {(Xi,Yi,Ei)}i=1nsuperscriptsubscriptsubscript𝑋𝑖subscript𝑌𝑖subscript𝐸𝑖𝑖1𝑛\{(X_{i},Y_{i},E_{i})\}_{i=1}^{n}{ ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_E start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT drawn i.i.d. from some distribution μ0subscript𝜇0\mu_{0}italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, where Xd𝑋superscript𝑑X\in\mathbb{R}^{d}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is the covariate we used for prediction, Y𝑌Y\in\mathbb{R}italic_Y ∈ blackboard_R is the target response, Eq𝐸superscript𝑞E\in\mathbb{R}^{q}italic_E ∈ blackboard_R start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT is the environment covariate we wish our prediction should be invariant with respect to.

Let (u,y):×:𝑢𝑦\ell(u,y):\mathbb{R}\times\mathbb{R}\to\mathbb{R}roman_ℓ ( italic_u , italic_y ) : blackboard_R × blackboard_R → blackboard_R be the user-defined risk whose population-level minimizer may not necessarily be conditional expectation but satisfying certain regularity conditions. Let u(u,y)=(u,y)/usubscript𝑢𝑢𝑦𝑢𝑦𝑢\ell_{u}(u,y)=\partial\ell(u,y)/\partial uroman_ℓ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_u , italic_y ) = ∂ roman_ℓ ( italic_u , italic_y ) / ∂ italic_u be the partial sub-gradient with respect to the prediction. Suppose the following general invariance structure with respect to \ellroman_ℓ and environment covariate holds, that there exists S[d]superscript𝑆delimited-[]𝑑S^{\star}\subseteq[d]italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ [ italic_d ] and a function gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT that only depends xSsubscript𝑥superscript𝑆x_{S^{\star}}italic_x start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT such that

𝔼[u(g(XS),Y)|XS,E]0.𝔼delimited-[]conditionalsubscript𝑢superscript𝑔subscript𝑋superscript𝑆𝑌subscript𝑋superscript𝑆𝐸0\displaystyle\mathbb{E}\left[\ell_{u}(g^{\star}(X_{S^{\star}}),Y)|X_{S^{\star}% },E\right]\equiv 0.blackboard_E [ roman_ℓ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) , italic_Y ) | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_E ] ≡ 0 . (A.1)

It coincides with the main problem of study when E𝐸Eitalic_E is discrete and \ellroman_ℓ satisfies (4.4), but also allows for other loss and continuous environment label. Other losses include but not limited to Huber loss for robust regression, or L1subscript𝐿1L_{1}italic_L start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT loss for median regression.

We consider the following optimization minimax objective containing a min-max game between a predictor g:d:𝑔superscript𝑑g:\mathbb{R}^{d}\to\mathbb{R}italic_g : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R and a discriminator f:d×q:𝑓superscript𝑑superscript𝑞f:\mathbb{R}^{d}\times\mathbb{R}^{q}\to\mathbb{R}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → blackboard_R:

ming𝒢maxfSg1ni=1n(g(Xi),Yi)𝖱^(g)+γ1ni=1n[u(g(Xi),Yi)f(Xi,Ei)0.5{f(Xi,Ei)}2]𝖩^(g,f),subscript𝑔𝒢subscript𝑓subscriptsubscript𝑆𝑔subscript1𝑛superscriptsubscript𝑖1𝑛𝑔subscript𝑋𝑖subscript𝑌𝑖^𝖱𝑔𝛾subscript1𝑛superscriptsubscript𝑖1𝑛delimited-[]subscript𝑢𝑔subscript𝑋𝑖subscript𝑌𝑖𝑓subscript𝑋𝑖subscript𝐸𝑖0.5superscript𝑓subscript𝑋𝑖subscript𝐸𝑖2^𝖩𝑔𝑓\displaystyle\min_{g\in\mathcal{G}}\max_{f\in\mathcal{F}_{S_{g}}}\underbrace{% \frac{1}{n}\sum_{i=1}^{n}\ell(g(X_{i}),Y_{i})}_{\widehat{\mathsf{R}}(g)}+% \gamma\underbrace{\frac{1}{n}\sum_{i=1}^{n}\left[\ell_{u}(g(X_{i}),Y_{i})f(X_{% i},E_{i})-0.5\{f(X_{i},E_{i})\}^{2}\right]}_{\widehat{\mathsf{J}}(g,f)},roman_min start_POSTSUBSCRIPT italic_g ∈ caligraphic_G end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT italic_f ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ℓ ( italic_g ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT over^ start_ARG sansserif_R end_ARG ( italic_g ) end_POSTSUBSCRIPT + italic_γ under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT [ roman_ℓ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_g ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_f ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_E start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - 0.5 { italic_f ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_E start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] end_ARG start_POSTSUBSCRIPT over^ start_ARG sansserif_J end_ARG ( italic_g , italic_f ) end_POSTSUBSCRIPT , (A.2)

where γ𝛾\gammaitalic_γ is the hyper-parameter to be determined, and Sg={f(x,e),f(x,e)=w(xSg,e) for some w}subscriptsubscript𝑆𝑔formulae-sequence𝑓𝑥𝑒𝑓𝑥𝑒𝑤subscript𝑥subscript𝑆𝑔𝑒 for some 𝑤\mathcal{F}_{S_{g}}=\{f(x,e)\in\mathcal{F},f(x,e)=w(x_{S_{g}},e)\text{ for % some }w\}caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT = { italic_f ( italic_x , italic_e ) ∈ caligraphic_F , italic_f ( italic_x , italic_e ) = italic_w ( italic_x start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_e ) for some italic_w }. Similar to the calculation in Section 1.2, one can expect that minimizing the population counterpart of the focused adversarial invariance regularizer maxfSg𝖩^(g,f)subscript𝑓subscriptsubscript𝑆𝑔^𝖩𝑔𝑓\max_{f\in\mathcal{F}_{S_{g}}}\widehat{\mathsf{J}}(g,f)roman_max start_POSTSUBSCRIPT italic_f ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT over^ start_ARG sansserif_J end_ARG ( italic_g , italic_f ) shares a similar nature of imposing (A.1). One can derive non-asymptotic identification and estimation error results akin to Theorem 4 and Theorem 5 provided strong convexity and certain Lipschitz property of the loss (u,y)𝑢𝑦\ell(u,y)roman_ℓ ( italic_u , italic_y ). We leave this for future studies.

A.4 Discussion on Relaxing Nonparametric Invariance Pursuit Identification Condition

Given our FAIR criterion search for the most predictive variable set whose conditional expectations remain across different environments, that is, when γ𝛾\gamma\to\inftyitalic_γ → ∞ and n=𝑛n=\inftyitalic_n = ∞, our population-level objective is equivalent to the following program,

mingΘ1||e𝔼[|Y(e)g(X(e))|2]s.t.g(x)m(e,Sg)(x)e.formulae-sequencesubscript𝑔Θ1subscript𝑒𝔼delimited-[]superscriptsuperscript𝑌𝑒𝑔superscript𝑋𝑒2𝑠𝑡𝑔𝑥superscript𝑚𝑒subscript𝑆𝑔𝑥for-all𝑒\displaystyle\min_{g\in\Theta}\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}% \mathbb{E}\left[|Y^{(e)}-g(X^{(e)})|^{2}\right]~{}~{}s.t.~{}~{}g(x)\equiv m^{(% e,S_{g})}(x)~{}~{}\forall e\in\mathcal{E}.roman_min start_POSTSUBSCRIPT italic_g ∈ roman_Θ end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E [ | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] italic_s . italic_t . italic_g ( italic_x ) ≡ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ( italic_x ) ∀ italic_e ∈ caligraphic_E .

We say a set S𝑆Sitalic_S is an invariant set if m(e,S)m¯(S)superscript𝑚𝑒𝑆superscript¯𝑚𝑆m^{(e,S)}\equiv\bar{m}^{(S)}italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≡ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT for any e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E. Therefore, one can slightly relax the identification condition as: Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is the most predictive invariant set, that is,

S[d],ifm(e,S)m¯(S),then either m¯(S)2<m¯(S)2orm¯(S)=m.formulae-sequencefor-all𝑆delimited-[]𝑑formulae-sequenceifsuperscript𝑚𝑒𝑆superscript¯𝑚𝑆then either subscriptnormsuperscript¯𝑚𝑆2subscriptnormsuperscript¯𝑚superscript𝑆2orsuperscript¯𝑚𝑆superscript𝑚\displaystyle\forall S\subseteq[d],\qquad\text{if}~{}~{}m^{(e,S)}\equiv\bar{m}% ^{(S)},\qquad\text{then either }\|\bar{m}^{(S)}\|_{2}<\|\bar{m}^{(S^{\star})}% \|_{2}~{}\text{or}~{}\bar{m}^{(S)}=m^{\star}.∀ italic_S ⊆ [ italic_d ] , if italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≡ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT , then either ∥ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < ∥ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT or over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT = italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT . (A.3)

The above condition is definitely weaker than 2 because 2 essentially requires the set Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is the maximum invariant set,

S[d],ifm(e,S)m¯(S),then m¯(SS)=m¯(S).formulae-sequencefor-all𝑆delimited-[]𝑑formulae-sequenceifsuperscript𝑚𝑒𝑆superscript¯𝑚𝑆then superscript¯𝑚𝑆superscript𝑆superscript¯𝑚superscript𝑆\displaystyle\forall S\subseteq[d],\qquad\text{if}~{}~{}m^{(e,S)}\equiv\bar{m}% ^{(S)},\qquad\text{then }\bar{m}^{(S\cup S^{\star})}=\bar{m}^{(S^{\star})}.∀ italic_S ⊆ [ italic_d ] , if italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≡ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT , then over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT = over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT . (A.4)

Here (A.4) just rewrites 2 in a manner similar to (A.3).

It is easy to derive results similar to Theorem 1 under (A.3) rather than 2. We can construct cases where (A.3) holds but 2 does not. Examples include Example 1 below with s(1)s(2)=1superscript𝑠1superscript𝑠21s^{(1)}s^{(2)}=1italic_s start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT italic_s start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT = 1 under which both {1}1\{1\}{ 1 } and {2}2\{2\}{ 2 } are invariant set but the set {1,2}12\{1,2\}{ 1 , 2 } is not. In this case, 2 no longer holds. However, our algorithm can still consistently estimate msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT provided (A.3) holds, that the variable X1subscript𝑋1X_{1}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT has better prediction power. In the main text, we still adopt 2 instead of (A.3). The main reasons are as follows. All our discussions are under the SCM with interventions setting.

Firstly, as further shown in Section 3, the cases where 2 fails to hold are degenerate cases. When the interventions are nondegenerate, there always exists a maximum invariant set Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, i.e., 2 holds. This means, (A.3) is somewhat “marginally” weaker than 2.

The second reason is the lack of semantic meaning of Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT under this case. When the interventions are non-degenerate, Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT can be interpreted as “contemporary/pragmatic direct causes” that can be expressed as direct causes + unaffected children + parents of unaffected children in Proposition 4, such a variable set also has certain robust transfer learning properties as stated in Proposition 2. All the above semantic meanings are valid even if the interventions are insufficient. However, when the interventions are degenerate such that 2 may not hold but (A.3) may hold, e.g., Example 1, all the two properties will no longer hold. If the true causal mechanism is X1YX2subscript𝑋1𝑌subscript𝑋2X_{1}\to Y\to X_{2}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → italic_Y → italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, then it is possible to construct data generating process such that Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT can be either {1}1\{1\}{ 1 } or {2}2\{2\}{ 2 } in (A.3).

A.5 Discussion on the Nondegenerate Intervention Condition

The conditions (a) and (b) in 5 are imposed to eliminate some degenerate cases. To illustrate the intuitions why such two conditions are needed, and how such a condition will hold in general. We consider the following two examples.

Introduction of condition (a)

From a high-level viewpoint, the introduction of condition (a) is to eliminate the cases where though there are shifts in condition distributions among different environments, it happens that there are no shifts in conditional expectations. This can be illustrated in the following example.

Example 1.

Consider the following canonical model also presented in Example 4.1 in Fan et al., (2023).

X1(e)superscriptsubscript𝑋1𝑒\displaystyle X_{1}^{(e)}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT 0.5U1absent0.5subscript𝑈1\displaystyle\leftarrow\sqrt{0.5}U_{1}← square-root start_ARG 0.5 end_ARG italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT
Y(e)superscript𝑌𝑒\displaystyle Y^{(e)}italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT X1(e)+0.5U3absentsuperscriptsubscript𝑋1𝑒0.5subscript𝑈3\displaystyle\leftarrow X_{1}^{(e)}+\sqrt{0.5}U_{3}← italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT + square-root start_ARG 0.5 end_ARG italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT
X2(e)superscriptsubscript𝑋2𝑒\displaystyle X_{2}^{(e)}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT s(e)Y(e)+U2absentsuperscript𝑠𝑒superscript𝑌𝑒subscript𝑈2\displaystyle\leftarrow s^{(e)}Y^{(e)}+U_{2}← italic_s start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT + italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT

where U1,U2,U3subscript𝑈1subscript𝑈2subscript𝑈3U_{1},U_{2},U_{3}italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT are independent standard normal variables, and ={1,2}12\mathcal{E}=\{1,2\}caligraphic_E = { 1 , 2 }. We let e=1𝑒1e=1italic_e = 1 be the observational environment and e=2𝑒2e=2italic_e = 2 be the interventional environment where the linear effect of Y𝑌Yitalic_Y on X2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are intervened (s(1)s(2)superscript𝑠1superscript𝑠2s^{(1)}\neq s^{(2)}italic_s start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ≠ italic_s start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT). We also focus on the regime where s(1)+s(2)0superscript𝑠1superscript𝑠20s^{(1)}+s^{(2)}\neq 0italic_s start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT + italic_s start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ≠ 0 such that running least squares will lead to a biased solution.

In the above model, we can see that

Y(e)|X2(e)𝒩(s(e)(s(e))2+1X2(e),1(s(e))2+1)similar-toconditionalsuperscript𝑌𝑒superscriptsubscript𝑋2𝑒𝒩superscript𝑠𝑒superscriptsuperscript𝑠𝑒21superscriptsubscript𝑋2𝑒1superscriptsuperscript𝑠𝑒21\displaystyle Y^{(e)}|X_{2}^{(e)}\sim\mathcal{N}\left(\frac{s^{(e)}}{(s^{(e)})% ^{2}+1}X_{2}^{(e)},\frac{1}{(s^{(e)})^{2}+1}\right)italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∼ caligraphic_N ( divide start_ARG italic_s start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_ARG start_ARG ( italic_s start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 1 end_ARG italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , divide start_ARG 1 end_ARG start_ARG ( italic_s start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 1 end_ARG )

It is easy to check under the case of no-degenerated child (s(1)+s(2)0superscript𝑠1superscript𝑠20s^{(1)}+s^{(2)}\neq 0italic_s start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT + italic_s start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ≠ 0) and faithfulness on M~~𝑀\widetilde{M}over~ start_ARG italic_M end_ARG (s(1)s(2)superscript𝑠1superscript𝑠2s^{(1)}\neq s^{(2)}italic_s start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ≠ italic_s start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT). We have

Y(1)|X1(1)𝑑Y(2)|X2(2),superscript𝑌1superscriptsubscript𝑋11𝑑superscript𝑌2superscriptsubscript𝑋22\displaystyle Y^{(1)}|X_{1}^{(1)}\overset{d}{\neq}Y^{(2)}|X_{2}^{(2)},italic_Y start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT overitalic_d start_ARG ≠ end_ARG italic_Y start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ,

or in other words, YE|X2Y\perp\!\!\!\perp E|X_{2}italic_Y ⟂ ⟂ italic_E | italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. However, when s(1)=1/s(2)=ssuperscript𝑠11superscript𝑠2𝑠s^{(1)}=1/s^{(2)}=sitalic_s start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = 1 / italic_s start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT = italic_s, the following holds

𝔼[Y(1)|X2(1)=x]=s(1)(s(1))2+1x=ss2+1x=s(2)(s(2))2+1x=𝔼[Y(2)|X2(2)=x]𝔼delimited-[]conditionalsuperscript𝑌1superscriptsubscript𝑋21𝑥superscript𝑠1superscriptsuperscript𝑠121𝑥𝑠superscript𝑠21𝑥superscript𝑠2superscriptsuperscript𝑠221𝑥𝔼delimited-[]conditionalsuperscript𝑌2superscriptsubscript𝑋22𝑥\displaystyle\mathbb{E}[Y^{(1)}|X_{2}^{(1)}=x]=\frac{s^{(1)}}{(s^{(1)})^{2}+1}% x=\frac{s}{s^{2}+1}x=\frac{s^{(2)}}{(s^{(2)})^{2}+1}x=\mathbb{E}[Y^{(2)}|X_{2}% ^{(2)}=x]blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = italic_x ] = divide start_ARG italic_s start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT end_ARG start_ARG ( italic_s start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 1 end_ARG italic_x = divide start_ARG italic_s end_ARG start_ARG italic_s start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 1 end_ARG italic_x = divide start_ARG italic_s start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT end_ARG start_ARG ( italic_s start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 1 end_ARG italic_x = blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT = italic_x ]

The introduction of 5 (a) is to rule out the cases where s(1)=1/s(2)=ssuperscript𝑠11superscript𝑠2𝑠s^{(1)}=1/s^{(2)}=sitalic_s start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = 1 / italic_s start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT = italic_s. And it is easy to see when s(1)superscript𝑠1s^{(1)}italic_s start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT and s(2)superscript𝑠2s^{(2)}italic_s start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT are independently generated from some prior distribution that is absolute continuous with respect to Lebesgue measure on \mathbb{R}blackboard_R, i.e., S(1),S(2)pssimilar-tosuperscript𝑆1superscript𝑆2subscript𝑝𝑠S^{(1)},S^{(2)}\sim p_{s}italic_S start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , italic_S start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT, then

[S(1)S(2)=1]=0.delimited-[]superscript𝑆1superscript𝑆210\displaystyle\mathbb{P}\left[S^{(1)}S^{(2)}=1\right]=0.blackboard_P [ italic_S start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT italic_S start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT = 1 ] = 0 .
Introduction of condition (b).

The condition (b), that the faithfulness condition on M~~𝑀\widetilde{M}over~ start_ARG italic_M end_ARG, is to eliminate the cases where though the interventions are applied, it happens that such interventions do not make an impact on the variables intervened. The following example presents such an example.

Example 2.

Consider the case where ={1,2}12\mathcal{E}=\{1,2\}caligraphic_E = { 1 , 2 }, and the data generating process is as follows

Y(e)superscript𝑌𝑒\displaystyle Y^{(e)}italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT U3absentsubscript𝑈3\displaystyle\leftarrow U_{3}← italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT
X1(e)subscriptsuperscript𝑋𝑒1\displaystyle X^{(e)}_{1}italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT Y(e)+e+U1absentsuperscript𝑌𝑒𝑒subscript𝑈1\displaystyle\leftarrow Y^{(e)}+e+U_{1}← italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT + italic_e + italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT
X2(e)superscriptsubscript𝑋2𝑒\displaystyle X_{2}^{(e)}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT 0.5Y(e)sX1(e)+e+U2.absent0.5superscript𝑌𝑒𝑠subscriptsuperscript𝑋𝑒1𝑒subscript𝑈2\displaystyle\leftarrow 0.5Y^{(e)}-sX^{(e)}_{1}+e+U_{2}.← 0.5 italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_s italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_e + italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT .

where U1,U2,U3subscript𝑈1subscript𝑈2subscript𝑈3U_{1},U_{2},U_{3}italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT are independent standard normal variables, s0.5𝑠0.5s\neq 0.5italic_s ≠ 0.5 is a fixed parameter. We let e=1𝑒1e=1italic_e = 1 be the observational environment and e=2𝑒2e=2italic_e = 2 be the interventional environment where shifts in mean are applied to the variables X1subscript𝑋1X_{1}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and X2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT.

In the above case, we have S=𝚙𝚊(3)=superscript𝑆𝚙𝚊3S^{\star}=\mathtt{pa}(3)=\emptysetitalic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = typewriter_pa ( 3 ) = ∅, and there exists a effective simultaneous intervention on (X1,X2)subscript𝑋1subscript𝑋2(X_{1},X_{2})( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). However, such an intervention will not affect X2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT if and only if s=1𝑠1s=1italic_s = 1 because its direct effect on X2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and the indirect effect passing through X1subscript𝑋1X_{1}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT get canceled provided s=1𝑠1s=1italic_s = 1. To be specific, X2(e)superscriptsubscript𝑋2𝑒X_{2}^{(e)}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT can be written as

X2(e)=0.5Y(e)s(Y(e)+e+U1)+e+U2=(0.5s)Y(e)sU1+U2+e(1s).superscriptsubscript𝑋2𝑒0.5superscript𝑌𝑒𝑠superscript𝑌𝑒𝑒subscript𝑈1𝑒subscript𝑈20.5𝑠superscript𝑌𝑒𝑠subscript𝑈1subscript𝑈2𝑒1𝑠\displaystyle X_{2}^{(e)}=0.5Y^{(e)}-s(Y^{(e)}+e+U_{1})+e+U_{2}=(0.5-s)Y^{(e)}% -sU_{1}+U_{2}+e(1-s).italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = 0.5 italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_s ( italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT + italic_e + italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + italic_e + italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ( 0.5 - italic_s ) italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_s italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_e ( 1 - italic_s ) .

This implies that

YE|X2\displaystyle Y\perp\!\!\!\perp E|X_{2}italic_Y ⟂ ⟂ italic_E | italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT

provided s=1𝑠1s=1italic_s = 1, under which the faithfulness on M~~𝑀\widetilde{M}over~ start_ARG italic_M end_ARG fails to hold because we have YG~E|X2subscriptperpendicular-toperpendicular-to~𝐺𝑌conditional𝐸subscript𝑋2Y\mathchoice{\mathrel{\hbox to0.0pt{\kern 7.77777pt\kern-5.27776pt$% \displaystyle\not$\hss}{\perp\!\!\!\perp}}}{\mathrel{\hbox to0.0pt{\kern 7.777% 77pt\kern-5.27776pt$\textstyle\not$\hss}{\perp\!\!\!\perp}}}{\mathrel{\hbox to% 0.0pt{\kern 2.75006pt\kern-4.11108pt$\scriptstyle\not$\hss}{\perp\!\!\!\perp}}% }{\mathrel{\hbox to0.0pt{\kern 1.25006pt\kern-3.3333pt$\scriptscriptstyle\not$% \hss}{\perp\!\!\!\perp}}}_{\widetilde{G}}E|X_{2}italic_Y start_RELOP / ⟂ ⟂ end_RELOP start_POSTSUBSCRIPT over~ start_ARG italic_G end_ARG end_POSTSUBSCRIPT italic_E | italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT since the path YX2E𝑌subscript𝑋2𝐸Y\to X_{2}\leftarrow Eitalic_Y → italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ← italic_E is not blocked by X2subscript𝑋2X_{2}italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. However, if the parameter s𝑠sitalic_s is also generated from some prior distribution that is absolute continuous with respect to Lebesgue measure on \mathbb{R}blackboard_R, i.e., Spssimilar-to𝑆subscript𝑝𝑠S\sim p_{s}italic_S ∼ italic_p start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT, then

[S=1]=0.delimited-[]𝑆10\displaystyle\mathbb{P}\left[S=1\right]=0.blackboard_P [ italic_S = 1 ] = 0 .

A.6 The Complete Statement of Proposition 2

Specifically, we construct a unified SCM (X,Y,E)M¯(𝒮¯,ν)similar-to𝑋𝑌𝐸¯𝑀¯𝒮𝜈(X,Y,E)\sim\bar{M}(\bar{\mathcal{S}},\nu)( italic_X , italic_Y , italic_E ) ∼ over¯ start_ARG italic_M end_ARG ( over¯ start_ARG caligraphic_S end_ARG , italic_ν ) based on M(0)superscript𝑀0M^{(0)}italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT and new environment M(t)superscript𝑀𝑡M^{(t)}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT as follows:

E𝐸\displaystyle Eitalic_E Uniform({0,t})absentUniform0𝑡\displaystyle\leftarrow\text{Uniform}(\{0,t\})← Uniform ( { 0 , italic_t } )
Xjsubscript𝑋𝑗\displaystyle X_{j}italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT {f¯j(X𝚙𝚊(j),Uj):=fj(0)(X𝚙𝚊(j),Uj)j[d]If¯j(X𝚙𝚊(j),E,Uj):=fj(t)(X𝚙𝚊(j),Uj)jIabsentcasesassignsubscript¯𝑓𝑗subscript𝑋𝚙𝚊𝑗subscript𝑈𝑗subscriptsuperscript𝑓0𝑗subscript𝑋𝚙𝚊𝑗subscript𝑈𝑗for-all𝑗delimited-[]𝑑𝐼assignsubscript¯𝑓𝑗subscript𝑋𝚙𝚊𝑗𝐸subscript𝑈𝑗subscriptsuperscript𝑓𝑡𝑗subscript𝑋𝚙𝚊𝑗subscript𝑈𝑗for-all𝑗𝐼\displaystyle\leftarrow\begin{cases}\bar{f}_{j}(X_{\mathtt{pa}(j)},U_{j}):=f^{% (0)}_{j}(X_{\mathtt{pa}(j)},U_{j})&\qquad\forall j\in[d]\setminus I\\ \bar{f}_{j}(X_{\mathtt{pa}(j)},E,U_{j}):=f^{(t)}_{j}(X_{\mathtt{pa}(j)},U_{j})% &\qquad\forall j\in I\end{cases}← { start_ROW start_CELL over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) := italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_CELL start_CELL ∀ italic_j ∈ [ italic_d ] ∖ italic_I end_CELL end_ROW start_ROW start_CELL over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT , italic_E , italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) := italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT typewriter_pa ( italic_j ) end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_CELL start_CELL ∀ italic_j ∈ italic_I end_CELL end_ROW
Y𝑌\displaystyle Yitalic_Y f¯d+1(X𝚙𝚊(d+1),Ud+1):=fd+1(X𝚙𝚊(𝚍+𝟷),Ud+1).absentsubscript¯𝑓𝑑1subscript𝑋𝚙𝚊𝑑1subscript𝑈𝑑1assignsubscript𝑓𝑑1subscript𝑋𝚙𝚊𝚍1subscript𝑈𝑑1\displaystyle\leftarrow\bar{f}_{d+1}(X_{\mathtt{pa}}(d+1),U_{d+1}):=f_{d+1}(X_% {\mathtt{pa(d+1)}},U_{d+1}).← over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT typewriter_pa end_POSTSUBSCRIPT ( italic_d + 1 ) , italic_U start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ) := italic_f start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT typewriter_pa ( typewriter_d + typewriter_1 ) end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_d + 1 end_POSTSUBSCRIPT ) .

We suppose the following condition similar to 5 holds in the constructed graph.

Condition 6.

The following holds for M¯¯𝑀\bar{M}over¯ start_ARG italic_M end_ARG: (1) S[d]for-all𝑆delimited-[]𝑑\forall S\subseteq[d]∀ italic_S ⊆ [ italic_d ] containing Y𝑌Yitalic_Y’s descendants, i.e., d+1jS𝚊𝚝(j)𝑑1subscript𝑗𝑆𝚊𝚝𝑗d+1\in\cup_{j\in S}\mathtt{at}(j)italic_d + 1 ∈ ∪ start_POSTSUBSCRIPT italic_j ∈ italic_S end_POSTSUBSCRIPT typewriter_at ( italic_j ), if EM¯Y|XSsubscriptperpendicular-toperpendicular-to¯𝑀𝐸conditional𝑌subscript𝑋𝑆E\mathchoice{\mathrel{\hbox to0.0pt{\kern 7.77777pt\kern-5.27776pt$% \displaystyle\not$\hss}{\perp\!\!\!\perp}}}{\mathrel{\hbox to0.0pt{\kern 7.777% 77pt\kern-5.27776pt$\textstyle\not$\hss}{\perp\!\!\!\perp}}}{\mathrel{\hbox to% 0.0pt{\kern 2.75006pt\kern-4.11108pt$\scriptstyle\not$\hss}{\perp\!\!\!\perp}}% }{\mathrel{\hbox to0.0pt{\kern 1.25006pt\kern-3.3333pt$\scriptscriptstyle\not$% \hss}{\perp\!\!\!\perp}}}_{\bar{M}}Y|X_{S}italic_E start_RELOP / ⟂ ⟂ end_RELOP start_POSTSUBSCRIPT over¯ start_ARG italic_M end_ARG end_POSTSUBSCRIPT italic_Y | italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT, then (μ(0)μ(t))({m(0,S)m(t,S)})>0superscript𝜇0superscript𝜇𝑡superscript𝑚0𝑆superscript𝑚𝑡𝑆0(\mu^{(0)}\land\mu^{(t)})(\{m^{(0,S)}\neq m^{(t,S)}\})>0( italic_μ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ∧ italic_μ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ( { italic_m start_POSTSUPERSCRIPT ( 0 , italic_S ) end_POSTSUPERSCRIPT ≠ italic_m start_POSTSUPERSCRIPT ( italic_t , italic_S ) end_POSTSUPERSCRIPT } ) > 0; (2) M¯¯𝑀\bar{M}over¯ start_ARG italic_M end_ARG is faithful, that is,

DisjointA,B,C[d+2],ZAZB|ZC(a)ZAG¯ZB|ZC,\displaystyle\forall~{}\text{Disjoint}~{}A,B,C\subseteq[d+2],\qquad Z_{A}\perp% \!\!\!\perp Z_{B}|Z_{C}~{}~{}\overset{(a)}{\Longrightarrow}~{}~{}Z_{A}\perp\!% \!\!\perp_{\bar{G}}Z_{B}|Z_{C},∀ Disjoint italic_A , italic_B , italic_C ⊆ [ italic_d + 2 ] , italic_Z start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ⟂ ⟂ italic_Z start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT | italic_Z start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT start_OVERACCENT ( italic_a ) end_OVERACCENT start_ARG ⟹ end_ARG italic_Z start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ⟂ ⟂ start_POSTSUBSCRIPT over¯ start_ARG italic_G end_ARG end_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT | italic_Z start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ,

where ZAG~ZB|ZCZ_{A}\perp\!\!\!\perp_{\widetilde{G}}Z_{B}|Z_{C}italic_Z start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ⟂ ⟂ start_POSTSUBSCRIPT over~ start_ARG italic_G end_ARG end_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT | italic_Z start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT means the node set A𝐴Aitalic_A and B𝐵Bitalic_B and d-separated conditioned on C𝐶Citalic_C in the graph G¯=G(M¯)¯𝐺𝐺¯𝑀\bar{G}=G(\bar{M})over¯ start_ARG italic_G end_ARG = italic_G ( over¯ start_ARG italic_M end_ARG ).

We are ready to give a complete statement of Proposition 2.

Proposition 5 (Formal Statement of Proposition 2).

Under the setting of Theorem 2, for a new environment t𝑡titalic_t with SCM M(t)={𝒮(t),ν}superscript𝑀𝑡superscript𝒮𝑡𝜈M^{(t)}=\{\mathcal{S}^{(t)},\nu\}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = { caligraphic_S start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_ν } satisfying fj(t)fj(0)superscriptsubscript𝑓𝑗𝑡superscriptsubscript𝑓𝑗0f_{j}^{(t)}\equiv f_{j}^{(0)}italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ≡ italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT for any j[d+1]I𝑗delimited-[]𝑑1𝐼j\in[d+1]\setminus Iitalic_j ∈ [ italic_d + 1 ] ∖ italic_I, i.e., only XIsubscript𝑋𝐼X_{I}italic_X start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT is intervened, we also have 𝔼[Y(t)|XS(t)]𝔼[Y(0)|XS(0)]𝔼delimited-[]conditionalsuperscript𝑌𝑡superscriptsubscript𝑋subscript𝑆𝑡𝔼delimited-[]conditionalsuperscript𝑌0superscriptsubscript𝑋subscript𝑆0\mathbb{E}[Y^{(t)}|X_{S_{\star}}^{(t)}]\equiv\mathbb{E}[Y^{(0)}|X_{S_{\star}}^% {(0)}]blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] ≡ blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ]. Suppose further that 6 holds for the constructed SCM M¯¯𝑀\bar{M}over¯ start_ARG italic_M end_ARG. Then Ssubscript𝑆S_{\star}italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT is the unique largest set whose conditional expectation is transferable, i.e., for any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] such that 𝔼[Y(t)|XSS(t)]𝔼[Y(t)|XS(t)]𝔼delimited-[]conditionalsuperscript𝑌𝑡superscriptsubscript𝑋subscript𝑆𝑆𝑡𝔼delimited-[]conditionalsuperscript𝑌𝑡superscriptsubscript𝑋subscript𝑆𝑡\mathbb{E}[Y^{(t)}|X_{S_{\star}\cup S}^{(t)}]\neq\mathbb{E}[Y^{(t)}|X_{S_{% \star}}^{(t)}]blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ∪ italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] ≠ blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ], one has 𝔼[Y(t)|XS(t)]𝔼[Y(0)|XS(0)]𝔼delimited-[]conditionalsuperscript𝑌𝑡superscriptsubscript𝑋𝑆𝑡𝔼delimited-[]conditionalsuperscript𝑌0superscriptsubscript𝑋𝑆0\mathbb{E}[Y^{(t)}|X_{S}^{(t)}]\neq\mathbb{E}[Y^{(0)}|X_{S}^{(0)}]blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] ≠ blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ].

Appendix B Generic Results and Its Applications

B.1 Main Result for the General FAIR Least Squares Estimator

This section is designed to offer a unified main result characterizing when the FAIR least squares estimator can identify the target regression function together with a non-asymptotic L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error bound for general (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ). We first introduce some standard regularity conditions.

Condition 7 (Data Generating Process).

We collect data from ||+superscript|\mathcal{E}|\in\mathbb{N}^{+}| caligraphic_E | ∈ blackboard_N start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT environments. For each environment e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, we observe (X1(e),Y1(e)),,(Xn(e),Yn(e))i.i.d.μ(e)(X_{1}^{(e)},Y_{1}^{(e)}),\ldots,(X_{n}^{(e)},Y_{n}^{(e)})\overset{i.i.d.}{% \sim}\mu^{(e)}( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) , … , ( italic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_Y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) start_OVERACCENT italic_i . italic_i . italic_d . end_OVERACCENT start_ARG ∼ end_ARG italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT.

Condition 8 (Sub-Gaussian Response).

For any e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E and t0𝑡0t\geq 0italic_t ≥ 0, [|Y(e)|t]Cyet2/(2σy2)delimited-[]superscript𝑌𝑒𝑡subscript𝐶𝑦superscript𝑒superscript𝑡22superscriptsubscript𝜎𝑦2\mathbb{P}\left[|Y^{(e)}|\geq t\right]\leq C_{y}e^{-t^{2}/(2\sigma_{y}^{2})}blackboard_P [ | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | ≥ italic_t ] ≤ italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT - italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( 2 italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT, where σy>0subscript𝜎𝑦0\sigma_{y}>0italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT > 0 and Cy>0subscript𝐶𝑦0C_{y}>0italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT > 0 are some constants independent of e𝑒eitalic_e and t𝑡titalic_t.

To impose statistical complexity on the function classes we used, we introduce the definition of localized population Rademacher complexity, described as follows.

Definition 5 (Localized Population Rademacher Complexity).

For a given radius δ>0𝛿0\delta>0italic_δ > 0, function class \mathcal{H}caligraphic_H, and distribution ν𝜈\nuitalic_ν, define

Rn,ν(δ;)=𝔼X,ε[suph,hL2(ν)δ|1ni=1nεih(Xi)|],subscript𝑅𝑛𝜈𝛿subscript𝔼𝑋𝜀delimited-[]subscriptsupremumformulae-sequencesubscriptnormsubscript𝐿2𝜈𝛿1𝑛superscriptsubscript𝑖1𝑛subscript𝜀𝑖subscript𝑋𝑖\displaystyle R_{n,\nu}(\delta;\mathcal{H})=\mathbb{E}_{X,\varepsilon}\left[% \sup_{h\in\mathcal{H},\|h\|_{L_{2}(\nu)}\leq\delta}\left|\frac{1}{n}\sum_{i=1}% ^{n}\varepsilon_{i}h(X_{i})\right|\right],italic_R start_POSTSUBSCRIPT italic_n , italic_ν end_POSTSUBSCRIPT ( italic_δ ; caligraphic_H ) = blackboard_E start_POSTSUBSCRIPT italic_X , italic_ε end_POSTSUBSCRIPT [ roman_sup start_POSTSUBSCRIPT italic_h ∈ caligraphic_H , ∥ italic_h ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_ν ) end_POSTSUBSCRIPT ≤ italic_δ end_POSTSUBSCRIPT | divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_ε start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_h ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) | ] ,

where X1,,Xnsubscript𝑋1subscript𝑋𝑛X_{1},\ldots,X_{n}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT are i.i.d. samples from distribution ν𝜈\nuitalic_ν, and ε1,,εnsubscript𝜀1subscript𝜀𝑛\varepsilon_{1},\ldots,\varepsilon_{n}italic_ε start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_ε start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT are i.i.d. Rademacher variables taking values in {1,+1}11\{-1,+1\}{ - 1 , + 1 } with equal probability which are also independent of (X1,,Xn)subscript𝑋1subscript𝑋𝑛(X_{1},\ldots,X_{n})( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ).

Condition 9 (Function Class).

Suppose the following holds for the function class 𝒢𝒢\mathcal{G}caligraphic_G and \mathcal{F}caligraphic_F we use:

  • (1).

    It is uniformly bounded by B1𝐵1B\geq 1italic_B ≥ 1, i.e., suph𝒢hBsubscriptsupremum𝒢subscriptnorm𝐵\sup_{h\in\mathcal{G}\cup\mathcal{F}}\|h\|_{\infty}\leq Broman_sup start_POSTSUBSCRIPT italic_h ∈ caligraphic_G ∪ caligraphic_F end_POSTSUBSCRIPT ∥ italic_h ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_B.

  • (2).

    000\in\mathcal{F}0 ∈ caligraphic_F and the statistical complexity of the function classes 𝒢+:={g+f:g𝒢,fSg}assign𝒢conditional-set𝑔𝑓formulae-sequence𝑔𝒢𝑓subscriptsubscript𝑆𝑔\mathcal{G}+\mathcal{F}:=\{g+f:g\in\mathcal{G},f\in\mathcal{F}_{S_{g}}\}caligraphic_G + caligraphic_F := { italic_g + italic_f : italic_g ∈ caligraphic_G , italic_f ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT } is upper-bounded by δnsubscript𝛿𝑛\delta_{n}italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. In particular, there exists some quantity 1/nδn<11𝑛subscript𝛿𝑛11/n\leq\delta_{n}<11 / italic_n ≤ italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT < 1 such that

    Rn,μ(e)(δ;𝒢)BδnδandRn,μ(e)(δ;(𝒢+))2Bδnδformulae-sequencesubscript𝑅𝑛superscript𝜇𝑒𝛿𝒢𝐵subscript𝛿𝑛𝛿andsubscript𝑅𝑛superscript𝜇𝑒𝛿𝒢2𝐵subscript𝛿𝑛𝛿\displaystyle R_{n,\mu^{(e)}}(\delta;\partial\mathcal{G})\leq B\delta_{n}% \delta\qquad\text{and}\qquad R_{n,\mu^{(e)}}(\delta;\partial(\mathcal{G}+% \mathcal{F}))\leq 2B\delta_{n}\deltaitalic_R start_POSTSUBSCRIPT italic_n , italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_δ ; ∂ caligraphic_G ) ≤ italic_B italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_δ and italic_R start_POSTSUBSCRIPT italic_n , italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_δ ; ∂ ( caligraphic_G + caligraphic_F ) ) ≤ 2 italic_B italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_δ

    for any e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E and δ[δn,2B]𝛿subscript𝛿𝑛2𝐵\delta\in[\delta_{n},2B]italic_δ ∈ [ italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , 2 italic_B ], where ={hh:h,h}conditional-setsuperscriptsuperscript\partial\mathcal{H}=\{h-h^{\prime}:h,h^{\prime}\in\mathcal{H}\}∂ caligraphic_H = { italic_h - italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT : italic_h , italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_H }.

Note that when 𝒢=𝒢𝒢𝒢-\mathcal{G}=\mathcal{G}- caligraphic_G = caligraphic_G, Rn,μ(e)(δ;𝒢)=Rn,μ(e)(δ;𝒢)subscript𝑅𝑛superscript𝜇𝑒𝛿𝒢subscript𝑅𝑛superscript𝜇𝑒𝛿𝒢R_{n,\mu^{(e)}}(\delta;\partial\mathcal{G})=R_{n,\mu^{(e)}}(\delta;\mathcal{G})italic_R start_POSTSUBSCRIPT italic_n , italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_δ ; ∂ caligraphic_G ) = italic_R start_POSTSUBSCRIPT italic_n , italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_δ ; caligraphic_G ). The above three assumptions 7, 8, 9 are standard in the theoretical analysis of regression. Recall the definition of m(e,S)superscript𝑚𝑒𝑆m^{(e,S)}italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT and m¯(S)superscript¯𝑚𝑆\bar{m}^{(S)}over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT in Section 2.1, now we introduce the specific assumption in our multi-environment regression setting.

Condition 10 (Invariance and Identification).

For any S𝑆Sitalic_S, let 𝒢S¯𝒢Ssubscript𝒢𝑆¯subscript𝒢𝑆\overline{\mathcal{G}_{S}}\supseteq\mathcal{G}_{S}over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG ⊇ caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT, S¯Ssubscript𝑆¯subscript𝑆\overline{\mathcal{F}_{S}}\supseteq\mathcal{F}_{S}over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG ⊇ caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT be closed subspaces of ΘSsubscriptΘ𝑆\Theta_{S}roman_Θ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT satisfying 𝒢S¯S¯¯subscript𝒢𝑆¯subscript𝑆\overline{\mathcal{G}_{S}}\subseteq\overline{\mathcal{F}_{S}}over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG ⊆ over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG. In this case, we can define Π𝒜(h)=argmina𝒜ah2subscriptΠ𝒜subscriptargmin𝑎𝒜subscriptnorm𝑎2\Pi_{\mathcal{A}}(h)=\mathop{\mathrm{argmin}}_{a\in\mathcal{A}}\|a-h\|_{2}roman_Π start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT ( italic_h ) = roman_argmin start_POSTSUBSCRIPT italic_a ∈ caligraphic_A end_POSTSUBSCRIPT ∥ italic_a - italic_h ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and Π𝒜(e)(h)=argmina𝒜ah2,esubscriptsuperscriptΠ𝑒𝒜subscriptargmin𝑎𝒜subscriptnorm𝑎2𝑒\Pi^{(e)}_{\mathcal{A}}(h)=\mathop{\mathrm{argmin}}_{a\in\mathcal{A}}\|a-h\|_{% 2,e}roman_Π start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT ( italic_h ) = roman_argmin start_POSTSUBSCRIPT italic_a ∈ caligraphic_A end_POSTSUBSCRIPT ∥ italic_a - italic_h ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT when 𝒜{S¯,𝒢S¯}𝒜¯subscript𝑆¯subscript𝒢𝑆\mathcal{A}\in\{\overline{\mathcal{F}_{S}},\overline{\mathcal{G}_{S}}\}caligraphic_A ∈ { over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG , over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG } and hΘSsubscriptΘ𝑆h\in\Theta_{S}italic_h ∈ roman_Θ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT. Suppose the following holds:

  • 1.

    (Invariance) There exists some index set S[d]superscript𝑆delimited-[]𝑑S^{\star}\subseteq[d]italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ [ italic_d ] such that

    eΠS¯(e)(m(e,S))=Π𝒢S¯(m¯(S)):=gformulae-sequencefor-all𝑒subscriptsuperscriptΠ𝑒¯subscriptsuperscript𝑆superscript𝑚𝑒superscript𝑆subscriptΠ¯subscript𝒢superscript𝑆superscript¯𝑚superscript𝑆assignsuperscript𝑔\displaystyle\forall e\in\mathcal{E}\qquad\Pi^{(e)}_{\overline{\mathcal{F}_{S^% {\star}}}}(m^{(e,S^{\star})})=\Pi_{\overline{\mathcal{G}_{S^{\star}}}}(\bar{m}% ^{(S^{\star})}):=g^{\star}∀ italic_e ∈ caligraphic_E roman_Π start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ) = roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ) := italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT
  • 2.

    (Heterogeneity) For each S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ], if 𝖻𝒢(S)>0subscript𝖻𝒢𝑆0\mathsf{b}_{\mathcal{G}}(S)>0sansserif_b start_POSTSUBSCRIPT caligraphic_G end_POSTSUBSCRIPT ( italic_S ) > 0, then 𝖽¯𝒢,(S)>0subscript¯𝖽𝒢𝑆0\bar{\mathsf{d}}_{\mathcal{G},\mathcal{F}}(S)>0over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT caligraphic_G , caligraphic_F end_POSTSUBSCRIPT ( italic_S ) > 0, where

    𝖻𝒢(S)=Π𝒢SS¯(m¯(SS))g22and𝖽¯𝒢,(S)=1||eΠS¯(e)(m(e,S))Π𝒢S¯(m¯(S))2,e2.subscript𝖻𝒢𝑆superscriptsubscriptnormsubscriptΠ¯subscript𝒢𝑆superscript𝑆superscript¯𝑚𝑆superscript𝑆superscript𝑔22andsubscript¯𝖽𝒢𝑆1subscript𝑒superscriptsubscriptnormsuperscriptsubscriptΠ¯subscript𝑆𝑒superscript𝑚𝑒𝑆subscriptΠ¯subscript𝒢𝑆superscript¯𝑚𝑆2𝑒2\displaystyle\mathsf{b}_{\mathcal{G}}(S)=\|\Pi_{\overline{\mathcal{G}_{S\cup S% ^{\star}}}}(\bar{m}^{(S\cup S^{\star})})-g^{\star}\|_{2}^{2}~{}~{}\text{and}~{% }~{}\bar{\mathsf{d}}_{\mathcal{G},\mathcal{F}}(S)=\frac{1}{|\mathcal{E}|}\sum_% {e\in\mathcal{E}}\|\Pi_{\overline{\mathcal{F}_{S}}}^{(e)}(m^{(e,S)})-\Pi_{% \overline{\mathcal{G}_{S}}}(\bar{m}^{(S)})\|_{2,e}^{2}.sansserif_b start_POSTSUBSCRIPT caligraphic_G end_POSTSUBSCRIPT ( italic_S ) = ∥ roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ) - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT caligraphic_G , caligraphic_F end_POSTSUBSCRIPT ( italic_S ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ) - roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (B.1)
  • 3.

    (Nondegenerate Covariate) For any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] such that SSsuperscript𝑆𝑆S^{\star}\setminus S\neq\emptysetitalic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∖ italic_S ≠ ∅, we have infg𝒢S¯gg22sminsubscriptinfimum𝑔¯subscript𝒢𝑆superscriptsubscriptnorm𝑔superscript𝑔22subscript𝑠\inf_{g\in\overline{\mathcal{G}_{S}}}\|g-g^{\star}\|_{2}^{2}\geq s_{\min}roman_inf start_POSTSUBSCRIPT italic_g ∈ over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT ∥ italic_g - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT for some constant smin>0subscript𝑠0s_{\min}>0italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT > 0.

The first condition “invariance” specifies the target regression function gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT of interests and states the invariance structure imposed for our theoretical analysis. It relaxes the general conditional expectation invariance (1.1) when S¯ΘS¯subscript𝑆subscriptΘ𝑆\overline{\mathcal{F}_{S}}\subsetneq\Theta_{S}over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG ⊊ roman_Θ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT. Two leading examples are (1) the fully nonparametric class 𝒢S¯=S¯=ΘS¯subscript𝒢𝑆¯subscript𝑆subscriptΘ𝑆\overline{\mathcal{G}_{S}}=\overline{\mathcal{F}_{S}}=\Theta_{S}over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG = over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG = roman_Θ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT, and (2) linear class 𝒢S¯=S¯={f(x)=βSxS:βS|S|}¯subscript𝒢𝑆¯subscript𝑆conditional-set𝑓𝑥superscriptsubscript𝛽𝑆topsubscript𝑥𝑆subscript𝛽𝑆superscript𝑆\overline{\mathcal{G}_{S}}=\overline{\mathcal{F}_{S}}=\{f(x)=\beta_{S}^{\top}x% _{S}:\beta_{S}\in\mathbb{R}^{|S|}\}over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG = over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG = { italic_f ( italic_x ) = italic_β start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT : italic_β start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_S | end_POSTSUPERSCRIPT }. In the first example, we are interested in estimating the invariant conditional expectation g=msuperscript𝑔superscript𝑚g^{\star}=m^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, and the invariance condition requires the conditional expectation invariance (1.1), that

em(e,S)(x)=m(xS).formulae-sequencefor-all𝑒superscript𝑚𝑒superscript𝑆𝑥superscript𝑚subscript𝑥superscript𝑆\displaystyle\forall e\in\mathcal{E}\qquad m^{(e,S^{\star})}(x)=m^{\star}(x_{S% ^{\star}}).∀ italic_e ∈ caligraphic_E italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ( italic_x ) = italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) .

In the second example, when the covariance matrices 𝔼[X(e)(X(e))]𝔼delimited-[]superscript𝑋𝑒superscriptsuperscript𝑋𝑒top\mathbb{E}[X^{(e)}(X^{(e)})^{\top}]blackboard_E [ italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] across all the environments are all positive definite, we are interested in estimating the invariant linear predictor g(x)=xβsuperscript𝑔𝑥superscript𝑥topsuperscript𝛽g^{\star}(x)=x^{\top}\beta^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x ) = italic_x start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, and such the “invariance” condition only requires that

eβ(e,S)βwhereβ(e,S)=argminβd,β(S)c=0𝔼[|Y(e)βX(e)|2],formulae-sequencefor-all𝑒formulae-sequencesuperscript𝛽𝑒superscript𝑆superscript𝛽wheresuperscript𝛽𝑒superscript𝑆subscriptargminformulae-sequence𝛽superscript𝑑subscript𝛽superscriptsuperscript𝑆𝑐0𝔼delimited-[]superscriptsuperscript𝑌𝑒superscript𝛽topsuperscript𝑋𝑒2\displaystyle\forall e\in\mathcal{E}\qquad\beta^{(e,S^{\star})}\equiv\beta^{% \star}\qquad\text{where}~{}~{}~{}~{}\beta^{(e,S^{\star})}=\mathop{\mathrm{% argmin}}_{\beta\in\mathbb{R}^{d},\beta_{(S^{\star})^{c}}=0}\mathbb{E}[|Y^{(e)}% -\beta^{\top}X^{(e)}|^{2}],∀ italic_e ∈ caligraphic_E italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ≡ italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT where italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT = roman_argmin start_POSTSUBSCRIPT italic_β ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , italic_β start_POSTSUBSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = 0 end_POSTSUBSCRIPT blackboard_E [ | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_β start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ,

that is, the best linear predictors constrained on Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT among all the environment are the same. In this case, the conditional expectations m(e,S)(x)superscript𝑚𝑒superscript𝑆𝑥m^{(e,S^{\star})}(x)italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ( italic_x ) can be nonlinear or different.

The second condition “heterogeneity” is for identification and is fundamental to derive the population-level strong convexity with respect to gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. The two quantities in (B.1) are general forms of the bias mean and the bias variance, respectively. We refer to 𝖻𝒢(S)subscript𝖻𝒢𝑆\mathsf{b}_{\mathcal{G}}(S)sansserif_b start_POSTSUBSCRIPT caligraphic_G end_POSTSUBSCRIPT ( italic_S ) as the bias mean because 𝖻𝒢(S)subscript𝖻𝒢𝑆\mathsf{b}_{\mathcal{G}}(S)sansserif_b start_POSTSUBSCRIPT caligraphic_G end_POSTSUBSCRIPT ( italic_S ) is the precise bias of the estimator that regress Y𝑌Yitalic_Y on XSsubscript𝑋𝑆X_{S}italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT when SSsuperscript𝑆𝑆S^{\star}\subseteq Sitalic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ italic_S using all the data. This can be formally presented in the following proposition, which asserts that in the absence of our proposed regularizer, a vanilla least squares estimator will not consistently estimate gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, and the discrepancy g^g22superscriptsubscriptnorm^𝑔superscript𝑔22\|\widehat{g}-g^{\star}\|_{2}^{2}∥ over^ start_ARG italic_g end_ARG - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT is approximately equal to 𝖻(S)𝖻𝑆\mathsf{b}(S)sansserif_b ( italic_S ) when n𝑛nitalic_n is large.

Proposition 6 (Inconsistency of Least Squares Estimator).

Let S𝑆Sitalic_S be an index set such that SS[d]superscript𝑆𝑆delimited-[]𝑑S^{\star}\subseteq S\subseteq[d]italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ italic_S ⊆ [ italic_d ]. Assume 7, 8, 910 hold, and 𝖻𝒢(S)>0subscript𝖻𝒢𝑆0\mathsf{b}_{\mathcal{G}}(S)>0sansserif_b start_POSTSUBSCRIPT caligraphic_G end_POSTSUBSCRIPT ( italic_S ) > 0. Suppose further that Uδn,logn+infg𝒢SgΠ𝒢S¯(m¯(S))2=o(1)𝑈subscript𝛿𝑛𝑛subscriptinfimum𝑔subscript𝒢𝑆subscriptnorm𝑔subscriptΠ¯subscript𝒢𝑆superscript¯𝑚𝑆2𝑜1U\delta_{n,\log n}+\inf_{g\in\mathcal{G}_{S}}\|g-\Pi_{\overline{\mathcal{G}_{S% }}}(\bar{m}^{(S)})\|_{2}=o(1)italic_U italic_δ start_POSTSUBSCRIPT italic_n , roman_log italic_n end_POSTSUBSCRIPT + roman_inf start_POSTSUBSCRIPT italic_g ∈ caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_g - roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_o ( 1 ), where U𝑈Uitalic_U and δn,tsubscript𝛿𝑛𝑡\delta_{n,t}italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT are two constants defined in Theorem 4 below. Then the estimator g^𝚁subscript^𝑔𝚁\widehat{g}_{\mathtt{R}}over^ start_ARG italic_g end_ARG start_POSTSUBSCRIPT typewriter_R end_POSTSUBSCRIPT that minimizes (4.6) in 𝒢Ssubscript𝒢𝑆\mathcal{G}_{S}caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT satisfies, for large enough n𝑛nitalic_n,

0.99g^𝚁g22𝖻𝒢(S)1.010.99superscriptsubscriptnormsubscript^𝑔𝚁superscript𝑔22subscript𝖻𝒢𝑆1.01\displaystyle 0.99\leq\frac{\|\widehat{g}_{\mathtt{R}}-g^{\star}\|_{2}^{2}}{% \mathsf{b}_{\mathcal{G}}(S)}\leq 1.010.99 ≤ divide start_ARG ∥ over^ start_ARG italic_g end_ARG start_POSTSUBSCRIPT typewriter_R end_POSTSUBSCRIPT - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG sansserif_b start_POSTSUBSCRIPT caligraphic_G end_POSTSUBSCRIPT ( italic_S ) end_ARG ≤ 1.01

with probability at least 1{Cy(σy+1)+1}n1001subscript𝐶𝑦subscript𝜎𝑦11superscript𝑛1001-\{C_{y}(\sigma_{y}+1)+1\}n^{-100}1 - { italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT + 1 ) + 1 } italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT.

On the other hand, our proposed FAIR estimator will not converge to the biased solution under the condition “heterogeneity”. The condition “heterogeneity” is an abstraction of the “identification” condition in previous subsections, for example, 2 for FAIR-NN.

The last condition “nondegenerate covariate” ensures that the target regression function gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT cannot be exactly fitted by any function g𝑔gitalic_g whose dependent variable set Sgsubscript𝑆𝑔S_{g}italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT does not cover Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. It reduces to be “non-collinearity” when 𝒢𝒢\mathcal{G}caligraphic_G is linear.

In practice, we may only get access to the approximate solution. In our theoretical analysis, we focus on the performance of the approximate solution (g^,f^)^𝑔superscript^𝑓(\widehat{g},\widehat{f}^{\mathcal{E}})( over^ start_ARG italic_g end_ARG , over^ start_ARG italic_f end_ARG start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) satisfying

supf{Sg^}||𝖰^γ(g^,f)(γ+1)δ𝚘𝚙𝚝2𝖰^γ(g^,f^)infg𝒢supf{Sg}||𝖰^γ(g,f)+(1+γ)δ𝚘𝚙𝚝2subscriptsupremumsuperscript𝑓superscriptsubscriptsubscript𝑆^𝑔subscript^𝖰𝛾^𝑔superscript𝑓𝛾1superscriptsubscript𝛿𝚘𝚙𝚝2subscript^𝖰𝛾^𝑔superscript^𝑓subscriptinfimum𝑔𝒢subscriptsupremumsuperscript𝑓superscriptsubscriptsubscript𝑆𝑔subscript^𝖰𝛾𝑔superscript𝑓1𝛾superscriptsubscript𝛿𝚘𝚙𝚝2\displaystyle\sup_{f^{\mathcal{E}}\in\{\mathcal{F}_{S_{\widehat{g}}}\}^{|% \mathcal{E}|}}\widehat{\mathsf{Q}}_{\gamma}(\widehat{g},f^{\mathcal{E}})-(% \gamma+1)\delta_{\mathtt{opt}}^{2}\leq\widehat{\mathsf{Q}}_{\gamma}(\widehat{g% },\widehat{f}^{\mathcal{E}})\leq\inf_{g\in\mathcal{G}}\sup_{f^{\mathcal{E}}\in% \{\mathcal{F}_{S_{g}}\}^{|\mathcal{E}|}}\widehat{\mathsf{Q}}_{\gamma}(g,f^{% \mathcal{E}})+(1+\gamma)\delta_{\mathtt{opt}}^{2}roman_sup start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT over^ start_ARG italic_g end_ARG end_POSTSUBSCRIPT end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over^ start_ARG sansserif_Q end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( over^ start_ARG italic_g end_ARG , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) - ( italic_γ + 1 ) italic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ over^ start_ARG sansserif_Q end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( over^ start_ARG italic_g end_ARG , over^ start_ARG italic_f end_ARG start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) ≤ roman_inf start_POSTSUBSCRIPT italic_g ∈ caligraphic_G end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over^ start_ARG sansserif_Q end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) + ( 1 + italic_γ ) italic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (B.2)

with some optimization error δ𝚘𝚙𝚝2>0superscriptsubscript𝛿𝚘𝚙𝚝20\delta_{\mathtt{opt}}^{2}>0italic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT > 0, here γ𝛾\gammaitalic_γ in (1+γ)1𝛾(1+\gamma)( 1 + italic_γ ) is the same as that in 𝖰^γsubscript^𝖰𝛾\widehat{\mathsf{Q}}_{\gamma}over^ start_ARG sansserif_Q end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT. Now we are ready to state the main result regarding the statistical rate of convergence of our estimator g^^𝑔\widehat{g}over^ start_ARG italic_g end_ARG to gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, that is,

g^g2={(g^g)2μ¯x(dx)}1/2.subscriptnorm^𝑔superscript𝑔2superscriptsuperscript^𝑔superscript𝑔2subscript¯𝜇𝑥𝑑𝑥12\displaystyle\|\widehat{g}-g^{\star}\|_{2}=\left\{\int(\widehat{g}-g^{\star})^% {2}\bar{\mu}_{x}(dx)\right\}^{1/2}.∥ over^ start_ARG italic_g end_ARG - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = { ∫ ( over^ start_ARG italic_g end_ARG - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_d italic_x ) } start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT .
Theorem 4 (Main Result for the FAIR Estimator with 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT Loss).

Assume Conditions 710 hold. Define the critical threshold

γ:=supS[d]:𝖻𝒢(S)>0𝖻𝒢(S)𝖽¯𝒢,(S).assignsuperscript𝛾subscriptsupremum:𝑆delimited-[]𝑑subscript𝖻𝒢𝑆0subscript𝖻𝒢𝑆subscript¯𝖽𝒢𝑆\displaystyle\gamma^{\star}:=\sup_{S\subseteq[d]:\mathsf{b}_{\mathcal{G}}(S)>0% }\frac{\mathsf{b}_{\mathcal{G}}(S)}{\bar{\mathsf{d}}_{\mathcal{G},\mathcal{F}}% (S)}.italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT := roman_sup start_POSTSUBSCRIPT italic_S ⊆ [ italic_d ] : sansserif_b start_POSTSUBSCRIPT caligraphic_G end_POSTSUBSCRIPT ( italic_S ) > 0 end_POSTSUBSCRIPT divide start_ARG sansserif_b start_POSTSUBSCRIPT caligraphic_G end_POSTSUBSCRIPT ( italic_S ) end_ARG start_ARG over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT caligraphic_G , caligraphic_F end_POSTSUBSCRIPT ( italic_S ) end_ARG .

There exists some universal constant C𝐶Citalic_C such that, for any γ8γ𝛾8superscript𝛾\gamma\geq 8\gamma^{\star}italic_γ ≥ 8 italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, the following holds:

(1) General L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error rate. Let t>0𝑡0t>0italic_t > 0 be arbitrary. Define general approximation errors with respect to the function class 𝒢𝒢\mathcal{G}caligraphic_G and \mathcal{F}caligraphic_F as

δ𝚊,𝒢=infg𝒢Sgg2andδ𝚊,,𝒢(S)=1||esupg𝒢:Sg=SinffSgΠS¯(e)(m(e,S))gf2,e2,formulae-sequencesubscript𝛿𝚊𝒢subscriptinfimum𝑔subscript𝒢superscript𝑆subscriptnorm𝑔superscript𝑔2andsubscript𝛿𝚊𝒢𝑆1subscript𝑒subscriptsupremum:𝑔𝒢subscript𝑆𝑔𝑆subscriptinfimum𝑓subscriptsubscript𝑆𝑔superscriptsubscriptnormsuperscriptsubscriptΠ¯subscript𝑆𝑒superscript𝑚𝑒𝑆𝑔𝑓2𝑒2\displaystyle\delta_{\mathtt{a},\mathcal{G}}=\inf_{g\in\mathcal{G}_{S^{\star}}% }\|g-g^{\star}\|_{2}~{}~{}~{}~{}\text{and}~{}~{}~{}~{}\delta_{\mathtt{a},% \mathcal{F},\mathcal{G}}(S)=\sqrt{\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}% }\sup_{g\in\mathcal{G}:S_{g}=S}\inf_{f\in{\mathcal{F}_{S_{g}}}}\|\Pi_{% \overline{\mathcal{F}_{S}}}^{(e)}(m^{(e,S)})-g-f\|_{2,e}^{2}},italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_G end_POSTSUBSCRIPT = roman_inf start_POSTSUBSCRIPT italic_g ∈ caligraphic_G start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_g - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_S ) = square-root start_ARG divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_g ∈ caligraphic_G : italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = italic_S end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_f ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ) - italic_g - italic_f ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ,

and the stochastic error as δn,t=δn+{(log(nB||)+t+1)/n}1/2subscript𝛿𝑛𝑡subscript𝛿𝑛superscript𝑛𝐵𝑡1𝑛12\delta_{n,t}=\delta_{n}+\{(\log(nB|\mathcal{E}|)+t+1)/n\}^{1/2}italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT = italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + { ( roman_log ( italic_n italic_B | caligraphic_E | ) + italic_t + 1 ) / italic_n } start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT, where δnsubscript𝛿𝑛\delta_{n}italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is the quantity in 9. Let U=B(B+σylog(n||))𝑈𝐵𝐵subscript𝜎𝑦𝑛U=B(B+\sigma_{y}\sqrt{\log(n|\mathcal{E}|)})italic_U = italic_B ( italic_B + italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT square-root start_ARG roman_log ( italic_n | caligraphic_E | ) end_ARG ), then

g^g2C(1+γ)(Uδn,t+δ𝚊,𝒢+δ𝚊,,𝒢(Sg^)+δ𝚊,,𝒢(S)+δ𝚘𝚙𝚝).subscriptnorm^𝑔superscript𝑔2𝐶1𝛾𝑈subscript𝛿𝑛𝑡subscript𝛿𝚊𝒢subscript𝛿𝚊𝒢subscript𝑆^𝑔subscript𝛿𝚊𝒢superscript𝑆subscript𝛿𝚘𝚙𝚝\displaystyle\|\widehat{g}-g^{\star}\|_{2}\leq C(1+\gamma)\left(U\delta_{n,t}+% \delta_{\mathtt{a},\mathcal{G}}+\delta_{\mathtt{a},\mathcal{F},\mathcal{G}}(S_% {\widehat{g}})+\delta_{\mathtt{a},\mathcal{F},\mathcal{G}}(S^{\star})+\delta_{% \mathtt{opt}}\right).∥ over^ start_ARG italic_g end_ARG - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_C ( 1 + italic_γ ) ( italic_U italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_G end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT over^ start_ARG italic_g end_ARG end_POSTSUBSCRIPT ) + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) + italic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT ) . (B.3)

with probability at least 𝔭=16et2Cy(σy+1)n100𝔭16superscript𝑒𝑡2subscript𝐶𝑦subscript𝜎𝑦1superscript𝑛100\mathfrak{p}=1-6e^{-t}-2C_{y}(\sigma_{y}+1)n^{-100}fraktur_p = 1 - 6 italic_e start_POSTSUPERSCRIPT - italic_t end_POSTSUPERSCRIPT - 2 italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT + 1 ) italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT.

(2) Faster L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error rate. Moreover, if

δ𝚘𝚙𝚝2+supS[d]δ𝚊,,𝒢2(S)+δ𝚊,𝒢2+UBδn,t{1sminγ+1(γγ+1infS:𝖽¯𝒢,(S)>0𝖽¯𝒢,(S))}/Csuperscriptsubscript𝛿𝚘𝚙𝚝2subscriptsupremum𝑆delimited-[]𝑑subscriptsuperscript𝛿2𝚊𝒢𝑆subscriptsuperscript𝛿2𝚊𝒢𝑈𝐵subscript𝛿𝑛𝑡1subscript𝑠𝛾1𝛾𝛾1subscriptinfimum:𝑆subscript¯𝖽𝒢𝑆0subscript¯𝖽𝒢𝑆𝐶\displaystyle\delta_{\mathtt{opt}}^{2}+\sup_{S\subseteq[d]}\delta^{2}_{\mathtt% {a},\mathcal{F},\mathcal{G}}(S)+\delta^{2}_{\mathtt{a},\mathcal{G}}+UB\delta_{% n,t}\leq\left\{1\land\frac{s_{\min}}{\gamma+1}\land\left(\frac{\gamma}{\gamma+% 1}\inf_{S:\bar{\mathsf{d}}_{\mathcal{G},\mathcal{F}}(S)>0}\bar{\mathsf{d}}_{% \mathcal{G},\mathcal{F}}(S)\right)\right\}/Citalic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + roman_sup start_POSTSUBSCRIPT italic_S ⊆ [ italic_d ] end_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_S ) + italic_δ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_a , caligraphic_G end_POSTSUBSCRIPT + italic_U italic_B italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT ≤ { 1 ∧ divide start_ARG italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT end_ARG start_ARG italic_γ + 1 end_ARG ∧ ( divide start_ARG italic_γ end_ARG start_ARG italic_γ + 1 end_ARG roman_inf start_POSTSUBSCRIPT italic_S : over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT caligraphic_G , caligraphic_F end_POSTSUBSCRIPT ( italic_S ) > 0 end_POSTSUBSCRIPT over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT caligraphic_G , caligraphic_F end_POSTSUBSCRIPT ( italic_S ) ) } / italic_C (B.4)

then the following holds, with probability at least 𝔭𝔭\mathfrak{p}fraktur_p,

g^g2C(Uδn,t+δ𝚊,𝒢+δ𝚊,,𝒢+δ𝚘𝚙𝚝),subscriptnorm^𝑔superscript𝑔2𝐶𝑈subscript𝛿𝑛𝑡subscript𝛿𝚊𝒢superscriptsubscript𝛿𝚊𝒢subscript𝛿𝚘𝚙𝚝\displaystyle\|\widehat{g}-g^{\star}\|_{2}\leq C\left(U\delta_{n,t}+\delta_{% \mathtt{a},\mathcal{G}}+\delta_{\mathtt{a},\mathcal{F},\mathcal{G}}^{\star}+% \delta_{\mathtt{opt}}\right),∥ over^ start_ARG italic_g end_ARG - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_C ( italic_U italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_G end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT ) , (B.5)

where δ𝚊,,𝒢={1||esupg𝒢inffSgggf2,e2}1/2superscriptsubscript𝛿𝚊𝒢superscriptconditional-set1subscript𝑒subscriptsupremum𝑔𝒢subscriptinfimum𝑓subscriptsubscript𝑆𝑔superscript𝑔𝑔evaluated-at𝑓2𝑒212\delta_{\mathtt{a},\mathcal{F},\mathcal{G}}^{\star}=\{\frac{1}{|\mathcal{E}|}% \sum_{e\in\mathcal{E}}\sup_{g\in\mathcal{G}}\inf_{f\in\mathcal{F}_{S_{g}}}\|g^% {\star}-g-f\|_{2,e}^{2}\}^{1/2}italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = { divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_g ∈ caligraphic_G end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_f ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_g - italic_f ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT.

Theorem 4 generalizes Theorem 4.4 in Fan et al., (2023) to a broad spectrum of (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ) configurations. After specifying the function class (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ), one can further derive the corresponding identification condition by calculating (𝖻𝒢(S),𝖽¯𝒢,(S))subscript𝖻𝒢𝑆subscript¯𝖽𝒢𝑆(\mathsf{b}_{\mathcal{G}}(S),\bar{\mathsf{d}}_{\mathcal{G},\mathcal{F}}(S))( sansserif_b start_POSTSUBSCRIPT caligraphic_G end_POSTSUBSCRIPT ( italic_S ) , over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT caligraphic_G , caligraphic_F end_POSTSUBSCRIPT ( italic_S ) ) and establish a high probability bound on the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error by substituting approximation errors (δ𝚊,𝒢,δ𝚊,,𝒢(S),δ𝚊,,𝒢)subscript𝛿𝚊𝒢subscript𝛿𝚊𝒢𝑆superscriptsubscript𝛿𝚊𝒢(\delta_{\mathtt{a},\mathcal{G}},\delta_{\mathtt{a},\mathcal{F},\mathcal{G}}(S% ),\delta_{\mathtt{a},\mathcal{F},\mathcal{G}}^{\star})( italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_G end_POSTSUBSCRIPT , italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_S ) , italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) and stochastic error δnsubscript𝛿𝑛\delta_{n}italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT for the function class (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ). In particular, when 𝒢𝒢\mathcal{G}caligraphic_G and \mathcal{F}caligraphic_F are restricted to the linear function class, they not only match but also significantly improve the result in Fan et al., (2023); see Section B.6. All the results in Table 2 are direct corollaries of our abstract result Theorem 4.

It is required that γ𝛾\gammaitalic_γ should be greater than a constant-level critical threshold 8γ8superscript𝛾8\gamma^{\star}8 italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT for consistent estimation of gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Theorem 4 further establishes a crude instant-dependent and oracle-type error bound (B.3) that holds for arbitrary n2𝑛2n\geq 2italic_n ≥ 2 and scales linearly with γ𝛾\gammaitalic_γ. Furthermore, when the stochastic error and approximation errors all go to 00 as n𝑛nitalic_n increases and n𝑛nitalic_n is large enough such that (B.4) holds, we have (B.5), which improves the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error bound (B.3) in two aspects – the error bound is no longer dependent on either γ𝛾\gammaitalic_γ or other m(e,S)superscript𝑚𝑒𝑆m^{(e,S)}italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT with SS𝑆superscript𝑆S\neq S^{\star}italic_S ≠ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. The quantities in the RHS of (B.4) can be interpreted as the smaller of (1) the signal of true important variables and (2) the signal of heterogeneity. When one of these signals is weak, one can expect to demand more data to differentiate whether it is signal or noise.

One important ingredient in the FAIR estimator is the choice of regularization hyper-parameter γ𝛾\gammaitalic_γ that promotes the invariance. Theorem 4 offers some insights on choosing γ𝛾\gammaitalic_γ. Firstly, γCγ𝛾𝐶superscript𝛾\gamma\geq C\gamma^{\star}italic_γ ≥ italic_C italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is required such that it will correctly identify gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT from a population-level perspective. Second, it will influence the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error rate when n𝑛nitalic_n is not large enough such that (B.4) does not hold. Furthermore, the final L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error rate (B.5) when n𝑛nitalic_n is large enough is independent of γ𝛾\gammaitalic_γ. This indicates that the estimator’s performance is somewhat not very sensitive to the choice of hyper-parameter γ𝛾\gammaitalic_γ. In this case, one can adopt a slightly conservative large γ𝛾\gammaitalic_γ to meet the population condition γCγ𝛾𝐶superscript𝛾\gamma\geq C\gamma^{\star}italic_γ ≥ italic_C italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

B.2 Extension to the General Risk Loss under the Nonparametric Setting

Condition 11 (Risk Loss).

Define 𝒱=[infg𝒢{g}supL{g(X)L,μ¯x-a.s.},supg𝒢{g}infU{g(X)U,μ¯x-a.s.}]\mathcal{V}=[\inf_{g\in\mathcal{G}\cup\{g^{\star}\}}\sup_{L}\{g(X)\geq L,\bar{% \mu}_{x}\text{-}a.s.\},\sup_{g\in\mathcal{G}\cup\{g^{\star}\}}\inf_{U}\{g(X)% \leq U,\bar{\mu}_{x}\text{-}a.s.\}]caligraphic_V = [ roman_inf start_POSTSUBSCRIPT italic_g ∈ caligraphic_G ∪ { italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT } end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT { italic_g ( italic_X ) ≥ italic_L , over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT - italic_a . italic_s . } , roman_sup start_POSTSUBSCRIPT italic_g ∈ caligraphic_G ∪ { italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT } end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT { italic_g ( italic_X ) ≤ italic_U , over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT - italic_a . italic_s . } ] be the value that g(X)𝑔𝑋g(X)italic_g ( italic_X ) takes, and 𝒴=[supl{Yl,μ¯x-a.s.},infu{Yu,μ¯x-a.s.}]\mathcal{Y}=[\sup_{l}\{Y\geq l,\bar{\mu}_{x}\text{-}a.s.\},\inf_{u}\{Y\leq u,% \bar{\mu}_{x}\text{-}a.s.\}]caligraphic_Y = [ roman_sup start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT { italic_Y ≥ italic_l , over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT - italic_a . italic_s . } , roman_inf start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT { italic_Y ≤ italic_u , over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT - italic_a . italic_s . } ] be the value that Y𝑌Yitalic_Y takes. The loss (,)\ell(\cdot,\cdot)roman_ℓ ( ⋅ , ⋅ ) satisfies

  • (1)

    (y,v)<𝑦𝑣\ell(y,v)<\inftyroman_ℓ ( italic_y , italic_v ) < ∞ for any y𝒴𝑦𝒴y\in\mathcal{Y}italic_y ∈ caligraphic_Y and v𝒱𝑣𝒱v\in\mathcal{V}italic_v ∈ caligraphic_V and twice continuously differentiable in 𝒴×𝒱𝒴𝒱\mathcal{Y}\times\mathcal{V}caligraphic_Y × caligraphic_V. (y,v)v=(vy)ψ(v)𝑦𝑣𝑣𝑣𝑦𝜓𝑣\frac{\partial\ell(y,v)}{\partial v}=(v-y)\psi(v)divide start_ARG ∂ roman_ℓ ( italic_y , italic_v ) end_ARG start_ARG ∂ italic_v end_ARG = ( italic_v - italic_y ) italic_ψ ( italic_v ) for some continuously differentiable ψ(v)::𝜓𝑣\psi(v):\mathbb{R}\to\mathbb{R}italic_ψ ( italic_v ) : blackboard_R → blackboard_R.

  • (2)

    There exists some universal constant ζ1𝜁1\zeta\geq 1italic_ζ ≥ 1 such that

    |ψ(v)|ζandζ12v2(Y,v)ζv𝒱 and μ¯-a.s..formulae-sequenceformulae-sequence𝜓𝑣𝜁andsuperscript𝜁1superscript2superscript𝑣2𝑌𝑣𝜁for-all𝑣𝒱 and ¯𝜇-𝑎𝑠\displaystyle|\psi(v)|\leq\zeta\qquad\text{and}\qquad{\zeta^{-1}}\leq\frac{% \partial^{2}\ell}{\partial v^{2}}(Y,v)\leq{\zeta}\qquad\forall v\in\mathcal{V}% \text{ and }\bar{\mu}\text{-}a.s.~{}.| italic_ψ ( italic_v ) | ≤ italic_ζ and italic_ζ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ≤ divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ end_ARG start_ARG ∂ italic_v start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( italic_Y , italic_v ) ≤ italic_ζ ∀ italic_v ∈ caligraphic_V and over¯ start_ARG italic_μ end_ARG - italic_a . italic_s . .

The assumptions on risk loss in 11 is standard: (1) ensures that \ellroman_ℓ is well-defined on optimal solutions and linear combination of them, (2) requires that the population-level global minima is conditional mean, (3) guarantees that the loss function is strongly convex and smooth in the domain, and satisfies |(y,v)(y,v)|ζ|yv~||vv|𝑦𝑣𝑦superscript𝑣𝜁𝑦~𝑣𝑣superscript𝑣|\ell(y,v)-\ell(y,v^{\prime})|\leq\zeta|y-\widetilde{v}||v-v^{\prime}|| roman_ℓ ( italic_y , italic_v ) - roman_ℓ ( italic_y , italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) | ≤ italic_ζ | italic_y - over~ start_ARG italic_v end_ARG | | italic_v - italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | for some universal constant ζ𝜁\zetaitalic_ζ, which slightly relaxes the Lipschitz condition in Farrell et al., (2021) and Foster & Syrgkanis, (2019).

We now state the invariance and identification condition when the general risk loss is adopted.

Condition 12 (Invariance and Identification for General Risk Loss).

Suppose the following holds

  • 1.

    (Invariance) There exists some index set S[d]superscript𝑆delimited-[]𝑑S^{\star}\subseteq[d]italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ [ italic_d ] such that

    em(e,S)=m¯(S)=:m\displaystyle\forall e\in\mathcal{E}\qquad m^{(e,S^{\star})}=\bar{m}^{(S^{% \star})}=:m^{\star}∀ italic_e ∈ caligraphic_E italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT = over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT = : italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT
  • 2.

    (Heterogeneity) For each S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ], if 𝖻(S)>0𝖻𝑆0\mathsf{b}(S)>0sansserif_b ( italic_S ) > 0, then 𝖽¯(S)>0¯𝖽𝑆0\bar{\mathsf{d}}(S)>0over¯ start_ARG sansserif_d end_ARG ( italic_S ) > 0, where

    𝖻(S):=m¯(SS)m22,𝖽¯(S):=1||e=1mm(e,S)m¯(S)2,e2.formulae-sequenceassign𝖻𝑆superscriptsubscriptnormsuperscript¯𝑚𝑆superscript𝑆superscript𝑚22assign¯𝖽𝑆1superscriptsubscript𝑒1𝑚superscriptsubscriptnormsuperscript𝑚𝑒𝑆superscript¯𝑚𝑆2𝑒2\displaystyle\mathsf{b}(S):=\|\bar{m}^{(S\cup S^{\star})}-m^{\star}\|_{2}^{2},% \qquad\bar{\mathsf{d}}(S):=\frac{1}{|\mathcal{E}|}\sum_{e=1}^{m}\|m^{(e,S)}-% \bar{m}^{(S)}\|_{2,e}^{2}.sansserif_b ( italic_S ) := ∥ over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , over¯ start_ARG sansserif_d end_ARG ( italic_S ) := divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ∥ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT - over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (B.6)
  • 3.

    (Nondegenerate Covariate) For any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] such that SSsuperscript𝑆𝑆S^{\star}\setminus S\neq\emptysetitalic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∖ italic_S ≠ ∅, we have infgΘSgm22sminsubscriptinfimum𝑔subscriptΘ𝑆superscriptsubscriptnorm𝑔superscript𝑚22subscript𝑠\inf_{g\in\Theta_{S}}\|g-m^{\star}\|_{2}^{2}\geq s_{\min}roman_inf start_POSTSUBSCRIPT italic_g ∈ roman_Θ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_g - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT for some constant smin>0subscript𝑠0s_{\min}>0italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT > 0.

We are now ready to state the main result in this case.

Theorem 5 (Main Result for the FAIR Estimator with General Risk Loss).

Assume 7,8,9, and 1112 hold. Define the critical threshold

γ:=supS[d]:𝖻(S)>0𝖻(S)𝖽¯(S).assignsuperscript𝛾subscriptsupremum:𝑆delimited-[]𝑑𝖻𝑆0𝖻𝑆¯𝖽𝑆\displaystyle\gamma^{\star}:=\sup_{S\subseteq[d]:\mathsf{b}(S)>0}\frac{\mathsf% {b}(S)}{\bar{\mathsf{d}}(S)}.italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT := roman_sup start_POSTSUBSCRIPT italic_S ⊆ [ italic_d ] : sansserif_b ( italic_S ) > 0 end_POSTSUBSCRIPT divide start_ARG sansserif_b ( italic_S ) end_ARG start_ARG over¯ start_ARG sansserif_d end_ARG ( italic_S ) end_ARG .

There exists some universal constant C𝐶Citalic_C such that, for any γ8ζ2γ𝛾8superscript𝜁2superscript𝛾\gamma\geq 8\zeta^{2}\gamma^{\star}italic_γ ≥ 8 italic_ζ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, the following holds:

(1) General L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error rate. Let t>0𝑡0t>0italic_t > 0 be arbitrary. Define general approximation errors with respect to the function class 𝒢𝒢\mathcal{G}caligraphic_G and \mathcal{F}caligraphic_F as

δ𝚊,𝒢=infg𝒢Sgm2andδ𝚊,,𝒢(S)=1||esupg𝒢:Sg=SinffSgm(e,S)gf2,e2,formulae-sequencesubscript𝛿𝚊𝒢subscriptinfimum𝑔subscript𝒢superscript𝑆subscriptnorm𝑔superscript𝑚2andsubscript𝛿𝚊𝒢𝑆1subscript𝑒subscriptsupremum:𝑔𝒢subscript𝑆𝑔𝑆subscriptinfimum𝑓subscriptsubscript𝑆𝑔superscriptsubscriptnormsuperscript𝑚𝑒𝑆𝑔𝑓2𝑒2\displaystyle\delta_{\mathtt{a},\mathcal{G}}=\inf_{g\in\mathcal{G}_{S^{\star}}% }\|g-m^{\star}\|_{2}~{}~{}~{}~{}\text{and}~{}~{}~{}~{}\delta_{\mathtt{a},% \mathcal{F},\mathcal{G}}(S)=\sqrt{\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}% }\sup_{g\in\mathcal{G}:S_{g}=S}\inf_{f\in{\mathcal{F}_{S_{g}}}}\|m^{(e,S)}-g-f% \|_{2,e}^{2}},italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_G end_POSTSUBSCRIPT = roman_inf start_POSTSUBSCRIPT italic_g ∈ caligraphic_G start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_g - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_S ) = square-root start_ARG divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_g ∈ caligraphic_G : italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = italic_S end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_f ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT - italic_g - italic_f ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ,

and the stochastic error as δn,t=δn+{(log(nB||)+t+1)/n}1/2subscript𝛿𝑛𝑡subscript𝛿𝑛superscript𝑛𝐵𝑡1𝑛12\delta_{n,t}=\delta_{n}+\{(\log(nB|\mathcal{E}|)+t+1)/n\}^{1/2}italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT = italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + { ( roman_log ( italic_n italic_B | caligraphic_E | ) + italic_t + 1 ) / italic_n } start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT, where δnsubscript𝛿𝑛\delta_{n}italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is the quantity in 9. Let U=B(B+σylog(n||))𝑈𝐵𝐵subscript𝜎𝑦𝑛U=B(B+\sigma_{y}\sqrt{\log(n|\mathcal{E}|)})italic_U = italic_B ( italic_B + italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT square-root start_ARG roman_log ( italic_n | caligraphic_E | ) end_ARG ), then

g^m2g^mnC(ζ+γ)ζ(Uδn,t+δ𝚊,𝒢+δ𝚊,,𝒢(Sg^)+δ𝚊,,𝒢(S)+δ𝚘𝚙𝚝).subscriptnorm^𝑔superscript𝑚2subscriptnorm^𝑔superscript𝑚𝑛𝐶𝜁𝛾𝜁𝑈subscript𝛿𝑛𝑡subscript𝛿𝚊𝒢subscript𝛿𝚊𝒢subscript𝑆^𝑔subscript𝛿𝚊𝒢superscript𝑆subscript𝛿𝚘𝚙𝚝\displaystyle\|\widehat{g}-m^{\star}\|_{2}\lor\|\widehat{g}-m^{\star}\|_{n}% \leq C(\zeta+\gamma)\zeta\left(U\delta_{n,t}+\delta_{\mathtt{a},\mathcal{G}}+% \delta_{\mathtt{a},\mathcal{F},\mathcal{G}}(S_{\widehat{g}})+\delta_{\mathtt{a% },\mathcal{F},\mathcal{G}}(S^{\star})+\delta_{\mathtt{opt}}\right).∥ over^ start_ARG italic_g end_ARG - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∨ ∥ over^ start_ARG italic_g end_ARG - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ≤ italic_C ( italic_ζ + italic_γ ) italic_ζ ( italic_U italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_G end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT over^ start_ARG italic_g end_ARG end_POSTSUBSCRIPT ) + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) + italic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT ) . (B.7)

with probability at least 𝔭=16et2Cy(σy+1)n100𝔭16superscript𝑒𝑡2subscript𝐶𝑦subscript𝜎𝑦1superscript𝑛100\mathfrak{p}=1-6e^{-t}-2C_{y}(\sigma_{y}+1)n^{-100}fraktur_p = 1 - 6 italic_e start_POSTSUPERSCRIPT - italic_t end_POSTSUPERSCRIPT - 2 italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT + 1 ) italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT.

(2) Faster L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error rate. Moreover, if

δ𝚘𝚙𝚝2+supS[d]δ𝚊,,𝒢2(S)+δ𝚊,𝒢2+UBδn,t{1smin(γ+ζ)ζ(γγ+ζinfS:𝖽¯𝒢,(S)>0𝖽¯𝒢,(S))}/Csuperscriptsubscript𝛿𝚘𝚙𝚝2subscriptsupremum𝑆delimited-[]𝑑subscriptsuperscript𝛿2𝚊𝒢𝑆subscriptsuperscript𝛿2𝚊𝒢𝑈𝐵subscript𝛿𝑛𝑡1subscript𝑠𝛾𝜁𝜁𝛾𝛾𝜁subscriptinfimum:𝑆subscript¯𝖽𝒢𝑆0subscript¯𝖽𝒢𝑆𝐶\displaystyle\delta_{\mathtt{opt}}^{2}+\sup_{S\subseteq[d]}\delta^{2}_{\mathtt% {a},\mathcal{F},\mathcal{G}}(S)+\delta^{2}_{\mathtt{a},\mathcal{G}}+UB\delta_{% n,t}\leq\left\{1\land\frac{s_{\min}}{(\gamma+\zeta)\zeta}\land\left(\frac{% \gamma}{\gamma+\zeta}\inf_{S:\bar{\mathsf{d}}_{\mathcal{G},\mathcal{F}}(S)>0}% \bar{\mathsf{d}}_{\mathcal{G},\mathcal{F}}(S)\right)\right\}/Citalic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + roman_sup start_POSTSUBSCRIPT italic_S ⊆ [ italic_d ] end_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_S ) + italic_δ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_a , caligraphic_G end_POSTSUBSCRIPT + italic_U italic_B italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT ≤ { 1 ∧ divide start_ARG italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT end_ARG start_ARG ( italic_γ + italic_ζ ) italic_ζ end_ARG ∧ ( divide start_ARG italic_γ end_ARG start_ARG italic_γ + italic_ζ end_ARG roman_inf start_POSTSUBSCRIPT italic_S : over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT caligraphic_G , caligraphic_F end_POSTSUBSCRIPT ( italic_S ) > 0 end_POSTSUBSCRIPT over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT caligraphic_G , caligraphic_F end_POSTSUBSCRIPT ( italic_S ) ) } / italic_C (B.8)

then the following holds, with probability at least 𝔭𝔭\mathfrak{p}fraktur_p,

g^m2g^mnCζ2(Uδn,t+δ𝚊,𝒢+δ𝚊,,𝒢+δ𝚘𝚙𝚝),subscriptnorm^𝑔superscript𝑚2subscriptnorm^𝑔superscript𝑚𝑛𝐶superscript𝜁2𝑈subscript𝛿𝑛𝑡subscript𝛿𝚊𝒢superscriptsubscript𝛿𝚊𝒢subscript𝛿𝚘𝚙𝚝\displaystyle\|\widehat{g}-m^{\star}\|_{2}\lor\|\widehat{g}-m^{\star}\|_{n}% \leq C\zeta^{2}\left(U\delta_{n,t}+\delta_{\mathtt{a},\mathcal{G}}+\delta_{% \mathtt{a},\mathcal{F},\mathcal{G}}^{\star}+\delta_{\mathtt{opt}}\right),∥ over^ start_ARG italic_g end_ARG - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∨ ∥ over^ start_ARG italic_g end_ARG - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ≤ italic_C italic_ζ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_U italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_G end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT ) , (B.9)

where δ𝚊,,𝒢={1||esupg𝒢inffSgmgf2,e2}1/2superscriptsubscript𝛿𝚊𝒢superscriptconditional-set1subscript𝑒subscriptsupremum𝑔𝒢subscriptinfimum𝑓subscriptsubscript𝑆𝑔superscript𝑚𝑔evaluated-at𝑓2𝑒212\delta_{\mathtt{a},\mathcal{F},\mathcal{G}}^{\star}=\{\frac{1}{|\mathcal{E}|}% \sum_{e\in\mathcal{E}}\sup_{g\in\mathcal{G}}\inf_{f\in\mathcal{F}_{S_{g}}}\|m^% {\star}-g-f\|_{2,e}^{2}\}^{1/2}italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = { divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_g ∈ caligraphic_G end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_f ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_g - italic_f ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT.

B.3 Key Ideas and Proof Sketch of Theorem 4

We first introduce some additional notations. Let

𝖠(e)(g,f(e))superscript𝖠𝑒𝑔superscript𝑓𝑒\displaystyle\mathsf{A}^{(e)}(g,f^{(e)})sansserif_A start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) =𝔼[{Y(e)g(X(e))}f(e)(X(e))12{f(e)(X(e))}2]absent𝔼delimited-[]superscript𝑌𝑒𝑔superscript𝑋𝑒superscript𝑓𝑒superscript𝑋𝑒12superscriptsuperscript𝑓𝑒superscript𝑋𝑒2\displaystyle=\mathbb{E}\left[\{Y^{(e)}-g(X^{(e)})\}f^{(e)}(X^{(e)})-\frac{1}{% 2}\{f^{(e)}(X^{(e)})\}^{2}\right]= blackboard_E [ { italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG { italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
𝖠^(e)(g,f(e))superscript^𝖠𝑒𝑔superscript𝑓𝑒\displaystyle\widehat{\mathsf{A}}^{(e)}(g,f^{(e)})over^ start_ARG sansserif_A end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) =1ni=1n{Yi(e)g(Xi(e))}f(e)(Xi(e))12{f(e)(Xi(e))}2.absent1𝑛superscriptsubscript𝑖1𝑛superscriptsubscript𝑌𝑖𝑒𝑔superscriptsubscript𝑋𝑖𝑒superscript𝑓𝑒superscriptsubscript𝑋𝑖𝑒12superscriptsuperscript𝑓𝑒superscriptsubscript𝑋𝑖𝑒2\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\{Y_{i}^{(e)}-g(X_{i}^{(e)})\}f^{(e)}(X% _{i}^{(e)})-\frac{1}{2}\{f^{(e)}(X_{i}^{(e)})\}^{2}.= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT { italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG { italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT .

Define the population-level pooled risk and FAIR estimator loss as

𝖱(g)=1||e𝔼[12|Y(e)g(X(e))|2]and𝖰γ(g,f)=𝖱(g)+γ𝖩(g,f)formulae-sequence𝖱𝑔1subscript𝑒𝔼delimited-[]12superscriptsuperscript𝑌𝑒𝑔superscript𝑋𝑒2andsubscript𝖰𝛾𝑔superscript𝑓𝖱𝑔𝛾𝖩𝑔superscript𝑓\displaystyle\mathsf{R}(g)=\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}% \mathbb{E}\left[\frac{1}{2}|Y^{(e)}-g(X^{(e)})|^{2}\right]\qquad\text{and}% \qquad\mathsf{Q}_{\gamma}(g,f^{\mathcal{E}})=\mathsf{R}(g)+\gamma\mathsf{J}(g,% f^{\mathcal{E}})sansserif_R ( italic_g ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] and sansserif_Q start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) = sansserif_R ( italic_g ) + italic_γ sansserif_J ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT )

We will use the following theorem establishing approximate strong convexity with respect to gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

Theorem 6.

Assume 10 hold, (y,v)=12(yv)2𝑦𝑣12superscript𝑦𝑣2\ell(y,v)=\frac{1}{2}(y-v)^{2}roman_ℓ ( italic_y , italic_v ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_y - italic_v ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. Let δ(0,1)𝛿01\delta\in(0,1)italic_δ ∈ ( 0 , 1 ) be arbitrary. Then the following holds, for any γ4δ1γ𝛾4superscript𝛿1superscript𝛾\gamma\geq 4\delta^{-1}\gamma^{\star}italic_γ ≥ 4 italic_δ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT,

𝖰γ(g,f)𝖰γ(g~,f~)subscript𝖰𝛾𝑔superscript𝑓subscript𝖰𝛾~𝑔superscript~𝑓absent\displaystyle\mathsf{Q}_{\gamma}(g,f^{\mathcal{E}})-\mathsf{Q}_{\gamma}(% \widetilde{g},\widetilde{f}^{\mathcal{E}})\geqsansserif_Q start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) - sansserif_Q start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( over~ start_ARG italic_g end_ARG , over~ start_ARG italic_f end_ARG start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) ≥ 1δ2gg~22+γ4𝖽¯𝒢,(S)+γ2gΠ𝒢S¯(m¯(S))221𝛿2superscriptsubscriptnorm𝑔~𝑔22𝛾4subscript¯𝖽𝒢𝑆𝛾2superscriptsubscriptnorm𝑔subscriptΠ¯subscript𝒢𝑆superscript¯𝑚𝑆22\displaystyle\frac{1-\delta}{2}\|g-\widetilde{g}\|_{2}^{2}+\frac{\gamma}{4}% \bar{\mathsf{d}}_{\mathcal{G},\mathcal{F}}(S)+\frac{\gamma}{2}\|g-\Pi_{% \overline{\mathcal{G}_{S}}}(\bar{m}^{(S)})\|_{2}^{2}divide start_ARG 1 - italic_δ end_ARG start_ARG 2 end_ARG ∥ italic_g - over~ start_ARG italic_g end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG italic_γ end_ARG start_ARG 4 end_ARG over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT caligraphic_G , caligraphic_F end_POSTSUBSCRIPT ( italic_S ) + divide start_ARG italic_γ end_ARG start_ARG 2 end_ARG ∥ italic_g - roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_G start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
γ2||ef(e){ΠS¯(e)(m(e,S))g}2,e2(δ1+γ/2)g~g22𝛾2subscript𝑒superscriptsubscriptnormsuperscript𝑓𝑒superscriptsubscriptΠ¯subscript𝑆𝑒superscript𝑚𝑒𝑆𝑔2𝑒2superscript𝛿1𝛾2superscriptsubscriptnorm~𝑔superscript𝑔22\displaystyle~{}~{}~{}~{}~{}~{}-\frac{\gamma}{2|\mathcal{E}|}\sum_{e\in% \mathcal{E}}\|f^{(e)}-\{\Pi_{\overline{\mathcal{F}_{S}}}^{(e)}(m^{(e,S)})-g\}% \|_{2,e}^{2}-(\delta^{-1}+\gamma/2)\|\widetilde{g}-g^{\star}\|_{2}^{2}- divide start_ARG italic_γ end_ARG start_ARG 2 | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - { roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ) - italic_g } ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - ( italic_δ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_γ / 2 ) ∥ over~ start_ARG italic_g end_ARG - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

for any g𝒢𝑔𝒢g\in\mathcal{G}italic_g ∈ caligraphic_G, g~𝒢S~𝑔subscript𝒢superscript𝑆\widetilde{g}\in\mathcal{G}_{S^{\star}}over~ start_ARG italic_g end_ARG ∈ caligraphic_G start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT and Sg~=Ssubscript𝑆~𝑔superscript𝑆S_{\widetilde{g}}=S^{\star}italic_S start_POSTSUBSCRIPT over~ start_ARG italic_g end_ARG end_POSTSUBSCRIPT = italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, f{Sg¯}||superscript𝑓superscript¯subscriptsubscript𝑆𝑔f^{\mathcal{E}}\in\{\overline{\mathcal{F}_{S_{g}}}\}^{|\mathcal{E}|}italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT, and f~{S¯}||superscript~𝑓superscript¯subscriptsuperscript𝑆\widetilde{f}^{\mathcal{E}}\in\{\overline{\mathcal{F}_{S^{\star}}}\}^{|% \mathcal{E}|}over~ start_ARG italic_f end_ARG start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT.

Recall our definition of

δn,t=δn+t+log(nB||)+1nandU=B(B+σlog(n||))formulae-sequencesubscript𝛿𝑛𝑡subscript𝛿𝑛𝑡𝑛𝐵1𝑛and𝑈𝐵𝐵𝜎𝑛\displaystyle\delta_{n,t}=\delta_{n}+\sqrt{\frac{t+\log(nB|\mathcal{E}|)+1}{n}% }\qquad\text{and}\qquad U=B(B+\sigma\sqrt{\log(n|\mathcal{E}|)})italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT = italic_δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + square-root start_ARG divide start_ARG italic_t + roman_log ( italic_n italic_B | caligraphic_E | ) + 1 end_ARG start_ARG italic_n end_ARG end_ARG and italic_U = italic_B ( italic_B + italic_σ square-root start_ARG roman_log ( italic_n | caligraphic_E | ) end_ARG )

The first proposition establishes instance-dependent error bounds on

Δ𝖱(g,g~):={𝖱^(g)𝖱^(g~)}{𝖱(g)𝖱(g~)},assignsubscriptΔ𝖱𝑔~𝑔^𝖱𝑔^𝖱~𝑔𝖱𝑔𝖱~𝑔\displaystyle\Delta_{\mathsf{R}}(g,\widetilde{g}):=\{\widehat{\mathsf{R}}(g)-% \widehat{\mathsf{R}}(\widetilde{g})\}-\{{\mathsf{R}}(g)-\mathsf{R}(\widetilde{% g})\},roman_Δ start_POSTSUBSCRIPT sansserif_R end_POSTSUBSCRIPT ( italic_g , over~ start_ARG italic_g end_ARG ) := { over^ start_ARG sansserif_R end_ARG ( italic_g ) - over^ start_ARG sansserif_R end_ARG ( over~ start_ARG italic_g end_ARG ) } - { sansserif_R ( italic_g ) - sansserif_R ( over~ start_ARG italic_g end_ARG ) } ,

and is standard in nonparametric regression literature.

Proposition 7 (Instance-dependent error bounds for pooled risk).

Suppose 7,8, 9 hold. There exists some universal constant C𝐶Citalic_C such that for any η>0𝜂0\eta>0italic_η > 0 and t>0𝑡0t>0italic_t > 0, the following event

g,g~𝒢,|Δ𝖱(g,g~)|CU{δn,t2+δn,t1||egg~2,e}formulae-sequencefor-all𝑔~𝑔𝒢subscriptΔ𝖱𝑔~𝑔𝐶𝑈conditional-setsuperscriptsubscript𝛿𝑛𝑡2subscript𝛿𝑛𝑡1subscript𝑒𝑔evaluated-at~𝑔2𝑒\displaystyle\forall g,\widetilde{g}\in\mathcal{G},~{}~{}|\Delta_{\mathsf{R}}(% g,\widetilde{g})|\leq CU\left\{\delta_{n,t}^{2}+\delta_{n,t}\frac{1}{|\mathcal% {E}|}\sum_{e\in\mathcal{E}}\|g-\widetilde{g}\|_{2,e}\right\}∀ italic_g , over~ start_ARG italic_g end_ARG ∈ caligraphic_G , | roman_Δ start_POSTSUBSCRIPT sansserif_R end_POSTSUBSCRIPT ( italic_g , over~ start_ARG italic_g end_ARG ) | ≤ italic_C italic_U { italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ italic_g - over~ start_ARG italic_g end_ARG ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT }

occurs with probability at least 13etCy(σy+1)n10013superscript𝑒𝑡subscript𝐶𝑦subscript𝜎𝑦1superscript𝑛1001-3e^{-t}-C_{y}(\sigma_{y}+1)n^{-100}1 - 3 italic_e start_POSTSUPERSCRIPT - italic_t end_POSTSUPERSCRIPT - italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT + 1 ) italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT.

The analysis of the focused adversarial invariance regularizer is more involved. The next proposition establishes the instance-dependent error bound for the regularizer. We define

Δ𝖠(e)(g,g~,f(e),f~(e))=𝖠(e)(f,g(e))𝖠(e)(f~,g~(e)){𝖠^(e)(f,g(e))𝖠^(e)(f~,g~(e))}superscriptsubscriptΔ𝖠𝑒𝑔~𝑔superscript𝑓𝑒superscript~𝑓𝑒superscript𝖠𝑒𝑓superscript𝑔𝑒superscript𝖠𝑒~𝑓superscript~𝑔𝑒superscript^𝖠𝑒𝑓superscript𝑔𝑒superscript^𝖠𝑒~𝑓superscript~𝑔𝑒\displaystyle\Delta_{\mathsf{A}}^{(e)}(g,\widetilde{g},f^{(e)},\widetilde{f}^{% (e)})=\mathsf{A}^{(e)}(f,g^{(e)})-\mathsf{A}^{(e)}(\widetilde{f},\widetilde{g}% ^{(e)})-\left\{\widehat{\mathsf{A}}^{(e)}(f,g^{(e)})-\widehat{\mathsf{A}}^{(e)% }(\widetilde{f},\widetilde{g}^{(e)})\right\}roman_Δ start_POSTSUBSCRIPT sansserif_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_g , over~ start_ARG italic_g end_ARG , italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , over~ start_ARG italic_f end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) = sansserif_A start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_f , italic_g start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) - sansserif_A start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( over~ start_ARG italic_f end_ARG , over~ start_ARG italic_g end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) - { over^ start_ARG sansserif_A end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_f , italic_g start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) - over^ start_ARG sansserif_A end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( over~ start_ARG italic_f end_ARG , over~ start_ARG italic_g end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) }

and

(𝒢,)={(g,g~,f,f~):g,g~𝒢andfSg,f~Sg~}.𝒢conditional-set𝑔~𝑔𝑓~𝑓formulae-sequence𝑔~𝑔𝒢and𝑓subscriptsubscript𝑆𝑔~𝑓subscriptsubscript𝑆~𝑔\displaystyle\mathcal{M}(\mathcal{G},\mathcal{F})=\left\{(g,\widetilde{g},f,% \widetilde{f}):g,\widetilde{g}\in\mathcal{G}~{}\text{and}~{}f\in\mathcal{F}_{S% _{g}},\widetilde{f}\in\mathcal{F}_{S_{\widetilde{g}}}\right\}.caligraphic_M ( caligraphic_G , caligraphic_F ) = { ( italic_g , over~ start_ARG italic_g end_ARG , italic_f , over~ start_ARG italic_f end_ARG ) : italic_g , over~ start_ARG italic_g end_ARG ∈ caligraphic_G and italic_f ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT , over~ start_ARG italic_f end_ARG ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT over~ start_ARG italic_g end_ARG end_POSTSUBSCRIPT end_POSTSUBSCRIPT } .
Proposition 8 (Instance-dependent error bounds for regularizer).

Suppose 7, 8, 9 hold. There exists some universal constant C𝐶Citalic_C such that for any t>0𝑡0t>0italic_t > 0, the following event

e,for-all𝑒\displaystyle\forall e\in\mathcal{E},∀ italic_e ∈ caligraphic_E , (g,g~,f(e),f~(e))(𝒢,),for-all𝑔~𝑔superscript𝑓𝑒superscript~𝑓𝑒𝒢\displaystyle~{}\forall(g,\widetilde{g},f^{(e)},\widetilde{f}^{(e)})\in% \mathcal{M}(\mathcal{G},\mathcal{F}),∀ ( italic_g , over~ start_ARG italic_g end_ARG , italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , over~ start_ARG italic_f end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ∈ caligraphic_M ( caligraphic_G , caligraphic_F ) ,
|Δ𝖠(e)(g,g~,f(e),f~(e))|CU(δn,t(g~g2,e+g~+f~(e)gf(e)2,e)+δn,t2)superscriptsubscriptΔ𝖠𝑒𝑔~𝑔superscript𝑓𝑒superscript~𝑓𝑒𝐶𝑈subscript𝛿𝑛𝑡subscriptnorm~𝑔𝑔2𝑒subscriptnorm~𝑔superscript~𝑓𝑒𝑔superscript𝑓𝑒2𝑒superscriptsubscript𝛿𝑛𝑡2\displaystyle~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}|\Delta_{\mathsf{A}}^{(e)}(g,% \widetilde{g},f^{(e)},\widetilde{f}^{(e)})|\leq CU\left(\delta_{n,t}\left(\|% \widetilde{g}-g\|_{2,e}+\|\widetilde{g}+\widetilde{f}^{(e)}-g-f^{(e)}\|_{2,e}% \right)+\delta_{n,t}^{2}\right)| roman_Δ start_POSTSUBSCRIPT sansserif_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_g , over~ start_ARG italic_g end_ARG , italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , over~ start_ARG italic_f end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) | ≤ italic_C italic_U ( italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT ( ∥ over~ start_ARG italic_g end_ARG - italic_g ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT + ∥ over~ start_ARG italic_g end_ARG + over~ start_ARG italic_f end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g - italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT ) + italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )

occurs with probability at least 13etCy(σy+1)n10013superscript𝑒𝑡subscript𝐶𝑦subscript𝜎𝑦1superscript𝑛1001-3e^{-t}-C_{y}(\sigma_{y}+1)n^{-100}1 - 3 italic_e start_POSTSUPERSCRIPT - italic_t end_POSTSUPERSCRIPT - italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT + 1 ) italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT.

We first utilize Proposition 8 in a way that g𝑔gitalic_g and g~~𝑔\widetilde{g}over~ start_ARG italic_g end_ARG are the same. In this case, the optimization problem of max\maxroman_max-\mathcal{F}caligraphic_F in one single environment e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E for fixed g𝒢𝑔𝒢g\in\mathcal{G}italic_g ∈ caligraphic_G is similar to least squares regression that fits the target regression function

ΠS¯(e)(m(e,S))g.superscriptsubscriptΠ¯subscript𝑆𝑒superscript𝑚𝑒𝑆𝑔\displaystyle\Pi_{\overline{\mathcal{F}_{S}}}^{(e)}(m^{(e,S)})-g.roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ) - italic_g .

Thus one can establish high probability error bounds on the 2,e\|\cdot\|_{2,e}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT norm between the empirical loss maximizer f^g(e)superscriptsubscript^𝑓𝑔𝑒\widehat{f}_{g}^{(e)}over^ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT and the above target function in terms of statistical error δn,tsubscript𝛿𝑛𝑡\delta_{n,t}italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT and approximation error rate δ𝚊,,𝒢(e,Sg)subscript𝛿𝚊𝒢𝑒subscript𝑆𝑔\delta_{\mathtt{a},\mathcal{F},\mathcal{G}}(e,S_{g})italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_e , italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ), defined as

δ𝚊,,𝒢(e,S):=supg𝒢:Sg=SinffSΠS¯(e)(m(e,S))gf2,eassignsubscript𝛿𝚊𝒢𝑒𝑆subscriptsupremum:𝑔𝒢subscript𝑆𝑔𝑆subscriptinfimum𝑓subscript𝑆subscriptnormsuperscriptsubscriptΠ¯subscript𝑆𝑒superscript𝑚𝑒𝑆𝑔𝑓2𝑒\displaystyle\delta_{\mathtt{a},\mathcal{F},\mathcal{G}}(e,S):=\sup_{g\in% \mathcal{G}:S_{g}=S}\inf_{f\in\mathcal{F}_{S}}\|\Pi_{\overline{\mathcal{F}_{S}% }}^{(e)}(m^{(e,S)})-g-f\|_{2,e}italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_e , italic_S ) := roman_sup start_POSTSUBSCRIPT italic_g ∈ caligraphic_G : italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = italic_S end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_f ∈ caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ) - italic_g - italic_f ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT

We formally present the above intuition in the following instance-dependent error bound in Proposition 9 in a way that the optimization gap term is maintained in the error bound.

Proposition 9 (Instance-dependent characterization of approximately optimal discriminator).

Let 0<η<1/20𝜂120<\eta<1/20 < italic_η < 1 / 2 be arbitrary, under the event defined in Proposition 8, the following holds,

e,g𝒢,f(e)Sg,formulae-sequencefor-all𝑒formulae-sequencefor-all𝑔𝒢for-allsuperscript𝑓𝑒subscriptsubscript𝑆𝑔\displaystyle\forall e\in\mathcal{E},\forall g\in\mathcal{G},\forall f^{(e)}% \in\mathcal{F}_{S_{g}},∀ italic_e ∈ caligraphic_E , ∀ italic_g ∈ caligraphic_G , ∀ italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT ,
ΠS¯(m(e,S))gf(e)2,e22η1+24η12ηδ𝚊,,𝒢2(e,Sg)+2η1+412ηC2U2δn,t2superscriptsubscriptnormsubscriptΠ¯subscript𝑆superscript𝑚𝑒𝑆𝑔superscript𝑓𝑒2𝑒22superscript𝜂124𝜂12𝜂subscriptsuperscript𝛿2𝚊𝒢𝑒subscript𝑆𝑔2superscript𝜂1412𝜂superscript𝐶2superscript𝑈2superscriptsubscript𝛿𝑛𝑡2\displaystyle~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\|\Pi_{\overline{\mathcal{F}_% {S}}}(m^{(e,S)})-g-f^{(e)}\|_{2,e}^{2}\leq\frac{2\eta^{-1}+2-4\eta}{1-2\eta}% \delta^{2}_{\mathtt{a},\mathcal{F},\mathcal{G}}(e,S_{g})+\frac{2\eta^{-1}+4}{1% -2\eta}C^{2}U^{2}\delta_{n,t}^{2}∥ roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ) - italic_g - italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ divide start_ARG 2 italic_η start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + 2 - 4 italic_η end_ARG start_ARG 1 - 2 italic_η end_ARG italic_δ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_e , italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) + divide start_ARG 2 italic_η start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + 4 end_ARG start_ARG 1 - 2 italic_η end_ARG italic_C start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_U start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
+412η{supf˘Sg𝖠^(e)(g,f˘)𝖠^(e)(g,f(e))}412𝜂subscriptsupremum˘𝑓subscriptsubscript𝑆𝑔superscript^𝖠𝑒𝑔˘𝑓superscript^𝖠𝑒𝑔superscript𝑓𝑒\displaystyle~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{% }~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{% }~{}~{}~{}~{}~{}~{}+\frac{4}{1-2\eta}\left\{\sup_{\breve{f}\in\mathcal{F}_{S_{% g}}}\widehat{\mathsf{A}}^{(e)}(g,\breve{f})-\widehat{\mathsf{A}}^{(e)}(g,f^{(e% )})\right\}+ divide start_ARG 4 end_ARG start_ARG 1 - 2 italic_η end_ARG { roman_sup start_POSTSUBSCRIPT over˘ start_ARG italic_f end_ARG ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT over^ start_ARG sansserif_A end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_g , over˘ start_ARG italic_f end_ARG ) - over^ start_ARG sansserif_A end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) }

where C𝐶Citalic_C is the universal constant defined in Proposition 8. Averaging over all the e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, we obtain

g𝒢,f{Sg}||,formulae-sequencefor-all𝑔𝒢for-allsuperscript𝑓superscriptsubscriptsubscript𝑆𝑔\displaystyle\forall g\in\mathcal{G},~{}\forall f^{\mathcal{E}}\in\{\mathcal{F% }_{S_{g}}\}^{|\mathcal{E}|},∀ italic_g ∈ caligraphic_G , ∀ italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT ,
1||eΠS¯(m(e,S))gf(e)2,e22η1+24η12ηδ𝚊,,𝒢2(Sg)+2η1+412ηC2U2δn,t21subscript𝑒superscriptsubscriptnormsubscriptΠ¯subscript𝑆superscript𝑚𝑒𝑆𝑔superscript𝑓𝑒2𝑒22superscript𝜂124𝜂12𝜂subscriptsuperscript𝛿2𝚊𝒢subscript𝑆𝑔2superscript𝜂1412𝜂superscript𝐶2superscript𝑈2superscriptsubscript𝛿𝑛𝑡2\displaystyle~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}\frac{1}{|\mathcal{E}|}\sum_{% e\in\mathcal{E}}\|\Pi_{\overline{\mathcal{F}_{S}}}(m^{(e,S)})-g-f^{(e)}\|_{2,e% }^{2}\leq\frac{2\eta^{-1}+2-4\eta}{1-2\eta}\delta^{2}_{\mathtt{a},\mathcal{F},% \mathcal{G}}(S_{g})+\frac{2\eta^{-1}+4}{1-2\eta}C^{2}U^{2}\delta_{n,t}^{2}divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ) - italic_g - italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ divide start_ARG 2 italic_η start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + 2 - 4 italic_η end_ARG start_ARG 1 - 2 italic_η end_ARG italic_δ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) + divide start_ARG 2 italic_η start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + 4 end_ARG start_ARG 1 - 2 italic_η end_ARG italic_C start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_U start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
+γ1412η{supf˘{Sg}||𝖰^γ(g,f˘)𝖰^γ(g,f)}superscript𝛾1412𝜂subscriptsupremumsuperscript˘𝑓superscriptsubscriptsubscript𝑆𝑔subscript^𝖰𝛾𝑔˘𝑓subscript^𝖰𝛾𝑔superscript𝑓\displaystyle~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{% }~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{% }~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}~{}+\gamma^{-1}\frac{4}{1-2\eta}\left\{\sup_% {\breve{f}^{\mathcal{E}}\in\{\mathcal{F}_{S_{g}}\}^{|\mathcal{E}|}}\widehat{% \mathsf{Q}}_{\gamma}(g,\breve{f})-\widehat{\mathsf{Q}}_{\gamma}(g,f^{\mathcal{% E}})\right\}+ italic_γ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT divide start_ARG 4 end_ARG start_ARG 1 - 2 italic_η end_ARG { roman_sup start_POSTSUBSCRIPT over˘ start_ARG italic_f end_ARG start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over^ start_ARG sansserif_Q end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( italic_g , over˘ start_ARG italic_f end_ARG ) - over^ start_ARG sansserif_Q end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) }

Now we are ready to prove Theorem 4.

For the proof of (2) faster L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT rate, we will divide the proof into two main steps as follows.

  1. 1.

    In the first step, we establish a variable selection property claim that when the Eq. B.4 holds, and the events defined in Proposition 7 and 8 occurs, then S^^𝑆\widehat{S}over^ start_ARG italic_S end_ARG satisfies

    eΠS^¯(e)(m(e,S^))=gformulae-sequencefor-all𝑒superscriptsubscriptΠ¯subscript^𝑆𝑒superscript𝑚𝑒^𝑆superscript𝑔\displaystyle\forall e\in\mathcal{E}\qquad\Pi_{\overline{\mathcal{F}_{\widehat% {S}}}}^{(e)}(m^{(e,\widehat{S})})=g^{\star}∀ italic_e ∈ caligraphic_E roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT over^ start_ARG italic_S end_ARG end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , over^ start_ARG italic_S end_ARG ) end_POSTSUPERSCRIPT ) = italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT

    using proof by contradiction that any g𝑔gitalic_g such that such that the above constrain is violated in Sgsubscript𝑆𝑔S_{g}italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT, will not be the approximate solution of the minimax optimization infgsupf𝖰^γ(g,f)subscriptinfimum𝑔subscriptsupremumsuperscript𝑓subscript^𝖰𝛾𝑔superscript𝑓\inf_{g}\sup_{f^{\mathcal{E}}}\widehat{\mathsf{Q}}_{\gamma}(g,f^{\mathcal{E}})roman_inf start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over^ start_ARG sansserif_Q end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ). This can be summarized as the following Proposition 10.

  2. 2.

    In the second step, we proceed conditioned on the above claim and derive a sharp L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error bound. To derive a sharp error bound, we combine (1) the approximate strong convexity with respect to gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, i.e., Theorem 6, (2) the instance-dependent error bound for 𝖩𝖩\mathsf{J}sansserif_J and 𝖱𝖱\mathsf{R}sansserif_R, i.e., Proposition 7 and 8, and (3) the key fact that, if the claim in step 1 holds, then

    g~+f~g~(e)gfg(e)2,esubscriptnorm~𝑔superscriptsubscript~𝑓~𝑔𝑒𝑔superscriptsubscript𝑓𝑔𝑒2𝑒\displaystyle\|\widetilde{g}+\widetilde{f}_{\widetilde{g}}^{(e)}-g-f_{g}^{(e)}% \|_{2,e}∥ over~ start_ARG italic_g end_ARG + over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_g end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g - italic_f start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT g~+f~g~(e)g+ΠSg¯(e)(m(e,Sg))gfg(e)2,eabsentsubscriptnorm~𝑔superscriptsubscript~𝑓~𝑔𝑒superscript𝑔superscriptsubscriptΠ¯subscriptsubscript𝑆𝑔𝑒superscript𝑚𝑒subscript𝑆𝑔𝑔superscriptsubscript𝑓𝑔𝑒2𝑒\displaystyle\leq\|\widetilde{g}+\widetilde{f}_{\widetilde{g}}^{(e)}-g^{\star}% +\Pi_{\overline{\mathcal{F}_{S_{g}}}}^{(e)}(m^{(e,S_{g})})-g-f_{g}^{(e)}\|_{2,e}≤ ∥ over~ start_ARG italic_g end_ARG + over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_g end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ) - italic_g - italic_f start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT
    g~+f~g~(e)g2,e+ggfg(e)2,eless-than-or-similar-toabsentsubscriptnorm~𝑔superscriptsubscript~𝑓~𝑔𝑒superscript𝑔2𝑒subscriptnormsuperscript𝑔𝑔superscriptsubscript𝑓𝑔𝑒2𝑒\displaystyle\lesssim\|\widetilde{g}+\widetilde{f}_{\widetilde{g}}^{(e)}-g^{% \star}\|_{2,e}+\|g^{\star}-g-f_{g}^{(e)}\|_{2,e}≲ ∥ over~ start_ARG italic_g end_ARG + over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT over~ start_ARG italic_g end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT + ∥ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_g - italic_f start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT
    δn,t+δ𝚊,,𝒢.less-than-or-similar-toabsentsubscript𝛿𝑛𝑡superscriptsubscript𝛿𝚊𝒢\displaystyle\lesssim\delta_{n,t}+\delta_{\mathtt{a},\mathcal{F},\mathcal{G}}^% {\star}.≲ italic_δ start_POSTSUBSCRIPT italic_n , italic_t end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT typewriter_a , caligraphic_F , caligraphic_G end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT .

The proof of (1) is similar to the second step in the proof of (2), but now we no longer have g=ΠSg¯(e)(m(e,Sg))superscript𝑔superscriptsubscriptΠ¯subscriptsubscript𝑆𝑔𝑒superscript𝑚𝑒subscript𝑆𝑔g^{\star}=\Pi_{\overline{\mathcal{F}_{S_{g}}}}^{(e)}(m^{(e,S_{g})})italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ). The key challenge here is to establish an upper bound on gΠSg¯(e)(m(e,Sg))2,esubscriptnormsuperscript𝑔superscriptsubscriptΠ¯subscriptsubscript𝑆𝑔𝑒superscript𝑚𝑒subscript𝑆𝑔2𝑒\|g^{\star}-\Pi_{\overline{\mathcal{F}_{S_{g}}}}^{(e)}(m^{(e,S_{g})})\|_{2,e}∥ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT without imposing other population-level condition like Condition 7 in an early version of Fan et al., (2023). Instead, we will use the following instance-dependent bound, that

1||egΠSg¯(e)(m(e,Sg))2,e2C((1+γ)𝖽¯𝒢,(Sg)+gg22)1subscript𝑒superscriptsubscriptnormsuperscript𝑔superscriptsubscriptΠ¯subscriptsubscript𝑆𝑔𝑒superscript𝑚𝑒subscript𝑆𝑔2𝑒2𝐶1superscript𝛾subscript¯𝖽𝒢subscript𝑆𝑔superscriptsubscriptnorm𝑔superscript𝑔22\displaystyle\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}\|g^{\star}-\Pi_{% \overline{\mathcal{F}_{S_{g}}}}^{(e)}(m^{(e,S_{g})})\|_{2,e}^{2}\leq C\left((1% +\gamma^{\star})\bar{\mathsf{d}}_{\mathcal{G},\mathcal{F}}(S_{g})+\|g-g^{\star% }\|_{2}^{2}\right)divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_C ( ( 1 + italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT caligraphic_G , caligraphic_F end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) + ∥ italic_g - italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )

Such a bound is a population-level instance-dependent bound in that both the R.H.S. and L.H.S. are dependent on the function g𝑔gitalic_g.

Proposition 10.

Under the event defined in Proposition 8 and 7, we have the event

𝒜+:={eΠS^¯(e)(m(e,S^))=gforS^=Sg^}assignsubscript𝒜formulae-sequencefor-all𝑒formulae-sequencesuperscriptsubscriptΠ¯subscript^𝑆𝑒superscript𝑚𝑒^𝑆superscript𝑔for^𝑆subscript𝑆^𝑔\displaystyle\mathcal{A}_{+}:=\left\{\forall e\in\mathcal{E}\qquad\Pi_{% \overline{\mathcal{F}_{\widehat{S}}}}^{(e)}(m^{(e,\widehat{S})})=g^{\star}% \qquad\text{for}~{}~{}\widehat{S}=S_{\widehat{g}}\right\}caligraphic_A start_POSTSUBSCRIPT + end_POSTSUBSCRIPT := { ∀ italic_e ∈ caligraphic_E roman_Π start_POSTSUBSCRIPT over¯ start_ARG caligraphic_F start_POSTSUBSCRIPT over^ start_ARG italic_S end_ARG end_POSTSUBSCRIPT end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_m start_POSTSUPERSCRIPT ( italic_e , over^ start_ARG italic_S end_ARG ) end_POSTSUPERSCRIPT ) = italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT for over^ start_ARG italic_S end_ARG = italic_S start_POSTSUBSCRIPT over^ start_ARG italic_g end_ARG end_POSTSUBSCRIPT } (B.10)

occurs if the condition (B.4) with some large universal constant C𝐶Citalic_C holds.

B.4 Applications of Theorem 4 and Connection to the Predecessors

We present some examples here, sorted by the potential approximation capability of the function class (𝒢,)𝒢(\mathcal{G},\mathcal{F})( caligraphic_G , caligraphic_F ).

Example 3 (Linear 𝒢𝒢\mathcal{G}caligraphic_G, Linear \mathcal{F}caligraphic_F).

The simplest case is that 𝒢𝒢\mathcal{G}caligraphic_G and \mathcal{F}caligraphic_F are all linear function classes, that

𝒢=={h(x)=βx:βd}:=𝚕𝚒𝚗(d).𝒢conditional-set𝑥superscript𝛽top𝑥𝛽superscript𝑑assignsubscript𝚕𝚒𝚗𝑑\displaystyle\mathcal{G}=\mathcal{F}=\{h(x)=\beta^{\top}x:\beta\in\mathbb{R}^{% d}\}:=\mathcal{H}_{\mathtt{lin}}(d).caligraphic_G = caligraphic_F = { italic_h ( italic_x ) = italic_β start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x : italic_β ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT } := caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d ) .

The objective takes on a form that closely resembles the EILLS objective proposed in Fan et al., (2023). To see this, the EILLS objective is expressed as 1||e𝔼^[|Y(e)g(X(e))|2]+γ||er^g(e)221subscript𝑒^𝔼delimited-[]superscriptsuperscript𝑌𝑒𝑔superscript𝑋𝑒2𝛾subscript𝑒superscriptsubscriptnormsuperscriptsubscript^𝑟𝑔𝑒22\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}\widehat{\mathbb{E}}[|Y^{(e)}-g(X% ^{(e)})|^{2}]+\frac{\gamma}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}\|\widehat{r}_% {g}^{(e)}\|_{2}^{2}divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT over^ start_ARG blackboard_E end_ARG [ | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] + divide start_ARG italic_γ end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT where r^g(e)=𝔼^[{Y(e)g(X(e))}XSg(e)]superscriptsubscript^𝑟𝑔𝑒^𝔼delimited-[]superscript𝑌𝑒𝑔superscript𝑋𝑒superscriptsubscript𝑋subscript𝑆𝑔𝑒\widehat{r}_{g}^{(e)}=\widehat{\mathbb{E}}[\{Y^{(e)}-g(X^{(e)})\}X_{S_{g}}^{(e% )}]over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = over^ start_ARG blackboard_E end_ARG [ { italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } italic_X start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ]. If we take the supremum over all the f(e)Sgsuperscript𝑓𝑒subscriptsubscript𝑆𝑔f^{(e)}\in\mathcal{F}_{S_{g}}italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∈ caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT with e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, the objective in (4.5) transforms into

supf{Sg}||𝖰^γ(g,f)=1||e𝔼^[|Y(e)g(X(e))|2]+γ||e(r^g(e)){𝔼^[XS(e)(XS(e))]}1(r^g(e)).subscriptsupremumsuperscript𝑓superscriptsubscriptsubscript𝑆𝑔subscript^𝖰𝛾𝑔superscript𝑓1subscript𝑒^𝔼delimited-[]superscriptsuperscript𝑌𝑒𝑔superscript𝑋𝑒2𝛾subscript𝑒superscriptsuperscriptsubscript^𝑟𝑔𝑒topsuperscript^𝔼delimited-[]superscriptsubscript𝑋𝑆𝑒superscriptsuperscriptsubscript𝑋𝑆𝑒top1superscriptsubscript^𝑟𝑔𝑒\displaystyle\sup_{f^{\mathcal{E}}\in\{\mathcal{F}_{S_{g}}\}^{|\mathcal{E}|}}% \widehat{\mathsf{Q}}_{\gamma}(g,f^{\mathcal{E}})=\frac{1}{|\mathcal{E}|}\sum_{% e\in\mathcal{E}}\widehat{\mathbb{E}}[|Y^{(e)}-g(X^{(e)})|^{2}]+\frac{\gamma}{|% \mathcal{E}|}\sum_{e\in\mathcal{E}}(\widehat{r}_{g}^{(e)})^{\top}\{\widehat{% \mathbb{E}}[X_{S}^{(e)}(X_{S}^{(e)})^{\top}]\}^{-1}(\widehat{r}_{g}^{(e)}).roman_sup start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ∈ { caligraphic_F start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT | caligraphic_E | end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over^ start_ARG sansserif_Q end_ARG start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ( italic_g , italic_f start_POSTSUPERSCRIPT caligraphic_E end_POSTSUPERSCRIPT ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT over^ start_ARG blackboard_E end_ARG [ | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_g ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] + divide start_ARG italic_γ end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ( over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT { over^ start_ARG blackboard_E end_ARG [ italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] } start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) .

It slightly stabilizes the EILLS objective in that the regularizer has a matched moment index compared with the pooled least squares loss; see a detailed explanation and theoretical justification in Section B.6.

Example 4 (Linear 𝒢𝒢\mathcal{G}caligraphic_G, Augmented Linear \mathcal{F}caligraphic_F).

Consider the case where \mathcal{F}caligraphic_F is potentially larger than 𝒢𝒢\mathcal{G}caligraphic_G, that is, 𝒢=𝚕𝚒𝚗(d)𝒢subscript𝚕𝚒𝚗𝑑\mathcal{G}=\mathcal{H}_{\mathtt{lin}}(d)caligraphic_G = caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d ) and ={f(x)=βx+βϕϕ¯(x):β,βϕd}:=𝚊𝚕𝚒𝚗(d,ϕ)conditional-set𝑓𝑥superscript𝛽top𝑥superscriptsubscript𝛽italic-ϕtop¯italic-ϕ𝑥𝛽subscript𝛽italic-ϕsuperscript𝑑assignsubscript𝚊𝚕𝚒𝚗𝑑italic-ϕ\mathcal{F}=\{f(x)=\beta^{\top}x+\beta_{\phi}^{\top}\bar{\phi}(x):\beta,\beta_% {\phi}\in\mathbb{R}^{d}\}:=\mathcal{H}_{\mathtt{alin}}(d,\phi)caligraphic_F = { italic_f ( italic_x ) = italic_β start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x + italic_β start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over¯ start_ARG italic_ϕ end_ARG ( italic_x ) : italic_β , italic_β start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT } := caligraphic_H start_POSTSUBSCRIPT typewriter_alin end_POSTSUBSCRIPT ( italic_d , italic_ϕ ), where ϕ¯(x)=(ϕ(x1),,ϕ(xd))¯italic-ϕ𝑥italic-ϕsubscript𝑥1italic-ϕsubscript𝑥𝑑\bar{\phi}(x)=(\phi(x_{1}),\ldots,\phi(x_{d}))over¯ start_ARG italic_ϕ end_ARG ( italic_x ) = ( italic_ϕ ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , italic_ϕ ( italic_x start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) ) applies a transformation function ϕ::italic-ϕ\phi:\mathbb{R}\to\mathbb{R}italic_ϕ : blackboard_R → blackboard_R to each entry of the vector x𝑥xitalic_x.

The proposed estimator utilizes both the heterogeneity among different environments and the strong prior knowledge that the true regression function admits linear form. It bridges the EILLS estimator in Fan et al., (2023) and the Focused GMM estimator in Fan & Liao, (2014) when the instrumental variables are [XS,ϕ¯(XS)]subscript𝑋𝑆¯italic-ϕsubscript𝑋𝑆[X_{S},\bar{\phi}(X_{S})][ italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT , over¯ start_ARG italic_ϕ end_ARG ( italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ) ] and reduces to an improved version of the latter when ||=11|\mathcal{E}|=1| caligraphic_E | = 1.

Example 5 (Linear 𝒢𝒢\mathcal{G}caligraphic_G, Neural Network \mathcal{F}caligraphic_F).

We consider a more algorithmic version of Example 4 that uses neural networks to automatically learn the transformation function, that is, 𝒢=𝚕𝚒𝚗(d)𝒢subscript𝚕𝚒𝚗𝑑\mathcal{G}=\mathcal{H}_{\mathtt{lin}}(d)caligraphic_G = caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d ) and =𝚗𝚗(d,Lf,Nf,Bf)subscript𝚗𝚗𝑑subscript𝐿𝑓subscript𝑁𝑓subscript𝐵𝑓\mathcal{F}=\mathcal{H}_{\mathtt{nn}}(d,L_{f},N_{f},B_{f})caligraphic_F = caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) with neural network architecture hyper-parameters of (Lf,Nf,Bf)subscript𝐿𝑓subscript𝑁𝑓subscript𝐵𝑓(L_{f},N_{f},B_{f})( italic_L start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ).

The above three estimators focus on linear 𝒢𝒢\mathcal{G}caligraphic_G, the simplest structural function class. We now consider a more complicated structural function class when we know the invariant association admits additive form.

Example 6 (Additive Neural Network 𝒢𝒢\mathcal{G}caligraphic_G, Neural Network \mathcal{F}caligraphic_F).

We let 𝒢=𝚊𝚗𝚗(d,Lg,Ng,Bg):={g(x)=TcBg(j=1dgj(xj)):gj𝚗𝚗(1,Lg,Ng,)}𝒢subscript𝚊𝚗𝚗𝑑subscript𝐿𝑔subscript𝑁𝑔subscript𝐵𝑔assignconditional-set𝑔𝑥subscriptTcsubscript𝐵𝑔superscriptsubscript𝑗1𝑑subscript𝑔𝑗subscript𝑥𝑗subscript𝑔𝑗subscript𝚗𝚗1subscript𝐿𝑔subscript𝑁𝑔\mathcal{G}=\mathcal{H}_{\mathtt{ann}}(d,L_{g},N_{g},B_{g}):=\{g(x)=\mathrm{Tc% }_{B_{g}}(\sum_{j=1}^{d}g_{j}(x_{j})):g_{j}\in\mathcal{H}_{\mathtt{nn}}(1,L_{g% },N_{g},\infty)\}caligraphic_G = caligraphic_H start_POSTSUBSCRIPT typewriter_ann end_POSTSUBSCRIPT ( italic_d , italic_L start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) := { italic_g ( italic_x ) = roman_Tc start_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ) : italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∈ caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( 1 , italic_L start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , ∞ ) } and =𝚗𝚗(d,Lf,Nf,Bf)subscript𝚗𝚗𝑑subscript𝐿𝑓subscript𝑁𝑓subscript𝐵𝑓\mathcal{F}=\mathcal{H}_{\mathtt{nn}}(d,L_{f},N_{f},B_{f})caligraphic_F = caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ). Here (Lg,Ng,Bg)subscript𝐿𝑔subscript𝑁𝑔subscript𝐵𝑔(L_{g},N_{g},B_{g})( italic_L start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) and (Lf,Nf,Bf)subscript𝐿𝑓subscript𝑁𝑓subscript𝐵𝑓(L_{f},N_{f},B_{f})( italic_L start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) are all neural network architecture hyper-parameters.

Finally, we present the most algorithmic estimator, the FAIR-NN estimator, in which both 𝒢𝒢\mathcal{G}caligraphic_G and \mathcal{F}caligraphic_F are realized by fully-connected neural networks with no additional imposed structures.

Example 7 (Neural Network 𝒢𝒢\mathcal{G}caligraphic_G, Neural Network \mathcal{F}caligraphic_F).

We let 𝒢=𝚗𝚗(d,Lg,Ng,Bg)𝒢subscript𝚗𝚗𝑑subscript𝐿𝑔subscript𝑁𝑔subscript𝐵𝑔\mathcal{G}=\mathcal{H}_{\mathtt{nn}}(d,L_{g},N_{g},B_{g})caligraphic_G = caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) and =𝚗𝚗(d,Lf,Nf,Bf)subscript𝚗𝚗𝑑subscript𝐿𝑓subscript𝑁𝑓subscript𝐵𝑓\mathcal{F}=\mathcal{H}_{\mathtt{nn}}(d,L_{f},N_{f},B_{f})caligraphic_F = caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) with neural network architecture hyper-parameters (Lg,Ng,Bg)subscript𝐿𝑔subscript𝑁𝑔subscript𝐵𝑔(L_{g},N_{g},B_{g})( italic_L start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) and (Lf,Nf,Bf)subscript𝐿𝑓subscript𝑁𝑓subscript𝐵𝑓(L_{f},N_{f},B_{f})( italic_L start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ).

𝒢𝒢\mathcal{G}caligraphic_G \mathcal{F}caligraphic_F Category Short Name Result
Example 3 𝚕𝚒𝚗(d)subscript𝚕𝚒𝚗𝑑\mathcal{H}_{\mathtt{lin}}(d)caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d ) 𝚕𝚒𝚗(d)subscript𝚕𝚒𝚗𝑑\mathcal{H}_{\mathtt{lin}}(d)caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d ) 𝒢asymptotically-equals𝒢\mathcal{G}\asymp\mathcal{F}caligraphic_G ≍ caligraphic_F FAIR-Linear Theorem 8
Example 7 𝚗𝚗(d,Lg,Ng,Bg)subscript𝚗𝚗𝑑subscript𝐿𝑔subscript𝑁𝑔subscript𝐵𝑔\mathcal{H}_{\mathtt{nn}}(d,L_{g},N_{g},B_{g})caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) 𝚗𝚗(d,Lf,Nf,Bf)subscript𝚗𝚗𝑑subscript𝐿𝑓subscript𝑁𝑓subscript𝐵𝑓\mathcal{H}_{\mathtt{nn}}(d,L_{f},N_{f},B_{f})caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) 𝒢asymptotically-equals𝒢\mathcal{G}\asymp\mathcal{F}caligraphic_G ≍ caligraphic_F FAIR-NN Theorem 1
Example 4 𝚕𝚒𝚗(d)subscript𝚕𝚒𝚗𝑑\mathcal{H}_{\mathtt{lin}}(d)caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d ) 𝚊𝚕𝚒𝚗(d,ϕ)subscript𝚊𝚕𝚒𝚗𝑑italic-ϕ\mathcal{H}_{\mathtt{alin}}(d,\phi)caligraphic_H start_POSTSUBSCRIPT typewriter_alin end_POSTSUBSCRIPT ( italic_d , italic_ϕ ) 𝒢much-less-than𝒢\mathcal{G}\ll\mathcal{F}caligraphic_G ≪ caligraphic_F FAIR-AugLinear Theorem 9
Example 5 𝚕𝚒𝚗(d)subscript𝚕𝚒𝚗𝑑\mathcal{H}_{\mathtt{lin}}(d)caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d ) 𝚗𝚗(d,Lf,Nf,Bf)subscript𝚗𝚗𝑑subscript𝐿𝑓subscript𝑁𝑓subscript𝐵𝑓\mathcal{H}_{\mathtt{nn}}(d,L_{f},N_{f},B_{f})caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) 𝒢much-less-than𝒢\mathcal{G}\ll\mathcal{F}caligraphic_G ≪ caligraphic_F FAIR-NNLinear Theorem 10
Example 6 𝚊𝚗𝚗(d,Lg,Ng,Bg)subscript𝚊𝚗𝚗𝑑subscript𝐿𝑔subscript𝑁𝑔subscript𝐵𝑔\mathcal{H}_{\mathtt{ann}}(d,L_{g},N_{g},B_{g})caligraphic_H start_POSTSUBSCRIPT typewriter_ann end_POSTSUBSCRIPT ( italic_d , italic_L start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) 𝚗𝚗(d,Lf,Nf,Bf)subscript𝚗𝚗𝑑subscript𝐿𝑓subscript𝑁𝑓subscript𝐵𝑓\mathcal{H}_{\mathtt{nn}}(d,L_{f},N_{f},B_{f})caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) 𝒢much-less-than𝒢\mathcal{G}\ll\mathcal{F}caligraphic_G ≪ caligraphic_F FAIR-ANN Theorem 7
Table 2: A Glimpse of Estimators

Our framework requires 𝒢𝒢\mathcal{G}\subseteq\mathcal{F}caligraphic_G ⊆ caligraphic_F. We can divide the above estimators into two main categories that (1) 𝒢𝒢\mathcal{G}caligraphic_G has roughly the same representation power as \mathcal{F}caligraphic_F, denoted as 𝒢asymptotically-equals𝒢\mathcal{G}\asymp\mathcal{F}caligraphic_G ≍ caligraphic_F, and (2) \mathcal{F}caligraphic_F has at least as good representation power as 𝒢𝒢\mathcal{G}caligraphic_G, denoted as 𝒢much-less-than𝒢\mathcal{G}\ll\mathcal{F}caligraphic_G ≪ caligraphic_F. For the former, our framework uses only heterogeneity among different environments to identify the invariant association. For the latter, our framework utilizes both the heterogeneity and strong prior structural assumption that the invariant association cannot be significantly better approximated by \mathcal{F}caligraphic_F than by 𝒢𝒢\mathcal{G}caligraphic_G to jointly identify the invariant association. We summarize the proposed estimators above and divide them into these two categories in Table 2.

B.5 FAIR-ANN: Bridging Invariance and Additional Structural Knowledge

We next consider the estimator that utilizes both heterogeneity and the strong structural assumption that the invariant association msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT admits additive form to identify msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, which can be summarized as the following assumption.

Condition 13 (Invariance and Nondegenerate Covariate for FAIR-ANN).

There exists some set Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and m:|S|:superscript𝑚superscriptsuperscript𝑆m^{\star}:\mathbb{R}^{|S^{\star}|}\to\mathbb{R}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT : blackboard_R start_POSTSUPERSCRIPT | italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT | end_POSTSUPERSCRIPT → blackboard_R such that m(e,S)(x)m(xS)=jSmj(xj)superscript𝑚𝑒superscript𝑆𝑥superscript𝑚subscript𝑥superscript𝑆subscript𝑗superscript𝑆subscriptsuperscript𝑚𝑗subscript𝑥𝑗m^{(e,S^{\star})}(x)\equiv m^{\star}(x_{S^{\star}})=\sum_{j\in S^{\star}}m^{% \star}_{j}(x_{j})italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ( italic_x ) ≡ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_j ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) for any e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E. Moreover, for any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] with SSsuperscript𝑆𝑆S^{\star}\setminus S\neq\emptysetitalic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∖ italic_S ≠ ∅, infmΘSmm22smin>0subscriptinfimum𝑚subscriptΘ𝑆superscriptsubscriptnorm𝑚superscript𝑚22subscript𝑠0\inf_{m\in\Theta_{S}}\|m-m^{\star}\|_{2}^{2}\geq s_{\min}>0roman_inf start_POSTSUBSCRIPT italic_m ∈ roman_Θ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_m - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT > 0.

Condition 14 (Boundedness in Nonparametric Regression).

There exists some constants bxsubscript𝑏𝑥b_{x}italic_b start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT and bmsubscript𝑏𝑚b_{m}italic_b start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT such that (1) X[bx,bx]d𝑋superscriptsubscript𝑏𝑥subscript𝑏𝑥𝑑X\in[-b_{x},b_{x}]^{d}italic_X ∈ [ - italic_b start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT μ¯¯𝜇\bar{\mu}over¯ start_ARG italic_μ end_ARG-a.s. and (2) m(e,S)bmsubscriptnormsuperscript𝑚𝑒𝑆subscript𝑏𝑚\|m^{(e,S)}\|_{\infty}\leq b_{m}∥ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_b start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT for any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] and e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E.

Condition 15.

There exists some constant Casubscript𝐶𝑎C_{a}italic_C start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT such that

j=1dmj(xj)22Ca1j=1dmj(xj)22(m1,,md)j=1dΘ{j}withmj(xj)μ¯x(dx)0.formulae-sequencesuperscriptsubscriptnormsuperscriptsubscript𝑗1𝑑subscript𝑚𝑗subscript𝑥𝑗22superscriptsubscript𝐶𝑎1superscriptsubscript𝑗1𝑑superscriptsubscriptnormsubscript𝑚𝑗subscript𝑥𝑗22for-allsubscript𝑚1subscript𝑚𝑑superscriptsubscriptproduct𝑗1𝑑subscriptΘ𝑗withsubscript𝑚𝑗subscript𝑥𝑗subscript¯𝜇𝑥𝑑𝑥0\displaystyle\left\|\sum_{j=1}^{d}m_{j}(x_{j})\right\|_{2}^{2}\geq C_{a}^{-1}% \sum_{j=1}^{d}\|m_{j}(x_{j})\|_{2}^{2}\qquad\forall(m_{1},\ldots,m_{d})\in% \prod_{j=1}^{d}\Theta_{\{j\}}~{}\text{with}~{}\int m_{j}(x_{j})\bar{\mu}_{x}(% dx)\equiv 0.∥ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ italic_C start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ∥ italic_m start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∀ ( italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_m start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) ∈ ∏ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT roman_Θ start_POSTSUBSCRIPT { italic_j } end_POSTSUBSCRIPT with ∫ italic_m start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_d italic_x ) ≡ 0 .

The above condition is referred to as the nonparametric version of the restricted strong convexity condition, which is widely used in the theoretical analysis for nonparametric high-dimension additive models (Van de Geer,, 2008; Raskutti et al.,, 2012; Yuan & Zhou,, 2016). This condition is imposed to let jSΘ{j}subscriptproduct𝑗𝑆subscriptΘ𝑗\prod_{j\in S}\Theta_{\{j\}}∏ start_POSTSUBSCRIPT italic_j ∈ italic_S end_POSTSUBSCRIPT roman_Θ start_POSTSUBSCRIPT { italic_j } end_POSTSUBSCRIPT be a closed subspace of ΘSsubscriptΘ𝑆\Theta_{S}roman_Θ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT, where we can define

AS(h)=argminujSΘ{j}hu2,subscript𝐴𝑆subscriptargmin𝑢subscriptproduct𝑗𝑆subscriptΘ𝑗subscriptnorm𝑢2\displaystyle A_{S}(h)=\mathop{\mathrm{argmin}}_{u\in\prod_{j\in S}\Theta_{\{j% \}}}\|h-u\|_{2},italic_A start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ( italic_h ) = roman_argmin start_POSTSUBSCRIPT italic_u ∈ ∏ start_POSTSUBSCRIPT italic_j ∈ italic_S end_POSTSUBSCRIPT roman_Θ start_POSTSUBSCRIPT { italic_j } end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ italic_h - italic_u ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ,

which finds a unique additive function dependent on xSsubscript𝑥𝑆x_{S}italic_x start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT that fits hhitalic_h best in 2\|\cdot\|_{2}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT norm.

Condition 16 (Identification for FAIR-ANN).

For any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] such that μ¯({mASS(m¯(SS))})>0¯𝜇superscript𝑚subscript𝐴𝑆superscript𝑆superscript¯𝑚𝑆superscript𝑆0\bar{\mu}(\{m^{\star}\neq A_{S\cup S^{\star}}(\bar{m}^{(S\cup S^{\star})})\})>0over¯ start_ARG italic_μ end_ARG ( { italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ≠ italic_A start_POSTSUBSCRIPT italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ) } ) > 0, either of the two holds: (1) there exists some e,e𝑒superscript𝑒e,e^{\prime}\in\mathcal{E}italic_e , italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E such that (μ(e)μ(e))({m(e,S)m(e,S)})>0superscript𝜇𝑒superscript𝜇superscript𝑒superscript𝑚𝑒𝑆superscript𝑚superscript𝑒𝑆0(\mu^{(e)}\land\mu^{(e^{\prime})})(\{m^{(e,S)}\neq m^{(e^{\prime},S)}\})>0( italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ∧ italic_μ start_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ) ( { italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≠ italic_m start_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_S ) end_POSTSUPERSCRIPT } ) > 0, or (2) μ¯({m¯(S)AS(m¯(S))})>0¯𝜇superscript¯𝑚𝑆subscript𝐴𝑆superscript¯𝑚𝑆0\bar{\mu}(\{\bar{m}^{(S)}\neq A_{S}(\bar{m}^{(S)})\})>0over¯ start_ARG italic_μ end_ARG ( { over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ≠ italic_A start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ) } ) > 0.

With network hyper-parameter N,L𝑁𝐿N,Litalic_N , italic_L, we realize the 𝒢𝒢\mathcal{G}caligraphic_G and \mathcal{F}caligraphic_F as

𝒢=𝚊𝚗𝚗(d,L,N,bm)and=𝚗𝚗(d,L+2,2dN,2bm).formulae-sequence𝒢subscript𝚊𝚗𝚗𝑑𝐿𝑁subscript𝑏𝑚andsubscript𝚗𝚗𝑑𝐿22𝑑𝑁2subscript𝑏𝑚\displaystyle\mathcal{G}=\mathcal{H}_{\mathtt{ann}}(d,L,N,b_{m})\qquad\text{% and}\qquad\mathcal{F}=\mathcal{H}_{\mathtt{nn}}(d,L+2,2dN,2b_{m}).caligraphic_G = caligraphic_H start_POSTSUBSCRIPT typewriter_ann end_POSTSUBSCRIPT ( italic_d , italic_L , italic_N , italic_b start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) and caligraphic_F = caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , italic_L + 2 , 2 italic_d italic_N , 2 italic_b start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) . (B.11)

Similarly to the choice of for FAIR-NN (2.3), the choice of \mathcal{F}caligraphic_F is to ensure 𝒢𝒢𝒢𝒢\mathcal{G}-\mathcal{G}\subseteq\mathcal{F}caligraphic_G - caligraphic_G ⊆ caligraphic_F.

Theorem 7 (Optimal Rate for FAIR-ANN Least Squares Estimator).

Assume 7,8, and 1316 hold. Assume further that all the conditional moments {m(e,S)}e,S[d]subscriptsuperscript𝑚𝑒𝑆formulae-sequence𝑒𝑆delimited-[]𝑑\{m^{(e,S)}\}_{e\in\mathcal{E},S\subseteq[d]}{ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E , italic_S ⊆ [ italic_d ] end_POSTSUBSCRIPT are (β,C)superscript𝛽superscript𝐶(\beta^{\prime},C^{\prime})( italic_β start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT )-smooth for some β>0superscript𝛽0\beta^{\prime}>0italic_β start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT > 0 and C>0superscript𝐶0C^{\prime}>0italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT > 0, and δ𝚘𝚙𝚝=o(1)subscript𝛿𝚘𝚙𝚝𝑜1\delta_{\mathtt{opt}}=o(1)italic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT = italic_o ( 1 ). Consider the FAIR-ANN estimator that solves (B.2) with (y,v)=12(yv)2𝑦𝑣12superscript𝑦𝑣2\ell(y,v)=\frac{1}{2}(y-v)^{2}roman_ℓ ( italic_y , italic_v ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_y - italic_v ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT using γ8γ𝙰𝙽𝛾8subscriptsuperscript𝛾𝙰𝙽\gamma\geq 8\gamma^{\star}_{\mathtt{AN}}italic_γ ≥ 8 italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_AN end_POSTSUBSCRIPT with

γ𝙰𝙽:=supS[d]:μ¯({mASS(m¯(SS))})>0mASS(m¯(SS)))221||em(e,S)AS(m¯(S)))2,e2,\displaystyle\gamma_{\mathtt{AN}}^{\star}:=\sup_{S\subseteq[d]:\bar{\mu}(\{m^{% \star}\neq A_{S\cup S^{\star}}(\bar{m}^{(S\cup S^{\star})})\})>0}\frac{\|m^{% \star}-A_{S\cup S^{\star}}(\bar{m}^{(S\cup S^{\star})}))\|_{2}^{2}}{\frac{1}{|% \mathcal{E}|}\sum_{e\in\mathcal{E}}\|m^{(e,S)}-A_{S}(\bar{m}^{(S)}))\|_{2,e}^{% 2}},italic_γ start_POSTSUBSCRIPT typewriter_AN end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT := roman_sup start_POSTSUBSCRIPT italic_S ⊆ [ italic_d ] : over¯ start_ARG italic_μ end_ARG ( { italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ≠ italic_A start_POSTSUBSCRIPT italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ) } ) > 0 end_POSTSUBSCRIPT divide start_ARG ∥ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_A start_POSTSUBSCRIPT italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT - italic_A start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ( over¯ start_ARG italic_m end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG , (B.12)

and function class (B.11) with L,N𝐿𝑁L,Nitalic_L , italic_N satisfying LN{n(logn)8β3}12(2β+1)asymptotically-equals𝐿𝑁superscript𝑛superscript𝑛8superscript𝛽3122superscript𝛽1LN\asymp\{n(\log n)^{8\beta^{\star}-3}\}^{\frac{1}{2(2\beta^{\star}+1)}}italic_L italic_N ≍ { italic_n ( roman_log italic_n ) start_POSTSUPERSCRIPT 8 italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT } start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 ( 2 italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + 1 ) end_ARG end_POSTSUPERSCRIPT and (logn)/(NL)=o(1)𝑛𝑁𝐿𝑜1(\log n)/(N\land L)=o(1)( roman_log italic_n ) / ( italic_N ∧ italic_L ) = italic_o ( 1 ). Then, we have (1) γ𝙰𝙽γ𝙽𝙽superscriptsubscript𝛾𝙰𝙽superscriptsubscript𝛾𝙽𝙽\gamma_{\mathtt{AN}}^{\star}\leq\gamma_{\mathtt{NN}}^{\star}italic_γ start_POSTSUBSCRIPT typewriter_AN end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ≤ italic_γ start_POSTSUBSCRIPT typewriter_NN end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, and (2) for n𝑛nitalic_n large enough, the following event occurs with probability at least 1C~n1001~𝐶superscript𝑛1001-\widetilde{C}n^{-100}1 - over~ start_ARG italic_C end_ARG italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT

supm=jSmj(xj)withmj𝙷𝚂(1,β,C)mbmg^m2C~{δ𝚘𝚙𝚝+(log7nn)β2β+1},subscriptsupremumsuperscript𝑚subscript𝑗superscript𝑆subscriptsuperscript𝑚𝑗subscript𝑥𝑗withsuperscriptsubscript𝑚𝑗subscript𝙷𝚂1superscript𝛽superscript𝐶subscriptnormsuperscript𝑚subscript𝑏𝑚subscriptnorm^𝑔superscript𝑚2~𝐶subscript𝛿𝚘𝚙𝚝superscriptsuperscript7𝑛𝑛superscript𝛽2superscript𝛽1\displaystyle\sup_{\begin{subarray}{c}m^{\star}=\sum_{j\in S^{\star}}m^{\star}% _{j}(x_{j})~{}\text{with}~{}m_{j}^{\star}\in\mathcal{H}_{\mathtt{HS}}(1,\beta^% {\star},C^{\star})\\ \|m^{\star}\|_{\infty}\leq b_{m}\end{subarray}}\|\widehat{g}-m^{\star}\|_{2}% \leq\widetilde{C}\left\{\delta_{\mathtt{opt}}+\left(\frac{\log^{7}n}{n}\right)% ^{-\frac{\beta^{\star}}{2\beta^{\star}+1}}\right\},roman_sup start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_j ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) with italic_m start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ caligraphic_H start_POSTSUBSCRIPT typewriter_HS end_POSTSUBSCRIPT ( 1 , italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_C start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL ∥ italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_b start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_CELL end_ROW end_ARG end_POSTSUBSCRIPT ∥ over^ start_ARG italic_g end_ARG - italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ over~ start_ARG italic_C end_ARG { italic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT + ( divide start_ARG roman_log start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT italic_n end_ARG start_ARG italic_n end_ARG ) start_POSTSUPERSCRIPT - divide start_ARG italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + 1 end_ARG end_POSTSUPERSCRIPT } , (B.13)

where C~~𝐶\widetilde{C}over~ start_ARG italic_C end_ARG is a constant that depends on (C1,d,β,C,σy,Cy,bx,bm)subscript𝐶1𝑑superscript𝛽superscript𝐶subscript𝜎𝑦subscript𝐶𝑦subscript𝑏𝑥subscript𝑏𝑚(C_{1},d,\beta^{\star},C^{\star},\sigma_{y},C_{y},b_{x},b_{m})( italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_d , italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_C start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) but independent of γ,δ𝚘𝚙𝚝𝛾subscript𝛿𝚘𝚙𝚝\gamma,\delta_{\mathtt{opt}}italic_γ , italic_δ start_POSTSUBSCRIPT typewriter_opt end_POSTSUBSCRIPT and n𝑛nitalic_n.

The choice of N,L𝑁𝐿N,Litalic_N , italic_L, and the convergence rate align with FAIR-NN with α=βsuperscript𝛼superscript𝛽\alpha^{\star}=\beta^{\star}italic_α start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Given the strong structural prior knowledge that the true regression function is additive, FAIR-ANN requires weaker identification condition 16 and also smaller critical threshold of γ𝛾\gammaitalic_γ. In particular, 16 requires that for any S𝑆Sitalic_S such that regressing Y𝑌Yitalic_Y on XSSsubscript𝑋𝑆superscript𝑆X_{S\cup S^{\star}}italic_X start_POSTSUBSCRIPT italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT via additive models yields biased estimation, there should be either (1) a shift in conditional moments m(e,S)superscript𝑚𝑒𝑆m^{(e,S)}italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT across different environments, or (2) one of the conditional moments m(e,S)superscript𝑚𝑒𝑆m^{(e,S)}italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT is non-additive. This characteristic is called the “double identifiable” property since meeting either of these conditions can consistently estimate msuperscript𝑚m^{\star}italic_m start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Notably, the critical threshold γ𝙰𝙽subscriptsuperscript𝛾𝙰𝙽\gamma^{\star}_{\mathtt{AN}}italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_AN end_POSTSUBSCRIPT can be smaller than that of the FAIR-NN estimator. A small γ𝛾\gammaitalic_γ can be adopted if either the signal of violating the additive structure or the signal of heterogeneity is strong.

B.6 Theoretical Analysis for Linear 𝒢𝒢\mathcal{G}caligraphic_G

In this section, we apply our result in Theorem 4 to the cases where the target regression function gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is linear. As such, we use linear function class 𝚕𝚒𝚗(d)subscript𝚕𝚒𝚗𝑑\mathcal{H}_{\mathtt{lin}}(d)caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d ) as our predictor function class 𝒢𝒢\mathcal{G}caligraphic_G. Our theorem suggests that enhancing the potential approximation ability of the discriminator function class \mathcal{F}caligraphic_F will result in (1) a stronger condition on invariance, and (2) a weaker identification condition and a reduced choice of critical threshold γsuperscript𝛾\gamma^{\star}italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

B.6.1 Linear \mathcal{F}caligraphic_F

We first consider the case where we use linear discriminator function class =𝚕𝚒𝚗(d)subscript𝚕𝚒𝚗𝑑\mathcal{F}=\mathcal{H}_{\mathtt{lin}}(d)caligraphic_F = caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d ). We introduce some notations used in linear regression and state some standard regularity conditions used in linear regression and are also imposed in Fan et al., (2023).

Condition 17.

Suppose the following holds:

  • (1)

    The data satisfies 7 with ||nC1superscript𝑛subscript𝐶1|\mathcal{E}|\leq n^{C_{1}}| caligraphic_E | ≤ italic_n start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT for some constant C1subscript𝐶1C_{1}italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

  • (2)

    The covariance matrix Σ(e)=𝔼[X(e)(X(e))]d×dsuperscriptΣ𝑒𝔼delimited-[]superscript𝑋𝑒superscriptsuperscript𝑋𝑒topsuperscript𝑑𝑑\Sigma^{(e)}=\mathbb{E}[X^{(e)}(X^{(e)})^{\top}]\in\mathbb{R}^{d\times d}roman_Σ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = blackboard_E [ italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT in each environment satisfies λ(Σ(e))κL𝜆superscriptΣ𝑒subscript𝜅𝐿\lambda(\Sigma^{(e)})\geq\kappa_{L}italic_λ ( roman_Σ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ≥ italic_κ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT for some constant κL>0subscript𝜅𝐿0\kappa_{L}>0italic_κ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT > 0.

  • (3)

    Define the pooled covariance matrix Σ:=||1eΣ(e)assignΣsuperscript1subscript𝑒superscriptΣ𝑒{\Sigma}:=|\mathcal{E}|^{-1}\sum_{e\in\mathcal{E}}\Sigma^{(e)}roman_Σ := | caligraphic_E | start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT roman_Σ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT. There exists some positive constant Cx,σxsubscript𝐶𝑥subscript𝜎𝑥C_{x},\sigma_{x}italic_C start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT such that

    e,vdwithv2=1,t[0,),(|v(Σ)1/2X(e)|t)Cxet2/(2σx2)formulae-sequenceformulae-sequencefor-all𝑒for-all𝑣superscript𝑑withsubscriptnorm𝑣21formulae-sequencefor-all𝑡0superscript𝑣topsuperscriptΣ12superscript𝑋𝑒𝑡subscript𝐶𝑥superscript𝑒superscript𝑡22superscriptsubscript𝜎𝑥2\displaystyle\forall e\in\mathcal{E},~{}\forall v\in\mathbb{R}^{d}~{}\text{% with}~{}\|v\|_{2}=1,~{}\forall t\in[0,\infty),\qquad\mathbb{P}\left(|v^{\top}(% {\Sigma})^{-1/2}X^{(e)}|\geq t\right)\leq C_{x}e^{-t^{2}/(2\sigma_{x}^{2})}∀ italic_e ∈ caligraphic_E , ∀ italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with ∥ italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 , ∀ italic_t ∈ [ 0 , ∞ ) , blackboard_P ( | italic_v start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( roman_Σ ) start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | ≥ italic_t ) ≤ italic_C start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT - italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( 2 italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT
  • (4)

    8 holds.

Under 17 that the covariance matrices are all positive definite, we can define

β(e,S)=argminβd:βSc=0𝔼[|Y(e)βX(e)|2]superscript𝛽𝑒𝑆subscriptargmin:𝛽superscript𝑑subscript𝛽superscript𝑆𝑐0𝔼delimited-[]superscriptsuperscript𝑌𝑒superscript𝛽topsuperscript𝑋𝑒2\displaystyle\beta^{(e,S)}=\mathop{\mathrm{argmin}}_{\beta\in\mathbb{R}^{d}:% \beta_{S^{c}}=0}\mathbb{E}[|Y^{(e)}-\beta^{\top}X^{(e)}|^{2}]italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT = roman_argmin start_POSTSUBSCRIPT italic_β ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT : italic_β start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = 0 end_POSTSUBSCRIPT blackboard_E [ | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_β start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]

We can state the invariance and identification condition in this case.

Condition 18 (Invariance in Linear 𝒢𝒢\mathcal{G}caligraphic_G and Linear \mathcal{F}caligraphic_F).

There exists some S[d]superscript𝑆delimited-[]𝑑S^{\star}\subseteq[d]italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ [ italic_d ] and βdsuperscript𝛽superscript𝑑\beta^{\star}\in\mathbb{R}^{d}italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with β(S)c=0subscriptsuperscript𝛽superscriptsuperscript𝑆𝑐0\beta^{\star}_{(S^{\star})^{c}}=0italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = 0 and minjS|βj|=βmin>0subscript𝑗superscript𝑆superscriptsubscript𝛽𝑗subscript𝛽0\min_{j\in S^{\star}}|\beta_{j}^{\star}|=\beta_{\min}>0roman_min start_POSTSUBSCRIPT italic_j ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT | = italic_β start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT > 0 such that

eβ(e,S)=β.formulae-sequencefor-all𝑒superscript𝛽𝑒𝑆superscript𝛽\displaystyle\forall e\in\mathcal{E}\qquad\beta^{(e,S)}=\beta^{\star}.∀ italic_e ∈ caligraphic_E italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT = italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT . (B.14)

Let ε(e)=Y(e)(β)X(e)superscript𝜀𝑒superscript𝑌𝑒superscriptsuperscript𝛽topsuperscript𝑋𝑒\varepsilon^{(e)}=Y^{(e)}-(\beta^{\star})^{\top}X^{(e)}italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT, the above invariance equality (B.14) is equivalent to that XSsubscript𝑋superscript𝑆X_{S^{\star}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT are exogenous across all the environments, that is,

e𝔼[ε(e)XS(e)]=0formulae-sequencefor-all𝑒𝔼delimited-[]superscript𝜀𝑒superscriptsubscript𝑋superscript𝑆𝑒0\displaystyle\forall e\in\mathcal{E}\qquad\mathbb{E}[\varepsilon^{(e)}X_{S^{% \star}}^{(e)}]=0∀ italic_e ∈ caligraphic_E blackboard_E [ italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] = 0
Condition 19 (Identification for Linear 𝒢𝒢\mathcal{G}caligraphic_G and Linear \mathcal{F}caligraphic_F).

For any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] with e𝔼[XS(e)ε(e)]0subscript𝑒𝔼delimited-[]superscriptsubscript𝑋𝑆𝑒superscript𝜀𝑒0\sum_{e\in\mathcal{E}}\mathbb{E}[X_{S}^{(e)}\varepsilon^{(e)}]\neq 0∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E [ italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ≠ 0, there exists e,e𝑒superscript𝑒e,e^{\prime}\in\mathcal{E}italic_e , italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E such that β(e,S)β(e,S)superscript𝛽𝑒𝑆superscript𝛽superscript𝑒𝑆\beta^{(e,S)}\neq\beta^{(e^{\prime},S)}italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≠ italic_β start_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_S ) end_POSTSUPERSCRIPT.

We are ready to state the result using truncated linear function class with bounded L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT norm, that is,

𝚕𝚒𝚗(d,B1,B2)={f(x)=TcB2(βx):βd,Σ1/2β2B1}.subscript𝚕𝚒𝚗𝑑subscript𝐵1subscript𝐵2conditional-set𝑓𝑥subscriptTcsubscript𝐵2superscript𝛽top𝑥formulae-sequence𝛽superscript𝑑subscriptnormsuperscriptΣ12𝛽2subscript𝐵1\displaystyle\mathcal{H}_{\mathtt{lin}}(d,B_{1},B_{2})=\left\{f(x)=\mathrm{Tc}% _{B_{2}}(\beta^{\top}x):\beta\in\mathbb{R}^{d},\|\Sigma^{1/2}\beta\|_{2}\leq B% _{1}\right\}.caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d , italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = { italic_f ( italic_x ) = roman_Tc start_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_β start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) : italic_β ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , ∥ roman_Σ start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_β ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } .
Theorem 8 (Linear 𝒢𝒢\mathcal{G}caligraphic_G and Linear \mathcal{F}caligraphic_F).

Suppose 1719 hold, and we choose

𝒢=𝚕𝚒𝚗(d,C2,C2logn)and=𝚕𝚒𝚗(d,2C2,2C2logn)formulae-sequence𝒢subscript𝚕𝚒𝚗𝑑subscript𝐶2subscript𝐶2𝑛andsubscript𝚕𝚒𝚗𝑑2subscript𝐶22subscript𝐶2𝑛\displaystyle\mathcal{G}=\mathcal{H}_{\mathtt{lin}}(d,C_{2},C_{2}\sqrt{\log n}% )\qquad\text{and}\qquad\mathcal{F}=\mathcal{H}_{\mathtt{lin}}(d,2C_{2},2C_{2}% \sqrt{\log n})caligraphic_G = caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT square-root start_ARG roman_log italic_n end_ARG ) and caligraphic_F = caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d , 2 italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , 2 italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT square-root start_ARG roman_log italic_n end_ARG )

with some constant C22(σx1)maxe,S[d]Σ1/2β(e,S)2subscript𝐶22subscript𝜎𝑥1subscriptformulae-sequence𝑒𝑆delimited-[]𝑑subscriptnormsuperscriptΣ12superscript𝛽𝑒𝑆2C_{2}\geq 2(\sigma_{x}\lor 1)\max_{e\in\mathcal{E},S\subseteq[d]}\|\Sigma^{1/2% }\beta^{(e,S)}\|_{2}italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≥ 2 ( italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ∨ 1 ) roman_max start_POSTSUBSCRIPT italic_e ∈ caligraphic_E , italic_S ⊆ [ italic_d ] end_POSTSUBSCRIPT ∥ roman_Σ start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Then, there exists some constant C~~𝐶\widetilde{C}over~ start_ARG italic_C end_ARG that only depends on (C1,C2,σx,Cx,σy,Cy)subscript𝐶1subscript𝐶2subscript𝜎𝑥subscript𝐶𝑥subscript𝜎𝑦subscript𝐶𝑦(C_{1},C_{2},\sigma_{x},C_{x},\sigma_{y},C_{y})( italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) such that the FAIR least squares estimator using the above function class and hyper-parameter γ𝛾\gammaitalic_γ satisfying γ8γ𝙻𝙻=8supS:𝖻𝙻𝙻(S)>0𝖻𝙻𝙻(S)/𝖽¯𝙻𝙻(S)𝛾8superscriptsubscript𝛾𝙻𝙻8subscriptsupremum:𝑆subscript𝖻𝙻𝙻𝑆0subscript𝖻𝙻𝙻𝑆subscript¯𝖽𝙻𝙻𝑆\gamma\geq 8\gamma_{\mathtt{LL}}^{\star}=8\sup_{S:\mathsf{b}_{\mathtt{LL}}(S)>% 0}\mathsf{b}_{\mathtt{LL}}(S)/\bar{\mathsf{d}}_{\mathtt{LL}}(S)italic_γ ≥ 8 italic_γ start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = 8 roman_sup start_POSTSUBSCRIPT italic_S : sansserif_b start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ) > 0 end_POSTSUBSCRIPT sansserif_b start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ) / over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ), where

𝖻𝙻𝙻(S)=1||e𝔼[XSS(e)ε(e)](Σ¯SS)12(κL)11||e𝔼[XS(e)ε(e)]22,𝖽¯𝙻𝙻(S)=1||eβS(e,S)β(S)ΣS(e)2κL1||eβ(e,S)β¯(S)22formulae-sequencesubscript𝖻𝙻𝙻𝑆superscriptsubscriptdelimited-∥∥1subscript𝑒𝔼delimited-[]superscriptsubscript𝑋𝑆superscript𝑆𝑒superscript𝜀𝑒superscriptsubscript¯Σ𝑆superscript𝑆12superscriptsubscript𝜅𝐿1superscriptsubscriptdelimited-∥∥1subscript𝑒𝔼delimited-[]superscriptsubscript𝑋𝑆𝑒superscript𝜀𝑒22subscript¯𝖽𝙻𝙻𝑆1subscript𝑒superscriptsubscriptdelimited-∥∥subscriptsuperscript𝛽𝑒𝑆𝑆subscriptsuperscript𝛽𝑆superscriptsubscriptΣ𝑆𝑒2subscript𝜅𝐿1subscript𝑒superscriptsubscriptdelimited-∥∥superscript𝛽𝑒𝑆superscript¯𝛽𝑆22\displaystyle\begin{split}\mathsf{b}_{\mathtt{LL}}(S)&=\bigg{\|}\frac{1}{|% \mathcal{E}|}\sum_{e\in\mathcal{E}}\mathbb{E}[X_{S\cup S^{\star}}^{(e)}% \varepsilon^{(e)}]\bigg{\|}_{(\bar{\Sigma}_{S\cup S^{\star}})^{-1}}^{2}\leq(% \kappa_{L})^{-1}\bigg{\|}\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}\mathbb{% E}[X_{S}^{(e)}\varepsilon^{(e)}]\bigg{\|}_{2}^{2},\\ \bar{\mathsf{d}}_{\mathtt{LL}}(S)&=\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E% }}\|\beta^{(e,S)}_{S}-\beta^{(S)}_{\dagger}\|_{\Sigma_{S}^{(e)}}^{2}\geq\kappa% _{L}\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}\|\beta^{(e,S)}-\bar{\beta}^{% (S)}\|_{2}^{2}\end{split}start_ROW start_CELL sansserif_b start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ) end_CELL start_CELL = ∥ divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E [ italic_X start_POSTSUBSCRIPT italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ∥ start_POSTSUBSCRIPT ( over¯ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_S ∪ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ ( italic_κ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∥ divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E [ italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , end_CELL end_ROW start_ROW start_CELL over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ) end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT - italic_β start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT † end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT roman_Σ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ italic_κ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT - over¯ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL end_ROW (B.15)

with β(S)=(Σ¯S)1{1||e𝔼[XS(e)Y(e)]}superscriptsubscript𝛽𝑆superscriptsubscript¯Σ𝑆11subscript𝑒𝔼delimited-[]subscriptsuperscript𝑋𝑒𝑆superscript𝑌𝑒\beta_{\dagger}^{(S)}=(\bar{\Sigma}_{S})^{-1}\{\frac{1}{|\mathcal{E}|}\sum_{e% \in\mathcal{E}}\mathbb{E}[X^{(e)}_{S}Y^{(e)}]\}italic_β start_POSTSUBSCRIPT † end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT = ( over¯ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT { divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E [ italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] } and β¯(S)=1||eβ(e,S)superscript¯𝛽𝑆1subscript𝑒superscript𝛽𝑒𝑆\bar{\beta}^{(S)}=\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}\beta^{(e,S)}over¯ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT, satisfies, with probability at least 1C~n1001~𝐶superscript𝑛1001-\widetilde{C}n^{-100}1 - over~ start_ARG italic_C end_ARG italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT,

n3Σ1/2(βg^β)2C~(1+γ)dlog5(n)n,formulae-sequencefor-all𝑛3subscriptnormsuperscriptΣ12subscript𝛽^𝑔superscript𝛽2~𝐶1𝛾𝑑superscript5𝑛𝑛\displaystyle\forall n\geq 3\qquad\|{\Sigma}^{1/2}(\beta_{\widehat{g}}-\beta^{% \star})\|_{2}\leq\widetilde{C}(1+\gamma)\sqrt{\frac{d\log^{5}(n)}{n}},∀ italic_n ≥ 3 ∥ roman_Σ start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ( italic_β start_POSTSUBSCRIPT over^ start_ARG italic_g end_ARG end_POSTSUBSCRIPT - italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ over~ start_ARG italic_C end_ARG ( 1 + italic_γ ) square-root start_ARG divide start_ARG italic_d roman_log start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT ( italic_n ) end_ARG start_ARG italic_n end_ARG end_ARG , (B.16)

for g^(x)=TcB(βg^x)^𝑔𝑥subscriptTc𝐵superscriptsubscript𝛽^𝑔top𝑥\widehat{g}(x)=\mathrm{Tc}_{B}(\beta_{\widehat{g}}^{\top}x)over^ start_ARG italic_g end_ARG ( italic_x ) = roman_Tc start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_β start_POSTSUBSCRIPT over^ start_ARG italic_g end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ). Moreover, if d=o((1+γ2)n/(log6n))𝑑𝑜1superscript𝛾2𝑛superscript6𝑛d=o((1+\gamma^{2})n/(\log^{6}n))italic_d = italic_o ( ( 1 + italic_γ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_n / ( roman_log start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT italic_n ) ), then for large enough n𝑛nitalic_n, we further have

Σ1/2(βg^β)2C~dlog5(n)nsubscriptnormsuperscriptΣ12subscript𝛽^𝑔superscript𝛽2~𝐶𝑑superscript5𝑛𝑛\displaystyle\|{\Sigma}^{1/2}(\beta_{\widehat{g}}-\beta^{\star})\|_{2}\leq% \widetilde{C}\sqrt{\frac{d\log^{5}(n)}{n}}∥ roman_Σ start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ( italic_β start_POSTSUBSCRIPT over^ start_ARG italic_g end_ARG end_POSTSUBSCRIPT - italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ over~ start_ARG italic_C end_ARG square-root start_ARG divide start_ARG italic_d roman_log start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT ( italic_n ) end_ARG start_ARG italic_n end_ARG end_ARG (B.17)
Remark 8.

We present the results using truncated function classes, and there exist poly-logn𝑛\log nroman_log italic_n factors in the non-asymptotic L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error bounds. These are for technical convenience such that we can directly apply our result Theorem 4 which focuses on uniformly bounded function classes. Indeed, one can use a finer analysis and obtain the 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error bound

d+lognn𝑑𝑛𝑛\displaystyle\sqrt{\frac{d+\log n}{n}}square-root start_ARG divide start_ARG italic_d + roman_log italic_n end_ARG start_ARG italic_n end_ARG end_ARG

using unbounded linear function class.

The obtained results in Theorem 8 align with (up to log(n)𝑛\log(n)roman_log ( italic_n ) factors) and offer significant enhancements over Theorem 2 & 3 from Fan et al., (2023). Firstly, the “invariance” condition gets relaxed, we only assume that the noise ε(e)superscript𝜀𝑒\varepsilon^{(e)}italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT and the true important variables XS(e)superscriptsubscript𝑋superscript𝑆𝑒X_{S^{\star}}^{(e)}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT are uncorrelated rather than conditional independent across different environments. Meanwhile, the identification condition 19 exactly matches that in Fan et al., (2023) (refer to Condition 5 therein), and the choice of critical threshold γsuperscript𝛾\gamma^{\star}italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT gets reduced as indicated by the inequality in (B.15) and given that κL=O(1)subscript𝜅𝐿𝑂1\kappa_{L}=O(1)italic_κ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT = italic_O ( 1 ). Such an improvement can be attributed to the term 12{f(e)}212superscriptsuperscript𝑓𝑒2-\frac{1}{2}\{f^{(e)}\}^{2}- divide start_ARG 1 end_ARG start_ARG 2 end_ARG { italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT in our minimax regularization that stabilizes the objective. To see this, consider β𝛽\betaitalic_β with supp(β)=Ssupp𝛽superscript𝑆\mathrm{supp}(\beta)=S^{\star}roman_supp ( italic_β ) = italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, the population-level EILLS objective can be written as

(ββ)Σ(ββ)+γ1||e(ββ)S(ΣS(e))2(ββ)S,superscript𝛽superscript𝛽topΣ𝛽superscript𝛽𝛾1subscript𝑒superscriptsubscript𝛽superscript𝛽superscript𝑆topsuperscriptsubscriptsuperscriptΣ𝑒superscript𝑆2subscript𝛽superscript𝛽superscript𝑆\displaystyle(\beta-\beta^{\star})^{\top}\Sigma(\beta-\beta^{\star})+\gamma% \frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}(\beta-\beta^{\star})_{S^{\star}}% ^{\top}(\Sigma^{(e)}_{S^{\star}})^{2}(\beta-\beta^{\star})_{S^{\star}},( italic_β - italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Σ ( italic_β - italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) + italic_γ divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ( italic_β - italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( roman_Σ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_β - italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ,

where a square of the covariance matrix appears in the regularizer. This does not match what it is in the empirical risk part and will make the objective less stable. Meanwhile, the population-level FAIR objective with sup-f𝑓fitalic_f in this case is

(1+γ)(ββ)Σ(ββ),1𝛾superscript𝛽superscript𝛽topΣ𝛽superscript𝛽\displaystyle(1+\gamma)(\beta-\beta^{\star})^{\top}\Sigma(\beta-\beta^{\star}),( 1 + italic_γ ) ( italic_β - italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Σ ( italic_β - italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ,

which the problem of mismatched covariance matrix order disappears.

We’ve also refined the non-asymptotic L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error bounds. On the one hand, we can derive the error bound without further imposing stronger population-level conditions (Condition 7 required by Theorem 3 in Fan et al., (2023)). On the other, the faster 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error bound for sufficiently large n𝑛nitalic_n remains independent of the hyper-parameter γ𝛾\gammaitalic_γ we choose. These refinements result from our tighter characterization of the instance-dependent error bounds compared to the ones in Fan et al., (2023); see the discussion on technical novelties in Section B.3.

B.6.2 Augmented Linear \mathcal{F}caligraphic_F

Here we consider the case where the discriminator function class \mathcal{F}caligraphic_F is potentially larger than the predictor function class 𝒢𝒢\mathcal{G}caligraphic_G. We introduce the following notations. We let [x,y]𝑥𝑦[x,y][ italic_x , italic_y ] be the concatenation of two vectors xd1𝑥superscriptsubscript𝑑1x\in\mathbb{R}^{d_{1}}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and yd2𝑦superscriptsubscript𝑑2y\in\mathbb{R}^{d_{2}}italic_y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT as a d1+d2subscript𝑑1subscript𝑑2d_{1}+d_{2}italic_d start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT dimensional vector. For each S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ], we define X~S(e)=[XS(e),ϕ¯(XS(e))]2|S|superscriptsubscript~𝑋𝑆𝑒superscriptsubscript𝑋𝑆𝑒¯italic-ϕsuperscriptsubscript𝑋𝑆𝑒superscript2𝑆\widetilde{X}_{S}^{(e)}=[X_{S}^{(e)},\bar{\phi}(X_{S}^{(e)})]\in\mathbb{R}^{2|% S|}over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = [ italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , over¯ start_ARG italic_ϕ end_ARG ( italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ] ∈ blackboard_R start_POSTSUPERSCRIPT 2 | italic_S | end_POSTSUPERSCRIPT, Σ~S(e)=𝔼[X~S(e)(X~S(e))](2|S|)×(2|S|)subscriptsuperscript~Σ𝑒𝑆𝔼delimited-[]superscriptsubscript~𝑋𝑆𝑒superscriptsuperscriptsubscript~𝑋𝑆𝑒topsuperscript2𝑆2𝑆\widetilde{\Sigma}^{(e)}_{S}=\mathbb{E}[\widetilde{X}_{S}^{(e)}(\widetilde{X}_% {S}^{(e)})^{\top}]\in\mathbb{R}^{(2|S|)\times(2|S|)}over~ start_ARG roman_Σ end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT = blackboard_E [ over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT ( 2 | italic_S | ) × ( 2 | italic_S | ) end_POSTSUPERSCRIPT and let X~(e)=X~[d](e)superscript~𝑋𝑒subscriptsuperscript~𝑋𝑒delimited-[]𝑑\widetilde{X}^{(e)}=\widetilde{X}^{(e)}_{[d]}over~ start_ARG italic_X end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = over~ start_ARG italic_X end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_d ] end_POSTSUBSCRIPT and Σ~(e)=Σ~[d](e)superscript~Σ𝑒subscriptsuperscript~Σ𝑒delimited-[]𝑑\widetilde{\Sigma}^{(e)}=\widetilde{\Sigma}^{(e)}_{[d]}over~ start_ARG roman_Σ end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = over~ start_ARG roman_Σ end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT [ italic_d ] end_POSTSUBSCRIPT. We impose additional regularity conditions due to the incorporation of basis function ϕitalic-ϕ\phiitalic_ϕ.

Condition 20.

There exists some constant κ~L>0subscript~𝜅𝐿0\widetilde{\kappa}_{L}>0over~ start_ARG italic_κ end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT > 0 such that λmin(Σ~(e))κ~Lsubscript𝜆superscript~Σ𝑒subscript~𝜅𝐿\lambda_{\min}(\widetilde{\Sigma}^{(e)})\geq\widetilde{\kappa}_{L}italic_λ start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ( over~ start_ARG roman_Σ end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ≥ over~ start_ARG italic_κ end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT for any e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E. Moreover, define Σ~:=||1eΣ~(e)assign~Σsuperscript1subscript𝑒superscript~Σ𝑒\widetilde{\Sigma}:=|\mathcal{E}|^{-1}\sum_{e\in\mathcal{E}}\widetilde{\Sigma}% ^{(e)}over~ start_ARG roman_Σ end_ARG := | caligraphic_E | start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT over~ start_ARG roman_Σ end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT. There exists some positive constant Cx~,σx~subscript𝐶~𝑥subscript𝜎~𝑥C_{\widetilde{x}},\sigma_{\widetilde{x}}italic_C start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT such that

e,v2dwithv2=1,t[0,),(|v(Σ~)1/2X~(e)|t)Cx~et2/(2σx~2)formulae-sequenceformulae-sequencefor-all𝑒for-all𝑣superscript2𝑑withsubscriptnorm𝑣21formulae-sequencefor-all𝑡0superscript𝑣topsuperscript~Σ12superscript~𝑋𝑒𝑡subscript𝐶~𝑥superscript𝑒superscript𝑡22superscriptsubscript𝜎~𝑥2\displaystyle\forall e\in\mathcal{E},~{}\forall v\in\mathbb{R}^{2d}~{}\text{% with}~{}\|v\|_{2}=1,~{}\forall t\in[0,\infty),\qquad\mathbb{P}\left(|v^{\top}(% \widetilde{\Sigma})^{-1/2}\widetilde{X}^{(e)}|\geq t\right)\leq C_{\widetilde{% x}}e^{-t^{2}/(2\sigma_{\widetilde{x}}^{2})}∀ italic_e ∈ caligraphic_E , ∀ italic_v ∈ blackboard_R start_POSTSUPERSCRIPT 2 italic_d end_POSTSUPERSCRIPT with ∥ italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 , ∀ italic_t ∈ [ 0 , ∞ ) , blackboard_P ( | italic_v start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( over~ start_ARG roman_Σ end_ARG ) start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT over~ start_ARG italic_X end_ARG start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | ≥ italic_t ) ≤ italic_C start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT - italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( 2 italic_σ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT

Under 20 such that the covariance matrix for X~~𝑋\widetilde{X}over~ start_ARG italic_X end_ARG are positive definite, we can define

β~(e,S)=[β˘,β˘ϕ]with(β˘,β˘ϕ)=argmin(β,βϕ)(d)2,βSc=βScϕ=0𝔼[|Y(e)βX(e)(βϕ)ϕ¯(X(e))|2],formulae-sequencesuperscript~𝛽𝑒𝑆˘𝛽superscript˘𝛽italic-ϕwith˘𝛽superscript˘𝛽italic-ϕsubscriptargminformulae-sequence𝛽superscript𝛽italic-ϕsuperscriptsuperscript𝑑2subscript𝛽superscript𝑆𝑐superscriptsubscript𝛽superscript𝑆𝑐italic-ϕ0𝔼delimited-[]superscriptsuperscript𝑌𝑒superscript𝛽topsuperscript𝑋𝑒superscriptsuperscript𝛽italic-ϕtop¯italic-ϕsuperscript𝑋𝑒2\displaystyle\widetilde{\beta}^{(e,S)}=[\breve{\beta},\breve{\beta}^{\phi}]~{}% ~{}~{}~{}\text{with}~{}~{}~{}~{}(\breve{\beta},\breve{\beta}^{\phi})=\mathop{% \mathrm{argmin}}_{({\beta},\beta^{\phi})\in(\mathbb{R}^{d})^{2},\beta_{S^{c}}=% \beta_{S^{c}}^{\phi}=0}\mathbb{E}[|Y^{(e)}-\beta^{\top}X^{(e)}-(\beta^{\phi})^% {\top}\bar{\phi}(X^{(e)})|^{2}],over~ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT = [ over˘ start_ARG italic_β end_ARG , over˘ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ] with ( over˘ start_ARG italic_β end_ARG , over˘ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ) = roman_argmin start_POSTSUBSCRIPT ( italic_β , italic_β start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ) ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , italic_β start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = italic_β start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT = 0 end_POSTSUBSCRIPT blackboard_E [ | italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - italic_β start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - ( italic_β start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over¯ start_ARG italic_ϕ end_ARG ( italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ,

and β~S(e,S)=[β˘S,β˘Sϕ]subscriptsuperscript~𝛽𝑒𝑆𝑆subscript˘𝛽𝑆subscriptsuperscript˘𝛽italic-ϕ𝑆\widetilde{\beta}^{(e,S)}_{S}=[\breve{\beta}_{S},\breve{\beta}^{\phi}_{S}]over~ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT = [ over˘ start_ARG italic_β end_ARG start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT , over˘ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ] be a 2|S|2𝑆2|S|2 | italic_S |-dimensional vector. The invariance and identification conditions in this case are as follows.

Condition 21 (Invariance in Linear 𝒢𝒢\mathcal{G}caligraphic_G and Augmented Linear \mathcal{F}caligraphic_F).

There exists some S[d]superscript𝑆delimited-[]𝑑S^{\star}\subseteq[d]italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ [ italic_d ] and βdsuperscript𝛽superscript𝑑\beta^{\star}\in\mathbb{R}^{d}italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with β(S)c=0subscriptsuperscript𝛽superscriptsuperscript𝑆𝑐0\beta^{\star}_{(S^{\star})^{c}}=0italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = 0 and minjS|βj|=βmin>0subscript𝑗superscript𝑆superscriptsubscript𝛽𝑗subscript𝛽0\min_{j\in S^{\star}}|\beta_{j}^{\star}|=\beta_{\min}>0roman_min start_POSTSUBSCRIPT italic_j ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT | = italic_β start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT > 0 such that

eβ~(e,S)=[β,0].formulae-sequencefor-all𝑒superscript~𝛽𝑒𝑆superscript𝛽0\displaystyle\forall e\in\mathcal{E}\qquad\widetilde{\beta}^{(e,S)}=[\beta^{% \star},0].∀ italic_e ∈ caligraphic_E over~ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT = [ italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , 0 ] . (B.18)

Let ε(e)=Y(e)(β)X(e)superscript𝜀𝑒superscript𝑌𝑒superscriptsuperscript𝛽topsuperscript𝑋𝑒\varepsilon^{(e)}=Y^{(e)}-(\beta^{\star})^{\top}X^{(e)}italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT - ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT be the noise, the above invariance equality (B.14) is equivalent to that both XSsubscript𝑋superscript𝑆X_{S^{\star}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT and ϕ¯(XS)¯italic-ϕsubscript𝑋superscript𝑆\bar{\phi}(X_{S^{\star}})over¯ start_ARG italic_ϕ end_ARG ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) are uncorrelated with noise across all the environments, that is,

e𝔼[ε(e)XS(e)]=𝔼[ε(e)ϕ¯(XS(e))]=0formulae-sequencefor-all𝑒𝔼delimited-[]superscript𝜀𝑒superscriptsubscript𝑋superscript𝑆𝑒𝔼delimited-[]superscript𝜀𝑒¯italic-ϕsuperscriptsubscript𝑋superscript𝑆𝑒0\displaystyle\forall e\in\mathcal{E}\qquad\mathbb{E}[\varepsilon^{(e)}X_{S^{% \star}}^{(e)}]=\mathbb{E}[\varepsilon^{(e)}\bar{\phi}(X_{S^{\star}}^{(e)})]=0∀ italic_e ∈ caligraphic_E blackboard_E [ italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] = blackboard_E [ italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT over¯ start_ARG italic_ϕ end_ARG ( italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) ] = 0
Condition 22 (Identification for Linear 𝒢𝒢\mathcal{G}caligraphic_G and Augmented Linear \mathcal{F}caligraphic_F).

For any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] with e𝔼[XS(e)ε(e)]0subscript𝑒𝔼delimited-[]superscriptsubscript𝑋𝑆𝑒superscript𝜀𝑒0\sum_{e\in\mathcal{E}}\mathbb{E}[X_{S}^{(e)}\varepsilon^{(e)}]\neq 0∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E [ italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ≠ 0, either (1) there exists some e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E such that β~(e,S)[β(e,S),0]superscript~𝛽𝑒𝑆superscript𝛽𝑒𝑆0\widetilde{\beta}^{(e,S)}\neq[\beta^{(e,S)},0]over~ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≠ [ italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT , 0 ], or (2) there exists e,e𝑒superscript𝑒e,e^{\prime}\in\mathcal{E}italic_e , italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E such that β(e,S)β(e,S)superscript𝛽𝑒𝑆superscript𝛽superscript𝑒𝑆\beta^{(e,S)}\neq\beta^{(e^{\prime},S)}italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≠ italic_β start_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_S ) end_POSTSUPERSCRIPT.

For technical convenience, we also used truncated function class the discriminator class, defined as 𝚊𝚕𝚒𝚗(d,ϕ,B)={f~=TcB(f):f𝚊𝚕𝚒𝚗}subscript𝚊𝚕𝚒𝚗𝑑italic-ϕ𝐵conditional-set~𝑓subscriptTc𝐵𝑓𝑓subscript𝚊𝚕𝚒𝚗\mathcal{H}_{\mathtt{alin}}(d,\phi,B)=\{\widetilde{f}=\mathrm{Tc}_{B}(f):f\in% \mathcal{H}_{\mathtt{alin}}\}caligraphic_H start_POSTSUBSCRIPT typewriter_alin end_POSTSUBSCRIPT ( italic_d , italic_ϕ , italic_B ) = { over~ start_ARG italic_f end_ARG = roman_Tc start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_f ) : italic_f ∈ caligraphic_H start_POSTSUBSCRIPT typewriter_alin end_POSTSUBSCRIPT }.

Theorem 9 (Linear 𝒢𝒢\mathcal{G}caligraphic_G and Augmented Linear \mathcal{F}caligraphic_F).

Suppose 17, 2022 hold, and we choose

𝒢=𝚕𝚒𝚗(d,C2,C2logn)and=𝚊𝚕𝚒𝚗(d,ϕ,2C2logn)formulae-sequence𝒢subscript𝚕𝚒𝚗𝑑subscript𝐶2subscript𝐶2𝑛andsubscript𝚊𝚕𝚒𝚗𝑑italic-ϕ2subscript𝐶2𝑛\displaystyle\mathcal{G}=\mathcal{H}_{\mathtt{lin}}(d,C_{2},C_{2}\sqrt{\log n}% )\qquad\text{and}\qquad\mathcal{F}=\mathcal{H}_{\mathtt{alin}}(d,\phi,2C_{2}% \sqrt{\log n})caligraphic_G = caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT square-root start_ARG roman_log italic_n end_ARG ) and caligraphic_F = caligraphic_H start_POSTSUBSCRIPT typewriter_alin end_POSTSUBSCRIPT ( italic_d , italic_ϕ , 2 italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT square-root start_ARG roman_log italic_n end_ARG )

with some constant C22(σx~1)maxe,S[d]Σ~1/2β~(e,S)2subscript𝐶22subscript𝜎~𝑥1subscriptformulae-sequence𝑒𝑆delimited-[]𝑑subscriptnormsuperscript~Σ12superscript~𝛽𝑒𝑆2C_{2}\geq 2(\sigma_{\widetilde{x}}\lor 1)\max_{e\in\mathcal{E},S\subseteq[d]}% \|\widetilde{\Sigma}^{1/2}\widetilde{\beta}^{(e,S)}\|_{2}italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≥ 2 ( italic_σ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT ∨ 1 ) roman_max start_POSTSUBSCRIPT italic_e ∈ caligraphic_E , italic_S ⊆ [ italic_d ] end_POSTSUBSCRIPT ∥ over~ start_ARG roman_Σ end_ARG start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT over~ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Then, there exists some constant C~~𝐶\widetilde{C}over~ start_ARG italic_C end_ARG that only depends on (C1,C2,σx~,Cx~,σy,Cy)subscript𝐶1subscript𝐶2subscript𝜎~𝑥subscript𝐶~𝑥subscript𝜎𝑦subscript𝐶𝑦(C_{1},C_{2},\sigma_{\widetilde{x}},C_{\widetilde{x}},\sigma_{y},C_{y})( italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) such that the FAIR least squares estimator using the above function classes and hyper-parameter γ𝛾\gammaitalic_γ satisfying γ8γ𝙻𝙰=8supS:𝖻𝙻𝙻(S)>0𝖻𝙻𝙻(S)/𝖽¯𝙻𝙰(S)𝛾8subscriptsuperscript𝛾𝙻𝙰8subscriptsupremum:𝑆subscript𝖻𝙻𝙻𝑆0subscript𝖻𝙻𝙻𝑆subscript¯𝖽𝙻𝙰𝑆\gamma\geq 8\gamma^{\star}_{\mathtt{LA}}=8\sup_{S:\mathsf{b}_{\mathtt{LL}}(S)>% 0}\mathsf{b}_{\mathtt{LL}}(S)/\bar{\mathsf{d}}_{\mathtt{LA}}(S)italic_γ ≥ 8 italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_LA end_POSTSUBSCRIPT = 8 roman_sup start_POSTSUBSCRIPT italic_S : sansserif_b start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ) > 0 end_POSTSUBSCRIPT sansserif_b start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ) / over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LA end_POSTSUBSCRIPT ( italic_S ), where

𝖽¯𝙻𝙰(S)=1||eβ~S(e,S)[β(S),0]Σ~S(e)2𝖽¯𝙻𝙻(S)withβ(S)defined in Theorem 8,formulae-sequencesubscript¯𝖽𝙻𝙰𝑆1subscript𝑒superscriptsubscriptnormsubscriptsuperscript~𝛽𝑒𝑆𝑆subscriptsuperscript𝛽𝑆0superscriptsubscript~Σ𝑆𝑒2subscript¯𝖽𝙻𝙻𝑆withsubscriptsuperscript𝛽𝑆defined in Theorem 8\displaystyle\bar{\mathsf{d}}_{\mathtt{LA}}(S)=\frac{1}{|\mathcal{E}|}\sum_{e% \in\mathcal{E}}\|\widetilde{\beta}^{(e,S)}_{S}-[\beta^{(S)}_{\dagger},0]\|_{% \widetilde{\Sigma}_{S}^{(e)}}^{2}\geq\bar{\mathsf{d}}_{\mathtt{LL}}(S)\qquad% \text{with}~{}~{}\beta^{(S)}_{\dagger}\text{defined in \lx@cref{% creftype~refnum}{thm:lglf}},over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LA end_POSTSUBSCRIPT ( italic_S ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ over~ start_ARG italic_β end_ARG start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT - [ italic_β start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT † end_POSTSUBSCRIPT , 0 ] ∥ start_POSTSUBSCRIPT over~ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ) with italic_β start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT † end_POSTSUBSCRIPT defined in , (B.19)

satisfies the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error bound (B.16) with probability at least 1C~n1001~𝐶superscript𝑛1001-\widetilde{C}n^{-100}1 - over~ start_ARG italic_C end_ARG italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT. Moreover, if d=o((1+γ2)n/(log6n))𝑑𝑜1superscript𝛾2𝑛superscript6𝑛d=o((1+\gamma^{2})n/(\log^{6}n))italic_d = italic_o ( ( 1 + italic_γ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_n / ( roman_log start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT italic_n ) ), for large enough n𝑛nitalic_n, the error bound (B.17) also holds with probability at least 1C~n1001~𝐶superscript𝑛1001-\widetilde{C}n^{-100}1 - over~ start_ARG italic_C end_ARG italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT.

We can see that the proposed estimator utilizes both the heterogeneity among different environments and strong prior knowledge that the true regression function admits linear form to help the identification. It bridges the EILLS estimator in Fan et al., (2023) and the Focused GMM (FGMM) estimator in Fan & Liao, (2014) when the instrumental variables are [XS,ϕ¯(XS)]subscript𝑋𝑆¯italic-ϕsubscript𝑋𝑆[X_{S},\bar{\phi}(X_{S})][ italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT , over¯ start_ARG italic_ϕ end_ARG ( italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ) ] and hence has some advantages over the individual ones. We illustrate this as follows.

  • 1.

    When there are multiple environments ||>11|\mathcal{E}|>1| caligraphic_E | > 1, the identification condition 22 is weaker to both the EILLS and FGMM estimators. In particular, a consistent estimate βsuperscript𝛽\beta^{\star}italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is attainable if incorporating variables xjsubscript𝑥𝑗x_{j}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT with e𝔼[Xj(e)ε(e)]0subscript𝑒𝔼delimited-[]superscriptsubscript𝑋𝑗𝑒superscript𝜀𝑒0\sum_{e\in\mathcal{E}}\mathbb{E}[X_{j}^{(e)}\varepsilon^{(e)}]\neq 0∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E [ italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ≠ 0 will result in either (1) a shift in the best linear predictor across environments or (2) the fitted residuals is strongly correlated with some nonlinear basis. We refer to this property as “double identifiable” property, given satisfying either condition can lead to the consistent estimation of the true parameter. Furthermore, the critical threshold γsuperscript𝛾\gamma^{\star}italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT can be smaller than that of the EILLS estimator according to the inequality 𝖽¯𝙻𝙰(S)𝖽¯𝙻𝙻(S)subscript¯𝖽𝙻𝙰𝑆subscript¯𝖽𝙻𝙻𝑆\bar{\mathsf{d}}_{\mathtt{LA}}(S)\geq\bar{\mathsf{d}}_{\mathtt{LL}}(S)over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LA end_POSTSUBSCRIPT ( italic_S ) ≥ over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ). This implies that the estimation is sample efficient, which allows for a small γ𝛾\gammaitalic_γ, if either the signal of nonlinear basis or the signal of heterogeneity is strong.

  • 2.

    If there is only one environment ||=11|\mathcal{E}|=1| caligraphic_E | = 1, it reduces to an estimator similar to the FGMM estimator. Consistent estimation remains feasible in this case but completely impossible for EILLS estimator. Moreover, the identification condition, in this case, resembles and relaxes that in Fan & Liao, (2014).

At the same time, it should be noted that the above advantages over the EILLS estimator (linear \mathcal{F}caligraphic_F) are at the cost of imposing stronger invariance condition 21, which assures that the noise should not only be uncorrelated with Xj(e)superscriptsubscript𝑋𝑗𝑒X_{j}^{(e)}italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT but also be uncorrelated with ϕ(Xj(e))italic-ϕsuperscriptsubscript𝑋𝑗𝑒\phi(X_{j}^{(e)})italic_ϕ ( italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) for any jS𝑗superscript𝑆j\in S^{\star}italic_j ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E.

B.6.3 Neural Network \mathcal{F}caligraphic_F

We impose some regularity conditions on the regression function.

Condition 23.

There exists some constant (Cm,σm)subscript𝐶𝑚subscript𝜎𝑚(C_{m},\sigma_{m})( italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) such that m(e,S)superscript𝑚𝑒𝑆m^{(e,S)}italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT is Cmsubscript𝐶𝑚C_{m}italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT Lipschitz and |m(e,S)(0)|Cmsuperscript𝑚𝑒𝑆0subscript𝐶𝑚|m^{(e,S)}(0)|\leq C_{m}| italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ( 0 ) | ≤ italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT for any e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E and S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] and

(|m(e,S)(XS(e))|t)Cmet2/(2σm2)t[0,)formulae-sequencesuperscript𝑚𝑒𝑆superscriptsubscript𝑋𝑆𝑒𝑡subscript𝐶𝑚superscript𝑒superscript𝑡22superscriptsubscript𝜎𝑚2for-all𝑡0\displaystyle\mathbb{P}(|m^{(e,S)}(X_{S}^{(e)})|\geq t)\leq C_{m}e^{-t^{2}/(2% \sigma_{m}^{2})}\qquad\forall t\in[0,\infty)blackboard_P ( | italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ( italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) | ≥ italic_t ) ≤ italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT - italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( 2 italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ∀ italic_t ∈ [ 0 , ∞ )

In this case, we consider the strongest invariance condition together with the weakest identification when the predictor function class 𝒢𝒢\mathcal{G}caligraphic_G is linear.

Condition 24 (Invariance in Linear 𝒢𝒢\mathcal{G}caligraphic_G and Neural Network \mathcal{F}caligraphic_F).

There exists some S[d]superscript𝑆delimited-[]𝑑S^{\star}\subseteq[d]italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ [ italic_d ] and βdsuperscript𝛽superscript𝑑\beta^{\star}\in\mathbb{R}^{d}italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with β(S)c=0subscriptsuperscript𝛽superscriptsuperscript𝑆𝑐0\beta^{\star}_{(S^{\star})^{c}}=0italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ( italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = 0 and minjS|βj|=βmin>0subscript𝑗superscript𝑆superscriptsubscript𝛽𝑗subscript𝛽0\min_{j\in S^{\star}}|\beta_{j}^{\star}|=\beta_{\min}>0roman_min start_POSTSUBSCRIPT italic_j ∈ italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT | = italic_β start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT > 0 such that

e𝔼[Y(e)|XS(e)](β)X(e)formulae-sequencefor-all𝑒𝔼delimited-[]conditionalsuperscript𝑌𝑒superscriptsubscript𝑋superscript𝑆𝑒superscriptsuperscript𝛽topsuperscript𝑋𝑒\displaystyle\forall e\in\mathcal{E}\qquad\mathbb{E}[Y^{(e)}|X_{S^{\star}}^{(e% )}]\equiv(\beta^{\star})^{\top}X^{(e)}∀ italic_e ∈ caligraphic_E blackboard_E [ italic_Y start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT | italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ≡ ( italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT (B.20)
Condition 25 (Identification for Linear 𝒢𝒢\mathcal{G}caligraphic_G and Neural Network \mathcal{F}caligraphic_F).

For any S[d]𝑆delimited-[]𝑑S\subseteq[d]italic_S ⊆ [ italic_d ] with e𝔼[XS(e)ε(e)]0subscript𝑒𝔼delimited-[]superscriptsubscript𝑋𝑆𝑒superscript𝜀𝑒0\sum_{e\in\mathcal{E}}\mathbb{E}[X_{S}^{(e)}\varepsilon^{(e)}]\neq 0∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT blackboard_E [ italic_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_ε start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ] ≠ 0, either (1) there exists some e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E such that μ(e)({m(e,S)Xβ(e,S)})>0superscript𝜇𝑒superscript𝑚𝑒𝑆superscript𝑋topsuperscript𝛽𝑒𝑆0\mu^{(e)}(\{m^{(e,S)}\neq X^{\top}\beta^{(e,S)}\})>0italic_μ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( { italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≠ italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT } ) > 0, or (2) there exists e,e𝑒superscript𝑒e,e^{\prime}\in\mathcal{E}italic_e , italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E such that β(e,S)β(e,S)superscript𝛽𝑒𝑆superscript𝛽superscript𝑒𝑆\beta^{(e,S)}\neq\beta^{(e^{\prime},S)}italic_β start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT ≠ italic_β start_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_S ) end_POSTSUPERSCRIPT.

Theorem 10 (Linear 𝒢𝒢\mathcal{G}caligraphic_G and Neural Network \mathcal{F}caligraphic_F).

Suppose 17, 2325 hold, and we choose the function classes 𝒢=𝚕𝚒𝚗(d,C2,C2logn)𝒢subscript𝚕𝚒𝚗𝑑subscript𝐶2subscript𝐶2𝑛\mathcal{G}=\mathcal{H}_{\mathtt{lin}}(d,C_{2},C_{2}\sqrt{\log n})caligraphic_G = caligraphic_H start_POSTSUBSCRIPT typewriter_lin end_POSTSUBSCRIPT ( italic_d , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT square-root start_ARG roman_log italic_n end_ARG ) and 𝚗𝚗(d,logdn,logdn,C2logn)subscript𝚗𝚗𝑑superscript𝑑𝑛superscript𝑑𝑛subscript𝐶2𝑛\mathcal{H}_{\mathtt{nn}}(d,\log^{d}n,\log^{d}n,C_{2}\sqrt{\log n})caligraphic_H start_POSTSUBSCRIPT typewriter_nn end_POSTSUBSCRIPT ( italic_d , roman_log start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_n , roman_log start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_n , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT square-root start_ARG roman_log italic_n end_ARG ) with some constant C2(1σxσm)maxe,S[d]Σ1/2β2subscript𝐶21subscript𝜎𝑥subscript𝜎𝑚subscriptformulae-sequence𝑒𝑆delimited-[]𝑑subscriptnormsuperscriptΣ12superscript𝛽2C_{2}\geq(1\lor\sigma_{x}\lor\sigma_{m})\max_{e\in\mathcal{E},S\subseteq[d]}\|% \Sigma^{1/2}\beta^{\star}\|_{2}italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≥ ( 1 ∨ italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ∨ italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) roman_max start_POSTSUBSCRIPT italic_e ∈ caligraphic_E , italic_S ⊆ [ italic_d ] end_POSTSUBSCRIPT ∥ roman_Σ start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Then, there exists some constant C~~𝐶\widetilde{C}over~ start_ARG italic_C end_ARG that only depends on (C1,C2,d,σm,Cm,σy,Cy,σx,Cx)subscript𝐶1subscript𝐶2𝑑subscript𝜎𝑚subscript𝐶𝑚subscript𝜎𝑦subscript𝐶𝑦subscript𝜎𝑥subscript𝐶𝑥(C_{1},C_{2},d,\sigma_{m},C_{m},\sigma_{y},C_{y},\sigma_{x},C_{x})( italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_d , italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) such that the FAIR estimator using the above function classes and hyper-parameter γ𝛾\gammaitalic_γ satisfying γ8γ𝙻𝙽=8supS:𝖻𝙻𝙻(S)>0𝖻𝙻𝙻(S)/𝖽¯𝙻𝙽(S)𝛾8subscriptsuperscript𝛾𝙻𝙽8subscriptsupremum:𝑆subscript𝖻𝙻𝙻𝑆0subscript𝖻𝙻𝙻𝑆subscript¯𝖽𝙻𝙽𝑆\gamma\geq 8\gamma^{\star}_{\mathtt{LN}}=8\sup_{S:\mathsf{b}_{\mathtt{LL}}(S)>% 0}\mathsf{b}_{\mathtt{LL}}(S)/\bar{\mathsf{d}}_{\mathtt{LN}}(S)italic_γ ≥ 8 italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT typewriter_LN end_POSTSUBSCRIPT = 8 roman_sup start_POSTSUBSCRIPT italic_S : sansserif_b start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ) > 0 end_POSTSUBSCRIPT sansserif_b start_POSTSUBSCRIPT typewriter_LL end_POSTSUBSCRIPT ( italic_S ) / over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LN end_POSTSUBSCRIPT ( italic_S ), where

𝖽¯𝙻𝙽(S)=1||em(e,S)(β(S))xS2,e𝖽𝙻𝙰(S),subscript¯𝖽𝙻𝙽𝑆1subscript𝑒subscriptnormsuperscript𝑚𝑒𝑆superscriptsubscriptsuperscript𝛽𝑆topsubscript𝑥𝑆2𝑒subscript𝖽𝙻𝙰𝑆\displaystyle\bar{\mathsf{d}}_{\mathtt{LN}}(S)=\frac{1}{|\mathcal{E}|}\sum_{e% \in\mathcal{E}}\|m^{(e,S)}-(\beta^{(S)}_{\dagger})^{\top}x_{S}\|_{2,e}\geq% \mathsf{d}_{\mathtt{LA}}(S),over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LN end_POSTSUBSCRIPT ( italic_S ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ∥ italic_m start_POSTSUPERSCRIPT ( italic_e , italic_S ) end_POSTSUPERSCRIPT - ( italic_β start_POSTSUPERSCRIPT ( italic_S ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT † end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 , italic_e end_POSTSUBSCRIPT ≥ sansserif_d start_POSTSUBSCRIPT typewriter_LA end_POSTSUBSCRIPT ( italic_S ) , (B.21)

satisfies, for large enough n𝑛nitalic_n,

βg^β2C~(logd+3n)n1/2subscriptnormsubscript𝛽^𝑔superscript𝛽2~𝐶superscript𝑑3𝑛superscript𝑛12\displaystyle\qquad\|\beta_{\widehat{g}}-\beta^{\star}\|_{2}\leq\widetilde{C}(% \log^{d+3}n)n^{-1/2}∥ italic_β start_POSTSUBSCRIPT over^ start_ARG italic_g end_ARG end_POSTSUBSCRIPT - italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ over~ start_ARG italic_C end_ARG ( roman_log start_POSTSUPERSCRIPT italic_d + 3 end_POSTSUPERSCRIPT italic_n ) italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT

with probability at least 1C~n1001~𝐶superscript𝑛1001-\widetilde{C}n^{-100}1 - over~ start_ARG italic_C end_ARG italic_n start_POSTSUPERSCRIPT - 100 end_POSTSUPERSCRIPT.

The estimator can be viewed as an advanced version of the one using =𝚊𝚕𝚒𝚗(d,ϕ)subscript𝚊𝚕𝚒𝚗𝑑italic-ϕ\mathcal{F}=\mathcal{H}_{\mathtt{alin}}(d,\phi)caligraphic_F = caligraphic_H start_POSTSUBSCRIPT typewriter_alin end_POSTSUBSCRIPT ( italic_d , italic_ϕ ). It leverages neural networks to search for appropriate basis function ϕitalic-ϕ\phiitalic_ϕ with strong signals. With the proper choice of the neural network hyper-parameters, the estimator still maintains a parametric optimal rate (up to logarithmic factors). Additionally, it requires a weaker identification condition as described by 25 and reduced critical threshold γsuperscript𝛾\gamma^{\star}italic_γ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT according to the inequality 𝖽¯𝙻𝙽(S)𝖽¯𝙻𝙰(S)subscript¯𝖽𝙻𝙽𝑆subscript¯𝖽𝙻𝙰𝑆\bar{\mathsf{d}}_{\mathtt{LN}}(S)\geq\bar{\mathsf{d}}_{\mathtt{LA}}(S)over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LN end_POSTSUBSCRIPT ( italic_S ) ≥ over¯ start_ARG sansserif_d end_ARG start_POSTSUBSCRIPT typewriter_LA end_POSTSUBSCRIPT ( italic_S ) in Theorem 10.

Appendix C Omitted Parts in Experiments

C.1 Pseudo-code of the Gradient Descent Ascent Algorithm

Algorithm 1 FAIR Gradient Descent Ascent Training
1:SGD Hyper-parameters: iteration steps T𝑇Titalic_T, batch size m𝑚mitalic_m, predictor/discriminator iter steps Tgsubscript𝑇𝑔T_{g}italic_T start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT/Tfsubscript𝑇𝑓T_{f}italic_T start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT.
2:FAIR Hyper-parameters: invariance regularization γ𝛾\gammaitalic_γ.
3:Annealing Hyper-parameters: Initial τ0subscript𝜏0\tau_{0}italic_τ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and final τTsubscript𝜏𝑇\tau_{T}italic_τ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT.
4:Models: predictor g(x;θ)𝑔𝑥𝜃g(x;\theta)italic_g ( italic_x ; italic_θ ), discriminators {f(e)(x;ϕ(e))}esubscriptsuperscript𝑓𝑒𝑥superscriptitalic-ϕ𝑒𝑒\{f^{(e)}(x;\phi^{(e)})\}_{e\in\mathcal{E}}{ italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_x ; italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT, gate w𝑤witalic_w.
5:Input: data {𝒟(e)}esubscriptsuperscript𝒟𝑒𝑒\{\mathcal{D}^{(e)}\}_{e\in\mathcal{E}}{ caligraphic_D start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT with 𝒟(e)={(xi(e),yi(e))}i=1nsuperscript𝒟𝑒superscriptsubscriptsuperscriptsubscript𝑥𝑖𝑒superscriptsubscript𝑦𝑖𝑒𝑖1𝑛\mathcal{D}^{(e)}=\{(x_{i}^{(e)},y_{i}^{(e)})\}_{i=1}^{n}caligraphic_D start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = { ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT from |||\mathcal{E}|| caligraphic_E | environments, loss function (,)\ell(\cdot,\cdot)roman_ℓ ( ⋅ , ⋅ ).
6:Output: Parameters of the prediction model: w𝑤witalic_w and θ𝜃\thetaitalic_θ
7:
8:Initialize θ,{ϕ(e)}e𝜃subscriptsuperscriptitalic-ϕ𝑒𝑒\theta,\{\phi^{(e)}\}_{e\in\mathcal{E}}italic_θ , { italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT with random weights
9:Set w=0𝑤0w=0italic_w = 0
10:
11:for t{1,,T}𝑡1𝑇t\in\{1,\ldots,T\}italic_t ∈ { 1 , … , italic_T } do
12:     Set τt=τ0×(τT/τ0)t/Tsubscript𝜏𝑡subscript𝜏0superscriptsubscript𝜏𝑇subscript𝜏0𝑡𝑇\tau_{t}=\tau_{0}\times(\tau_{T}/\tau_{0})^{t/T}italic_τ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_τ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT × ( italic_τ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT / italic_τ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_t / italic_T end_POSTSUPERSCRIPT
13:     for tf{1,,Tf}subscript𝑡𝑓1subscript𝑇𝑓t_{f}\in\{1,\ldots,T_{f}\}italic_t start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ∈ { 1 , … , italic_T start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT } do \triangleright Discriminator Ascent
14:         Sample {uj}j=1d={(uj,1,uj,2)}j=1dsuperscriptsubscriptsubscript𝑢𝑗𝑗1𝑑superscriptsubscriptsubscript𝑢𝑗1subscript𝑢𝑗2𝑗1𝑑\{u_{j}\}_{j=1}^{d}=\{(u_{j,1},u_{j,2})\}_{j=1}^{d}{ italic_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT = { ( italic_u start_POSTSUBSCRIPT italic_j , 1 end_POSTSUBSCRIPT , italic_u start_POSTSUBSCRIPT italic_j , 2 end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT from Gumbel(0,1)Gumbel01\mathrm{Gumbel}(0,1)roman_Gumbel ( 0 , 1 ).
15:         Calculate a=(a1,,ad)𝑎subscript𝑎1subscript𝑎𝑑a=(a_{1},\ldots,a_{d})italic_a = ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_a start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) with aj=Vτt(wj,uj)subscript𝑎𝑗subscript𝑉subscript𝜏𝑡subscript𝑤𝑗subscript𝑢𝑗a_{j}=V_{\tau_{t}}(w_{j},u_{j})italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_V start_POSTSUBSCRIPT italic_τ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ), where V()𝑉V(\cdot)italic_V ( ⋅ ) is defined in (5.2).
16:         for e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E do \triangleright Update f(e)superscript𝑓𝑒f^{(e)}italic_f start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT
17:              Sample minibatch of m𝑚mitalic_m examples {(x(e,i),y(e,i))}i=1msuperscriptsubscriptsuperscript𝑥𝑒𝑖superscript𝑦𝑒𝑖𝑖1𝑚\{(x^{(e,i)},y^{(e,i)})\}_{i=1}^{m}{ ( italic_x start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT , italic_y start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT from 𝒟(e)superscript𝒟𝑒\mathcal{D}^{(e)}caligraphic_D start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT.
18:              Update the discriminator by ascending its stochastic gradient:
ϕ(e)γmi=1m[{y(e,i)g(x(e,i))}fϕ(e)(x(e,i))12{fϕ(e)(ϕ(e))}2]subscriptsuperscriptitalic-ϕ𝑒𝛾𝑚superscriptsubscript𝑖1𝑚delimited-[]superscript𝑦𝑒𝑖𝑔superscript𝑥𝑒𝑖subscript𝑓superscriptitalic-ϕ𝑒superscript𝑥𝑒𝑖12superscriptsubscript𝑓superscriptitalic-ϕ𝑒superscriptitalic-ϕ𝑒2\nabla_{\phi^{(e)}}\frac{\gamma}{m}\sum_{i=1}^{m}\left[\{y^{(e,i)}-g(x^{(e,i)}% )\}f_{\phi^{(e)}}(x^{(e,i)})-\frac{1}{2}\{f_{\phi^{(e)}}(\phi^{(e)})\}^{2}\right]∇ start_POSTSUBSCRIPT italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG italic_γ end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT [ { italic_y start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT - italic_g ( italic_x start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT ) } italic_f start_POSTSUBSCRIPT italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG { italic_f start_POSTSUBSCRIPT italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ) } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
              where
g(x)=g(a(w)x;θ)andfϕ(e)(x)=f(a(w)x;ϕ(e))formulae-sequence𝑔𝑥𝑔direct-product𝑎𝑤𝑥𝜃andsubscript𝑓superscriptitalic-ϕ𝑒𝑥𝑓direct-product𝑎𝑤𝑥superscriptitalic-ϕ𝑒g(x)=g(a(w)\odot x;\theta)\qquad\text{and}\qquad f_{\phi^{(e)}}(x)=f(a(w)\odot x% ;\phi^{(e)})italic_g ( italic_x ) = italic_g ( italic_a ( italic_w ) ⊙ italic_x ; italic_θ ) and italic_f start_POSTSUBSCRIPT italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x ) = italic_f ( italic_a ( italic_w ) ⊙ italic_x ; italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT )
19:         end for
20:     end for
21:     for tg{1,,Tg}subscript𝑡𝑔1subscript𝑇𝑔t_{g}\in\{1,\ldots,T_{g}\}italic_t start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ∈ { 1 , … , italic_T start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT } do \triangleright Predictor Descent
22:         Sample {uj}j=1d={(uj,1,uj,2)}j=1dsuperscriptsubscriptsubscript𝑢𝑗𝑗1𝑑superscriptsubscriptsubscript𝑢𝑗1subscript𝑢𝑗2𝑗1𝑑\{u_{j}\}_{j=1}^{d}=\{(u_{j,1},u_{j,2})\}_{j=1}^{d}{ italic_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT = { ( italic_u start_POSTSUBSCRIPT italic_j , 1 end_POSTSUBSCRIPT , italic_u start_POSTSUBSCRIPT italic_j , 2 end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT from Gumbel(0,1)Gumbel01\mathrm{Gumbel}(0,1)roman_Gumbel ( 0 , 1 ).
23:         Calculate a=(a1,,ad)𝑎subscript𝑎1subscript𝑎𝑑a=(a_{1},\ldots,a_{d})italic_a = ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_a start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) with aj=Vτt(wj,uj)subscript𝑎𝑗subscript𝑉subscript𝜏𝑡subscript𝑤𝑗subscript𝑢𝑗a_{j}=V_{\tau_{t}}(w_{j},u_{j})italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_V start_POSTSUBSCRIPT italic_τ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ), where V()𝑉V(\cdot)italic_V ( ⋅ ) is defined in (5.2).
24:         for e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E do \triangleright Enumerate Environments
25:              Sample minibatch of m𝑚mitalic_m examples {(x(e,i),y(e,i))}i=1msuperscriptsubscriptsuperscript𝑥𝑒𝑖superscript𝑦𝑒𝑖𝑖1𝑚\{(x^{(e,i)},y^{(e,i)})\}_{i=1}^{m}{ ( italic_x start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT , italic_y start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT from 𝒟(e)superscript𝒟𝑒\mathcal{D}^{(e)}caligraphic_D start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT.
26:              Calculate loss as function of θ𝜃\thetaitalic_θ and w𝑤witalic_w, that is
L(e)(θ,w)superscript𝐿𝑒𝜃𝑤\displaystyle L^{(e)}(\theta,w)italic_L start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_θ , italic_w ) =γmi=1m[{y(e,i)gw,θ(x(e,i))}fw(x(e,i))12{fw(x(e,i))}2]absent𝛾𝑚superscriptsubscript𝑖1𝑚delimited-[]superscript𝑦𝑒𝑖subscript𝑔𝑤𝜃superscript𝑥𝑒𝑖subscript𝑓𝑤superscript𝑥𝑒𝑖12superscriptsubscript𝑓𝑤superscript𝑥𝑒𝑖2\displaystyle=\frac{\gamma}{m}\sum_{i=1}^{m}\left[\{y^{(e,i)}-g_{w,\theta}(x^{% (e,i)})\}f_{w}(x^{(e,i)})-\frac{1}{2}\{f_{w}(x^{(e,i)})\}^{2}\right]= divide start_ARG italic_γ end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT [ { italic_y start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT - italic_g start_POSTSUBSCRIPT italic_w , italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT ) } italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG { italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT ) } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
+1mi=1m[(y(e,i),gw,θ(x(e,i)))]1𝑚superscriptsubscript𝑖1𝑚delimited-[]superscript𝑦𝑒𝑖subscript𝑔𝑤𝜃superscript𝑥𝑒𝑖\displaystyle~{}~{}~{}~{}~{}~{}+\frac{1}{m}\sum_{i=1}^{m}\left[\ell\left(y^{(e% ,i)},g_{w,\theta}(x^{(e,i)})\right)\right]+ divide start_ARG 1 end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT [ roman_ℓ ( italic_y start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT , italic_g start_POSTSUBSCRIPT italic_w , italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT ( italic_e , italic_i ) end_POSTSUPERSCRIPT ) ) ]
              where
gw,θ(x)=g(a(w)x;θ)andfw(x)=f(a(w)x;ϕ(e))formulae-sequencesubscript𝑔𝑤𝜃𝑥𝑔direct-product𝑎𝑤𝑥𝜃andsubscript𝑓𝑤𝑥𝑓direct-product𝑎𝑤𝑥superscriptitalic-ϕ𝑒g_{w,\theta}(x)=g(a(w)\odot x;\theta)\qquad\text{and}\qquad f_{w}(x)=f(a(w)% \odot x;\phi^{(e)})italic_g start_POSTSUBSCRIPT italic_w , italic_θ end_POSTSUBSCRIPT ( italic_x ) = italic_g ( italic_a ( italic_w ) ⊙ italic_x ; italic_θ ) and italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_x ) = italic_f ( italic_a ( italic_w ) ⊙ italic_x ; italic_ϕ start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT )
27:         end for
28:         Update the predictor weights w,θ𝑤𝜃w,\thetaitalic_w , italic_θ by descending its stochastic gradient:
(θ,w)eL(e)(θ,w)subscript𝜃𝑤subscript𝑒superscript𝐿𝑒𝜃𝑤\nabla_{(\theta,w)}\sum_{e\in\mathcal{E}}L^{(e)}(\theta,w)∇ start_POSTSUBSCRIPT ( italic_θ , italic_w ) end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_θ , italic_w )
29:     end for
30:end for

C.2 Detailed Simulation Configuration

C.2.1 Linear Model with d=15𝑑15d=15italic_d = 15

Data Generating Process.

The data-generating process is similar to that described in Section 5.2.1. We also let ||=22|\mathcal{E}|=2| caligraphic_E | = 2, and use the same procedure to generate parent-children relationship and structural assignment except that (1) we use d=15𝑑15d=15italic_d = 15 and let the variable Z8subscript𝑍8Z_{8}italic_Z start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPT be Y𝑌Yitalic_Y; and (2) we enforce that Y𝑌Yitalic_Y has at least 3333 parents and 3333 children (3) the structural assignment for variable Y𝑌Yitalic_Y is

Y(8)=Z(8)k𝚙𝚊(8)C8,kZk(e)+C8,8ε8,superscript𝑌8superscript𝑍8subscript𝑘𝚙𝚊8subscript𝐶8𝑘superscriptsubscript𝑍𝑘𝑒subscript𝐶88subscript𝜀8\displaystyle Y^{(8)}=Z^{(8)}\leftarrow\sum_{k\in\mathtt{pa}(8)}C_{8,k}Z_{k}^{% (e)}+C_{8,8}\varepsilon_{8},italic_Y start_POSTSUPERSCRIPT ( 8 ) end_POSTSUPERSCRIPT = italic_Z start_POSTSUPERSCRIPT ( 8 ) end_POSTSUPERSCRIPT ← ∑ start_POSTSUBSCRIPT italic_k ∈ typewriter_pa ( 8 ) end_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT 8 , italic_k end_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT + italic_C start_POSTSUBSCRIPT 8 , 8 end_POSTSUBSCRIPT italic_ε start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPT ,

that is we let the variance noise to be the same for the two environments. This is because we will include ICP in our simulation comparisons, which requires conditional distribution invariance.

Implementation.

We use the same configurations in the implementation of FAIR-GB and FAIR-RF. We also use fixed γ=36𝛾36\gamma=36italic_γ = 36 for all the FAIR family estimators including EILLS. It is worth noticing that ICP, anchor regression, and IRM introduce an additional hyper-parameter, we pick it in an oracle way for them: that is, we enumerate all the candidate hyper-parameters and select the one that minimizes the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT estimation error. We report the performance for n{100,200,500,800,1000}𝑛1002005008001000n\in\{100,200,500,800,1000\}italic_n ∈ { 100 , 200 , 500 , 800 , 1000 }.

Discussion of Results.

For anchor regression and IRM, their performance and the corresponding relationships w.r.t. Pool-LS are similar to the 12 variable illustrations in Fan et al., (2023). The anchor regression is almost the same as Pool-LS because it is essentially the same as standard least squares when the environments are discrete: indeed, in ||=22|\mathcal{E}|=2| caligraphic_E | = 2, it just runs least squares with a difference intercept for the interventional environment e=1𝑒1e=1italic_e = 1. The IRM is better than vanilla least squares by slightly decreasing the bias, while the performance improvement is negligible compared with the bias it has.

For ICP, the performance is even worse than pooled least squares because it collapses to conservative solutions like 00. Note that we apply interventions to all the variables in environment e=1𝑒1e=1italic_e = 1, under which it is possible for ICP to identify βsuperscript𝛽\beta^{\star}italic_β start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and Ssuperscript𝑆S^{\star}italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT when n=𝑛n=\inftyitalic_n = ∞. The large estimation error it depicts is due to its inefficiency in estimation.

We can also see that the performance of FAIR-BF and FAIR-RF are similar, demonstrating the effectiveness of our proposed gradient descent ascent algorithm with Gumbel approximation. The performance of FAIR-GF and FAIR-RF is slightly better than EILLS. This is because the FAIR estimator is essentially doing the most efficient pooled least squares when it selects the correct variable.

C.2.2 Nonlinear Model

Data Generating Process.

For the structural assignment, we let εi(e)=εisuperscriptsubscript𝜀𝑖𝑒subscript𝜀𝑖\varepsilon_{i}^{(e)}=\varepsilon_{i}italic_ε start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = italic_ε start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for i5𝑖5i\leq 5italic_i ≤ 5 and εi(e)=Ci,i(e)εisuperscriptsubscript𝜀𝑖𝑒superscriptsubscript𝐶𝑖𝑖𝑒subscript𝜀𝑖\varepsilon_{i}^{(e)}=C_{i,i}^{(e)}\varepsilon_{i}italic_ε start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT = italic_C start_POSTSUBSCRIPT italic_i , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_ε start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT where (ε1,,ε26)subscript𝜀1subscript𝜀26(\varepsilon_{1},\ldots,\varepsilon_{26})( italic_ε start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_ε start_POSTSUBSCRIPT 26 end_POSTSUBSCRIPT ) are independent Uniform([1.5,1.5])Uniform1.51.5\mathrm{Uniform}([-1.5,1.5])roman_Uniform ( [ - 1.5 , 1.5 ] ) random variables to let the covariates to be uniformly bounded and Ci,i(e)superscriptsubscript𝐶𝑖𝑖𝑒C_{i,i}^{(e)}italic_C start_POSTSUBSCRIPT italic_i , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT are scalars that are randomly generated in each trial. ε0subscript𝜀0\varepsilon_{0}italic_ε start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is standard normal distributed that is independent of (ε1,,ε26)subscript𝜀1subscript𝜀26(\varepsilon_{1},\ldots,\varepsilon_{26})( italic_ε start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_ε start_POSTSUBSCRIPT 26 end_POSTSUBSCRIPT ).

For the assignments for the children of Y𝑌Yitalic_Y, we let fi,0(e)(u)=Ci,0(e)tanh(u)superscriptsubscript𝑓𝑖0𝑒𝑢superscriptsubscript𝐶𝑖0𝑒𝑢f_{i,0}^{(e)}(u)=C_{i,0}^{(e)}\tanh(u)italic_f start_POSTSUBSCRIPT italic_i , 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_u ) = italic_C start_POSTSUBSCRIPT italic_i , 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT roman_tanh ( italic_u ), where Ci,0(e)superscriptsubscript𝐶𝑖0𝑒C_{i,0}^{(e)}italic_C start_POSTSUBSCRIPT italic_i , 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT are scalars that is randomly sampled from Uniform([1.5,1.5])Uniform1.51.5\mathrm{Uniform}([-1.5,1.5])roman_Uniform ( [ - 1.5 , 1.5 ] ) for e=0𝑒0e=0italic_e = 0 and Uniform([5,5])Uniform55\mathrm{Uniform}([-5,5])roman_Uniform ( [ - 5 , 5 ] ) for e=1𝑒1e=1italic_e = 1, the noise level Ci,i(e)superscriptsubscript𝐶𝑖𝑖𝑒C_{i,i}^{(e)}italic_C start_POSTSUBSCRIPT italic_i , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT is a scalar generated from Uniform([1,1.5])Uniform11.5\mathrm{Uniform}([1,1.5])roman_Uniform ( [ 1 , 1.5 ] ). For the assignments for other variables Xisubscript𝑋𝑖X_{i}italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with i10𝑖10i\geq 10italic_i ≥ 10, we let fi,j(e)(u)=Ci,j(e)hi,j(e)(u)superscriptsubscript𝑓𝑖𝑗𝑒𝑢superscriptsubscript𝐶𝑖𝑗𝑒superscriptsubscript𝑖𝑗𝑒𝑢f_{i,j}^{(e)}(u)=C_{i,j}^{(e)}h_{i,j}^{(e)}(u)italic_f start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_u ) = italic_C start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT ( italic_u ) where hi,j(e)superscriptsubscript𝑖𝑗𝑒h_{i,j}^{(e)}italic_h start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT are randomly picked from the function set {tanh(x),sin(x),cos(x)}𝑥𝑥𝑥\{\tanh(x),\sin(x),\cos(x)\}{ roman_tanh ( italic_x ) , roman_sin ( italic_x ) , roman_cos ( italic_x ) }, the noise level Ci,i(e)superscriptsubscript𝐶𝑖𝑖𝑒C_{i,i}^{(e)}italic_C start_POSTSUBSCRIPT italic_i , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_e ) end_POSTSUPERSCRIPT is a scalar generated from Uniform([2,3])Uniform23\mathrm{Uniform}([2,3])roman_Uniform ( [ 2 , 3 ] ). For m1superscriptsubscript𝑚1m_{1}^{\star}italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, it is k=15f0,j(x)superscriptsubscript𝑘15subscript𝑓0𝑗𝑥\sum_{k=1}^{5}f_{0,j}(x)∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT 0 , italic_j end_POSTSUBSCRIPT ( italic_x ) with f0,j(x)subscript𝑓0𝑗𝑥f_{0,j}(x)italic_f start_POSTSUBSCRIPT 0 , italic_j end_POSTSUBSCRIPT ( italic_x ) randomly picked from {tanh(x),sin(x),max(0,x),x}𝑥𝑥0𝑥𝑥\{\tanh(x),\sin(x),\max(0,x),x\}{ roman_tanh ( italic_x ) , roman_sin ( italic_x ) , roman_max ( 0 , italic_x ) , italic_x }.

Implementation.

For the FAIR-NN implementation using Gumbel approximation, we also run gradient descent ascent using the Adam optimizer using a learning rate of 1e-3, batch size 64646464. The number of iterations is 70k70𝑘70k70 italic_k for m1superscriptsubscript𝑚1m_{1}^{\star}italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and 80k80𝑘80k80 italic_k for m2superscriptsubscript𝑚2m_{2}^{\star}italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. In each iteration, one gradient descent update of the neural network parameters in g𝑔gitalic_g and the Gumbel logits parameter w𝑤witalic_w is conducted followed by three gradient ascent updates of the neural network parameters in f(0)superscript𝑓0f^{(0)}italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT and f(1)superscript𝑓1f^{(1)}italic_f start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT. We also use fixed γ=36𝛾36\gamma=36italic_γ = 36. The implementation details for the estimators are:

  • (1)

    Pool-LS: it simply runs least squares on the full covariate X𝑋Xitalic_X using all the data.

  • (2)

    FAIR-GB: Our FAIR-NN estimator with Gumbel approximation, its prediction on the test dataset is evaluated by averaging the predictions over 100100100100 Gumbel samples.

  • (3)

    FAIR-RF: it first selects the variables xjsubscript𝑥𝑗x_{j}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT in the fitted model in (2) with sig(wj)>tsigsubscript𝑤𝑗𝑡{\mathrm{sig}(w_{j})}>troman_sig ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) > italic_t, i.e., S^={j:sig(wj)>t}^𝑆conditional-set𝑗sigsubscript𝑤𝑗𝑡\widehat{S}=\{j:{\mathrm{sig}(w_{j})}>t\}over^ start_ARG italic_S end_ARG = { italic_j : roman_sig ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) > italic_t }, and runs least squares again on XS^subscript𝑋^𝑆X_{\widehat{S}}italic_X start_POSTSUBSCRIPT over^ start_ARG italic_S end_ARG end_POSTSUBSCRIPT using all the data. Here we let t=0.6𝑡0.6t=0.6italic_t = 0.6 for n2000𝑛2000n\leq 2000italic_n ≤ 2000 and t=0.9𝑡0.9t=0.9italic_t = 0.9 for n>2000𝑛2000n>2000italic_n > 2000.

  • (4)

    Oracle: it runs least squares on XSsubscript𝑋superscript𝑆X_{S^{\star}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT using all the data.

For FAIR-GB, we report the estimated MSE for the model in the last iteration. For other estimators, we also run gradient descent using the Adam optimizer for 10k iterations. We report the estimated MSE for the model with early stop** regularization: that is, we report the estimated MSE of the model that has the smallest validation error, and the validation data is sampled independently and identically to the training data with sample size nvalid=3n/7subscript𝑛valid3𝑛7n_{\mathrm{valid}}=\lfloor 3n/7\rflooritalic_n start_POSTSUBSCRIPT roman_valid end_POSTSUBSCRIPT = ⌊ 3 italic_n / 7 ⌋.

C.3 Details of the Discovery in Real Physical System Application

Data Collection

We directly use the dataset ‘lt_interventions_standard_v1’ released in Gamella et al., (2024).

For the training dataset, given fixed sample size n𝑛nitalic_n, the data in the first environment e=0𝑒0e=0italic_e = 0 is sampled from the experimental setting ‘uniform_reference’. For the second environment e=1𝑒1e=1italic_e = 1, a mixture of interventions is applied. To be specific, a weak intervention on the variables V~3,V~1,V~2,I~1,I~2subscript~𝑉3subscript~𝑉1subscript~𝑉2subscript~𝐼1subscript~𝐼2\widetilde{V}_{3},\widetilde{V}_{1},\widetilde{V}_{2},\widetilde{I}_{1},% \widetilde{I}_{2}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT with probability (1/3,1/6,1/6,1/6,1/6)1316161616(1/3,1/6,1/6,1/6,1/6)( 1 / 3 , 1 / 6 , 1 / 6 , 1 / 6 , 1 / 6 ), respectively. This is equivalent to sample data from the experimental setting ‘t_vis_3_weak’, ‘t_vis_1_weak’, ‘t_vis_2_weak,’, ‘t_ir_1_weak’, ‘t_ir_2_weak’ with weights (1/3,1/6,1/6,1/6,1/6)1316161616(1/3,1/6,1/6,1/6,1/6)( 1 / 3 , 1 / 6 , 1 / 6 , 1 / 6 , 1 / 6 ).

For the test data used for evaluation in Fig. 6 (b)–(c), we use the data from the experimental setting ‘t_vis_3_strong’, ‘t_vis_1_strong’, ‘t_vis_2_strong’, ‘t_ir_1_strong’, ‘t_ir_2_strong’. Since there is an out-of-support issue for the intervention, i.e.,

|MeanμX,i(X)Meanμ¯n(X)|>1.6Stdμ¯n(X)subscriptMeansubscript𝜇𝑋𝑖𝑋subscriptMeansubscript¯𝜇𝑛𝑋1.6subscriptStdsubscript¯𝜇𝑛𝑋\displaystyle|\mathrm{Mean}_{\mu_{X,i}}(X)-\mathrm{Mean}_{\bar{\mu}_{n}}(X)|>1% .6\cdot\mathrm{Std}_{\bar{\mu}_{n}}(X)| roman_Mean start_POSTSUBSCRIPT italic_μ start_POSTSUBSCRIPT italic_X , italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X ) - roman_Mean start_POSTSUBSCRIPT over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X ) | > 1.6 ⋅ roman_Std start_POSTSUBSCRIPT over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X )

where μX,isubscript𝜇𝑋𝑖\mu_{X,i}italic_μ start_POSTSUBSCRIPT italic_X , italic_i end_POSTSUBSCRIPT is the empirical distribution of X𝑋Xitalic_X in the experimental setting where strong intervention is intervened on X𝑋Xitalic_X, and μ¯nsubscript¯𝜇𝑛\bar{\mu}_{n}over¯ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is the empirical distribution of X𝑋Xitalic_X in the training dataset. Thus, we recenter the variable X𝑋Xitalic_X in the corresponding test intervention environment such that it has the same empirical mean as that in the training dataset.

Explanation on the Equivalent Graph

We regress I~3subscript~𝐼3\widetilde{I}_{3}over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT on (R,G,B,θ1,θ2,V~1,V~2,V~3,I~1,I~2)𝑅𝐺𝐵subscript𝜃1subscript𝜃2subscript~𝑉1subscript~𝑉2subscript~𝑉3subscript~𝐼1subscript~𝐼2(R,G,B,\theta_{1},\theta_{2},\widetilde{V}_{1},\widetilde{V}_{2},\widetilde{V}% _{3},\widetilde{I}_{1},\widetilde{I}_{2})( italic_R , italic_G , italic_B , italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). There are several hidden confounders, hence there should be an arrow from V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT to I~3subscript~𝐼3\widetilde{I}_{3}over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT and an arrow from I~3subscript~𝐼3\widetilde{I}_{3}over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT to V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT if V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT is not intervened given the existence of hidden confounders (L3,1,L3,2)subscript𝐿31subscript𝐿32(L_{3,1},L_{3,2})( italic_L start_POSTSUBSCRIPT 3 , 1 end_POSTSUBSCRIPT , italic_L start_POSTSUBSCRIPT 3 , 2 end_POSTSUBSCRIPT ). Introducing the variable V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT in predicting I~3subscript~𝐼3\widetilde{I}_{3}over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT can increase the predictive power given it can provide additional information of (L3,1,L3,2)subscript𝐿31subscript𝐿32(L_{3,1},L_{3,2})( italic_L start_POSTSUBSCRIPT 3 , 1 end_POSTSUBSCRIPT , italic_L start_POSTSUBSCRIPT 3 , 2 end_POSTSUBSCRIPT ). The (equivalent) arrow from V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT to I~3subscript~𝐼3\widetilde{I}_{3}over~ start_ARG italic_I end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT do disappear because of the intervention on V~3subscript~𝑉3\widetilde{V}_{3}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT will make the association perturbs.

Experimental Setup

For the FAIR-NN implementation using Gumbel approximation, we also run gradient descent ascent using the Adam optimizer using a learning rate of 1e-3, batch size 64646464. The number of iterations is 100k100𝑘100k100 italic_k. In each iteration, one gradient descent update of the neural network parameters in g𝑔gitalic_g and the Gumbel logits parameter w𝑤witalic_w is conducted followed by three gradient ascent updates of the neural network parameters in f(0)superscript𝑓0f^{(0)}italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT and f(1)superscript𝑓1f^{(1)}italic_f start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT. We also use fixed γ=36𝛾36\gamma=36italic_γ = 36. The neural network architectures for all the estimators are the same and are the same as in the simulation of FAIR-NN. The implementation details for all the estimators are:

  • (1)

    Pooled-NN: it simply runs least squares on the full covariate X𝑋Xitalic_X using all the data.

  • (2)

    FAIR-NN-GB: Our FAIR-NN estimator with Gumbel approximation, its prediction on the test dataset is evaluated by averaging the predictions over 100100100100 Gumbel samples.

  • (3)

    FAIR-NN-RF: it first selects the variables xjsubscript𝑥𝑗x_{j}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT in the fitted model in (2) with sig(wj)>0.9sigsubscript𝑤𝑗0.9{\mathrm{sig}(w_{j})}>0.9roman_sig ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) > 0.9, i.e., S^={j:sig(wj)>t}^𝑆conditional-set𝑗sigsubscript𝑤𝑗𝑡\widehat{S}=\{j:{\mathrm{sig}(w_{j})}>t\}over^ start_ARG italic_S end_ARG = { italic_j : roman_sig ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) > italic_t }, and runs least squares again on XS^subscript𝑋^𝑆X_{\widehat{S}}italic_X start_POSTSUBSCRIPT over^ start_ARG italic_S end_ARG end_POSTSUBSCRIPT using all the data.

  • (4)

    Oracle-NN: it runs least squares on XSsubscript𝑋superscript𝑆X_{S^{\star}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT using all the data and neural networks.

  • (5)

    Oracle-Linear: it runs least squares on XSsubscript𝑋superscript𝑆X_{S^{\star}}italic_X start_POSTSUBSCRIPT italic_S start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT using all the data and linear model.

The out-of-sample R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for all the estimators is reported based on the model selection using the validation set that is sampled from the same source as training data with sample size n=0.6nsuperscript𝑛0.6𝑛n^{\prime}=0.6nitalic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 0.6 italic_n. Such a model selection is adopted to prevent the model from over-fitting.

C.4 Details of the Prediction Based on Extracted Features

We generate datasets by combining the bird images in the CUB dataset (Wah et al.,, 2011) and the background images in the Places dataset (Zhou et al.,, 2017) using specific probabilities, which is similar to the waterbird setting in Sagawa et al., (2020) except the spurious correlation ratio. In each environment, there are 50%percent5050\%50 % water birds and 50%percent5050\%50 % land birds. The probabilities of each environment are as follows:

  • (a)

    Environment-1. We place 95%percent9595\%95 % of all water birds against a water background, with the remaining 5%percent55\%5 % against a land background. We place 90%percent9090\%90 % of all land birds against a land background, with the remaining 10%percent1010\%10 % against a water background. The dataset is denoted by 𝒟1subscript𝒟1\mathcal{D}_{1}caligraphic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, with 50505050k images.

  • (b)

    Environment-2. We place 75%percent7575\%75 % of all waterbirds against a water background, with the remaining 25%percent2525\%25 % against a land background. We place 70%percent7070\%70 % of all landbirds against a land background, with the remaining 30%percent3030\%30 % against a water background. The dataset is denoted by 𝒟2subscript𝒟2\mathcal{D}_{2}caligraphic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, with 50505050k images.

  • (c)

    Environment-3 (Test Environment). We only place 2%percent22\%2 % of all waterbirds against a water background, with the remaining 98%percent9898\%98 % against a land background. We place 2%percent22\%2 % of all landbirds against a land background, with the remaining 98%percent9898\%98 % against a water background. The dataset is denoted by 𝒟3subscript𝒟3\mathcal{D}_{3}caligraphic_D start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT, with 30303030k images.

  • (d)

    Environment-4 (Oracle Environment). We place 50%percent5050\%50 % of all waterbirds against a water background, with the remaining 50%percent5050\%50 % against a land background. We place 50%percent5050\%50 % of all landbirds against a land background, with the remaining 50%percent5050\%50 % against a water background. The dataset is denoted by 𝒟4subscript𝒟4\mathcal{D}_{4}caligraphic_D start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT, with 30303030k images.

Class Identification

We apply the CUB dataset Wah et al., (2011), which contains images of birds, along with pixel-level segmentation masks for each bird. When generating the dataset, we classify each bird into waterbird if it belongs to the seabird or waterfowl categories (e.g., albatross, auklet, cormorant, frigatebird, fulmar, gull, jaeger, kittiwake, pelican, puffin, tern, gadwall, grebe, mallard, merganser, guillemot, or Pacific loon) and the land birds if it does not belong to the seabird or waterfowl categories.

Image Generation

When picking bird images from the CUB dataset, we use the provided pixel-level segmentation masks to crop each bird from its original background. Then we decide which environment they should be placed in and select either a water background like ocean and lake or a land background like bamboo forest and broadleaf forest sourced from the Places dataset Zhou et al., (2017). We randomly select 70%percent7070\%70 % of the images in the CUB dataset as a training set and the remaining 30%percent3030\%30 % as a testing set and generate our dataset for training and testing based on the split CUB dataset.

Feature Extraction

Based on the dataset, we use the Pytorch torchvision implementation of the ResNet50 model He et al., (2016) with the pre-trained weights to extract the feature of the images, obtaining a dataset of the feature vector of 2048 dimensions. Then we apply principal components analysis (PCA) to reduce the dimensions of the feature vector to 500500500500 based on the whole training data 𝒟1subscript𝒟1\mathcal{D}_{1}caligraphic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝒟2subscript𝒟2\mathcal{D}_{2}caligraphic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. We apply the same dimensionality reduction transformation to data in other environments.

Experiment Setup

We run FAIR-Linear with Gumbel approximation on the dataset. Following the standard setting, we apply the logistic loss and Adam optimizer using a learning rate of 1e21𝑒21e-21 italic_e - 2, weight decay of 1e41𝑒41e-41 italic_e - 4, and batch size 4096409640964096 for 10000100001000010000 iterations. In each iteration, one gradient descent update of the neural network parameters in g𝑔gitalic_g and the Gumbel logits parameter ω𝜔\omegaitalic_ω is conducted based on 5555 gradient ascent updates of the neural network parameters in f𝑓fitalic_f. We fix γ𝛾\gammaitalic_γ as 200200200200. The implementation details for all the estimators are:

(1) Oracle: it runs logistic regression with 1subscript1\ell_{1}roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT penalty and penalty weight α=0.001𝛼0.001\alpha=0.001italic_α = 0.001 on the oracle environment for 1000100010001000 iterations.

(2) Pooled Lasso: it runs logistic regression with 1subscript1\ell_{1}roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT penalty and penalty weight α=0.001𝛼0.001\alpha=0.001italic_α = 0.001 on the Environment-1 and Environment-2 for 1000100010001000 iterations.

(3) Lasso on D2: it runs logistic regression with 1subscript1\ell_{1}roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT penalty and penalty weight α=0.001𝛼0.001\alpha=0.001italic_α = 0.001 on the Environment-2 for 1000100010001000 iterations.

(4) FAIR-GB: Our FAIR-Linear estimator with Gumbel approximation trained on Environment-1 and Environment-2 for 10000100001000010000 iterations.

(5) IRM: it runs Invariant Risk Minimization (IRM) trained on Environment-1 and Environment-2 with 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT regularizer weight 0.0010.0010.0010.001 and penalty weight 100100100100 for 10000100001000010000 iterations.

(6) GroupDRO: it runs Group Distributionally Robust Optimization (Group-DRO) on Environment-1 and Environment-2 using ResNet50 and γ=0.1𝛾0.1\gamma=0.1italic_γ = 0.1 for 10000100001000010000 iterations.

References

  • Agarwal & Zhang, (2022) Agarwal, A. & Zhang, T. (2022). Minimax regret optimization for robust machine learning under distribution shift. In Conference on Learning Theory (pp. 2704–2729).: PMLR.
  • Anthony & Bartlett, (1999) Anthony, M. & Bartlett, P. L. (1999). Neural Network Learning: Theoretical Foundations. Cambridge University Press.
  • Arjovsky et al., (2019) Arjovsky, M., Bottou, L., Gulrajani, I., & Lopez-Paz, D. (2019). Invariant risk minimization. arXiv preprint arXiv:1907.02893.
  • Athey et al., (2019) Athey, S., Tibshirani, J., & Wager, S. (2019). Generalized random forests. The Annals of Statistics, 47(2), 1148.
  • Bartlett et al., (2019) Bartlett, P. L., Harvey, N., Liaw, C., & Mehrabian, A. (2019). Nearly-tight vc-dimension and psuedodimension bounds for piecewise linear neural networks. Journal of Machine Learning Research, 20(63), 1–17.
  • Bauer & Kohler, (2019) Bauer, B. & Kohler, M. (2019). On deep learning as a remedy for the curse of dimensionality in nonparametric regression. The Annals of Statistics, 47(4), 2261–2285.
  • Breiman, (2001) Breiman, L. (2001). Statistical modeling: The two cultures (with comments and a rejoinder by the author). Statistical science, 16(3), 199–231.
  • Chernozhukov et al., (2020) Chernozhukov, V., Newey, W., Singh, R., & Syrgkanis, V. (2020). Adversarial estimation of riesz representers. arXiv preprint arXiv:2101.00009.
  • Chickering, (2002) Chickering, D. M. (2002). Optimal structure identification with greedy search. Journal of machine learning research, 3(Nov), 507–554.
  • Dikkala et al., (2020) Dikkala, N., Lewis, G., Mackey, L., & Syrgkanis, V. (2020). Minimax estimation of conditional moment models. Advances in Neural Information Processing Systems, 33, 12248–12262.
  • Duchi & Namkoong, (2021) Duchi, J. C. & Namkoong, H. (2021). Learning models with uniform performance via distributionally robust optimization. The Annals of Statistics, 49(3), 1378–1406.
  • Fan et al., (2023) Fan, J., Fang, C., Gu, Y., & Zhang, T. (2023). Environment invariant linear least squares. arXiv preprint arXiv:2303.03092.
  • Fan & Gu, (2024) Fan, J. & Gu, Y. (2024). Factor augmented sparse throughput deep relu neural networks for high dimensional regression. Journal of American Statistical Association, to appear.
  • Fan et al., (2022) Fan, J., Gu, Y., & Zhou, W.-X. (2022). How do noise tails impact on deep relu networks? arXiv preprint arXiv:2203.10418.
  • Fan et al., (2020) Fan, J., Li, R., Zhang, C.-H., & Zou, H. (2020). Statistical foundations of data science. Chapman and Hall/CRC.
  • Fan & Liao, (2014) Fan, J. & Liao, Y. (2014). Endogeneity in high dimensions. Annals of statistics, 42(3), 872.
  • Fan & Lv, (2008) Fan, J. & Lv, J. (2008). Sure independence screening for ultrahigh dimensional feature space. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 70(5), 849–911.
  • Farrell et al., (2021) Farrell, M. H., Liang, T., & Misra, S. (2021). Deep neural networks for estimation and inference. Econometrica, 89(1), 181–213.
  • Foster & Syrgkanis, (2019) Foster, D. J. & Syrgkanis, V. (2019). Orthogonal statistical learning. arXiv preprint arXiv:1901.09036.
  • Gamella et al., (2024) Gamella, J. L., Peters, J., & Bühlmann, P. (2024). The causal chambers: Real physical systems as a testbed for ai methodology. arXiv preprint arXiv:2404.11341.
  • Gauss, (1809) Gauss, C. F. (1809). Theoria Motus Corporum Coelestium in Sectionibus Conicis Solem Ambientium. Cambridge University Press; Reissue edition (May 19, 2011).
  • Geiger & Pearl, (1990) Geiger, D. & Pearl, J. (1990). On the logic of causal models. In Machine Intelligence and Pattern Recognition, volume 9 (pp. 3–14). Elsevier.
  • Ghassami et al., (2017) Ghassami, A., Salehkaleybar, S., Kiyavash, N., & Zhang, K. (2017). Learning causal structures using regression invariance. Advances in Neural Information Processing Systems, 30.
  • Glymour et al., (2016) Glymour, M., Pearl, J., & Jewell, N. P. (2016). Causal inference in statistics: A primer. John Wiley & Sons.
  • Goodfellow et al., (2014) Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., & Bengio, Y. (2014). Generative adversarial nets. Advances in neural information processing systems, 27.
  • Gretton et al., (2009) Gretton, A., Smola, A., Huang, J., Schmittfull, M., Borgwardt, K., Schölkopf, B., et al. (2009). Covariate shift by kernel mean matching. Dataset shift in machine learning, 3(4), 5.
  • Györfi et al., (2002) Györfi, L., Kohler, M., Krzyżak, A., & Walk, H. (2002). A Distribution-free Theory of Nonparametric Regression, volume 1. Springer.
  • Hastie et al., (2009) Hastie, T., Tibshirani, R., & Friedman, J. (2009). The elements of statistical learning: data mining, inference, and prediction. Springer Science & Business Media.
  • He et al., (2016) He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770–778).
  • Heinze-Deml et al., (2018) Heinze-Deml, C., Peters, J., & Meinshausen, N. (2018). Invariant causal prediction for nonlinear models. Journal of Causal Inference, 6(2).
  • Hirshberg & Wager, (2021) Hirshberg, D. A. & Wager, S. (2021). Augmented minimax linear estimation. The Annals of Statistics, 49(6), 3206–3227.
  • Hoyer et al., (2008) Hoyer, P., Janzing, D., Mooij, J. M., Peters, J., & Schölkopf, B. (2008). Nonlinear causal discovery with additive noise models. Advances in neural information processing systems, 21.
  • Hyttinen et al., (2014) Hyttinen, A., Eberhardt, F., & Järvisalo, M. (2014). Constraint-based causal discovery: Conflict resolution with answer set programming. In Conference on Uncertainty in Artificial Intelligence (pp. 340–349).: AUAI Press.
  • Hyttinen et al., (2013) Hyttinen, A., Hoyer, P. O., Eberhardt, F., & Järvisalo, M. (2013). Discovering cyclic causal models with latent variables: A general sat-based procedure. In Uncertainty in Artificial Intelligence (pp. 301).: Citeseer.
  • Janzing et al., (2016) Janzing, D., Chaves, R., & Schölkopf, B. (2016). Algorithmic independence of initial condition and dynamical law in thermodynamics and causal inference. New Journal of Physics, 18(9), 093052.
  • Janzing et al., (2012) Janzing, D., Mooij, J., Zhang, K., Lemeire, J., Zscheischler, J., Daniušis, P., Steudel, B., & Schölkopf, B. (2012). Information-geometric approach to inferring causal directions. Artificial Intelligence, 182, 1–31.
  • Kamath et al., (2021) Kamath, P., Tangella, A., Sutherland, D., & Srebro, N. (2021). Does invariant risk minimization capture invariance? In International Conference on Artificial Intelligence and Statistics (pp. 4069–4077).: PMLR.
  • Kennedy et al., (2024) Kennedy, E. H., Balakrishnan, S., Robins, J. M., & Wasserman, L. (2024). Minimax rates for heterogeneous causal effect estimation. The Annals of Statistics, 52(2), 793–816.
  • Kohler & Langer, (2021) Kohler, M. & Langer, S. (2021). On the rate of convergence of fully connected deep neural network regression estimates. The Annals of Statistics, 49(4), 2231–2249.
  • Legendre, (1805) Legendre, A.-M. (1805). Nouvelles méthodes pour la détermination des orbites des comètes [New Methods for the Determination of the Orbits of Comets] (in French). Paris: F. Didot.
  • Liang, (2021) Liang, T. (2021). How well generative adversarial networks learn distributions. Journal of Machine Learning Research, 22(228), 1–41.
  • Lu et al., (2021) Lu, J., Shen, Z., Yang, H., & Zhang, S. (2021). Deep network approximation for smooth functions. SIAM Journal on Mathematical Analysis, 53(5), 5465–5506.
  • Meinshausen & Bühlmann, (2015) Meinshausen, N. & Bühlmann, P. (2015). Maximin effects in inhomogeneous large-scale data. The Annals of Statistics, 43(4), 1801–1830.
  • Nelder & Wedderburn, (1972) Nelder, J. A. & Wedderburn, R. W. (1972). Generalized linear models. Journal of the Royal Statistical Society Series A: Statistics in Society, 135(3), 370–384.
  • Peters et al., (2016) Peters, J., Bühlmann, P., & Meinshausen, N. (2016). Causal inference by using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society. Series B (Statistical Methodology), (pp. 947–1012).
  • Peters et al., (2014) Peters, J., Mooij, J. M., Janzing, D., & Schölkopf, B. (2014). Causal discovery with continuous additive noise models. Journal of Machine Learning Research, 15, 2009–2053.
  • Pfister et al., (2019) Pfister, N., Bühlmann, P., & Peters, J. (2019). Invariant causal prediction for sequential data. Journal of the American Statistical Association, 114(527), 1264–1276.
  • Pfister et al., (2021) Pfister, N., Williams, E. G., Peters, J., Aebersold, R., & Bühlmann, P. (2021). Stabilizing variable selection and regression. The Annals of Applied Statistics, 15(3), 1220–1246.
  • Popper, (2005) Popper, K. (2005). The logic of scientific discovery. Routledge.
  • Raskutti et al., (2012) Raskutti, G., J Wainwright, M., & Yu, B. (2012). Minimax-optimal rates for sparse additive models over kernel classes via convex programming. Journal of machine learning research, 13(2).
  • Richardson, (1996) Richardson, T. (1996). Feedback models: Interpretation and discovery. PhD thesis, Ph. D. thesis, Carnegie Mellon.
  • Robins et al., (1994) Robins, J. M., Rotnitzky, A., & Zhao, L. P. (1994). Estimation of regression coefficients when some regressors are not always observed. Journal of the American statistical Association, 89(427), 846–866.
  • Rojas-Carulla et al., (2018) Rojas-Carulla, M., Schölkopf, B., Turner, R., & Peters, J. (2018). Invariant models for causal transfer learning. The Journal of Machine Learning Research, 19(1), 1309–1342.
  • Rosenfeld et al., (2021) Rosenfeld, E., Ravikumar, P., & Risteski, A. (2021). The risks of invariant risk minimization. International Conference on Learning Representations.
  • Rothenhäusler et al., (2019) Rothenhäusler, D., Bühlmann, P., & Meinshausen, N. (2019). Causal dantzig: fast inference in linear structural equation models with hidden variables under additive interventions. The Annals of Statistics, 47(3), 1688–1722.
  • Rothenhäusler et al., (2021) Rothenhäusler, D., Meinshausen, N., Bühlmann, P., & Peters, J. (2021). Anchor regression: Heterogeneous data meet causality. Journal of the Royal Statistical Society. Series B, Statistical Methodology, 83(2), 215–246.
  • Rubin, (1974) Rubin, D. B. (1974). Estimating causal effects of treatments in randomized and nonrandomized studies. Journal of educational Psychology, 66(5), 688.
  • Sagawa et al., (2020) Sagawa, S., Koh, P. W., Hashimoto, T. B., & Liang, P. (2020). Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. International Conference on Learning Representations.
  • Schmidt-Hieber, (2020) Schmidt-Hieber, J. (2020). Nonparametric regression using deep neural networks with relu activation function (with discussion). The Annals of Statistics, 48(4), 1875–1921.
  • Shimizu et al., (2006) Shimizu, S., Hoyer, P. O., Hyvärinen, A., Kerminen, A., & Jordan, M. (2006). A linear non-gaussian acyclic model for causal discovery. Journal of Machine Learning Research, 7(10).
  • Spirtes et al., (2000) Spirtes, P., Glymour, C. N., & Scheines, R. (2000). Causation, prediction, and search. MIT press.
  • Van de Geer, (2008) Van de Geer, S. A. (2008). High-dimensional generalized linear models and the lasso. The Annals of Statistics, 36(2), 614–645.
  • Wah et al., (2011) Wah, C., Branson, S., Welinder, P., Perona, P., & Belongie, S. (2011). The caltech-ucsd birds-200-2011 dataset.
  • Wainwright, (2019) Wainwright, M. J. (2019). High-dimensional statistics: A non-asymptotic viewpoint, volume 48. Cambridge University Press.
  • Yin et al., (2021) Yin, M., Wang, Y., & Blei, D. M. (2021). Optimization-based causal estimation from heterogenous environments. arXiv preprint arXiv:2109.11990.
  • Yuan & Zhou, (2016) Yuan, M. & Zhou, D.-X. (2016). Minimax optimal rates of estimation in high dimensional additive models. The Annals of Statistics, 44(6), 2564–2593.
  • Zhang & Hyvärinen, (2009) Zhang, K. & Hyvärinen, A. (2009). On the identifiability of the post-nonlinear causal model. In 25th Conference on Uncertainty in Artificial Intelligence (UAI 2009) (pp. 647–655).: AUAI Press.
  • Zhou et al., (2017) Zhou, B., Lapedriza, A., Khosla, A., Oliva, A., & Torralba, A. (2017). Places: A 10 million image database for scene recognition. IEEE transactions on pattern analysis and machine intelligence, 40(6), 1452–1464.