This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Adversarially Balanced Representation for Continuous Treatment Effect Estimation

Amirreza Kazemi1, Martin Ester1
Abstract

Individual treatment effect (ITE) estimation requires adjusting for the covariate shift between populations with different treatments, and deep representation learning has shown great promise in learning a balanced representation of covariates. However the existing methods mostly consider the scenario of binary treatments. In this paper, we consider the more practical and challenging scenario in which the treatment is a continuous variable (e.g. dosage of a medication), and we address the two main challenges of this setup. We propose the adversarial counterfactual regression network (ACFR) that adversarially minimizes the representation imbalance in terms of KL divergence, and also maintains the impact of the treatment value on the outcome prediction by leveraging an attention mechanism. Theoretically we demonstrate that ACFR objective function is grounded in an upper bound on counterfactual outcome prediction error. Our experimental evaluation on semi-synthetic datasets demonstrates the empirical superiority of ACFR over a range of state-of-the-art methods.

Introduction

Estimating the individual treatment effect (ITE) from observational datasets has important applications in domains such as personalized medicine (Kent, Steyerberg, and van Klaveren 2018; Prosperi et al. 2020), economics (Varian 2016), and recommendation systems (Wang et al. 2020). The observational dataset contains units with their covariates XX, the assigned treatment TT, and outcome after intervention YY (also known as the factual outcome). Since the treatment assignment policy is unknown, there typically exists an inherent treatment-selection bias stemming from confounding covariates that influence both the treatment assignment and the outcome. Consequently, a causal ITE estimator requires adjusting for the covariate shift among different treatment populations (Bareinboim, Tian, and Pearl 2014; Bareinboim and Pearl 2012).

In recent years, the representation learning approach (Johansson, Shalit, and Sontag 2016; Shalit, Johansson, and Sontag 2017) has demonstrated remarkable success in adjusting for covariate shift. Briefly, the idea is to learn a balanced representation of covariates (rather than balancing the covariates themselves) using an encoder and then predicting the outcomes from the representation using an outcome prediction network. However, the majority of the representation learning methods have been proposed for binary treatments (Hassanpour and Greiner 2019; Zhang, Liu, and Li 2021; Yao et al. 2018), and incorporating continuous treatments into their architectures is challenging (Chu et al. 2023). We elaborate on the associated challenges with continuous treatments in the following.

i) Balancing representation for continuous treatments
To achieve a balanced representation ZZ, most existing methods minimize the shift in terms of Integral Probability Metric (IPM) distance 111IPMG(p,q)=supgGSg(s)(p(s)q(s))𝑑sIPM_{G}(p,q)=\sup_{g\in G}\int_{S}g(s)\big{(}p(s)-q(s)\big{)}ds, where specifying GG leads to different distributional distance. (Sriperumbudur et al. 2009) between the distributions of P(Z,T)P(Z,T) and P(Z)P(T)P(Z)\,P(T) as it allows bounding the counterfactual prediction error (Shalit, Johansson, and Sontag 2017). Given that the marginal distribution P(Z)P(Z) is unknown, (Bellot, Dhir, and Prando 2022) proposed minimizing the IPM distance between P(Z,T=t)P(Z,T=t) and P(Z,T=¬t)P(Z,T=\lnot t) where tt denotes a treatment value in the data and ¬t\lnot t denotes all treatments except tt. Similarly, (Wang et al. 2022a) suggested discretizing the treatment range into intervals and minimized the maximum IPM distance between the distributions of the two intervals. Despite minimizing an upper bound, these methods involve non-parametric approximations of several IPM distances in practice, which may be inaccurate for high-dimensional representation and small training data (Liang 2019). Furthermore, IPMs are by definition worst-case distances and and obtaining a treatment-invariant representation through IPM might be overly restrictive, potentially excluding important confounding factors for outcome prediction (Zhang, Bellot, and van der Schaar 2020).

There are also non-IPM representation learning methods. For instance, (Du et al. 2021; Berrevoets et al. 2020) adopt an adversarial discriminator in order to balance the distributions in the latent representation, (Yao et al. 2018) proposed preserving the local similarity in the representation space, (Hassanpour and Greiner 2019; Wu et al. 2022) learn a disentangled representation to distinguish different latent factors, (Zhang, Bellot, and van der Schaar 2020) enforce invertibility of the encoder function to prevent the loss of covariate information, and (Alaa and van der Schaar 2017) introduced a regularization scheme in order to generalize to counterfactual outcomes. However, these methods lack theoretical justification (Du et al. 2021; Hassanpour and Greiner 2019; Yao et al. 2018) or their guarantees are limited to binary treatments (Zhang, Bellot, and van der Schaar 2020; Alaa and van der Schaar 2017).

ii) Treatment impact in outcome prediction network
In order to predict outcomes of different treatments, the outcome prediction network on top of the representation needs to incorporate the treatment value. Considering the treatment variable as an input of the outcome prediction network along with the representation causes overfitting the much higher-dimensional representation and largely limits the impact of the treatment value. Also, a distinct prediction head for each treatment value (as in the case of binary treatments) is not practical. Instead, (Schwab et al. 2020) proposed dose response network (DRNet) which divides the treatment range into intervals and consider a distinct network for the prediction of each interval. (Nie et al. 2021) proposed a varying coefficient network (VCNet) that involves the treatment value in the network parameters through spline functions of treatment. Nonetheless, both networks can not capture the dependency between representation and treatment effectively since the choice of spline functions in VCNet (or the intervals in DRNet) are made irrespective to representation values (Zhang et al. 2022).

In this paper, we propose a representation learning method to accurately predict potential outcomes of a continuous treatment. We address the above challenges through the following contributions:

  1. 1.

    We prove that under certain assumptions the counterfactual error is bounded by the factual error and the KL divergence between P(Z)P(T)P(Z)P(T) and P(Z,T)P(Z,T). Unlike the IPM distance, the KL divergence can be estimated parametrically, leading to a more reliable bound.

  2. 2.

    We propose Adversarial Balanced Counterfactual Regression (ACFR) network. ACFR minimizes the KL divergence using an adversarial game extracting a balanced representation for continuous treatments. ACFR also minimizes the factual prediction error by a cross-attention network that captures the complex dependency between treatment and the representation.

  3. 3.

    We conduct an experimental comparison of ACFR against state-of-the-art methods on semi-synthetic datasets, News and TCGA, and analyze the robustness to varying-levels of treatment-selection bias for the methods.

Problem Setup

We assume a dataset of the form D={xi,ti,yi}i=1ND=\{x_{i},t_{i},y_{i}\}_{i=1}^{N}, where xi𝒳dx_{i}\in\mathcal{X}\subseteq\mathbb{R}^{d} denotes the covariates of the iith unit, ti[0,1]t_{i}\in[0,1] is the continuous treatment that unit ii received, and yi𝒴y_{i}\in\mathcal{Y}\subseteq\mathbb{R} denotes the outcome of interest for unit ii after receiving treatment tit_{i}. NN is the total number of units, and dd is the dimension size of covariates. We are interested in learning a machine learning model to predict the causal quantity μ(x,t)=E𝒴[Y(t)|X=x]\mu(x,t)=E_{\mathcal{Y}}[Y(t)|X=x], which is the potential expected outcome under treatment tt for the individual with covariates xx. Note that, unlike binary ITE, the goal is to predict all potential outcomes, not just the difference between them. Similar to previous works, we rely on the following standard assumptions to make treatment effects identifiable from an observational dataset.

{restatable*}[ Assumption 1 - Unconfoundedness]thm ({Yt}tTT|X\{Y_{t}\}_{t\in T}\perp T|X). In words, given covariates, treatment and potential outcomes are independent.

{restatable*}[ Assumption 2 - Overlap]thm (P(T=t|X=x)>0,t[0,1],xXP(T=t|X=x)>0,\forall t\in[0,1],\forall x\in X). In words, every unit receives treatment level tt with a probability greater than zero.

With these assumptions, μ(x,t)\mu(x,t) can be rewritten as follows, and we can estimate it:

μ(x,t)=E𝒴[Y(t)|X=x]=E𝒴[Y|X=x,T=t]\mu(x,t)=E_{\mathcal{Y}}[Y(t)|X=x]=E_{\mathcal{Y}}[Y|X=x,T=t]

Theoretical Analysis

We analyze the properties of encoder function ϕ:𝒳𝒵\phi:\mathcal{X}\rightarrow\mathcal{Z}, where 𝒵\mathcal{Z} is the representation space, outcome prediction function h:𝒵×[0,1]𝒴h:\mathcal{Z}\times[0,1]\rightarrow\mathcal{Y} and loss function L:𝒴×𝒴+L:\mathcal{Y}\times\mathcal{Y}\rightarrow\mathbb{R}^{+}.

{restatable*}[Definition 1]thm Define L,h,ϕ(x,t)=L(h(ϕ(x),t),y):𝒳×[0,1]+\ell_{L,h,\phi}(x,t)=L(h(\phi(x),t),y):\mathcal{X}\times[0,1]\rightarrow\mathbb{R}^{+} to be the unit-loss for a unit with covariate xx that is intervened with treatment tt. Unit-loss L,h,ϕ(x,t)\ell_{L,h,\phi}(x,t) measures loss LL between the predicted outcome y^=h(ϕ(x),t)\hat{y}=h(\phi(x),t) and the ground-truth outcome y=μ(x,t)y=\mu(x,t).

Using the definition of unit-loss, we can define the expected prediction error of some treatment tt by marginalizing over the covariate distribution. As a result of treatment-selection bias, covariate distribution of samples having received treatment tt (factual) and samples not having received treatment tt (counterfactual) are different. We define factual error εf(t)\varepsilon_{f}^{\ell}(t) by marginalizing over p(x|t)p(x|t) and counterfactual error εcf(t)\varepsilon_{cf}^{\ell}(t) by marginalizing over p(x)p(x) as follows.

εf(t)=𝒳L,h,ϕ(x,t)p(x|t)𝑑x\displaystyle\varepsilon_{f}^{\ell}(t)=\int_{\mathcal{X}}\ell_{L,h,\phi}(x,t)\,p(x|t)\,dx
εcf(t)=𝒯=[0,1]{t}𝒳L,h,ϕ(x,t)p(x|t)𝑑x𝑑t\displaystyle\varepsilon_{cf}^{\ell}(t)=\int_{\mathcal{T}^{\prime}=[0,1]-\{t\}}\int_{\mathcal{X}}\ell_{L,h,\phi}(x,t)p(x|t^{\prime})\,dx\,dt^{\prime}
=𝒳L,h,ϕ(x,t)p(x)𝑑x\displaystyle\qquad=\int_{\mathcal{X}}\ell_{L,h,\phi}(x,t)\,p(x)\,dx

We also define the expected error of all treatments by marginalizing over their range [0,1][0,1] as follows: εf=[0,1]𝒳L,h,ϕ(x,t)p(x,t)𝑑x𝑑t\varepsilon_{f}^{\ell}=\int_{[0,1]}\int_{\mathcal{X}}\ell_{L,h,\phi}(x,t)\,p(x,t)\,dx\,dt and εcf=[0,1]𝒳L,h,ϕ(x,t)p(x)p(t)𝑑x𝑑t\varepsilon_{cf}^{\ell}=\int_{[0,1]}\int_{\mathcal{X}}\ell_{L,h,\phi}(x,t)\,p(x)\,p(t)\,dx\,dt. Note that the expected factual error εf\varepsilon_{f}^{\ell} integrates over joint distribution p(x,t)p(x,t), and expected counterfactual error εcf\varepsilon_{cf}^{\ell} integrates over p(x)p(t)p(x)\,p(t). We aim to reduce the distributional distance in representation space ZZ to ensure that minimizing εf\varepsilon_{f}^{\ell} results in minimizing εcf\varepsilon_{cf}^{\ell}. We need the following assumption on encoder ϕ\phi to ensure balancing properties generalize from the representation space to the covariate space.

{restatable*}[Assumption 3]thmass3 The encoder function ϕ\phi is a twice-differentiable one-to-one mapping and the representation space 𝒵\mathcal{Z} is the image of 𝒳\mathcal{X} under ϕ\phi with the induced distribution pϕ(z)p_{\phi}(z).

We also need the following assumption to ensure the unit-loss is not arbitrary large for any (x,t)(x,t) pair. Constraint on the unit-loss function is required for the IPM distance specification as well.

{restatable*}[Assumption 4]thmass4 Let GG be a class of functions with infinity norm less than 1, G={g:𝒵×[0,1]+|||g||1}G=\{g:\mathcal{Z}\times[0,1]\rightarrow\mathbb{R}^{+}\,|\,||g||_{\infty}\leq 1\}. Then, there exist a constant C>0C>0 such that L,h,ϕ(x,t)CG\frac{\ell_{L,h,\phi}(x,t)}{C}\in G. This means for any (x,t)(x,t) we have L,h,ϕ(x,t)C1\frac{\ell_{L,h,\phi}(x,t)}{C}\leq 1.

Note that Assumptions 3 and 4 are common in representation learning literature (Shalit, Johansson, and Sontag 2017; Bellot, Dhir, and Prando 2022; Hassanpour and Greiner 2019). Now we present our main theoretical results which demonstrates a bound on the expected counterfactual error εcf\varepsilon_{cf}^{\ell} consisting of the expected factual error εf\varepsilon_{f}^{\ell} and the KL divergence between distributions pϕ(z,t)p_{\phi}(z,t) and pϕ(z)p(t)p_{\phi}(z)\,p(t).

{restatable*}[Proposition 1 - Counterfactual Generalization Bound]thmprop1 Given the one-to-one encoder function ϕ:𝒳𝒵\phi:\mathcal{X}\rightarrow\mathcal{Z}, the outcome prediction function h:𝒵×[0,1]𝒴h:\mathcal{Z}\times[0,1]\rightarrow\mathcal{Y}, and the unit-loss function L,h,ϕ(x,t)\ell_{L,h,\phi}(x,t) that satisfies Assumption 4,

εcfεf+C2DKL(pϕ(z,t)||pϕ(z)p(t))\varepsilon_{cf}^{\ell}\leq\varepsilon_{f}^{\ell}+C\sqrt{2\,D_{KL}\bigg{(}p_{\phi}(z,t)\,\big{|}\big{|}\,p_{\phi}(z)p(t)\bigg{)}}

Note that the KL divergence is non-negative and becomes zero if and only if two distributions are the same. Therefore, DKL(pϕ(z,t)||pϕ(z)p(t))=0D_{KL}\bigg{(}p_{\phi}(z,t)||p_{\phi}(z)p(t)\bigg{)}=0 implies pϕ(z,t)=pϕ(z)p(t)p_{\phi}(z,t)=p_{\phi}(z)\,p(t) and εf=εcf\varepsilon_{f}^{\ell}=\varepsilon_{cf}^{\ell}. We can also interpret the above bound from information theory perspective. Briefly, minimizing RHS of Proposition 1 results in representation zz which has maximum mutual information with outcome yy given treatment tt (equivalent to minimizing εf\varepsilon_{f}^{\ell}) and minimum mutual information with treatment tt (equivalent to minimizing KL divergence term).

For some applications, one might be interested in the treatment effect between two different treatments rather than predicting all counterfactual outcomes. For instance, in binary treatment setting it is standard to report the model performance in terms of precision of estimating heterogeneous effect (PEHE) (Hill 2011) which measures the squared difference between ground-truth treatment effect τ(x)=μ(x,1)μ(x,0)\tau(x)=\mu(x,1)-\mu(x,0) and predicted treatment effect τ^(x)=h(ϕ(x),1)h(ϕ(x),0)\hat{\tau}(x)=h(\phi(x),1)-h(\phi(x),0). We define the continuous counterpart εpehe(t1,t2)\varepsilon_{pehe}(t_{1},t_{2}) between two treatment levels t1t_{1} and t2t_{2} and present an upper bound on it in Proposition 2.

{restatable*}[Definition 2]thmdef2 Define εpehe(t1,t2)=𝒳[(μ(x,t1)μ(x,t2))(h(ϕ(x),t1)h(ϕ(x),t2))]2p(x)𝑑x\varepsilon_{pehe}(t_{1},t_{2})=\int_{\mathcal{X}}\big{[}\big{(}\mu(x,t_{1})-\mu(x,t_{2})\big{)}-\big{(}h(\phi(x),t_{1})-h(\phi(x),t_{2})\big{)}\big{]}^{2}p(x)dx to be the expected precision of estimating heterogeneous effect between treatment levels t1t_{1} and t2t_{2}.

{restatable*}[Proposition 2 - Precision of Estimating Heterogeneous Effect Bound]thmprop2 Given the one-to-one encoder function ϕ\phi and outcome prediction function hh as in Proposition 1, and a unit-loss function L,h,ϕ(x,t)\ell_{L,h,\phi}(x,t) that satisfies Assumption 4 and its associated LL is squared error ||.||2||.||^{2},

εpehe(t1,t2)εf(t1)+εf(t2)+\displaystyle\varepsilon_{pehe}(t_{1},t_{2})\leq\varepsilon_{f}^{\ell}(t_{1})\,+\,\varepsilon_{f}^{\ell}(t_{2})\,+\,
C[2DKL(pϕ(z)||pϕ(z|t1))+2DKL(pϕ(z)||pϕ(z|t2))]\displaystyle C\Bigg{[}\sqrt{2\,D_{KL}\bigg{(}p_{\phi}(z)\,\big{|}\big{|}\,p_{\phi}(z|t_{1})\bigg{)}}+\sqrt{2\,D_{KL}\bigg{(}p_{\phi}(z)||p_{\phi}(z|t_{2})\bigg{)}}\,\Bigg{]}

You can find the proof of Proposition 1 and 2 in the Appendix A.

Method

Refer to caption
Figure 1: The architecture of Adversarial CoutnerFactual Regression Network consisting of three sub-networks encoder ϕ\phi, outcome predictor hh, and treatment predictor π\pi. Networks ϕ\phi and hh are trained to minimize the outcome prediction loss lpredl_{pred}, and networks ϕ\phi and π\pi are trained to maximize / minimize adversarial loss ladvl_{adv}. The encoder and treatment predictor are implemented using linear layers, and the outcome predictor network consists of a cross-attention module followed by a linear layer.

Based on Proposition 1, we can control the counterfactual error by learning an encoder function ϕ\phi and an outcome prediction function hh that jointly minimize distribution shift and factual outcome error. We propose our method that implements functions ϕ\phi and hh using neural networks and is trained with an objective function inspired by Proposition 1. Figure 1 illustrates the architecture of the ACFR network, consisting of an encoder network ϕ\phi, an outcome prediction network hh, and a treatment prediction network π\pi. The key aspects of ACFR, distribution shift minimization and minimization of outcome prediction error, are presented as follows. 222Note that in the previous section we assumed ϕ\phi is a one-to-one mapping, however, we obtained better results empirically for all methods using a neural network as encoder.

Refer to caption
Figure 2: Tsne plot of latent representation ZZ learned using different distributional distances. After training each method on News dataset, we mapped validation samples into latent representation and plotted them using 2d tsne. We categorized the samples into 4 intervals with respect to their assigned treatment value and each interval corresponds to a color. We consider two important classes of IPM metrics, HSIC and Wasserstein. The treatment value is less distinguishable in the KL divergence representation followed by IPM-ADMIT (minimization with the algorithm proposed in (Wang et al. 2022b)), and IPM (minimization with the procedure proposed in (Bellot, Dhir, and Prando 2022))

Distribution Shift Minimization

As discussed earlier, in order to minimize distribution shift we aim to minimize the KL divergence term with respect to encoder ϕ\phi. The KL divergence can be rewritten as follows:

DKL(pϕ(z,t)||pϕ(z)p(t))\displaystyle D_{KL}\bigg{(}p_{\phi}(z,t)\,||\,p_{\phi}(z)\,p(t)\bigg{)} =I(T,Z;ϕ)\displaystyle=I(T,Z;\phi)
=H(T)H(T|Z;ϕ)\displaystyle=H(T)-H(T|Z;\phi)

where I(T,Z;ϕ)I(T,Z;\phi) is the mutual information between TT and ZZ, and H(T|Z;ϕ)H(T|Z;\phi) is the conditional entropy. Marginal entropy H(T)H(T) does not depend on ϕ\phi, thus minimizing KL divergence is equivalent to maximizing the conditional entropy H(T|Z;ϕ)=𝔼[pϕ(t|z)]H(T|Z;\phi)=\mathbb{E}[p_{\phi}(t|z)]. However, as pϕ(t|z)p_{\phi}(t|z) is intractable we introduce variational distribution qπ(t|z)q_{\pi}(t|z) defined over the same space to approximate it. For any variational distribution qπ(t|z)q_{\pi}(t|z) the following holds (Farnia and Tse 2016).

maxϕH(T|Z;ϕ)=maxϕinfπ𝔼pϕ(t,z)[logqπ(t|z)]\displaystyle\max_{\phi}H(T|Z;\phi)=\max_{\phi}\inf_{\pi}\mathbb{E}_{p_{\phi}(t,z)}[-\log q_{\pi}(t|z)]

We assume the distribution qπ(t|z)q_{\pi}(t|z) is a normal distribution with a fixed variance. We can estimate the mean of qπ(t|z)q_{\pi}(t|z) by a neural network called treatment-prediction network π\pi. By approximating pϕ(z,t)p_{\phi}(z,t) with empirical data, we derive the following mean squared adversarial loss term from the above negative log likelihood.

ladv=maxϕminπi=1N1N(tiπ(ϕ(xi)))2\displaystyle l_{adv}=\max_{\phi}\min_{\pi}\sum_{i=1}^{N}\frac{1}{N}(t_{i}-\pi(\phi(x_{i})))^{2}

Specifically, the treatment-predictor network π\pi (The green network in Figure 1) is trained to minimize ladvl_{adv} by predicting treatment tt from representation z=ϕ(x)z=\phi(x). The encoder network ϕ\phi (The blue network in Figure 1) is trained to maximize ladvl_{adv} by extracting zz in such a way that the assigned treatment tt is not distinguishable. Therefore KL divergence can be estimated and minimized using two networks and an adversarial loss. Through alternating optimization with respect to ϕ\phi and π\pi, and assuming that the treatment predictor π\pi reaches the optimum in each iteration, the resulting representation zz has a desired property: 𝔼[t|z]=𝔼[t]\mathbb{E}[t|z]=\mathbb{E}[t] (Wang, He, and Katabi 2020). This implies that knowing latent representation zz does not provide additional information for predicting the expected treatment. Figure 2 also demonstrates that the representation learned through KL divergence minimization is less predictive of the treatment value compared to representations obtained from two classes of IPM, HSIC (Gretton et al. 2007) and Wasserstein (Villani 2008), thereby showing more effective reduction of the shift.

Factual Outcome Error Minimization

In this section, we discuss the minimization of the factual prediction error ϵf\epsilon_{f}^{\ell}. Recall that:

ϵf\displaystyle\epsilon_{f}^{\ell} =[0,1]𝒳ϕ,h(x,t)p(x,t)𝑑x𝑑t\displaystyle=\int_{[0,1]}\int_{\mathcal{X}}\ell_{\phi,h}(x,t)\,p(x,t)\,dx\,dt
=[0,1]𝒳L(h(ϕ(x),t),y)p(x,t)𝑑x𝑑t\displaystyle=\int_{[0,1]}\int_{\mathcal{X}}L(h(\phi(x),t),y)\,p(x,t)\,dx\,dt

Here, outcome yy is a continuous variable, and we consider LL to be the squared loss. By approximating p(x,t)p(x,t) with empirical distribution, we derive the following outcome prediction loss that needs to be minimized with respect to ϕ\phi and hh:

lpred=minϕ,hi=1N1N(yih(ϕ(xi),ti))2\displaystyle l_{pred}=\min_{\phi,h}\,\sum_{i=1}^{N}\frac{1}{N}\big{(}y_{i}-h(\phi(x_{i}),t_{i})\big{)}^{2}

The encoder network ϕ\phi is as defined in the previous section. The outcome prediction network hh, however, needs to be particularly designed to maintain the treatment impact on the outcome. We aim to obtain an informative embedding for treatment value, and similar to (Zhang et al. 2022) predict the outcome from the embedding and representation using an attention-based network. (Zhang et al. 2022) proposed to learn the embedding by a neural network. While neural networks are universal function approximators, it has been shown that they can not extract an expressive embedding from a scalar value due to optimization difficulties (Gorishniy, Rubachev, and Babenko 2023). We construct the treatment embedding applying a set of predefined spline functions to the treatment tt shown as S(t)=[s1(t),s2(t),,sm(t)]S(t)=[s_{1}(t),s_{2}(t),\dots,s_{m}(t)] in Figure 1. Spline functions have been shown to be able to approximate a function in a piece-wise manner (Eilers and Marx 1996)

The treatment embedding and the representation are then passed to the cross attention layer (red module in Figure 1) to learn the dependency between treatment and representation. A cross-attention layer has three matrices query QQ, key KK, and value VV, where QQ is learned from treatment embedding using hqh_{q} parameter and KK and VV are learned from the representation using hkh_{k} and hvh_{v} parameters respectively. The output of the cross-attention layer is σ(QTKdk)V\sigma(\frac{Q^{T}K}{\sqrt{d_{k}}})V where dkd_{k} is the dimension of the QQ and KK matrices and σ\sigma denotes the softmax function. We then predict the outcome y^\hat{y} by a linear layer after the cross attention module.

Unlike ad-hoc architectures VCNet (Nie et al. 2021) and DRNet (Schwab et al. 2020), our outcome prediction network is flexible in terms of the number of spline functions. We can incorporate as many splines as needed without increasing the number of model parameters. This is particularly important in individual effect estimation, because each individual responds differently to a given treatment, and hence different spline functions might be necessary to approximate the treatment-response function for different patients. The proposed architecture can incorporate a large number of spline functions, and the attention layer learns how relevant each spline is for estimating each patient treatment-response function. It is also worth mentioning that by setting hqh_{q} to the identity, hkh_{k} to the unity and parameterizing hvh_{v} with a neural network (which are sub-optimal choices) we recover VCNet and DRNet with a cross-attention layer.

Adversarial Counterfactual Regression

Algorithm 1 Adversarial CounterFactual Regression
1:Input: Factual samples (xi,ti,yi)i=1N(x_{i},t_{i},y_{i})_{i=1}^{N}, encoder network with initial parameter ϕ0\phi_{0}, treatment-predictor network with initial parameters π0\pi_{0}, hypothesis network with initial parameters h0h_{0}, batch size bb, iteration number TT, inner loop size MM, trade-off parameter γ\gamma, and the step sizes η1\eta_{1} and η2\eta_{2}.
2:for t0t\leftarrow 0 to T1T-1 do
3:     Sample a mini-batch: B={i1,i2,,ib}B=\{i_{1},i_{2},...,i_{b}\}:
4:     Encode into latent representation: zB=ϕt(xB)z_{B}=\phi_{t}(x_{B})
5:     Initialize ω0=πt\omega_{0}=\pi_{t}
6:     for m0m\leftarrow 0 to M1M-1 do
7:         Compute ladvl_{adv} and update ωm\omega_{m}
8:         t^B=ωm(zB)ladv=1biB(titi^)2\hat{t}_{B}=\omega_{m}(z_{B})\quad l_{adv}=\frac{1}{b}\sum_{i\in B}(t_{i}-\hat{t_{i}})^{2}
9:         ωm+1=ωmη2γωladv(ωm)\omega_{m+1}=\omega_{m}-\eta_{2}\gamma\nabla_{\omega}l_{adv}(\omega_{m})
10:     end for
11:     πt+1=ωm\pi_{t+1}=\omega_{m}
12:     Compute lpredl_{pred} and update ϕt\phi_{t} and hth_{t}:
13:     y^=ht(zB)lpred=1biB(yiyi^)2\hat{y}=h_{t}(z_{B})\quad l_{pred}=\frac{1}{b}\sum_{i\in B}(y_{i}-\hat{y_{i}})^{2}
14:     ϕt+1=ϕtη1(ϕlpred(ϕt)γϕladv(ϕt))\phi_{t+1}=\phi_{t}-\eta_{1}\big{(}\nabla_{\phi}l_{pred}(\phi_{t})-\gamma\nabla_{\phi}l_{adv}(\phi_{t})\big{)}
15:     ht+1=htη1hlpred(ht)h_{t+1}=h_{t}-\eta_{1}\nabla_{h}l_{pred}(h_{t})
16:end for
17:Return ϕT\phi_{T} and hTh_{T}

In order to minimize the distribution shift, we derived the adversarial loss ladvl_{adv} and introduced the networks encoder ϕ\phi and treatment-prediction network π\pi to optimize the loss. Similarly, in order to predict factual outcome we derived outcome prediction loss lpredl_{pred} and introduced its associated attention based prediction network hh to minimize it. Now, we aim to train the entire network to optimize the following objective function as Proposition 1 suggests:

ACFR=maxπminϕ,hlpredγladv\displaystyle\mathcal{L}_{ACFR}=\max_{\pi}\min_{\phi,h}l_{pred}-\gamma l_{adv}

where γ\gamma is a tunable parameter. Algorithm 1 presents pseudo-code for the training of the ACFR network. At each iteration a batch of samples is given as input to the network (line 3). At the first stage ACFR predicts the treatment using encoder and treatment-prediction networks, computes ladvl_{adv} and only updates parameters of π\pi for MM iterations (line 5-10). At the second stage, ACFR predicts the outcome from the encoded representation of the batch using outcome prediction network hh , computes lpredl_{pred} and updates hh and ϕ\phi with respect to lpredl_{pred} and lpredγladvl_{pred}-\gamma\,l_{adv} losses (line 11-14). Finally, the parameters of the encoder and the prediction networks are returned for the inference phase (line 15).

Experiments

News TCGA
Method MISE PE MISE PE
GPS 3.21±0.343.21\pm 0.34 0.39±0.030.39\pm 0.03 6.50±1.216.50\pm 1.21 2.30±0.272.30\pm 0.27
MLP 2.91±0.332.91\pm 0.33 0.31±0.020.31\pm 0.02 4.81±0.544.81\pm 0.54 1.15±0.231.15\pm 0.23
DRNet-HSIC 1.59±0.201.59\pm 0.20 0.21±0.010.21\pm 0.01 2.03±0.272.03\pm 0.27 1.24±0.231.24\pm 0.23
DRNet-Wass 1.64±0.211.64\pm 0.21 0.21±0.010.21\pm 0.01 2.01±0.202.01\pm 0.20 1.29±0.211.29\pm 0.21
VCNet-HISC 1.28±0.101.28\pm 0.10 0.16±0.01\mathbf{0.16\pm 0.01} 1.99±0.111.99\pm 0.11 0.94±0.140.94\pm 0.14
VCNet-Wass 1.43±0.111.43\pm 0.11 0.17±0.010.17\pm 0.01 1.76±0.121.76\pm 0.12 0.92±0.140.92\pm 0.14
ADMIT-HSIC 1.25±0.121.25\pm 0.12 0.18±0.010.18\pm 0.01 1.81±0.231.81\pm 0.23 0.86±0.150.86\pm 0.15
ADMIT-Wass 1.35±0.201.35\pm 0.20 0.18±0.010.18\pm 0.01 1.67±0.231.67\pm 0.23 0.81±0.140.81\pm 0.14
SCIGAN 1.21±0.151.21\pm 0.15 0.20±0.010.20\pm 0.01 1.85±0.141.85\pm 0.14 0.97±0.140.97\pm 0.14
ACFR w/o attention 1.58±0.151.58\pm 0.15 0.19±0.010.19\pm 0.01 1.86±0.21{1.86\pm 0.21} 1.01±0.151.01\pm 0.15
ACFR 1.12±0.12\mathbf{1.12\pm 0.12} 0.18±0.010.18\pm 0.01 1.60±0.20\mathbf{{1.60\pm 0.20}} 0.76±0.12\mathbf{0.76\pm 0.12}
Table 1: Results on News and TCGA datasets for the out-of-sample setting.
Dataset #\#Samples #\#Covariates Outcome function Treatment assignment
TCGA 96599659 40004000 y=10(v1Tx+12v2Txt12v3Txt2)y=10(v_{1}^{T}x+12v_{2}^{T}xt-12v_{3}^{T}xt^{2}) t=Beta(α,β)t=\text{Beta}(\alpha,\beta)
News 50005000 34773477 y=10(v1Tx+sin(v2Txv3Txπt))y=10(v_{1}^{T}x+\sin(\frac{v_{2}^{T}x}{v_{3}^{T}x}\pi t)) β=2(α1)v2Txv3Tx+2α\beta=\frac{2(\alpha-1)v_{2}^{T}x}{v_{3}^{T}x}+2-\alpha
Table 2: Datasets and data generating functions.

Treatment effect estimation methods have to be evaluated for predicting potential outcomes including counterfactuals which are unavailable in real-world observational datasets. Therefore, synthetic or semi-synthetic datasets are commonly used since their treatment assignment mechanism and outcome function are known and hence counterfactual outcomes can be generated. Note that this does not change the fact that only factual outcomes are accessible during training. In this section, we present our experimental results. The code for synthetic data generation and implementation of the methods can be found at here: https://github.com/amirrezakazemi/acfr

Setup

Semi-synthetic data generation

We used TCGA (Network et al. 2013) and News (Johansson, Shalit, and Sontag 2016) semi-synthetic datasets. TCGA dataset consists of gene expression measurements of the 4000 most variable genes for 9659 cancer patients. The News dataset which was introduced as a benchmark in (Johansson, Shalit, and Sontag 2016) consists of 3477 word counts for 5000 randomly sampled news items from the NY times corpus. For each dataset, we first normalized each covariate and then scaled every sample to have a norm 11. We then split the datasets with 68/12/2068/12/20 ratio into training, validation, and test sets. We followed treatment and outcome generating process of (Bica, Jordon, and van der Schaar 2020), summarized in Table 1. The parameter α\alpha in treatment function determines the treatment-selection bias level (α\alpha is set 2 in all experiments unless otherwise stated), and v1v_{1}, v2v_{2} and v3v_{3} are vectors whose entries are sampled from the normal distribution 𝒩(0,1)\mathcal{N}(0,1), and then are normalized. Using the functions in Table 1, we assigned the treatment and factual outcome for all samples in the training and validation sets. All methods are then trained on the training set, and the validation set has been used for hyperparameter selection. Same as (Bica, Jordon, and van der Schaar 2020), potential outcomes for a unit are generated using the outcome function given the unit’s covariates and 6565 grids in the range [0,1][0,1] as an approximation of the treatment range.

Baselines

DRNet (Schwab et al. 2020) and VCNet (Nie et al. 2021) are state-of-the-art neural networks for continuous treatment estimation. Following (Bellot, Dhir, and Prando 2022), we use the following versions of these two methods as the main baselines. HSIC (Gretton et al. 2007) is a version using the Hilbert-Schmidt independence criterion to minimize the distribution shift while Wass (Villani 2008) is a version that uses the Wasserstein distance for that purpose. We consider SCIGAN (Bica, Jordon, and van der Schaar 2020) as state-of-the-art generative method for continuous treatments, and also compare against ADMIT (Wang et al. 2022b) network with their proposed algorithm to estimate IPM distances. Finally we include Generalized Proposensity Score (GPS) and a MLP network as baselines. The MLP network consists of two fully connected layers without any attempt to reduce distribution shift.

Metrics

Having μ(x,t)\mu(x,t) as the ground-truth outcome of the unit with covariate xx under treatment tt and f(x,t)f(x,t) as the predicted outcome, we report the performance of methods in terms of the two following metrics defined in (Schwab et al. 2020). The Mean Integrated Squared Error (MISE) is the squared error of the predicted outcome averaged over all treatment values and all units. The Policy Error (PE) measures the average squared error of estimated optimal treatment, where tit_{i}^{*} and t^i\hat{t}_{i}^{*} denote ground-truth and predicted best treatments respectively.

MISE=1NΣi=1N01[μ(xi,t)f(xi,t)]2𝑑t\displaystyle\text{MISE}=\frac{1}{N}\Sigma_{i=1}^{N}\int_{0}^{1}[\mu(x_{i},t)-f(x_{i},t)]^{2}dt
PE=1NΣi=1N[μ(xi,ti)μ(xi,t^i)]2\displaystyle\text{PE}=\frac{1}{N}\Sigma_{i=1}^{N}[\mu(x_{i},t_{i}^{*})-\mu(x_{i},\hat{t}_{i}^{*})]^{2}

Results

We performed two sets of experiments for potential outcome prediction, called out-of-sample prediction and within-sample prediction. The out-of-sample experiment shows the ability of models in predicting the potential outcomes for units in the held-out test set, and the within-sample experiment shows the ability for units in the training set.

Prediction error

For all methods, we reported the mean and the standard deviation of MISE and PE in the format of mean±\pmstd over 20 realizations of each dataset. Table 1 shows that on TCGA dataset, ACFR outperformed the contenders in both metrics, and on News dataset, ACFR achieved the best and second best result in terms of MISE and PE metrics respectively. We can also see the substantial gain of cross-attention layer in the performance of ACFR by comparing it with ACFR w/o attention, demonstrating the effectiveness of proposed outcome prediction network. Comparably, DRNet and VCNet methods have more parameters while ACFR and ADMIT methods are more time-consuming because of their corresponding inner loop to minimize the distribution shift. You can find the details of our implementation and the results in the within-sample experiment in Appendix B.

Treatment-selection bias robustness

We also investigate the robustness of 4 methods (ACFR, VCNet-HSIC, DRNet-HSIC, and ADMIT-HSIC) against varying level of treatment-selection bias. As mentioned earlier, the α\alpha parameter of Beta distribution in the treatment generating function controls the amount of bias. As α\alpha increases the treatment-selection bias and covariate shift of the observational dataset increase and consequently, we expect the error of methods to increase as well.

Refer to caption
Figure 3: Robustness of ACFR against varying level of treatment-selection bias determined by α\alpha parameter of treatment assignment distribution. ACFR demonstrates a robust performance in terms of MISE and PE compared to baselines.

As shown in Figure , ACFR performs consistently and has a notable gap with the contenders at the strong bias level (α\alpha = 6) in terms of MISE for the out-of-sample setting.

Related Work

Continuous Treatment Effect Estimation

These methods can be categorized into those estimating the average effect and those estimating the individual effect. In the first category, (Hirano and Imbens 2004) proposed the generalized propensity score (GPS) that generalizes the notion of propensity score to continuous treatments. (Wu et al. 2021; Fong, Hazlett, and Imai 2018) proposed approaches to matching and covariate balancing, respectively, according to the weights learned using GPS. (Nie et al. 2021) proposed the Varying Coefficient network (VCNet) which extracts a representation sufficient for GPS prediction and predicts the outcome using a network where the treatment value influences the outcome indirectly through parameters instead of being given directly as input. (Bahadori, Tchetgen, and Heckerman 2022) proposes an entropy balancing method to learn more stable weights compared to GPS based weights. Our approach is fundamentally different since sufficiency of (generalized) propensity score theorem holds only for the average effect and we aim to estimate the individual effect.

In the individual effect category, methods are mostly based on learning a balanced representation. DRNet (Schwab et al. 2020) discretized the treatment range into intervals, minimized the pair-wise shift between the populations fall within these intervals in the representation, and proposed a hierarchical multi head network to predict outcomes for the intervals. (Bellot, Dhir, and Prando 2022) demonstrated that minimizing IPM between P(Z,T)P(Z,T) and P(Z)P(T)P(Z)P(T) coupled with factual outcome error leads to an upper bound of the counterfactual error. Nonetheless, in practice they minimized the IPM sample-wise since P(Z)P(Z) is unknown. Similarly, (Wang et al. 2022b) demonstrated that discretizing the treatment range and minimizing the maximum pair-wise IPM bounds the counterfactual error. Our method is different as we learn the balanced representation by minimizing the KL divergence parameterically.

Adversarial Balanced Representation

Learning a balanced (invariant) representation using an adversarial discriminator has been studied in the transfer learning literature to align source domain(s) to target domain(s) (Ganin et al. 2016; Tzeng et al. 2017; Wang, He, and Katabi 2020). Similarly, in causal inference (Du et al. 2021; Berrevoets et al. 2020) aimed to balance the distributions of two treatment groups adversarially. (Bica et al. 2020) extended the approach to the multiple time-varying treatment setting. However, the existing methods consider only scenarios with a finite number of treatment options and do not provide theoretical guarantees of their generalization capability.

Conclusion

This paper has investigated the problem of continuous treatment effect estimation and introduced ACFR (Adversarial Counter-Factual Regression) method for predicting potential outcomes. We proved a new bound of the counterfactual error using the KL divergence instead of an IPM distance, which has the benefit that the KL divergence can be estimated parametrically and results in a more reliable bound. Based on the error bound, ACFR uses an adversarial neural network architecture to minimize the KL divergence of the representations and a cross-attention network to minimize the factual prediction error. It is worth mentioning that ACFR is not restricted to continuous treatments, and in future work we plan to extend and evaluate ACFR framework for structured and time series treatments. Nonetheless, we note that ACFR, similar to many treatment effect estimation methods, relies on strong ignorability assumption, which is not necessarily hold in real-world applications.

Acknowledgement

We would like to thank Oliver Schulte and Sharan Vaswani for feedback on the paper. This research was supported by the Natural Sciences and Engineering Research Council of Canada (NSERC) Discovery Grant.

References

  • Alaa and van der Schaar (2017) Alaa, A. M.; and van der Schaar, M. 2017. Bayesian Inference of Individualized Treatment Effects using Multi-task Gaussian Processes. arXiv:1704.02801.
  • Bahadori, Tchetgen, and Heckerman (2022) Bahadori, M. T.; Tchetgen, E. T.; and Heckerman, D. E. 2022. End-to-End Balancing for Causal Continuous Treatment-Effect Estimation. arXiv:2107.13068.
  • Bareinboim and Pearl (2012) Bareinboim, E.; and Pearl, J. 2012. Controlling selection bias in causal inference. In Artificial Intelligence and Statistics, 100–108. PMLR.
  • Bareinboim, Tian, and Pearl (2014) Bareinboim, E.; Tian, J.; and Pearl, J. 2014. Recovering from Selection Bias in Causal and Statistical Inference. Proceedings of the AAAI Conference on Artificial Intelligence, 28(1).
  • Bellot, Dhir, and Prando (2022) Bellot, A.; Dhir, A.; and Prando, G. 2022. Generalization bounds and algorithms for estimating conditional average treatment effect of dosage. arXiv:2205.14692.
  • Berrevoets et al. (2020) Berrevoets, J.; Jordon, J.; Bica, I.; gimson, a.; and van der Schaar, M. 2020. OrganITE: Optimal transplant donor organ offering using an individual treatment effect. In Larochelle, H.; Ranzato, M.; Hadsell, R.; Balcan, M.; and Lin, H., eds., Advances in Neural Information Processing Systems, volume 33, 20037–20050. Curran Associates, Inc.
  • Bica et al. (2020) Bica, I.; Alaa, A. M.; Jordon, J.; and van der Schaar, M. 2020. Estimating Counterfactual Treatment Outcomes over Time Through Adversarially Balanced Representations. arXiv:2002.04083.
  • Bica, Jordon, and van der Schaar (2020) Bica, I.; Jordon, J.; and van der Schaar, M. 2020. Estimating the Effects of Continuous-valued Interventions using Generative Adversarial Networks. arXiv:2002.12326.
  • Chu et al. (2023) Chu, Z.; Huang, J.; Li, R.; Chu, W.; and Li, S. 2023. Causal Effect Estimation: Recent Advances, Challenges, and Opportunities. arXiv:2302.00848.
  • Du et al. (2021) Du, X.; Sun, L.; Duivesteijn, W.; Nikolaev, A.; and Pechenizkiy, M. 2021. Adversarial balancing-based representation learning for causal effect inference with observational data. Data Mining and Knowledge Discovery, 35(4): 1713–1738.
  • Eilers and Marx (1996) Eilers, P. H.; and Marx, B. D. 1996. Flexible smoothing with B-splines and penalties. Statistical science, 11(2): 89–121.
  • Farnia and Tse (2016) Farnia, F.; and Tse, D. 2016. A Minimax Approach to Supervised Learning.
  • Fong, Hazlett, and Imai (2018) Fong, C.; Hazlett, C.; and Imai, K. 2018. Covariate balancing propensity score for a continuous treatment: Application to the efficacy of political advertisements. The Annals of Applied Statistics, 12(1): 156–177.
  • Ganin et al. (2016) Ganin, Y.; Ustinova, E.; Ajakan, H.; Germain, P.; Larochelle, H.; Laviolette, F.; Marchand, M.; and Lempitsky, V. 2016. Domain-Adversarial Training of Neural Networks. arXiv:1505.07818.
  • Gorishniy, Rubachev, and Babenko (2023) Gorishniy, Y.; Rubachev, I.; and Babenko, A. 2023. On Embeddings for Numerical Features in Tabular Deep Learning. arXiv:2203.05556.
  • Gretton et al. (2007) Gretton, A.; Fukumizu, K.; Teo, C. H.; Song, L.; Schölkopf, B.; and Smola, A. 2007. A Kernel Statistical Test of Independence. In NIPS.
  • Hassanpour and Greiner (2019) Hassanpour, N.; and Greiner, R. 2019. Learning disentangled representations for counterfactual regression. In International Conference on Learning Representations.
  • Hill (2011) Hill, J. L. 2011. Bayesian Nonparametric Modeling for Causal Inference. Journal of Computational and Graphical Statistics, 20: 217 – 240.
  • Hirano and Imbens (2004) Hirano, K.; and Imbens, G. W. 2004. The propensity score with continuous treatments. Applied Bayesian modeling and causal inference from incomplete-data perspectives, 226164: 73–84.
  • Johansson, Shalit, and Sontag (2016) Johansson, F.; Shalit, U.; and Sontag, D. 2016. Learning Representations for Counterfactual Inference. In Balcan, M. F.; and Weinberger, K. Q., eds., Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, 3020–3029. New York, New York, USA: PMLR.
  • Kent, Steyerberg, and van Klaveren (2018) Kent, D. M.; Steyerberg, E.; and van Klaveren, D. 2018. Personalized evidence based medicine: predictive approaches to heterogeneous treatment effects. Bmj, 363.
  • Kobrosly (2020) Kobrosly, R. W. 2020. causal-curve: A Python Causal Inference Package to Estimate Causal Dose-Response Curves. Journal of Open Source Software, 5(52): 2523.
  • Liang (2019) Liang, T. 2019. Estimating Certain Integral Probability Metric (IPM) is as Hard as Estimating under the IPM. arXiv:1911.00730.
  • Network et al. (2013) Network, C.; Weinstein, J.; Collisson, E.; Mills, G.; Shaw, K.; Ozenberger, B.; Ellrott, K.; Shmulevich, I.; Sander, C.; Stuart, J.; and Vandin, F. 2013. The Cancer Genome Atlas Pan-Cancer analysis project. Nature Genetics, 45(10): 1113–1120.
  • Nie et al. (2021) Nie, L.; Ye, M.; Liu, Q.; and Nicolae, D. 2021. Vcnet and functional targeted regularization for learning causal effects of continuous treatments. arXiv preprint arXiv:2103.07861.
  • Prosperi et al. (2020) Prosperi, M.; Guo, Y.; Sperrin, M.; Koopman, J. S.; Min, J. S.; He, X.; Rich, S.; Wang, M.; Buchan, I. E.; and Bian, J. 2020. Causal inference and counterfactual prediction in machine learning for actionable healthcare. Nature Machine Intelligence, 2(7): 369–375.
  • Schwab et al. (2020) Schwab, P.; Linhardt, L.; Bauer, S.; Buhmann, J. M.; and Karlen, W. 2020. Learning Counterfactual Representations for Estimating Individual Dose-Response Curves. Proceedings of the AAAI Conference on Artificial Intelligence, 34(04): 5612–5619.
  • Shalit, Johansson, and Sontag (2017) Shalit, U.; Johansson, F. D.; and Sontag, D. 2017. Estimating individual treatment effect: generalization bounds and algorithms. In International conference on machine learning, 3076–3085. PMLR.
  • Sriperumbudur et al. (2009) Sriperumbudur, B. K.; Fukumizu, K.; Gretton, A.; Schölkopf, B.; and Lanckriet, G. R. G. 2009. On integral probability metrics, ϕ\phi-divergences and binary classification. arXiv:0901.2698.
  • Tzeng et al. (2017) Tzeng, E.; Hoffman, J.; Saenko, K.; and Darrell, T. 2017. Adversarial Discriminative Domain Adaptation. arXiv:1702.05464.
  • Varian (2016) Varian, H. R. 2016. Causal inference in economics and marketing. Proceedings of the National Academy of Sciences, 113(27): 7310–7315.
  • Villani (2008) Villani, C. 2008. Optimal Transport: Old and New. In Optimal Transport: Old and New.
  • Wang, He, and Katabi (2020) Wang, H.; He, H.; and Katabi, D. 2020. Continuously indexed domain adaptation. arXiv preprint arXiv:2007.01807.
  • Wang et al. (2022a) Wang, X.; Lyu, S.; Wu, X.; Wu, T.; and Chen, H. 2022a. Generalization Bounds for Estimating Causal Effects of Continuous Treatments. In Koyejo, S.; Mohamed, S.; Agarwal, A.; Belgrave, D.; Cho, K.; and Oh, A., eds., Advances in Neural Information Processing Systems, volume 35, 8605–8617. Curran Associates, Inc.
  • Wang et al. (2022b) Wang, X.; Lyu, S.; Wu, X.; Wu, T.; and Chen, H. 2022b. Generalization Bounds for Estimating Causal Effects of Continuous Treatments. In Oh, A. H.; Agarwal, A.; Belgrave, D.; and Cho, K., eds., Advances in Neural Information Processing Systems.
  • Wang et al. (2020) Wang, Y.; Liang, D.; Charlin, L.; and Blei, D. M. 2020. Causal Inference for Recommender Systems. In Proceedings of the 14th ACM Conference on Recommender Systems, RecSys ’20, 426–431. New York, NY, USA: Association for Computing Machinery. ISBN 9781450375832.
  • Wu et al. (2022) Wu, A.; Yuan, J.; Kuang, K.; Li, B.; Wu, R.; Zhu, Q.; Zhuang, Y.; and Wu, F. 2022. Learning decomposed representations for treatment effect estimation. IEEE Transactions on Knowledge and Data Engineering, 35(5): 4989–5001.
  • Wu et al. (2021) Wu, X.; Mealli, F.; Kioumourtzoglou, M.-A.; Dominici, F.; and Braun, D. 2021. Matching on Generalized Propensity Scores with Continuous Exposures. arXiv:1812.06575.
  • Yao et al. (2018) Yao, L.; Li, S.; Li, Y.; Huai, M.; Gao, J.; and Zhang, A. 2018. Representation Learning for Treatment Effect Estimation from Observational Data. In Neural Information Processing Systems.
  • Zhang, Liu, and Li (2021) Zhang, W.; Liu, L.; and Li, J. 2021. Treatment effect estimation with disentangled latent factors. arXiv:2001.10652.
  • Zhang, Bellot, and van der Schaar (2020) Zhang, Y.; Bellot, A.; and van der Schaar, M. 2020. Learning Overlapping Representations for the Estimation of Individualized Treatment Effects. arXiv:2001.04754.
  • Zhang et al. (2022) Zhang, Y.-F.; Zhang, H.; Lipton, Z. C.; Li, L. E.; and Xing, E. P. 2022. Exploring Transformer Backbones for Heterogeneous Treatment Effect Estimation. arXiv:2202.01336.

Appendix
 

Appendix A: Proofs

{restatable*}

[Proposition 1 - Counterfactual Generalization Bound]thmprop1 Given the one-to-one encoder function ϕ:𝒳𝒵\phi:\mathcal{X}\rightarrow\mathcal{Z}, the outcome prediction function h:𝒵×[0,1]𝒴h:\mathcal{Z}\times[0,1]\rightarrow\mathcal{Y}, and the unit-loss function L,h,ϕ(x,t)\ell_{L,h,\phi}(x,t) that satisfies Assumption 4,

εcfεf+C2DKL(pϕ(z,t)||pϕ(z)p(t))\varepsilon_{cf}^{\ell}\leq\varepsilon_{f}^{\ell}+C\sqrt{2D_{KL}\bigg{(}p_{\phi}(z,t)||p_{\phi}(z)p(t)\bigg{)}}
Proof.

Let ψ:𝒵𝒳\psi:\mathcal{Z}\rightarrow\mathcal{X} be the inverse of ϕ\phi. Similar to the proof technique of (Shalit, Johansson, and Sontag 2017), the following derivations shows the result

εcfεf=[0,1]𝒳L,h,ϕ(x,t)[p(x)p(t)p(x,t)]𝑑x𝑑t\displaystyle\varepsilon_{cf}^{\ell}-\varepsilon_{f}^{\ell}=\int_{[0,1]}\int_{\mathcal{X}}\ell_{L,h,\phi}(x,t)\,\big{[}p(x)p(t)-p(x,t)\big{]}\,dxdt (1)
=[0,1]𝒵L,h,ϕ(ψ(z),t)[p(ψ(z))p(t)p(ψ(z),t)]JψJψ1𝑑ψ(z)𝑑t\displaystyle=\int_{[0,1]}\int_{\mathcal{Z}}\ell_{L,h,\phi}(\psi(z),t)\,\big{[}p(\psi(z))p(t)-p(\psi(z),t)\big{]}J_{\psi}J_{\psi}^{-1}d\psi(z)dt (2)
=[0,1]𝒵L,h,ϕ(ψ(z),t)[pϕ(z)p(t)pϕ(z,t)]𝑑z𝑑t\displaystyle=\int_{[0,1]}\int_{\mathcal{Z}}\ell_{L,h,\phi}(\psi(z),t)\,\big{[}p_{\phi}(z)p(t)-p_{\phi}(z,t)\big{]}\,dzdt (3)
[0,1]𝒵C|pϕ(z)p(t)pϕ(z,t)|\displaystyle\leq\int_{[0,1]}\int_{\mathcal{Z}}C\big{|}p_{\phi}(z)p(t)-p_{\phi}(z,t)\big{|} (4)
C2[0,1]𝒵pϕ(z)p(t)log(pϕ(z)p(t)pϕ(z,t))\displaystyle\leq C\sqrt{2\,\int_{[0,1]}\int_{\mathcal{Z}}p_{\phi}(z)p(t)\log\bigg{(}\frac{p_{\phi}(z)p(t)}{p_{\phi}(z,t)}\bigg{)}} (5)
=C2DKL(pϕ(z)p(t)||pϕ(z,t))\displaystyle=C\sqrt{2\,D_{KL}\bigg{(}p_{\phi}(z)p(t)||p_{\phi}(z,t)\bigg{)}} (6)

where the equality (3) holds by the reparameterization x=ψ(z)x=\psi(z), inequality (4) holds by Assumption 4 constraining the function \ell, and the last two inequalities is by Pinkser’s inequality |pq|=2TVD(p,q)=2DKL(p,q)\int|p-q|=2{TV}_{D}(p,q)=\sqrt{2D_{KL}(p,q)}. ∎

In order to prove Proposition 2, we first prove the following lemma.
{restatable*}[Lemma 1 ]thmlem1 Given the one-to-one encoder function ϕ:𝒳𝒵\phi:\mathcal{X}\rightarrow\mathcal{Z}, the outcome prediction function h:𝒵×[0,1]𝒴h:\mathcal{Z}\times[0,1]\rightarrow\mathcal{Y}, and the unit-loss function L,h,ϕ(x,t)\ell_{L,h,\phi}(x,t) that satisfies Assumption 4, for any treatment tt in the valid range,

εcf(t)εf(t)+C2DKL(pϕ(z)||pϕ(z|t))\varepsilon_{cf}^{\ell}(t)\leq\varepsilon_{f}^{\ell}(t)+C\sqrt{2D_{KL}\bigg{(}p_{\phi}(z)\big{|}\big{|}p_{\phi}(z|t)\bigg{)}}
Proof.

Let ψ:𝒵𝒳\psi:\mathcal{Z}\rightarrow\mathcal{X} be the inverse of ϕ\phi.

εcf(t)εf(t)=𝒳L,h,ϕ(x,t)[p(x)p(x|t)]𝑑x\displaystyle\varepsilon_{cf}^{\ell}(t)-\varepsilon_{f}^{\ell}(t)=\int_{\mathcal{X}}\ell_{L,h,\phi}(x,t)\,\big{[}p(x)-p(x|t)\big{]}\,dx (1)
=𝒵L,h,ϕ(ψ(z),t)[pϕ(z)pϕ(z|t)]𝑑z\displaystyle=\int_{\mathcal{Z}}\ell_{L,h,\phi}(\psi(z),t)\,\big{[}p_{\phi}(z)-p_{\phi}(z|t)\big{]}\,dz (2)
𝒵C|pϕ(z)pϕ(z|t)|\displaystyle\leq\int_{\mathcal{Z}}C\big{|}p_{\phi}(z)-p_{\phi}(z|t)\big{|} (3)
C2𝒵pϕ(z)log(pϕ(z)pϕ(z|t))\displaystyle\leq C\sqrt{2\,\int_{\mathcal{Z}}p_{\phi}(z)\log\bigg{(}\frac{p_{\phi}(z)}{p_{\phi}(z|t)}\bigg{)}} (4)
=C2DKL(pϕ(z)||pϕ(z|t))\displaystyle=C\sqrt{2\,D_{KL}\bigg{(}p_{\phi}(z)||p_{\phi}(z|t)\bigg{)}} (5)

Observe the difference between Lemma 1 and Proposition 1. In Lemma 1 the counterfactual error is restricted to a treatment value, and thus it is bounded by the factual error of that specific treatment, and the resulting shift from it. Using this lemma, we can prove the following.

{restatable*}[Proposition 2 - Precision of Estimating Heterogeneous Effect Bound]thmprop2 Given the same encoder function ϕ\phi and outcome prediction function hh as in Proposition 1, and a unit-loss function L,h,ϕ(x,t)\ell_{L,h,\phi}(x,t) that satisfies Assumption 4 and its associated LL is squared error ||.||2||.||^{2},

εpehe(t1,t2)εf(t1)+εf(t2)+\displaystyle\varepsilon_{pehe}(t_{1},t_{2})\leq\varepsilon_{f}^{\ell}(t_{1})\,+\,\varepsilon_{f}^{\ell}(t_{2})\,+\,
C[2DKL(pϕ(z)||pϕ(z|t1))+2DKL(pϕ(z)||pϕ(z|t2))]\displaystyle C\Bigg{[}\sqrt{2\,D_{KL}\bigg{(}p_{\phi}(z)||p_{\phi}(z|t_{1})\bigg{)}}+\sqrt{2\,D_{KL}\bigg{(}\,p_{\phi}(z)||p_{\phi}(z|t_{2})\bigg{)}}\,\Bigg{]}
Proof.
εpehe(t1,t2)\displaystyle\varepsilon_{pehe}(t_{1},t_{2}) =𝒳[(μ(x,t1)μ(x,t2))(h(ϕ(x),t1)h(ϕ(x),t2))]2p(x)𝑑x\displaystyle=\int_{\mathcal{X}}\big{[}\big{(}\mu(x,t_{1})-\mu(x,t_{2})\big{)}-\big{(}h(\phi(x),t_{1})-h(\phi(x),t_{2})\big{)}\big{]}^{2}p(x)dx (1)
X(μ(x,t1)h(ϕ(x),t1))2p(x)𝑑x+X(μ(x,t2)h(ϕ(x),t2))2p(x)𝑑x\displaystyle\leq\int_{X}\big{(}\mu(x,t_{1})-h(\phi(x),t_{1})\big{)}^{2}p(x)dx+\int_{X}\big{(}\mu(x,t_{2})-h(\phi(x),t_{2})\big{)}^{2}p(x)dx (2)
=εcf(t1)+εcf(t2)\displaystyle=\varepsilon_{cf}^{\ell}(t_{1})+\varepsilon_{cf}^{\ell}(t_{2}) (3)
εf(t1)+εf(t2)+C[2DKL(pϕ(z)||pϕ(z|t1))+2DKL(pϕ(z)||pϕ(z|t2))]\displaystyle\leq\varepsilon_{f}^{\ell}(t_{1})\,+\,\varepsilon_{f}^{\ell}(t_{2})\,+\,C\Bigg{[}\sqrt{2\,D_{KL}\bigg{(}p_{\phi}(z)||p_{\phi}(z|t_{1})\bigg{)}}+\sqrt{2\,D_{KL}\bigg{(}p_{\phi}(z)||p_{\phi}(z|t_{2})\bigg{)}}\,\Bigg{]} (4)

where the inequality (2) is by triangle inequality, and the last two lines hold by the definition of counterfactual error and Lemma 1 respectively. ∎

Appendix B: Additional Experiments

We discuss the implementation and the set of parameters and hyperparameters for each method.

Adversarial Counterfactual Regression

For the representation and treatment-predictor networks, we use feedforward layers. To obtain the treatment embedding, we vary the degree {2,3,4}\in\{2,3,4\} and the knots {{1/3,2/3},{1/4,2/4,3/4},{1/5,2/5,3/5,4/5}}\in\{\{1/3,2/3\},\{1/4,2/4,3/4\},\{1/5,2/5,3/5,4/5\}\} used to construct the spline functions for treatment. For the attention module in the outcome-prediction network, we also consider a feedforward layer with dimensions 3232 for News and 6464 for the TCGA dataset. We finally vary the dimensions of the last layer {16,8}\in\{16,8\} and {32,16}\{32,16\} for News and TCGA, respectively.

The trade-off parameter in the objective function, γ{102,101,1,10}\gamma\in\{10^{-2},10^{-1},1,10\}, and we vary the number of inner loops, M{1,10,100}M\in\{1,10,100\}, for optimizing the treatment predictor.

IPM Minimization Techniques

We use two different techniques to minimize the IPM distance.

a) As proposed in (Bellot, Dhir, and Prando 2022), for each sample in a batch, we consider the distance of the joint distribution of that sample p(zi,ti)p(z_{i},t_{i}) with the distance of the distribution of all other samples p(zj,tj)p(z_{j},t_{j}) in that batch. This involves computing the IPM term once for each sample as follows:

1Ni=1NIPM({zi,ti},{zi,tj}j:ji)\frac{1}{N}\sum_{i=1}^{N}IPM\Big{(}\{z_{i},t_{i}\},\{z_{i},t_{j}\}_{j:j\neq i}\Big{)}

b) As proposed in (Wang et al. 2022b), we first divide the treatment range into ll equal intervals. For each interval Δi\Delta_{i}, we compute the IPM distance of the distribution corresponding to that interval and the distributions of other intervals. We then minimize the maximum distance between them as follows:

1Ni=1Nk=1lmax(IPM(Δk,Δj:jk))\frac{1}{N}\sum_{i=1}^{N}\sum_{k=1}^{l}\max\Big{(}IPM\big{(}\Delta_{k},\Delta_{j:j\neq k}\big{)}\Big{)}

We use two classes of IPM families called Hilbert Schmidt Independence Criterion and Wasserstein distance with the implementations from (Bellot, Dhir, and Prando 2022) and (Shalit, Johansson, and Sontag 2017) for these two metrics, respectively.

Multi-Layer Perceptron

We construct an MLP network using two feedforward layers, which takes the concatenated covariates and treatment as input and predicts the outcome. Its objective is the mean squared error between the ground truth and the predicted outcome.

Dose-Response Network + IPM

The original implementation of DRNet considered the treatment variable as a pair of a medication and a dosage, where the medication is categorical and the dosage is a continuous variable. We adjusted the architecture and algorithm for continuous treatment. To minimize the distribution distance in the representation space, we minimize the IPM distance using the first procedure described above. We also use 55 distinct regression heads for the samples in 55 equal intervals: [0,0.2],[0.2,0.4],,[0.8,1][0,0.2],[0.2,0.4],\dots,[0.8,1]. Each regression head and the representation network consist of feedforward layers, and similar to the MLP architecture, the treatment value is given to each regression head as input. The weight of the IPM loss term is {103,102,101,1}\{10^{-3},10^{-2},10^{-1},1\} for News and {102,101,1,10}\{10^{-2},10^{-1},1,10\} for the TCGA dataset.

Varying Coefficient Network + IPM

We adjusted the implementation of VCNet, which was originally proposed for average treatment effect and learned the representation based on propensity score. The adjusted VCNet has two sub-networks. The representation network consists of feedforward layers, and the outcome prediction network is constructed by involving the treatment value into the network parameters. Specifically, we consider a set of spline functions with degree 22 and knots [1/3,2/3][1/3,2/3] and use 55 heads, each associated with one spline function. The output of the dynamic network is the linear combination of spline functions where the weights are the output of the regression heads. To minimize the distribution distance in the representation space, we minimize the IPM distance using the first technique explained above. The weight of the IPM loss term is {102,101,1}\{10^{-2},10^{-1},1\} for News and {102,101,1,10}\{10^{-2},10^{-1},1,10\} for the TCGA dataset.

ADMIT + IPM

The IPM minimization procedure in ADMIT is based on the second technique. ADMIT has three sub-networks: representation network, re-weighting network, and hypothesis network. For the representation network, we use feedforward layers, and for the last two, we use the dynamic network proposed by (Nie et al. 2021) as described above, in order to maintain treatment impact. The weight of the IPM loss term is {102,101,1}\{10^{-2},10^{-1},1\} for both datasets. Also, we vary the number of intervals: {3,4,5}\{3,4,5\}.

In the feedforward layers of the above implementations, we vary the number of nodes {50,100}\in\{50,100\} and the number of hidden layers {0,1}\in\{0,1\}. The step sizes are {105,104,103}\in\{10^{-5},10^{-4},10^{-3}\} and the batch size is {32,64}\{32,64\} for all methods. Additionally, for the Generalized Propensity Score implementation, we employed the implementation in (Kobrosly 2020) for continuous outcomes, and adjusted the implementation of SCIGAN (Bica, Jordon, and van der Schaar 2020) for continuous treatments as well. To select the best set of hyperparameters, we used a Bayesian approach, specifically the Tree-Structured Parzen Estimator in the Optuna package.

Additional Results

Similar to the results provided for the out-of-sample experiment, we present the prediction error and selection bias robustness results of the within-sample setting for the methods as follows:

News TCGA
Method MISE PE MISE PE
GPS 3.08±0.333.08\pm 0.33 0.36±0.020.36\pm 0.02 6.25±0.976.25\pm 0.97 1.95±0.291.95\pm 0.29
MLP 2.79±0.322.79\pm 0.32 0.31±0.020.31\pm 0.02 4.72±0.654.72\pm 0.65 1.27±0.231.27\pm 0.23
DRNet-HSIC 1.32±0.201.32\pm 0.20 0.20±0.010.20\pm 0.01 1.91±0.251.91\pm 0.25 1.04±0.191.04\pm 0.19
DRNet-Wass 1.34±0.211.34\pm 0.21 0.19±0.010.19\pm 0.01 1.88±0.201.88\pm 0.20 1.04±0.211.04\pm 0.21
VCNet-HISC 1.18±0.111.18\pm 0.11 0.16±0.010.16\pm 0.01 1.59±0.121.59\pm 0.12 0.87±0.140.87\pm 0.14
VCNet-Wass 1.23±0.101.23\pm 0.10 0.13±0.010.13\pm 0.01 1.39±0.11\mathbf{1.39\pm 0.11} 0.82±0.140.82\pm 0.14
ADMIT-HSIC 1.12±0.121.12\pm 0.12 0.12±0.01\mathbf{0.12\pm 0.01} 1.71±0.231.71\pm 0.23 0.76±0.150.76\pm 0.15
ADMIT-Wass 1.20±0.101.20\pm 0.10 0.13±0.010.13\pm 0.01 1.46±0.211.46\pm 0.21 0.79±0.130.79\pm 0.13
SCIGAN 1.15±0.111.15\pm 0.11 0.16±0.010.16\pm 0.01 1.58±0.211.58\pm 0.21 0.86±0.100.86\pm 0.10
ACFR w/o attention 1.34±0.131.34\pm 0.13 0.17±0.010.17\pm 0.01 1.66±0.20{1.66\pm 0.20} 0.92±0.130.92\pm 0.13
ACFR 0.95±0.12\mathbf{0.95\pm 0.12} 0.15±0.010.15\pm 0.01 1.42±0.22{1.42\pm 0.22} 0.62±0.11\mathbf{0.62\pm 0.11}
Table 1: Results on News and TCGA datasets for the within-sample setting.
Refer to caption
Figure 1: Robustness of ACFR against varying level of selection bias in within-sample setting determined by α\alpha parameter of treatment assignment distribution. ACFR demonstrates a robust performance in terms of MISE and PE compared to contenders.