This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

DAPDAG: Domain Adaptation via
Perturbed DAG Reconstruction

Yanke Li    Hatt Tobias    Ioana Bica    Mihaela van der Schaar
Abstract

Leveraging labelled data from multiple domains to enable prediction in another domain without labels is an important, yet challenging problem. To address this problem, we introduce the framework DAPDAG (Domain Adaptation via Perturbed DAG Reconstruction) and propose to learn an auto-encoder that undertakes inference on population statistics given features and reconstructing a directed acyclic graph (DAG) as an auxiliary task. The underlying DAG structure is assumed invariant among observed variables whose conditional distributions are allowed to vary across domains led by a latent environmental variable EE. The encoder is designed to serve as an inference device on EE while the decoder reconstructs each observed variable conditioned on its graphical parents in the DAG and the inferred EE. We train the encoder and decoder jointly through an end-to-end manner and conduct experiments on both synthetic and real datasets with mixed types of variables. Empirical results demonstrate that reconstructing the DAG benefits the approximate inference and furthermore, our approach can achieve competitive performance against other benchmarks in prediction tasks, with better adaptation ability especially in the target domain significantly different from the source domains.

Machine Learning, ICML

1 Introduction

Domain adaptation (DA) concerns itself with a scenario where one wants to transfer a model learned from one or more labelled source datasets, to a target dataset (which can be labelled or unlabelled) drawn from a different but somehow related distribution. In many settings, a wealth of data may exist, contained in several datasets collected from different sources, such as different hospitals, yet the target domain has few labels available due to possible lag or in-feasibility on data collection. Knowing what information can be transferred across domains and what needs to be adapted becomes the key to leveraging these datasets when presented unlabelled dataset from another domain, which has attracted significant attention in the machine learning community. In this paper, we revisit the problem of unsupervised domain adaptation (UDA), where the target dataset is unlabelled, under the same feature space with multiple source domains. To avoid over repetition, the word “environment” and “domain” are used interchangeably in the following paragraphs.

For UDA, there have been various approaches developed and most of them can be summarised into two main categories - either learning an invariant representation of features with implicit alignment over source domains and the target domain, or utilising the underlying causal assumptions and knowledge that provide clues on the source of distribution shift for better adaptation. Despite the success of invariant representation methods in visual UDA tasks (Wang & Deng, 2018; Deng et al., 2019; Kang et al., 2019; Lee et al., 2019; Liu et al., 2019; Jiang et al., 2020), its black-box nature remains vague locally and causes issue in some situations (Zhao et al., 2019). Exploring the underlying causal structure and properties may help add more interpretability and make predictions across different domains more robust 111Robustness refers to generalisation ability of model to unseen data. In most settings, the causal structure of variables (both the features 𝐗\mathbf{X} and the label YY) are assumed to remain constant across domains and the label has a fixed conditional distribution given causal features (Schölkopf et al., 2012; Magliacane et al., 2017). In this work, we expect to capture similar invariant structural information but with conditional shift. More specifically, we cast the data generating process of distinct domains as a probability distribution with a continuous latent variable EE that perturbs the conditional distributions of observed variables. An auto-encoder approach is proposed to capture this latent EE, utilising structural regularisation to facilitate sparsity and acyclicity among variable relationships. Our model is expected to be able to make inference on EE, which is further used to adjust for domain shift in prediction. To accomplish this, the encoder structure is designed to approximate the posterior of EE, drawing insights from methods of deep sets (Zaheer et al., 2017) and Bayesian inference (Maeda et al., 2020), while the decoder aims to reconstruct all observed variables in a DAG taking the inferred EE.

Contributions

The main contributions of this paper are three-fold:

  • We present a framework consisting of a encoder for the approximate inference on domain-specific variable EE, and a decoder to reconstruct mixed-type data including continuous and binary variables. A novel training strategy is proposed to train our model: weighted stochastic domain selection, which enables inter and intra-domain validation during training.

  • We provide a generalisation bound on the decoder in our structure with mixed-type data, validating the training loss form to some extent.

  • We validate our method with experiments on both simulated and real-world datasets, demonstrating the effectiveness of DAG reconstruction and performance gain of our approach in prediction tasks against benchmarks.

Related Work

Since we are more interested in causal methods for UDA, review on other UDA methods would not be discussed in this paper. For more detailed reviews on general DA methods, please refer to (Quiñonero-Candela et al., 2009) and (Pan & Yang, 2009). Various approaches have been proposed in causal UDA yet most of them can be categorised into three classes: (1) Correcting distribution shift by different scenarios of UDA according to underlying causal relations between 𝐗\mathbf{X} and YY (e.g. to estimate the target conditional distribution as a linear mixture of source domain conditionals by matching the target-domain feature distribution) (Schölkopf et al., 2012; Zhang et al., 2015; Stojanov et al., 2019); (2) Identifying invariant subset of variables across domains for robust prediction (Magliacane et al., 2017; Rojas-Carulla et al., 2018); (3) Augmenting causal graph by considering interventions or environmental changes as exogenous (context) variables which affect endogenous (system) variables and implementing joint causal inference (JCI) on these augmenting graphs (Mooij et al., 2020; Zhang et al., 2020). Our approach is closest to the third class, by introducing an latent perturbation variable EE that induces conditional shift of observed variables. The resulting graph may not be causal any more, nevertheless our focus is the DAG representation of the whole distribution, which enforces sparsity and acyclicity for better learning of EE.

Our entire framework also takes resemblance with meta learning (for a survey on this, please see (Vilalta & Drissi, 2002; Vanschoren, 2018)). In our setting, the objective is to learn an algorithm from different training tasks (domains) and to apply the algorithm to a new task (domain). (Maeda et al., 2020) introduces an auto-encoder model to learn the latent embedding of different tasks under the Bayesian inference framework, which has similar mechanism with ours except that our decoder aims to reconstruct all variables in a DAG instead of only the target variable. There also exist a few works using meta-learning approach to handle variant causal structures across domains (Nair et al., 2019; Dasgupta et al., 2019; Ke et al., 2020; Löwe et al., 2020). Since our approach assumes an invariant DAG structure, we would not dive deeper into those methods although they may provide inspiring reference for our future work.

2 Preliminaries

Learning a casual DAG is a hard problem that needs exhaustive search over a super-exponential combinatorial DAG space, which becomes impossible to deal with in high-dimensional case. However, recent advances in structure learning (Zheng et al., 2018; Yu et al., 2019; Zhang et al., 2019; Lachapelle et al., 2019; Yang et al., 2020; Zheng et al., 2020) reduce the original combinatorial optimisation problem to a continuous optimisation by using a novel acyclicity constraint, which accelerate the learning and provide more inspirations. Some works have been extended to more complex settings including structural learning across non-stationary environments (Ghassami et al., 2018; Bengio et al., 2019; Ke et al., 2019). Despite difference in implementation, above methods use end-to-end optimisation with standard gradient-descent methods that are on-the-shelf. In our work, we take the advantage of continuous optimisation methods and emphasise on NO-TEARS methods (Zheng et al., 2018, 2020) that can be better integrated into the deep learning framework. We consider learning a DAG as a auxiliary task to improve model’s generalisation and robustness (Kyono et al., 2020), contributing to the better learning of latent variable EE in the meantime.

An example is introduced below to recap the basic idea of the NO-TEARS method. Suppose we want to learn a linear SEM (Structural Equation Model) with the form 𝐗=𝐗𝐁+ϵ\mathbf{X}=\mathbf{X}\mathbf{B}+\mathbf{\epsilon} where ϵ\mathbf{\epsilon} is the random noise variable and 𝐁d×d\mathbf{B}\in\mathbb{R}^{d\times d} is the weighted adjacency matrix. Then it can be proved that:

𝐁 is a DAGh(𝐁)=Tr(e𝐁𝐁)d=0\mathbf{B}\text{ is a DAG}\Leftrightarrow h(\mathbf{B})=Tr(e^{\mathbf{B}\odot\mathbf{B}})-d=0 (1)

where \odot is the Hadamard product [𝐁𝐁]ij=𝐁ij2[\mathbf{B}\odot\mathbf{B}]_{ij}=\mathbf{B}_{ij}^{2}.

For formal proof, please refer to (Zheng et al., 2018). This formulation converts learning a linear DAG into a non-convex optimisation problem:

min𝐁(𝐁)=12n𝐗𝐗𝐁F2+λ|vec(𝐁)|1\displaystyle\min\limits_{\mathbf{B}}\quad\mathcal{L}(\mathbf{B})=\frac{1}{2n}||\mathbf{X}-\mathbf{X}\mathbf{B}||_{F}^{2}+\lambda|vec(\mathbf{B})|_{1}
subject toh(𝐁)=0\displaystyle\quad\text{subject to}\quad h(\mathbf{B})=0 (2)

In (Zheng et al., 2018), they solve the above problem by augmenting quadratic penalty and using Lagrangian method:

min𝐁(𝐁)+ρ2|h(𝐁)|2+αh(𝐁)\min\limits_{\mathbf{B}}\quad\mathcal{L}(\mathbf{B})+\frac{\rho}{2}|h(\mathbf{B})|^{2}+\alpha h(\mathbf{B}) (3)

where ρ\rho is the penalty coefficient and α\alpha is the Lagrangian Multiplier. A further extension of this conversion has proposed by (Zheng et al., 2020) to the case of general non-parametric DAGs. Please refer appendix A for a detailed illustration.

3 Methodology

3.1 Formulation

Problem Setting

Let YY be the target variable and 𝐗d\mathbf{X}\in\mathbb{R}^{d} be features. We consider MM labeled datasets from different source domains, i.e. (𝐗im,Yim)i=1nmm(\mathbf{X}_{i}^{m},Y_{i}^{m})_{i=1}^{n_{m}}\sim\mathbb{P}^{m} where m{1,2,,M}m\in\{1,2,...,M\} represents the domain index, m\mathbb{P}^{m} stands for the probability distribution of (𝐗,Y)(\mathbf{X},Y) in domain mm and nmn_{m} is the dataset size of domain mm. Our objective is to predict (Yiτ)i=1nτ(Y_{i}^{\tau})_{i=1}^{n_{\tau}} given (𝐗iτ)i=1nτ(\mathbf{X}_{i}^{\tau})_{i=1}^{n_{\tau}} from the target domain τ\tau without labels.

Basic Assumptions

Let 𝐗~=(𝐗,Y)d+1\tilde{\mathbf{X}}=(\mathbf{X},Y)\in\mathbb{R}^{d+1} be observed variables, we assume:

  • Besides 𝐗~\tilde{\mathbf{X}}, there is a latent environmental variable EE controlling the distribution shift of observed variables. For each domain, EE is sampled from its prior 𝒩(0,σe2)\mathcal{N}(0,\sigma_{e}^{2}) and fixed for data generation.

  • Observed data are generated according to a perturbed DAG: the conditional distribution of X~j\tilde{X}_{j} given its parents and EE follows an exponential family distribution in the form of:

    p(X~j|X~Pa(j),E)\displaystyle p(\tilde{X}_{j}|\tilde{X}_{Pa(j)},E) =exp(η(X~Pa(j),E)T(X~j)\displaystyle=\exp(\eta(\tilde{X}_{Pa(j)},E)\cdot T(\tilde{X}_{j})
    +A(X~Pa(j),E)+B(X~j))\displaystyle+A(\tilde{X}_{Pa(j)},E)+B(\tilde{X}_{j})) (4)

    where η()\eta(\cdot), A()A(\cdot) and B()B(\cdot) are functions.

Perturbed DAG

We assume a perturbed DAG where a joint environmental variable EE will influence the conditional distribution of an observed variable across domains. The illustration of this perturbed DAG is shown in Figure 1.

Refer to caption
Figure 1: Perturbed DAG across Different Domains: for each domain, an environmental variable EE is generated and fixed for that domain, then all observed variables are sampled according to the DAG and EE.

3.2 Model

We expect a model that is able to well capture the difference between EE of different domains, and then adapt to the change accordingly. So how to properly encode an empirical distribution to a statistics becomes the cornerstone of our model. Considering similarity with the goal of classical statistical estimation methods such as Maximum Likelihood Estimation (MLE), our objective is to learn an estimation device that can output the estimated EE for each domain taking its samples as input. The model has an auto-encoder architecture, with an encoder to take the whole domain sample to approximate EE and a decoder to reconstruct each feature according its graphical parents and EE. The latter bears resemblance with CASTLE (Kyono et al., 2020) except that EE is used for reconstruction. Figure 2 sketches the general model architecture which consists of a domain encoder, a set of structural filters, shared hidden layers and separate output layers. We now explain each part in detail.

Refer to caption
Figure 2: Overview of model structure: 𝐖1k\mathbf{W}_{1}^{k} is the structural filter of the kk-th variable comprised of a (d+1)×h(d+1)\times h matrix where hh is the number of hidden units in the hidden layer. To reconstruct the kk-th variable, all entries in the kk-th row of 𝐖1k\mathbf{W}_{1}^{k} are set to 0 to avoid using itself for reconstruction.

3.2.1 Domain Encoder

An encoder that takes the whole dataset features and outputs an estimated environmental variable EE neglecting the permutation of sample orders for each specific domain is preferred in our case. According the theory of deep sets (Zaheer et al., 2017) below:

Theorem 3.1.

(Zaheer et al., 2017) A function f(X)f(X) on a set XX having countable elements, is a valid set function, i.e. invariant to the permutation of objects in XX if and only if it can be decomposed as the form ρ(xXϕ(x))\rho(\sum_{x\in X}\phi(x)) for suitable transformations ρ\rho and ϕ\phi.

The key to deep sets is to add up all representations and then apply nonlinear transformations. Further inspired by the approximated Bayesian posterior (Maeda et al., 2020) on the variable EE, we design our encoder structure as shown in Figure 3 where:

V(𝐗)\displaystyle V(\mathbf{X}) =(inν(xi)(n1)ν0)1\displaystyle=(\sum_{i}^{n}\nu(x_{i})-(n-1)\nu_{0})^{-1} (5)
μ(𝐗)\displaystyle\mu(\mathbf{X}) =V(𝐗)(inν(xi)ϕ(xi)).\displaystyle=V(\mathbf{X})(\sum_{i}^{n}\nu(x_{i})\phi(x_{i})). (6)

For point estimation on EE, we directly let E^=μ(𝐗)\hat{E}=\mu(\mathbf{X}). For approximate Bayesian inference on EE, we sample E^𝒩(μ(𝐗),V(𝐗))\hat{E}\sim\mathcal{N}(\mu(\mathbf{X}),V(\mathbf{X})). See more about the intuition on encoder structure design in Appendix B.

Refer to caption
Figure 3: The Domain Encoder Structure.

3.2.2 Structural Filters

We directly use a weight matrix as each variable’s structural filter, more details about which are shown in Figure 2. As for other hidden layers in the decoder architecture, we keep them shared for all variables and these will be discussed in next sub-section.

3.2.3 Hidden and Output Layers

Shared Hidden Layers

The model is designed to have shared hidden layers out of two purposes: (1) Learning similar basis functions/representations among variables; (2) Saving the computation resource.

As we have mentioned in assumptions, each variable follows a distribution of exponential family conditioned on its parents (and EE). Since distributions in exponential family can be represented as a common form of probability density function, the shared hidden layers are expected to learn the similarity of basis representation among these variables that are assumed to follow conditional distributions from the same family.

On the other hand, shared hidden layers can substantially reduce the efforts needed for computation during training the model. Normally, we would have separate hidden layers for each variable. However, this will introduce much more learning parameters, which decrease the model’s scalability in high-dimensional setting and could also aggravate over-fitting facing small dataset.

Separate Output Layers

We have separate output layer for each variable of either a continuous type or binary type. For continuous variables, the output layer is simply a weight matrix without any activation function. For binary variables, the output layer will be a weight matrix with sigmoid activation function.

3.2.4 Loss Function

Denote g,Θ1,Θ2,Θ3g,\Theta_{1},\Theta_{2},\Theta_{3} the parameters of encoder, structural filters, shared hidden layers and output layers respectively (Θ=Θ1Θ2Θ3\Theta=\Theta_{1}\cup\Theta_{2}\cup\Theta_{3}), the model is trained by minimising the below loss function with respect to gg and Θ\Theta for each source domain index m[M]m\in[M]:

m=Nm(𝐘m,fd+1(g,Θ))+γE^m2+λ𝒢(𝐗~m,fg,Θ)\mathcal{L}_{m}=\mathcal{L}_{N_{m}}(\mathbf{Y}^{m},f_{d+1}(g,\Theta))+\gamma\hat{E}_{m}^{2}+\lambda\mathcal{R}_{\mathcal{G}}(\mathbf{\tilde{X}}^{m},f_{g,\Theta}) (7)

where for continuous variables:

Nm(𝐘m,fd+1(g,Θ))=1Nm𝐘mfd+1(𝐗m)2\mathcal{L}_{N_{m}}(\mathbf{Y}^{m},f_{d+1}(g,\Theta))=\frac{1}{N_{m}}||\mathbf{Y}^{m}-f_{d+1}(\mathbf{X}^{m})||^{2} (8)

and for binary variables:

Nm(𝐘m,fd+1(g,Θ))=1Nmi=1Nm[𝐘imlogfd+1(𝐗im)\displaystyle\mathcal{L}_{N_{m}}(\mathbf{Y}^{m},f_{d+1}(g,\Theta))=\frac{1}{N_{m}}\sum_{i=1}^{N_{m}}[\mathbf{Y}_{i}^{m}\log f_{d+1}(\mathbf{X}_{i}^{m})
+(1𝐘im)log(1fd+1(𝐗im)].\displaystyle+(1-\mathbf{Y}_{i}^{m})\log(1-f_{d+1}(\mathbf{X}_{i}^{m})]. (9)

We also regularise the estimated E^\hat{E} since a small EE is expected for better generalisation of decoder as shown in Theorem 3.2. The DAG loss 𝒢\mathcal{R}_{\mathcal{G}} takes the form of:

𝒢(𝐗~m,fg,Θ)\displaystyle\mathcal{R}_{\mathcal{G}}(\mathbf{\tilde{X}}^{m},f_{g,\Theta}) =Nm(fg,Θ(𝐗~m))+h(Θ1)\displaystyle=\mathcal{L}_{N_{m}}(f_{g,\Theta}(\mathbf{\tilde{X}}^{m}))+h(\Theta_{1})
+αh(Θ1)2+βl1(Θ1).\displaystyle+\alpha h(\Theta_{1})^{2}+\beta l_{1}(\Theta_{1}). (10)

where Nm(fg,Θ(𝐗~m))\mathcal{L}_{N_{m}}(f_{g,\Theta}(\mathbf{\tilde{X}}^{m})) is the reconstruction loss for all variables including features and the label in domain mm. We use the mean squared loss (8) for continuous variables and cross entropy loss (9) for binary variables. h(Θ1)=0h(\Theta_{1})=0 is the acyclicity constraint of NO-TEARS (Zheng et al., 2020). l1(Θ1)l_{1}(\Theta_{1}) is the group lasso regularisation on the weight matrix in Θ1\Theta_{1}. α\alpha, β\beta and γ\gamma are the corresponding hyper-parameters.

Refer to caption
Figure 4: Loss components and corresponding responsible model parts: the regularisation loss refers to square regularisation on E^\hat{E} and the structural loss includes the DAG constraint and l1l_{1} loss on structural filters.
Generalisation Bound of Decoder

We have derived a generalisation bound of the decoder Θ\Theta trained on i.i.d data within the same domain, which validates the form of our loss function (7).

Theorem 3.2.

Let fΘf_{\Theta}: 𝒳~𝒳~\tilde{\mathcal{X}}\rightarrow\tilde{\mathcal{X}} be a LL-layer ReLU feed-forward neural network decoder with hidden layer size hh. Then, under appropriate assumptions C.1, C.2, C.3 and C.4 on the neural network norm and loss functions (refer to Appendix C.1 for more details), δ(0,1)\forall\delta\in(0,1), with probability at least 1δ1-\delta on a training domain with NN i.i.d samples conditioned on a shared EE, we have:

P(fΘ)\displaystyle\mathcal{L}_{P}(f_{\Theta}) 4Nc(fΘ)+Nb(fΘ)\displaystyle\leq 4\mathcal{L}_{N}^{c}(f_{\Theta})+\mathcal{L}_{N}^{b}(f_{\Theta})
+3N[Θ1+C1E2+C2(𝒱(Θ1)+𝒱(Θ2)\displaystyle+\frac{3}{N}[\mathcal{R}_{\Theta_{1}}+C_{1}\cdot E^{2}+C_{2}(\mathcal{V}(\Theta_{1})+\mathcal{V}(\Theta_{2})
+𝒱(Θ3)+log(8δ))]+C3\displaystyle+\mathcal{V}(\Theta_{3})+\log(\frac{8}{\delta}))]+C_{3} (11)

where C1C_{1}, C2C_{2} and C3C_{3} are constants, 𝒱()\mathcal{V}(\cdot) is the square of l2l_{2} norm on the corresponding parameters and Θ1\mathcal{R}_{\Theta_{1}} is the DAG constraint on Θ1\Theta_{1}. For more details on the theorem proof, please refer to Appendix C.

3.2.5 Training Strategy

In this section, we introduce a novel algorithm for training our model with multiple domains. The flow chart of the training algorithm is depicted as in Figure 5. For more details, please refer to the Algorithm 1 in supplementary materials D.

Refer to caption
Figure 5: Training epochs: we iterate stochastic domain selection and updating parameters within each training epoch for N times where N is the integer (training size/sampled batch size). For each epoch, we randomly select a domain according to the size weights of source domains, and sample a batch size that is no larger than the size of selected training domain. Then we update the encoder and decoder parameters respectively with alternating maximisation-maximisation procedures. We use different batch sizes for different training epochs because we want to ensure the encoder learned will be applicable to both large and small domain datasets. After N iterations of each epoch, we validate the updated model using validation sets from all source domains. If the validation scores on all validation sets are not improving for a predefined patience number, we cease the training and output the model.
Prediction in the Target Domain

To predict the target variable YτY^{\tau} in the target domain, we first feed features of the whole unlabelled dataset into the encoder to get the predicted E^τ\hat{E}_{\tau}. Then we go through corresponding model components by order: the last causal filter of YY, the hidden layers and the last output layer of YY trained from source domains to get the predicted Y^τ\hat{Y}^{\tau} taking E^τ\hat{E}_{\tau} and features 𝐗τ\mathbf{X}^{\tau} as input.

3.2.6 Bayesian Formulation

We can also put the whole framework into Bayesian formulation. The log likelihood of observed data 𝐗~m\mathbf{\tilde{X}}^{m} is

logp(𝐗~m)=\displaystyle\log p(\mathbf{\tilde{X}}^{m})= logq(E|𝐗m)p(E)+logq(E|𝐗m)p(E|𝐗~m)\displaystyle-\log\frac{q(E|\mathbf{X}^{m})}{p(E)}+\log\frac{q(E|\mathbf{X}^{m})}{p(E|\mathbf{\tilde{X}}^{m})}
+logp(𝐗~m|E)\displaystyle+\log p(\mathbf{\tilde{X}}^{m}|E) (12)

By taking the expectation on both sides of (3.2.6) with respect to a variational posterior q(E|𝐗m)q(E|\mathbf{X}^{m}), the evidence lower bound (ELBO) of the marginal distribution of observed data is derived as below:

logp(𝐗~m)\displaystyle\log p(\mathbf{\tilde{X}}^{m})\geq KL(q(E|𝐗m)||p(E))\displaystyle-KL(q(E|\mathbf{X}^{m})||p(E))
+Eq(E|𝐗m)[inmlogpΘ(𝐱im,yim|E)].\displaystyle+E_{q(E|\mathbf{X}^{m})}[\sum_{i}^{n_{m}}\log p_{\Theta}(\mathbf{x}_{i}^{m},y_{i}^{m}|E)]. (13)

Where KL(q(E|𝐗m)||p(E))=12[1+log(σe2)log(V(𝐗m))+1σe2(μ(𝐗m)2+V(𝐗m))]KL(q(E|\mathbf{X}^{m})||p(E))=\frac{1}{2}[-1+\log(\sigma_{e}^{2})-\log(V(\mathbf{X}^{m}))+\frac{1}{\sigma_{e}^{2}}(\mu(\mathbf{X}^{m})^{2}+V(\mathbf{X}^{m}))] if we assume q(E|𝐗m)𝒩(μ(𝐗m),V(𝐗m))q(E|\mathbf{X}^{m})\sim\mathcal{N}(\mu(\mathbf{X}^{m}),V(\mathbf{X}^{m})). It is easily noticed that this KL term also contains a squared regularisation term on estimated E^\hat{E}. We can then replace the prediction loss and reconstruction loss in (7) with corresponding ELBO to train the Bayesian predictor.

Prediction

After getting the trained decoder Θ\Theta and variational parameters qgq_{g} (the encoder parameters), we perform prediction on the target domain τ\tau by approximate inference via sampling :

P(yτ|𝐱τ)\displaystyle P(y^{\tau}|\mathbf{x}^{\tau}) =P(yτ|𝐱τ,E)q(E|𝐗τ)dE\displaystyle=\int P(y^{\tau}|\mathbf{x}^{\tau},E)q(E|\mathbf{X}^{\tau})d_{E}
1Ni=1NP(yτ|𝐱τ,Ei)\displaystyle\approx\frac{1}{N}\sum_{i=1}^{N}P(y^{\tau}|\mathbf{x}^{\tau},E_{i}) (14)

where Eiq(E|𝐗τ)E_{i}\sim q(E|\mathbf{X}^{\tau}).

4 Experiments

In this section, we empirically evaluate the performance of our method for UDA on synthetic and real-world datasets. To begin with, we will briefly describe experiment settings including evaluation metrics, baselines and benchmarks we compare with. In the second sub-session, we discuss experiments on two made-up datasets which comply with our basic assumptions. We demonstrate the performance improvement of DAPDAG (our method) (Please refer to Appendix E.5 for ablation studies on how each part of the model contributes to the performance gain). In the third section, we introduce real-world datasets - MAGGIC (Meta-Analysis Global Group in Chronic Heart Failure) (Mart´ınez-Sellés et al., 2012) with 30 different studies of patients and test our method on the processed datasets of selected studies against benchmarks.

4.1 Experiment Setups

Benchmarks

We benchmark DAPDAG against the plain MLP, CASTLE and MDAN (Multi-domain Adversarial Networks) (Zhao et al., 2018) and BRM (Meta Learning as Bayesian Risk Minimisation) (Maeda et al., 2020). We set MLP to be our baseline method and train it on merged data by directly combining all source domains. MDAN is representative of a class of well-founded DA methods (Pei et al., 2018; Sebag et al., 2019) to learn an invariant feature representation or implicit distribution alignment across domains. They use an adversarial objective to minimise the training loss over labelled sources and distance of feature representation between each source domain and the target domain at the same time. Despite that this class of methods are usually applied in the field of computer vision with high-dimensional image data, we adapt the structure and transfer the idea to our learning setting where data are generated by a DAG with much fewer variables. While BRM can also be regarded as an auto-encoder that could make inference on latent variable perturbing the conditional distribution of YY without reconstructing DAG as an auxiliary task.

Implementation and Training

All methods are implemented using PyTorch driven by GPU. We set the same decoder architecture of DAPDAG as CASTLE except that DAPDAG has an extra domain encoder and an extra row for taking inferred EE in structural filters. Moreover, the DAPDAG decoder has the same number of hidden layers and number neurons in each hidden layer with MLP, BRM decoder and feature extractor of MDAN. We fix the number of hidden layers to be 2 and number of hidden neurons to be 16 for both synthetic and real datasets. For the encoder of DAPDAG and DoAMLP, we use a two-hidden-layer deep-set structure with the same number of neurons as decoder in each hidden layer. The activation function used is ELU and each model is trained using the Adam optimiser (Kingma & Ba, 2014) with an early stopping regime. For the data features with large scales in classification datasets such as ages, BMI (body-mass index), we standardise these variables with a mean of 0 and variance of 1.

4.2 Synthetic Datasets

In this part, we present experiments on synthetic datasets, please refer to E.3 in supplementary materials for more detailed description on synthetic datasets.

Comparison with Benchmarks
Refer to caption
Figure 6: Performance against benchmarks on synthetic datasets. DAPDAG-B denotes the DAPDAG under Bayesian formulation.

We compare DAPDAG with other benchmarks with variant number of training sources with size of 500 for each domain set. As the results show in Figure 6, DAPDAG outperforms all other benchmarks in both classification and regression datasets. Despite the fact that CASTLE does not have the ability to adjust for domain shift, it achieves better performance than MDAN with the ability of domain adaptation. This validates the intuition that in a causally perturbed system, forcing different distributions to be in a similar representation space may not help much compared to finding invariant causal features for prediction. However, these results only sketch a general performance gain of DAPDAG against other methods over multiple combinations of source and target domains. We also compare DAPDAG against benchmarks with respect to different target variables and average source-target difference. The results are shown in Figure 7. We observe that DAPDAG has apparently better performance in scenarios where target domain is significantly different from sources and the target variable is not a sink node (that has no descendants) in the underlying causal DAG.

Refer to caption
Refer to caption
Refer to caption
Figure 7: Left sub-figure shows the DAG for generating the regression datasets (see more details in Appendix 2); Middle figure shows the average rank of each method’s performance on the regression datasets with respect to different target variables selected from nodes in the left DAG. We repeat experiments for each target variable selection over 30 combinations of 9 sources and 1 target domain; Right figure plots the rank of each method’s performance on the regression datasets with respect to the average distance between the target domain and source domains. Each averaged distance is the mean of absolute distance between the target domain and 9 sources.
Evaluation of DAG Learning

We have also included a few experiments evaluating the learned causal DAGs from synthetic regression datasets, as shown in Figure 8, where dd is the number of variables, MM is the number of training domains and SHD is the structural hamming distance (used to measure the discrepancy between the learned graph and the truth, the lower the better). For generating synthetic datasets with different dimensions dd, we randomly generate causal graphs and assign non-linear conditionals (For each variable Xi=𝐖i2σ(𝐖i1[Pai])+ϵiX_{i}=\mathbf{W}_{i}^{2}\sigma(\mathbf{W}_{i}^{1}[Pa_{i}])+\epsilon_{i} where both 𝐖i2\mathbf{W}_{i}^{2} and 𝐖i1\mathbf{W}_{i}^{1} are randomly sampled weight matrices, PaiPa_{i} represent the graphical parents of the variable ii, ϵi\epsilon_{i} is the noise variable and σ\sigma is the activation function) according to the causal order. We compare our method with the baseline CD-NOD (Zhang et al., 2017) in the left plot of Figure 8 (for non-oriented edges, we use the ground-truth directions if possible). Due to an extra prediction loss on the target variable in addition to the reconstruction loss, the learned graphs usually deviate from the truth in terms of mis-specified edges and redundant edges. Yet this prediction loss in the Bayesian formulation will become less important as dimensions increase and then the learned graph will approach the ground truth, which is shown in both the left plot and middle plot of the Figure 8. Meanwhile, the right plot in Figure 8 demonstrates a highly positive relationship between the accuracy of graph learning and prediction performance.

Refer to caption
Figure 8: Evaluation on DAG Learning

Scalability Analysis: Please refer to part E.6 in the appendix for more details.

4.3 MAGGIC Datasets

In this section, we show experimental results on MAGGIC dataset. Since DAPDAG-B (Bayesian formulation) performs better than DAPDAG on synthetic datasets, we only show the performance of DAPDAG in this part. We also add a benchmark - a data imputation method called MisForest (Stekhoven & Bühlmann, 2012) to impute labels in the target domain as missing values. MAGGIC is a collection of 30 datasets from different medical studies containing patients who experienced heart failure. For the UDA task, we take the 12 shared variables by all studies and set the label as one-year survival indicator.

Refer to caption
Figure 9: Performance of DAPDAG-B against benchmarks on MAGGIC datasets for each target study using rest studies as source domains.
Performance

The experiment results on selected MAGGIC studies are demonstrated in Figure 9. The shown results of our method are obtained using the environmental variable 𝐄\mathbf{E} with dimension 3, which is fine-tuned as a hyper-parameter during the model selection.We observe that DAPDAG-B can almost beat other benchmarks on the selected studies in APR scores. Despite the minor improvement against benchmarks in a few studies such as ”BATTL” and ”Kirk” or even worse performance than MissForest in ”Richa”, DAPDAG exhibits significant performance boost in other studies like ”Hilli”, ”Macin” and ”NPC I”, which are found to be more different from rest sources (please refer to E.8 in supplementary materials).

5 Discussion

To sum up, we explore a novel auto-encoder structure that combines estimation of population statistics using deep sets and reconstructing a DAG through a regularised decoder. We prove that under certain assumptions, the loss function has components similar to terms in the generalisation bound of decoder, which validates the form of training loss. Experiments on synthetic and real datasets manifest the performance gain of our method against popular benchmarks in UDA tasks.

Better design of encoder.

Currently, the encoder needs to take the whole dataset from a domain as input, which greatly slows down the training speed when the source dataset size is huge. Meanwhile, a source domain with large sample set is preferred since it will help capture the environmental variable EE. Therefore, a better encoder should be designed to balance the trade-off from domain sizes.

Theoretical exploration on the encoder.

We have derived a generalisation bound for decoder within the same domain yet haven’t looked into the properties of encoder. We hope to dive deeper into theoretical guarantees on the encoder for inference on EE.

Extension to DA with Feature Mismatch.

Currently, we only focus on the task of UDA within the same feature space. In reality, it is highly possible to encounter datasets with different features available in each domain, such as the case of missing features across studies in MAGGIC dataset. Although imputation can be a solution, it can fail if there are a large portion of non-overlapped features for each domain. Therefore, it is imperative to develop approaches that can handle feature mismatch in the near future.

References

  • Bengio et al. (2019) Bengio, Y., Deleu, T., Rahaman, N., Ke, R., Lachapelle, S., Bilaniuk, O., Goyal, A., and Pal, C. A meta-transfer objective for learning to disentangle causal mechanisms. arXiv preprint arXiv:1901.10912, 2019.
  • Cuturi (2013) Cuturi, M. Sinkhorn distances: Lightspeed computation of optimal transport. Advances in neural information processing systems, 26:2292–2300, 2013.
  • Dasgupta et al. (2019) Dasgupta, I., Wang, J., Chiappa, S., Mitrovic, J., Ortega, P., Raposo, D., Hughes, E., Battaglia, P., Botvinick, M., and Kurth-Nelson, Z. Causal reasoning from meta-reinforcement learning. arXiv preprint arXiv:1901.08162, 2019.
  • Deng et al. (2019) Deng, Z., Luo, Y., and Zhu, J. Cluster alignment with a teacher for unsupervised domain adaptation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp.  9944–9953, 2019.
  • Germain et al. (2016) Germain, P., Bach, F., Lacoste, A., and Lacoste-Julien, S. Pac-bayesian theory meets bayesian inference. In Neural Information Processing Systems (NIPS 2016), pp. 1876–1884, 2016.
  • Ghassami et al. (2018) Ghassami, A., Kiyavash, N., Huang, B., and Zhang, K. Multi-domain causal structure learning in linear systems. Advances in neural information processing systems, 31:6266, 2018.
  • Jiang et al. (2020) Jiang, X., Lao, Q., Matwin, S., and Havaei, M. Implicit class-conditioned domain alignment for unsupervised domain adaptation. In International Conference on Machine Learning, pp. 4816–4827. PMLR, 2020.
  • Kaiser & Sipos (2021) Kaiser, M. and Sipos, M. Unsuitability of notears for causal graph discovery. arXiv preprint arXiv:2104.05441, 2021.
  • Kang et al. (2019) Kang, G., Jiang, L., Yang, Y., and Hauptmann, A. G. Contrastive adaptation network for unsupervised domain adaptation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  4893–4902, 2019.
  • Ke et al. (2019) Ke, N. R., Bilaniuk, O., Goyal, A., Bauer, S., Larochelle, H., Schölkopf, B., Mozer, M. C., Pal, C., and Bengio, Y. Learning neural causal models from unknown interventions. arXiv preprint arXiv:1910.01075, 2019.
  • Ke et al. (2020) Ke, N. R., Wang, J., Mitrovic, J., Szummer, M., Rezende, D. J., et al. Amortized learning of neural causal representations. arXiv preprint arXiv:2008.09301, 2020.
  • Kingma & Ba (2014) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Kyono & van der Schaar (2019) Kyono, T. and van der Schaar, M. Improving model robustness using causal knowledge. arXiv preprint arXiv:1911.12441, 2019.
  • Kyono et al. (2020) Kyono, T., Zhang, Y., and van der Schaar, M. Castle: Regularization via auxiliary causal graph discovery. arXiv preprint arXiv:2009.13180, 2020.
  • Lachapelle et al. (2019) Lachapelle, S., Brouillard, P., Deleu, T., and Lacoste-Julien, S. Gradient-based neural dag learning. arXiv preprint arXiv:1906.02226, 2019.
  • Lee et al. (2019) Lee, S., Kim, D., Kim, N., and Jeong, S.-G. Drop to adapt: Learning discriminative features for unsupervised domain adaptation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp.  91–100, 2019.
  • Liu et al. (2019) Liu, H., Long, M., Wang, J., and Jordan, M. Transferable adversarial training: A general approach to adapting deep classifiers. In International Conference on Machine Learning, pp. 4013–4022. PMLR, 2019.
  • Löwe et al. (2020) Löwe, S., Madras, D., Zemel, R., and Welling, M. Amortized causal discovery: Learning to infer causal graphs from time-series data. arXiv preprint arXiv:2006.10833, 2020.
  • Maeda et al. (2020) Maeda, S.-i., Nakanishi, T., and Koyama, M. Meta learning as bayes risk minimization. arXiv preprint arXiv:2006.01488, 2020.
  • Magliacane et al. (2017) Magliacane, S., van Ommen, T., Claassen, T., Bongers, S., Versteeg, P., and Mooij, J. M. Domain adaptation by using causal inference to predict invariant conditional distributions. Advances in neural information processing systems, 2017.
  • Mart´ınez-Sellés et al. (2012) Martínez-Sellés, M., Doughty, R. N., Poppe, K., Whalley, G. A., Earle, N., Tribouilloy, C., McMurray, J. J., Swedberg, K., Køber, L., Berry, C., et al. Gender and survival in patients with heart failure: interactions with diabetes and aetiology. results from the maggic individual patient meta-analysis. European journal of heart failure, 14(5):473–479, 2012.
  • Mooij et al. (2020) Mooij, J. M., Magliacane, S., and Claassen, T. Joint causal inference from multiple contexts. 2020.
  • Nair et al. (2019) Nair, S., Zhu, Y., Savarese, S., and Fei-Fei, L. Causal induction from visual observations for goal directed tasks. arXiv preprint arXiv:1910.01751, 2019.
  • Pan & Yang (2009) Pan, S. J. and Yang, Q. A survey on transfer learning. IEEE Transactions on knowledge and data engineering, 22(10):1345–1359, 2009.
  • Panaretos & Zemel (2019) Panaretos, V. M. and Zemel, Y. Statistical aspects of wasserstein distances. Annual review of statistics and its application, 6:405–431, 2019.
  • Pei et al. (2018) Pei, Z., Cao, Z., Long, M., and Wang, J. Multi-adversarial domain adaptation. In Thirty-second AAAI conference on artificial intelligence, 2018.
  • Quiñonero-Candela et al. (2009) Quiñonero-Candela, J., Sugiyama, M., Lawrence, N. D., and Schwaighofer, A. Dataset shift in machine learning. Mit Press, 2009.
  • Rojas-Carulla et al. (2018) Rojas-Carulla, M., Schölkopf, B., Turner, R., and Peters, J. Invariant models for causal transfer learning. The Journal of Machine Learning Research, 19(1):1309–1342, 2018.
  • Schölkopf et al. (2012) Schölkopf, B., Janzing, D., Peters, J., Sgouritsa, E., Zhang, K., and Mooij, J. On causal and anticausal learning. arXiv preprint arXiv:1206.6471, 2012.
  • Sebag et al. (2019) Sebag, A. S., Heinrich, L., Schoenauer, M., Sebag, M., Wu, L., and Altschuler, S. Multi-domain adversarial learning. In ICLR 2019-Seventh annual International Conference on Learning Representations, 2019.
  • Stekhoven & Bühlmann (2012) Stekhoven, D. J. and Bühlmann, P. Missforest—non-parametric missing value imputation for mixed-type data. Bioinformatics, 28(1):112–118, 2012.
  • Stojanov et al. (2019) Stojanov, P., Gong, M., Carbonell, J., and Zhang, K. Data-driven approach to multiple-source domain adaptation. In The 22nd International Conference on Artificial Intelligence and Statistics, pp.  3487–3496. PMLR, 2019.
  • Vanschoren (2018) Vanschoren, J. Meta-learning: A survey. arXiv preprint arXiv:1810.03548, 2018.
  • Vilalta & Drissi (2002) Vilalta, R. and Drissi, Y. A perspective view and survey of meta-learning. Artificial intelligence review, 18(2):77–95, 2002.
  • Wang & Deng (2018) Wang, M. and Deng, W. Deep visual domain adaptation: A survey. Neurocomputing, 312:135–153, 2018.
  • Yang et al. (2020) Yang, M., Liu, F., Chen, Z., Shen, X., Hao, J., and Wang, J. Causalvae: Structured causal disentanglement in variational autoencoder. arXiv preprint arXiv:2004.08697, 2020.
  • Yu et al. (2019) Yu, Y., Chen, J., Gao, T., and Yu, M. Dag-gnn: Dag structure learning with graph neural networks. In International Conference on Machine Learning, pp. 7154–7163. PMLR, 2019.
  • Zaheer et al. (2017) Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., and Smola, A. J. Deep sets. Advances in Neural Information Processing Systems, 30, 2017.
  • Zhang et al. (2015) Zhang, K., Gong, M., and Schölkopf, B. Multi-source domain adaptation: A causal view. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 29, 2015.
  • Zhang et al. (2017) Zhang, K., Huang, B., Zhang, J., Glymour, C., and Schölkopf, B. Causal discovery from nonstationary/heterogeneous data: Skeleton estimation and orientation determination. In IJCAI: Proceedings of the Conference, volume 2017, pp. 1347. NIH Public Access, 2017.
  • Zhang et al. (2020) Zhang, K., Gong, M., Stojanov, P., Huang, B., Liu, Q., and Glymour, C. Domain adaptation as a problem of inference on graphical models. Advances in neural information processing systems, 2020.
  • Zhang et al. (2019) Zhang, M., Jiang, S., Cui, Z., Garnett, R., and Chen, Y. D-vae: A variational autoencoder for directed acyclic graphs. arXiv preprint arXiv:1904.11088, 2019.
  • Zhao et al. (2018) Zhao, H., Zhang, S., Wu, G., Gordon, G. J., et al. Multiple source domain adaptation with adversarial learning. 2018.
  • Zhao et al. (2019) Zhao, H., Des Combes, R. T., Zhang, K., and Gordon, G. On learning invariant representations for domain adaptation. In International Conference on Machine Learning, pp. 7523–7532. PMLR, 2019.
  • Zheng et al. (2018) Zheng, X., Aragam, B., Ravikumar, P., and Xing, E. P. Dags with no tears: Continuous optimization for structure learning. Advances in neural information processing systems, 2018.
  • Zheng et al. (2020) Zheng, X., Dan, C., Aragam, B., Ravikumar, P., and Xing, E. Learning sparse nonparametric dags. In International Conference on Artificial Intelligence and Statistics, pp.  3414–3425. PMLR, 2020.

Appendix A NOTEARS for Learning Non-linear SEM

How to construct a proxy of 𝐁\mathbf{B} for a general non-linear SEM? Suppose in graph 𝒢\mathcal{G}, there exists a function fi:df_{i}:\mathbb{R}^{d}\rightarrow\mathbb{R} for the ii-th variable XiX_{i} such that

𝔼[Xi|XPa(i)]=fi(𝐗)\mathbb{E}[X_{i}|X_{Pa(i)}]=f_{i}(\mathbf{X}) (15)

where if XjPa(i)X_{j}\not\in Pa(i) then fi(x1,,xd+1)f_{i}(x_{1},...,x_{d+1}) does not depend on xjx_{j}, leading to a result that the function a(u):=fi(x1,,xj1,u,xj+1,,xd+1)a(u):=f_{i}(x_{1},...,x_{j-1},u,x_{j+1},...,x_{d+1}) is constant for all uu\in\mathbb{R}. (Zheng et al., 2020) uses partial derivatives fixj\frac{\partial f_{i}}{\partial x_{j}} to measure the dependence of XiX_{i} on XjX_{j}. Denote jfi=fixj\partial_{j}f_{i}=\frac{\partial f_{i}}{\partial x_{j}}, then it can be shown that

fiXj||jfi||L2=0f_{i}\perp\!\!\!\perp X_{j}\Leftrightarrow||\partial_{j}f_{i}||_{L^{2}}=0 (16)

where ||.||L2||.||_{L^{2}} is the L2L^{2}-norm. Denote the matrix 𝐀(f)d×d\mathbf{A}(f)\in\mathbb{R}^{d\times d} with entries [𝐀(f)]ij:=jfiL2[\mathbf{A}(f)]_{ij}:=||\partial_{j}f_{i}||_{L^{2}}. Then 𝐀(f)\mathbf{A}(f) becomes an non-linear surrogate of the adjacency matrix 𝐁\mathbf{B} in linear models. Now consider using a MLP to approximate the fif_{i}. Suppose the MLP has hh hidden layers and a single activation σ:\sigma:\mathbb{R}\rightarrow\mathbb{R}:

f^i(𝐮)=σ(σ(σ(𝐮𝐖i(1))𝐖i(2))𝐖i(h)),\hat{f}_{i}(\mathbf{u})=\sigma(\sigma(...\sigma(\mathbf{u}\mathbf{W}_{i}^{(1)})\mathbf{W}_{i}^{(2)})\mathbf{W}_{i}^{(h)}), (17)

where 𝐮d\mathbf{u}\in\mathbb{R}^{d} and 𝐖i(l)nl1×nl\mathbf{W}_{i}^{(l)}\in\mathbb{R}^{n_{l-1}\times n_{l}} and n0=dn_{0}=d. It is shown in (Zheng et al., 2020) that if [𝐖i(1)]bk=0[\mathbf{W}_{i}^{(1)}]_{bk}=0 for all k=1,,n1k=1,...,n_{1}, then f^i(𝐮)\hat{f}_{i}(\mathbf{u}) is independent of the kk-th input uku_{k}. Let θ=(θ1,,θd)\theta=(\theta_{1},...,\theta_{d}) with θi=(𝐖i(1),,𝐖i(h))\theta_{i}=(\mathbf{W}_{i}^{(1)},...,\mathbf{W}_{i}^{(h)}) and define [A(θ)]ij[A(\theta)]_{ij} as the norm of jj-th row of 𝐖i(1)\mathbf{W}_{i}^{(1)}. Then it suffices to solve DAG learning by tacking below problem (Zheng et al., 2020):

minθ1ni=1dl(xi,f^i(𝐗,θ𝐢))+λ𝐖i(1)1,1\displaystyle\min\limits_{\theta}\quad\frac{1}{n}\sum_{i=1}^{d}l(x_{i},\hat{f}_{i}(\mathbf{X,\theta_{i}}))+\lambda||\mathbf{W}_{i}^{(1)}||_{1,1} (18)
subject toh(A(θ))=0\displaystyle\text{subject to}\quad h(A(\theta))=0

Appendix B Intuition on the Encoder Design

In this part, we intuitively induce the encoder design drawn from Bayesian posterior of EE. Our objective is to infer the latent variable EE from a sample of features. Following the idea of Bayesian inference, the MAP estimate of a latent variable can be obtained by maximising its posterior. In our case, however, we aim to learn a direct but approximate mapping from the features to the key statistics of EE posterior distribution given those features if its posterior is assumed to have a special form of distribution, e.g. Gaussian.

Consider the observed data {𝐱~m}i=1Nm\{\mathbf{\tilde{x}}^{m}\}_{i=1}^{N_{m}} in source domain m[M]m\in[M]. For notation simplicity, we omit the domain index mm in following texts. let’s begin with the conditional probability of {𝐱~}i=1n\{\mathbf{\tilde{x}}\}_{i=1}^{n} i.i.d data drawn from the same domain, we have:

p({𝐱~}i=1n|E)=i=1np(𝐱~i|E).p(\{\mathbf{\tilde{x}}\}_{i=1}^{n}|E)=\prod_{i=1}^{n}p(\mathbf{\tilde{x}}_{i}|E). (19)

For the posterior of EE given {𝐱~}i=1n\{\mathbf{\tilde{x}}\}_{i=1}^{n}, we have:

p(E|{𝐱~}i=1n)\displaystyle p(E|\{\mathbf{\tilde{x}}\}_{i=1}^{n}) =p({𝐱~}i=1n|E)p(E)p({𝐱~}i=1n)\displaystyle=\frac{p(\{\mathbf{\tilde{x}}\}_{i=1}^{n}|E)\cdot p(E)}{p(\{\mathbf{\tilde{x}}\}_{i=1}^{n})}
p({𝐱~}i=1n|E)p(E)\displaystyle\propto p(\{\mathbf{\tilde{x}}\}_{i=1}^{n}|E)\cdot p(E)
p(E)i=1np(𝐱~i|E)\displaystyle\propto p(E)\prod_{i=1}^{n}p(\mathbf{\tilde{x}}_{i}|E)
p(E)i=1np(E|𝐱~i)p(E)\displaystyle\propto p(E)\prod_{i=1}^{n}\frac{p(E|\mathbf{\tilde{x}}_{i})}{p(E)}
p(E)(n1)i=1np(E|𝐱~i).\displaystyle\propto p(E)^{-(n-1)}\prod_{i=1}^{n}p(E|\mathbf{\tilde{x}}_{i}). (20)

If we further assume both p(E)p(E) and p(E|𝐱~i)p(E|\mathbf{\tilde{x}}_{i}) are members of an exponential family, e.g. Gaussian distributions (without loss of generality), which can be expressed (approximately) as:

p(E)\displaystyle p(E) =𝒩(0,ν01)\displaystyle=\mathcal{N}(0,\nu_{0}^{-1}) (21)
p(E|𝐱~i)\displaystyle p(E|\mathbf{\tilde{x}}_{i}) =𝒩(ϕ(𝐱~i),ν1(𝐱~i))\displaystyle=\mathcal{N}(\phi(\mathbf{\tilde{x}}_{i}),\nu^{-1}(\mathbf{\tilde{x}}_{i})) (22)

where ϕ\phi, ν1\nu^{-1} are approximated mappings and ν01\nu_{0}^{-1} is the parameter for the prior variance of EE. Then we can re-write p(E|{𝐱~}i=1n)p(E|\{\mathbf{\tilde{x}}\}_{i=1}^{n}) as:

p(E|{𝐱~}i=1n)\displaystyle p(E|\{\mathbf{\tilde{x}}\}_{i=1}^{n}) exp(0.5(1n)ν0E2)i=1nexp(0.5ν(𝐱~i)(Eϕ(𝐱~i))2)\displaystyle\propto\exp(-0.5(1-n)\nu_{0}\cdot E^{2})\prod_{i=1}^{n}\exp(-0.5\nu(\mathbf{\tilde{x}}_{i})\cdot(E-\phi(\mathbf{\tilde{x}}_{i}))^{2})
exp(0.5[(1n)ν0E2+i=1nν(𝐱~i)(Eϕ(𝐱~i))2])\displaystyle\propto\exp(-0.5[(1-n)\nu_{0}E^{2}+\sum_{i=1}^{n}\nu(\mathbf{\tilde{x}}_{i})\cdot(E-\phi(\mathbf{\tilde{x}}_{i}))^{2}])
exp(0.5[((1n)ν0+i=1nν(𝐱~i))E22(i=1nϕ(𝐱~i)ν(𝐱~i))E])\displaystyle\propto\exp(-0.5[((1-n)\nu_{0}+\sum_{i=1}^{n}\nu(\mathbf{\tilde{x}}_{i}))E^{2}-2(\sum_{i=1}^{n}\phi(\mathbf{\tilde{x}}_{i})\nu(\mathbf{\tilde{x}}_{i}))E]) (23)

By completion of squares on (23), we get the approximate posterior p(E|{𝐱~}i=1n)𝒩(μ(𝐗~),V(𝐗~))p(E|\{\mathbf{\tilde{x}}\}_{i=1}^{n})\sim\mathcal{N}(\mu(\mathbf{\tilde{X}}),V(\mathbf{\tilde{X}})) with the similar form as in (6) except that the input is 𝐗~\mathbf{\tilde{X}} in (23) instead of 𝐗\mathbf{X} in (6).

Appendix C Generalisation Bound with Mixed Type Data

Our proof of Theorem 3.2 mainly follows the work by (Kyono et al., 2020), except the extension to mixed data type including binary variables and regularisation on the environmental variable EE. Let P(fΘ)\mathcal{L}_{P}(f_{\Theta}) and N(fΘ)\mathcal{L}_{N}(f_{\Theta}) be the expected loss and empirical loss respectively. We further divide each loss into two components - c(fΘ)\mathcal{L}^{c}(f_{\Theta}) as the loss of continuous variables and b(fΘ)\mathcal{L}^{b}(f_{\Theta}) as the loss of binary variables. Similar notations of distinguishing variable types are applied to Θ\Theta, fΘf_{\Theta} and 𝐗~i\tilde{\mathbf{X}}_{i}.

C.1 Assumptions

Assumption C.1.

For any sample 𝐗~=(𝐗,Y)P𝐗~\tilde{\mathbf{X}}=(\mathbf{X},Y)\sim P_{\tilde{\mathbf{X}}}, the continuous variables 𝐗~c\tilde{\mathbf{X}}^{c} has bounded l2l_{2} norm such that B1>0,𝐗~c2B1\exists B_{1}>0,||\tilde{\mathbf{X}}^{c}||_{2}\leq B_{1}. This can further infer that (Kyono et al., 2020):

sup𝐗~𝒳~𝔼𝐮fΘ𝐮c(𝐗~)fΘc(𝐗~)2γ1\sup\limits_{\tilde{\mathbf{X}}\in\tilde{\mathcal{X}}}\mathbb{E}_{\mathbf{u}}||f_{\Theta_{\mathbf{u}}}^{c}(\tilde{\mathbf{X}})-f_{\Theta}^{c}(\tilde{\mathbf{X}})||^{2}\leq\gamma_{1} (24)

where γ1\gamma_{1} is a constant.

Assumption C.2.

For any sample 𝐗~=(𝐗,Y)P𝐗~\tilde{\mathbf{X}}=(\mathbf{X},Y)\sim P_{\tilde{\mathbf{X}}}, we assume

sup𝐗~𝒳~max{j=1b𝔼𝐮logfΘbj(𝐗~)fΘ𝐮bj(𝐗~),j=1b𝔼𝐮log1fΘbj(𝐗~)1fΘ𝐮bj(𝐗~)}γ2\sup\limits_{\tilde{\mathbf{X}}\in\tilde{\mathcal{X}}}\max\{\sum_{j=1}^{b}\mathbb{E}_{\mathbf{u}}||\log\frac{f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}})}{f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}})}||,\sum_{j=1}^{b}\mathbb{E}_{\mathbf{u}}||\log\frac{1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}})}{1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}})}||\}\leq\gamma_{2} (25)

where γ2\gamma_{2} is a constant.

Assumption C.3.

The squared loss function of continuous variables c(fΘ)=fΘc(𝐗~)𝐗~c2\mathcal{L}^{c}(f_{\Theta})=||f_{\Theta}^{c}(\tilde{\mathbf{X}})-\tilde{\mathbf{X}}^{c}||^{2} is sub-Gaussian under the distribution P𝐗~P_{\tilde{\mathbf{X}}} with a proxy-variance factor s12s_{1}^{2} such that ϵ\forall\epsilon\in\mathbb{R}, 𝔼P[exp(ϵ(c(fΘ)Pc(fΘ)))]exp(ϵ2s122)\mathbb{E}_{P}[\exp(\epsilon(\mathcal{L}^{c}(f_{\Theta})-\mathcal{L}_{P}^{c}(f_{\Theta})))]\leq\exp(\frac{\epsilon^{2}s_{1}^{2}}{2}).

Assumption C.4.

The loss function for binary variables b(fΘ)=j=1bbj(fΘ)\mathcal{L}^{b}(f_{\Theta})=\sum_{j=1}^{b}\mathcal{L}^{b_{j}}(f_{\Theta}) where bj\mathcal{L}^{b_{j}} is the cross-entropy loss function of jj-th binary variable, is sub-Gaussian under the distribution P𝐗~P_{\tilde{\mathbf{X}}} with a proxy-variance factor s22s_{2}^{2} such that ϵ\forall\epsilon\in\mathbb{R}, 𝔼P[exp(ϵ(b(fΘ)Pb(fΘ)))]exp(ϵ2s222)\mathbb{E}_{P}[\exp(\epsilon(\mathcal{L}^{b}(f_{\Theta})-\mathcal{L}_{P}^{b}(f_{\Theta})))]\leq\exp(\frac{\epsilon^{2}s_{2}^{2}}{2}).

C.2 Generalisation Bound for the Decoder

Proof. Denote Θ\Theta as the parameters of DAPDAG decoder and Θ𝐮\Theta_{\mathbf{u}} as perturbed Θ\Theta where each parameter in Θ\Theta is perturbed by a noise vector 𝐮𝒩(𝟎,σ2𝐈)\mathbf{u}\sim\mathcal{N}(\mathbf{0},\sigma^{2}\mathbf{I}).

Step 1.

We first the derive the upper bound for the expected loss over parameter perturbation and data distribution. For each shared EE within the same domain, we have (for simplicity, we omit the notation of EE, which serves as an input for fΘf_{\Theta}, fΘ𝐮f_{\Theta_{\mathbf{u}}}, (fΘ)\mathcal{L}(f_{\Theta}) and (fΘ𝐮)\mathcal{L}(f_{\Theta_{\mathbf{u}}}) ):

𝔼𝐮[Nc(fΘ𝐮)]\displaystyle\mathbb{E}_{\mathbf{u}}[\mathcal{L}_{N}^{c}(f_{\Theta_{\mathbf{u}}})] =𝔼𝐮[1Ni=1NfΘ𝐮c(𝐗~i)fΘc(𝐗~i)+fΘc(𝐗~i)𝐗~ic2]\displaystyle=\mathbb{E}_{\mathbf{u}}[\frac{1}{N}\sum_{i=1}^{N}||f_{\Theta_{\mathbf{u}}}^{c}(\tilde{\mathbf{X}}_{i})-f_{\Theta}^{c}(\tilde{\mathbf{X}}_{i})+f_{\Theta}^{c}(\tilde{\mathbf{X}}_{i})-\tilde{\mathbf{X}}_{i}^{c}||^{2}]
=𝔼𝐮[1Ni=1NfΘ𝐮c(𝐗~i)fΘc(𝐗~i)2]+1Ni=1NfΘc(𝐗~i)𝐗~ic2\displaystyle=\mathbb{E}_{\mathbf{u}}[\frac{1}{N}\sum_{i=1}^{N}||f_{\Theta_{\mathbf{u}}}^{c}(\tilde{\mathbf{X}}_{i})-f_{\Theta}^{c}(\tilde{\mathbf{X}}_{i})||^{2}]+\frac{1}{N}\sum_{i=1}^{N}||f_{\Theta}^{c}(\tilde{\mathbf{X}}_{i})-\tilde{\mathbf{X}}_{i}^{c}||^{2}
+𝔼𝐮[2Ni=1N(fΘ𝐮c(𝐗~i)fΘc(𝐗~i))(fΘc(𝐗~i)𝐗~ic)]\displaystyle+\mathbb{E}_{\mathbf{u}}[\frac{2}{N}\sum_{i=1}^{N}(f_{\Theta_{\mathbf{u}}}^{c}(\tilde{\mathbf{X}}_{i})-f_{\Theta}^{c}(\tilde{\mathbf{X}}_{i}))\cdot(f_{\Theta}^{c}(\tilde{\mathbf{X}}_{i})-\tilde{\mathbf{X}}_{i}^{c})]
𝔼𝐮[1Ni=1NfΘ𝐮c(𝐗~i)fΘc(𝐗~i)2]+1Ni=1NfΘc(𝐗~i)𝐗~ic2\displaystyle\leq\mathbb{E}_{\mathbf{u}}[\frac{1}{N}\sum_{i=1}^{N}||f_{\Theta_{\mathbf{u}}}^{c}(\tilde{\mathbf{X}}_{i})-f_{\Theta}^{c}(\tilde{\mathbf{X}}_{i})||^{2}]+\frac{1}{N}\sum_{i=1}^{N}||f_{\Theta}^{c}(\tilde{\mathbf{X}}_{i})-\tilde{\mathbf{X}}_{i}^{c}||^{2}
+𝔼𝐮[1Ni=1NfΘ𝐮c(𝐗~i)fΘc(𝐗~i)2]+1Ni=1NfΘc(𝐗~i)𝐗~ic2\displaystyle+\mathbb{E}_{\mathbf{u}}[\frac{1}{N}\sum_{i=1}^{N}||f_{\Theta_{\mathbf{u}}}^{c}(\tilde{\mathbf{X}}_{i})-f_{\Theta}^{c}(\tilde{\mathbf{X}}_{i})||^{2}]+\frac{1}{N}\sum_{i=1}^{N}||f_{\Theta}^{c}(\tilde{\mathbf{X}}_{i})-\tilde{\mathbf{X}}_{i}^{c}||^{2}
2γ1+2Nc(fΘ)\displaystyle\leq 2\gamma_{1}+2\mathcal{L}_{N}^{c}(f_{\Theta}) (26)

Similarly, we can derive:

Pc(fΘ)\displaystyle\mathcal{L}_{P}^{c}(f_{\Theta}) =𝔼P𝔼𝐮fΘc(𝐗~)fΘ𝐮c(𝐗~)+fΘ𝐮c(𝐗~)𝐗~c2\displaystyle=\mathbb{E}_{P}\mathbb{E}_{\mathbf{u}}||f_{\Theta}^{c}(\tilde{\mathbf{X}})-f_{\Theta_{\mathbf{u}}}^{c}(\tilde{\mathbf{X}})+f_{\Theta_{\mathbf{u}}}^{c}(\tilde{\mathbf{X}})-\tilde{\mathbf{X}}^{c}||^{2}
2γ1+2𝔼𝐮[Pc(fΘ𝐮)]\displaystyle\leq 2\gamma_{1}+2\mathbb{E}_{\mathbf{u}}[\mathcal{L}_{P}^{c}(f_{\Theta_{\mathbf{u}}})]

where γ1\gamma_{1} is the upper bound of 𝔼𝐮fΘ𝐮c(𝐗~)fΘc(𝐗~)2\mathbb{E}_{\mathbf{u}}||f_{\Theta_{\mathbf{u}}}^{c}(\tilde{\mathbf{X}})-f_{\Theta}^{c}(\tilde{\mathbf{X}})||^{2} such that sup𝐗~𝒳~𝔼𝐮fΘ𝐮c(𝐗~)fΘc(𝐗~)2γ1\sup\limits_{\tilde{\mathbf{X}}\in\tilde{\mathcal{X}}}\mathbb{E}_{\mathbf{u}}||f_{\Theta_{\mathbf{u}}}^{c}(\tilde{\mathbf{X}})-f_{\Theta}^{c}(\tilde{\mathbf{X}})||^{2}\leq\gamma_{1}. Let QΘ𝐮Q_{\Theta_{\mathbf{u}}} and PΘ𝐮P_{\Theta_{\mathbf{u}}} be the distribution and prior distribution of perturbed decoder parameter Θ𝐮\Theta_{\mathbf{u}}, QEQ_{E} and PEP_{E} be the distribution and prior distribution of EE respectively. According to Corollary 4 in (Germain et al., 2016) and original proof of CASTLE, we can trivially transfer their theoretical results to continuous variables in our framework. Given PΘ𝐮P_{\Theta_{\mathbf{u}}} and PEP_{E} that are independent of training data in the domain with EE, we can deduce from the PAC-Bayes theorem that with probability at least 1δ1-\delta δ(0,1)\forall\delta\in(0,1), for any NN i.i.d training samples with the shared EE:

𝔼𝐮[Pc(fΘ𝐮)]\displaystyle\mathbb{E}_{\mathbf{u}}[\mathcal{L}_{P}^{c}(f_{\Theta_{\mathbf{u}}})] 𝔼𝐮[Nc(fΘ𝐮)]+1N[2KL(QE,Θ𝐮c||PE,Θ𝐮c)+log8δ]+s122\displaystyle\leq\mathbb{E}_{\mathbf{u}}[\mathcal{L}_{N}^{c}(f_{\Theta_{\mathbf{u}}})]+\frac{1}{N}[2KL(Q_{E,\Theta_{\mathbf{u}}^{c}}||P_{E,\Theta_{\mathbf{u}}^{c}})+\log\frac{8}{\delta}]+\frac{s_{1}^{2}}{2} (28)

where QE,Θ𝐮c=QΘ𝐮cQEQ_{E,\Theta_{\mathbf{u}}^{c}}=Q_{\Theta_{\mathbf{u}}^{c}}\cdot Q_{E} and PE,Θ𝐮c=PΘ𝐮cPEP_{E,\Theta_{\mathbf{u}}^{c}}=P_{\Theta_{\mathbf{u}}^{c}}\cdot P_{E}. Combining C.2, C.2 and 28, we get:

Pc(fΘ)\displaystyle\mathcal{L}_{P}^{c}(f_{\Theta}) 6γ1+4Nc(fΘ)+2N[2KL(QE,Θ𝐮c||PE,Θ𝐮c)+log8δ]+s12\displaystyle\leq 6\gamma_{1}+4\mathcal{L}_{N}^{c}(f_{\Theta})+\frac{2}{N}[2KL(Q_{E,\Theta_{\mathbf{u}}^{c}}||P_{E,\Theta_{\mathbf{u}}^{c}})+\log\frac{8}{\delta}]+s_{1}^{2} (29)

For jj-th binary variable denoted as 𝐗~bj\tilde{\mathbf{X}}^{b_{j}}, we have:

𝔼𝐮[Nbj(fΘ𝐮)]\displaystyle\mathbb{E}_{\mathbf{u}}[\mathcal{L}_{N}^{b_{j}}(f_{\Theta_{\mathbf{u}}})] =𝔼𝐮[1Ni=1N(𝐗~ibjlogfΘ𝐮bj(𝐗~i)+(1𝐗~ibj)log(1fΘ𝐮bj(𝐗~i)))]\displaystyle=-\mathbb{E}_{\mathbf{u}}[\frac{1}{N}\sum_{i=1}^{N}(\tilde{\mathbf{X}}_{i}^{b_{j}}\log f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})+(1-\tilde{\mathbf{X}}_{i}^{b_{j}})\log(1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})))]
+𝔼𝐮[1Ni=1N(𝐗~ibjlogfΘbj(𝐗~i)+(1𝐗~ibj)log(1fΘbj(𝐗~i)))]\displaystyle+\mathbb{E}_{\mathbf{u}}[\frac{1}{N}\sum_{i=1}^{N}(\tilde{\mathbf{X}}_{i}^{b_{j}}\log f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})+(1-\tilde{\mathbf{X}}_{i}^{b_{j}})\log(1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})))]
𝔼𝐮[1Ni=1N(𝐗~ibjlogfΘbj(𝐗~i)+(1𝐗~ibj)log(1fΘbj(𝐗~i)))]\displaystyle-\mathbb{E}_{\mathbf{u}}[\frac{1}{N}\sum_{i=1}^{N}(\tilde{\mathbf{X}}_{i}^{b_{j}}\log f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})+(1-\tilde{\mathbf{X}}_{i}^{b_{j}})\log(1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})))]
=𝔼𝐮[1Ni=1N(𝐗~ibjlogfΘ𝐮bj(𝐗~i)fΘbj(𝐗~i)+(1𝐗~ibj)log1fΘ𝐮bj(𝐗~i)1fΘbj(𝐗~i))]+Nbj(fΘ)\displaystyle=-\mathbb{E}_{\mathbf{u}}[\frac{1}{N}\sum_{i=1}^{N}(\tilde{\mathbf{X}}_{i}^{b_{j}}\log\frac{f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}+(1-\tilde{\mathbf{X}}_{i}^{b_{j}})\log\frac{1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})})]+\mathcal{L}_{N}^{b_{j}}(f_{\Theta})
=1Ni=1N𝔼𝐮[𝐗~ibjlogfΘbj(𝐗~i)fΘ𝐮bj(𝐗~i)+(1𝐗~ibj)log1fΘbj(𝐗~i)1fΘ𝐮bj(𝐗~i)]+Nbj(fΘ)\displaystyle=\frac{1}{N}\sum_{i=1}^{N}\mathbb{E}_{\mathbf{u}}[\tilde{\mathbf{X}}_{i}^{b_{j}}\log\frac{f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}+(1-\tilde{\mathbf{X}}_{i}^{b_{j}})\log\frac{1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}]+\mathcal{L}_{N}^{b_{j}}(f_{\Theta})
1Ni=1N𝔼𝐮[𝐗~ibjlogfΘbj(𝐗~i)fΘ𝐮bj(𝐗~i)+(1𝐗~ibj)log1fΘbj(𝐗~i)1fΘ𝐮bj(𝐗~i)]+Nbj(fΘ)\displaystyle\leq\frac{1}{N}\sum_{i=1}^{N}\mathbb{E}_{\mathbf{u}}[||\tilde{\mathbf{X}}_{i}^{b_{j}}\log\frac{f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}||+||(1-\tilde{\mathbf{X}}_{i}^{b_{j}})\log\frac{1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}||]+\mathcal{L}_{N}^{b_{j}}(f_{\Theta})
1Ni=1N𝔼𝐮[logfΘbj(𝐗~i)fΘ𝐮bj(𝐗~i)+log1fΘbj(𝐗~i)1fΘ𝐮bj(𝐗~i)]+Nbj(fΘ)\displaystyle\leq\frac{1}{N}\sum_{i=1}^{N}\mathbb{E}_{\mathbf{u}}[||\log\frac{f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}||+||\log\frac{1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}||]+\mathcal{L}_{N}^{b_{j}}(f_{\Theta}) (30)

We then upper bound the expected perturbed loss for all binary variables:

𝔼𝐮[Nb(fΘ𝐮)]\displaystyle\mathbb{E}_{\mathbf{u}}[\mathcal{L}_{N}^{b}(f_{\Theta_{\mathbf{u}}})] =𝔼𝐮[j=1bNbj(fΘ𝐮)]\displaystyle=\mathbb{E}_{\mathbf{u}}[\sum_{j=1}^{b}\mathcal{L}_{N}^{b_{j}}(f_{\Theta_{\mathbf{u}}})]
j=1b1Ni=1N𝔼𝐮[logfΘbj(𝐗~i)fΘ𝐮bj(𝐗~i)+log1fΘbj(𝐗~i)1fΘ𝐮bj(𝐗~i)]+j=1bNbj(fΘ)\displaystyle\leq\sum_{j=1}^{b}\frac{1}{N}\sum_{i=1}^{N}\mathbb{E}_{\mathbf{u}}[||\log\frac{f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}||+||\log\frac{1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}||]+\sum_{j=1}^{b}\mathcal{L}_{N}^{b_{j}}(f_{\Theta})
=1Ni=1N[j=1b𝔼𝐮logfΘbj(𝐗~i)fΘ𝐮bj(𝐗~i)+j=1b𝔼𝐮log1fΘbj(𝐗~i)1fΘ𝐮bj(𝐗~i)]+Nb(fΘ)\displaystyle=\frac{1}{N}\sum_{i=1}^{N}[\sum_{j=1}^{b}\mathbb{E}_{\mathbf{u}}||\log\frac{f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}||+\sum_{j=1}^{b}\mathbb{E}_{\mathbf{u}}||\log\frac{1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}}_{i})}{1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}}_{i})}||]+\mathcal{L}_{N}^{b}(f_{\Theta})
2γ2+Nb(fΘ)\displaystyle\leq 2\gamma_{2}+\mathcal{L}_{N}^{b}(f_{\Theta}) (31)

where γ2\gamma_{2} is a constant such that sup𝐗~𝒳~max{j=1b𝔼𝐮logfΘbj(𝐗~)fΘ𝐮bj(𝐗~),j=1b𝔼𝐮log1fΘbj(𝐗~)1fΘ𝐮bj(𝐗~)}γ2\sup\limits_{\tilde{\mathbf{X}}\in\tilde{\mathcal{X}}}\max\{\sum_{j=1}^{b}\mathbb{E}_{\mathbf{u}}||\log\frac{f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}})}{f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}})}||,\sum_{j=1}^{b}\mathbb{E}_{\mathbf{u}}||\log\frac{1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}})}{1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}})}||\}\leq\gamma_{2}. Similar to C.2, we also have:

Pb(fΘ)\displaystyle\mathcal{L}_{P}^{b}(f_{\Theta}) =𝔼P𝔼𝐮[j=1b(𝐗~bjlogfΘbj(𝐗~)+(1𝐗~bj)log(1fΘbj(𝐗~)))]\displaystyle=\mathbb{E}_{P}\mathbb{E}_{\mathbf{u}}[\sum_{j=1}^{b}(\tilde{\mathbf{X}}^{b_{j}}\log f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}})+(1-\tilde{\mathbf{X}}^{b_{j}})\log(1-f_{\Theta}^{b_{j}}(\tilde{\mathbf{X}})))]
𝔼P𝔼𝐮[j=1b(𝐗~bjlogfΘ𝐮bj(𝐗~)+(1𝐗~bj)log(1fΘ𝐮bj(𝐗~)))]\displaystyle-\mathbb{E}_{P}\mathbb{E}_{\mathbf{u}}[\sum_{j=1}^{b}(\tilde{\mathbf{X}}^{b_{j}}\log f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}})+(1-\tilde{\mathbf{X}}^{b_{j}})\log(1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}})))]
+𝔼P𝔼𝐮[j=1b(𝐗~bjlogfΘ𝐮bj(𝐗~)+(1𝐗~bj)log(1fΘ𝐮bj(𝐗~)))]\displaystyle+\mathbb{E}_{P}\mathbb{E}_{\mathbf{u}}[\sum_{j=1}^{b}(\tilde{\mathbf{X}}^{b_{j}}\log f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}})+(1-\tilde{\mathbf{X}}^{b_{j}})\log(1-f_{\Theta_{\mathbf{u}}}^{b_{j}}(\tilde{\mathbf{X}})))]
2γ2+𝔼𝐮[Pb(fΘ𝐮)]\displaystyle\leq 2\gamma_{2}+\mathbb{E}_{\mathbf{u}}[\mathcal{L}_{P}^{b}(f_{\Theta_{\mathbf{u}}})] (32)

Similar to 28, given PΘ𝐮P_{\Theta_{\mathbf{u}}} and PEP_{E} that are independent of training data in the domain with EE, we can deduce from the PAC-Bayes theorem that with probability at least 1δ1-\delta δ(0,1)\forall\delta\in(0,1), for any NN i.i.d training samples with the shared EE:

𝔼𝐮[Pb(fΘ𝐮)]\displaystyle\mathbb{E}_{\mathbf{u}}[\mathcal{L}_{P}^{b}(f_{\Theta_{\mathbf{u}}})] 𝔼𝐮[Nb(fΘ𝐮)]+1N[2KL(QE,Θ𝐮b||PE,Θ𝐮b)+log8δ]+s222\displaystyle\leq\mathbb{E}_{\mathbf{u}}[\mathcal{L}_{N}^{b}(f_{\Theta_{\mathbf{u}}})]+\frac{1}{N}[2KL(Q_{E,\Theta_{\mathbf{u}}^{b}}||P_{E,\Theta_{\mathbf{u}}^{b}})+\log\frac{8}{\delta}]+\frac{s_{2}^{2}}{2} (33)

where QE,Θ𝐮b=QΘ𝐮bQEQ_{E,\Theta_{\mathbf{u}}^{b}}=Q_{\Theta_{\mathbf{u}}^{b}}\cdot Q_{E} and PE,Θ𝐮b=PΘ𝐮bPEP_{E,\Theta_{\mathbf{u}}^{b}}=P_{\Theta_{\mathbf{u}}^{b}}\cdot P_{E}. Combining results from C.2, C.2 and 28, we get:

Pb(fΘ)\displaystyle\mathcal{L}_{P}^{b}(f_{\Theta}) 4γ2+Nb(fΘ)+1N[2KL(QE,Θ𝐮b||PE,Θ𝐮b)+log8δ]+s222\displaystyle\leq 4\gamma_{2}+\mathcal{L}_{N}^{b}(f_{\Theta})+\frac{1}{N}[2KL(Q_{E,\Theta_{\mathbf{u}}^{b}}||P_{E,\Theta_{\mathbf{u}}^{b}})+\log\frac{8}{\delta}]+\frac{s_{2}^{2}}{2} (34)
Step 2.

Notice that Θ=Θ1Θ2Θ3\Theta=\Theta_{1}\cup\Theta_{2}\cup\Theta_{3} where Θ1\Theta_{1}, Θ2\Theta_{2} and Θ3\Theta_{3} represent the parameter of structural filters, shared hidden layers and output layers, we can further dissemble Θi\Theta_{i} for i=1,2,3i=1,2,3 and write PΘ𝐮P_{\Theta_{\mathbf{u}}} and QΘ𝐮Q_{\Theta_{\mathbf{u}}} in more details. Let 𝐖\mathbf{W} be the weight matrix in a neural network layer and LL be the number of hidden layers, then we denote:

Θ1c={𝐖1cj}j=1c,Θ2c={𝐖k}k=2L,Θ3c={𝐖ocj}j=1c\Theta_{1}^{c}=\{\mathbf{W}_{1}^{c_{j}}\}_{j=1}^{c},\quad\Theta_{2}^{c}=\{\mathbf{W}_{k}\}_{k=2}^{L},\quad\Theta_{3}^{c}=\{\mathbf{W}_{o}^{c_{j}}\}_{j=1}^{c} (35)

as the decoder parameters of continuous variables. And the similar denotation for binary variables are as below:

Θ1b={𝐖1bj}j=1b,Θ2b={𝐖k}k=2L,Θ3b={𝐖obj}j=1b.\Theta_{1}^{b}=\{\mathbf{W}_{1}^{b_{j}}\}_{j=1}^{b},\quad\Theta_{2}^{b}=\{\mathbf{W}_{k}\}_{k=2}^{L},\quad\Theta_{3}^{b}=\{\mathbf{W}_{o}^{b_{j}}\}_{j=1}^{b}. (36)

Therefore, it is obvious that:

Θ1=Θ1cΘ1b={𝐖1j}j=1d+1,Θ2=Θ2c=Θ2b={𝐖k}k=2L,Θ3=Θ3cΘ3b={Woj}j=1d+1\Theta_{1}=\Theta_{1}^{c}\cup\Theta_{1}^{b}=\{\mathbf{W}_{1}^{j}\}_{j=1}^{d+1},\quad\Theta_{2}=\Theta_{2}^{c}=\Theta_{2}^{b}=\{\mathbf{W}_{k}\}_{k=2}^{L},\quad\Theta_{3}=\Theta_{3}^{c}\cup\Theta_{3}^{b}=\{W_{o}^{j}\}_{j=1}^{d+1} (37)

Furthermore, we assume both PΘ𝐮P_{\Theta_{\mathbf{u}}} and QΘ𝐮Q_{\Theta_{\mathbf{u}}} can be decomposed into two parts such that:

PΘ𝐮=PΘ𝐮1PΘ𝐮2,QΘ𝐮=PΘ𝐮1QΘ𝐮2P_{\Theta_{\mathbf{u}}}=P_{\Theta_{\mathbf{u}}}^{1}\cdot P_{\Theta_{\mathbf{u}}}^{2},\quad Q_{\Theta_{\mathbf{u}}}=P_{\Theta_{\mathbf{u}}}^{1}\cdot Q_{\Theta_{\mathbf{u}}}^{2} (38)

where PΘ𝐮1P_{\Theta_{\mathbf{u}}}^{1} and QΘ𝐮1Q_{\Theta_{\mathbf{u}}}^{1} are corresponding prior and probability distributions of structural filters Θ1\Theta_{1} that form a DAG, PΘ𝐮2P_{\Theta_{\mathbf{u}}}^{2} and QΘ𝐮2Q_{\Theta_{\mathbf{u}}}^{2} are weight parameter distributions of corresponding layer parameters. Without loss of generality, PΘ𝐮1P_{\Theta_{\mathbf{u}}}^{1} and QΘ𝐮1Q_{\Theta_{\mathbf{u}}}^{1} are assumed to follow normal distributions for simplicity:

PΘ𝐮1=PΘ1,𝐮1𝒩(hΘ1,𝐮;d+1,1),QΘ𝐮1=QΘ1,𝐮1𝒩(hΘ1,𝐮;hΘ1,1)P_{\Theta_{\mathbf{u}}}^{1}=P_{\Theta_{1,\mathbf{u}}}^{1}\sim\mathcal{N}(h_{\Theta_{1,\mathbf{u}}};d+1,1),\quad Q_{\Theta_{\mathbf{u}}}^{1}=Q_{\Theta_{1,\mathbf{u}}}^{1}\sim\mathcal{N}(h_{\Theta_{1,\mathbf{u}}};h_{\Theta_{1}},1)\\ (39)

and the variable hΘ1,𝐮h_{\Theta_{1,\mathbf{u}}} and constant hΘ1h_{\Theta_{1}} take the form as:

hΘ1,𝐮=Tr(exp(𝐀𝐮𝐀𝐮)),hΘ1=Tr(exp(𝐀𝐀))h_{\Theta_{1,\mathbf{u}}}=Tr(\exp(\mathbf{A}_{\mathbf{u}}\odot\mathbf{A}_{\mathbf{u}})),\quad h_{\Theta_{1}}=Tr(\exp(\mathbf{A}\odot\mathbf{A})) (40)

where 𝐀𝐮\mathbf{A}_{\mathbf{u}} is a (d+1)×(d+1)(d+1)\times(d+1) adjacency-proxy matrix such that [𝐀𝐮]i,j[\mathbf{A}_{\mathbf{u}}]_{i,j} is the l2l_{2}-norm of the ii-th row of the jj-th perturbed structural filter matrix W1,𝐮jW_{1,\mathbf{u}}^{j} and \odot represents the Hadamard product operation. From the introduction of NOTEARS method before, we know that hΘ1,𝐮=Tr(𝐈)+k=11k!i=1d+1[(𝐀𝐮𝐀𝐮)k]iid+1h_{\Theta_{1,\mathbf{u}}}=Tr(\mathbf{I})+\sum_{k=1}^{\infty}\frac{1}{k!}\sum_{i=1}^{d+1}[(\mathbf{A}_{\mathbf{u}}\odot\mathbf{A}_{\mathbf{u}})^{k}]_{ii}\geq d+1 and in fact each element in 𝐀𝐮\mathbf{A}_{\mathbf{u}} is non-negative, so using Normal approximation in 40 may not be appropriate for Bayesian Inference. Formally, it is better to consider using truncated normal or exponential priors for better approximation.

And PΘ𝐮2P_{\Theta_{\mathbf{u}}}^{2} and QΘ𝐮2Q_{\Theta_{\mathbf{u}}}^{2} are given as:

PΘ𝐮2\displaystyle P_{\Theta_{\mathbf{u}}}^{2} =j=1d+1𝒩(𝐖1,𝐮j;𝟎,σ2𝐈)k=2L𝒩(𝐖k,𝐮;𝟎,σ2𝐈)j=1d+1𝒩(𝐖o,𝐮j;𝟎,σ2𝐈),\displaystyle=\prod_{j=1}^{d+1}\mathcal{N}(\mathbf{W}_{1,\mathbf{u}}^{j};\mathbf{0},\sigma^{2}\mathbf{I})\prod_{k=2}^{L}\mathcal{N}(\mathbf{W}_{k,\mathbf{u}};\mathbf{0},\sigma^{2}\mathbf{I})\prod_{j=1}^{d+1}\mathcal{N}(\mathbf{W}_{o,\mathbf{u}}^{j};\mathbf{0},\sigma^{2}\mathbf{I}), (41)
QΘ𝐮2\displaystyle Q_{\Theta_{\mathbf{u}}}^{2} =j=1d+1𝒩(𝐖1,𝐮j;𝐖1j,σ2𝐈)k=2L𝒩(𝐖k,𝐮;𝐖k,σ2𝐈)j=1d+1𝒩(𝐖o,𝐮j;𝐖oj,σ2𝐈).\displaystyle=\prod_{j=1}^{d+1}\mathcal{N}(\mathbf{W}_{1,\mathbf{u}}^{j};\mathbf{W}_{1}^{j},\sigma^{2}\mathbf{I})\prod_{k=2}^{L}\mathcal{N}(\mathbf{W}_{k,\mathbf{u}};\mathbf{W}_{k},\sigma^{2}\mathbf{I})\prod_{j=1}^{d+1}\mathcal{N}(\mathbf{W}_{o,\mathbf{u}}^{j};\mathbf{W}_{o}^{j},\sigma^{2}\mathbf{I}). (42)

Recall that we also have a shared environmental variable EE, which can be considered as a parameter independent of each component in the decoder. Despite that the EE value is obtained from the encoder taking sample features as input, for any NN i.i.d samples drawn from the same domain, this EE is fixed as a constant for decoder. Here, we further assume:

PE=𝒩(0,σe2),QE=𝒩(E,σe2).P_{E}=\mathcal{N}(0,\sigma_{e}^{2}),\quad Q_{E}=\mathcal{N}(E,\sigma_{e}^{2}). (43)
Step 3.

By using the fact the that KL of two joint distributions is greater or equal to the KL of two marginal distributions, we can upper bound the KL in 29 and 34 using their versions of joint distributions:

KL(QE,Θ𝐮b||PE,Θ𝐮b)KL(QE,Θ𝐮||PE,Θ𝐮),KL(QE,Θ𝐮c||PE,Θ𝐮c)KL(QE,Θ𝐮||PE,Θ𝐮).\displaystyle KL(Q_{E,\Theta_{\mathbf{u}}^{b}}||P_{E,\Theta_{\mathbf{u}}^{b}})\leq KL(Q_{E,\Theta_{\mathbf{u}}}||P_{E,\Theta_{\mathbf{u}}}),\quad KL(Q_{E,\Theta_{\mathbf{u}}^{c}}||P_{E,\Theta_{\mathbf{u}}^{c}})\leq KL(Q_{E,\Theta_{\mathbf{u}}}||P_{E,\Theta_{\mathbf{u}}}). (44)

And we can upper bound KL(QE,Θ𝐮||PE,Θ𝐮)KL(Q_{E,\Theta_{\mathbf{u}}}||P_{E,\Theta_{\mathbf{u}}}) as follows:

KL(QE,Θ𝐮||PE,Θ𝐮)\displaystyle KL(Q_{E,\Theta_{\mathbf{u}}}||P_{E,\Theta_{\mathbf{u}}}) =QE,Θ𝐮logQE,Θ𝐮PE,Θ𝐮dEdΘ𝐮\displaystyle=\int Q_{E,\Theta_{\mathbf{u}}}\log\frac{Q_{E,\Theta_{\mathbf{u}}}}{P_{E,\Theta_{\mathbf{u}}}}d_{E}d_{\Theta_{\mathbf{u}}}
=QEQΘ𝐮1QΘ𝐮2logQEQΘ𝐮1QΘ𝐮2PEPΘ𝐮1PΘ𝐮2dEdΘ𝐮\displaystyle=\int Q_{E}Q_{\Theta_{\mathbf{u}}}^{1}Q_{\Theta_{\mathbf{u}}}^{2}\log\frac{Q_{E}Q_{\Theta_{\mathbf{u}}}^{1}Q_{\Theta_{\mathbf{u}}}^{2}}{P_{E}P_{\Theta_{\mathbf{u}}}^{1}P_{\Theta_{\mathbf{u}}}^{2}}d_{E}d_{\Theta_{\mathbf{u}}}
=QEQΘ𝐮1QΘ𝐮2logQEPEdEdΘ𝐮+QEQΘ𝐮1QΘ𝐮2logQΘ𝐮1PΘ𝐮1dEdΘ𝐮\displaystyle=\int Q_{E}Q_{\Theta_{\mathbf{u}}}^{1}Q_{\Theta_{\mathbf{u}}}^{2}\log\frac{Q_{E}}{P_{E}}d_{E}d_{\Theta_{\mathbf{u}}}+\int Q_{E}Q_{\Theta_{\mathbf{u}}}^{1}Q_{\Theta_{\mathbf{u}}}^{2}\log\frac{Q_{\Theta_{\mathbf{u}}}^{1}}{P_{\Theta_{\mathbf{u}}}^{1}}d_{E}d_{\Theta_{\mathbf{u}}}
+QEQΘ𝐮1QΘ𝐮2logQΘ𝐮2PΘ𝐮2dEdΘ𝐮\displaystyle+\int Q_{E}Q_{\Theta_{\mathbf{u}}}^{1}Q_{\Theta_{\mathbf{u}}}^{2}\log\frac{Q_{\Theta_{\mathbf{u}}}^{2}}{P_{\Theta_{\mathbf{u}}}^{2}}d_{E}d_{\Theta_{\mathbf{u}}}
QElogQEPEdE+QΘ𝐮1logQΘ𝐮1PΘ𝐮1dΘ𝐮+QΘ𝐮2logQΘ𝐮2PΘ𝐮2dΘ𝐮\displaystyle\leq\int Q_{E}\log\frac{Q_{E}}{P_{E}}d_{E}+\int Q_{\Theta_{\mathbf{u}}}^{1}\log\frac{Q_{\Theta_{\mathbf{u}}}^{1}}{P_{\Theta_{\mathbf{u}}}^{1}}d_{\Theta_{\mathbf{u}}}+\int Q_{\Theta_{\mathbf{u}}}^{2}\log\frac{Q_{\Theta_{\mathbf{u}}}^{2}}{P_{\Theta_{\mathbf{u}}}^{2}}d_{\Theta_{\mathbf{u}}}
=E22σe2+12[hΘ1(d+1)]2+12σ2[j=1d+1𝐖1jF2+k=1L𝐖kF2+j=1d+1𝐖ojF2].\displaystyle=\frac{E^{2}}{2\sigma_{e}^{2}}+\frac{1}{2}[h_{\Theta_{1}}-(d+1)]^{2}+\frac{1}{2\sigma^{2}}[\sum_{j=1}^{d+1}||\mathbf{W}_{1}^{j}||_{F}^{2}+\sum_{k=1}^{L}||\mathbf{W}_{k}||_{F}^{2}+\sum_{j=1}^{d+1}||\mathbf{W}_{o}^{j}||_{F}^{2}]. (45)

By upper-bounding the 34 and 29 using C.2, we have the final generalisation bound of decoder for mixed-type variables. Given PΘ𝐮P_{\Theta_{\mathbf{u}}} and PEP_{E} that are independent of training data within each domain, for any NN i.i.d training samples with the shared EE, then with probability at least 1δ1-\delta δ(0,1)\forall\delta\in(0,1) we have:

P(fΘ,E)\displaystyle\mathcal{L}_{P}(f_{\Theta},E) =Pc(fΘ,E)+Pb(fΘ,E)\displaystyle=\mathcal{L}_{P}^{c}(f_{\Theta},E)+\mathcal{L}_{P}^{b}(f_{\Theta},E)
4Nc(fΘ)+Nb(fΘ)+3N[E2σe2+[hΘ1(d+1)]2\displaystyle\leq 4\mathcal{L}_{N}^{c}(f_{\Theta})+\mathcal{L}_{N}^{b}(f_{\Theta})+\frac{3}{N}[\frac{E^{2}}{\sigma_{e}^{2}}+[h_{\Theta_{1}}-(d+1)]^{2}
+1σ2(j=1d+1||𝐖1j||F2+k=1L||𝐖k||F2+j=1d+1||𝐖oj||F2)+log8δ]+C3\displaystyle+\frac{1}{\sigma^{2}}(\sum_{j=1}^{d+1}||\mathbf{W}_{1}^{j}||_{F}^{2}+\sum_{k=1}^{L}||\mathbf{W}_{k}||_{F}^{2}+\sum_{j=1}^{d+1}||\mathbf{W}_{o}^{j}||_{F}^{2})+\log\frac{8}{\delta}]+C_{3} (46)

where C3=6γ1+4γ2+s12+s222C_{3}=6\gamma_{1}+4\gamma_{2}+s_{1}^{2}+\frac{s_{2}^{2}}{2}.

Appendix D Training Algorithm

This section looks into more details about the training algorithm of DAPDAG. The sudo-code of the algorithm is shown in Algorithm 1.

Algorithm 1 Training Algorithm for DAPDAG
Input: (𝐗im,Yim)i=1nm(\mathbf{X}_{i}^{m},Y_{i}^{m})_{i=1}^{n_{m}} for m[M]m\in[M] where MM is the number of source domains; validation ratio pp; patience kk for early stop.
Output: Domain Encoder ϕ\phi; structural filters Θ1\Theta_{1}; Shared hidden layers Θ2\Theta_{2}; Output layers Θ3\Theta_{3} (decoder Θ=Θ1Θ2Θ3\Theta=\Theta_{1}\bigcup\Theta_{2}\bigcup\Theta_{3}).
for source index m[M]m\in[M] do
  Randomly split (𝐗im,Yim)i=1nm(\mathbf{X}_{i}^{m},Y_{i}^{m})_{i=1}^{n_{m}} into training and validation datasets according to pp;
  Record the size of training data: NmN_{m};
end for
 Obtain number of total training samples from all domains: N=m=1MNmN=\sum_{m=1}^{M}N_{m};
for source index m[M]m\in[M] do
  Compute the weight for each training domain: wm=NmNw_{m}=\frac{N_{m}}{N};
end for
 Initialise all parameters;
for each training epoch do
  for  index i[M]i\in[M] do
   Randomly select a training domain mCat(w1,w2,,wM)m\sim Cat(w_{1},w_{2},...,w_{M});
   Obtain the objective 3.2.6 for the selected domain (𝐗im,Yim)i=1Nm(\mathbf{X}_{i}^{m},Y_{i}^{m})_{i=1}^{N_{m}};
   Update encoder parameters ϕ\phi by maximising 3.2.6 with respect to ϕ\phi;
   Update decoder parameters Θ\Theta by maximising 3.2.6 with respect to Θ\Theta.
  end for
  Compute sum of validation scores from all validation sets;
  if validation score not improving for k epochs then
   break the epoch.
  end if
end for

Appendix E Experiments

E.1 Metrics

Classification

For classification task, we report two scores: Area Under ROC Curve (AUC) and Average Precision-Recall Score (APR). An ROC curve (receiver operating characteristic curve) plots True Positive Rate (TPR) versus False Positive Rate (FPR) at different classification thresholds, showing the performance of a classifier in a more balanced and robust manner. APR summarises a precision-recall curve as the weighted mean of precision attained at each pre-defined threshold, with the increase in recall from the previous threshold used as the weight:

APR=n(RnRn1)PnAPR=\sum_{n}(R_{n}-R_{n-1})P_{n} (47)

where PnP_{n} and RnR_{n} are the precision and recall at the nn-th threshold. Both AUC and APR are computed using the predicted probabilities from classifier and the true labels in binary classification.

Regression

For regression task, we present the coefficient of determination (R2R^{2}), the proportion of the variation in YY that is predictable from 𝐗\mathbf{X}.

E.2 Benchmark: Domain-invariant Representation Methods

Here we give a more detailed description on adversarial methods for UDA with implicit alignment. Please refer to Figure 10 for a general idea of the class of methods.

Refer to caption
Figure 10: The illustration of MDAN (extracted from original paper): {Si}ik\{S_{i}\}_{i}^{k} and TT are indices of source domains and the target domain respectively. The model has a shared multi-layer feature extractor (just same as hidden layers in a plain MLP). The extracted feature vector is then used to reconstruct the label, against which the training loss is minimised over multiple source domains. In the meantime, the feature vector of an instance in source domain SiS_{i} is also used to fool the specific domain classifier that intends to distinguish feature vectors from SiS_{i} and the target domain TT.

We have also added the data-driven unsupervised domain adaptation proposed by (Zhang et al., 2020) in extra comparison experiments. Because it requires a two-stage learning and much more parameters than our approach, we do not include it in the main texts. For more details, please see the Figure 18 for the experiments on synthetic regression datasets.

E.3 Synthetic Dataset Generation

We make two synthetic datasets for classification and regression task respectively: the classification dataset is made up following a DAG learned from MAGGIC dataset and the regression dataset is generated by our own DAG design in Figure 11. The general algorithm of synthetic generation is exhibited in Algorithm 2.

Algorithm 2 Generation Algorithm for Synthetic Datasets
Input: Random seed for sampling; Number of domains MM; Required hyper-parameters NN and σ2\sigma^{2}.
Output: Synthetic Datasets.
for  m[M]m\in[M] do
  Sample an environmental variable 𝐄m𝒩(0,σ2)\mathbf{E}_{m}\sim\mathcal{N}(0,\sigma^{2});
  Sample a domain size NmPois(N)N_{m}\sim Pois(N);
  for  i[Nm]i\in[N_{m}] do
   Generate classification data according to 48;
   Generate regression data according to 49.
  end for
end for

E.3.1 Classification

We refer to the learned causal graph in (Kyono & van der Schaar, 2019) as our ground truth for synthetic classification data (as shown in the right part of Figure 11). The made-up dataset have features that carry explicit meaning in real world thus they are generated compatible with reality to some extent (e.g. design of variable types, range of values, positive and negative causal relations should acknowledge the real-world constraints such as ages can not be negative.). We use 8 features to predict YY: 5-year survival rate of ”made-up” patients. These features are X1X_{1}: Age of patients; X2X_{2}: Ethnicity of the patient; X3X_{3}: Angina; X4X_{4}: Myocardial Infarction; X5X_{5}: ACE Inhibitors; X6X_{6}: NYHA1; X7X_{7}: NYHA2; X8X_{8}: NYHA3. Equations 48 below elaborate more details about their distributions and causal relationships.

{X1Pois(65+0.5E)X2Bernoulli(0.30.025E)X3Bernoulli(0.2)X4Bernoulli(σ(0.5+0.2E+1.3X3))X5Bernoulli(σ(1+0.3E+0.015X1+0.001X2+1.5X3))X6Bernoulli(0.1750.015E)X7Bernoulli(0.3)𝐈X6=0X8Bernoulli(0.6)𝐈X6+X7=0Tlog𝒩(1.5+0.4E0.1(X165)0.05X21.75X32.5X4+0.6X5+0.25X60.75X72X8,1)Y=𝐈T>5\left\{\begin{aligned} X_{1}&\sim Pois(65+0.5\cdot E)\\ X_{2}&\sim Bernoulli(0.3-0.025\cdot E)\\ X_{3}&\sim Bernoulli(0.2)\\ X_{4}&\sim Bernoulli(\sigma(-0.5+0.2\cdot E+1.3\cdot X_{3}))\\ X_{5}&\sim Bernoulli(\sigma(-1+0.3\cdot E+0.015\cdot X_{1}+0.001\cdot X_{2}+1.5\cdot X_{3}))\\ X_{6}&\sim Bernoulli(0.175-0.015\cdot E)\\ X_{7}&\sim Bernoulli(0.3)\cdot\mathbf{I}_{X_{6}=0}\\ X_{8}&\sim Bernoulli(0.6)\cdot\mathbf{I}_{X_{6}+X_{7}=0}\\ T&\sim log\mathcal{N}(1.5+0.4E-0.1(X_{1}-65)-0.05X_{2}-1.75X_{3}-2.5X_{4}\\ &+0.6X_{5}+0.25X_{6}-0.75X_{7}-2X_{8},1)\\ Y&=\mathbf{I}_{T>5}\end{aligned}\right. (48)

where TT is an intermediate variable for deriving YY, which will not show as a feature in the dataset, log𝒩(,)\log\mathcal{N}(\cdot,\cdot) stand for the log-normal distribution.

E.3.2 Regression

Refer to caption
Refer to caption
Figure 11: The underlying DAG for the synthetic classification (left) and regression (right) datasets

The second dataset for regression task is generated according to the DAG in Figure 11. Its structural equations are sketched below:

{X1=0.8E+ϵ1X2=0.4X12+ϵ2X3=0.3E+0.1exp(X2)+ϵ3Y=0.5E2+log(0.3X12+0.7X22)+ϵyX4=0.1X1exp(E)+ϵ4X5=0.25EX4+0.6Y+ϵ5X6=1+0.2X3Y+ϵ6X7=0.6E+3X6+ϵ7\left\{\begin{aligned} X_{1}&=0.8E+\epsilon_{1}\\ X_{2}&=0.4X_{1}^{2}+\epsilon_{2}\\ X_{3}&=0.3E+0.1\exp(X_{2})+\epsilon_{3}\\ Y&=-0.5E^{2}+\log(0.3X_{1}^{2}+0.7X_{2}^{2})+\epsilon_{y}\\ X_{4}&=0.1X_{1}\cdot\sqrt{\exp(E)}+\epsilon_{4}\\ X_{5}&=-0.25E\cdot X_{4}+0.6Y+\epsilon_{5}\\ X_{6}&=-1+0.2X_{3}\cdot Y+\epsilon_{6}\\ X_{7}&=-0.6E+3X_{6}+\epsilon_{7}\end{aligned}\right. (49)

E.4 Verifying Intuition on Synthetic Datasets

In this section, we verify the close relationship between EE difference and Wasserstein distance of two distributions through synthetic causal data and meanwhile dive deeper into how well DAPDAG can learn this EE and exploit this for domain adaptation.

Wasserstein Distance between two empirical distributions
Refer to caption
Figure 12: Simple Example of Transport between Two Empirical Distributions

There exist extensive works inspecting distribution distances, e.g. KL-divergence and H-divergence, and how to utilise these metrics for further applications. In our work, we use a distance metric called Wasserstein distance to measure the distance of two empirical distributions (Panaretos & Zemel, 2019). It’s formal mathematical definition is below: The pp-Wasserstein distance between probability measures μ\mu and ν\nu on d\mathbb{R}^{d} is defined as

Wp(μ,ν)=infXμ,Yν(𝔼XYp)1p,p1.W_{p}(\mu,\nu)=\inf\limits_{X\sim\mu,Y\sim\nu}(\mathbb{E}||X-Y||^{p})^{\frac{1}{p}},\quad p\geq 1. (50)

A very high-level understanding of the distance metric from the optimal transport perspective is the minimum effort it would take to move points of mass from one distribution to the other. Let’s consider a simple example in Figure 12 where we want to move the points in p(x)p(x) to the same places of points in q(x)q(x). There can be a lot of ways of moving, and the arrows in the Figure depict one of them. However, what we are interested is the way with the least effort. This can be approximated using a numeric method called Sinkhorn iterations (Cuturi, 2013). Since our focus is on DA, we skip the details of this algorithm and directly apply the method to compute the distance of each pair of synthetic datasets.

Visualisation of Results
Refer to caption
Figure 13: Comparisons between E differences and Sinkhorn distances of synthetic classification and regression datasets (numbers in the x-axis stand for the indices of domain pairs, e.g. if we have total 9 domains, we will have 36 domain pairs indexed from 0 to 35.)
Refer to caption
Figure 14: Comparisons between E differences and Sinkhorn distances of features from synthetic classification and regression datasets

It is an interesting fact from Figure 13 that the difference of EEs that are used for generating two synthetic datasets can be a regarded as good proxy of Wasserstein distance between these two datasets. For regression data, the absolute difference of EEs almost fully coincides with the log of Wasserstein distance in terms of both values and fluctuations. Since our method utilises the features in the target domain for adaptation, we also plot the relationship between EE difference and Wasserstein distance of features in 14. As shown in the plots, ignorance of labels barely affects the relationship. This finding provides a strong evidence for using only features to find the distribution difference and adjust for the shift accordingly.

Refer to caption
Figure 15: Comparisons between E differences and Sinkhorn distances of synthetic classification datasets with and without standardisation.

However, on the classification dataset, we can see that despite of the resemblance on fluctuations, true values of EE differences deviate to some extent from distances of both full variables and only features. Luckily, we can relieve this issue by standardising the features with large scales. And after standardisation, the distance can better capture the variation of EE difference, as illustrated in Figure 15.

Capturing the EE difference

How well can our method learn the EE difference so as to enable its ability of domain adaptation? We observe that as the number of sources increases, the learned EE difference catches better the trend of true EE difference, which is exhibited in Figure 16. This supports the benefit of training more sources for adaptation.

Refer to caption
Figure 16: Comparisons between E differences and learned E differences from synthetic regression datasets with different number of source domains used (M is the number of source domains and numbers in the x-axis stand for the indices of domain pairs, here for better graphical presentation, we show 15 randomly-selected domain pairs).

E.5 Ablation Studies

We do ablation studies on various loss components in (4) except for the regularisation loss on EE to better understand sources of performance gain. It is noticed that the comparison experiment with CASTLE can be considered as an ablation study on the encoder and once this EE is introduced, the square of EE should be regularised as proved in C.2 to lower the generalisation bound of decoder during training. Therefore, it is not necessary to do a separate ablation study on the squared regularisation term in 7. Besides, we have shown the comparison with BRM as an ablation study on structural filters. Both DAPDAG and CASTLE have the same number of structural filters as the total number of variables and these filters contribute to the reconstruction of each variable and DAG learning. In BRM, however, there is only one filter that selects features locally for the target variable.

Methods M=3 M=5 M=7
AUC APR AUR APR AUR APR
Dag+Spa 0.947(0.021) 0.814(0.091) 0.954(0.017) 0.820(0.075) 0.958(0.015) 0.827(0.063)
Rec+Dag 0.959(0.006) 0.825(0.072) 0.961(0.004) 0.849(0.063) 0.962(0.004) 0.872(0.049)
Rec+Spa 0.960(0.005) 0.845(0.069) 0.962(0.005) 0.854(0.055) 0.963(0.003) 0.890(0.044)
Rec+Dag+Spa 0.961(0.004) 0.856(0.036) 0.964(0.004) 0.883(0.033) 0.965(0.003) 0.893(0.031)
Table 1: Ablation studies on synthetic classification dataset. M is the number of source domains and evaluation metric scores of AUR and APR are averaged over multiple runs with the respective standard deviation in the parentheses, in each of which a target and source domains are randomly selected from a pool of domains. (Dag: NOTEARS regularisation; Spa: group-lasso regularisation on the structural filters; Rec: reconstruction loss of all observed variables.)
Methods RMSE
M=3 M=5 M=7
Dag+Spa 0.422(0.325) 0.444(0.306) 0.495(0.258)
Rec+Dag 0.479(0.254) 0.508(0.221) 0.545(0.173)
Rec+Spa 0.486(0.231) 0.510(0.209) 0.558(0.166)
Rec+Dag+Spa 0.501(0.200) 0.533(0.167) 0.572(0.142)
Table 2: Ablation studies on synthetic regression dataset. M is the number of source domains and the R2R^{2} scores are averaged over multiple runs with the respective standard deviation in the parentheses, in each of which a target and source domains are randomly selected from a pool of domains. (Dag: NOTEARS regularisation; Spa: group-lasso regularisation on the structural filters; Rec: reconstruction loss of all observed variables.)
Refer to caption
Figure 17: Ablation Studies on Synthetic Datasets. In ablation studies, R stands for only including reconstruction loss for the DAG loss; D stands for only including DAG constraint for the DAG loss; S stands for only including sparsity loss for the DAG loss;

The comparison results in Figure 17 verify the importance of structural filters and the regularisation on these filters as a DAG. On the regression dataset, if we do not reconstruct each variable, the performance of DAPDAG will be even worse than BRM with much simpler structure. Therefore, reconstruction of all variables brings significant benefit to prediction while DAG and sparsity constraint further improves the model’s robustness across different domains.

E.6 Scalability

We have extended simulations to cases with higher dimensions, about which you may find more information in Figure 18. In the right plot, we compare our method with two UDA benchmarks in training time versus data dimensions. Despite the minor gap between our approach and (Zhang et al., 2020), ours needs considerately less time for training than theirs in high-dimensional settings.

Refer to caption
Figure 18: Performance and Scalability

E.7 Processing of MAGGIC Dataset

It is tricky that data-preprocessing and domain selection can exert considerate influence on testing performance because these datasets have extensive missing values or features in each study while the usual data imputation methods tend to impute those missing values without taking account of domain distinction. And we admit that in this part future work is needed for better data imputation or UDA methodology that can deal with feature mismatch.

Imputation of Missing Values

Despite extensive instances contained by MAGGIC, each study tends to have massive missing values in certain features or even a number of missing features, which significantly violates our assumptions in terms of causal sufficiency and feature match. Hence it is imperative to process these datasets before use. We omit those with missing labels and use MissForest (Stekhoven & Bühlmann, 2012), a non-parametric missing value imputation method for mixed-type data, to impute missing values of features. We first iterate imputation of missing values over studies. During imputing missing entries in each study, we try to rely on other features of that study as much as possible. For missing features that cannot be imputed based on the single study, we resort to other studies that have the features available. For binary features, the imputed values would be fractions between 0 and 1, which are transformed to 0 or 1 with a threshold at 0.5.

Selection of Domains

We then select processed studies with fewest missing features originally for experiments because those studies are supposed to be affected least by data imputation and maintain the distribution shift from other domains. The selected studies are shown in Table 3 with dataset size followed in the parentheses.

Standardisation of Non-binary features

All continuous features are standardised with mean of 0 and variance of 1, just following the same procedure as we do for synthetic classification dataset.

Meanwhile, a recent work (Kaiser & Sipos, 2021) claims that continuous optimisation/differentiable methods of causal discovery such as NO-TEARS may not work well on dataset with variant scales. They observe inconsistent learning results with respect to data scaling - variables with larger scales or variance tend to be the child nodes. Standardising the data with large scales can alleviate the problem to some extent.

E.8 Supporting Experimental Figures and Plots

E.8.1 Comparison against Benchmarks on MAGGIC Datasets

Target Study AUC Scores
Deep MLP MDAN BRM MisForest CASTLE DAPDAG-B
AHFMS (99.7) 0.782(0.012) 0.785(0.013) 0.811(0.010) 0.819(0.006) 0.826(0.019) 0.854(0.011)
BATTL (58.8) 0.692(0.016) 0.707(0.020) 0.735(0.014) 0.765(0.007) 0.768(0.013) 0.790(0.012)
Hilli (111.5) 0.687(0.014) 0.695(0.020) 0.699(0.014) 0.711(0.012) 0.713(0.005) 0.730(0.013)
Kirk (46.9) 0.868(0.005) 0.891(0.009) 0.931(0.012) 0.936(0.010) 0.954(0.008) 0.970(0.009)
Macin (361) 0.569(0.010) 0.567(0.007) 0.619(0.015) 0.588(0.014) 0.625(0.017) 0.646(0.015)
Mim B (59.4) 0.596(0.011) 0.578(0.014) 0.612(0.018) 0.647(0.013) 0.616(0.024) 0.659(0.016)
NPC I (71.5) 0.517(0.016) 0.524(0.011) 0.540(0.017) 0.542(0.021) 0.533(0.011) 0.571(0.020)
Richa (36.6) 0.712(0.012) 0.711(0.013) 0.739(0.013) 0.782(0.011) 0.757(0.012) 0.775(0.017)
SCR A (54.9) 0.706(0.017) 0.698(0.024) 0.710(0.019) 0.675(0.009) 0.691(0.022) 0.728(0.018)
Tribo (56.6) 0.760(0.006) 0.771(0.010) 0.766(0.016) 0.769(0.012) 0.788(0.015) 0.799(0.010)
Table 3: Classification performance of DAPDAG-B on MAGGIC Dataset against other benchmarks for each target study in the selection pool. For each target study, we set corresponding training domains to be rest 9 studies in the selection pool and add its average distance with respect to sources behind its name in the first column. The performance scores are the averaged AUC over 10 replicates, with standard deviation in the parentheses. Bold denotes the best.
Target Study APR Scores
Deep MLP MDAN BRM MisForest CASTLE DAPDAG
AHFMS (196) 0.914(0.011) 0.921(0.008) 0.930(0.014) 0.936(0.007) 0.938(0.009) 0.949(0.009)
BATTL (363) 0.947(0.013) 0.953(0.010) 0.955(0.009) 0.965(0.003) 0.966(0.004) 0.970(0.006)
Hilli (176) 0.853(0.007) 0.865(0.013) 0.866(0.010) 0.869(0.006) 0.881(0.002) 0.897(0.008)
Kirk (215) 0.923(0.007) 0.938(0.006) 0.952(0.012) 0.967(0.004) 0.969(0.002) 0.972(0.007)
Macin (228) 0.506(0.012) 0.514(0.019) 0.517(0.017) 0.497(0.017) 0.547(0.014) 0.581(0.016)
Mim B (282) 0.812(0.009) 0.821(0.014) 0.825(0.016) 0.837(0.007) 0.819(0.021) 0.846(0.013)
NPC I (66) 0.528(0.011) 0.529(0.019) 0.551(0.018) 0.546(0.013) 0.565(0.023) 0.569(0.018)
Richa (627) 0.879(0.008) 0.884(0.011) 0.894(0.011) 0.921(0.010) 0.912(0.006) 0.918(0.010)
SCR A (324) 0.959(0.011) 0.952(0.012) 0.963(0.013) 0.965(0.008) 0.967(0.013) 0.975(0.012)
Tribo (663) 0.914(0.005) 0.920(0.014) 0.927(0.012) 0.921(0.009) 0.935(0.010) 0.939(0.011)
Table 4: Classification performance of DAPDAG-B on MAGGIC Dataset against other benchmarks for each target study in the selection pool. For each target study, we set corresponding training domains to be rest 9 studies in the selection pool and add its sample size behind its name in the first column. The performance scores are the averaged APR over 10 replicates, with standard deviation in the parentheses. Bold denotes the best.