Stable Heterogeneous Treatment Effect Estimation across Out-of-Distribution Populations

Yuling Zhang1, Anpeng Wu34, Kun Kuang3, Liang Du5, Zixun Sun5, and Zhi Wang12 Kun Kuang and Zhi Wang are the corresponding authors. 1Tsinghua Shenzhen International Graduate School, Tsinghua University 2Tsinghua-Berkeley Shenzhen Institute, Tsinghua University 3College of Computer Science and Technology, Zhejiang University 4Machine Learning Department, Mohamed bin Zayed University of Artificial Intelligence 5Interactive Entertainment Group, Tencent Email: [email protected], {anpwu,kunkuang}@zju.edu.cn, {lucasdu,zixunsun}@tencent.com, [email protected]
Abstract

Heterogeneous treatment effect (HTE) estimation is vital for understanding the change of treatment effect across individuals or subgroups. Most existing HTE estimation methods focus on addressing selection bias induced by imbalanced distributions of confounders between treated and control units, but ignore distribution shifts across populations. Thereby, their applicability has been limited to the in-distribution (ID) population, which shares a similar distribution with the training dataset. In real-world applications, where population distributions are subject to continuous changes, there is an urgent need for stable HTE estimation across out-of-distribution (OOD) populations, which, however, remains an open problem. As pioneers in resolving this problem, we propose a novel Stable Balanced Representation Learning with Hierarchical-Attention Paradigm (SBRL-HAP) framework, which consists of 1) Balancing Regularizer for eliminating selection bias, 2) Independence Regularizer for addressing the distribution shift issue, 3) Hierarchical-Attention Paradigm for coordination between balance and independence. In this way, SBRL-HAP regresses counterfactual outcomes using ID data, while ensuring the resulting HTE estimation can be successfully generalized to out-of-distribution scenarios, thereby enhancing the model’s applicability in real-world settings. Extensive experiments conducted on synthetic and real-world datasets demonstrate the effectiveness of our SBRL-HAP in achieving stable HTE estimation across OOD populations, with an average 10%percent1010\%10 % reduction in the error metric PEHE and 11%percent1111\%11 % decrease in the ATE bias, compared to the SOTA methods.

Index Terms:
Heterogeneous Treatment Effect; Out-of-Distribution; Balanced Representation Learning; Hierarchical-Attention Optimization

I Introduction

Refer to caption
Figure 1: Two main challenges in stable HTE estimation across OOD populations: (C1) selection bias from imbalanced treatment assignment, and (C2) distribution shift across various populations. The former is manifested as imbalanced distributions of covariates (e.g., age) between treated (i.e., T=1) and control (i.e., T=0) units in a specific population. The latter occurs frequently in real-world applications, resulting in out-of-distribution populations that have distinct covariate distributions from the training dataset. This work is among the first to synergistically resolve both selection bias and distribution shift.

Estimating Heterogeneous Treatment Effects (HTE) from observational data has gained increasing importance across various fields [1], including medicine, economics, and marketing [2, 3, 4, 5]. This can provide practitioners valuable insights into understanding how treatment effects vary among different subpopulations, ultimately achieving personalized health-care and explainable decision-making. However, reliable and robust estimation of HTE still faces significant challenges. One primary challenge in observational data is non-random treatment assignment, which can lead to imbalanced covariate distributions between treated and control units (Top panel in Fig. 1). Taking healthcare as an example, in the study of the effect of treatment on outcomes, physicians may assign different treatment recommendations (e.g., taking the drug or not) based on the patient’s individual circumstances (e.g., age). Typically, physicians recommend young individuals to take the medication more often while advising older individuals not to take it. Such imbalanced treatment assignment can result in selection bias [6], manifested as differences in age distribution between the treated group and the control group. As selection bias has been taken seriously by academia and industry [7, 8, 9, 10], various methods such as propensity score matching, doubly robust, stratification, inverse probability of treatment weighting (IPTW) [11, 12, 13, 14], and representation learning methods [15, 16, 17, 18, 19, 20, 21, 22] have been developed to reduce selection bias and estimate treatment effects more accurately.

However, one limitation is that these methods have only been tested and validated on data that is similar to the training data, known as in-distribution (ID) data. In real-world applications, where data or population distributions, specifically the covariate distributions, are subject to continuous changes [23, 24, 25, 26, 27], there is a concern regarding the performance of these methods when applied to populations with different covariate distributions compared to the training dataset [28, 29, 30, 31, 32]. This issue, referred to as distribution shift [33, 34], has posed another significant challenge to achieving stable HTE estimation for out-of-distribution (OOD) populations. As shown in Fig. 1, the distribution of patients’ circumstances may change over time, seasons, holidays, urban and rural areas, etc., resulting in the emergence of various populations. These populations may have different data distributions and characteristics compared to the training data, and they may even include individuals that were not present in the training data. Due to induction bias, the causal relations learnt from training data (e.g., data collected during weekdays) are typically not applicable to testing data (e.g., data collected during weekends). If we directly use the above causal models trained on one specific dataset, it may lead to unstable and unreliable HTE estimation for other populations. Such unreliability of HTE estimation can lead to inappropriate treatment choices, posing a huge threat to patients’ health and even resulting in catastrophic medical events. Therefore, there is an urgent demand to develop stable HTE estimation methods that can effectively generalize to unseen samples or different populations.

In this paper, we first study the problem of stable HTE estimation across OOD populations, and systemically review the two main challenges (Fig. 1): (C1) selection bias from imbalanced treatment assignment, and (C2) distribution shift across various populations. Selection bias in observational data can lead to unreliable and biased HTE estimation. Although many previous causal methods have been proposed to eliminate selection bias, they still suffer from the distribution shift issue, resulting in a higher error and unstable estimates of HTE on out-of-distribution populations.

To address the selection bias, Balanced Representation Learning (BRL) has been developed to map the original covariates to a representation space and narrow the representation discrepancies across different treatment arms [15, 16, 17]. This approach enables accurate HTE estimation within the in-distribution data. Nevertheless, in the presence of distribution shifts across various populations, the problem of stable HTE estimation across OOD populations remains relatively unexplored. Among the many OOD generalization algorithms, Stable Learning (SL) stands out as a promising approach [35, 36, 37] based on the following observation. For general machine learning models, model crashes under distribution shifts are mainly caused by the unstable correlation between irrelevant features and the target outcome. This kind of unstable correlation fundamentally stems from the statistical dependence between relevant and irrelevant features [38, 39, 40]. Therefore, to address distribution shift and maintain performance across OOD data, SL methods propose to decorrelate all features by sample reweighting, facilitating models to recognize stable and invariant relationships between features and outcomes.

Building upon these methods, we propose a novel framework called Stable Balanced Representation Learning with Hierarchical-Attention Paradigm (SBRL-HAP), which comprises three core components: (a) Balancing Regularizer (BR) to eliminate selection bias and obtain balanced representations; (b) Independence Regularizer (IR) to reweight samples and enforce independencies between features, addressing the distribution shift issue; (c) Hierarchical-Attention Paradigm (HAP) to assign distinct priorities to each neural network layers for comprehensive feature decorrelation throughout the learning process. Notably, in the training process of BR and IR, the learning of balanced representation and independence-driven weights can be interdependent. For instance, when representations change, the learning of weights would also adapt accordingly. In such cases, optimizing one objective may come at the expense of the other. Consequently, we design a Hierarchical-Attention Paradigm to synergistically facilitate the learning of balanced representations and independence-driven weights, thereby alleviating conflicts. To differentiate, we refer to the model without the Hierarchical-Attention Paradigm as SBRL.

The primary contributions of this paper are threefold:

  • In this paper, we first investigate the problem of stable heterogeneous treatment effect estimation across out-of-distribution populations and pioneer the integration of representation balancing and stable training techniques.

  • We propose a novel SBRL-HAP framework in which the Hierarchical-Attention Paradigm eliminates selection bias and addresses distribution shifts through comprehensive decorrelation in a hierarchical manner. This flexible framework enables the extension of any existing representation balancing method to various OOD populations.

  • Extensive experiments conducted on synthetic and real-world data demonstrate the effectiveness of our SBRL-HAP in achieving stable HTE estimation across OOD populations, compared to the SOTA methods. On the OOD datasets, our SBRL-HAP reduces the error metric PEHE by 10%percent1010\%10 % on average compared with the best baseline, and reduces the ATE bias by up to 14%percent1414\%14 %.

II Related Work

Representation Balancing to Mitigate Selection Bias. Many prior works have concentrated on addressing the challenges of estimating heterogeneous treatment effects from observational data while mitigating selection bias, with a promising method being balanced representation learning [14, 41]. This method minimizes the distribution distance between treated and control groups, effectively balancing confounders and producing similar distributions in the representation space, ultimately improving prediction accuracy for heterogeneous treatment effects. Specifically, representation balancing methods can be broadly categorized into five groups: 1) Fundamental methods, such as CFR [15, 42], which learn balanced representation by directly minimizing distribution distance between the treated and control groups; 2) Reweighting methods, such as RCFR [43] and CFR-ISW [16], which incorporate information from treatments and use importance sampling techniques to further mitigate the negative impact of selection bias; 3) Similarity-based methods, such as SITE [22] and ACE [21], which focus on learning balanced representations while preserving similarity information among data points; 4) Subgroup methods, such as HNN [44] and SCI [20], which enhance the model’s predictive ability by identifying and partitioning sub-spaces within the representation; and 5) Decomposition methods, such as DR-CFR [18] and DeR-CFR [17], which separate confounders from pre-treatment variables to achieve precise balancing of covariates. These methods have proven successful in estimating treatment effects without taking distribution shifts into account, but they may be prone to performance degradation in OOD scenarios.

Stable Learning to Eliminate Distribution Shifts. Distribution shifts across distinct populations in HTE estimation are not as well explored, and stable learning is a promising approach to address the distribution shift issue [35]. Taking inspiration from variable balancing strategies in causal inference [45, 46, 47], stable learning eliminates dependence among covariates via sample reweighting to manifest causation, thus utilizing the stability of causation to achieve generalization. Recently, several studies, including [48, 49, 50, 39], have aimed to tackle the discrepancy between the training and testing distribution stemming from datasets collected at different time periods or platforms. These approaches have the potential to handle distribution shifts in HTE estimation. Among them, CRLR [48] addresses distribution shifts by simultaneously optimizing global confounder balancing and weighted logistic regression to estimate the causal effect of each variable on the outcome. However, CRLR requires that all the features and labels be binary, which is impractical in real-world applications. To overcome this limitation, DWR [49] proposes to utilize the statistical independence condition to force that variables are independent of each other, thereby relaxing the binary restriction. Furthermore, SRDO [50] constructs an uncorrelated design matrix from original covariates to alleviate the issue of co-linearity among variables. On the other hand, StableNet [39] goes beyond the linear case and addresses both linear and non-linear dependencies between variables using Random Fourier Features and the Hilbert-Schmidt Independence Criterion. Overall, stable learning techniques aim to realize model generalization across any distribution by excavating stable relationships through feature decorrelation.

Although Representation Balancing [15] and Stable Learning [39] can address Selection Bias from imbalanced treatment assignment and distribution shift across data respectively, their optimization objectives are not orthogonal. The learning of weights and representations can interfere with each other, which is the reason why few works have discussed Stable Estimation in HTE across data. Considering the increasing importance of stable HTE estimation, this work pioneers to propose a novel framework named SBRL-HAP, in which the Hierarchical-Attention Paradigm coordinates the Balancing Regularizer and Independence Regularizer to extract balanced and stable representations, thus bridging these two topics.

III Problem Setup and Assumptions

III-A Problem Setup

In this paper, we study the heterogeneous treatment effect estimation across multiple populations. For simplicity, we consider that the population used for training the model is drawn from environment e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E and the target population is from environment esuperscript𝑒e^{\prime}\in\mathcal{E}italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E. Taking healthcare as an example, as illustrated in Fig. 1, we gather an observational De={𝐗e,Te,Ye}={𝐱ie,tie,yiti,e}i=1nsuperscript𝐷𝑒superscript𝐗𝑒superscript𝑇𝑒superscript𝑌𝑒superscriptsubscriptsuperscriptsubscript𝐱𝑖𝑒superscriptsubscript𝑡𝑖𝑒superscriptsubscript𝑦𝑖subscript𝑡𝑖𝑒𝑖1𝑛D^{e}=\{\mathbf{X}^{e},T^{e},Y^{e}\}=\{\mathbf{x}_{i}^{e},t_{i}^{e},y_{i}^{t_{% i},e}\}_{i=1}^{n}italic_D start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT = { bold_X start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , italic_T start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , italic_Y start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT } = { bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_e end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT from urban hospitals represented by the environment e𝑒eitalic_e, where 𝐱ie𝒳subscriptsuperscript𝐱𝑒𝑖𝒳\mathbf{x}^{e}_{i}\in\mathcal{X}bold_x start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ caligraphic_X denotes the covariates (e.g., patients circumstances), tie{0,1}superscriptsubscript𝑡𝑖𝑒01t_{i}^{e}\in\{0,1\}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ∈ { 0 , 1 } denotes the received treatment (e.g., take drug or not), and yiti,e𝒴superscriptsubscript𝑦𝑖subscript𝑡𝑖𝑒𝒴y_{i}^{t_{i},e}\in\mathcal{Y}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_e end_POSTSUPERSCRIPT ∈ caligraphic_Y is the observed outcome corresponding to the treatment tiesuperscriptsubscript𝑡𝑖𝑒t_{i}^{e}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT. Then, in the target environment esuperscript𝑒e^{\prime}\in\mathcal{E}italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E different from e𝑒eitalic_e, such as a remote village, we have a potential population denoted as De={𝐗e}superscript𝐷superscript𝑒superscript𝐗superscript𝑒D^{e^{\prime}}=\{\mathbf{X}^{e^{\prime}}\}italic_D start_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT = { bold_X start_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT }. This dataset only includes the covariates 𝐱𝐱\mathbf{x}bold_x of the individuals in the target population, without the corresponding treatment or outcome information. Our goal is to learn a causal model from the dataset Desuperscript𝐷𝑒D^{e}italic_D start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT which enables accurate HTE estimations for the target populations from different environments esuperscript𝑒e^{\prime}\in\mathcal{E}italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E. We refer to this problem as Heterogeneous Treatment Effect Estimation across Out-of-Distribution Populations.

Our work focuses on the Heterogeneous Treatment Effect at the individual level, i.e., Individual Treatment Effect (ITE), and Average Treatment Effect (ATE) at the population level.

Definition 3.1 (Individual Treatment Effect).

Given any environment e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, the Individual Treatment Effect of unit i𝑖iitalic_i is:

ITEie=yi1,eyi0,e,𝐼𝑇superscriptsubscript𝐸𝑖𝑒superscriptsubscript𝑦𝑖1𝑒superscriptsubscript𝑦𝑖0𝑒ITE_{i}^{e}=y_{i}^{1,e}-y_{i}^{0,e},italic_I italic_T italic_E start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT = italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , italic_e end_POSTSUPERSCRIPT - italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 , italic_e end_POSTSUPERSCRIPT , (1)

where yi1,esuperscriptsubscript𝑦𝑖1𝑒y_{i}^{1,e}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , italic_e end_POSTSUPERSCRIPT and yi0,esuperscriptsubscript𝑦𝑖0𝑒y_{i}^{0,e}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 , italic_e end_POSTSUPERSCRIPT are potential outcomes.

Definition 3.2 (Average Treatment Effect).

Given any environment e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, the Average Treatment Effect of Desuperscript𝐷𝑒D^{e}italic_D start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT is:

ATEe=𝔼[Y1,eY0,e]=1ni=1n(yi1,eyi0,e).𝐴𝑇superscript𝐸𝑒𝔼delimited-[]superscript𝑌1𝑒superscript𝑌0𝑒1𝑛superscriptsubscript𝑖1𝑛superscriptsubscript𝑦𝑖1𝑒superscriptsubscript𝑦𝑖0𝑒ATE^{e}=\mathbb{E}[Y^{1,e}-Y^{0,e}]=\frac{1}{n}\sum_{i=1}^{n}(y_{i}^{1,e}-y_{i% }^{0,e}).italic_A italic_T italic_E start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT = blackboard_E [ italic_Y start_POSTSUPERSCRIPT 1 , italic_e end_POSTSUPERSCRIPT - italic_Y start_POSTSUPERSCRIPT 0 , italic_e end_POSTSUPERSCRIPT ] = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , italic_e end_POSTSUPERSCRIPT - italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 , italic_e end_POSTSUPERSCRIPT ) . (2)

III-B Assumptions

Given the training data De={𝐗e,Te,Ye}superscript𝐷𝑒superscript𝐗𝑒superscript𝑇𝑒superscript𝑌𝑒D^{e}=\{\mathbf{X}^{e},T^{e},Y^{e}\}italic_D start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT = { bold_X start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , italic_T start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , italic_Y start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT } from environment e𝑒eitalic_e, our goal is to find a regressor f():𝒳×𝒯𝒴:𝑓𝒳𝒯𝒴f(\cdot):\mathcal{X}\times\mathcal{T}\rightarrow\mathcal{Y}italic_f ( ⋅ ) : caligraphic_X × caligraphic_T → caligraphic_Y capable of precisely predicting potential outcomes across different OOD environments esuperscript𝑒e^{\prime}\in\mathcal{E}italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E. To eliminate the selection bias in Desuperscript𝐷𝑒D^{e}italic_D start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT, existing causal models rely on standard assumptions [51].

Assumption 3.1 (Stable Unit Treatment Value).

The distribution of the potential outcome of one unit is assumed to be independent of the treatment assignment of another unit.

Assumption 3.2 (Unconfoundedness).

The distribution of treatment is independent of the potential outcome when given covariates. Formally, T(Y0,Y1)|𝐗bottom𝑇conditionalsuperscript𝑌0superscript𝑌1𝐗T\bot(Y^{0},Y^{1})|\mathbf{X}italic_T ⊥ ( italic_Y start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_Y start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) | bold_X.

Assumption 3.3 (Overlap).

Every unit should have a nonzero probability to receive either treatment status. Formally, 0<p(T=1|𝐗)<10𝑝𝑇conditional1𝐗10<p(T=1|\mathbf{X})<10 < italic_p ( italic_T = 1 | bold_X ) < 1.

Additionally, without any prior knowledge or structural assumptions, it is impossible to figure out the distribution shift problem, since one cannot characterize the rare or unseen latent environments [52]. Thereby, we follow the assumption commonly used in studies of distribution shift [53, 52, 49].

Assumption 3.4 (Stable Representation).

There exists a stable representation Ψs(𝐗)superscriptΨ𝑠𝐗\Psi^{s}(\mathbf{X})roman_Ψ start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( bold_X ) of covariates 𝐗𝐗\mathbf{X}bold_X such that for any environment e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E, 𝔼[Y,T|𝐗e]=𝔼[Y,T|Ψs(𝐗e)]𝔼𝑌conditional𝑇superscript𝐗𝑒𝔼𝑌conditional𝑇superscriptΨ𝑠superscript𝐗𝑒\mathbb{E}[Y,T|\mathbf{X}^{e}]=\mathbb{E}[Y,T|\Psi^{s}(\mathbf{X}^{e})]blackboard_E [ italic_Y , italic_T | bold_X start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ] = blackboard_E [ italic_Y , italic_T | roman_Ψ start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( bold_X start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) ] holds.

This assumption implies covariates 𝐗𝐗\mathbf{X}bold_X include two parts: relevant features having causal effects on outcome Y𝑌Yitalic_Y, known as stable features 𝐗Ssubscript𝐗𝑆\mathbf{X}_{S}bold_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT; And irrelevant features (i.e., unstable features 𝐗Vsubscript𝐗𝑉\mathbf{X}_{V}bold_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT) that have Pe(Y|𝐗V)Pe(Y|𝐗V)superscript𝑃𝑒conditional𝑌subscript𝐗𝑉superscript𝑃superscript𝑒conditional𝑌subscript𝐗𝑉P^{e}(Y|\mathbf{X}_{V})\neq P^{e^{\prime}}(Y|\mathbf{X}_{V})italic_P start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_Y | bold_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ) ≠ italic_P start_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_Y | bold_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ) and create instability for prediction. The existence of 𝐗Ssubscript𝐗𝑆\mathbf{X}_{S}bold_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT provides the possibility of precisely predicting the outcome Y𝑌Yitalic_Y using Ψs(𝐗)superscriptΨ𝑠𝐗\Psi^{s}(\mathbf{X})roman_Ψ start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( bold_X ), which is known as the stable representation with invariant relationships to the outcome Y𝑌Yitalic_Y across different environments e𝑒e\in\mathcal{E}italic_e ∈ caligraphic_E [52, 53].

Challenges. Overall, we formally discuss challenges in stable HTE estimation across OOD populations. (C1) Selection bias refers to the inconsistent distribution of covariates between different treatment arms in a specific environment e𝑒eitalic_e, i.e., Pe(𝐗t)Pe(𝐗c)superscript𝑃𝑒superscript𝐗𝑡superscript𝑃𝑒superscript𝐗𝑐P^{e}(\mathbf{X}^{t})\neq P^{e}(\mathbf{X}^{c})italic_P start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( bold_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) ≠ italic_P start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( bold_X start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ), where 𝐗t={𝐱i:ti=1}superscript𝐗𝑡subscript𝐱:𝑖subscript𝑡𝑖1\mathbf{X}^{t}=\{\mathbf{x}_{i:t_{i}=1}\}bold_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = { bold_x start_POSTSUBSCRIPT italic_i : italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT } and 𝐗c={𝐱i:ti=0}superscript𝐗𝑐subscript𝐱:𝑖subscript𝑡𝑖0\mathbf{X}^{c}=\{\mathbf{x}_{i:t_{i}=0}\}bold_X start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT = { bold_x start_POSTSUBSCRIPT italic_i : italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 end_POSTSUBSCRIPT }. (C2) Distribution shift indicates that the marginal distribution of 𝐗𝐗\mathbf{X}bold_X shifts across environments while the conditional distribution P(T,Y|𝐗)𝑃𝑇conditional𝑌𝐗P(T,Y|\mathbf{X})italic_P ( italic_T , italic_Y | bold_X ) remains unchanged. That is, e,efor-all𝑒superscript𝑒\forall e,e^{\prime}\in\mathcal{E}∀ italic_e , italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_E, Pe(T,Y|𝐗)=Pe(T,Y|𝐗)superscript𝑃𝑒𝑇conditional𝑌𝐗superscript𝑃superscript𝑒𝑇conditional𝑌𝐗P^{e}(T,Y|\mathbf{X})=P^{e^{\prime}}(T,Y|\mathbf{X})italic_P start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( italic_T , italic_Y | bold_X ) = italic_P start_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_T , italic_Y | bold_X ) and Pe(𝐗)Pe(𝐗)superscript𝑃𝑒𝐗superscript𝑃superscript𝑒𝐗P^{e}(\mathbf{X})\neq P^{e^{\prime}}(\mathbf{X})italic_P start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ( bold_X ) ≠ italic_P start_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ( bold_X ). One naive method to address selection bias and the issue of distribution shift is to combine representation-based methods and stable learning techniques. To this end, we propose a Stable Balanced Representation Learning (SBRL) to estimate HTE across various populations. However, a novel challenge arises, i.e., (C3) the learning of balanced representation and independence-driven weights in SBRL can be interdependent, hence restricting the generalization performance of stable HTE estimation. It should be noticed that current stable learning techniques, designed for typical prediction tasks, learn sample weights by decorrelating the last layer of the network, while balanced representations are required in the first half of the network. Once the balanced representations are updated, adaptive weight modification is necessary, which, however, cannot guarantee feature independence for generalization. As a result, prioritizing the optimization of one objective may entail expenses in achieving the other, making it difficult to achieve stable HTE estimation across environments.

To overcome the above challenges, we propose a novel framework named Stable Balanced Representation Learning with Hierarchical-Attention Paradigm (SBRL-HAP) which settles the conflict between balance and independence in a holistic and hierarchical manner.

IV Methodology

Refer to caption
Figure 2: The framework of Stable Balanced Representation Learning with Hierarchical-Attention Paradigm (SBRL-HAP). SBRL-HAP consists of three modules: i. Balancing Regularizer restricts IPM for balanced representation, ii. Independence Regularizer eliminates feature dependence measured by HSIC-RFF for generalization, and iii. Hierarchical-Attention Paradigm decorrelates features comprehensively with a hierarchy for dispelling the interaction between balance and independence. With high flexibility, SBRL-HAP can be plugged into most balanced representation methods by replacing the neural network backbone.

In this section, we propose SBRL-HAP to stably estimate heterogeneous treatment effects across OOD populations. Firstly, we will present the overall framework of our SBRL-HAP. Subsequently, we will offer a thorough description of three components of SBRL-HAP. Finally, we will demonstrate the end-to-end optimization and training strategies.

Fig. 2 depicts the overall architecture of our SBRL-HAP which consists of three components for stable HTE estimation:

  • Balancing Regularizer (BR) employs Integral Probability Metrics (IPM) [54, 55] to measure the distribution discrepancy between the treated and control group, and proposes to adopt a model-free method to narrow the distribution discrepancy, so as to eliminate selection bias and obtain balanced representations.

  • Independence Regularizer (IR) learns sample weights to remove non-linear dependencies between features by utilizing the Hilbert-Schmidt Independence Criterion [56] with Random Fourier Features [57], thereby facilitating the identification of the stable relationships between features and potential outcomes.

  • Hierarchical-Attention Paradigm (HAP) emphasizes assigning distinct priorities to each neural network layer, in order to achieve comprehensive feature decorrelation with hierarchical attention. Therefore, HAP harmoniously integrates the Balancing Regularizer and Independence Regularizer, effectively resolving the conflict between balance and independence.

Next, we will describe each component of our SBRL-HAP model in detail, and then demonstrate the end-to-end optimization and training strategy.

IV-A Balancing Regularizer

The Balancing Regularizer is designed to eliminate selection bias and obtain a balanced representation by reducing the distribution discrepancy between different treatment arms with a model-free method. A typical metric used for measuring the distribution discrepancy is the Integral Probability Metric (IPM) [54, 55], which is formally defined as

dist(PΦc,PΦt)=supf|𝔼xPΦc[f(x)]𝔼xPΦt[f(x)]|,𝑑𝑖𝑠𝑡subscript𝑃subscriptΦ𝑐subscript𝑃subscriptΦ𝑡subscriptsupremum𝑓subscript𝔼similar-to𝑥subscript𝑃subscriptΦ𝑐delimited-[]𝑓𝑥subscript𝔼similar-to𝑥subscript𝑃subscriptΦ𝑡delimited-[]𝑓𝑥dist(P_{\Phi_{c}},P_{\Phi_{t}})=\sup\limits_{f\in\mathcal{F}}|\mathbb{E}_{x% \sim P_{\Phi_{c}}}[f(x)]-\mathbb{E}_{x\sim P_{\Phi_{t}}}[f(x)]|,italic_d italic_i italic_s italic_t ( italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) = roman_sup start_POSTSUBSCRIPT italic_f ∈ caligraphic_F end_POSTSUBSCRIPT | blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_f ( italic_x ) ] - blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_f ( italic_x ) ] | , (3)

where PΦc={Φ(𝐱i)}i:ti=0subscript𝑃subscriptΦ𝑐subscriptΦsubscript𝐱𝑖:𝑖subscript𝑡𝑖0P_{\Phi_{c}}=\{\Phi(\mathbf{x}_{i})\}_{i:t_{i}=0}italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT = { roman_Φ ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i : italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 end_POSTSUBSCRIPT and PΦt={Φ(𝐱i)}i:ti=1subscript𝑃subscriptΦ𝑡subscriptΦsubscript𝐱𝑖:𝑖subscript𝑡𝑖1P_{\Phi_{t}}=\{\Phi(\mathbf{x}_{i})\}_{i:t_{i}=1}italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT = { roman_Φ ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i : italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT denote the covariate distribution of the control group and the treated group in the representation space ΦΦ\Phiroman_Φ, respectively. For rich enough function families \mathcal{F}caligraphic_F, dist(PΦc,PΦt)=0PΦc=PΦt𝑑𝑖𝑠𝑡subscript𝑃subscriptΦ𝑐subscript𝑃subscriptΦ𝑡0subscript𝑃subscriptΦ𝑐subscript𝑃subscriptΦ𝑡dist(P_{\Phi_{c}},P_{\Phi_{t}})=0\Rightarrow P_{\Phi_{c}}=P_{\Phi_{t}}italic_d italic_i italic_s italic_t ( italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) = 0 ⇒ italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT = italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT holds [15]. Most previous works constrain the IPM dist(PΦc,PΦt)𝑑𝑖𝑠𝑡subscript𝑃subscriptΦ𝑐subscript𝑃subscriptΦ𝑡dist(P_{\Phi_{c}},P_{\Phi_{t}})italic_d italic_i italic_s italic_t ( italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) by directly optimizing network parameters [15, 42], thereby getting rid of selection bias. This practice may lead to an overbalanced representation discarding predictive information [17].

Therefore, we propose to adopt the sample reweighting technique to reduce network dependence. Specifically, our Balancing Regularizer strives to mitigate selection bias by learning a set of sample weights 𝐰=(w1,w2,,wn)+n𝐰subscript𝑤1subscript𝑤2subscript𝑤𝑛subscriptsuperscript𝑛\mathbf{w}=(w_{1},w_{2},\dots,w_{n})\in\mathbb{R}^{n}_{+}bold_w = ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT + end_POSTSUBSCRIPT with minimizing the following balance loss BsubscriptB\mathcal{L}_{\text{B}}caligraphic_L start_POSTSUBSCRIPT B end_POSTSUBSCRIPT:

min𝐰B=supf|𝔼xPΦc𝐰[f(x)]𝔼xPΦt𝐰[f(x)]|,subscript𝐰subscriptBsubscriptsupremum𝑓subscript𝔼similar-to𝑥superscriptsubscript𝑃subscriptΦ𝑐𝐰delimited-[]𝑓𝑥subscript𝔼similar-to𝑥superscriptsubscript𝑃subscriptΦ𝑡𝐰delimited-[]𝑓𝑥\min_{\mathbf{w}}\mathcal{L}_{\text{B}}=\sup\limits_{f\in\mathcal{F}}|\mathbb{% E}_{x\sim P_{\Phi_{c}}^{\mathbf{w}}}[f(x)]-\mathbb{E}_{x\sim P_{\Phi_{t}}^{% \mathbf{w}}}[f(x)]|,roman_min start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT B end_POSTSUBSCRIPT = roman_sup start_POSTSUBSCRIPT italic_f ∈ caligraphic_F end_POSTSUBSCRIPT | blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_w end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ italic_f ( italic_x ) ] - blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_w end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ italic_f ( italic_x ) ] | , (4)

where PΦc𝐰={wiΦ(𝐱i)}i:ti=0superscriptsubscript𝑃subscriptΦ𝑐𝐰subscriptsubscript𝑤𝑖Φsubscript𝐱𝑖:𝑖subscript𝑡𝑖0P_{\Phi_{c}}^{\mathbf{w}}=\{w_{i}\cdot\Phi(\mathbf{x}_{i})\}_{i:t_{i}=0}italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_w end_POSTSUPERSCRIPT = { italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ roman_Φ ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i : italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 end_POSTSUBSCRIPT and PΦt𝐰={wiΦ(𝐱i)}i:ti=1superscriptsubscript𝑃subscriptΦ𝑡𝐰subscriptsubscript𝑤𝑖Φsubscript𝐱𝑖:𝑖subscript𝑡𝑖1P_{\Phi_{t}}^{\mathbf{w}}=\{w_{i}\cdot\Phi(\mathbf{x}_{i})\}_{i:t_{i}=1}italic_P start_POSTSUBSCRIPT roman_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_w end_POSTSUPERSCRIPT = { italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ roman_Φ ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i : italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT denote the weighted probability distributions of covariates in the representation space ΦΦ\Phiroman_Φ with t=0𝑡0t=0italic_t = 0 and t=1𝑡1t=1italic_t = 1, respectively.

IV-B Independence Regularizer

The Independence Regularizer aims to eliminate feature dependencies, so as to recognize stable representations against distribution shifts. As stated in previous studies [35, 53, 39], the statistical correlation between stable features 𝐗Ssubscript𝐗𝑆\mathbf{X}_{S}bold_X start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT and unstable features 𝐗Vsubscript𝐗𝑉\mathbf{X}_{V}bold_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT is a major cause of model failure under distribution shifts, and thus, independence between variables can lead to more reliable and stable models. When variables are independent, alterations in one variable do not exert any influence on the other variables. Thereby, the relationships between variables and outcomes can be regarded as stable causation, facilitating the superior performance of models across different OOD populations.

The Independence Regularizer employs the Hilbert-Schmidt Independence Criterion with Random Fourier Features to measure the non-linear correlation between two variables. HSIC is widely utilized to measure the dependency between two random variables by comparing their representations in a Hilbert space [39, 44, 20]:

HSIC(A,B)=𝐊A𝐊BHS2,HSIC𝐴𝐵subscriptsuperscriptnormsubscript𝐊𝐴subscript𝐊𝐵2𝐻𝑆\displaystyle\text{HSIC}(A,B)=\|\mathbf{K}_{A}-\mathbf{K}_{B}\|^{2}_{HS},HSIC ( italic_A , italic_B ) = ∥ bold_K start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT - bold_K start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_H italic_S end_POSTSUBSCRIPT , (5)

where 𝐊A=kA(A,A)subscript𝐊𝐴subscript𝑘𝐴𝐴𝐴\mathbf{K}_{A}=k_{A}(A,A)bold_K start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT = italic_k start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_A , italic_A ) and 𝐊B=kB(B,B)subscript𝐊𝐵subscript𝑘𝐵𝐵𝐵\mathbf{K}_{B}=k_{B}(B,B)bold_K start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT = italic_k start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_B , italic_B ) are RBF kernel matrices, and HS\|\cdot\|_{HS}∥ ⋅ ∥ start_POSTSUBSCRIPT italic_H italic_S end_POSTSUBSCRIPT is the Hilbert-Schmidt norm. If the product kAkBsubscript𝑘𝐴subscript𝑘𝐵k_{A}k_{B}italic_k start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT is characteristic, and 𝔼[kA(A,A)]<𝔼delimited-[]subscript𝑘𝐴𝐴𝐴\mathbb{E}[k_{A}(A,A)]<\inftyblackboard_E [ italic_k start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ( italic_A , italic_A ) ] < ∞ and 𝔼[kB(B,B)]<𝔼delimited-[]subscript𝑘𝐵𝐵𝐵\mathbb{E}[k_{B}(B,B)]<\inftyblackboard_E [ italic_k start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_B , italic_B ) ] < ∞ hold, then ABperpendicular-to𝐴𝐵A\perp Bitalic_A ⟂ italic_B if and only if HSIC(A,B)=0HSIC𝐴𝐵0\text{HSIC}(A,B)=0HSIC ( italic_A , italic_B ) = 0. However, HSIC involving large-scale kernel matrices is computationally expensive.

Therefore, HSIC with Random Fourier Features (HSIC-RFF) is developed as an approximation technique for HSIC, leading to a notable reduction in time complexity. The function space of Random Fourier Features is:

RFF={h:x2cos(wx+φ)},subscriptRFFconditional-set𝑥2𝑤𝑥𝜑\displaystyle\mathcal{H}_{\text{RFF}}=\{h:x\rightarrow\sqrt{2}\cos{(wx+\varphi% )}\},caligraphic_H start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT = { italic_h : italic_x → square-root start_ARG 2 end_ARG roman_cos ( italic_w italic_x + italic_φ ) } , (6)

where w𝒩(0,1)similar-to𝑤𝒩01w\sim\mathcal{N}(0,1)italic_w ∼ caligraphic_N ( 0 , 1 ) and φ𝒰(0,2π)similar-to𝜑𝒰02𝜋\varphi\sim\mathcal{U}(0,2\pi)italic_φ ∼ caligraphic_U ( 0 , 2 italic_π ) from normal distribution and the uniform distribution. Then, the statistics of HSIC can be approximated as HSICRFFsubscriptHSICRFF\text{HSIC}_{\text{RFF}}HSIC start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT:

HSICRFF(A,B)subscriptHSICRFF𝐴𝐵\displaystyle\text{HSIC}_{\text{RFF}}(A,B)HSIC start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT ( italic_A , italic_B ) =C𝐮(A),𝐯(B)F2absentsubscriptsuperscriptnormsubscript𝐶𝐮𝐴𝐯𝐵2𝐹\displaystyle=\big{\|}C_{\mathbf{u}(A),\mathbf{v}(B)}\big{\|}^{2}_{F}= ∥ italic_C start_POSTSUBSCRIPT bold_u ( italic_A ) , bold_v ( italic_B ) end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT (7)
=i=1nAj=1nB|Cov(ui(A),vj(B))|2,absentsuperscriptsubscript𝑖1subscript𝑛𝐴superscriptsubscript𝑗1subscript𝑛𝐵superscript𝐶𝑜𝑣subscript𝑢𝑖𝐴subscript𝑣𝑗𝐵2\displaystyle=\sum_{i=1}^{n_{A}}\sum_{j=1}^{n_{B}}\big{|}Cov(u_{i}(A),v_{j}(B)% )\big{|}^{2},= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_POSTSUPERSCRIPT | italic_C italic_o italic_v ( italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_A ) , italic_v start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_B ) ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,

where F\|\cdot\|_{F}∥ ⋅ ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT is the Frobenius norm, and C𝐮(A),𝐯(B)nA×nBsubscript𝐶𝐮𝐴𝐯𝐵superscriptsubscript𝑛𝐴subscript𝑛𝐵C_{\mathbf{u}(A),\mathbf{v}(B)}\in\mathcal{R}^{n_{A}\times n_{B}}italic_C start_POSTSUBSCRIPT bold_u ( italic_A ) , bold_v ( italic_B ) end_POSTSUBSCRIPT ∈ caligraphic_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT × italic_n start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the cross-covariance matrix of random Fourier features 𝐮(A)𝐮𝐴\mathbf{u}(A)bold_u ( italic_A ) and 𝐮(B)𝐮𝐵\mathbf{u}(B)bold_u ( italic_B ) containing entries:

𝐮(A)=(u1(A),u2(A),,unA(A)),ui(A)RFF,i,formulae-sequence𝐮𝐴subscript𝑢1𝐴subscript𝑢2𝐴subscript𝑢subscript𝑛𝐴𝐴subscript𝑢𝑖𝐴subscriptRFFfor-all𝑖\displaystyle\mathbf{u}(A)=(u_{1}(A),u_{2}(A),\dots,u_{n_{A}}(A)),u_{i}(A)\in% \mathcal{H}_{\text{RFF}},\forall i,bold_u ( italic_A ) = ( italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_A ) , italic_u start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_A ) , … , italic_u start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_A ) ) , italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_A ) ∈ caligraphic_H start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT , ∀ italic_i , (8)
𝐯(B)=(v1(B),v2(B),,vnB(B)),vj(B)RFF,j,formulae-sequence𝐯𝐵subscript𝑣1𝐵subscript𝑣2𝐵subscript𝑣subscript𝑛𝐵𝐵subscript𝑣𝑗𝐵subscriptRFFfor-all𝑗\displaystyle\mathbf{v}(B)=(v_{1}(B),v_{2}(B),\dots,v_{n_{B}}(B)),v_{j}(B)\in% \mathcal{H}_{\text{RFF}},\forall j,bold_v ( italic_B ) = ( italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_B ) , italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_B ) , … , italic_v start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_B ) ) , italic_v start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_B ) ∈ caligraphic_H start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT , ∀ italic_j ,

where nAsubscript𝑛𝐴n_{A}italic_n start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and nBsubscript𝑛𝐵n_{B}italic_n start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT denote the number of functions from RFFsubscriptRFF\mathcal{H}_{\text{RFF}}caligraphic_H start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT. The accuracy of the statistics HSICRFFsubscriptHSICRFF\text{HSIC}_{\text{RFF}}HSIC start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT increases as the values of nAsubscript𝑛𝐴n_{A}italic_n start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and nBsubscript𝑛𝐵n_{B}italic_n start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT, defaulting to 5, become larger.

Motivated by [39, 38], our Independence Regularizer coherently optimizes sample weights 𝐰𝐰\mathbf{w}bold_w by decorrelating all features in covariates (or its representations) 𝐗n×m𝐗superscript𝑛𝑚\mathbf{X}\in\mathbb{R}^{n\times m}bold_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_m end_POSTSUPERSCRIPT. That is, for any two features 𝐗:,a,𝐗:,b𝐗subscript𝐗:𝑎subscript𝐗:𝑏𝐗\mathbf{X}_{:,a},\mathbf{X}_{:,b}\in\mathbf{X}bold_X start_POSTSUBSCRIPT : , italic_a end_POSTSUBSCRIPT , bold_X start_POSTSUBSCRIPT : , italic_b end_POSTSUBSCRIPT ∈ bold_X, the weighted statistics HSICRFFsubscriptHSICRFF\text{HSIC}_{\text{RFF}}HSIC start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT, denoted by HSICRFFwsubscriptsuperscriptHSIC𝑤RFF\text{HSIC}^{w}_{\text{RFF}}HSIC start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT, should be close to zero. Formally, for 𝐗:,a,𝐗:,b𝐗for-allsubscript𝐗:𝑎subscript𝐗:𝑏𝐗\forall\mathbf{X}_{:,a},\mathbf{X}_{:,b}\in\mathbf{X}∀ bold_X start_POSTSUBSCRIPT : , italic_a end_POSTSUBSCRIPT , bold_X start_POSTSUBSCRIPT : , italic_b end_POSTSUBSCRIPT ∈ bold_X,

HSICRFFw(𝐗:,a,𝐗:,b,𝐰)=C𝐮(𝐗:,a),𝐯(𝐗:,b)wF2subscriptsuperscriptHSIC𝑤RFFsubscript𝐗:𝑎subscript𝐗:𝑏𝐰subscriptsuperscriptnormsubscriptsuperscript𝐶𝑤𝐮subscript𝐗:𝑎𝐯subscript𝐗:𝑏2𝐹\displaystyle\text{HSIC}^{w}_{\text{RFF}}(\mathbf{X}_{:,a},\mathbf{X}_{:,b},% \mathbf{w})=\big{\|}C^{w}_{\mathbf{u}(\mathbf{X}_{:,a}),\mathbf{v}(\mathbf{X}_% {:,b})}\big{\|}^{2}_{F}HSIC start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT ( bold_X start_POSTSUBSCRIPT : , italic_a end_POSTSUBSCRIPT , bold_X start_POSTSUBSCRIPT : , italic_b end_POSTSUBSCRIPT , bold_w ) = ∥ italic_C start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_u ( bold_X start_POSTSUBSCRIPT : , italic_a end_POSTSUBSCRIPT ) , bold_v ( bold_X start_POSTSUBSCRIPT : , italic_b end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT (9)
=i=1nAj=1nB|Cov(ui(𝐰𝐗:,a),vj(𝐰𝐗:,b))|20.absentsuperscriptsubscript𝑖1subscript𝑛𝐴superscriptsubscript𝑗1subscript𝑛𝐵superscript𝐶𝑜𝑣subscript𝑢𝑖superscript𝐰topsubscript𝐗:𝑎subscript𝑣𝑗superscript𝐰topsubscript𝐗:𝑏20\displaystyle=\sum_{i=1}^{n_{A}}\sum_{j=1}^{n_{B}}\big{|}Cov(u_{i}(\mathbf{w}^% {\top}\mathbf{X}_{:,a}),v_{j}(\mathbf{w}^{\top}\mathbf{X}_{:,b}))\big{|}^{2}% \rightarrow 0.= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_POSTSUPERSCRIPT | italic_C italic_o italic_v ( italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_X start_POSTSUBSCRIPT : , italic_a end_POSTSUBSCRIPT ) , italic_v start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_X start_POSTSUBSCRIPT : , italic_b end_POSTSUBSCRIPT ) ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT → 0 .

The corresponding loss term can be denoted as:

D(𝐗,𝐰)=1abmHSICRFFw(𝐗:,a,𝐗:,b,𝐰).subscriptD𝐗𝐰subscript1𝑎𝑏𝑚subscriptsuperscriptHSIC𝑤RFFsubscript𝐗:𝑎subscript𝐗:𝑏𝐰\mathcal{L}_{\text{D}}(\mathbf{X},\mathbf{w})=\sum_{1\leq a\leq b\leq m}\text{% HSIC}^{w}_{\text{RFF}}(\mathbf{X}_{:,a},\mathbf{X}_{:,b},\mathbf{w}).caligraphic_L start_POSTSUBSCRIPT D end_POSTSUBSCRIPT ( bold_X , bold_w ) = ∑ start_POSTSUBSCRIPT 1 ≤ italic_a ≤ italic_b ≤ italic_m end_POSTSUBSCRIPT HSIC start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT ( bold_X start_POSTSUBSCRIPT : , italic_a end_POSTSUBSCRIPT , bold_X start_POSTSUBSCRIPT : , italic_b end_POSTSUBSCRIPT , bold_w ) . (10)

Following prior work [39, 38], we apply the loss term D(,)subscriptD\mathcal{L}_{\text{D}}(\cdot,\cdot)caligraphic_L start_POSTSUBSCRIPT D end_POSTSUBSCRIPT ( ⋅ , ⋅ ) to the last layer of the neural network 𝒵psuperscript𝒵𝑝\mathcal{Z}^{p}caligraphic_Z start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT, and thus obtain the independence loss of our Independence Regularizer I=D(𝒵p,𝐰)subscriptIsubscriptDsuperscript𝒵𝑝𝐰\mathcal{L}_{\text{I}}=\mathcal{L}_{\text{D}}(\mathcal{Z}^{p},\mathbf{w})caligraphic_L start_POSTSUBSCRIPT I end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT D end_POSTSUBSCRIPT ( caligraphic_Z start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT , bold_w ). This is done to ensure that stable representations can establish the most direct map** to the outcome.

Note that our Balancing Regularizer and Independence Regularizer are designed to handle the issue of selection bias and distribution shifts separately, which are both based on sample reweighting. Therefore, we propose to directly integrate the Balancing Regularizer and the Independence Regularizer to achieve stable and reliable HTE estimation. This approach is named Stable Balanced Representation Learning (SBRL).

IV-C Hierarchical-Attention Paradigm (HAP)

Although Balancing Regularizer and Independence Regularizer are able to solve selection bias and distribution shift separately, distribution shift in HTE estimation triggers an extra challenge, i.e., the contradiction between balance and dependence as stated in Section III. This challenge poses a significant obstacle to reconciling the Balancing Regularizer and Independence Regularizer methods, thereby making it difficult to achieve stable HTE estimates in OOD environments. To address this issue, we propose a Hierarchical-Attention Paradigm to form a coordinated and unified objective function.

The design of HAP stems from the following insight: applying decorrelation solely to the last layer of models, as traditional works suggest [39, 38], would induce interaction between the learning of balanced representation and independence-driven weights; one intuitive approach is to uniformly enforce decorrelation for each layer throughout the entire network. However, such indiscriminate constraints may lead to a large value for the independence loss and a relatively small value for the balance loss term, resulting in the disregard of the covariate balancing objective.

Consequently, we propose to divide the entire neural network into three priorities: the model’s last layer 𝐙pn×dpsuperscript𝐙𝑝superscript𝑛subscript𝑑𝑝\mathbf{Z}^{p}\in\mathbb{R}^{n\times d_{p}}bold_Z start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_d start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with the first priority, the layer 𝐙rn×drsuperscript𝐙𝑟superscript𝑛subscript𝑑𝑟\mathbf{Z}^{r}\in\mathbb{R}^{n\times d_{r}}bold_Z start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_d start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_POSTSUPERSCRIPT for balanced representations ΦΦ\Phiroman_Φ with the second priority and other fully connected layers {𝐙ion×do}i=1lsuperscriptsubscriptsubscriptsuperscript𝐙𝑜𝑖superscript𝑛subscript𝑑𝑜𝑖1𝑙\{\mathbf{Z}^{o}_{i}\in\mathbb{R}^{n\times d_{o}}\}_{i=1}^{l}{ bold_Z start_POSTSUPERSCRIPT italic_o end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT with the third priority. Then, besides the loss term IsubscriptI\mathcal{L}_{\text{I}}caligraphic_L start_POSTSUBSCRIPT I end_POSTSUBSCRIPT for 𝐙psuperscript𝐙𝑝\mathbf{Z}^{p}bold_Z start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT, we emphasize the necessity of the loss terms D(𝐙r,𝐰)subscriptDsuperscript𝐙𝑟𝐰\mathcal{L}_{\text{D}}(\mathbf{Z}^{r},\mathbf{w})caligraphic_L start_POSTSUBSCRIPT D end_POSTSUBSCRIPT ( bold_Z start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT , bold_w ) and D(𝐙o,𝐰)subscriptDsuperscript𝐙𝑜𝐰\mathcal{L}_{\text{D}}(\mathbf{Z}^{o},\mathbf{w})caligraphic_L start_POSTSUBSCRIPT D end_POSTSUBSCRIPT ( bold_Z start_POSTSUPERSCRIPT italic_o end_POSTSUPERSCRIPT , bold_w ) with hierarchical attention for thorough removal of the negative impact of unstable features.

By integrating Balancing Regularizer and Independence Regularizer with hierarchical attention, the Hierarchical-Attention Paradigm optimizes sample weights 𝐰𝐰\mathbf{w}bold_w with the following loss function 𝐰subscript𝐰\mathcal{L}_{\mathbf{w}}caligraphic_L start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT:

min𝐰𝐰=subscript𝐰subscript𝐰absent\displaystyle\min_{\mathbf{w}}\mathcal{L}_{\mathbf{w}}=roman_min start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT = αB+γ1I+γ2D(𝐙r,𝐰)+𝛼subscriptBsubscript𝛾1subscriptIlimit-fromsubscript𝛾2subscriptDsuperscript𝐙𝑟𝐰\displaystyle\alpha\cdot\mathcal{L}_{\text{B}}+\gamma_{1}\cdot\mathcal{L}_{% \text{I}}+\gamma_{2}\cdot\mathcal{L}_{\text{D}}(\mathbf{Z}^{r},\mathbf{w})+italic_α ⋅ caligraphic_L start_POSTSUBSCRIPT B end_POSTSUBSCRIPT + italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋅ caligraphic_L start_POSTSUBSCRIPT I end_POSTSUBSCRIPT + italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⋅ caligraphic_L start_POSTSUBSCRIPT D end_POSTSUBSCRIPT ( bold_Z start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT , bold_w ) + (11)
γ3i=1lD(𝐙io,𝐰)+𝐰,subscript𝛾3superscriptsubscript𝑖1𝑙subscriptDsubscriptsuperscript𝐙𝑜𝑖𝐰subscript𝐰\displaystyle\gamma_{3}\cdot\sum_{i=1}^{l}\mathcal{L}_{\text{D}}(\mathbf{Z}^{o% }_{i},\mathbf{w})+\mathcal{R}_{\mathbf{w}},italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ⋅ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT D end_POSTSUBSCRIPT ( bold_Z start_POSTSUPERSCRIPT italic_o end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_w ) + caligraphic_R start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ,

where 𝐰=1ni=1n(wi1)2subscript𝐰1𝑛superscriptsubscript𝑖1𝑛superscriptsubscript𝑤𝑖12\mathcal{R}_{\mathbf{w}}=\frac{1}{n}\sum_{i=1}^{n}(w_{i}-1)^{2}caligraphic_R start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - 1 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT avoids all the sample weights to be zero or model only focuses on some samples and ignores others. Besides, the value of hyper-parameters α𝛼\alphaitalic_α and {γ1,γ2,γ3}subscript𝛾1subscript𝛾2subscript𝛾3\{\gamma_{1},\gamma_{2},\gamma_{3}\}{ italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT } allows us to adjust the sensitivity to selection bias and distribution shift with hierarchical attention.

Algorithm 1 Stable Balanced Representation Learning with Hierarchical-Attention Paradigm
1:Observational dataset De={𝐱ie,tie,yiti,e}nsuperscript𝐷𝑒superscriptsuperscriptsubscript𝐱𝑖𝑒superscriptsubscript𝑡𝑖𝑒superscriptsubscript𝑦𝑖subscript𝑡𝑖𝑒𝑛D^{e}=\{\mathbf{x}_{i}^{e},t_{i}^{e},y_{i}^{t_{i},e}\}^{n}italic_D start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT = { bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_e end_POSTSUPERSCRIPT } start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT from environment e𝑒eitalic_e
2:y^0superscript^𝑦0\hat{y}^{0}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT, y^1superscript^𝑦1\hat{y}^{1}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT
3:Initialize network parameters 𝐖,𝐛𝐖𝐛\mathbf{W},\mathbf{b}bold_W , bold_b
4:Initialize sample weights 𝐰{1}n𝐰superscript1𝑛\mathbf{w}\leftarrow\{1\}^{n}bold_w ← { 1 } start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT
5:for i=0𝑖0i=0italic_i = 0 to \mathcal{I}caligraphic_I do
6:     Calculate loss function YwsuperscriptsubscriptY𝑤\mathcal{L}_{\text{Y}}^{w}caligraphic_L start_POSTSUBSCRIPT Y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT with parameters 𝐖,𝐛𝐖𝐛\mathbf{W},\mathbf{b}bold_W , bold_b and sample weights 𝐰𝐰\mathbf{w}bold_w
7:     Update 𝐖,𝐛𝐖𝐛\mathbf{W},\mathbf{b}bold_W , bold_b with gradient descent by fixing 𝐰𝐰\mathbf{w}bold_w
8:     Calculate loss function 𝐰subscript𝐰\mathcal{L}_{\mathbf{w}}caligraphic_L start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT with sample weights 𝐰𝐰\mathbf{w}bold_w
9:     Update 𝐰𝐰\mathbf{w}bold_w with gradient descent by fixing 𝐖,𝐛𝐖𝐛\mathbf{W},\mathbf{b}bold_W , bold_b
10:end for

IV-D Optimization and Training Procedure

By optimizing the loss function 𝐰subscript𝐰\mathcal{L}_{\mathbf{w}}caligraphic_L start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT, we can acquire the optimal sample weights 𝐰superscript𝐰\mathbf{w}^{*}bold_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT to guide the deep neural networks to achieve stable HTE estimation across OOD data. Note that our SBRL-HAP learns sample weights regardless of the model structure; hence, it is applicable to the backbone of nearly all balanced representation methods. We take the backbone of the most classic balanced representation algorithm, i.e., Counterfactual Regressor (CFR) [15], as an example, to illustrate our end-to-end training process.

The backbone of CFR contains two sub-modules, i.e., a shared representation network (Φ(𝐱i)Φsubscript𝐱𝑖\Phi(\mathbf{x}_{i})roman_Φ ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )) for representation extraction and multi-head predictive networks (hti(Φ(𝐱i)h_{t_{i}}(\Phi(\mathbf{x}_{i})italic_h start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( roman_Φ ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )) for potential outcome prediction.

The representation network is expected to provide a balanced representation Φ(𝐗)Φ𝐗\Phi(\mathbf{X})roman_Φ ( bold_X ), so as to remove distribution discrepancies between the treated group {Φ(𝐱i)}i:ti=1subscriptΦsubscript𝐱𝑖:𝑖subscript𝑡𝑖1\{\Phi(\mathbf{x}_{i})\}_{i:t_{i}=1}{ roman_Φ ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i : italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT and the control group {Φ(𝐱i)}i:ti=0subscriptΦsubscript𝐱𝑖:𝑖subscript𝑡𝑖0\{\Phi(\mathbf{x}_{i})\}_{i:t_{i}=0}{ roman_Φ ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i : italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 end_POSTSUBSCRIPT. Then, in HTE estimation, to avoid the treatment information being dominated by the high-dimensional covariates, the two-head networks ht=0(Φ)subscript𝑡0Φh_{t=0}(\Phi)italic_h start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT ( roman_Φ ) and ht=1(Φ)subscript𝑡1Φh_{t=1}(\Phi)italic_h start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT ( roman_Φ ) are adopted to predict outcomes in control and treated groups, with the prediction loss Ysubscript𝑌\mathcal{L}_{Y}caligraphic_L start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT:

minh0,h1Y=1ni=1nl(hti(Φ(𝐱i)),yiti)+l2,subscriptsubscript0subscript1subscriptY1𝑛superscriptsubscript𝑖1𝑛𝑙subscriptsubscript𝑡𝑖Φsubscript𝐱𝑖superscriptsubscript𝑦𝑖subscript𝑡𝑖subscriptsubscript𝑙2\min_{h_{0},h_{1}}\mathcal{L}_{\text{Y}}=\frac{1}{n}\sum_{i=1}^{n}l(h_{t_{i}}(% \Phi(\mathbf{x}_{i})),y_{i}^{t_{i}})+\mathcal{R}_{l_{2}},roman_min start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT Y end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_l ( italic_h start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( roman_Φ ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) + caligraphic_R start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , (12)

where l2subscriptsubscript𝑙2\mathcal{R}_{l_{2}}caligraphic_R start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT is l2subscript𝑙2l_{2}italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT regularization for hhitalic_h, and l(,)𝑙l(\cdot,\cdot)italic_l ( ⋅ , ⋅ ) encodes the loss function, i.e., mean squared loss (MSE) for continuous outcome, and cross-entropy error for binary outcome.

To guide the above neural networks to achieve stable and unbiased prediction, we propose to plug our SBRL-HAP module in Equation (12) by

minh0,h1Yw=1ni=1nwil(hti(Φ(𝐱i)),yiti)+l2,subscriptsubscript0subscript1superscriptsubscriptY𝑤1𝑛superscriptsubscript𝑖1𝑛subscript𝑤𝑖𝑙subscriptsubscript𝑡𝑖Φsubscript𝐱𝑖superscriptsubscript𝑦𝑖subscript𝑡𝑖subscriptsubscript𝑙2\min_{h_{0},h_{1}}\mathcal{L}_{\text{Y}}^{w}=\frac{1}{n}\sum_{i=1}^{n}w_{i}% \cdot l(h_{t_{i}}(\Phi(\mathbf{x}_{i})),y_{i}^{t_{i}})+\mathcal{R}_{l_{2}},roman_min start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT Y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ italic_l ( italic_h start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( roman_Φ ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) + caligraphic_R start_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , (13)

where {wi}i=1n𝐰superscriptsubscriptsubscript𝑤𝑖𝑖1𝑛superscript𝐰\{w_{i}\}_{i=1}^{n}\in\mathbf{w}^{*}{ italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ bold_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT are the optimal sample weights learnt with the loss function 𝐰subscript𝐰\mathcal{L}_{\mathbf{w}}caligraphic_L start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT.

Ultimately, we adopt an alternating training strategy to iteratively optimize the loss function YwsuperscriptsubscriptY𝑤\mathcal{L}_{\text{Y}}^{w}caligraphic_L start_POSTSUBSCRIPT Y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT for heterogeneous outcome prediction and the loss function 𝐰subscript𝐰\mathcal{L}_{\mathbf{w}}caligraphic_L start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT for stable and balanced representations. Algorithm 1 illustrates the details of the pseudo-code of our SBRL-HAP.

V Experiments

TABLE I: The results (mean±plus-or-minus\pm±std) of treatment effect estimation on synthetic data Syn_8_8_8_2Syn_8_8_8_2\text{Syn}\_8\_8\_8\_2Syn _ 8 _ 8 _ 8 _ 2 with different bias rate ρ𝜌\rhoitalic_ρ.
Metric PEHE (Mean±plus-or-minus\pmbold_±Std)
Bias Rate 𝝆=𝟑𝝆3\bm{\rho=-3}bold_italic_ρ bold_= bold_- bold_3 𝝆=2.5𝝆2.5\bm{\rho=-2.5}bold_italic_ρ bold_= bold_- bold_2.5 𝝆=1.5𝝆1.5\bm{\rho=-1.5}bold_italic_ρ bold_= bold_- bold_1.5 𝝆=1.3𝝆1.3\bm{\rho=-1.3}bold_italic_ρ bold_= bold_- bold_1.3 𝝆=1.3𝝆1.3\bm{\rho=1.3}bold_italic_ρ bold_= bold_1.3 𝝆=1.5𝝆1.5\bm{\rho=1.5}bold_italic_ρ bold_= bold_1.5 𝝆=2.5𝝆superscript2.5\bm{\rho=2.5}^{*}bold_italic_ρ bold_= bold_2.5 start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT 𝝆=𝟑𝝆3\bm{\rho=3}bold_italic_ρ bold_= bold_3
TARNet 0.565±plus-or-minus\pm±0.009 0.558±plus-or-minus\pm±0.007 0.567±plus-or-minus\pm±0.003 0.559±plus-or-minus\pm±0.006 0.461±plus-or-minus\pm±0.003 0.420±plus-or-minus\pm±0.005 0.363±plus-or-minus\pm±0.003 0.358±plus-or-minus\pm±0.003
+SBRL 0.474±plus-or-minus\pm±0.008 0.459±plus-or-minus\pm±0.008 0.492±plus-or-minus\pm±0.007 0.489±plus-or-minus\pm±0.009 0.410±plus-or-minus\pmbold_±0.007 0.377±plus-or-minus\pmbold_±0.004 0.341±plus-or-minus\pmbold_±0.004 0.332±plus-or-minus\pmbold_±0.004
+SBRL-HAP 0.440±plus-or-minus\pmbold_±0.005 0.435±plus-or-minus\pmbold_±0.007 0.442±plus-or-minus\pmbold_±0.008 0.462±plus-or-minus\pmbold_±0.004 0.444±plus-or-minus\pm±0.005 0.421±plus-or-minus\pm±0.006 0.404±plus-or-minus\pm±0.006 0.407±plus-or-minus\pm±0.005
CFR 0.559±plus-or-minus\pm±0.009 0.552±plus-or-minus\pm±0.007 0.563±plus-or-minus\pm±0.003 0.555±plus-or-minus\pm±0.006 0.459±plus-or-minus\pm±0.003 0.418±plus-or-minus\pm±0.005 0.363±plus-or-minus\pm±0.003 0.357±plus-or-minus\pm±0.003
+SBRL 0.475±plus-or-minus\pm±0.008 0.460±plus-or-minus\pm±0.008 0.492±plus-or-minus\pm±0.007 0.490±plus-or-minus\pm±0.009 0.410±plus-or-minus\pm±0.007 0.378±plus-or-minus\pm±0.004 0.341±plus-or-minus\pmbold_±0.004 0.332±plus-or-minus\pmbold_±0.004
+SBRL+HAP 0.419±plus-or-minus\pmbold_±0.005 0.412±plus-or-minus\pmbold_±0.005 0.429±plus-or-minus\pmbold_±0.004 0.433±plus-or-minus\pmbold_±0.005 0.401±plus-or-minus\pmbold_±0.007 0.374±plus-or-minus\pmbold_±0.006 0.354±plus-or-minus\pm±0.005 0.352±plus-or-minus\pm±0.005
DeRCFR 0.431±plus-or-minus\pm±0.007 0.439±plus-or-minus\pm±0.009 0.449±plus-or-minus\pm±0.007 0.455±plus-or-minus\pm±0.008 0.376±plus-or-minus\pm±0.005 0.338±plus-or-minus\pm±0.005 0.311±plus-or-minus\pm±0.004 0.306±plus-or-minus\pm±0.005
+SBRL 0.431±plus-or-minus\pm±0.005 0.429±plus-or-minus\pm±0.007 0.441±plus-or-minus\pm±0.004 0.446±plus-or-minus\pm±0.007 0.371±plus-or-minus\pm±0.006 0.335±plus-or-minus\pm±0.006 0.301±plus-or-minus\pm±0.006 0.293±plus-or-minus\pmbold_±0.002
+SBRL-HAP 0.350±plus-or-minus\pmbold_±0.006 0.353±plus-or-minus\pmbold_±0.009 0.373±plus-or-minus\pmbold_±0.006 0.374±plus-or-minus\pmbold_±0.009 0.340±plus-or-minus\pmbold_±0.006 0.312±plus-or-minus\pmbold_±0.006 0.295±plus-or-minus\pmbold_±0.006 0.295±plus-or-minus\pm±0.006
Improvement 25.0%percent25.0absent25.0\%\uparrow25.0 % ↑ 25.4%percent25.4absent25.4\%\uparrow25.4 % ↑ 23.8%percent23.8absent23.8\%\uparrow23.8 % ↑ 22.0%percent22.0absent22.0\%\uparrow22.0 % ↑ 12.6%percent12.6absent12.6\%\uparrow12.6 % ↑ 10.5%percent10.5absent10.5\%\uparrow10.5 % ↑ 5.1%percent5.1absent5.1\%\uparrow5.1 % ↑ 3.6%percent3.6absent3.6\%\uparrow3.6 % ↑
Metric ϵATEsubscriptbold-italic-ϵATE\bm{\epsilon_{\text{ATE}}}bold_italic_ϵ start_POSTSUBSCRIPT ATE end_POSTSUBSCRIPT (Mean±plus-or-minus\pm±Std)
Bias Rate 𝝆=𝟑𝝆3\bm{\rho=-3}bold_italic_ρ bold_= bold_- bold_3 𝝆=2.5𝝆2.5\bm{\rho=-2.5}bold_italic_ρ bold_= bold_- bold_2.5 𝝆=1.5𝝆1.5\bm{\rho=-1.5}bold_italic_ρ bold_= bold_- bold_1.5 𝝆=1.3𝝆1.3\bm{\rho=-1.3}bold_italic_ρ bold_= bold_- bold_1.3 𝝆=1.3𝝆1.3\bm{\rho=1.3}bold_italic_ρ bold_= bold_1.3 𝝆=1.5𝝆1.5\bm{\rho=1.5}bold_italic_ρ bold_= bold_1.5 𝝆=2.5𝝆superscript2.5\bm{\rho=2.5}^{*}bold_italic_ρ bold_= bold_2.5 start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT 𝝆=𝟑𝝆3\bm{\rho=3}bold_italic_ρ bold_= bold_3
TARNet 0.019±plus-or-minus\pmbold_±0.006 0.032±plus-or-minus\pm±0.008 0.012±plus-or-minus\pm±0.004 0.015±plus-or-minus\pm±0.005 0.021±plus-or-minus\pmbold_±0.008 0.021±plus-or-minus\pm±0.008 0.018±plus-or-minus\pm±0.006 0.021±plus-or-minus\pm±0.007
+SBRL 0.029±plus-or-minus\pm±0.005 0.040±plus-or-minus\pm±0.006 0.027±plus-or-minus\pm±0.004 0.026±plus-or-minus\pm±0.005 0.029±plus-or-minus\pm±0.011 0.020±plus-or-minus\pm±0.006 0.026±plus-or-minus\pm±0.006 0.029±plus-or-minus\pm±0.008
+SBRL-HAP 0.021±plus-or-minus\pm±0.006 0.025±plus-or-minus\pmbold_±0.009 0.012±plus-or-minus\pmbold_±0.004 0.015±plus-or-minus\pmbold_±0.005 0.023±plus-or-minus\pm±0.008 0.019±plus-or-minus\pmbold_±0.008 0.017±plus-or-minus\pmbold_±0.007 0.021±plus-or-minus\pmbold_±0.007
CFR 0.018±plus-or-minus\pmbold_±0.006 0.032±plus-or-minus\pm±0.008 0.012±plus-or-minus\pmbold_±0.004 0.014±plus-or-minus\pm±0.004 0.021±plus-or-minus\pmbold_±0.008 0.020±plus-or-minus\pm±0.008 0.018±plus-or-minus\pm±0.006 0.021±plus-or-minus\pm±0.007
+SBRL 0.029±plus-or-minus\pm±0.005 0.040±plus-or-minus\pm±0.006 0.028±plus-or-minus\pm±0.004 0.026±plus-or-minus\pm±0.005 0.030±plus-or-minus\pm±0.011 0.021±plus-or-minus\pm±0.006 0.027±plus-or-minus\pm±0.006 0.029±plus-or-minus\pm±0.008
+SBRL-HAP 0.019±plus-or-minus\pm±0.006 0.024±plus-or-minus\pmbold_±0.009 0.015±plus-or-minus\pm±0.005 0.013±plus-or-minus\pmbold_±0.004 0.024±plus-or-minus\pm±0.008 0.018±plus-or-minus\pmbold_±0.006 0.013±plus-or-minus\pmbold_±0.006 0.015±plus-or-minus\pmbold_±0.007
DeRCFR 0.017±plus-or-minus\pm±0.006 0.021±plus-or-minus\pmbold_±0.007 0.014±plus-or-minus\pm±0.004 0.020±plus-or-minus\pm±0.005 0.021±plus-or-minus\pmbold_±0.008 0.020±plus-or-minus\pm±0.007 0.019±plus-or-minus\pm±0.006 0.021±plus-or-minus\pm±0.006
+SBRL 0.021±plus-or-minus\pm±0.007 0.033±plus-or-minus\pm±0.005 0.024±plus-or-minus\pm±0.005 0.028±plus-or-minus\pm±0.006 0.027±plus-or-minus\pm±0.011 0.018±plus-or-minus\pm±0.005 0.022±plus-or-minus\pm±0.007 0.029±plus-or-minus\pm±0.008
+SBRL-HAP 0.013±plus-or-minus\pmbold_±0.003 0.023±plus-or-minus\pm±0.008 0.013±plus-or-minus\pmbold_±0.005 0.015±plus-or-minus\pmbold_±0.005 0.022±plus-or-minus\pm±0.009 0.013±plus-or-minus\pmbold_±0.005 0.019±plus-or-minus\pmbold_±0.007 0.021±plus-or-minus\pmbold_±0.008
Improvement 23.5%percent23.5absent23.5\%\uparrow23.5 % ↑ 25.0%percent25.0absent25.0\%\uparrow25.0 % ↑ 7.1%percent7.1absent7.1\%\uparrow7.1 % ↑ 25.0%percent25.0absent25.0\%\uparrow25.0 % ↑ 4.8%percent4.8absent4.8\%\downarrow4.8 % ↓ 35.0%percent35.0absent35.0\%\uparrow35.0 % ↑ 27.8%percent27.8absent27.8\%\uparrow27.8 % ↑ 28.6%percent28.6absent28.6\%\uparrow28.6 % ↑
  • *

    In this paper, we utilize synthetic data with ρ=2.5𝜌2.5\rho=2.5italic_ρ = 2.5 as the training population. The testing data with ρ=2.5𝜌2.5\rho=2.5italic_ρ = 2.5 can be regarded as the In-Distribution Population. As the parameter ρ𝜌\rhoitalic_ρ increases, the difference in distribution between the testing and training datasets also increases.

V-A Baselines

In this paper, we propose two model-agnostic frameworks, SBRL and SBRL-HAP111https://github.com/superpig99/SBRL-HAP, for estimating heterogeneous treatment effects across out-of-distribution environments. SBRL can be regarded as an ablation study of SBRL-HAP without HAP. In these frameworks, most existing representation balancing methods can be incorporated as backbones, because our methods only introduce BR, IR, and HAP as additional regularizers to constrain representation learning, without being tied to specific models. Below, to demonstrate the performance of our SBRL and SBRL-HAP in improving heterogeneous treatment effect estimation across OOD populations, we compare them to baselines and describe how we can combine SBRL and SBRL-HAP with each method:

  • TARNet [15] is a treatment-agnostic representation network algorithm with a shared representation network, which uses a two-head predictive network to predict the factual treated outcome and control outcome, separately. Since TARNet does not include balance regularization, we only incorporate Independence Regularize into TARNet as TARNet+SBRL. TARNet+SBRL-HAP achieves comprehensive feature decorrelation with hierarchical attention by Hierarchical-Attention Paradigm.

  • CFR [15, 42] employs IPM to measure the distribution distance between the treated and control groups, and learns a balanced representation by minimizing IPM to eliminate selection bias. By incorporating Balancing Regularization and Independence Regularization into CFR, we refer to it as CFR+SBRL. Furthermore, CFR+SBRL-HAP employs the Hierarchical-Attention Paradigm for comprehensive feature decorrelation through hierarchical attention mechanisms.

  • DeR-CFR [17] further considers confounder separation by learning representations for instrumental variables, confounding variables, and adjustment variables respectively. This enables a more precise evaluation of heterogeneous treatment effects. When incorporating the SBRL framework, we refer to it as DeR-CFR+SBRL. Additionally, when incorporating it into SBRL-HAP framework, we call it DeR-CFR+SBRL-HAP.

The aforementioned three baselines are the most classic solutions to the traditional HTE estimation problem within in-distribution populations, we use them as Vanilla models to compare them with +SBRL and +SBRL-HAP models. Other balanced representation methods, such as RCFR [43], CFR-ISW [16], SITE [22], and DR-CFR [18], are built upon these baselines, and these methods have not exceeded the performance of DeR-CFR [17]. Consequently, we only combine SBRL and SBRL-HAP with TARNet, CFR, and DeR-CFR to study the performance of our methods in estimating HTE across OOD populations.

V-B Metrics

Following previous work [15, 17], we adopt the Precision in Estimation of Heterogeneous Effect (PEHE) [58] and the bias of ATE prediction (ϵATEsubscriptitalic-ϵATE\epsilon_{\text{ATE}}italic_ϵ start_POSTSUBSCRIPT ATE end_POSTSUBSCRIPT) to evaluate the individual-level and population-level performance respectively, where PEHE=1ni=1n((y^i1y^i0)(yi1yi0))2PEHE1𝑛superscriptsubscript𝑖1𝑛superscriptsuperscriptsubscript^𝑦𝑖1superscriptsubscript^𝑦𝑖0superscriptsubscript𝑦𝑖1superscriptsubscript𝑦𝑖02\text{PEHE}=\sqrt{\frac{1}{n}\sum_{i=1}^{n}((\hat{y}_{i}^{1}-\hat{y}_{i}^{0})-% (y_{i}^{1}-y_{i}^{0}))^{2}}PEHE = square-root start_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( ( over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT - over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) - ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT - italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG and ϵATE=|ATEATE^|subscriptitalic-ϵATE𝐴𝑇𝐸^𝐴𝑇𝐸\epsilon_{\text{ATE}}=|ATE-\hat{ATE}|italic_ϵ start_POSTSUBSCRIPT ATE end_POSTSUBSCRIPT = | italic_A italic_T italic_E - over^ start_ARG italic_A italic_T italic_E end_ARG |. Smaller values of these two metrics indicate better model performance.

Besides, popular evaluation metrics for prediction tasks, such as F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT Score [59], are also adopted to assist in evaluating the model performance. We utilize the average and stability of error [49] to evaluate the generalization performance. For example, the average of F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT Score is defined as F1¯=1||eF1e¯subscript𝐹11subscript𝑒superscriptsubscript𝐹1𝑒\bar{F_{1}}=\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}F_{1}^{e}over¯ start_ARG italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT, and the stability of F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT Score is F1std=1||e(F1eF1¯)2superscriptsubscript𝐹1𝑠𝑡𝑑1subscript𝑒superscriptsuperscriptsubscript𝐹1𝑒¯subscript𝐹12F_{1}^{std}=\frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}}(F_{1}^{e}-\bar{F_{1}% })^{2}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s italic_t italic_d end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG | caligraphic_E | end_ARG ∑ start_POSTSUBSCRIPT italic_e ∈ caligraphic_E end_POSTSUBSCRIPT ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT - over¯ start_ARG italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. Lower values of these two indicators mean better model stability.

V-C Experimental Settings

In the experiment, we utilize ELU as the non-linear activation function and adopt the Adam optimizer to train all methods. We set the maximum number of iterations to 3000300030003000. Besides, we apply an exponentially decaying learning rate [60] and report the best-evaluated iterate with early stop**. We first identify the optimal hyper-parameters for all baseline algorithms by optimizing hyper-parameters with trails on random search [61]. Then, with the fixed basic hyper-parameters, we conduct a random search for the hyper-parameters {γ1,γ2,γ3}subscript𝛾1subscript𝛾2subscript𝛾3\{\gamma_{1},\gamma_{2},\gamma_{3}\}{ italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT } of HSIC losses in the scope {0.0001,0.001,0.01,0.1,1,10,100}0.00010.0010.010.1110100\{0.0001,0.001,0.01,0.1,1,10,100\}{ 0.0001 , 0.001 , 0.01 , 0.1 , 1 , 10 , 100 } to optimize our model.

V-D Experiments on Synthetic Data

V-D1 Datasets

To simulate complex real-world scenarios, the synthetic data used in our study incorporates several key factors: 1) The observed covariates include not only confounding variables but also other relevant factors; 2) The imbalanced treatment assignment would introduce selection bias, reflecting the inherent biases that exist in observational studies; and 3) The synthetic data also incorporates distribution shifts that occur across different environments or populations. We generate synthetic data using the following process.

Covariates generation. We generate covariates from a multi-variable normal distribution, i.e., X1,X2,,Xm𝒩(0,1)similar-tosubscript𝑋1subscript𝑋2subscript𝑋𝑚𝒩01X_{1},X_{2},\ldots,X_{m}\sim\mathcal{N}(0,1)italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , 1 ), where m=mI+mC+mA+mV𝑚subscript𝑚𝐼subscript𝑚𝐶subscript𝑚𝐴subscript𝑚𝑉m=m_{I}+m_{C}+m_{A}+m_{V}italic_m = italic_m start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT + italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_m start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT + italic_m start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT denotes the dimension of covariates, and {mI,mC,mA,mV}subscript𝑚𝐼subscript𝑚𝐶subscript𝑚𝐴subscript𝑚𝑉\{m_{I},m_{C},m_{A},m_{V}\}{ italic_m start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT } denote the dimensions of instruments I𝐼Iitalic_I, confounders C𝐶Citalic_C, adjustments A𝐴Aitalic_A and noise V𝑉Vitalic_V, respectively. For generality, we design two settings of variable dimensions {mI,mC,mA,mV}={8,8,8,2}subscript𝑚𝐼subscript𝑚𝐶subscript𝑚𝐴subscript𝑚𝑉8882\{m_{I},m_{C},m_{A},m_{V}\}=\{8,8,8,2\}{ italic_m start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , italic_m start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT } = { 8 , 8 , 8 , 2 } or {16,16,16,2}1616162\{16,16,16,2\}{ 16 , 16 , 16 , 2 } with the sample size n=10000𝑛10000n=10000italic_n = 10000, and denoted different setting as Syn_mI_mC_mA_mVSyn_subscript𝑚𝐼_subscript𝑚𝐶_subscript𝑚𝐴_subscript𝑚𝑉\text{Syn}\_m_{I}\_m_{C}\_m_{A}\_m_{V}Syn _ italic_m start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT _ italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT _ italic_m start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT _ italic_m start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT.

Treatments generation. We produce treatment t(11+ez)similar-to𝑡11superscript𝑒𝑧t\sim\mathcal{B}(\frac{1}{1+e^{-z}})italic_t ∼ caligraphic_B ( divide start_ARG 1 end_ARG start_ARG 1 + italic_e start_POSTSUPERSCRIPT - italic_z end_POSTSUPERSCRIPT end_ARG ), where z=110θt×XIC+ξ𝑧110subscript𝜃𝑡subscript𝑋𝐼𝐶𝜉z=\frac{1}{10}\theta_{t}\times X_{IC}+\xiitalic_z = divide start_ARG 1 end_ARG start_ARG 10 end_ARG italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT × italic_X start_POSTSUBSCRIPT italic_I italic_C end_POSTSUBSCRIPT + italic_ξ, XICsubscript𝑋𝐼𝐶X_{IC}italic_X start_POSTSUBSCRIPT italic_I italic_C end_POSTSUBSCRIPT denotes the covariates that belong to I𝐼Iitalic_I and C𝐶Citalic_C, and θt𝒰((8,16)mI+mC)similar-tosubscript𝜃𝑡𝒰superscript816subscript𝑚𝐼subscript𝑚𝐶\theta_{t}\sim\mathcal{U}((8,16)^{m_{I}+m_{C}})italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ caligraphic_U ( ( 8 , 16 ) start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT + italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ).

Outcomes generation. Two potential outcomes are generated as follows: Y0=sign(max(0,z0z0¯))superscript𝑌0sign0superscript𝑧0¯superscript𝑧0Y^{0}=\text{sign}(\max(0,z^{0}-\bar{z^{0}}))italic_Y start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT = sign ( roman_max ( 0 , italic_z start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT - over¯ start_ARG italic_z start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT end_ARG ) ) and Y1=sign(max(0,z1z1¯))superscript𝑌1sign0superscript𝑧1¯superscript𝑧1Y^{1}=\text{sign}(\max(0,z^{1}-\bar{z^{1}}))italic_Y start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT = sign ( roman_max ( 0 , italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT - over¯ start_ARG italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_ARG ) ), where z0=110θy0×XCAmC+mAsuperscript𝑧0110subscript𝜃superscript𝑦0subscript𝑋𝐶𝐴subscript𝑚𝐶subscript𝑚𝐴z^{0}=\frac{1}{10}\frac{\theta_{y^{0}}\times X_{CA}}{m_{C}+m_{A}}italic_z start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG 10 end_ARG divide start_ARG italic_θ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT × italic_X start_POSTSUBSCRIPT italic_C italic_A end_POSTSUBSCRIPT end_ARG start_ARG italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_m start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT end_ARG, z1=110θy1×XCA2mC+mAsuperscript𝑧1110subscript𝜃superscript𝑦1subscriptsuperscript𝑋2𝐶𝐴subscript𝑚𝐶subscript𝑚𝐴z^{1}=\frac{1}{10}\frac{\theta_{y^{1}}\times X^{2}_{CA}}{m_{C}+m_{A}}italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG 10 end_ARG divide start_ARG italic_θ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT × italic_X start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_C italic_A end_POSTSUBSCRIPT end_ARG start_ARG italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_m start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT end_ARG, and θy0,θy1𝒰((8,16)mC+mA)similar-tosubscript𝜃subscript𝑦0subscript𝜃subscript𝑦1𝒰superscript816subscript𝑚𝐶subscript𝑚𝐴\theta_{y_{0}},\theta_{y_{1}}\sim\mathcal{U}((8,16)^{m_{C}+m_{A}})italic_θ start_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∼ caligraphic_U ( ( 8 , 16 ) start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_m start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) and ξ𝒩(0,1)similar-to𝜉𝒩01\xi\sim\mathcal{N}(0,1)italic_ξ ∼ caligraphic_N ( 0 , 1 ). The observed outcome is Y=TY1+(1T)Y0𝑌𝑇superscript𝑌11𝑇superscript𝑌0Y=TY^{1}+(1-T)Y^{0}italic_Y = italic_T italic_Y start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT + ( 1 - italic_T ) italic_Y start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT.

Finally, to simulate the distribution shift, we generate different covariate distributions by biased sampling. For each sample, we select it with probability Pr=XiXV|ρ|10DiPrsubscriptproductsubscript𝑋𝑖subscript𝑋𝑉superscript𝜌10subscript𝐷𝑖\Pr=\prod_{X_{i}\in X_{V}}|\rho|^{-10*D_{i}}roman_Pr = ∏ start_POSTSUBSCRIPT italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_ρ | start_POSTSUPERSCRIPT - 10 ∗ italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, where Di=|Y1Y0sign(ρ)Xi|subscript𝐷𝑖superscript𝑌1superscript𝑌0sign𝜌subscript𝑋𝑖D_{i}=|Y^{1}-Y^{0}-\text{sign}(\rho)*X_{i}|italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = | italic_Y start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT - italic_Y start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT - sign ( italic_ρ ) ∗ italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT |. If ρ>0𝜌0\rho>0italic_ρ > 0, sign(ρ)=1sign𝜌1\text{sign}(\rho)=1sign ( italic_ρ ) = 1; otherwise, sign(ρ)=1sign𝜌1\text{sign}(\rho)=-1sign ( italic_ρ ) = - 1. We generate different data distributions by altering the bias rate ρ{3.0,2.5,1.5,1.3,1.3,1.5,2.5,3.0}𝜌3.02.51.51.31.31.52.53.0\rho\in\{-3.0,-2.5,-1.5,-1.3,1.3,1.5,2.5,3.0\}italic_ρ ∈ { - 3.0 , - 2.5 , - 1.5 , - 1.3 , 1.3 , 1.5 , 2.5 , 3.0 }, where ρ>1𝜌1\rho>1italic_ρ > 1 implies the positive correlation between outcome Y𝑌Yitalic_Y and unstable features XVsubscript𝑋𝑉X_{V}italic_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT, and ρ<1𝜌1\rho<-1italic_ρ < - 1 implies the negative correlation. The higher |ρ|𝜌|\rho|| italic_ρ | is, the stronger correlation between Y𝑌Yitalic_Y and XVsubscript𝑋𝑉X_{V}italic_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT. Therefore, different values of ρ𝜌\rhoitalic_ρ refer to different environments. To evaluate the generalization of our SBRL and SBRL-HAP frameworks, we use the generated data with ρ=2.5𝜌2.5\rho=2.5italic_ρ = 2.5 as default training data, and use the data with different ρ{3.0,2.5,1.5,1.3,1.3,1.5,2.5,3.0}𝜌3.02.51.51.31.31.52.53.0\rho\in\{-3.0,-2.5,-1.5,-1.3,1.3,1.5,2.5,3.0\}italic_ρ ∈ { - 3.0 , - 2.5 , - 1.5 , - 1.3 , 1.3 , 1.5 , 2.5 , 3.0 } as testing data with different environments.

Refer to caption
Refer to caption
Refer to caption
Figure 3: Results of PEHE on synthetic data Syn_16_16_16_2Syn_16_16_16_2\text{Syn}\_16\_16\_16\_2Syn _ 16 _ 16 _ 16 _ 2 with different bias rate ρ𝜌\rhoitalic_ρ for the testing set. All models are trained with ρ=2.5𝜌2.5\rho=2.5italic_ρ = 2.5.
Refer to caption
(a) F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT scores for factual outcomes.
Refer to caption
(b) F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT scores for counterfactual outcomes.
Figure 4: Results of F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT scores on synthetic data Syn_16_16_16_2Syn_16_16_16_2\text{Syn}\_16\_16\_16\_2Syn _ 16 _ 16 _ 16 _ 2 with different bias rate ρ𝜌\rhoitalic_ρ for the testing set. All models are trained with ρ=2.5𝜌2.5\rho=2.5italic_ρ = 2.5.

V-D2 Results of treatment effect estimation

Results of treatment effect estimation on synthetic data are shown in Table I and Fig. 3. Table I reveals that both SBRL and SBRL-HAP effectively boost the stability of ITE estimations across diverse OOD data, while presenting a comparable performance in ATE evaluation compared to the vanilla methods. According to Table I, with the increasing distribution discrepancy between the testing set and the training set, the error metric PEHE of all methods gets worse. Our methods, however, show success in counteracting this performance degradation, and exhibit a more obvious improvement as the bias rate ρ𝜌\rhoitalic_ρ decreases from 2.52.52.52.5 to 33-3- 3, resulting in the maximum reduction of PEHE from 5.1%percent5.15.1\%5.1 % to 25%percent2525\%25 %. To validate the robustness of our method for high-dimensional data, we report the results of effect estimation on Syn_16_16_16_2Syn_16_16_16_2\text{Syn}\_16\_16\_16\_2Syn _ 16 _ 16 _ 16 _ 2 data. Fig. 3 depicts the excellent performance of our method on high-dimensional data. From results on Syn_16_16_16_2Syn_16_16_16_2\text{Syn}\_16\_16\_16\_2Syn _ 16 _ 16 _ 16 _ 2 data, we have following observations and analysis:

  • Three baselines fail to handle the problem of HTE estimation accompanied by distribution shifts. On the testing data with ρ=2.5𝜌2.5\rho=2.5italic_ρ = 2.5, which shares the same distribution as the training test, PEHE is 0.4170.4170.4170.417, 0.4180.4180.4180.418, and 0.4220.4220.4220.422 for TARNet, CFR, and DeR-CFR. However, the performance of the baseline methods degrades gradually as the distribution gap between the testing data and the training data increases (i.e., as ρ𝜌\rhoitalic_ρ decreases). For instance, on the testing data with ρ=3𝜌3\rho=-3italic_ρ = - 3, PEHE of TARNet, CFR, and DeR-CFR worsens to 0.7400.7400.7400.740, 0.7280.7280.7280.728, and 0.6250.6250.6250.625, with the performance decrease222Performance decrease in OOD testing datasets is calculated by:
    Decrease=(PEHE{ρ=3}PEHE{ρ=2.5})/PEHE{ρ=2.5}DecreasesubscriptPEHE𝜌3subscriptPEHE𝜌2.5subscriptPEHE𝜌2.5\text{Decrease}=(\text{PEHE}_{\{\rho=-3\}}-\text{PEHE}_{\{\rho=2.5\}})/\text{% PEHE}_{\{\rho=2.5\}}Decrease = ( PEHE start_POSTSUBSCRIPT { italic_ρ = - 3 } end_POSTSUBSCRIPT - PEHE start_POSTSUBSCRIPT { italic_ρ = 2.5 } end_POSTSUBSCRIPT ) / PEHE start_POSTSUBSCRIPT { italic_ρ = 2.5 } end_POSTSUBSCRIPT.
    of 77%percent7777\%77 %, 74%percent7474\%74 %, and 56%percent5656\%56 %, respectively. Such performance degradation of baseline methods is anticipated, as they erroneously capture the spurious correlation between unstable variables XVsubscript𝑋𝑉X_{V}italic_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT and the target outcome Y𝑌Yitalic_Y.

  • Compared to other baselines, DeR-CFR exhibits superior resistance to distribution shift, whose performance degradation is about 20%percent2020\%20 % less than TARNet and CFR. This is attributed to DeR-CFR’s confounder separation, which orthogonalizes confounding, instrumental, and adjustment variables. It indicates that decorrelating variables is beneficial in learning genuine and stable relationships.

  • Both SBRL and SBRL-HAP achieve more stable HTE estimation across various OOD data. With distribution shifts (i.e., ρ𝜌\rhoitalic_ρ shifts from 2.52.52.52.5 to 33-3- 3), the PEHE of DeR-CFR+SBRL varies from 0.4280.4280.4280.428 to 0.6090.6090.6090.609, indicating a 42%percent4242\%42 % drop. By combining SRBL-HAP, the PEHE of DeR-CFR changes from 0.4880.4880.4880.488 to 0.5450.5450.5450.545, only reduced by 11%percent1111\%11 %. However, the PEHE of origin DeR-CFR declines by 56%percent5656\%56 %. This percentage demonstrates that our algorithm is more stable and the results are more robust in the OOD testing data. Besides, our algorithm exceeds all baselines on each OOD testing data (i.e., ρ[3,1.3]𝜌31.3\rho\in[-3,1.3]italic_ρ ∈ [ - 3 , 1.3 ]). For example, by combining our SBRL-HAP, the PEHE of DeR-CFR under ρ=3𝜌3\rho=-3italic_ρ = - 3 reduces from 0.6570.6570.6570.657 to 0.5450.5450.5450.545, with a 21%percent2121\%21 % performance improvement. It is because our algorithm resolves the conflict between balance and independence by hierarchical decorrelation, obtaining stable and balanced representations. Hence, our algorithm can improve the stability of HTE estimation.

  • Our approach outperforms baselines on OOD data (ρ<2.5𝜌2.5\rho<2.5italic_ρ < 2.5) but performs worse on ID data (ρ2.5𝜌2.5\rho\geq 2.5italic_ρ ≥ 2.5), which aligns with prior observations [62, 63, 64, 65, 66, 49, 67]. It is because unstable features tend to contribute to better inference in ID data [49]; however, our algorithm mitigates the influence of these features to prevent instability when estimating in OOD data.

Furthermore, Fig. 4 demonstrates that our method outperforms the other methods in stably predicting factual and counterfactual outcomes, as measured by F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT scores with mean and standard deviation (std) across all test sets. Especially, our SBRL-HAP reduces the std of F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT scores from 0.0580.0580.0580.058 to 0.0260.0260.0260.026 for factual outcomes and from 0.0400.0400.0400.040 to 0.0090.0090.0090.009 for counterfactual outcomes, compared to the best baseline (i.e., DeR-CFR). Consequently, our method can significantly improve the stability of HTE estimation.

Refer to caption
Figure 5: Nonlinear correlation among features in the balanced representation. As shown, the feature correlation is reduced by our SBRL, and further decreased by incorporating HAP.

V-D3 Decorrelation Performance

We demonstrate the nonlinear correlation between features in the balanced representation ΦΦ\Phiroman_Φ to illustrate the effectiveness of our method in mitigating the conflict between balance and independence. Specifically, we randomly sample 25-dimension variables from the balanced representation learned by CFR, CFR+SBRL, and CFR+SBRL-HAP on data Syn_16_16_16_2Syn_16_16_16_2\text{Syn}\_16\_16\_16\_2Syn _ 16 _ 16 _ 16 _ 2, and compute HSICRFFsubscriptHSICRFF\text{HSIC}_{\text{RFF}}HSIC start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT between each pair of variables. As shown in Fig. 5, the balanced representation obtained from CFR exhibits strong correlation between features, with average HSICRFF=0.85subscriptHSICRFF0.85\text{HSIC}_{\text{RFF}}=0.85HSIC start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT = 0.85, while direct integration of representation balancing and stable training techniques (i.e., CFR+SBRL) reduces the average HSICRFFsubscriptHSICRFF\text{HSIC}_{\text{RFF}}HSIC start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT to 0.640.640.640.64. Notably, CFR+SBRL-HAP can further decrease the average HSICRFFsubscriptHSICRFF\text{HSIC}_{\text{RFF}}HSIC start_POSTSUBSCRIPT RFF end_POSTSUBSCRIPT to 0.580.580.580.58, with 37%percent3737\%37 % reduction compared to CFR. Since the major difference between CFR+SBRL-HAP and CFR is the feature decorrelation with hierarchical attention, we can safely conclude that such feature decorrelation can promote the model to identify stable features and acquire more effective associations with potential outcomes, thus enhancing the generalization ability.

V-D4 Ablation Studies

Table II reports the effects of each sub-module of our SBRL-HAP by conducting ablation experiments on Syn_16_16_16_2Syn_16_16_16_2\text{Syn}\_16\_16\_16\_2Syn _ 16 _ 16 _ 16 _ 2 dataset. The observations are as follows: (1) Each component of our SBRL-HAP is indispensable since the absence of any one of them would hinder obtaining balanced and stable representations and damage the performance of HTE estimation on OOD data. (2) Compared to IR and BR, HAP has the greatest impact on the model’s performance on OOD populations of Syn_16_16_16_2Syn_16_16_16_2\text{Syn}\_16\_16\_16\_2Syn _ 16 _ 16 _ 16 _ 2 data.

V-E Experiments on Real-world Data

V-E1 Datasets

We also conduct experiments on two real-world datasets, Twins and IHDP, which are widely used in HTE estimation literature [16, 21, 15].

Twins333 http://www.nber.org/data/linked-birth-infant-death-data-vital-statistics-data.html.. The Twins dataset originates from twins birth in the USA between 1989 and 1991 [68]. The treatment corresponds to twins’ weight, where t=1𝑡1t=1italic_t = 1 indicates the heavier twin and t=0𝑡0t=0italic_t = 0 indicates the lighter one. The outcome corresponds to the twins’ mortality after one year. We collect records of same-sex twins weighing less than 2000g2000𝑔2000g2000 italic_g and without missing features, resulting in a total of 5271 records. The dataset consists of 43 variables X={X1,X2,,X43}𝑋subscript𝑋1subscript𝑋2subscript𝑋43X=\{X_{1},X_{2},\ldots,X_{43}\}italic_X = { italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT 43 end_POSTSUBSCRIPT }, of which XC={X1,X2,,X28}subscript𝑋𝐶subscript𝑋1subscript𝑋2subscript𝑋28X_{C}=\{X_{1},X_{2},\ldots,X_{28}\}italic_X start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT = { italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT 28 end_POSTSUBSCRIPT } are derived from the original data related to parents, pregnancy, and birth. In addition, 10 instrumental variables XI={X29,X30,,X38}subscript𝑋𝐼subscript𝑋29subscript𝑋30subscript𝑋38X_{I}=\{X_{29},X_{30},\ldots,X_{38}\}italic_X start_POSTSUBSCRIPT italic_I end_POSTSUBSCRIPT = { italic_X start_POSTSUBSCRIPT 29 end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT 30 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT 38 end_POSTSUBSCRIPT } and 5 unstable variables XV={X39,X40,,X43}subscript𝑋𝑉subscript𝑋39subscript𝑋40subscript𝑋43X_{V}=\{X_{39},X_{40},\ldots,X_{43}\}italic_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT = { italic_X start_POSTSUBSCRIPT 39 end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT 40 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT 43 end_POSTSUBSCRIPT } are generated with normal distribution 𝒩(0,1)𝒩01\mathcal{N}(0,1)caligraphic_N ( 0 , 1 ). To simulate selection bias, treatment is assigned as follows: ti|xi(11+ez)similar-toconditionalsubscript𝑡𝑖subscript𝑥𝑖11superscript𝑒𝑧t_{i}|x_{i}\sim\mathcal{B}(\frac{1}{1+e^{-z}})italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_B ( divide start_ARG 1 end_ARG start_ARG 1 + italic_e start_POSTSUPERSCRIPT - italic_z end_POSTSUPERSCRIPT end_ARG ), where z=wTXIC+η𝑧superscript𝑤𝑇subscript𝑋𝐼𝐶𝜂z=w^{T}X_{IC}+\etaitalic_z = italic_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_I italic_C end_POSTSUBSCRIPT + italic_η, w𝒰(0.1,0.1)similar-to𝑤𝒰0.10.1w\sim\mathcal{U}(-0.1,0.1)italic_w ∼ caligraphic_U ( - 0.1 , 0.1 ) and η𝒩(0,0.1)similar-to𝜂𝒩00.1\eta\sim\mathcal{N}(0,0.1)italic_η ∼ caligraphic_N ( 0 , 0.1 ). \mathcal{B}caligraphic_B denotes the Bernoulli distribution. Besides, to create distribution shift, we generate selection probabilities for each sample in the following way: Pr=XiXV|ρ|10DiPrsubscriptproductsubscript𝑋𝑖subscript𝑋𝑉superscript𝜌10subscript𝐷𝑖\Pr=\prod_{X_{i}\in X_{V}}|\rho|^{-10*D_{i}}roman_Pr = ∏ start_POSTSUBSCRIPT italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_ρ | start_POSTSUPERSCRIPT - 10 ∗ italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, where Di=|Y1Y0sign(ρ)Xi|subscript𝐷𝑖subscript𝑌1subscript𝑌0sign𝜌subscript𝑋𝑖D_{i}=|Y_{1}-Y_{0}-\text{sign}(\rho)*X_{i}|italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = | italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - sign ( italic_ρ ) ∗ italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT |. Here, we set ρ=2.5𝜌2.5\rho=-2.5italic_ρ = - 2.5. Based on the sample probabilities, 20%percent2020\%20 % records are sampled as the testing set. Then, the rest data is randomly split into a training/validation set using a 70/30 ratio. Repeat the above data partitioning for 10 rounds to form the final dataset.

TABLE II: Ablation experiments on the performance of each sub-module. (\checkmark refers to kee** the sub-module.)
BR (BsubscriptB\mathcal{L}_{\textbf{B}}caligraphic_L start_POSTSUBSCRIPT B end_POSTSUBSCRIPT) IR (IsubscriptI\mathcal{L}_{\textbf{I}}caligraphic_L start_POSTSUBSCRIPT I end_POSTSUBSCRIPT) HAP (HsubscriptH\mathcal{L}_{\textbf{H}}caligraphic_L start_POSTSUBSCRIPT H end_POSTSUBSCRIPT) PEHE
ρ=2.5𝜌2.5\rho=2.5italic_ρ = 2.5 ρ=3𝜌3\rho=-3italic_ρ = - 3
\checkmark \checkmark 0.457±plus-or-minus\pm±0.006 0.594±plus-or-minus\pm±0.002
\checkmark \checkmark 0.502±plus-or-minus\pm±0.007 0.584±plus-or-minus\pm±0.006
\checkmark \checkmark 0.439±plus-or-minus\pmbold_±0.006 0.662±plus-or-minus\pm±0.015
\checkmark \checkmark \checkmark 0.460±plus-or-minus\pm±0.007 0.591±plus-or-minus\pmbold_±0.004
  • * H=D(𝐙r,𝐰)+D(𝐙o,𝐰)subscriptHsubscriptDsuperscript𝐙𝑟𝐰subscriptDsuperscript𝐙𝑜𝐰\mathcal{L}_{\textbf{H}}=\mathcal{L}_{\textbf{D}}(\mathbf{Z}^{r},\mathbf{w})+% \mathcal{L}_{\textbf{D}}(\mathbf{Z}^{o},\mathbf{w})caligraphic_L start_POSTSUBSCRIPT H end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT D end_POSTSUBSCRIPT ( bold_Z start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT , bold_w ) + caligraphic_L start_POSTSUBSCRIPT D end_POSTSUBSCRIPT ( bold_Z start_POSTSUPERSCRIPT italic_o end_POSTSUPERSCRIPT , bold_w ).

IHDP444http://www.fredjo.com.. This is a binary-treatment and continuous-outcome dataset, generated from the Randomized Controlled Trial (RCT) data of the Infant Health and Development Program (IHDP) [58]. The RCT data of IHDP is collected to evaluate the effect of specialist home visits on the cognitive test scores of premature infants. Hill induced selection bias by removing a biased subset of the treated group, and Shuilte simulated outcomes by setting “A” of the NPCI package [69]. This dataset contains 747 units (139 treated, 608 control) with 25 covariates (6 continuous, 19 discrete) related to children and mothers. To introduce distribution shift, we biasedly sample 10%percent1010\%10 % records as the testing set with specific selection probabilities Pr=XiXl|ρ|10DiPrsubscriptproductsubscript𝑋𝑖subscript𝑋𝑙superscript𝜌10subscript𝐷𝑖\Pr=\prod_{X_{i}\in X_{l}}|\rho|^{-10*D_{i}}roman_Pr = ∏ start_POSTSUBSCRIPT italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_ρ | start_POSTSUPERSCRIPT - 10 ∗ italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, where Xlsubscript𝑋𝑙X_{l}italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT are continuous variables, and Di=|Y1Y0sign(ρ)Xi|subscript𝐷𝑖subscript𝑌1subscript𝑌0sign𝜌subscript𝑋𝑖D_{i}=|Y_{1}-Y_{0}-\text{sign}(\rho)*X_{i}|italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = | italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - sign ( italic_ρ ) ∗ italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT |. The remaining 90%percent9090\%90 % of records are divided randomly into training/validation with a 70/30 proportion. Different from the unstable variables XV𝒩(0,1)similar-tosubscript𝑋𝑉𝒩01X_{V}\sim\mathcal{N}(0,1)italic_X start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , 1 ) in Twins, we choose the subset of original variables in IHDP to introduce distribution shift. This approach aims to create a more complex scenario to verify the effectiveness of our method.

TABLE III: The results (mean±plus-or-minus\pm±std) of treatment effect estimation on real-world data. Our methods significantly improve the accuracy of HTE estimation on the testing set, with comparable performances on the training set compared to the baselines.
Twins
Metric PEHE (Mean±plus-or-minus\pmbold_±Std) ϵATEsubscriptbold-italic-ϵATE\bm{\epsilon_{\text{ATE}}}bold_italic_ϵ start_POSTSUBSCRIPT ATE end_POSTSUBSCRIPT (Mean±plus-or-minus\pmbold_±Std)
Dataset Training Validation Testing Training Validation Testing
TARNet 0.313±plus-or-minus\pm±0.010 0.342±plus-or-minus\pm±0.014 0.630±plus-or-minus\pm±0.012 0.024±plus-or-minus\pmbold_±0.005 0.028±plus-or-minus\pmbold_±0.007 0.355±plus-or-minus\pm±0.007
+SBRL 0.309±plus-or-minus\pm±0.011 0.336±plus-or-minus\pm±0.014 0.621±plus-or-minus\pm±0.009 0.026±plus-or-minus\pm±0.004 0.031±plus-or-minus\pm±0.006 0.348±plus-or-minus\pm±0.004
+SBRL-HAP 0.236±plus-or-minus\pmbold_±0.006 0.239±plus-or-minus\pmbold_±0.007 0.547±plus-or-minus\pmbold_±0.003 0.057±plus-or-minus\pm±0.001 0.056±plus-or-minus\pm±0.002 0.321±plus-or-minus\pmbold_±0.002
CFR 0.294±plus-or-minus\pm±0.013 0.313±plus-or-minus\pm±0.018 0.613±plus-or-minus\pm±0.012 0.024±plus-or-minus\pm±0.004 0.025±plus-or-minus\pm±0.005 0.352±plus-or-minus\pm±0.005
+SBRL 0.287±plus-or-minus\pm±0.014 0.307±plus-or-minus\pm±0.018 0.611±plus-or-minus\pm±0.013 0.020±plus-or-minus\pmbold_±0.005 0.023±plus-or-minus\pmbold_±0.006 0.356±plus-or-minus\pm±0.006
+SBRL-HAP 0.236±plus-or-minus\pmbold_±0.005 0.238±plus-or-minus\pmbold_±0.007 0.547±plus-or-minus\pmbold_±0.003 0.056±plus-or-minus\pm±0.001 0.056±plus-or-minus\pm±0.002 0.321±plus-or-minus\pmbold_±0.001
DeRCFR 0.229±plus-or-minus\pm±0.002 0.229±plus-or-minus\pm±0.003 0.585±plus-or-minus\pm±0.009 0.041±plus-or-minus\pm±0.013 0.040±plus-or-minus\pm±0.013 0.385±plus-or-minus\pm±0.013
+SBRL 0.229±plus-or-minus\pmbold_±0.002 0.229±plus-or-minus\pmbold_±0.003 0.584±plus-or-minus\pm±0.009 0.040±plus-or-minus\pmbold_±0.013 0.039±plus-or-minus\pmbold_±0.013 0.384±plus-or-minus\pm±0.013
+SBRL-HAP 0.236±plus-or-minus\pm±0.002 0.236±plus-or-minus\pm±0.004 0.552±plus-or-minus\pmbold_±0.006 0.048±plus-or-minus\pm±0.010 0.047±plus-or-minus\pm±0.011 0.330±plus-or-minus\pmbold_±0.011
IHDP
Metric PEHE (Mean±plus-or-minus\pmbold_±Std) ϵATEsubscriptbold-italic-ϵATE\bm{\epsilon_{\text{ATE}}}bold_italic_ϵ start_POSTSUBSCRIPT ATE end_POSTSUBSCRIPT (Mean±plus-or-minus\pmbold_±Std)
Dataset Training Validation Testing Training Validation Testing
TARNet 0.620±plus-or-minus\pmbold_±0.042 0.677±plus-or-minus\pmbold_±0.056 0.857±plus-or-minus\pm±0.098 0.200±plus-or-minus\pm±0.026 0.199±plus-or-minus\pm±0.026 0.254±plus-or-minus\pm±0.037
+SBRL 0.622±plus-or-minus\pm±0.042 0.683±plus-or-minus\pm±0.057 0.834±plus-or-minus\pm±0.093 0.184±plus-or-minus\pm±0.025 0.183±plus-or-minus\pm±0.025 0.250±plus-or-minus\pm±0.037
+SBRL-HAP 0.628±plus-or-minus\pm±0.041 0.696±plus-or-minus\pm±0.058 0.827±plus-or-minus\pmbold_±0.089 0.179±plus-or-minus\pmbold_±0.023 0.179±plus-or-minus\pmbold_±0.023 0.226±plus-or-minus\pmbold_±0.032
CFR 0.628±plus-or-minus\pm±0.042 0.687±plus-or-minus\pm±0.057 0.858±plus-or-minus\pm±0.099 0.197±plus-or-minus\pm±0.026 0.196±plus-or-minus\pm±0.026 0.259±plus-or-minus\pm±0.038
+SBRL 0.622±plus-or-minus\pmbold_±0.043 0.681±plus-or-minus\pmbold_±0.059 0.848±plus-or-minus\pm±0.094 0.196±plus-or-minus\pm±0.027 0.197±plus-or-minus\pm±0.027 0.251±plus-or-minus\pm±0.037
+SBRL-HAP 0.623±plus-or-minus\pm±0.038 0.688±plus-or-minus\pm±0.053 0.820±plus-or-minus\pmbold_±0.087 0.185±plus-or-minus\pmbold_±0.024 0.184±plus-or-minus\pmbold_±0.024 0.220±plus-or-minus\pmbold_±0.031
DeRCFR 0.460±plus-or-minus\pm±0.024 0.487±plus-or-minus\pm±0.029 0.607±plus-or-minus\pm±0.062 0.150±plus-or-minus\pm±0.022 0.152±plus-or-minus\pm±0.022 0.183±plus-or-minus\pm±0.025
+SBRL 0.450±plus-or-minus\pm±0.022 0.476±plus-or-minus\pmbold_±0.028 0.592±plus-or-minus\pm±0.062 0.141±plus-or-minus\pmbold_±0.019 0.143±plus-or-minus\pmbold_±0.019 0.181±plus-or-minus\pm±0.024
+SBRL-HAP 0.449±plus-or-minus\pmbold_±0.023 0.478±plus-or-minus\pm±0.029 0.573±plus-or-minus\pmbold_±0.057 0.151±plus-or-minus\pm±0.021 0.154±plus-or-minus\pm±0.021 0.178±plus-or-minus\pmbold_±0.024

V-E2 Results

We report the mean and standard deviation (std) of treatment effect over 10 replications on Twins and 100 replications on IHDP datasets in Table III. The results show that in comparison with state-of-the-art methods, our SBRL achieves significantly better performance on the testing set, while avoiding model overfitting and maintaining similar performance to the baseline methods on the training set. Especially on Twins, our proposed SBRL-HAP reduces the error metric PEHE by 13.1%percent13.113.1\%13.1 %, 10.8%percent10.810.8\%10.8 %, and 5.6%percent5.65.6\%5.6 % for TARNet, CFR, and DeR-CFR, as well as minimizes the ATE bias by 9.6%percent9.69.6\%9.6 %, 8.8%percent8.88.8\%8.8 %, and 14.3%percent14.314.3\%14.3 %.

Compared to synthetic datasets, the performance of our method on real-world datasets is enhanced, but the improvement is not stably significant. According to the characteristics and experiment results of Twins and IHDP datasets, we have the following observations. During the training process, the hierarchical independence measure of Twins dataset consistently remains significantly lower compared to the other datasets employed. Since most parents made similar pregnancy preparations, there is an abundance of similar or identical variables in Twins dataset, resulting in distribution differences that are not highly significant in different environments. Although our algorithm eliminated the OOD issue, the level of OOD is too low to indicate remarkable improvement. Similarly, due to limited distribution shift, the performance of our algorithm on IHDP dataset only improved by 2.3%15.1%similar-topercent2.3percent15.12.3\%\sim 15.1\%2.3 % ∼ 15.1 %. Furthermore, for IHDP dataset, we introduce a more complex covariate shift than the traditional settings [49, 53]: among the six continuous variables used for biased sampling, some may have causation with the outcome Y𝑌Yitalic_Y. Artificially introducing unstable correlation on these potentially stable features would make it difficult for the model to identify real stable representations.

TABLE IV: Optimal hyper-parameters of CFR+SBRL-HAP.
Hyper-parameters Twins IHDP Syn_8_8_8_2 Syn_16_16_16_2
learning rate 1e-5 1e-3 1e-5 1e-4
batch norm 1 0 1 1
rep normalization 1 1 0 0
{dr,dy}subscript𝑑𝑟subscript𝑑𝑦\{d_{r},d_{y}\}{ italic_d start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT } {3,3} {3,3} {3,3} {3,3}
{hr,hy}subscript𝑟subscript𝑦\{h_{r},h_{y}\}{ italic_h start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT , italic_h start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT } {128,64} {256,128} {128,64} {128,64}
{α,λ}𝛼𝜆\{\alpha,\lambda\}{ italic_α , italic_λ } {1e-4,1e-4} {1,1e-4} {5e-2,1e-4} {1e-3,1e-4}
{γ1,γ2,γ3}subscript𝛾1subscript𝛾2subscript𝛾3\{\gamma_{1},\gamma_{2},\gamma_{3}\}{ italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT } {1,1,1e-1} {1e-1,1e-4,1e-4} {1,1,1e-1} {1,1e-3,1e-3}
  • * Set α𝛼\alphaitalic_α to 00 to get the optimal hyper-parameters of TARNet+SBRL-HAP.

TABLE V: Optimal hyper-parameters of DeR-CFR+SBRL-HAP.
Hyper-parameters Twins IHDP Syn_8_8_8_2 Syn_16_16_16_2
learning rate 1e-1 1e-3 1e-4 5e-4
batch norm 1 0 1 1
rep normalization 1 1 0 0
{dr,dy,dt}subscript𝑑𝑟subscript𝑑𝑦subscript𝑑𝑡\{d_{r},d_{y},d_{t}\}{ italic_d start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } {3,3,2} {5,3,1} {2,2,3} {2,2,3}
{hr,hy,ht}subscript𝑟subscript𝑦subscript𝑡\{h_{r},h_{y},h_{t}\}{ italic_h start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT , italic_h start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } {256,128,128} {32,256,128} {256,256,256} {256,256,256}
{α,β,γ,μ,λ}𝛼𝛽𝛾𝜇𝜆\{\alpha,\beta,\gamma,\mu,\lambda\}{ italic_α , italic_β , italic_γ , italic_μ , italic_λ } {1e-2,5,1e-4,5,5} {10,5,1e-3,50,10} {1,1e-3,5,1,1} {1,1e-3,5,1,1}
{γ1,γ2,γ3}subscript𝛾1subscript𝛾2subscript𝛾3\{\gamma_{1},\gamma_{2},\gamma_{3}\}{ italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT } {1,1,1e-2} {1,1e-1,1e-2} {1,1e-2,1} {1,1e-2,1e-2}
  • * Refer to DeR-CFR [17] for the meaning of hyper-parameters {α,β,γ,μ,λ}𝛼𝛽𝛾𝜇𝜆\{\alpha,\beta,\gamma,\mu,\lambda\}{ italic_α , italic_β , italic_γ , italic_μ , italic_λ }.

V-F Hyper-parameter Analysis

Table  IV and Table V list all optimal hyper-parameters of our SBRL-HAP used for each dataset. Note that setting {γ1,γ2,γ3}subscript𝛾1subscript𝛾2subscript𝛾3\{\gamma_{1},\gamma_{2},\gamma_{3}\}{ italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT } to 00 in Table IV and Table V denotes the optimal hyper-parameters of our SBRL. Given that the hyper-parameters {γ1,γ2,γ3}subscript𝛾1subscript𝛾2subscript𝛾3\{\gamma_{1},\gamma_{2},\gamma_{3}\}{ italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT } determine the hierarchical attention for variable decorrelation, we investigate the impact of each hyper-parameter on the model’s performance and stability. As shown in Fig. 6, we report PEHE on data Syn_16_16_16_2Syn_16_16_16_2\text{Syn}\_16\_16\_16\_2Syn _ 16 _ 16 _ 16 _ 2 with ρ=2.5𝜌2.5\rho=2.5italic_ρ = 2.5 and F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT scores of factual outcomes with ρ=3𝜌3\rho=-3italic_ρ = - 3 by changing {γ1,γ2,γ3}subscript𝛾1subscript𝛾2subscript𝛾3\{\gamma_{1},\gamma_{2},\gamma_{3}\}{ italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT } in the scope {0,0.01,0.1,1,10,100}00.010.1110100\{0,0.01,0.1,1,10,100\}{ 0 , 0.01 , 0.1 , 1 , 10 , 100 }. Since PEHE is higher under γ1=0subscript𝛾10\gamma_{1}=0italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0 compared to that under γ1=100subscript𝛾1100\gamma_{1}=100italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 100, and PEHE is lower under γ2=0subscript𝛾20\gamma_{2}=0italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 0 than that under γ2=100subscript𝛾2100\gamma_{2}=100italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 100, we conclude that it is better to give relatively more attention to the last layer of models and comparatively less attention to the balanced representation layer. Besides, compared to γ1subscript𝛾1\gamma_{1}italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and γ2subscript𝛾2\gamma_{2}italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, the impact of γ3subscript𝛾3\gamma_{3}italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT on the model’s performance and stability is more complex. This is because γ3subscript𝛾3\gamma_{3}italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT controls attention to nearly all hidden layers, so that slight modifications in γ3subscript𝛾3\gamma_{3}italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT can result in significant changes in the entire loss. Hyper-parameters analysis assists us in identifying the most suitable hyper-parameters for experiments.

V-G Training Cost Analysis

In our method, the network structure and hierarchical-attention independence constraints are the primary contributors to the increased model complexity and training time. To investigate the complexity of all methods, we implement 10 replications on IHDP dataset to study the average training time(s) in a single execution, as shown in Table VI. Table VI indicates that our SBRL results in nearly twice the training cost than TARNet and CFR. This is due to the additional training process for sample weights compared to TARNet and CFR. Besides, our SBRL-HAP leads to over a 3-fold increase in training time of TARNet and CFR, and a 1.5-fold increase for DeR-CFR. Such an increase is primarily due to the hierarchical-attention optimization strategy. As the model complexity increases, both accuracy and stability of the model improve. Despite its higher computational time, our proposed method achieves the most stable and accurate treatment effect estimation. Fortunately, the maximum training time in a single execution is less than 180 seconds, which is still acceptable.

Refer to caption
(a) The PEHE error with ρ=2.5𝜌2.5\rho=2.5italic_ρ = 2.5.
Refer to caption
(b) The F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT score on factual outcomes with ρ=3𝜌3\rho=-3italic_ρ = - 3.
Figure 6: Hyper-parameter sensitivity analysis on {γ1,γ2,γ3}subscript𝛾1subscript𝛾2subscript𝛾3\{\gamma_{1},\gamma_{2},\gamma_{3}\}{ italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT } within the specified range {0,0.01,0.1,1,10,100}00.010.1110100\{0,0.01,0.1,1,10,100\}{ 0 , 0.01 , 0.1 , 1 , 10 , 100 } on Syn_16_16_16_2Syn_16_16_16_2\text{Syn}\_16\_16\_16\_2Syn _ 16 _ 16 _ 16 _ 2 dataset. The reference red line indicates the optimal parameters for the setting.

Hardware configuration: CentOS Linux release 7.2 (Final) operating system with the AMD EPYC 7K62 48-Core CPU Processor, 1TB of RAM. Software configuration: Python 3.6.8 with TensorFlow 1.15.0, NumPy 1.19.5, Scikit-learn 0.24.2.

VI Conclusion and Future

In this paper, we first study the problem of the Heterogeneous Treatment Effect across Out-of-distribution Populations. Previous causal methods have primarily concentrated on addressing selection bias within in-distribution data. However, in real-world applications, where distribution shifts are common, these methods may face challenges in effectively handling OOD data. To achieve more accurate HTE estimation on OOD data, we propose a Stable Balanced Representation Learning with Hierarchical-Attention Paradigm (SBRL-HAP) to jointly address selection bias and distribution shift by synergistically optimizing a Balancing Regularizer and an Independence Regularizer in a Hierarchical-Attention Paradigm. One limitation is that when combining existing balanced representation methods with SBRL-HAP, the performance on in-distribution data may decrease compared to vanilla methods. Because vanilla methods would rely on the inductive bias from unstable features to improve performance on in-distribution data, which does not generalize well to OOD populations. One potential solution to find a balance between stability and performance is to incorporate a module that measures the OOD level between the target domain and the source domain. Based on the measured OOD level, it would be feasible to use interpolation or spline methods to boost our algorithm with conventional supervised learning, which is left to future work.

TABLE VI: Training time(s) of various methods in a single execution on IHDP dataset.
Method TARNet +SBRL +SBRL-HAP
Time (s) 22.4 40.6 79.7
Method CFR +SBRL +SBRL-HAP
Time (s) 25.3 40.8 80.1
Method DeR-CFR +SBRL +SBRL-HAP
Time (s) 96.4 112.1 140.5

Acknowledgment

This work was supported in part by National Key Research and Development Project of China (Grant No. 2023YFF0905502), Shenzhen Science and Technology Program (Grant No. RCYX20200714114523079 and JCYJ20220818101014030) and National Natural Science Foundation of China (Grant No. 62376243 and U20A20387). We would like to thank Tencent for supporting the research during Yuling Zhang’s internship. We also would like to thank Kuaishou for sponsoring the research. Anpeng Wu’s research was supported by the China Scholarship Council.

References

  • [1] Z. Chu, R. Li, S. Rathbun, and S. Li, “Continual Causal Inference with Incremental Observational Data,” Mar. 2023.
  • [2] M. Ai, B. Li, H. Gong, Q. Yu, S. Xue, Y. Zhang, Y. Zhang, and P. Jiang, “LBCF: A Large-Scale Budget-Constrained Causal Forest Algorithm,” in Proceedings of the ACM Web Conference 2022, ser. WWW ’22.   New York, NY, USA: Association for Computing Machinery, Apr. 2022, pp. 2310–2319.
  • [3] Z. Tan, S. Zhang, N. Hong, K. Kuang, Y. Yu, J. Yu, Z. Zhao, H. Yang, S. Pan, J. Zhou, and F. Wu, “Uncovering Causal Effects of Online Short Videos on Consumer Behaviors,” in Proceedings of the Fifteenth ACM International Conference on Web Search and Data Mining, ser. WSDM ’22.   New York, NY, USA: Association for Computing Machinery, Feb. 2022, pp. 997–1006.
  • [4] Y. Meng, S. Zhang, Z. Ye, B. Wang, Z. Wang, Y. Sun, Q. Liu, S. Yang, and D. Pei, “Causal Analysis of the Unsatisfying Experience in Realtime Mobile Multiplayer Games in the Wild,” in 2019 IEEE International Conference on Multimedia and Expo (ICME), Jul. 2019, pp. 1870–1875.
  • [5] S. Wager and S. Athey, “Estimation and Inference of Heterogeneous Treatment Effects using Random Forests,” Journal of the American Statistical Association, vol. 113, no. 523, pp. 1228–1242, Jul. 2018.
  • [6] J. Pearl, Causality.   Cambridge University Press., 2009.
  • [7] Z. Wang, X. Chen, R. Zhou, Q. Dai, Z. Dong, and J.-R. Wen, “Sequential Recommendation with User Causal Behavior Discovery,” in 2023 IEEE 39th International Conference on Data Engineering (ICDE), Apr. 2023, pp. 28–40.
  • [8] F. Zhu, M. Zhong, X. Yang, L. Li, L. Yu, T. Zhang, J. Zhou, C. Chen, F. Wu, G. Liu, and Y. Wang, “DCMT: A Direct Entire-Space Causal Multi-Task Framework for Post-Click Conversion Estimation,” in 2023 IEEE 39th International Conference on Data Engineering (ICDE), Apr. 2023, pp. 3113–3125.
  • [9] B. Youngmann, M. Cafarella, Y. Moskovitch, and B. Salimi, “On Explaining Confounding Bias,” in 2023 IEEE 39th International Conference on Data Engineering (ICDE), Apr. 2023, pp. 1846–1859.
  • [10] F. Shen, K. Heravi, O. Gomez, S. Galhotra, A. Gilad, S. Roy, and B. Salimi, “Causal What-If and How-To Analysis Using HypeR,” in 2023 IEEE 39th International Conference on Data Engineering (ICDE), Apr. 2023, pp. 3663–3666.
  • [11] P. R. Rosenbaum and D. B. Rubin, “The central role of the propensity score in observational studies for causal effects,” Biometrika, vol. 70, no. 1, pp. 41–55, 1983.
  • [12] P. R. Rosenbaum, “Model-Based Direct Adjustment,” Journal of the American Statistical Association, vol. 82, no. 398, pp. 387–394, Jun. 1987.
  • [13] S. Li, N. Vlassis, J. Kawale, and Y. Fu, “Matching via dimensionality reduction for estimation of treatment effects in digital marketing campaigns,” in Proceedings of the Twenty-Fifth International Joint Conference on Artificial Intelligence, ser. IJCAI’16.   New York, New York, USA: AAAI Press, Jul. 2016, pp. 3768–3774.
  • [14] L. Yao, Z. Chu, S. Li, Y. Li, J. Gao, and A. Zhang, “A Survey on Causal Inference,” ACM Transactions on Knowledge Discovery from Data, vol. 15, no. 5, pp. 1–46, Oct. 2021.
  • [15] U. Shalit, F. D. Johansson, and D. Sontag, “Estimating individual treatment effect: Generalization bounds and algorithms,” in Proceedings of the 34th International Conference on Machine Learning.   PMLR, Jul. 2017, pp. 3076–3085.
  • [16] N. Hassanpour and R. Greiner, “CounterFactual Regression with Importance Sampling Weights,” in Proceedings of the Twenty-Eighth International Joint Conference on Artificial Intelligence.   Macao, China: International Joint Conferences on Artificial Intelligence Organization, Aug. 2019, pp. 5880–5887.
  • [17] A. Wu, J. Yuan, K. Kuang, B. Li, R. Wu, Q. Zhu, Y. T. Zhuang, and F. Wu, “Learning Decomposed Representations for Treatment Effect Estimation,” IEEE Transactions on Knowledge and Data Engineering, pp. 1–1, 2022.
  • [18] N. Hassanpour and R. Greiner, “Learning Disentangled Representations for CounterFactual Regression,” in International Conference on Learning Representations, Mar. 2020.
  • [19] P. Schwab, L. Linhardt, S. Bauer, J. M. Buhmann, and W. Karlen, “Learning Counterfactual Representations for Estimating Individual Dose-Response Curves,” Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, no. 04, pp. 5612–5619, Apr. 2020.
  • [20] L. Yao, Y. Li, S. Li, M. Huai, J. Gao, and A. Zhang, “SCI: Subspace Learning Based Counterfactual Inference for Individual Treatment Effect Estimation,” in Proceedings of the 30th ACM International Conference on Information & Knowledge Management, ser. CIKM ’21.   New York, NY, USA: Association for Computing Machinery, Oct. 2021, pp. 3583–3587.
  • [21] L. Yao, S. Li, Y. Li, M. Huai, J. Gao, and A. Zhang, “ACE: Adaptively Similarity-Preserved Representation Learning for Individual Treatment Effect Estimation,” in 2019 IEEE International Conference on Data Mining (ICDM), Nov. 2019, pp. 1432–1437.
  • [22] ——, “Representation Learning for Treatment Effect Estimation from Observational Data,” in Advances in Neural Information Processing Systems, vol. 31.   Curran Associates, Inc., 2018.
  • [23] Y. Zhang, C. Li, I. W. Tsang, H. Xu, L. Duan, H. Yin, W. Li, and J. Shao, “Diverse Preference Augmentation with Multiple Domains for Cold-start Recommendations,” in 2022 IEEE 38th International Conference on Data Engineering (ICDE), May 2022, pp. 2942–2955.
  • [24] J. Cao, J. Sheng, X. Cong, T. Liu, and B. Wang, “Cross-Domain Recommendation to Cold-Start Users via Variational Information Bottleneck,” in 2022 IEEE 38th International Conference on Data Engineering (ICDE), May 2022, pp. 2209–2223.
  • [25] S. Zhou, L. Wang, S. Zhang, Z. Wang, and W. Zhu, “Active Gradual Domain Adaptation: Dataset and Approach,” IEEE Transactions on Multimedia, vol. 24, pp. 1210–1220, 2022.
  • [26] S. Zhou, H. Zhao, S. Zhang, L. Wang, H. Chang, Z. Wang, and W. Zhu, “Online Continual Adaptation with Active Self-Training,” in Proceedings of The 25th International Conference on Artificial Intelligence and Statistics.   PMLR, May 2022, pp. 8852–8883.
  • [27] H. Fang, B. Chen, X. Wang, Z. Wang, and S.-T. Xia, “GIFD: A Generative Gradient Inversion Method with Feature Domain Optimization,” in Proceedings of the IEEE/CVF International Conference on Computer Vision, 2023, pp. 4967–4976.
  • [28] H. Wu, Y. Yan, G. Lin, M. Yang, M. K. Ng, and Q. Wu, “Iterative Refinement for Multi-Source Visual Domain Adaptation (Extended abstract),” in 2023 IEEE 39th International Conference on Data Engineering (ICDE), Apr. 2023, pp. 3829–3830.
  • [29] Y. Yan, H. Wu, Y. Ye, C. Bi, M. Lu, D. Liu, Q. Wu, and M. K. Ng, “Transferable Feature Selection for Unsupervised Domain Adaptation : Extended Abstract,” in 2023 IEEE 39th International Conference on Data Engineering (ICDE), Apr. 2023, pp. 3855–3856.
  • [30] C. Chen, J. Xiao, J. Liu, J. Zhang, J. Jia, and N. Hu, “Unsupervised Intra-Domain Adaptation for Recommendation via Uncertainty Minimization,” in 2023 IEEE 39th International Conference on Data Engineering Workshops (ICDEW), Apr. 2023, pp. 79–86.
  • [31] Z. Chen, T. Xiao, and K. Kuang, “BA-GNN: On Learning Bias-Aware Graph Neural Network,” in 2022 IEEE 38th International Conference on Data Engineering (ICDE), May 2022, pp. 3012–3024.
  • [32] J. Yuan, X. Ma, D. Chen, K. Kuang, F. Wu, and L. Lin, “Label-Efficient Domain Generalization via Collaborative Exploration and Generalization,” in Proceedings of the 30th ACM International Conference on Multimedia, ser. MM ’22.   New York, NY, USA: Association for Computing Machinery, Oct. 2022, pp. 2361–2370.
  • [33] M. Sugiyama, M. Krauledat, and K.-R. Müller, “Covariate Shift Adaptation by Importance Weighted Cross Validation,” The Journal of Machine Learning Research, vol. 8, pp. 985–1005, Dec. 2007.
  • [34] H. Shimodaira, “Improving predictive inference under covariate shift by weighting the log-likelihood function,” Journal of Statistical Planning and Inference, vol. 90, no. 2, pp. 227–244, Oct. 2000.
  • [35] P. Cui and S. Athey, “Stable learning establishes some common ground between causal inference and machine learning,” Nature Machine Intelligence, vol. 4, no. 2, pp. 110–115, Feb. 2022.
  • [36] H. Wang, Z. He, Z. C. Lipton, and E. P. Xing, “Learning Robust Representations by Projecting Superficial Statistics Out,” Mar. 2019.
  • [37] K. Muandet, D. Balduzzi, and B. Schölkopf, “Domain Generalization via Invariant Feature Representation,” in Proceedings of the 30th International Conference on Machine Learning.   PMLR, Feb. 2013, pp. 10–18.
  • [38] S. Fan, X. Wang, C. Shi, P. Cui, and B. Wang, “Generalizing Graph Neural Networks on Out-Of-Distribution Graphs,” Nov. 2021.
  • [39] X. Zhang, P. Cui, R. Xu, L. Zhou, Y. He, and Z. Shen, “Deep Stable Learning for Out-of-Distribution Generalization,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2021, pp. 5372–5382.
  • [40] B. M. Lake, T. D. Ullman, J. B. Tenenbaum, and S. J. Gershman, “Building machines that learn and think like people,” Behavioral and Brain Sciences, vol. 40, p. e253, Jan. 2017.
  • [41] A. Wu, K. Kuang, R. Xiong, M. Zhu, Y. Liu, B. Li, F. Liu, Z. Wang, and F. Wu, “Learning Instrumental Variable from Data Fusion for Treatment Effect Estimation,” Proceedings of the AAAI Conference on Artificial Intelligence, vol. 37, no. 9, pp. 10 324–10 332, Jun. 2023.
  • [42] F. Johansson, U. Shalit, and D. Sontag, “Learning Representations for Counterfactual Inference,” in Proceedings of The 33rd International Conference on Machine Learning.   PMLR, Jun. 2016, pp. 3020–3029.
  • [43] F. D. Johansson, N. Kallus, U. Shalit, and D. Sontag, “Learning Weighted Representations for Generalization Across Designs,” Feb. 2018.
  • [44] Y. Chang and J. Dy, “Informative Subspace Learning for Counterfactual Inference,” Proceedings of the AAAI Conference on Artificial Intelligence, vol. 31, no. 1, Feb. 2017.
  • [45] K. Kuang, P. Cui, B. Li, M. Jiang, and S. Yang, “Estimating Treatment Effect in the Wild via Differentiated Confounder Balancing,” in Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, ser. KDD ’17.   New York, NY, USA: Association for Computing Machinery, Aug. 2017, pp. 265–274.
  • [46] S. Athey, G. W. Imbens, and S. Wager, “Approximate Residual Balancing: De-Biased Inference of Average Treatment Effects in High Dimensions,” Jan. 2018.
  • [47] J. Hainmueller, “Entropy Balancing for Causal Effects: A Multivariate Reweighting Method to Produce Balanced Samples in Observational Studies,” Political Analysis, vol. 20, no. 1, pp. 25–46, 2012/ed.
  • [48] Z. Shen, P. Cui, K. Kuang, B. Li, and P. Chen, “Causally Regularized Learning with Agnostic Data Selection Bias,” in Proceedings of the 26th ACM International Conference on Multimedia, ser. MM ’18.   New York, NY, USA: Association for Computing Machinery, Oct. 2018, pp. 411–419.
  • [49] K. Kuang, R. Xiong, P. Cui, S. Athey, and B. Li, “Stable Prediction with Model Misspecification and Agnostic Distribution Shift,” Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, no. 04, pp. 4485–4492, Apr. 2020.
  • [50] Z. Shen, P. Cui, T. Zhang, and K. Kunag, “Stable Learning via Sample Reweighting,” Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, no. 04, pp. 5692–5699, Apr. 2020.
  • [51] Guido W Imbens and Donald B Rubin, Causal Inference in Statistics, Social, and Biomedical Sciences.   Cambridge University Press, 2015.
  • [52] J. Liu, Z. Hu, P. Cui, B. Li, and Z. Shen, “Heterogeneous Risk Minimization,” in Proceedings of the 38th International Conference on Machine Learning.   PMLR, Jul. 2021, pp. 6804–6814.
  • [53] R. Xu, X. Zhang, Z. Shen, T. Zhang, and P. Cui, “A Theoretical Analysis on Independence-driven Importance Weighting for Covariate-shift Generalization,” Jul. 2022.
  • [54] B. K. Sriperumbudur, K. Fukumizu, A. Gretton, B. Schölkopf, and G. R. G. Lanckriet, “On integral probability metrics, \phi-divergences and binary classification,” Oct. 2009.
  • [55] A. Müller, “Integral Probability Metrics and Their Generating Classes of Functions,” Advances in Applied Probability, vol. 29, no. 2, pp. 429–443, Jun. 1997.
  • [56] A. Gretton, K. Fukumizu, C. Teo, L. Song, B. Schölkopf, and A. Smola, “A Kernel Statistical Test of Independence,” in Advances in Neural Information Processing Systems, vol. 20.   Curran Associates, Inc., 2007.
  • [57] E. V. Strobl, K. Zhang, and S. Visweswaran, “Approximate Kernel-Based Conditional Independence Tests for Fast Non-Parametric Causal Discovery,” Journal of Causal Inference, vol. 7, no. 1, Mar. 2019.
  • [58] J. L. Hill, “Bayesian Nonparametric Modeling for Causal Inference,” Journal of Computational and Graphical Statistics, vol. 20, no. 1, pp. 217–240, Jan. 2011.
  • [59] H. R. Kunsch, “The Jackknife and the Bootstrap for General Stationary Observations,” The Annals of Statistics, vol. 17, no. 3, pp. 1217–1241, 1989.
  • [60] J. Duchi, E. Hazan, and Y. Singer, “Adaptive Subgradient Methods for Online Learning and Stochastic Optimization,” The Journal of Machine Learning Research, vol. 12, no. null, pp. 2121–2159, Jul. 2011.
  • [61] J. Bergstra and Y. Bengio, “Random search for hyper-parameter optimization,” The Journal of Machine Learning Research, vol. 13, no. null, pp. 281–305, Feb. 2012.
  • [62] Y. Zhang, X. Wang, J. Liang, Z. Zhang, L. Wang, R. **, and T. Tan, “Free Lunch for Domain Adversarial Training: Environment Label Smoothing,” Jan. 2023.
  • [63] X. Tan, L. Yong, S. Zhu, C. Qu, X. Qiu, X. Yinghui, P. Cui, and Y. Qi, “Provably Invariant Learning without Domain Information,” in Proceedings of the 40th International Conference on Machine Learning.   PMLR, Jul. 2023, pp. 33 563–33 580.
  • [64] D. Krueger, E. Caballero, J.-H. Jacobsen, A. Zhang, J. Binas, D. Zhang, R. L. Priol, and A. Courville, “Out-of-Distribution Generalization via Risk Extrapolation (REx),” in Proceedings of the 38th International Conference on Machine Learning.   PMLR, Jul. 2021, pp. 5815–5826.
  • [65] Y.-F. Zhang, J. Wang, J. Liang, Z. Zhang, B. Yu, L. Wang, D. Tao, and X. Xie, “Domain-Specific Risk Minimization for Domain Generalization,” in Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, ser. KDD ’23.   New York, NY, USA: Association for Computing Machinery, Aug. 2023, pp. 3409–3421.
  • [66] X. Zhou, Y. Lin, W. Zhang, and T. Zhang, “Sparse Invariant Risk Minimization,” in Proceedings of the 39th International Conference on Machine Learning.   PMLR, Jun. 2022, pp. 27 222–27 244.
  • [67] M. Zhang, J. Yuan, Y. He, W. Li, Z. Chen, and K. Kuang, “MAP: Towards Balanced Generalization of IID and OOD through Model-Agnostic Adapters,” in Proceedings of the IEEE/CVF International Conference on Computer Vision, 2023, pp. 11 921–11 931.
  • [68] D. Almond, K. Y. Chay, and D. S. Lee, “The Costs of Low Birth Weight*,” The Quarterly Journal of Economics, vol. 120, no. 3, pp. 1031–1083, Aug. 2005.
  • [69] V. Dorie, “Vdorie/npci,” May 2023.