Proxy Methods for Domain Adaptation
Google Research
Stanford University
University College London
Google DeepMind
Gatsby Computational Neuroscience Unit )
Abstract
We study the problem of domain adaptation under distribution shift, where the shift is due to a change in the distribution of an unobserved, latent variable that confounds both the covariates and the labels. In this setting, neither the covariate shift nor the label shift assumptions apply. Our approach to adaptation employs proximal causal learning, a technique for estimating causal effects in settings where proxies of unobserved confounders are available. We demonstrate that proxy variables allow for adaptation to distribution shift without explicitly recovering or modeling latent variables. We consider two settings, (i) Concept Bottleneck: an additional “concept” variable is observed that mediates the relationship between the covariates and labels; (ii) Multi-domain: training data from multiple source domains is available, where each source domain exhibits a different distribution over the latent confounder. We develop a two-stage kernel estimation approach to adapt to complex distribution shifts in both settings. In our experiments, we show that our approach outperforms other methods, notably those which explicitly recover the latent confounder.
1 Introduction
The goal of domain adaptation is to transfer an accurate model from a labeled source domain to an unlabeled target domain, which has a different but related distribution (pan2010domain; koh2021wilds; malinin2021shifts). It is motivated by the fact that labeling data is often labor intensive, and sometimes requires domain expertise. For example, the distribution of patients diagnosed with a condition from hospital and hospital may differ due to patients’ socioeconomic status, demographics, and other factors. However, labeled data might be only be available at hospital and not at hospital (e.g., due to less funding). As a result, an accurate model for patients from hospital may perform poorly for patients from hospital .
In order to provide guarantees on the accuracy of a transferred model, one of two classical assumptions have been made: label shift or covariate shift. Label shift (buck1966comparison; lipton2018detecting) assumes that the distribution of a label shifts between source and target domains, but the conditional distribution does not. Conversely, covariate shift (shimodaira2000improving) assumes that the covariate distribution shifts between domains, but the distribution stays the same. Each assumption provides theoretical guarantees on the generalization of a transferred classifier. In fact, without any assumptions, the source and target domains could differ arbitrarily, making guarantees impossible. However, these assumptions are often too restrictive to apply in real-world settings (zhang2015multi; schrouff2022diagnosing). For instance, if covariates and labels are confounded by a third variable , it is possible for neither or to be equal across domains. For example, demographic information could confound the relationship between a diagnosis and a radiological image . In this example, if two hospitals have different distributions over demographics, both label shift and covariate shift adaptation methods will fail to transfer a classifier across hospitals.
To address this, recent work has introduced a latent shift assumption: the distribution of , an unobserved latent confounder of and , shifts between the source and target domain (alabdulmohsin2023adapting). In this setting, all distributions of and (without conditioning on ) may differ across the domains, violating label and covariate shift assumptions.
Contributions. We propose techniques for domain adaptation under the latent shift assumption that are guaranteed to identify the optimal predictor in the target domain. We make use of proxy methods (miao2018identifying), which are a recently developed framework for causal effect estimation in the presence of a hidden confounder , given indirect proxy information on . Compared to prior work (alabdulmohsin2023adapting), our techniques do not require: identifying the distribution of the latent variable , that be discrete, or further linear independence assumptions. We consider two settings: (1) Concept Bottleneck: we observe in both domains a proxy of the unobserved confounder and a concept that mediates the direct relationship between and (alabdulmohsin2023adapting), or (2) Multi-Domain: we do not observe in either domain, but have access to observations from multiple source domains. For both settings, we provide guarantees for identifying without observing in the target domain. When is identifiable, we develop practical two-stage kernel estimators to perform adaptation.
2 Related Work
The development of techniques for learning robust models and adapting to distribution shift has a long history in machine learning, but recently has received increased attention (shen2021towards; zhou2022domain; wang2022generalizing).
Causality for domain adaptation. Our work is inspired by techniques that formulate the covariate/label shift settings as assumptions on the causal structure for domain adaptation and distributional robustness (e.g, scholkopf2012causal; peters2015causal; zhang2015multi; subbaswamy2019preventing; rothenhausler2021anchor; veitch2021counterfactual; magliacane2018domain; arjovsky2019invariant; ganin2016domain; ben2010theory; oberst2021regularizing).
Proximal causal inference. Our identification technique is inspired by approaches used to identify causal effects with unobserved confounding with observed proxies (kuroki2014measurement; miao2018identifying; deaner2018proxy; tchetgen2020introduction; mastouri2021proximal; cui2023semiparametric; xu2023kernel). These approaches design ‘bridge functions’ to connect quantities involving a proxy with those of the label . The beauty of this approach is that these bridge functions are implicitly a marginalization over . This allows these approaches to identify causal quantities without identifying distributions involving .
Latent shift. Our work is most closely related to alabdulmohsin2023adapting, who introduced the setting of latent shift with proxies and concepts . They showed that the optimal predictor is identifiable in the target domain if and are observed in the source domain and is observed in the target domain. To do so, they required (a) identification of distributions involving , (b) that is a discrete variable, (c) knowledge of the dimensionality of , and (d) additional linear independence assumptions. In contrast, our work derives identification results for arbitrary , and does not require any of (a)-(d). However, there is no free lunch: to achieve this, we require that proxies are observed in the target, and either that: (i) concepts are also observed in the target, or (ii) we observe multiple source domains. For (ii) we do not require in either the source or the target, but for full identification we require that is discrete.
3 Problem Framework
Let and denote the probability distribution functions of the source domain and target domain, respectively. Let and indicate source and target quantities. Our goal is to study identification and estimation of the optimal target predictor when is not observed in the target domain.
Concept Bottleneck. The first setting we study is described by the graph in Figure 0(c). We have two additional variables: (i) proxies , which provide auxiliary information about , or can be seen as a noisy version of it (kuroki2014measurement), and (ii) concepts , which mediate or ‘bottleneck’ the relationship between the covariates and labels (goyal2019explaining; koh2020concept). For example, koh2020concept describe a setting where the concepts are high-level clinical and morphological features of a knee X-ray , which mediate the relationship with osteoporosis severity . In this example, could describe demographic variations that alter symptoms and outcome , and the proxies could include patient background and clinical history (e.g., prior diagnoses, medications, procedures, etc). For the source domain we assume we observe and for the target domain we observe .
We formalize the notion of latent shift, as introduced in alabdulmohsin2023adapting.
Assumption 1 (Concept Bottleneck, alabdulmohsin2023adapting).
The shift between and is located in unobserved , i.e., there is a latent shift , but , where .
This assumption states that every variable conditioned on is invariant across domains. However, as , none of the marginal distributions are: for . This assumption is a generalization of covariate shift (shimodaira2000improving) and label shift (buck1966comparison), with associated graphs in Figure 0(a)–0(b).
Assumption 2 (Structural assumption).
Graphs in Figure 1 are faithful and Markov (spirtes2000causation).
Under Assumption 2, we have the following conditional independence properties for the graph in Figure 0(c):
With this conditional independence structure, blocks the information from to and blocks the information flow from to . We will see in Section 4 that these assumptions allow us to obtain from in the target domain, where the latter is a function of observed quantities.
Multi-domain. In the second setting, suppose we do not observe the concepts in any domain, but instead observe data from multiple source domains, according to the graph in Figure 0(d). For instance, we may want to learn a classifier for a target hospital that has only unlabelled data, using data from several source hospitals with labelled data. Here, let be a random variable in denoting a prior over the source domains, and let be the distribution of given . We make draws from , indexed by , and write . For each source domain , we observe . For the target, we denote it with index and only observe . In general let and for any . For this setting we replace Assumption 1 with the following shift assumption.
Assumption 3 (Multi-Domain).
For each such that , we have .
4 Identification under Latent Shifts
Our identification techniques are inspired by proximal causal inference (tchetgen2020introduction). The key idea is to design so-called “bridge” functions to identify distributions confounded by unobserved variables. We first show that with additional proxies and concepts, is identifiable under any latent shift.
4.1 Identification with Concepts
To prove identifiability, we need certain assumptions to hold for the shift. The first is a regularity assumption, also known as a completeness condition, and is commonly used to identify causal estimands (d2011completeness; miao2018identifying).
Assumption 4 (Informative variables).
Let be any mean squared integrable function. Both the source domain and the target domain, , satisfy for all if and only if almost surely with respect to .
At a high level, completeness states that the must have sufficient variability related to the change of . This is a common assumption made in proximal causal inference (cf. Condition (ii) in miao2018identifying and Assumption 3 in mastouri2021proximal). For more details on the justification of completeness assumption, see the supplementary material of miao2022identifying.
Second, we need a guarantee on the support of . Intuitively, if a has non-zero probability in the target domain, it should have non-zero probability in the source domain as well. Otherwise, it is impossible to adjust to certain shifts (as we never see these regimes in the source domain). This is similar to the positivity assumption commonly made in causality literature (hernan2006estimating).
Assumption 5 (Positivity).
For any , if then .
If data are generated according to Figure 0(c), and the regularity conditions 8–10 hold (see Appendix A.2), miao2018identifying first showed the existence of the solutions of the following equations:
(4.1) | ||||
The terms are called ‘bridge’ functions as they connect the proxy to the label . If we are able to identify then we can identify , by using eq. (4.1) to obtain and marginalizing over .
We show that it is possible to connect identification of with that of , leading directly to identification of .
Theorem 4.1.
The proof is given in Appendix B.1. Hence, given and from the target , we are able to adapt to arbitrary distribution shifts in unobserved . The advantage of this approach is that it will not require estimating any distributions involving . We demonstrate this in Section 5.
While concepts can ensure identifiability, they may not be available in practice. In this case, a natural question is whether the optimal target predictor is still identifiable. In the next section we show that if we instead have access to data from multiple source domains, may again be identifiable.
4.2 The Blessings of Multiple Domains
We now turn to the multi-domain setting. The graphical structure in Figure 0(d) is similar to the structure in Figure 0(c) with replaced by , replaced by , and the arrow between and flipped. Although the bridge function proposed by miao2018identifying assumes an edge from to , changing the direction from to does not change the conditional independence structure (pearl2009causality). The main difference is we will only be able to guarantee full identification when is discrete. We start by demonstrating this, and then give an example of the inherent difficulty of identification when is continuous.
To begin, for simplicity, assume and are discrete (with dimensionalities and ). We have finitely many samples from , denoted as , corresponding to our training domains. We seek a bridge function (in this case, a matrix ) satisfying
(4.2) |
for all , where is the conditional expectation obtained in domain , and .
In order to identify , and then , we need enough source domains to capture the variability of . The following result describes how many we need.
Proposition 4.2.
Suppose that we have source domains and , have and categories respectively. Then, if and subject to appropriate rank conditions (see proof in Appendix B.2), the bridge function is identifiable and does not depend on the specific .
This result generalizes the identification analysis developed in miao2018identifying. If the number of observed source domains is greater than the dimension of the latent , then subject to appropriate identifiability requirements (detailed in Appendix B.2), we can recover the bridge .
Now, consider the case where is discrete but all observed variables are continuous. In this case we have the following system
(4.3) |
for . The proof of existence of is a modification of Proposition A.2, as shown in Proposition A.3. In order to identify target , we need the following assumption.
Assumption 6.
Let be a square integrable function on . For each and for all , if and only if , almost surely.
Given this assumption we can prove identifiability.
Proposition 4.3.
The proof is given in Appendix B.3. Crucially, this result is valid only when Assumptions 6 holds, and it remains unclear when it is expected to hold. Proposition 4.2 suggests that Assumptions 6 is not vacuous when is finite dimensional. We plan to investigate further this in future work.
Now let us consider the case where is continuous. In this case, unfortunately, Assumption 6 may not hold, preventing identification of . This is illustrated in the following example.
Example 4.4.
Recall the decomposition of both sides of (4.3). Under Assumption 2 and given the existence of (Proposition A.2),
(4.5) | ||||
(4.6) |
For every , Eqs. (4.5) and (4.6) represent projections onto Consider with periodic boundary conditions, and for a given define (note that cosines form an orthonormal basis). We now construct an example where (4.5) holds for some but not for others. Define the difference
(4.7) | |||
In this case, and in particular, (4.5) holds for all but not for
This example illustrates a larger point: that for continuous , no finite set of projections will suffice to completely characterize the square integrable functions on . That said, as more projections are employed, and subject to appropriate assumptions on the smoothness of (4.7), the error will reduce as more domains are observed. The characterization of this convergence will be the topic of future work. In experiments, we show that the adaptation can still be effective even when the latent variable is continuous valued and follows different Beta distributions for each distinct , given just two training source domains.
5 Kernel Bridge Function Estimation
We introduce kernel methods to estimate the bridge functions and subsequently leverage the estimates to adapt to distribution shifts. Section 4 shows that bridge functions for both settings can be adapted to the target domain, so we drop the domain specific indices and use and to denote the bridge functions. We begin by introducing the notation.
Notation. Let be the tensor product, be the columnwise Khatri-Rao product and be the Hadamard product. For any space , let be a positive semidefinite kernel function and for any be the feature map. We denote to be the RKHS on associated with kernel function . The RKHS has two properties: (i) , for all and (ii) . We denote as the inner product and as the induced norm. For notation simplicity, we denote the product space associated with operation as . We define the kernel mean embedding as (smola2007hilbert) and the conditional mean embedding as (song2009hilbert; singh2019kernel). For , we denote the -th batch of i.i.d. samples as . Define the Gram matrices as , . Let be the vectorized feature map such that .
5.1 Adaptation with Concepts
Suppose that for the bridge function , where is a RKHS. It follows from Theorem 4.1 that
(5.1) |
To adapt to the distribution shifts, we estimate the bridge function in the source domain and the conditional mean embedding in the target domain. The empirical estimate of the conditional mean embedding along with the consistency proof have been provided in (song2009hilbert; grunewalder2012conditional) thus we focus on the estimation procedure of the bridge function .
To estimate the bridge function , we employ the regression method developed in mastouri2021proximal. Recall . We define the population risk function in the source domain as:
(5.2) | ||||
The procedure to optimize (5.2) involves two stages. In the first stage, we estimate the conditional mean embedding , which we will use as a plug-in estimator to estimate in the second step. Given i.i.d. samples from the source distribution and a regularizing parameter , we denote , as the Gram matrices and , as -dimensional vectorized feature maps of , respectively. Following the procedure developed in song2009hilbert, the estimate of is
(5.3) | ||||
In the second stage, we replace with in (5.2) and define the empirical risk. Consider i.i.d. samples from the source distribution and a regularization parameter , we want to minimize
(5.4) |
We follow the same analysis procedure derived in mastouri2021proximal. The solution to (5.4) is shown in the following.
Proposition 5.1.
Let , be the Gram matrices of and , respectively. Let , be the cross Gram matrices of and , respectively. For any , there exists a unique optimal solution to (5.4) of the form
where , , and .
Proposition 5.1 is an application of the Representer theorem (scholkopf2001generalized) – the optimal estimate of the infinite dimensional operator is a finite rank operator spanned by the feature space of and .
Finally, given estimate and a new sample , we can construct the empirical predictor of (5.1) as
This completes the full adaptation procedure.
On classification tasks. For classification tasks, where the label is , we treat the multi-task regressor as a classifier. We encode by a one-hot encoder and then regress on the encoded . Each label has a corresponding bridge function for . For , let the encoded be . Then for each , we can estimate by replacing in (5.4) with . For each new sample , the predicted score of label is , and we select the label that has the highest prediction score: .
5.2 Adaptation with Multiple Domains
In the multiple source domain setting, the estimation of follows similarly to that of . Assuming that , then (4.3) can be written as
for . The task is to estimate from the source domain and then apply it to the target domain. We can define the population risk function as
(5.5) | ||||
We employ the two-stage estimation procedure as we did for estimating : (i) we first estimate and then (ii) plug the estimate to estimate .
At the -th domain, we observe the samples: . As with (5.3), we learn a conditional mean embedding , where and for . In the second stage, given another batch of independent samples: for , we minimize:
(5.6) |
Then, yields an analytical solution in similar form to shown in Proposition 5.1 (see Appendix C.2 for details). Finally, with the estimated conditional mean embedding and a new sample from the target test set, we have
We convert the regression task with to the classification task by learning bridge functions, where each bridge function corresponds to label .
6 Experiments
Task | ORACLE | Cat-ERM | Avg-ERM | SA | MK | WCSC | DANN | MMD | Proposed |
---|---|---|---|---|---|---|---|---|---|
Task 1 | |||||||||
Task 2 | |||||||||
Task 3 | |||||||||
We verify our theory with both simulated and real data, demonstrating robustness to latent shifts and transferablility of the bridge functions.
For the setting with concept variables present, we compare our method with baselines: Empricial Risk Minimization (ERM), Covariate shift weighting (COVAR) (shimodaira2000improving), Label shift weighting (LABEL) (buck1966comparison), and the spectral (LSA-S) and Wasserstein Autoencoder (LSA-WAE) latent shift adaptation approaches (alabdulmohsin2023adapting). For the multi-domain setting, we compare our method with baselines: Simple Adaptation (SA) (mansour2008domain), Weighted Combination of Source Classifiers (WCSC) (zhang2015multi), and Marginal Kernel (MK) (blanchard2011generalizing). We also compare with multi-domain generalization baselines (muandet2013domain): Domain Adversarial Neural Networks (DANN) (ganin2016domain), Maximum Mean Discrepancy (MMD) (GreBorRasSchetal12). Additionally, we modify the ERM method to the multi-domain setting by concatenating the source samples to learn one ERM model (Cat-ERM) or taking the average result of each source domain ERM model (Avg-ERM). The ORACLE model is a model that is trained on target distribution samples. and evaluated on held-out target distribution samples. The tuning parameters for all models including the proposed model are selected using five-fold cross-validation. Details regarding the setups are in Appendix D.
Classification task. The task designed in alabdulmohsin2023adapting is a binary classification problem with and the latent variable is a Bernoulli random variable. Additionally, are continuous random variables and is a discrete variable. We have one source domain with . We evaluate the models on the target distribution with shifting from . The goal of this task is to investigate whether the adaptation method is robust to any arbitrary shift of .
The ORACLE and ERM model are implemented as MultiLayer Perceptrons (MLP). The kernel function used in the proposed method is the Gaussian kernel.
We compare the proposed method with the LSA-S and Wasserstein Autoencoder adaptation LSA-WAE approaches developed in alabdulmohsin2023adapting. While all three methods are designed to adjust shift for the same graph in Figure 0(c), our method takes additional as training samples in the target domain while LSA-S and LSA-WAE only take . For all three methods, only is observed in the test data.
While the identification theory developed in (alabdulmohsin2023adapting) does not require in the target domain, we are aware that in practice, having more information in the target domain may improve estimation. To make the methods more directly comparable, we design an additional step to incorporate from the target in the LSA-S algorithm. We describe this procedure in more detail in Appendix D.1.
Results are shown in Figure 2. The proposed method is more robust to the shift compared to baselines and is close to the ORACLE model. It is shown that with observed in the target domain, LSA-S does not improve the performance compared to LSA-S without . We also compare results under different noise levels and observe similar trends as discussed in Appendix D.
dSprites dataset regression task. We test the proposed procedure on the dSprites (dsprites17) dataset, an image dataset described by five latent parameters (shape, scale, rotation, posX, and posY). Motivated by dsprites17’s experiments, we design a regression task where the dSprites images (64 64 = 4096-dimensional) are and subject to a nonlinear confounder which is a rotation of the image. and are continuous random variables. For this experiment, we have training samples and test samples. Further details about the procedure are in Appendix D.
In the results in Figure 2, we vary , which controls which region of the source distribution that the target distribution concentrates. We design the experiment such that increasing shifts the target distribution to increasingly low mass regions of the source distribution. We compute the mean squared error of each method on test examples from the target distribution.
We find that, while the baseline methods degrade as the target distributions shift increases, the proposed method adapts and maintains low error, nearly matching the error achieved by the oracle, which is trained on target distribution samples.
6.1 Multi-Domain Adaptation
In the multi-domain setting, we use the same classification dataset provided in alabdulmohsin2023adapting as Section D.6. We assume that is not observed in any domain and generate multiple datasets drawn with different distributions on .
Classification task. We construct three different tasks with different settings of over the source and target domains. For each task, we construct three source domains and one target domain, drawing random training samples for the each source domain and random training samples for the target domain. The set of source domains of of Task 1–3 have different combinations of distribution on documented in Appendix D.3.
The backbone models for ORACLE, Cat-ERM, Avg-ERM, and SA (mansour2008domain) are simple MLPs; MK (blanchard2011generalizing) is a weighted kernel support vector machine; WCSC (zhang2015multi) is a re-weighted kernel density estimator. SA (mansour2008domain) assumes that is the convex combinations of for ; WCSC (zhang2015multi) assumes that is a linear mixture of for domain is an i.i.d. realization from the general distribution.
The results are shown in Table 1. Overall, we find our approach performs better than ERM and baseline multi-domain adaptation methods. All methods perform better in the setting of Task 2 than for Task 1, informally demonstrating the effect of the closeness of the source domains to the target domain. For Task 3, while our proposed approach performs best, ERM also performs well, and substantially better than the domain adaptation baselines.
Regression task. We consider two regression tasks, where is either a Bernoulli or a Beta random variable. We present the results in Appendix D.
6.2 Concept and multi-domain adaptation with MIMIC-CXR
We conduct a small-scale experiment using a sample of chest X-ray data extracted from the MIMIC-CXR dataset (johnson2019mimic). We briefly describe the experimental design and results here, and include a complete description in Appendix D.7. We consider classification of the absence of a radiological finding from low-dimensional embeddings of the X-rays (sellergren2022simplified), using the absence of a radiological finding in the radiology report as the target of prediction. This corresponds to the “No Finding” label defined by irvin2019chexpert.
We consider distribution shifts similar to settings in makar2022causally, where patient sex is considered as a possible “shortcut" in the classification of the absence of a radiological finding. We impose distribution shift through structured resampling of the data where and is held constant. We perform both concept adaptation and multi-domain adaptation experiments with the MIMIC-CXR data. For the concept adaptation experiment, we consider the concept variable to be the embedding of a radiology report associated with the chest X-ray. We experiment with the use of patient age as a potential proxy for due to a hypothesized correlation between the presence of radiological findings and patient age.
The results are summarized in Figure 3. For both experiments, we find that the performance of baseline models fit using only information from the source domain(s) degrades under distribution shift. In the concept adaptation experiment, adaptation is relatively successful, as much of the performance of comparator models fit using target domain data is recovered by the adaptation procedure.
However, we find that the multi-domain adaptation procedure is not successful. In this case, we find that while the multi-domain adaptation procedure marginally outperforms a model fit using the concatenated source domain data under distribution shift, it recovers substantially less of the performance of the target domain model than the concept adaptation procedure does. Furthermore, the adapted model does not outperform the kernel estimators that only leverage information from the source domains. The lack of success in this setting could potentially be explained by insufficient number or diversity of domains relative to the level of noise induced by sampling variability and limited sample size.
7 Discussion
We propose a strategy for adaptation under distribution shift in a latent variable using a bridge function approach (miao2018identifying; tchetgen2020introduction). This approach allows for identification of the optimal predictor in the target domain without identifying the distribution of the latent variable and without distributional assumptions on the form of the latent. We require that proxies of the latent variable are present and that (i) mediating concepts are available or (ii) data from multiple source domains are present.
We argue our approach is useful for two reasons. First, the latent distribution in general is only identifiable under strict distributional assumptions (locatello2019challenging). Second, recovery of the latent variable may be challenging in practice even if it is identifiable (rissanen2021critical). For example, because most latent variable estimation methods are designed to model the data generating process (kingma2013auto), one might allocate substantial modeling capacity to variability in the data and the latent variable that are irrelevant to modeling the shift in the conditional distribution of . By contrast, we model only the components of the observable variables relevant to the adaptation.
Acknowledgments: We thank Zhu Li and Dimitri Meunier for helpful discussions. AG was partly supported by the Gatsby Charitable Foundation. OS was partly supported by the UIUC Beckman Institute Graduate Research Fellowship, NSF-NRT 1735252. KT was partly supported by NSF Graduate Research Fellowship Program. SK was partly supported by the NSF III 2046795, IIS 1909577, CCF 1934986, NIH 1R01MH116226-01A, NIFA award 2020-67021-32799, the Alfred P. Sloan Foundation, and Google Inc. This study was funded by Google LLC and/or a subsidiary thereof (‘Google’).
References
Appendix A Identification of the Distribution
In this section, we demonstrate the existence of the bridge functions and under certain regularity conditions. We first discuss the discrete case and then generalize to the continuous case.
A.1 The Discrete Case of the Bridge Function
The idea of bridge function may seem abstract in the continuous setting. When every variable is discrete, however, the construction of the bridge function is demonstrated by solving series of matrix problems. This idea originates from miao2018identifying and we apply the technique to show the construction of bridge function when every variable is discrete.
Let
be a column vector, and a matrix, respectively. We define similarly
for . We define
analogously. As an alternative to finding a such that
the proxy problem is converted to finding a such that
First, under the condition that , we can write
(A.1) |
Similarly, under the condition that , we have
(A.2) |
We introduce the following assumption:
Assumption 7.
Columns of are linearly independent. For every , the columns of satisfy for all .
A.2 Existence of the Bridge Function
The sufficient conditions of existence of are originally discussed in miao2018identifying, we adapt them to our setting and provide a brief review in this section. We assume the following completeness assumption and regularity conditions. This assumption is equivalent to Condition (iii) in miao2018identifying.
Assumption 8.
For any mean squared integrable function and for , almost surely if and only if almost surely.
Let be either the distribution from or , we consider as the conditional expectation operator associated with the kernel function
Then it follows that :
To find the solution , we assume the followings.
Assumption 9.
For any , .
This is a sufficient condition to ensure that is a compact operator (carrasco2007linear, Example 2.3). Hence, by the definition of a compact operator, there exists a singular system of for every .
Assumption 10.
For fixed :
-
1.
-
2.
.
The above two assumptions are restatements of Conditions (v)–(vii) in miao2018identifying. We adapt the results from Proposition 1 in miao2018identifying to the graph in Figure 0(c) which replaces the node by and node by .
Proposition A.1 (Existence of , adapted from Proposition 1 in miao2018identifying).
A.3 Existence of Bridge Function
The proof of the existence of is similar to the analysis of . Let be the integral operator associated with the kernel function . Then, we can write
Proposition A.2 (Existence of , Proposition 1 in miao2018identifying).
Assume that
-
1.
for any mean squared integrable function and for , almost surely if and only if almost surely;
-
2.
For any , ;
-
3.
For any , ;
-
4.
For any , , where is the singular system of .
Then the solution of exists.
A.4 Auxiliary Lemma
We introduce the Picard’s theorem as follows.
Lemma A.3 (Picard’s Theorem).
Let be a compact operator with singular system and be a given function in . Then the equation of first kind have solutions if and only if
-
1.
, where is the null space of the adjoint operator .
-
2.
.
Appendix B Transferring Bridge Functions
In this section, we discuss the identifiability results.
B.1 Proof of Theorem 4.1
For , recall that
Similarly, we can write
Under Assumption 4, we have
(B.1) |
almost surely with respect to , .
Suppose that such that . Then, by Assumption 5 , we must have . Hence, conditioned on the selected and and under Assumption 1, we have
We then can write
Note that, by Assumption 1, we have and hence the left hand side of the above equation is and we can conclude that:
almost surely. We complete the first part of proof.
To show the second part of the theorem, note that we can write
Since by Assumption 1, we can factorize the above equation as | ||||
Let the support of conditioned on be and . Hence, we have , and such that and . Then, we can further decompose the above as | ||||
Given , since the support of is included in the support of , so if , we must have and hence by Assumption 5, and we can swap with . | ||||
Since , we can add it to the above term and arrive at | ||||
(B.2) |
Since we can identify from the observable of the source domain by solving the linear system (4.1), given observable from the target domain, we can identify .
B.2 Proof of Proposition 4.2
The following proof is a generalization of the proof of miao2018identifying, suited to the multidomain case. All variables besides are assumed to be discrete-valued and multivariate: can take values for .
Let .
Similarly, define
This notation carries through to the remaining variables.
The approach we will take differs from the concept case (and standard proxy case) in the following way: we do not observe in the training or test domains, nor do we know its true dimension (indeed may be continuous valued). Rather, we assume that we have at least distinct draws from in training, where is the domain index, and that We also suppose that in test, we observe a distinct draw which was not seen in training.
Our goal is to obtain a bridge function, which in the categorical case will be a bridge matrix of dimension . Define for . We assume that for each ,
which implies that varies with and that we see a sufficient diversity of domains to span the space of vectors on .
The graphical model supports the conditional independence relation
however we will only require the standard proxy assumptions
Next, as in the concept case, we require
where we assume (as in the first condition of Assumption 7). The matrix is invariant to the distribution by construction. If we can solve for , then given a novel domain corresponding to the draw , we have
This allows us to compute conditional expectations under in the novel domain, based on observations of in this domain.
To solve for , we project both sides on a basis over arising from the training domains,
where we define , and likewise Then the above becomes
(B.3) |
This demonstrates that we can recover the domain-invariant purely from observed data.
One domain is not enough: We illustrate with an example, where we again consider the case where all variables are categorical:
(B.4) |
where is a vector of probabilities, is a vector of probabilities, and is a matrix for which we wish to solve. We have too few equations for the number of unknowns.
B.3 Proof of Proposition 4.3
For all , we can write
(B.5) | ||||
(B.6) |
By Assumption 6, the integrands of (B.5)–(B.6) have the following property
(B.7) |
almost surely with respect to . We will show that can be transferred to identify the distribution in the target domain.
We define the support set . Therefore, we can write
Furthermore, since we have , we can apply (B.7) to obtain
We complete the proof.
Appendix C Estimation Procedure
The estimation procedure of is discussed in Section C.1 and the estimation procedure of is discussed in Section C.2. In Section C.3, we discuss the case when either or is a discrete variable.
C.1 Proof of Proposition 5.1
The proof of Proposition 5.1 simply follows the result in (mastouri2021proximal) which extends from the representer theorem (scholkopf2001generalized). There exists a such that
(C.1) |
From song2009hilbert, we have and is the -th element of , a function on : . If we expand (C.1) with the previous expression, we have
where . Hence, the rest of the proof will focus on finding the expression of . Following the proof technique developed in (mastouri2021proximal), we introduce two following lemmas that assist the analysis.
Lemma C.1.
The square of the operator norm of , denoted as , can be represented as
Proof of Lemma C.1.
Write
Using the fact that , the above display can be written as | ||||
∎
Lemma C.2.
For any , ,
With Lemma C.1–C.2, we can write (5.4) as
(C.2) |
where
Then by setting the gradient of (C.2) with respect to to zero, we will obtain
Apply Woodbury matrix identity, the above display is equivalent as | ||||
(C.3) |
Using the fact that for matrices , , we can simplify as
Hence, using the fact that , we have
Hence, we can write (C.3) as
C.2 Proof of Kernel Bridge Function
We begin with the results.
Proposition C.3.
Let , be the Gram matrices of and , respectively. Let , be the cross Gram matrices of and , respectively. For any , there exists a unique optimal solution to (5.6) of the form
where , , and .
C.3 Estimation with discrete or
In the case when or happen to be discrete variables, a more efficient alternative to the estimator introduced in Section 5.1 which requires kernelized features of (or ), is to solve a separate regression of on for each (or ). Define the index set , we modify (5.3) as
where and with . Alternatively, one can apply the form in (5.3) but use binary kernel on (or ).
Appendix D Experiments
In this section we discuss the experimental settings and implementation details. We start with introducing the implementation details of all the baselines and proposed method. Then, we discuss the experimental settings.
D.1 Baselines of Adaptation with Concepts and Proxies
We introduce the baseline methods for the adaptation task with and . This includes the baselines methods COVARS, LABELS, ORACLE, LSA-W, LSA-S, LSA-S w/ target and the proposed method. To select the parameters for the regression task on dSprite, we apply five-fold cross-validation with mean squared error as the metric to select the kernel length scale and the ridge regularization penalty.
COVARS. We fit a domain classifier using logistic regression, compute instance weights following shimodaira2000improving, and learn a weighted kernel ridge regressor with a Gaussian kernel function on the source training samples.
LABELS. The label shift baseline assumes oracle access to labels in the target domain. For the classification task, we compute instance weights using the observed frequencies in the validation set for the source domain and the training set for the target domain. For the regression task, we compute the weights by fitting a Gaussian kernel density estimator using the source validation set and the target training set separately. We then use the fitted densities to estimate for each sample in the source training set. Finally, we learn a sample-weighted kernel ridge regressor with a Gaussian kernel on the source training samples.
ORACLE. For regression tasks, we learn a kernel ridge regressor with a Gaussian kernel on target training samples. For the classification task, we use a standard MLP trained with sample in the target domain. Details of the model structure are documented in Section D.2.
LSA-W. The estimation procedure follows Section 6 in alabdulmohsin2023adapting. In this case, we discretize the values of by applying additional transform for each sample .
LSA-S. The estimation procedure follows Algorithm 2–5 in alabdulmohsin2023adapting.
LSA-S w/ target . We briefly describe the procedure to incorporate target to LSA-S. alabdulmohsin2023adapting showed that can be decomposed as
(D.1) | ||||
(D.2) |
where is a permutation of original . Both LSA-WAE and LSA-S are multi-stage procedures to compute (a), (c), (d) individually and combine the results using formula (D.2) to obtain the predicted target distribution. Step (a) corresponds to Algorithm 5, (c) corresponds to Equation (17), and (d) corresponds to Algorithm 4 in (alabdulmohsin2023adapting).
With the additional from target, we can obtain (b) by slightly modifying the one estimation step in LSA-S. We test on this procedure, namely LSA-S w/ target W, with (c), (d) replaced by (b). Suppose that takes values in and be a permutation of . Define the matrix as:
where is the estimated conditional kernel density function obtained by Algorithm 3 in alabdulmohsin2023adapting. The step (b) is computed by solving the following least-squares:
Then, we compute the predicted conditional probability based on (D.1).
Proposed Method. For the regression task using the dSprite dataset, we employ the Gaussian kernel function as the feature map for both and . In the classification task, we also utilize the Gaussian kernel function for and . Additionally, we make use of a columnwise binary kernel for , which performs a binary kernel operation on each entry and computes the product of all function outputs. To compute , we apply one-hot encoder on and apply the results in Proposition 5.1 For choosing the kernel length scale for the classification task, we use the validation set with AUROC metric.
D.2 Baselines of Multi-Source Adaptation
For the first three baselines: Cat-ERM, Avg-ERM, and SA, we use a standard MLP model as the backbone structure. It is a single hidden layer MLP with size and ReLU activation functions. The network is trained using Adam optimizer (kingma2014adam) with learning rate . The batch size is set to be and the maximum number of iteration is set to be .
Cat-ERM. We concatenate all the samples across environments into one dataset. Then, we train the model with a standard MLP model as specified above.
Avg-ERM. For each environment, we train a standard MLP model. During testing, we take the average of predictions from all models.
Simple Adaptation (SA) (mansour2008domain). To implement the method, we build kernel density estimators with Gaussian kernel function to estimate the density for . We then reweigh the output of the classifier, a standard MLP, of each domain with the normalized weight . The kernel length scale is chosen using five-fold cross-validation with AUROC metric.
Marginal Kernel (MK) (blanchard2011generalizing). This method involves a kernel SVM with a product kernel on . For any and a distribution on , , the kernel function is defined as . Let be the number of samples. Here is a Gaussian kernel function, and is the mean of the Gram matrix , where for is a i.i.d. sample from and for is a i.i.d. sample from . To accommodate the large dataset, we precompute the Gram matrix and apply it to a linear classifier trained using Stochastic Gradient Descent (SGD) implemented in the package scikit-learn (scikit-learn). The kernel length scale is chosen using five-fold cross-validation with AUROC metric.
Weighted Combination of Source Classifiers (WCSC) (zhang2015multi). For each source environment, we estimate the conditional probability using kernel density estimator with the Gaussian kernel function. The rest of the estimation procedure follows Section 2 in zhang2015multi. The kernel length scale is chosen using five-fold cross-validation with AUROC metric.
Proposed Method. We use columnwise Gaussian kernel function as the feature map of , a Gaussian kernel function as the feature map of . The conditional mean embedding is estimated using the approach introduced in Section C.3. The analytical solution of is discussed in Proposition C.3. All the kernel length scale and the regularization parameters , are selected using five-fold cross-validation with AUROC metric.
ORACLE. The model is , where both the bridge function and are estimated using the target dataset, with the number of training samples equal to the training samples of the source domain. All the kernel length scale and the regularization parameters , are selected using five-fold cross-validation with AUROC metric.
D.3 Classification Task
The classification task discussed in Section D.6 is first introduced alabdulmohsin2023adapting. Let be the one-hot encoder, we follow their data generation procedure and generate samples using the following data generation process:
where the matrices are defined as
The coefficient in Figure 2. Figure 4 displays additional results where . We generate training samples, validation samples, and testing samples for the classification task with concepts and proxies.
In the multi-domain case, we construct different tasks: Task is composed of such that , , and a target domain with . For task , we select such that , , and . For task , we select such that , , and . The results are shown in Table 1– 2.
D.4 Comparison to Domain Generalization Baselines
ORACLE | ARM | CDANN | CORAL | DANN | GroupDRO | IRM | MMD | VREx | Proposed | |
---|---|---|---|---|---|---|---|---|---|---|
Task 1 | ||||||||||
Task 2 | ||||||||||
Task 3 | ||||||||||
Given that we observe multiple domains at test time, a natural question is: Does adaptation give us an advantage over generalization? In generalization, we cannot assume to have any observations in the target domain. We compare our adaptation method with multi-domain generalization baselines (muandet2013domain): Adaptive Risk Minimization (ARM) (zhang2021adaptive), Conditional Domain Adversarial Neural Networks (CDANN) (long2018conditional), Correlation Alignment (CORAL) (sun2016deep), Domain Adversarial Neural Networks (DANN) (ganin2016domain), Distributionally Robust Optimization for Group Shifts (GroupDRO) (sagawa2019distributionally), Invariant Risk Minimization (IRM) (arjovsky2019invariant), Maximum Mean Discrepancy (MMD) (Borgwardt2006IntegratingSB), and Risk Extrapolation (REx) (krueger2021out).
In Table 2, we show that our proposed method for domain adaptation in the multi-domain setting outperforms the state-of-the-art multi-domain generalization methods.
D.5 Regression Tasks
We consider three tasks. We will first introduce the simulated task and then discuss about the task on dSprite data (dsprites17).
D.5.1 Simulated Dataset
We consider the following data generation process.
Simulated regression task 1.
(D.3) | ||||
There are two source domains. We set for source domain and for source domain . According to the data generation process (D.3), is mostly positively correlated with in domain and negatively correlated with in domain . For each domain, we synthesized training samples and testing samples. We sweep across in the target domain. We run replications and the results shown in Figure 5. In the next task, we set to be a continuous random variable following a Beta distribution.
In this task, we expect the Cat-ERM method to fail drastically as we anticipate that the predicted versus is a flat line – the predicted result would be an average of the downward slo** line and upward slo** line . The result in Figure 5 supports our hypothesis, as the mean squared error remains nearly flat as we vary the target distribution .
Simulated regression task 2.
There are two source domains, corresponding to two draws from which we write . We set for the first source domain , and for the second source domain . The corresponding distributions over are shown in Figure 6. Under this setting, we test the target domain with , with distributions shown in Figure 6. For each domain, we synthesized training samples and testing samples. We run replications and the results shown in Figure 5.
D.6 Adaptation with Concepts and Proxies
D.6.1 dSprites Dataset
We test the proposed procedure on the dSprites dataset (dsprites17), an image dataset described by five latent parameters (shape, scale, rotation, posX, and posY). Motivated by dsprites17’s experiments, we design a regression task where the dSprites images (64 64 = 4096-dimensional) are and subject to a nonlinear confounder which is a rotation of the image (Figure 7). We fix all other latent parameters – shape is heart, scale is maximized, and all others are set to their 0’th position. and are continuous random variables. The data generation process is defined as follows
When fitting all model, both baselines and the proposed method, we project the images to via Gaussian Random Projection using the scikit-learn implementation (Bingham2001RandomPI; scikit-learn). Additionally, for the proposed method, we use a Gaussian kernel as the feature map for .
We generate training samples and test samples in our experiments. Then, we use five-fold cross-validation to select hyperparameters for baselines and proposed method for each () – hyperparameters are (i) ridge regression penalty and (ii) Gaussian kernel scaling factor. Once we select a set of hyperparameters for a value of , we perform 10 new random data regenerations to get transfer errors with 95% confidence intervals (Figure 2).
D.7 Classification of radiological findings with MIMIC-CXR
We conduct a small-scale experiment with chest X-ray data extracted from the MIMIC-CXR dataset (johnson2019mimic). We consider classification of the absence of a radiological finding in a chest X-ray. For this, we use the set of labels extracted by irvin2019chexpert. These labels correspond to 14 categories of radiological findings extracted based on mentions in the associated radiology reports. We specifically consider classification of the “No Finding” () label, corresponding to cases where no pathology was identified as positive or uncertain in the radiology report.
To define the dataset, we consider the set of 217,536 chest X-rays with defined Chexpert labels (irvin2019chexpert), MIMIC-IV entries, and pretrained embeddings (sellergren2022simplified). We then filter this dataset to the 212,567 examples considered as a part of the “train” partition provided by the MIMIC-CXR database (johnson2019mimic). We then partition the data into training, validation, and testing splits such that 80%, 10%, and 10% of the examples belong to each partition, respectively. For adaptation, we consider BioBERT (lee2020biobert) 768-dimensional embeddings of the radiology reports as concepts and the patient’s age as a proxy variable . For simplicity, we use the patient anchor_age defined through linkage to the MIMIC-IV database, regardless of the patient’s age at the time of the chest X-ray. Similar to the dSprites experiment, we further reduce the dimensionality of and to using Gaussian Random Projection fit on the full training partition (170,053 examples).
To define distribution shifts, we adopt a problem formulation similar to that of makar2022causally, where patient sex is considered as a possible “shortcut" in the classification of the absence of a radiological finding. As in makar2022causally, we impose distribution shift through structured resampling of the data where . For example, when , the prevalence of and . We implement the shift through a weighted sampling procedure that maintains the label shift invariance within patient sex subgroups, i.e., preserves under the distribution shift, where corresponds to patient sex. This procedure further fixes the total proportion of male and female patients in the population at 50%. For our experiments, we consider nine domains corresponding to cases where .
We perform both concept adaptation and multi-domain adaptation experiments with the MIMIC-CXR data. For the concept adaptation experiment, we perform weighted sampling with replacement of 1,000 examples from each of the training, validation, and testing partitions defined previously, separately for each domain. We fix the source domain to the case where and then adapt to each of the nine target domains. For the multi-domain adaptation experiment, we randomly sample 500 examples per domain and partition from the sets of 1,000 examples defined for the concept experiment. For this experiment, we consider a case where two source domains corresponding to and are available. To match the size of the aggregate source domain data with the size of the target domain, we sample 250 examples per partition for each source domain. We repeat the sampling procedure five times and report the mean standard deviation of performance metrics over the five replicates.
For both experiments, we perform two-fold cross-validation for the kernel length-scale parameters using data from the source domain(s). Here, we compare to ridge logistic regression models fit in the source and target domains, with the ridge penalty fit with five-fold cross validation. We use LR-Target to refer to logistic regression models fit in a target domain, LR-SOURCE to refer to models fit in a source domain, and Cat-LR to refer to logistic regression models fit with concatenated data from the multiple source domains. We use Bridge-SOURCE to refer to the kernel estimator that leverages the bridge function ( or for the concept and multi-domain adaptation settings, respectively) and conditional mean embedding ( or ) fit on the source domain data. Bridge-TARGET refers to the kernel estimator where both the bridge function and conditional mean embedding are fit on the target domain data.