Longitudinal Targeted Minimum Loss-based Estimation with Temporal-Difference Heterogeneous Transformer
Abstract
We propose Deep Longitudinal Targeted Minimum Loss-based Estimation (Deep LTMLE), a novel approach to estimate the counterfactual mean of outcome under dynamic treatment policies in longitudinal problem settings. Our approach utilizes a transformer architecture with heterogeneous type embedding trained using temporal-difference learning. After obtaining an initial estimate using the transformer, following the targeted minimum loss-based likelihood estimation (TMLE) framework, we statistically corrected for the bias commonly associated with machine learning algorithms. Furthermore, our method also facilitates statistical inference by enabling the provision of 95% confidence intervals grounded in asymptotic statistical theory. Simulation results demonstrate our method’s superior performance over existing approaches, particularly in complex, long time-horizon scenarios. It remains effective in small-sample, short-duration contexts, matching the performance of asymptotically efficient estimators. To demonstrate our method in practice, we applied our method to estimate counterfactual mean outcomes for standard versus intensive blood pressure management strategies in a real-world cardiovascular epidemiology cohort study.
1 Introduction
In the fields of medicine and public health, researchers frequently encounter data that are both high-dimensional and longitudinal. The outcomes of interest in these settings often involve time to the incidence of some failure event, such as total mortality (van der Laan & Robins, 2003; Salerno & Li, 2023). Estimating the counterfactual probability of the event is challenging in high-dimensional longitudinal settings. Existing methods suffer computationally due to lack of scalability and have worse performance due to curse-of-dimensionality (Wyss et al., 2022). In response, we propose an estimator that is computationally scalable and simultaneously allows for robust statistical inference. Our estimator incorporates a transformer architecture for estimating the target estimand, defined as the cumulative incidence probability under dynamic interventions, where the treatment sequence depends on patients’ evolving histories. The target estimand can be identified through the g-formula contingent upon suitable assumptions (Robins, 1986). However, the target functional involves integration over potentially high-dimensional time-dependent covariates across time-horizon, posing computational challenges. Our method advances the longitudinal targeted minimum loss-based estimation (LTMLE) framework (van der Laan & Gruber, 2012; Lendle et al., 2017) by leveraging the computational capabilities of the transformer, facilitating the estimation of the target estimand and relevant nuisance parameters.
A number of estimators for the target estimand were proposed since the pioneering work by Robins (Robins, 1986). These estimators first factor the target parameter as a functional of nuisance parameters given a structural assumption on the underlying variables. Then, a common strategy to construct an estimator is plug-in, where one estimate the nuisance components with some models and then plug them into the target functional. However, since the naive plug-in of the estimated nuisance components causes bias, several methods have been proposed to remove this bias using the first variation of the target functional called influence function. Examples of such de-biasing techniques include one-step estimators (Klaassen, 1987; Bickel et al., 1993), estimating equations (Robins et al., 1994; Chernozhukov et al., 2022), and targeted minimum loss-based estimation (TMLE) (van der Laan & Rose, 2011). Notably, due to its plug-in property, TMLE stands out because it will respect any conditional bounds on the outcome or global bounds on the statistical model, resulting in improved finite-sample performance (Gruber & van der Laan, 2012).
The first-order bias of the plug-in estimator is represented as a population mean of the influence function evaluated at the estimated nuisance distribution. Bias correction is performed by solving the empirical analogue of this term. TMLE solves this term by optimizing a loss function along a submodel starting from the initial nuisance estimate (Bang & Robins, 2005; van der Laan & Rubin, 2006; van der Laan & Rose, 2011). The loss function and the submodel are chosen so that the linear span of the derivative of the loss function along the submodel contains the efficient influence function, the influence function with minimal variance. Targeting is the term that refers to this correction by fluctuating of the initial estimate along the path.
The current LTMLE, a TMLE developed in the context of longitudinal data, relies on a sequential regression representation of the target estimand (Bang & Robins, 2005). An ensemble machine learning technique called super learner is then used to estimate the nuisance components of the data-generating distribution (van der Laan et al., 2007). In real-world complex longitudinal data, these nuisance components, such as the survival probability at a given time, may depend on all past histories. Therefore, the Markovian property, which states that future variable values only depends on the present variables, independent of the past, is not guaranteed to hold. In other words, every observed variable could depend on the past variables in the time ordering. Hence, we want our model for the nuisance components to be able to take variable length of history as input. Under the targeted learning framework, we introduce a transformer architecture tailored towards our longitudinal setting and propose a novel method for the bias correction using a single fluctuation parameter across all time-points.
Our contribution includes: 1) Developed a general method that uses a transformer architecture to facilitate valid statistical inference in longitudinal settings concerning survival outcomes under dynamic interventions; 2) Proposed a method for bias correction using one-dimensional fluctuation for any length of time-horizon; 3) Demonstrated competetive statistical performance with asymptotically efficient estimators in simple and low-dimensional settings and superior statistical and computational performances in more complex settings; and 4) Applied our method to a real-world medical data with results presented in a format that aligns with clinical research guidelines.
2 Related Work
In the data science literature, several methods were proposed that predict the counterfactual outcomes from patient history. The methods include G-Net (Li et al., 2021), counterfactual recurrent network (CRN) (Bica et al., 2020), and causal transformer (CT) (Melnychuk et al., 2022). However, their target parameters do not involve survival outcomes, and their methods are optimized for the mean squared error (MSE) of the individual predictions, rather than for making statistical inferences. DeepACE (Frauen et al., 2023) is closely related to the present study which uses deep neural networks to estimate the whole propensity scores and outcome regressions simultaneously. Furthermore, it has an additional layer for targeting implementing the one-dimensional submodel proposed by van der Laan (van der Laan & Rose, 2018). Our method differs from theirs in the following three aspects. First, DeepACE incorporates the targeting step within their loss function, which requires an additional hyperparameter. However, there is a lack of justification for the chosen value of this hyperparameter and guidance on its tuning in practical applications. Our approach, in contrast, separates the targeting step, aligning more closely with the TMLE literature. Second, DeepACE does not address survival outcomes, specifically failing to consider the process degeneracy following a patient’s event occurrence. Third, while DeepACE utilizes the long short-term memory (LSTM) architecture, our method employs transformers. Transformers are superior in capturing long-term dependencies and offer greater computational efficiency during training than LSTM. Moreover, DeepACE does not provide uncertainty measures, such as confidence intervals, limiting its utility for statistical inference.
Our problem of estimating mean of counterfactual outcomes from longitudinal observational data under dynamic interventions has been extensively investigated as an off-policy evaluation problem in the bandit algorithm and reinforcement learning literature (Levine et al., 2020). Methods of bias correction after plugging in the initial estimate with influence function were also introduced in this context (Jiang & Li, 2016; Farajtabar et al., 2018; Narita et al., 2021). However, they did not provide tools for inference. Double reinforcement learning (Kallus & Uehara, 2020) utilized the efficient influence functions in the spirit of double machine learning (Chernozhukov et al., 2018), which is a closed form of a more general debiased estimating equation framework (Chernozhukov et al., 2022), to correct plug-in bias and proved efficiency. TMLE deform the distribution itself to correct bias before plugged-in to the the target functional, thereby the values are contained the domain of the functional.
3 Problem Formulation
In this section, follwing the roadmap of causal inference (Petersen & van der Laan, 2014; van der Laan & Rose, 2018; Dang et al., 2023), we first describe the experiment that generated the observed data and the statistical model that contains the data-generating distribution. Next, we define our causal target parameter. Then, we discuss assumptions needed to identify our target parameter from the observed data. Finally, we describe the idea of statistical method for constructing estimator and correcting bias.
3.1 Data
We consider the general longitudinal setting involving repeated measurements of a set of variables for a group of patients over a period of time. In particular, our observed data contains independent and identically distributed copies of random vector
(1) |
with baseline covariates , time-dependent covariates , treatments , and outcome . We use to denote the true probability distribution of that generated the data, and is in some statistical model . Stop** time is a random variable (e.g. time of death in the case of survival analysis) and we use to denote the maximum time. We make the remark that in real-world data, patients are often subject to censoring. For a formulation of the data structure involving censoring nodes, see Appendix H.
3.2 Target Parameter
To define the target parameter, we introduce a structural causal model (SCM). In brief, SCM assumes each observed random variable is generated from the parent nodes and the external noise by a production function as . By abusing notation, we also denote the induced probability measure of by the same symbol . See Appendix C.1 for details.
Our target parameter is the counterfactual mean of the final outcome under a user-specified dynamic treatment policy where is a probability measure on the treatment space conditioned on the whole history, up until (not including ). Specifically, our target parameter is given by
(2) |
which is the mean of the counterfactual outcome produced by replacing , defined as the observed treatment policy from the data, with in the structural causal model.
Identification
Under the positivity assumption:
(3) |
and the sequential randomization assumption:
(4) |
we can identify our target causal parameter through g-formula as the mean of under the counterfactual distribution which is given by replacing distributions with (Robins, 1986):
(5) |
Note that the consistency assumption , usually stated in causal inference literature, is a consequence of the definition of counterfactual outcome in our SCM. Now the problem is reduced to the estimation of the statistical parameter:
(6) |
3.3 Targeted Minimum Loss-based Estimation
Given we have an estimator of the data-generating distribution , a natural estimator of the target functional is the plug-in estimator . Under a regularity condition, admits the following first-order expansion
(7) |
where is the efficient influence function of , and is the exact remainder. Influence functions quantifies the amount of changes of an estimator under small perturbations of the sample. The efficient influence function is the influence function with minimal variance. The idea of TMLE is to eliminate the empirical analogue of the first term of the right hand side by fluctuating to find a distribution with , where is a shorthand for the expectation of a measurable function with respect to a probability measure . Our problem is to obtain an initial estimate with a potentially large scale and high dimensional longitudinal data, and correct bias of the plug-in estimator by fluctuating .
4 Proposed Method
In this section, we describe our proposed method, Deep Longitudinal Targeted Minimum Loss-based Estimation (Deep LTMLE). Let
(8) |
be the mean outcome at stop** time given the history before node for , where future treatments follow a counterfactual treatment assignment policy . Similarly,
(9) |
is the mean outcome at stop** time given the history before node , for . We abbreviate for if it is clear from the context, similarly for . Our goal is to estimate by
(10) |
where is an estimation of such that is asymptotically efficient. We achieve this by proposing a temporal-difference heterogeneous transformer to yield an initial estimation , then update this estimation to get via Targeted Minimum Loss-based Estimation (TMLE).
4.1 Temporal-Difference Heterogeneous Transformer
To learn the initial model , we use temporal-difference loss as the objective to learn underlying models for via stochastic gradient descent (SGD). The principle of temporal difference learning (Sutton, 1988; Mnih et al., 2013) is to supervise to obey the temporal equality of :
(11) |
for and . The temporal difference loss on a sample trajectory is thus given by for , where can be computed by Monte-Carlo estimation if is continuous, and . In the case of survival analysis, the components for are defined as , where is the binary cross entropy loss. To yield the updated model , we need to adjust after model training factoring in the estimating model for the propensity score
(12) |
which we will describe in detail in the next section. Hence, the loss function also needs to include and is thus given by
(13) |
where is a hyperparameter that controls the weights of losses. See Algorithm 1 for the optimization workflow. Convergence of the algorithm can be found in Appendix D.
For the estimation of and , we propose a unified model architecture to simultaneously optimize deep neural networks and in an efficient, non-sequential manner by adapting a decoder-only Transformer (Vaswani et al., 2017; Brown et al., 2020) to longitudinal data with heterogeneous tokens. An overview of the model architecture is given in Figure 1. For each sampled sequence in the training set, we feed each token in the sequence to a linear embedding layer according to its variable type. In the case of Figure 1, there are four different embedding layers , , , and . Each embedding layer has the same number of output dimensions. Then, each embedding is integrated with its positional encoding and type encoding , , , and that represent its timestamp and variable type information through an aggregation function (e.g. sum, concat):
(14) |
for where we used concat as aggr in the experiments in this work. Note that we include type embedding because need not necessarily be type-specific linear layers. For more efficient and parallelizable embedding operation, we can pad each variable to the same number of dimensions before feeding into the same embedding layer . Then, the embedded sequence is fed into the transformer and produce and through output heads and at each position that corresponds to token type and respectively:
(15) | ||||
(16) |
In practice, we can use a joint output layer for and for more efficient and parallelizable output generation, where the output number of dimensions is the sum of the number of dimensions for treatment and for outcome . Then, we compute softmax probabilities masking out the last dimensions for and first dimensions for .
Our proposed architecture does not entail concatenation of variables at the same timestamp or sequential decoding of outputs following the transformer embedding block like prior work Melnychuk et al. (2022), which 1) allows us to handle different types of and different number of variables at different timestamps (e.g. starting from , ending at , while and are missing), and 2) is fully parallelizable when we use padding instead of learnable linear map** for the embedding layer and use the joint output layer .
4.2 Targeted Minimum Loss-based Estimation
Efficient Influence Function
Since our target parameter is the counterfactual mean outcome at the final , the relevant part of of interest are for .
Theorem 4.1.
In our counterfactual mean case, the efficient influence function is given by
(17) |
where and .
This is given in (van der Laan & Gruber, 2012).
4.2.1 Temporal Difference Targeting
Submodel
We update the initial estimate for to such that . We realize this by fluctuating along a one-dimensional submodel through the initial fit given by, , where
(18) |
with a common fluctuation parameter across . If the outcome is survival, then we automatically have . In a general longitudinal setting for bounded ’s, we can re-scale both and to and use the same one-dimensional submodel.
Partial Loss function
We search for the optimal fluctuation with respect to the partial loss function
(19) |
where and , such that satisfies the following theorem:
Theorem 4.2.
For any , we have
(20) |
See Section B for the proof.
Corollary 4.3.
Suppose that we found an satisfying
(21) |
then solves the efficient influence function.
Convergence of Algorithm 2
The investigation of for different and ’s as a function of suggests that they admit different bell curve shapes concentrating at different ’s and have different spread out levels. Thus, by summing up across and across ’s as as a function of will fluctuate a lot and we expect a local minima and local maxima around the neighborhood of . And thus the convergence of the algorithm is highly probable and we don’t discover any issue in our simulations.
Comparison to LTMLE
In the LTMLE, we only need a good estimate of and then do backward sequential regression and targeting as mentioned in (van der Laan & Gruber, 2012). However, the problem is the error in the estimation of can propagate as we progress back to get . Nonetheless, after our initial transformer step, we have good initial estimates for all . So, instead of only relying on a good estimate of , our algorithm makes uses of all of them. and doing targeting across with fluctuation at each level. Thus, we are able to pool information across time when doing the targeting step.
4.2.2 Sequential Targeting
Alternatively, one could apply a sequential targeting procedure that is very similar to LTMLE but with given initials generated from the transformer step.
Submodel
We fluctuate each component of the initial fit along a model as
(22) |
Loss function
Starting from , given we have found , among individuals whose , we search for empirical loss minimizer with respect to the loss function as,
(23) |
where when and when . To initialize, we set .
Lemma 4.4.
Suppose that we found sequentially as mentioned above, then solves the efficient influence function.
Comparaison to LTMLE
While the error can still propagate as we move back in time, the error propagates only through the targeting steps whereas in LTMLE the error can also propagate through regressions. At each time step , LTMLE needs to first regress on to get an estimate and then perform the targeting through the submodel in(22). However, we only use initial estimate from our transformer fit and it does not depend on .
Why not targeting through additional loss function
As in DeepACE, the targeting can be performed through introducing additional loss components to further train the transformer we have build in the first step. This additional loss function will have its derivative equal to the efficient influence function. However, we find that the penalty factor before this loss function is hard to tune and in near all cases, it is hard to guarantee the EIF is solved and most of the time we will hurt our initial fits as shown in Appendix 5.1.
Bias | RMSE | Coverage | Mean | |||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
Model | ||||||||||||
LTMLE (GLM) | 0.0230 | 0.0766 | 0.1344 | 0.0265 | 0.0796 | 0.1381 | 1.00 | 1.00 | 1.00 | 0.43 | 0.69 | 0.76 |
LTMLE (SL) | 0.0144 | 0.0297 | 0.0477 | 0.0185 | 0.0344 | 0.0545 | 1.00 | 1.00 | 1.00 | 0.31 | 0.40 | 0.45 |
DeepACE | -0.0704 | -0.1491 | -0.2396 | 0.0948 | 0.1601 | 0.2453 | 1.00 | 1.00 | 1.00 | 0.74 | 0.69 | 0.57 |
Deep LTMLE | 0.0182 | 0.0304 | 0.0499 | 0.0264 | 0.0342 | 0.0532 | 1.00 | 0.94 | 0.71 | 0.17 | 0.09 | 0.06 |
Deep LTMLE | 0.0158 | 0.0286 | 0.0548 | 0.0188 | 0.0314 | 0.0589 | 1.00 | 0.93 | 0.73 | 0.16 | 0.09 | 0.06 |
Deep LTMLE | 0.0143 | 0.0305 | 0.0471 | 0.0204 | 0.0333 | 0.0509 | 1.00 | 0.93 | 0.76 | 0.16 | 0.08 | 0.06 |
5 Experiments
We conducted two experiments. In the first experiment, we compare the bias, root-mean-squared-error (RMSE), and coverage probability, of our estimator with existing estimators based on 100 times of estimations for both continuous and survival outcomes. The second experiment is an application of our proposed method to a real-world data.
5.1 Synthetic Data with Continuous Outcome
First, we start our experiment with a very simple data generating process with continuous outcome, , and . The data generating proccess is described in the Section F.1. After fitting DeepACE, we additionally performed our targeting precedures on the fit.
The results were shown in Figure 2. Initial fits of Deep LTMLE and DeepACE had comparable bias. Even with the targeting loss, DeepACE failed to solve the efficient influence function. On the other hand, due to the separation of the targeting step in our method, we managed to solve it completely and succeeded in correcting bias.
5.2 Synthetic Data with Survival Outcome
Bias | RMSE | Coverage | Mean | |||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
Model | ||||||||||||
LTMLE (SL) | 0.0075 | 0.0341 | 0.0574 | 0.0138 | 0.0491 | 0.0786 | 0.70 | 0.45 | 0.25 | 0.09 | 0.12 | 0.14 |
DeepACE | -0.0174 | -0.0434 | -0.0770 | 0.0788 | 0.1154 | 0.1341 | 1.00 | 1.00 | 1.00 | 0.67 | 0.78 | 0.86 |
Deep LTMLE | -0.0002 | 0.0108 | 0.0041 | 0.0162 | 0.0720 | 0.0772 | 1.00 | 0.95 | 1.00 | 0.18 | 0.27 | 0.32 |
Deep LTMLE | 0.0058 | 0.0429 | 0.0709 | 0.0205 | 0.0724 | 0.0968 | 0.95 | 0.90 | 0.95 | 0.18 | 0.26 | 0.31 |
Next, we evaluated Deep LTMLE under a highly complex data-generating process with survival outcomes, five-dimensional time-dependent covariates, non-Markovian dependencies, , and , imitating the setups from previous studies (Bica et al., 2020; Frauen et al., 2023). See Section F.3 for details.
Results are presented in Table 1. We observe that Deep LTMLE on average achieves a lower RMSE compared to other methods, particularly in scenarios with larger , indicating its robustness in complex and realistic scenarios without Markovian dependencies. Benefits by our targeting procedures are obvious for . For , we still see reductions in bias and in RMSE when the temporal-difference targeting is applied. While Deep LTMLE’s coverage probability diminished at , the confidence intervals generated by LTMLE and DeepACE were notably over-conservative with large estimated standard errors.
The pronounced bias of DeepACE can likely be attributed to three factors. First, DeepACE’s use of the squared-error-loss for the outcome is known to induce greater bias in sparse outcomes, a common scenario in survival analysis, as opposed to the logistic loss used in our approach (Gruber & van der Laan, 2010). Second, DeepACE failed to solve the efficient influence function. Third, DeepACE does not account for the degeneration of the survival outcome.
Simple Synthetic Data with Survival Outcome
We also conducted an eperiment with a very simple survival synthetic data with one-dimensional time-dependent covariates, , and . Although LTMLE with GLM is expected to have strong performance in this experiment, Deep LTMLE remains highly competitive in this context, equalling LTMLE’s performance (Section G).
5.3 Semi-Synthetic Data
To evaluate the performance of the proposed methods, we generated realistic data from Circulatory Risk in Communities Study (CIRCS) (Yamagishi et al., 2019), a long-term on-going cardiovascular epidemiological cohort study, lasting over a half century. See Section G.1 for the detail.
Table 2 shows the results with semi-synthetic data with unmeasured confounding, which reflects a real world setting. Deep LTMLE performed best in terms of bias for all time horizons. Furthermore, as the time horizon increases from 10 to 30, LTMLE’s coverage probability drops as low as 0.3. On the other hand, Deep LTMLE has nominal coverage even in the longest time-horizon setting.
5.4 Real World Data
We applied Deep LTMLE to real world data from CIRCS. We estimated the counterfactual mean outcomes under the standard blood pressure (SBP) management strategy that controls SBP less than 140 mmHg and the intensive blood pressure management strategy with SBP less than 120 mmHg after the 30 years of sustained management.
In real world applications, we often encounter with practical problems of censoring, that is loss of follow-up for some reasons. Our model can be easily generalized to cover this setting with a slight modification by adding censoring nodes. Details are described in Section H of Appendix.
The results were shown in Figure 3. The average treatment effect (ATE) of the intensive management strategy over the standard management strategy first increased with a peak at 20 years after baseline and then decreased with a fluctuation. The direction and trend of ATE is consistent with the difference of empirical means of cumulative outcomes between two groups followed the two strategies.
5.5 Computation Details
DeepACE and Deep LTMLE were run on a GPU (Tesla T4) with 16 GB memory and LTMLE on CPU (Intel Xeon Skylake 6230 @ 2.1 GHz) with 40 cores and 96 GB memory. We used the R package ltmle with GLM and a super learner (SL) library consisting of GLM, maltivariate adaptive regression spline with earth package, and xgboost for the simple synthetic data and the real world data (Lendle et al., 2017; Polley et al., 2021; Milborrow, 2023; Chen et al., 2022). Confidence intervals for LTMLE was constructed based on its estimate of the efficient influence function.
Time, sec | |||
---|---|---|---|
Model | |||
LTMLE (SL) | 271 | 958 | 2122 |
DeepACE | 53 | 54 | 133 |
Deep LTMLE | 38 | 39 | 116 |
As shown in Table 3, Deep LTMLE leverages GPU acceleration to achieve significantly faster processing times than LTMLE, presenting a substantial computational benefit for analyses involving extensive time horizons and high-dimensional time-dependent covariates.
6 Limitations
Our method assumes the sequential randomization and the positivity assumption on the intervention mechanism to identify the counterfactual outcome from observational data. However, to our surprise, in semi-synthetic data simulations, we found that when there is unmeasured confounding violating the sequential randomization assumption rely on, our method is very robust and could even provide robust inference. Furthermore, our proposed model does not currently address several complexities often found in real-world data, such as visiting processes, competing risks, and continuous time horizons. These challenges will be the focus of our future research efforts.
7 Conclusion
In this paper, we propose a variant of LTMLE that leverages the sequential learning capabilities of transformers. This approach enables simultaneous fitting of the entire LTMLE, allowing us to target the mean survival under dynamic interventions directly through weighting the loss function with cumulative inverse probabilities of intervention. The proposed method performs competitively with asymptotically efficient estimators in low-dimensional settings and exceeds the performance of existing models in high-dimensional scenarios. Scalability of our model to larger and longer datasets was implied. We applied our method to real world data and demonstrated a causal inference on the effect of sustained blood pressure management strategies on total mortality.
Acknowledgement
This research is funded by NIH and Berkeley School of Public Health, Interdisciplinary Collaborative Research Grant. TS is supported by Fulbright scholarship program. The authors thank Dr. Ahmed Alaa at University of California, San Francisco and Berkeley for valuable discussions. The authors acknowledge the CIRCS investigators team for providing the real world data for experiments; Dr. Akihiko Kitamura at Yao City, Dr. Masahiko Kiyama at Osaka Center for Prevention of Cardiovascular Diseases, Dr. Takeo Okada at Osaka Center for Prevention of Cardiovascular Diseases, Dr. Yuji Shimizu at Osaka Center for Prevention of Cardiovascular Diseases, Dr. Hironori Imano at Kinki University, Dr. Tetsuya Ohira at Fukushima Prefeture Medical University, Dr. Kazumasa Yamagishi at Tsukuba University, and Dr. Isao Muraki at Osaka University.
References
- Bang & Robins (2005) Bang, H. and Robins, J. M. Doubly robust estimation in missing data and causal inference models. Biometrics, 61(4):962–973, 2005.
- Bica et al. (2020) Bica, I., Alaa, A. M., Jordon, J., and van der Schaar, M. Estimating counterfactual treatment outcomes over time through adversarially balanced representations. In International Conference on Learning Representations, 2020.
- Bickel et al. (1993) Bickel, P., Klaassen, C., Ritov, Y., and Wellner, J. Efficient and Adaptive Estimation for Semiparametric Models. Johns Hopkins Series in the Mathematical Sciences. Springer New York, 1993. ISBN 978-0-387-98473-5.
- Brown et al. (2020) Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
- Chen et al. (2022) Chen, T., He, T., Benesty, M., Khotilovich, V., Tang, Y., Cho, H., Chen, K., Mitchell, R., Cano, I., Zhou, T., Li, M., Xie, J., Lin, M., Geng, Y., Li, Y., and Yuan, J. xgboost: Extreme Gradient Boosting, 2022. URL https://CRAN.R-project.org/package=xgboost. R package version 1.7.6.1.
- Chernozhukov et al. (2018) Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., Hansen, C., Newey, W., and Robins, J. Double/debiased machine learning for treatment and structural parameters. The Econometrics Journal, 21(1):C1–C68, 2018.
- Chernozhukov et al. (2022) Chernozhukov, V., Escanciano, J. C., Ichimura, H., Newey, W. K., and Robins, J. M. Locally robust semiparametric estimation. Econometrica, 90(4):1501–1535, 2022.
- Dang et al. (2023) Dang, L. E., Gruber, S., Lee, H., Dahabreh, I. J., Stuart, E. A., Williamson, B. D., Wyss, R., Díaz, I., Ghosh, D., Kıcıman, E., Alemayehu, D., Hoffman, K. L., Vossen, C. Y., Huml, R. A., Ravn, H., Kvist, K., Pratley, R., Shih, M.-C., Pennello, G., Martin, D., Waddy, S. P., Barr, C. E., Akacha, M., Buse, J. B., van der Laan, M., and Petersen, M. A causal roadmap for generating high-quality real-world evidence. Journal of Clinical and Translational Science, 7(1):e212, 2023.
- Farajtabar et al. (2018) Farajtabar, M., Chow, Y., and Ghavamzadeh, M. More robust doubly robust off-policy evaluation. In Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pp. 1447–1456. PMLR, 10–15 Jul 2018.
- Frauen et al. (2023) Frauen, D., Hatt, T., Melnychuk, V., and Feuerriegel, S. Estimating average causal effects from patient trajectories. Proceedings of the AAAI Conference on Artificial Intelligence, 37(6):7586–7594, 2023.
- Gruber & van der Laan (2010) Gruber, S. and van der Laan, M. J. A targeted maximum likelihood estimator of a causal effect on a bounded continuous outcome. The International Journal of Biostatistics, 6(1):Article 26, 2010. ISSN 1557-4679. doi: 10.2202/1557-4679.1260.
- Gruber & van der Laan (2012) Gruber, S. and van der Laan, M. J. Targeted minimum loss based estimation of a causal effect on an outcome with known conditional bounds. The international journal of biostatistics, 8(1):21–21, 2012. ISSN 1557-4679.
- Jiang & Li (2016) Jiang, N. and Li, L. Doubly robust off-policy value evaluation for reinforcement learning. In Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, pp. 652–661, New York, New York, USA, 20–22 Jun 2016. PMLR.
- Kallus & Uehara (2020) Kallus, N. and Uehara, M. Double reinforcement learning for efficient and robust off-policy evaluation. In Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pp. 5078–5088. PMLR, 13–18 Jul 2020.
- Kennedy (2022) Kennedy, E. H. Semiparametric doubly robust targeted double machine learning: A review. arXiv preprint arXiv:2203.06469, 2022.
- Klaassen (1987) Klaassen, C. A. J. Consistent estimation of the influence function of locally asymptotically linear estimators. The Annals of Statistics, 15(4):1548–1562, 1987.
- Lendle et al. (2017) Lendle, S. D., Schwab, J., Petersen, M. L., and van der Laan, M. J. ltmle: An R package implementing targeted minimum loss-based estimation for longitudinal data. Journal of Statistical Software, 81(1):1–21, 2017. doi: 10.18637/jss.v081.i01.
- Levine et al. (2020) Levine, S., Kumar, A., Tucker, G., and Fu, J. Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems. arXiv preprint arXiv:2005.01643, 2020.
- Li et al. (2021) Li, R., Hu, S., Lu, M., Utsumi, Y., Chakraborty, P., Sow, D. M., Madan, P., Li, J., Ghalwash, M., Shahn, Z., and Lehman, L.-w. G-net: A recurrent network approach to G-computation for counterfactual prediction under a dynamic treatment regime. In Proceedings of Machine Learning for Health, volume 158 of Proceedings of Machine Learning Research, pp. 282–299. PMLR, 2021.
- Melnychuk et al. (2022) Melnychuk, V., Frauen, D., and Feuerriegel, S. Causal transformer for estimating counterfactual outcomes. In Chaudhuri, K., Jegelka, S., Song, L., Szepesvari, C., Niu, G., and Sabato, S. (eds.), Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pp. 15293–15329. PMLR, 2022.
- Milborrow (2023) Milborrow, S. earth: Multivariate Adaptive Regression Splines, 2023. URL https://CRAN.R-project.org/package=earth. R package version 5.3.2.
- Mnih et al. (2013) Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D., and Riedmiller, M. Playing atari with deep reinforcement learning. arXiv preprint arXiv:1312.5602, 2013.
- Narita et al. (2021) Narita, Y., Yasui, S., and Yata, K. Debiased off-policy evaluation for recommendation systems. In Proceedings of the 15th ACM Conference on Recommender Systems, RecSys ’21, pp. 372–379, New York, NY, USA, 2021. Association for Computing Machinery. ISBN 9781450384582.
- Petersen & van der Laan (2014) Petersen, M. L. and van der Laan, M. J. Causal models and learning from data: Integrating causal modeling and statistical estimation. Epidemiology (Cambridge, Mass.), 25(3):418–426, 2014.
- Polley et al. (2021) Polley, E., LeDell, E., Kennedy, C., and van der Laan, M. SuperLearner: Super Learner Prediction, 2021. URL https://CRAN.R-project.org/package=SuperLearner. R package version 2.0-28.1.
- Robins (1986) Robins, J. A new approach to causal inference in mortality studies with a sustained exposure period—application to control of the healthy worker survivor effect. Mathematical modelling, 7(9-12):1393–1512, 1986.
- Robins et al. (1994) Robins, J. M., Rotnitzky, A., and Zhao, L. P. Estimation of Regression Coefficients When Some Regressors Are Not Always Observed. Journal of the American Statistical Association, 89(427):846–866, 1994.
- Salerno & Li (2023) Salerno, S. and Li, Y. High-dimensional survival analysis: Methods and applications. Annual review of statistics and its application, 10:25–49, 2023.
- Sutton (1988) Sutton, R. S. Learning to predict by the methods of temporal differences. Machine learning, 3:9–44, 1988.
- van der Laan & Rubin (2006) van der Laan, M. and Rubin, D. Targeted Maximum Likelihood Learning. The International Journal of Biostatistics, 2(1), 2006.
- van der Laan & Gruber (2012) van der Laan, M. J. and Gruber, S. Targeted Minimum Loss Based Estimation of Causal Effects of Multiple Time Point Interventions. The International Journal of Biostatistics, 8(1), 2012.
- van der Laan & Robins (2003) van der Laan, M. J. and Robins, J. Unified Methods for Censored Longitudinal Data and Causality. Springer Series in Statistics. Springer New York, 2003. ISBN 978-0-387-21700-0.
- van der Laan & Rose (2011) van der Laan, M. J. and Rose, S. Targeted Learning: Causal Inference for Observational and Experimental Data. Springer Series in Statistics. Springer, 2011. ISBN 978-1-4419-9781-4 978-1-4419-9782-1.
- van der Laan & Rose (2018) van der Laan, M. J. and Rose, S. Targeted Learning in Data Science: Causal Inference for Complex Longitudinal Studies. Springer Series in Statistics. Springer International Publishing, 2018.
- van der Laan et al. (2007) van der Laan, M. J., Polley, E. C., and Hubbard, A. E. Super learner. Statistical Applications in Genetics and Molecular Biology, 6(1):1309–1309, 2007. ISSN 1544-6115.
- Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., and Polosukhin, I. Attention is all you need. In Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
- Wyss et al. (2022) Wyss, R., Yanover, C., El‐Hay, T., Bennett, D., Platt, R. W., Zullo, A. R., Sari, G., Wen, X., Ye, Y., Yuan, H., Gokhale, M., Patorno, E., and Lin, K. J. Machine learning for improving high‐dimensional proxy confounder adjustment in healthcare database studies: An overview of the current literature. Pharmacoepidemiology and drug safety, 31(9):932–943, 2022. ISSN 1053-8569.
- Yamagishi et al. (2019) Yamagishi, K., Muraki, I., Kubota, Y., Hayama-Terada, M., Imano, H., Cui, R., Umesawa, M., Shimizu, Y., Sankai, T., Okada, T., Sato, S., Kitamura, A., Kiyama, M., and Iso, H. The Circulatory Risk in Communities Study (CIRCS): A Long-Term Epidemiological Study for Lifestyle-Related Disease Among Japanese Men and Women Living in Communities. Journal of Epidemiology, 29(3):83–91, 2019.
Appendix A Notation
Here we list notations used in the article.
Observed variables | |
Maximum length of time-horizon | |
Stop** time | |
Baseline covariates | |
Time-dependent covariates (states) | |
Time-dependent treatments (controls) | |
Outcomes. In survival case, binary failure indicator defined as | |
Outcome at the end of the trajectory: | |
Parent nodes of . For example, | |
The true distribution of the observed variable | |
Estimator of | |
Propensity scores with | |
User-specified treatment policies | |
Target functional | |
True parameter | |
Estimator of | |
Estimator of the standard error of the estimator | |
State-action value functions | |
Value functions | |
for and | |
Propensity scores with | |
Clever covariates (importance weights) | |
Efficient influence function of : | |
Local least favorable submodel | |
Local least favorable submodel | |
for and | |
Loss function for temporal difference learning | |
Loss function for targeting | |
Weight for the propensity loss (hyperparameter) | |
Mean of a function under the distribution : | |
Embedding of a node | |
Type embedding of a node | |
Positional encoding at time | |
production function of a node |
Appendix B Proof
Appendix C Review of TMLE
C.1 Structural Causal Model
We assume each node depends on the all previous nodes in the trajectory, that is, we do not assume the Markovian property. And each node is produced from the parent nodes and independent noise random variables by a measurable function : . This production function induces a conditional distribution of given by pushing forward the distribution of noise variable: for all measurable , where is a domain of random vector . Starting from nodes without parents including noise nodes and their distributions, production functions and their causal structure, which can be described by a directed acyclic graph over the ovservables, generate the joint distribution of the observed random variables. With our particular data in longitudinal setting, we define the propensity score , where is the patient history before the node . We use the same symbol for the production function if the treatment assignment is deterministics, that is, there is no noise variable in generating the treatment node: if for some specific .
C.2 Causal Target Parameter and Identification
Our target parameter is the counterfactual mean of the final outcome under the user-specified dynamic treatment policy . This is the mean of counterfactual outcome which is produced by replacing with in the structural causal model:
(24) |
To identify this causal target paratmer from observatoinal data, we assume the following conditions of the positiviy:
(25) |
and the sequential randomization:
(26) |
Note that the consistency usually stated in the causal inference literature is a consequence of the definition of counterfactual outcome in our structural causal model. Under these identifiability conditions, this parameter is identified through g-formula that is the mean of under the counterfactual distribution which is given by replacing distributions with :
(27) |
Then the problem reduced to the estimation of the statistical parameter:
(28) |
C.3 TMLE
Bias correction by TMLE is based on the following first order approximation of the target functional around the true distribution (van der Laan & Rubin, 2006; van der Laan & Rose, 2011; Kennedy, 2022):
(29) |
where is called influence function and is the second order remainder. This equation is the infinite dimensional extension of Taylor expansion.
The right hand side of this equation can be further written as:
(30) |
whose second term called empirical process term converges to zero in the rate of square root of if belong to the Donsker class and converges to in . Given a good initial fit of , above conditions are usually satisfied and, in addition, . Thus, by further using the fact about the influence function that , the right hand side reduced to
(31) |
Now, the idea is to find in the close neighborhood of that solves the empirical analog of the first term:
(32) |
By doing so, using similar arguments as above for instead of , we have the following.
(33) |
Thus, our estimator is a plug in estimator and attains the efficiency bound among the asymptotically linear and regular estimators.
C.4 Efficient influence curve
Then the efficient influence function of our target parameter is computed as follows (van der Laan & Gruber, 2012)
(34) | ||||
where by definition.
Appendix D Convergence of Temporal Difference Learning
First, consider a flexible model and corresponding . Initiate and then iteratively update by for till convergence. Our proof below shows that if we use a variation independent parameter space for each and the parameter spaces contain the true , then in -steps this algorithm will have converged to the true solution .
Ignoring the parameterization, but just thinking in terms of optimizing over parameter spaces, this algorithm corresponds with: initiate , and then for , compute and set as the one implied by the intervention and ; and set .
Firstly, we claim that in a nonparametric model the -specific parameters are variation independent across . Consider a given (misspecified). This implies a parameter space for the regressions . The parameter space of the free parameter is even larger than the parameter space of functions of . Therefore this appears indeed a reasonable condition. Then we can state that the parameter space over which we optimize at step of the algorithm is the cartesian product of the parameter spaces for across . Consider the -step of the algorithm in which the outcomes are and we optimize over all the . Then, is the minimizer of . That means that the derivative w.r.t. along a path through in any direction at should be equal to zero, across all . Thus, at , we have
Consider the derivative w.r.t. . This yields the score equation for all . This implies that . The others are some optimizer. Now, we go to step . We now know that due to . Therefore, at the next step, due to the derivative w.r.t. , it follows that , while it again . Then, at step , we also obtain . In this manner, it follows that after steps we have .
Appendix E Hyperparameter Tuning
We selected hyperparameters shown in Table 4 which optimized the empiricall loss in the validation set which is the 30% of the entire dataset. The parameter and for censoring mechanism balances the learning rate of -parts and -parts because the complexity of -parts would be simpler than -parts which involves prediction in the long-range.
Data | Simple Synthetic Data | Complex Synthetic Data | Real World | ||||||||||
Model | Deep LTMLE | DeepACE | Deep LTMLE | DeepACE | Deep LTMLE | ||||||||
10 | 20 | 30 | 10 | 20 | 30 | 10 | 20 | 30 | 10 | 20 | 30 | 30 | |
Embedding dimension | 32 | 32 | 32 | 32 | 32 | 32 | 16 | 32 | 32 | 32 | 32 | 32 | 32 |
Dropout rate | 0 | 0 | 0.1 | 0.3 | 0.2 | 0.1 | 0 | 0 | 0 | 0.3 | 0.2 | 0.2 | 0 |
Hidden size | 64 | 64 | 16 | 8 | 4 | 4 | 64 | 32 | 16 | 4 | 4 | 4 | 16 |
Number of Layers | 8 | 4 | 4 | 1 | 1 | 2 | 4 | 4 | 4 | 2 | 8 | 2 | 8 |
Number of heads | 8 | 4 | 4 | — | — | — | 8 | 8 | 8 | — | — | — | 4 |
Learning rate | 1e-04 | 5e-04 | 5e-04 | 5e-03 | 1e-02 | 5e-03 | 1e-03 | 5e-04 | 1e-04 | 5e-04 | 5e-04 | 5e-04 | 5e-04 |
0.1 | 0.01 | 0.01 | 0.01 | 0.1 | 0.05 | 0.01 | 0.05 | 0.05 | 0.01 | 0.1 | 0.1 | 0.1 | |
— | — | — | 0.05 | 0.05 | 0.05 | — | — | — | 0.05 | 0.05 | 0.05 | 0.01 | |
Number of epochs | 100 | 200 | 400 | 100 | 200 | 100 | 100 | 100 | 400 | 100 | 100 | 100 | 100 |
Appendix F Synthetic Data
F.1 Simple Synthetic Data with Continuous Outcome
The process iteratively generates variables , , , and over time steps , for . . At , , , . For , , , , is the sigmoid function. We set the counterfactual treatment at all time-points to 1 and and evaluated the counterfactual mean of survival under this treatment policy.
F.2 Simple Synthetic Data with Survival Outcome
The process iteratively generates variables , , , and over time steps , for . . At , , , . For , , , , with implying . Here is the sigmoid function. We set the counterfactual treatment at all time-points to 1 and and evaluated the counterfactual mean of survival under this treatment policy.
F.3 Complex Synthetic Data with Survival Outcome
First draw parameters and for , where is the length of time-dependency with corresponding to Markovian process. Then, draw error in time-dependet variables for and , errors in treatment , for . For each , , then draw from an indicator function , with . The outcome is drawn from a Bernoulli distribution of a probability with . if for . We set the counterfactual treatment policy as for and evaluated the counterfactual mean of survival under this policy.
Appendix G Results with Simple Synthetic Data with Survival Outcome
Results of an experiment with the simple synthetic data described in Section F.2 was shown in Table 5. Although LTMLE’s strong performance on simple synthetic data is anticipated due to reduced burden in estimating nuisance parameters from Markovian dependencies, Deep LTMLE remains highly competitive in this context, equalling LTMLE’s performance. Our two targeting approaches demonstrated better bias variance trade off for the estimation of the target parameter compared to the untargeted approach. Both bias and standard deviation get improved a lot for all ’s considered. The targeting step made a marked difference in terms of coverage probability, getting much closer to a nominal 95% coverage probability compared to the one without targeting.
Bias | RMSE | Coverage | Mean | |||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
Model | ||||||||||||
LTMLE (GLM) | 0.0052 | 0.0045 | 0.0021 | 0.0202 | 0.0268 | 0.0308 | 0.95 | 0.94 | 0.93 | 0.02 | 0.03 | 0.03 |
LTMLE (SL) | 0.0056 | 0.0058 | 0.0061 | 0.0203 | 0.0263 | 0.0311 | 0.91 | 0.93 | 0.91 | 0.02 | 0.02 | 0.03 |
DeepACE | 0.0213 | 0.0462 | -0.1342 | 0.0266 | 0.0515 | 0.1397 | 1.00 | 1.00 | 1.00 | 0.19 | 0.70 | 0.70 |
Deep LTMLE | 0.0080 | 0.0133 | 0.0090 | 0.0292 | 0.0569 | 0.0449 | 0.79 | 0.78 | 0.87 | 0.02 | 0.04 | 0.03 |
Deep LTMLE | 0.0054 | 0.0070 | 0.0080 | 0.0207 | 0.0350 | 0.0329 | 0.91 | 0.95 | 0.91 | 0.02 | 0.04 | 0.03 |
Deep LTMLE | 0.0053 | 0.0053 | 0.0080 | 0.0207 | 0.0361 | 0.0310 | 0.90 | 0.96 | 0.92 | 0.02 | 0.04 | 0.03 |
G.1 Semi-Synthetic Data
As a compromise, we conducted several additional experiments with semi-synthetic data from the real world data as used in previous studies (Bica et al., 2020; Frauen et al., 2023). For this experiment, we used covariates from the Circulatory Risk in Communities Study (CIRCS) and fit outcome regression given the history through each time point using XGBoost with early stop**. Outcomes were then generated using this fitted regression model. For the experiment, we sample 1000 observations from the empirical dstribution of covariates and generate for with .
Appendix H Extension to Survival Analysis with Censoring
In this section, we describe the extended LTMLE model with censoring for the real world application in Section 5.4. We assume the following order of observed nodes , where are binary censoring nodes with indicating one being censord. Our interest is to estimate the risk of our outcome , the mortality of the individual. However, our observation period spans long-term, individuals are at risk of being censored. Censoring is loss of follow-up from administrative reasons, for example, move to other areas or denial of participation in the survey. We assume degenerations of nodes. When we observe a jump in or nodes, the process halts and all nodes after the jump remain constant with the last observed values. For example, if , then , , , and for all .
We constructed a Deep LTMLE similar to the one describe in Section 4 with this structure. The only difference is an additional component of censoring mechanism which is involved in the clever covariate and the loss function:
(35) | ||||
(36) |
where is an additional hyperparameter for the loss function of binary logistic loss. The counterfactual treatment on the censoring process is meaning supression of censoring. Estimates of the target parameter and the efficient influence curve for different treatment strategies are computed using Deep LTMLE, and average treatment effects (ATEs) and its EIC were computed using the delta method. Based on the estimated EICs of the target parameters at each time point t, we constructed a simultaneous confidence intervals.