This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Estimating counterfactual treatment outcomes over time in complex multiagent scenarios

Keisuke Fujii,  Koh Takeuchi, Atsushi Kuribayashi, Naoya Takeishi,
Yoshinobu Kawahara, and Kazuya Takeda
K. Fujii is with Graduate School of Informatics at Nagoya University, JAPAN and RIKEN Center for Advanced Intelligence Project, JAPAN. e-mail: fujii@i.nagoya-u.ac.jp.K. Takeuchi is with Graduate School of Informatics at Kyoto University and RIKEN.A. Kuribayashi and K. Takeda are with Graduate School of Informatics at Nagoya University.N. Takeishi is with Graduate School of Engineering at the University of Tokyo and RIKEN Center for Advanced Intelligence Project.Y. Kawahara is with Graduate School of Information Science and Technology at Osaka University and RIKEN Center for Advanced Intelligence Project.
Abstract

Evaluation of intervention in a multiagent system, e.g., when humans should intervene in autonomous driving systems and when a player should pass to teammates for a good shot, is challenging in various engineering and scientific fields. Estimating the individual treatment effect (ITE) using counterfactual long-term prediction is practical to evaluate such interventions. However, most of the conventional frameworks did not consider the time-varying complex structure of multiagent relationships and covariate counterfactual prediction. This may lead to erroneous assessments of ITE and difficulty in interpretation. Here we propose an interpretable, counterfactual recurrent network in multiagent systems to estimate the effect of the intervention. Our model leverages graph variational recurrent neural networks and theory-based computation with domain knowledge for the ITE estimation framework based on long-term prediction of multiagent covariates and outcomes, which can confirm the circumstances under which the intervention is effective. On simulated models of an automated vehicle and biological agents with time-varying confounders, we show that our methods achieved lower estimation errors in counterfactual covariates and the most effective treatment timing than the baselines. Furthermore, using real basketball data, our methods performed realistic counterfactual predictions and evaluated the counterfactual passes in shot scenarios.

Index Terms:
multiagent modeling, causal inference, deep generative model, trajectory data, autonomous vehicle, sports.

I Introduction

Evaluation of intervention in a real-world multiagent system is a fundamental problem in a variety of engineering and scientific fields. For example, a human driver in an autonomous vehicle, a player in team sports, and an experimenter on animals intervene in multiagent systems to obtain desirable results (e.g., safe driving, a good shot, and specific behavior, respectively) as shown in Fig. 1. In these processes and complex interactions between agents, it is often difficult to estimate the intervention (or treatment) effect to compare the outcomes with and without interventions. There have been many methods to estimate the individual treatment effect (ITE), which evaluates the causal effect of treatment strategies on some important outcomes at the individual level in various fields (e.g., [1, 2, 3]). In particular, some work has been proposed for dealing with time-varying [4, 5] and hidden confounders [6, 7, 8].

However, most of the conventional frameworks did not consider the time-varying complex structures of multiagent relationships and counterfactual covariate predictions, which may lead to erroneous assessments of ITE and difficulty of interpretation. The structures of multiagent relationships include bottom-up ones based on interactions between local agents (often represented as a graph), and top-down ones described by global statistics and/or theories with domain knowledge [9, 10]. Since real-world multiagent systems do not usually have explicit governing equations, both top-down and bottom-up approaches can be simultaneously required for the modeling [11, 12]. That is, in our problem, we need to model multiagent systems in both data-driven and theory-based approaches. In addition, for extracting insights from the ITE estimation results, the circumstances where intervention is effective must be analyzed. Therefore, plausible and interpretable modeling of time-varying multiagent systems is required for estimating ITE.

In this paper, we propose a novel causal inference framework called Theory-based and Graph Variational Counterfactual Recurrent Network (TGV-CRN), which estimates the effect of intervention in multiagent systems in interpretable ways. The causal graph of the problem setting is illustrated in Fig. 2, where the hidden confounders at a particular time stamp not only have causal relations to the observed variables at the same time stamp but also are causally affected by the hidden confounders from previous time stamps. To model the hidden confounders and other causal relationships from data for the ITE estimation framework, we leverage graph variational recurrent neural networks (GVRNNs) [13] to represent local agent interactions and theory-based computation for global properties of the multiagent behaviors based on domain knowledge. This framework is based on the long-term prediction of multiagent covariates and outcomes can confirm under what circumstances the intervention is effective.

Our general motivation is to estimate ITE over time in complex multiagent scenarios. Specifically, in team sports, decision-making skills (e.g., whether a player with the ball should perform a pass or shot) are important. However, we cannot observe both patterns as data (i.e., performing a pass and a shot) in the same situation. For autonomous driving or animal behavioral science, in real-world scenarios, we cannot obtain both data (i.e., with and without intervention) in the same situation (note that in the numerical experiment, we used synthetic data including both cases). For animal behavioral science, similarly, scientific experiments can hardly be performed with all possible intervention timing during movements and sometimes include the bias of the experimenters.

In summary, our main contributions are as follows. (1) We proposed a novel counterfactual recurrent network called TGV-CRN, which estimates the effect of intervention in multiagent systems in interpretable ways, compared with the previous counterfactual recurrent networks dealing with time-varying treatment [4, 5, 6, 7, 8]. (2) Methodologically, for the ITE estimation framework based on long-term prediction of multiagent covariates and outcomes, our model leverages GVRNN to represent local agent interactions and theory-based computation, which was not considered in the above previous work. This framework can confirm under what circumstances the intervention is effective. (3) In experiments using two simulated models of an automated vehicle and biological agents, we show that our methods achieved lower errors in estimating counterfactual covariates and the most effective treatment timing than the baselines. Furthermore, using real basketball data, our methods performed realistic counterfactual predictions. All of these subjects moved interactively in multiagent systems, which are not dealt with in the previous work. We extend our previous short paper [14] by adding theoretical background, experimental results of a synthetic Boid dataset and a real-world basketball dataset, and a sensitivity analysis using the CARLA dataset.

The remainder of this paper is organized as follows. First, in Section II, we describe our problem definition. Next, we describe our methods in Section III. We overview the related works in Section IV, present experimental results in Section V, and conclude this paper in Section VI.

Refer to caption
Figure 1: The illustrations of our problems. Interventions in (A) an autonomous vehicle simulation, (B) a biological agent simulation, and (C) a real basketball are shown. In (A) and (C), a single agent (A: an ego-vehicle and C: a ball player) is intervened whereas multiple agents (all boids) are intervened in (B). The motivations and variable definitions are described in the Introduction and Background sections. In short, we aim to perform long-term counterfactual prediction of outcomes and covariates from the past covariates and intervention (or treatment assignment). In (A), our approach can test autonomous driving software including human interventions without creating the same situations for controlled trials (as real-world scenarios). The outcome is a safe driving distance. In (B), our approach has the potential to estimate the effect of an experimenter’s interventions on multi-animal behaviors, which improves the efficiency of experimental procedures for observing desired movements. The outcome is the angular velocity of multiple agents. In (C), our approach can estimate the effect of the selection of passes in basketball shot scenarios, thus we can evaluate the decision-making skills in this situation (e.g., during a game). The outcome is the effectiveness of an attack.

II Background

In this section, we first give definitions of the notations used throughout the paper, present the assumptions of our methods for estimating ITE, and then introduce VRNNs (variational recurrent neural networks) [15] and GNN (graph neural network) [16] we used.

II-A Preliminary

The multiagent observational data is denoted as the following Xt,At,YtX_{t},A_{t},Y_{t} at time stamp tt. Examples of these variables are shown in Fig. 1. Let XtX_{t} be the time-dependent covariates of the observational data such that Xt={xt(1),,xt(n)}X_{t}=\{x^{(1)}_{t},...,x^{(n)}_{t}\}, where the xt(i)x^{(i)}_{t} denotes the covariates for ii-th multiagent sample with KK agents, and nn denotes the number of samples. The relationships between agents are represented by the theory-based computation and GVRNN described in Section III. Although some related papers [5, 7] consider the static covariates CC, which do not change over time, here we do not explicitly consider CC because we can easily formulate and implement our methods to add CC by conditioning. The treatment (or intervention) assignments are denoted as At={at(1),at(2),,at(n)}A_{t}=\{a^{(1)}_{t},a^{(2)}_{t},...,a^{(n)}_{t}\}, where at(i)a^{(i)}_{t} denotes the treatments assigned in the ii-th sample. We consider ati{0,1}a^{i}_{t}\in\{0,1\}, where 11 is considered treated whereas 0 is the control (i.e., a binary treatment setting), and we estimate the effect of the treatment assigned at time stamp tt on the outcomes Yt+1={yt+1(1),yt+1(2),,yt+1(n)}Y_{t+1}=\{y^{(1)}_{t+1},y^{(2)}_{t+1},...,y^{(n)}_{t+1}\} at time stamp t+1t+1. Note that in observational data, a multiagent sample can only belong to one group (i.e., either a treated or control group), thus the outcome from the other group is always missing and referred to as counterfactual. To represent the historical sequential data before time stamp tt, we use the notation X¯t={X1,X2,,Xt1}\overline{X}_{t}=\{X_{1},X_{2},...,X_{t-1}\} to denote the history of covariates observed before time stamp tt, and A¯t\overline{A}_{t} refers to the history of treatment assignments. Combining all covariates and treatments, we define t(i)={x¯t(i),a¯t(i)}\mathcal{H}^{(i)}_{t}=\{\overline{x}^{(i)}_{t},\overline{a}^{(i)}_{t}\} as all the historical data collected before time stamp tt.

We adopt the potential outcomes framework (e.g., [17]) and extended by [18] to account for time-varying treatments. The potential outcome yat=a,t+1(i)y^{(i)}_{a_{t}=a,t+1} of the ii-th sample given the historical treatment can be formulated as yat=a,t+1(i)=𝔼[y|xt(i),t(i),at=a]y^{(i)}_{a_{t}=a,t+1}=\mathbb{E}[y|x^{(i)}_{t},\mathcal{H}^{(i)}_{t},a_{t}=a], where a={0,1}a=\{0,1\}. Then the ITE on the temporal observational data is defined as:

τt(i)=𝔼[yat=1,t+1(i)yat=0,t+1(i)|xt(i),t(i)].\tau^{(i)}_{t}=\mathbb{E}[y^{(i)}_{a_{t}=1,t+1}-y^{(i)}_{a_{t}=0,t+1}|x^{(i)}_{t},\mathcal{H}^{(i)}_{t}]. (1)

Here, the observed outcome yat=a,t+1(i)y^{(i)}_{a_{t}=a,t+1} under treatment aa is called factual outcome, while the unobserved one yat=1a,t+1(i)y^{(i)}_{a_{t}=1-a,t+1} is the counterfactual outcome.

Refer to caption
Figure 2: The illustration of causal graphs for our problem. We denote Xt,Zt,At,Yt+1X_{t},Z_{t},A_{t},Y_{t+1} as the dynamic covariates, representations of hidden confounders, treatment assignment, and outcomes, respectively. The black lines indicate the causal relations. The hidden confounders Zt+1Z_{t+1} usually affect the treatment assignment At+1A_{t+1}, the outcome Yt+2Y_{t+2}, and the covariate XtX_{t}. To infer Zt+1Z_{t+1}, we can leverage the observational data Xt+1X_{t+1} and previous hidden counfounders ZtZ_{t}.

II-B Assumptions

Our estimation of ITE is based on the following standard assumptions [19, 4, 20], and we further extend the assumptions in our scenario to include time-varying observational data.

Assumption II.1 (Consistency).

The potential outcome under treatment history A¯\overline{A} is equal to the observed outcome if the actual treatment history is A¯\overline{A}.

Assumption II.2 (Positivity).

For any sample ii, if the probability p(a¯t1(i),x¯t(i))0p(\overline{a}^{(i)}_{t-1},\overline{x}^{(i)}_{t})\neq 0, then the probability of receiving treatment 0 or 11 is positive, i.e., 0<p(a¯t(i),x¯t(i))<10<p(\overline{a}^{(i)}_{t},\overline{x}^{(i)}_{t})<1, for all a¯t(i)\overline{a}^{(i)}_{t}.

Assumption II.2 means that, for each time tt, each treatment has a non-zero probability of being assigned. Besides these two assumptions, much of the existing work is based on the strong ignorability assumption as follows:

Definition 1 (Sequential Strong Ignorability).

Given the observed historical covariates x¯t(i)\overline{x}^{(i)}_{t} of the ii-th sample, the potential outcome variables {yat=0,t+1(i),yat=1,t+1(i)}\{y^{(i)}_{a_{t}=0,t+1},y^{(i)}_{a_{t}=1,t+1}\} are independent of the treatment assignment, i.e., {yat=0,t+1(i),yat=1,t+1(i)}at(i)|x¯t(i)\{y^{(i)}_{a_{t}=0,t+1},y^{(i)}_{a_{t}=1,t+1}\}\perp\!\!\!\!\perp a^{(i)}_{t}|\overline{x}^{(i)}_{t}.

Definition 1 means that there are no hidden confounders, i.e., all covariates affecting both the treatment assignment and the outcomes are present in the observational dataset. However, this condition is difficult to be guaranteed in practice especially in real-world observational data (in other words, it is not testable in practice [19, 21]). In this paper, we relax such a strict assumption by acknowledging potential hidden confounders. Our proposed methods can learn the representations of the hidden confounders and eliminate the bias between the treatment assignments and outcomes at each time stamp.

In our approach, the learned representations (denoted by Zt={zt(1),zt(2),,zt(n)}Z_{t}=\{z^{(1)}_{t},z^{(2)}_{t},...,z^{(n)}_{t}\}) can be leveraged to infer the unobserved confounders and act as substitutes of hidden confounders. That is, we extend the strong ignorability assumption by considering the existence of hidden confounders ZtZ_{t}, which influence the treatment assignment AtA_{t} and potential outcomes Yt+1Y_{t+1}. Given the hidden confounders ZtZ_{t}, the potential outcome variables are independent of the treatment assignment at each time stamp. We aim to learn the representations of hidden confounders ZtZ_{t} for bias elimination based on the following assumptions [8]:

Assumption II.3 (Existence of Hidden Confounders).

(i) The hidden confounders may not be accessible, but the covariates are correlated with the hidden confounders, and can be considered as proxy variables, and (ii) hidden confounders at each time stamp are also influenced by the hidden confounders and treatment assignments from previous time stamps.

Based on the premise, we study the identification of ITE:

Theorem 1 (Identification of ITE).

If we recover p(zt(i)|xt(i),t(i))p(z_{t}^{(i)}|x_{t}^{(i)},\mathcal{H}_{t}^{(i)}) and p(yt+1(i)|zt(i),at(i))p(y_{t+1}^{(i)}|z_{t}^{(i)},a_{t}^{(i)}), then the proposed methods can recover the ITE under the causal graph in Fig. 2.

We provide a proof in Appendix G. For simplicity, the sample superscript (ii) will be omitted unless explicitly needed.

Refer to caption
Figure 3: The illustration of TGV-CRN. (A) TGV-CRN aims to estimate ITE based on long-term prediction of multiagent covariates and outcomes while visualizing the long-term future covariate prediction. TGV-CRN leverages GVRNN to represent local agent interactions and theory-based functions for covariate and outcome prediction, which can confirm under what circumstances the intervention is effective. Specifically, (B) the training and inference processes of GNN encoder, prior, and decoder are illustrated. At each time stamp, the model takes the current covariates and treatment assignments as input to learn representations of the hidden confounders via GRNNs and GNN encoders. Then, via theory-based computations, the GNN decoders, and MLPs (multi-layer perceptron), the model predicts time-varying covariates, a potential outcome, and a treatment. We also use the gradient reversal layer before the treatment classifier to ensure the confounder representation distribution of the treated and that of the controlled are similar at the group level.

II-C Variational recurrent and graph neural networks

Here we explain VRNN [15] and GNN [16] used in the following section III.

VRNN. Let xT={x1,,xT}x_{\leq T}=\{x_{1},\dots,x_{T}\} denote a sequence of variables of length TT. The goal of sequential generative modeling is to learn the distribution over sequential data 𝒟\mathcal{D} consisting of multiple demonstrations. A common approach to model the trajectory is to factorize the joint distribution and then maximize the log-likelihood θ=argmaxθxT𝒟t=1Tlogpθ(xt|x<t),\theta^{*}=\operatorname*{arg\,max}_{\theta}\sum_{x_{\leq T}\in\mathcal{D}}\sum_{t=1}^{T}\log p_{\theta}(x_{t}|x_{<t}), where θ\theta denotes the learnable parameters of models such as RNNs. However, RNNs with simple output distributions often struggle to capture highly variable and structured sequential data (e.g., multimodal behaviors) [22]. Recent work in sequential generative models addressed this issue by injecting stochastic latent variables into the model and optimization using amortized variational inference to learn the latent variables (e.g., [15, 23, 24]). Among these methods, a variational RNN (VRNN) [15] has been widely used in base models for multiagent trajectories [13, 22] with unknown governing equations. A VRNN is essentially a variational autoencoder (VAE) conditioned on the hidden state of an RNN and is trained by maximizing the (sequential) evidence lower-bound (ELBO):

vrnn=\displaystyle\mathcal{L}_{vrnn}= 𝔼qϕ(zTxT)[t=1Tlogpθ(xtzT,x<t)\displaystyle\mathbb{E}_{q_{\phi}(z_{\leq T}\mid x_{\leq T})}\Bigg{[}\sum_{t=1}^{T}\log p_{\theta}(x_{t}\mid z_{\leq T},x_{<t})
DKL(qϕ(ztxT,z<t)||pθ(ztx<t,z<t))],\displaystyle-D_{KL}\Big{(}q_{\phi}(z_{t}\mid x_{\leq T},z_{<t})||p_{\theta}(z_{t}\mid x_{<t},z_{<t})\Big{)}\Bigg{]}, (2)

where ztz_{t} is a stochastic latent variable of VAE, and pθ(xtzt,x<t)p_{\theta}(x_{t}\mid z_{\leq t},x_{<t}), qϕ(ztxt,z<t)q_{\phi}(z_{t}\mid x_{\leq t},z_{<t}), and pθ(ztx<t,z<t)p_{\theta}(z_{t}\mid x_{<t},z_{<t}) are generative model, the approximate posterior or inference model, and the prior model, respectively. The first term is the reconstruction term. The second term is the Kullback-Leibler (KL) divergence between the approximate posterior and the prior.

GNN. We then overview a graph neural network (GNN) based on [16]. Let vkv_{k} be a feature vector for each node kk of KK agents. Next, a feature vector for each edge e(k,j)e_{(k,j)} is computed based on the nodes to which it is connected. The edge feature vectors are sent as “messages” to each of the connected nodes to compute their new output state oko_{k}. Formally, a single round of message passing operations of a graph net is characterized below:

ve:e(k,j)\displaystyle v\rightarrow e:e_{(k,j)} =fe([vk,vj]),\displaystyle=~{}f_{e}([v_{k},v_{j}]), (3)
ev:oi\displaystyle e\rightarrow v:~{}~{}~{}~{}~{}~{}o_{i} =fv(jN(k)e(k,j)),\displaystyle=~{}f_{v}\left(\sum_{j\in N(k)}e_{(k,j)}\right), (4)

where N(k)N(k) is the set of neighbors of node kk and fef_{e} and fvf_{v} are neural networks. In summary, a GNN takes in feature vectors v1:Kv_{1:K} and outputs a vector for each node o1:Ko_{1:K}, i.e., o1:K=GNN(v1:K)o_{1:K}={\rm{GNN}}(v_{1:K}). The operations of the GNN satisfy the permutation equivariance property as the edge construction is symmetric between pairs of nodes and the summation operator ignores the ordering of the edges [25].

III Proposed Method

Here, we describe our TGV-CRN method for ITE estimation in multiagent observational data. The overall framework is illustrated in Fig. 3A. We aim to combine predictions of outcome and covariates using data-driven and theory-based approaches, while balancing the representations of treated and control groups to reduce the confounding bias. To this end, we first introduce the representation learning of hidden confounders with balancing by mapping the current multiagent observational data and historical information into the representation space. Next, we describe the prediction methods of the time-varying covariates, a potential outcome, and the treatment using the learned representations. Finally, we describe the loss function.

III-A Representation learning of confounders

Here, as a main approach in Fig. 3A, we extend a GVRNN [13] for local multiagent locations xtlx^{l}_{t} (i.e., specific for each agent) with theory-based computation. As its variant (e.g., used for the ablation study), a pure data-driven model combining GVRNN and VRNN [15] for global variables xtgx^{g}_{t} (i.e., common for all agents) is also considered. Since the global variables do not usually have the graph structure, VRNN without GNN is suitable. Here we describe the representation learning of hidden confounders.

GVRNN. We first describe Graph Variational RNN (GVRNN) [13] to obtain the representation from multiagent locations, which models the interactions between them at each step using GNNs. Let xTl={x1l,,xTl}x^{l}_{\leq T}=\{x^{l}_{1},\dots,x^{l}_{T}\} denote a sequence of covariates (here we consider multiagent locations). In this paper, GVRNN’s update equations are as follows:

pθ(ztl|xtl,z<tl)\displaystyle p_{\theta}(z^{l}_{t}|x^{l}_{\leq t},z^{l}_{<t}) =k𝒩(zt,kl|μt,kpri,(σt,kpri)2),\displaystyle=\prod_{k}\mathcal{N}(z^{l}_{t,k}|\mu^{\rm{pri}}_{t,k},(\sigma^{\rm{pri}}_{t,k})^{2}), (5)
qϕ(ztl|xt+1l,z<tl)\displaystyle q_{\phi}(z^{l}_{t}|x^{l}_{\leq t+1},z^{l}_{<t}) =k𝒩(zt,kl|μt,kenc,(σt,kenc)2),\displaystyle=\prod_{k}\mathcal{N}(z^{l}_{t,k}|\mu^{\rm{enc}}_{t,k},(\sigma^{\rm{enc}}_{t,k})^{2}), (6)
pθ(xt+1l|ztl,xtl)\displaystyle p_{\theta}(x^{l}_{t+1}|z^{l}_{\leq t},x^{l}_{\leq t}) =k𝒩(zt,kl|μt,kdec,(σt,kdec)2),\displaystyle=\prod_{k}\mathcal{N}(z^{l}_{t,k}|\mu^{\rm{dec}}_{t,k},(\sigma^{\rm{dec}}_{t,k})^{2}), (7)
ht+1,kl\displaystyle h^{l}_{t+1,k} =frnnl(xt+1,kl,zt,kl,ht,kl),\displaystyle=f^{l}_{rnn}(x^{l}_{t+1,k},z^{l}_{t,k},h^{l}_{t,k}), (8)

where htlh^{l}_{t} and ztlz^{l}_{t} are deterministic and stochastic latent variables, 𝒩(|μ,σ2)\mathcal{N}(\cdot|\mu,\sigma^{2}) denotes a multivariate normal distribution with mean μ\mu and covariance matrix diag(σ2\sigma^{2}), and

[μt,1:Kpri,σt,1:Kpri]\displaystyle[\mu^{\rm{pri}}_{t,1:K},\sigma^{\rm{pri}}_{t,1:K}] =GNNpri(ht,1:Kl),\displaystyle={\rm{GNN_{pri}}}(h^{l}_{t,1:K}), (9)
[μt,1:Kenc,σt,1:Kenc]\displaystyle[\mu^{\rm{enc}}_{t,1:K},\sigma^{\rm{enc}}_{t,1:K}] =GNNenc([xt+1,1:Kl,ht,1:Kl]),\displaystyle={\rm{GNN_{enc}}}([x^{l}_{t+1,1:K},h^{l}_{t,1:K}]), (10)
[μt+1,1:Kdec,σt+1,1:Kdec]\displaystyle[\mu^{\rm{dec}}_{t+1,1:K},\sigma^{\rm{dec}}_{t+1,1:K}] =GNNdec([zt,1:Kl,ht,1:Kl]).\displaystyle={\rm{GNN_{dec}}}([z^{l}_{t,1:K},h^{l}_{t,1:K}]). (11)

The prior network GNNpri{\rm{GNN_{pri}}}, encoder GNNenc{\rm{GNN_{enc}}}, and decoder GNNdec{\rm{GNN_{dec}}} are GNNs with learnable parameters ϕ\phi and θ\theta. The relationship among them is illustrated in Fig. 3B. Here we used the mean value μt+1,1:Kdec\mu^{\rm{dec}}_{t+1,1:K} as input variables x^t+1l\hat{x}^{l^{\prime}}_{t+1} in the following theory-based computation. GVRNN is trained by maximizing the sequential ELBO in a similar way to VRNN as described in Eq. (2), which is denoted as gvrnn\mathcal{L}_{gvrnn}.

Combined representation learning. To construct a fully data-driven model (instead of the theory-based computation in Fig. 3A), we propose a hierarchical GVRNN combining GVRNN for multiagent locations and VRNN for global inputs to learn the representation of hidden confounders. In summary, each agent’s trajectory and other global information are processed through gated recurrent units (GRUs). The GRU parameters for the agent’s trajectory are shared but keep its own individual recurrent state. At each time stamp, the model takes the current covariates and treatment assignments as input for learning representations of hidden confounders zt=[ztl,ztg]z_{t}=[z^{l}_{t},z^{g}_{t}] via GRUs, GNN, and MLP encoders.

III-B Prediction with learned representation

Our methods predict time-varying covariates, potential outcomes, and treatment for balancing by combining data-driven and theory-based approaches. Here our contributions are to propose the theory-based computation of global variables and the prediction of time-varying covariates. In this subsection, we describe the theory-based computation and the prediction of time-varying covariates, potential outcomes, and treatment for balancing.

Theory-based computation. Here we assume that we do not have simulators including governing equations (used e.g., in [16, 26]) of multiagent systems such as team sports. In such a situation, we can utilize theory or prior knowledge of the domain using two approaches. One is to partially incorporate rule-based models into the data-driven model such as a mathematical relationship [27] (e.g., between position and velocity), the biological constraints (e.g., turn angle in Sec. V-A), and critical behaviors (e.g., collision in Sec. V-A, ball movement in the air, and defending against the shot in Sec. V-B). Another is to compute auxiliary features for potential outcome predictions such as global variables (e.g., mean angular momentum in Sec. V-A or specific inter-agent distances in Secs. V-A and V-B). In our problem, we need to predict potential outcomes and covariates using predicted local variables from the data-driven model in Sec. III-A as shown in Fig. 3A. That is,

y^t+1=ftheoryy([ztl,xtg,at])\displaystyle\hat{y}_{t+1}=f^{y}_{theory}([z^{l}_{t},x^{g}_{t},a_{t}]) (12)
x^t+1=ftheoryx(x^t+1l),\displaystyle\hat{x}_{t+1}=f^{x}_{theory}(\hat{x}^{l^{\prime}}_{t+1}), (13)

where x^t+1=[x^t+1l,x^t+1g]\hat{x}_{t+1}=[\hat{x}^{l}_{t+1},\hat{x}^{g}_{t+1}] and ftheoryyf^{y}_{theory} and ftheoryxf^{x}_{theory} are theory-based functions utilizing domain knowledge. For the details, see Sec. V.

Prediction of time-varying covariates. Another contribution of this paper is to propose methods to predict time-varying covariates, which can confirm under what circumstances the intervention is effective. The proposed TGV-CRN infers time-varying covariates via GNN decoders and the theory-based computation as illustrated in Fig. 3A. For a fully data-driven model, we use the VRNN decoder for global variables instead of the theory-based computation. We minimize the factual covariate loss function as follows:

x=1nTnt=1T(x^txt)2.\mathcal{L}_{x}=\frac{1}{nT}\sum^{n}\sum_{t=1}^{T}(\hat{x}_{t}-x_{t})^{2}. (14)

Prediction of a potential outcome. Next, we describe a potential outcome prediction network to estimate the outcome y^t+1\hat{y}_{t+1} as described in Eq. (12). Depending on the problem of various data domains, MLP was placed before or after the theory-based function ftheoryyf^{y}_{theory} to properly model the potential outcome (for details, see Section V). For a fully data-driven model, we use MLP instead of the theory-based computation. We minimize the factual loss function as follows:

y=1nTnt=1T(y^t+1yt+1)2.\mathcal{L}_{y}=\frac{1}{nT}\sum^{n}\sum_{t=1}^{T}(\hat{y}_{t+1}-y_{t+1})^{2}. (15)

Treatment prediction and balancing. We also predict the treatment assignments a^t\hat{a}_{t} at each time stamp. The predicted treatments a^t\hat{a}_{t} are obtained through a fully-connected layer with a sigmoid function as the last layer. That is, a^t\hat{a}_{t} is the probability of receiving treatment based on the confounders at time tt, which can be typically referred to as propensity score [28] a^t=p(at=1|zt)\hat{a}_{t}=p(a_{t}=1|z_{t}).

Since we consider the binary treatment in this paper, we use a cross-entropy loss for the treatment prediction as follows:

a=1nTnt=1T(atloga^t+(1at)log(1a^t)).\mathcal{L}_{a}=-\frac{1}{nT}\sum^{n}\sum_{t=1}^{T}(a_{t}\log{\hat{a}_{t}}+(1-a_{t})\log{(1-\hat{a}_{t})}). (16)

One critical consideration is balancing the representations of treated and control groups, which helps reduce the confounding bias [5, 7, 8] and minimize the upper bound of the outcome inference error [29]. In this paper, we incorporate the adversarial learning-based balancing method using a gradient reversal layer [30] according to the related studies [5, 8]. Figure 4 illustrates our adversarial training process. Let GNNenc(;θg)\rm{GNN_{enc}(\cdot;\theta_{g})} be the GNN encoder with parameters θg\theta_{g}. Also, let MLPa(;θa)\rm{MLP_{a}(\cdot;\theta_{a})} and ftheoryy(;θy)f^{y}_{theory}(\cdot;\theta_{y}) be the MLP for treatment and outcome prediction with parameters θa\theta_{a} and θy\theta_{y} (if ftheoryyf^{y}_{theory} is a neural network). γ\gamma and λ\lambda are the hyperparameters (or weights) of loss functions as described below. We add the gradient reversal layer before the treatment classifier MLPa(;θa)\rm{MLP_{a}(\cdot;\theta_{a})} to ensure the confounder representation distribution of the treated and that of the controlled are similar at the group level [5, 8]. The gradient reversal layer does not change the input during the forward-propagation, but in the back-propagation, reversing the gradient by multiplying by a negative scalar (i.e., λaθg-\lambda\frac{\partial\mathcal{L}_{a}}{\partial\theta_{g}} in Fig. 4). Although the model will minimize the loss of the treatment prediction, the adversarial learning process will contribute to the distribution balancing during the prediction of the potential outcome and the covariates.

Refer to caption
Figure 4: Illustration of gradient reversal layer (GRL). The diagram is a part of Fig. 3 ignoring covariate prediction and includes backward passes. θg\theta_{g}, θa\theta_{a}, and θy\theta_{y} are the neural network parameters of the GNN encoder, MLP for treatment and outcome prediction (if ftheoryyf^{y}_{theory} is a neural network). GRL multiplies the gradient by a certain negative constant during the backpropagation-based training.

III-C Loss function

The loss function for our method is defined as

=y+αgvrnn+γx+λa\mathcal{L}=\mathcal{L}_{y}+\alpha\mathcal{L}_{gvrnn}+\gamma\mathcal{L}_{x}+\lambda\mathcal{L}_{a} (17)

where y\mathcal{L}_{y} is the factual outcome prediction loss, gvrnn\mathcal{L}_{gvrnn} is the ELBO in GVRNN, x\mathcal{L}_{x} is the covariate prediction loss, and a\mathcal{L}_{a} is the treatment prediction loss. α\alpha, γ\gamma, and λ\lambda are hyperparameters to balance the loss function. In a pure data-driven model, we add vrnn\mathcal{L}_{vrnn} to \mathcal{L}. To prevent over-fitting, we select the best-performing model using the loss function on the validation set. The sensitivity analysis in the hyperparameters is presented in Appendix I.

IV Related work

Learning causal effects with time-varying data. Pioneering work to estimate the effects of time-varying exposures has developed such as g-computation, structural nested models, and marginal structural models (e.g., [31, 32, 19, 18]) in statistics and epidemiology domains. To handle complex time-dependencies, Bayesian non-parametric approaches based on Gaussian processes [33, 34, 35] and Dirichlet processes [36] have been proposed. Moreover, to model time-dependencies without strong assumptions on functional forms, RNN approaches have been intensively investigated such as recurrent marginal structural networks [4] and adversarial training to balance the historical confounders [5]. Recently, to estimate the treatment effects with hidden confounders, several methods have been proposed by relaxing the assumption of strong ignorability [6, 7, 8]. Our approach is related to these works, but they did not consider the utilization of domain knowledge and the explicit prediction of time-varying covariates, which is necessary for the proper interpretation of the results.

Representation learning for treatment effect estimation. Estimation methods of balanced representation between treated and control groups in hidden space have been proposed in the static setting [37, 38, 39]. In neural network approaches, methods with regularization for the balancing [40, 29], incorporating the local similarity among individuals [41], generative adversarial nets (GAN) [42], a multi-task learning [43], probabilistic modeling [44, 45], and optimal transport framework [46] have been proposed. In the dynamic setting, [5, 8] adopted adversarial training techniques with a gradient reversal layer, [6] proposed multi-task learning to build a factor model of the cause distribution, and [7] used inverse probability of treatment weighting. To model multiagent or networked systems, GNN-based approaches have been intensively used in prediction problems (e.g., [47, 48]). To obtain representations from the networked covariates, [8] incorporated graph convolutional networks and [49] incorporated hierarchical spatial graph structure into causal estimation frameworks (other counterfactual learning on graphs was surveyed by [50]). Recently, continuous-time causal inferences for dynamical systems [51, 52] have been proposed. We adopted the adversarial training to balance the representation and introduced GVRNN with prior (domain) knowledge for a more accurate long-term prediction of potential outcomes and covariates.

Counterfactual prediction for multiagent data. In past research of real-world multiagent movements, counterfactual prediction of human movements has been studied. As control problems, forward models such as in robotics, can be considered. However, our problem is inherently a backward problem to model the multiagent behaviors from data, then here we focus on such backward approaches. Counterfactual trajectory prediction methods have been proposed in pedestrians [53], animals [54], and team sports [13, 27, 55, 56]. Outside the context of causal inference based on the Rubin causal model, the causality of the physics-based system such as using neural rational inference [16] and learning physics in a game engine with randomized stability [57] have been considered (sometimes called causal reasoning). In the context of reinforcement learning, the counterfactual action prediction and evaluation can be considered in team sports [58, 59, 60, 61]. In traffic environments, methods to understand the reasons behind the maneuvers of other vehicles and pedestrians for escaping accidents [62, 63, 64] have been proposed. In the static setting of causal inference, [65] proposed a spatial convolutional counterfactual regression to estimate the effects of crowd movement guidance. In team sports, propensity score matching was used to investigate the causal effect of some plays or timeouts in many sports [66, 67, 68, 69]. In the dynamic setting, [70] applied a g-computation method to examine the effect of a specific pitch in baseball. We firstly propose a framework for estimating ITE in multiagent motions.

V Experiments

The purpose of our experiments is to validate the proposed methods for application to real-world multiagent trajectories, which usually have no ground truth of the interaction rules. Hence, for verification of our methods, we first compared their performances to infer the ITE to those in various baselines using two synthetic datasets with ground truth: the CARLA [71] autonomous driving simulator and a biological multiagent model called Boid [10]. In particular, for real-world applications, we need long-term counterfactual predictions. We mainly verified the estimation of ITE using the CARLA dataset (relatively fewer interactions) and that of the best intervention timing using the Boid dataset (relatively more and complex interactions) because the intervention timing is sensitive. Finally, we examined the applicability of our methods to real-world data using the basketball (NBA) dataset. In common, we separated the time TT into a burn-in period 1,,Tb1,\ldots,T_{b} and a prediction period Tp=Tb+1,,TT_{p}=T_{b}+1,\ldots,T. The intervention point during the intervention period is denoted as TiT_{i} (a similar interval to TpT_{p}, but we set it differently in each problem). The hyperparameters of the models were determined by validation datasets in each experiment (for the details, see Appendices J, K, and L).

Here, we commonly compared our methods to four baseline methods: a simple RNN baseline using GRU [72], deep sequential weighting (DSW) [7] as a baseline considering hidden confounders, graph counterfactual recurrent network (GCRN: for clarity, we change the name) [8] , and the variant to predict covariates (GCRN++X). These baselines were modified to our setting (e.g., multiagent and time-varying outcome: for details, see Appendix H). For verification, since the prediction of the covariates is required, the most appropriate baseline is GCRN++X, which is compared via visualization because DSW and GCRN cannot predict the covariates. We also validated our approach with three variants: GV-CRN (without the theory-based computation, i.e., without global covariates), TV-CRN (removing GNN), and TG-CRN (replacing VRNN with RNN). To perform fair comparisons among the models, we trained all models with 20 epochs (the learning curves are shown in Appendix N). Other common training details are described in Appendix H and the codes are provided at https://github.com/keisuke198619/TGV-CRN.

V-A Synthetic datasets

To verify our method, we compared the performances to infer the causal effect with those in various baselines using two synthetic datasets with ground truth. We used two types of simulations: autonomous vehicle and biological agent (Boid) simulations. All simulations are rule-based rather than learning-based as described below. As for performance metrics, we first adopted commonly used potential outcome and covariate prediction errors. Potential outcome prediction errors were computed as an absolute error of all simulated outcomes: Loutcome(at,t,i,t)=|y^at,t(i,t)yat,t(i,t)|L_{outcome}^{(a_{t},t,i,t^{\prime})}=|\hat{y}_{a_{t},t}^{(i,t^{\prime})}-y_{a_{t},t}^{(i,t^{\prime})}|, where a treatment at={0,1}a_{t}=\{0,1\}, time t=[Tb,T+1]t=[T_{b},T+1], simulated temporal variation of the treatment timing tTit^{\prime}\in T_{i}, and the burn-in period in RNN TbT_{b}. Covariate prediction errors were computed as L2L_{2} prediction error of all simulated covariates: Lcovariates(at,t,i,t)=x^at,t(i,t)xat,t(i,t)2L_{covariates}^{(a_{t},t,i,t^{\prime})}=\|\hat{x}_{a_{t},t}^{(i,t^{\prime})}-x_{a_{t},t}^{(i,t^{\prime})}\|_{2}. We took the average over all variables for evaluation.

We also adopted two widely-used evaluation metrics in causal inference: rooted precision in the estimation of heterogeneous effect (PEHE) [37] ϵPEHEt=1nin(τ^t(i)τt(i))2\sqrt{\epsilon^{t}_{PEHE}}=\sqrt{\frac{1}{n}\sum_{i\in n}(\hat{\tau}_{t}^{(i)}-\tau_{t}^{(i)})^{2}} and mean absolute error of the average treatment effect (ATE) [73] to measure the quality of the estimated individual treatment effects at different time stamps: ϵATEt=|1ninτ^t(i)1ninτt(i)|\epsilon^{t}_{ATE}=|\frac{1}{n}\sum_{i\in n}\hat{\tau}_{t}^{(i)}-\frac{1}{n}\sum_{i\in n}\tau_{t}^{(i)}|. We took the average over all time stamps for evaluation.

Regarding interpretability, we emphasize that there has been no previous work to visualize or interpret future covariates (i.e., multiagent trajectory), which can confirm under what circumstances the intervention is effective. We then illustrated trajectory prediction results in each domain. We compared the customized baselines to predict future covariates to demonstrate the superiority of our approach, which are also quantitatively shown as the covariate losses LcovariatesL_{covariates}.

Autonomous vehicle simulation. We first validated the performances of our methods using an autonomous vehicle simulation using the CARLA simulator [71] (ver. 0.9.8). The CARLA environment contains dynamic obstacles (e.g., pedestrians and cars) that interact with the ego car. To generate data, we simulated autonomous driving at various towns, starting points, and obstacle positions, which were subsampled at 2 Hz (see also Appendix K). Since we did not use the future path as inputs for generality, we randomly split 7 towns data into 904 training, 129 validation, and 259 test scenarios. Here we used two types of autonomous vehicles: partial observation with Autoware [74] (ver. 1.12.0) and full observation types with a pre-installed CARLA simulator. The partial observation model often stops when dangerous situations arise for safety. The full observation model intervenes to accelerate the stopped vehicle after the safety confirmation. We set T=60,Tb=40T=60,T_{b}=40, and Ti=Tb+1,,Tb+10T_{i}=T_{b}+1,\ldots,T_{b}+10. For counterfactual data, since dangerous situations (i.e., requiring intervention) are limited, we only used with or without intervention at the same timing (not various timings).

We predict the safe driving distance of the ego car from the starting point as an outcome while driving without any collision with obstacles. That is, the treatment effect of the full observation model is defined as the difference between safe driving distances with and without interventions. As the local covariates xtlx^{l}_{t}, we used the position, velocity, pose, and size information of the ego car and obstacles in the 2D map. The maximum number of dynamic obstacles was 120120, but since the obstacles related to the ego car were limited and the number changes over time, we used the nearest 1010 obstacles’ information as the covariates. As the global covariates xtgx^{g}_{t}, we used the current driving distance xtdistx_{t}^{dist} and (binary) collision xtcollx_{t}^{coll} information of the ego car. The theory-based function ftheoryxf^{x}_{theory} mathematically computed x^t+1\hat{x}_{t+1} using the predicted velocity and pose information, and added the learned positive velocity in intervention. ftheoryyf^{y}_{theory} computed y^t+1\hat{y}_{t+1} using the predicted y^t+1\hat{y}^{\prime}_{t+1} as the output of the MLP such that y^t+1=(1xtcoll)y^t+1\hat{y}_{t+1}=(1-x_{t}^{coll})\hat{y}^{\prime}_{t+1} based on the definition of the safe driving distance.

Quantitative verification results are shown in Table I. Our full model and its ablated variants show lower prediction errors of the outcome and covariates, and ϵPEHEt\sqrt{\epsilon^{t}_{PEHE}}, and ϵATEt\epsilon^{t}_{ATE} than other baselines. Totally, our full model and the variant without amortized variational inference (TG-CRN) show the best performances, suggesting that the necessity of complex modeling for long-term prediction using this dataset would be smaller than that using the following datasets. In our four models, we found that on this dataset the combination of theory-based computation and GNN worked well in the covariate prediction, but did not in the outcome prediction. Example results of our method are shown in Fig. 5, which can interpret the results. Our method (TGV-CRN) shows better counterfactual prediction with intervention than the baseline (GCRN++X). On average, the outcome values with intervention compared to the absence of the intervention (i.e., ITE) were 0.005±0.0000.005\pm 0.000 (km) in our full model, 0.021±0.000-0.021\pm 0.000 in the baseline (GCRN+X), and 0.013±0.0010.013\pm 0.001 in the ground truth. As Fig. 5 shows, the baseline did not model the intervention effects. We expect that this method can be applied to the effect of human intervention in Level 3 autonomous vehicle simulations, which need human intervention. This approach can examine the effect of autonomous control in Level 4 or 5 autonomous vehicle simulations (without human intervention) in a case when human intervention is necessary in Level 3.

  LOutcomeL_{Outcome} ϵPEHEt\sqrt{\epsilon^{t}_{PEHE}} ϵATEt\epsilon^{t}_{ATE} LCovariatesL_{Covariates}
RNN 3.330 ±\pm 0.286 0.159 ±\pm 0.025 0.129 ±\pm 0.023 0.479 ±\pm 0.0011
DSW [7] 0.161 ±\pm 0.013 0.030 ±\pm 0.003 0.019 ±\pm 0.003
GCRN [8] 0.094 ±\pm 0.009 0.024 ±\pm 0.002 0.014 ±\pm 0.002
GCRN++X 3.020 ±\pm 0.496 0.049 ±\pm 0.009 0.034 ±\pm 0.007 2.072 ±\pm 0.0098
GV-CRN 0.038 ±\pm 0.003 0.022 ±\pm 0.002 0.012 ±\pm 0.001 0.372 ±\pm 0.0010
TG-CRN 0.045 ±\pm 0.004 0.021 ±\pm 0.002 0.011 ±\pm 0.001 0.082 ±\pm 0.0004
TV-CRN 0.032 ±\pm 0.003 0.023 ±\pm 0.002 0.014 ±\pm 0.002 0.102 ±\pm 0.0004
TGV-CRN (full) 0.041 ±\pm 0.003 0.020 ±\pm 0.002 0.008 ±\pm 0.001 0.098 ±\pm 0.0004
 
TABLE I: Performance comparison on the carla dataset. The upper and lower rows indicate the baselines and proposed methods, respectively.
Refer to caption
Figure 5: Example CARLA results using our method. (Top) Visualization of covariates and (middle row and bottom) outcome time series in (left) ground truth without intervention, (middle column) counterfactual intervention using our model, and (right) the baseline. The middle row subfigures are enlarged views of the bottom ones from 20 s. An ego car (red square) and obstacles (black) are shown in the upper plots (see also Fig. 1A) at the intervention time, which is the solid line in the lower plots. The unfilled circle is the start of the long-term prediction (dashed line in the lower plots). The ego-car moves from right to left and stops because of the obstacles.The videos are given in the above GitHub page.
  Boid simulation dataset Real-world NBA dataset
Treatment timing LOutcomeL_{Outcome} ϵPEHEt\sqrt{\epsilon^{t}_{PEHE}} ϵATEt\epsilon^{t}_{ATE} LCovariatesL_{Covariates} LOutcomeL_{Outcome} τ^CFP\hat{\tau}^{CFP} LCovariatesL_{Covariates}
RNN 1.815 ±\pm 0.068 0.668 ±\pm 0.033 0.443 ±\pm 0.054 0.085 ±\pm 0.016 0.168 ±\pm 0.0002 0.291 ±\pm 0.003 0.170 ±\pm 0.0065 0.969 ±\pm 0.0024
DSW [7] 2.353 ±\pm 0.080 0.586 ±\pm 0.029 0.440 ±\pm 0.054 0.082 ±\pm 0.015 0.197 ±\pm 0.003 0.166 ±\pm 0.0064
GCRN [8] 2.290 ±\pm 0.081 0.587 ±\pm 0.028 0.443 ±\pm 0.055 0.088 ±\pm 0.016 0.170 ±\pm 0.003 0.149 ±\pm 0.0062
GCRN++X 2.290 ±\pm 0.081 0.727 ±\pm 0.022 0.440 ±\pm 0.054 0.080 ±\pm 0.014 0.162 ±\pm 0.0001 0.431 ±\pm 0.004 0.154 ±\pm 0.0063 1.226 ±\pm 0.0032
GV-CRN 1.900 ±\pm 0.068 0.674 ±\pm 0.028 0.536 ±\pm 0.045 0.090 ±\pm 0.014 0.329 ±\pm 0.0003 0.475 ±\pm 0.005 0.151 ±\pm 0.0060 1.173 ±\pm 0.0028
TG-CRN 1.750 ±\pm 0.062 0.694 ±\pm 0.041 0.501 ±\pm 0.067 0.125 ±\pm 0.021 0.086 ±\pm 0.0002 0.347 ±\pm 0.005 0.205 ±\pm 0.0072 0.811 ±\pm 0.0022
TV-CRN 1.692 ±\pm 0.064 0.981 ±\pm 0.044 0.483 ±\pm 0.052 0.106 ±\pm 0.018 0.090 ±\pm 0.0002 0.165 ±\pm 0.002 0.192 ±\pm 0.0068 0.852 ±\pm 0.0023
TGV-CRN (full) 1.853 ±\pm 0.062 0.690 ±\pm 0.032 0.537 ±\pm 0.055 0.101 ±\pm 0.016 0.085 ±\pm 0.0002 0.231 ±\pm 0.004 0.261 ±\pm 0.0083 0.825 ±\pm 0.0023
 
TABLE II: Performance comparison on the boid and NBA datasets.The upper and lower rows indicate the baselines and proposed methods, respectively.

Biological multiagent simulation. Here, we validated our methods on the Boid model, which contains movement trajectories of 20 agents. The Boid model (originally, [75]) is a rule-based model to generate generic simulated flocking agents and we used a unit-vector-based (rule-based) model [10] (for details, see Appendix J). In this paper, we intervene the agents’ recognition to generate torus (circle) behaviors from the swarm (random) behaviors. The outcome is defined as the mean angular momentum of individuals about the center of the group (assuming the mass of each agent m=1m=1). That is, the treatment effect of the change in the recognition is estimated as the difference in the future mean angular momentum between simulations with and without the interventions.

In this model, 2020 agents are described by a 2-D vector with a 11 m/s constant velocity in a 15 ×\times 15 m boundary square. At each time stamp, a member will change direction according to the positions of all other members based on three zones. The first is the repulsion zone with radius rr=0.5r_{r}=0.5 m, in which individuals within each other’s repulsion zone try to avoid each other by swimming in opposite directions. The second is the orientation zone, in which individuals try to move in the same direction; here we set radius ro=1r_{o}=1 to generate swarming behaviors before the intervention. To generate torus behaviors, we change ro=4r_{o}=4, which is the intervention in this study. The third is the attractive zone (radius ra=7.5r_{a}=7.5 m), in which agents move towards each other and tend to cluster.

To simulate the treatment assignments, we generate factual 2080020800 samples (2000020000 training, 400400 validation, and 400400 test datasets). We set T=14T=14 and Tb=9T_{b}=9 and randomly pick the intervention point during the intervention period Ti=9,,13T_{i}=9,\ldots,13. The outcome is defined as the mean angular momentum among individuals at time t+1t+1. We also created a counterfactual dataset only in the test dataset. Here we created all combinations of treatment points during the intervention period TiT_{i}. As the local covariates xtlx^{l}_{t}, we used position, velocity, and directional change of all agents. As the global covariates xtgx^{g}_{t}, we used the current mean angular momentum. The theory-based function ftheoryxf^{x}_{theory} mathematically computed x^t+1\hat{x}_{t+1} using the direction change dd at the next step, maximum turn angle β\beta as body constraints (for dd and β\beta, see also Appendix J), and attraction rule when agents are far from the center of the group. In addition, we added the orientation rule when agents are in the orientation zone and not in the repulsion zone in ftheoryxf^{x}_{theory}. ftheoryyf^{y}_{theory} was replaced with a MLP such that y^t+1=MLPy([ztl,xtg,at])\hat{y}_{t+1}={\rm{MLP_{y}}}([z^{l}_{t},x^{g}_{t},a_{t}]).

The results are shown in Table II left. In addition to the four indices in the CARLA experiment, we investigated the estimation error of the best intervention timing |argmaxt(y^T+1(i,t))argmaxt(yT+1(i,t))||\operatorname*{arg\,max}_{t^{\prime}}(\hat{y}^{(i,t^{\prime})}_{T+1})-\operatorname*{arg\,max}_{t^{\prime}}(y^{(i,t^{\prime})}_{T+1})|. The results indicate that our model and its variants with theory-based computation show better prediction performances in covariates and the best intervention timing than all of the baselines. However, the outcome prediction errors in our models were worse than the causal inference baselines (DSW, GCRN, and GCFN+XX), which may lead to degraded performances in ϵPEHEt\sqrt{\epsilon^{t}_{PEHE}} and ϵATEt\epsilon^{t}_{ATE} of our models than the baselines. In our four models, all combination of core three components (T, G, and V) did not work well and on this dataset, TG-CRN without VRNN is the best performing model. On average, the ITE values were 0.160±0.0100.160\pm 0.010 in our best model (TG-CRN), 0.035±0.0010.035\pm 0.001 in the baseline (GCRN+X), and 0.091±0.0260.091\pm 0.026 in the ground truth. From this viewpoint (both average and variance), our best model was closer to the ground truth than the baseline. Example interpretable results of our method are shown in Fig. 6. Our best model (TG-CRN) shows better counterfactual covariate prediction with intervention than the baseline (GCRN++X). Moreover, the counterfactual prediction of the outcome in our model had variation among various intervention times, whereas that in the baseline did not. We consider that these were important properties in our problem, and improved prediction of the potential outcome was left for future work. We expect that our approach can estimate the effect of an experimenter’s interventions on multi-animal behaviors even if they did not perform the experiment in missing conditions. Our approach is expected to improve the efficiency of experimental procedures for observing desired movements.

Refer to caption
Figure 6: Example Boid results of our method. The configurations are the same as Fig. 5. (Top) 20 boids with 6 different colors move in a rule-based manner (but we did not use this information for the prediction) and the filled circles are the starting point of the trajectories. (Middle row and bottom) outcome time series in (left) ground truth without intervention, (middle column) counterfactual intervention using our model, and (right) the baseline. The middle row subfigures are enlarged views of the bottom ones from the 9th frame. The “a” in the lower caption is the intervention times. For example, “a=9a=9” means the case of intervention at the 9th frame and “None” indicates no intervention. The videos are given in the above GitHub page.

V-B Real-world basketball dataset

Finally, we examined the applicability of our methods to a real-world basketball dataset from the NBA. Data acquisition was based on the contract between the league (NBA) and the company (STATS LLC.), not between the players and us. They are top-level players and then the data was not anonymized. The company was licensed to acquire this data, and it was guaranteed that the use of the data would not infringe on any rights of players or teams. In this study, we used attack sequences from 630 games from the 2015/2016 NBA season (https://www.stats.com/data-science/), which contained the trajectories of 10 players and the ball. We extracted 47,467 attacks (i.e., offensive plays) as samples, which were subsampled at 5 Hz. We separated the dataset into 34,696 pre-training (458 games for training the following classifier of effective attack), 11460 training (154 games, 1/10 of that is used as validation), and 1,305 test samples (18 games) in chronological order. Since scoring predictions are difficult in general (e.g., [76, 11]), we define the attack effectiveness as the outcome by predicting whether the attack is effective or not in the future. This is because evaluating team movements based on scores alone may not provide a holistic view, due to factors such as the shooting skills of individual players. In addition, we predict the attack effectiveness at the next time stamp using the pre-training samples and a logistic regression. The details are provided in Appendix L.

We verified our methods using factual data and provided insights using counterfactual prediction. In model verification using factual data, we set T=95T=95 and Tb=85T_{b}=85. For counterfactual predictions, we set T=105T=105 and Tb=95T_{b}=95 (at the end of attacks, i.e., a shot or turnover) and predicted all combinations of the counterfactual timing during the intervention period Ti=95,,98T_{i}=95,\ldots,98. As the local covariates xtlx^{l}_{t}, we used the position and velocity of all agents including the ball. As the global covariates xtgx^{g}_{t}, we used the ball player’s information and areas (see also Appendix L), distances from the nearest defender (about the ball player and other attackers), successful shot probabilities of all attackers, and game and shot clock. The theory-based function ftheoryxf^{x}_{theory} and ftheoryyf^{y}_{theory} computed x^t+1g\hat{x}^{g}_{t+1} and y^t+1\hat{y}_{t+1} in rule-based manners. In addition, to perform long-term prediction, ftheoryxf^{x}_{theory} also computed the ball and the defender nearest to the next ball player during a counterfactual pass in rule-based manners.

Our verification results using factual data are shown in Table I right. We examined the counterfactual pass effect τ^CFP=maxtTi(maxt=[Tb,T](y^T+1(i,t)))yTb+1(i)\hat{\tau}^{CFP}=\max_{t^{\prime}\in T_{i}}(\max_{t=[T_{b},T]}(\hat{y}^{(i,t^{\prime})}_{T+1}))-y^{(i)}_{T_{b}+1} in addition to the outcome and covariate prediction errors. Our methods show better prediction performances in the factual covariates than the fully data-driven baselines. In the factual outcomes, our methods outperformed the most appropriate baseline (GCFN++X) and RNN, but only the method without GNN shows competitive performance with DSW and GCRN. One of the possible reasons may be the difficulty in modeling the relationship between the future outcome and current covariates, whereas GNN worked well in only the covariate prediction as the previous work [13]. As additional analysis, we indicate endpoint errors and distributions of long-term covariate prediction for basketball data in Appendix M. Again, the strength of our method is to model the covariates at the next timestep for interpretability of the model. Results in Appendix M that the tendency of the endpoint prediction error for all models is similar to the mean prediction error, in which our approach (TV-CRN and TGV-CRN) show the best performance. In addition, the distribution of the player velocity in the figure of Appendix M shows that our approach without theory-based computation had a wider distribution like ground truth than other models, and those with theory-based computation had less zero-velocity bins like ground truth than other models. Although our approach did not completely model the player’s velocity, we show that they were partially effective for covariate prediction using each component. In τ^CFP\hat{\tau}^{CFP}, since all models show positive values, it suggests that there may be a more promising shot opportunity by an extra pass in the shot situation. Figure 7 shows realistic counterfactual prediction from our model to demonstrate interpretable results. Our model completed the counterfactual pass and the nearest defender chase, but the baseline model failed and predicted unrealistic behaviors (e.g., the attackers moved toward the outside of the court). In practice, our approach can estimate the effect of the selection of passes in basketball shot scenarios. We expect that our approach can evaluate the decision-making skills of players in a competitive game. Again, we emphasize these important properties in our problem.

Refer to caption
Figure 7: Example NBA results of our method. The configurations are the same as Fig. 6. (Top) Red and blue numbers, gray line, and orange circle and line indicate an attacker, a defender, players’ historical trajectories, and the ball, respectively. The positions of the numbers are at the end of the factual data (shot), which is shown as the break line in lower plots. In the CF intervention subplots, colored trajectories indicate counterfactual predictions. The actual red player #5 shot (left) but in the counterfactual prediction (middle and right columns), the player tried to pass to a teammate. In the middle top (our method), the player successfully passes to the teammate red #3, but in the left top (baseline), the player’s pass failed. (Bottom) outcome time series (attack effectiveness) are shown. The “a” in the lower caption is the intervention times as shown in Fig. 6. For example, “a=95a=95” means the case of intervention at the 95th frame (9.5 s). The video is given in the above GitHub page.

VI Conclusions

In this paper, we proposed an interpretable counterfactual recurrent network in multiagent systems to estimate the effect of the intervention. Using synthetic CARLA and Boid datasets, we showed that our model achieved lower errors in estimating counterfactual covariates and the most effective treatment timing than the baselines. Furthermore, using real basketball data, our model performed realistic counterfactual prediction. We consider a general ITE framework for various domains, but the experimental results show that for each domain the effective modeling was different. Possible future research directions are to realize better modeling of the future outcomes, and to apply our approach to other multiagent domains such as animals and pedestrians using domain knowledge.

Acknowledgments

This work was supported by JSPS KAKENHI (Japan Society for the Promotion of Science, Grant Numbers 20H04075, 21H04892, and 21H05300), JST PRESTO (Japan Science and Technology Agency, Precursory Research for Embryonic Science and Technology, Grant Number JPMJPR20CA), and JST CREST (Core Research for Evolutional Science and Technology, Grant Number JPMJCR1913).

Appendix

G A proof of Theorem 1

Proof.

Under the aforementioned assumptions in the main text, we can prove the identification of ITE:

τt\displaystyle\tau_{t} =𝔼y[y1,t+1y0,t+1|xt,t]\displaystyle=\mathbb{E}_{y}[y_{1,t+1}-y_{0,t+1}|x_{t},\mathcal{H}_{t}] (18)
=𝔼z[𝔼y[y1,t+1y0,t+1|xt,zt,t]|xt,t]\displaystyle=\mathbb{E}_{z}[\mathbb{E}_{y}[y_{1,t+1}-y_{0,t+1}|x_{t},z_{t},\mathcal{H}_{t}]|x_{t},\mathcal{H}_{t}] (19)
=𝔼z[𝔼y[y1,t+1y0,t+1|zt]|xt,t]\displaystyle=\mathbb{E}_{z}[\mathbb{E}_{y}[y_{1,t+1}-y_{0,t+1}|z_{t}]|x_{t},\mathcal{H}_{t}] (20)
=𝔼z[𝔼y[y1,t+1|zt,at=1]𝔼y[y0,t+1|zt,at=0]|xt,t]\displaystyle=\mathbb{E}_{z}[\mathbb{E}_{y}[y_{1,t+1}|z_{t},a_{t}=1]-\mathbb{E}_{y}[y_{0,t+1}|z_{t},a_{t}=0]|x_{t},\mathcal{H}_{t}] (21)
=𝔼z[𝔼y[yF,t+1|zt,at=1]𝔼y[yF,t+1|zt,at=0]|xt,t],\displaystyle=\mathbb{E}_{z}[\mathbb{E}_{y}[y_{F,t+1}|z_{t},a_{t}=1]-\mathbb{E}_{y}[y_{F,t+1}|z_{t},a_{t}=0]|x_{t},\mathcal{H}_{t}], (22)

where τt=τ(xt,t)\tau_{t}=\tau(x_{t},\mathcal{H}_{t}), yF,t+1y_{F,t+1} is a factual outcome, and we drop the instance index (i)(i) for simplification. Eq. (18) is the definition of ITE in our setting, Eq. (19) is a straightforward expectation over p(zt|xt,t)p(z_{t}|x_{t},\mathcal{H}_{t}) , and Eq. (20) be inferred from the structure of the causal graph shown in Fig. 2. Eq. (21) is based on the assumption that ztz_{t} contains all the hidden confounders, as well as the positivity assumption, and Eq. (22) can be inferred from the consistency assumption. Thus, if our framework can correctly model p(zt|xt,t)p(z_{t}|x_{t},\mathcal{H}_{t}) and p(yt|zt,at)p(y_{t}|z_{t},a_{t}), then the ITEs can be identified under the causal graph in Fig. 2. ∎

H Common training setup

H-A Model training and computation

The codes and data we used are provided at https://github.com/keisuke198619/TGV-CRN. This experiment was performed on an Intel(R) Xeon(R) CPU E5-2699 v4 (2.202.20 GHz ×\times 16) with GeForce TITAN X pascal GPU. For the training of the proposed and baseline models, we used the Adam optimizer [77] with an initial learning rate of 0.00010.0001 and 2020 training epochs. We set the batchsize to 256. For the hyper-parameters in the loss function, we set α=0.1,λ=0.1\alpha=0.1,\lambda=0.1 for all datasets and γ=0.1\gamma=0.1 in the Boid and CARLA dataset and γ=1\gamma=1 in the NBA experiment because the latter dataset was more difficult to predict than the Boid and CARLA experiments.

H-B Baseline models implementation

We compared the performances of our methods to infer ITE with those in the following baselines: a simple RNN using GRU [72], deep sequential weighting (DSW) [7], dynamic networked observational data deconfounder (but for clarity, we change the name into GCRN: graph counterfactual recurrent network) [8].

RNN. This approach is based on GRU [72]. This model predicts the input (covariates) at the next time stamp, the potential outcome, and the probability of receiving treatment. We also model the hidden confounder as the hidden state of GRU, but do not learn the representation to reduce the confounding bias.

DSW [7]. Compared with the original model, we modified it to our setting (e.g., multiagent, time-varying outcome, and long-term prediction), and to fairly compare with our model, we removed the attention module.

GCRN [8]. Similarly, compared with the original model, we modified it to our setting (e.g., multiagent and long-term prediction) based on DSW, and to fairly compare with our model, we removed the attention module.

I Sensitivity analysis in hyperparameters

We performed the sensitivity analysis in hyperparameters using the CARLA dataset. The hyperparameters are presented in Eq. (17). Results in Table III shows the existence of the trade-off between the prediction performances of the outcome and covariates (γ\gamma).

  LOutcomeL_{Outcome} ϵPEHEt\sqrt{\epsilon^{t}_{PEHE}} ϵATEt\epsilon^{t}_{ATE} LCovariatesL_{Covariates}
α,γ,λ=0.1\alpha,\gamma,\lambda=0.1 (default) 0.041 ±\pm 0.003 0.020 ±\pm 0.002 0.008 ±\pm 0.001 0.098 ±\pm 0.0004
α=1.0,γ,λ=0.1\alpha=1.0,\gamma,\lambda=0.1 0.040 ±\pm 0.003 0.020 ±\pm 0.002 0.009 ±\pm 0.001 0.102 ±\pm 0.0004
α=0.01,γ,λ=0.1\alpha=0.01,\gamma,\lambda=0.1 0.041 ±\pm 0.003 0.019 ±\pm 0.002 0.007 ±\pm 0.001 0.097 ±\pm 0.0004
γ=1.0,α,λ=0.1\gamma=1.0,\alpha,\lambda=0.1 0.041 ±\pm 0.003 0.019 ±\pm 0.001 0.005 ±\pm 0.001 0.086 ±\pm 0.0004
γ=0.01,α,λ=0.1\gamma=0.01,\alpha,\lambda=0.1 0.031 ±\pm 0.002 0.020 ±\pm 0.002 0.009 ±\pm 0.001 0.146 ±\pm 0.0005
λ=1.0,α,γ=0.1\lambda=1.0,\alpha,\gamma=0.1 0.062 ±\pm 0.005 0.019 ±\pm 0.002 0.007 ±\pm 0.001 0.121 ±\pm 0.0005
λ=0.01,α,γ=0.1\lambda=0.01,\alpha,\gamma=0.1 0.041 ±\pm 0.003 0.020 ±\pm 0.002 0.008 ±\pm 0.001 0.098 ±\pm 0.0004
 
TABLE III: Sensitivity analysis on the carla dataset.

J Boid dataset

The schooling model we used in this study was a unit-vector-based (rule-based) model [10], which accounts for the relative positions and direction vectors of neighboring fish agents, such that each fish tends to align its own direction vector with those of its neighbors. In this model, 2020 agents (length: 0.5 m) are described by a two-dimensional vector with a constant velocity (1 m/s) in a boundary square (30 ×\times 30 m) as follows: rk=(xiyi)T{r}^{k}=\left({x_{i}}~{}{y_{i}}\right)^{T} and vtk=vk2dk{v}^{k}_{t}=\|v^{k}\|_{2}d_{k}, where xix_{i} and yiy_{i} are two-dimensional Cartesian coordinates, vk{v}^{k} is a velocity vector, 2\|\cdot\|_{2} is the Euclidean norm, and dkd_{k} is an unit directional vector for agent ii.

At each timestep, a member will change direction according to the positions of all other members. The space around an individual is divided into three zones where each modifying the unit vector of the velocity (for the zones, see the main text). Let λr\lambda_{r}, λo\lambda_{o}, and λa\lambda_{a} be the numbers in the zones of repulsion, orientation and attraction respectively. For λr0\lambda_{r}\neq 0, the unit vector of an individual at the next timestep is given by:

dk(t+1,λr0)=(1λr1jkλrrtkjrtkj2),d_{k}(t+1,\lambda_{r}\neq 0)=-\left(\frac{1}{\lambda_{r}-1}\sum_{j\neq k}^{\lambda_{r}}\frac{r^{kj}_{t}}{\|r^{kj}_{t}\|_{2}}\right), (23)

where rkj=rjrir^{kj}={r}_{j}-{r}_{i}. The velocity vector points away from neighbors within this zone to prevent collisions. This zone is given the highest priority; if and only if λr=0\lambda_{r}=0, the remaining zones are considered. The unit vector in this case is given by:

dk(t+1,λr=0)=12(1λoj=1λodj(t)+1λa1jkλartkjrtkj2).d_{k}(t+1,{\lambda}_{r}=0)=\frac{1}{2}\left(\frac{1}{\lambda_{o}}\sum_{j=1}^{\lambda_{o}}d_{j}(t)+\frac{1}{\lambda_{a}-1}\sum_{j\neq k}^{{\lambda}_{a}}\frac{r^{kj}_{t}}{\|r^{kj}_{t}\|_{2}}\right). (24)

The first term corresponds to the orientation zone while the second term corresponds to the attraction zone. The above equation contains a factor of 1/21/2 which normalizes the unit vector in the case where both zones have non-zero neighbors. If no agents are found near any zone, the individual maintains a constant velocity at each timestep.

In addition to the above, we constrain the angle by which a member can change its unit vector at each timestep to a maximum of β=30\beta=30 deg. This condition was imposed to facilitate rigid body dynamics. Since we assumed point-like members, all information about the physical dimensions of the actual fish is lost, which leaves the unit vector free to rotate at any angle. In reality, however, the conservation of angular momentum will limit the ability of the fish to turn angle θ\theta as follows:

dk(t+1)dk(t)={cos(β)if θ>βcos(θ)otherwise.d_{k}\left(t+1\right)\cdot d_{k}(t)=\begin{cases}\cos(\beta)&\text{if $\theta>\beta$}\\ \cos\left(\theta\right)&\text{otherwise}.\end{cases} (25)

If the above condition is not satisfied, the angle of the desired direction at the next timestep is rescaled to θ=β\theta=\beta. In this way, any un-physical behavior such as having a 180 rotation of the velocity vector in a single timestep is prevented.

In the simulation, the ground truth of τT\tau_{T} was 0.091±0.0260.091\pm 0.026, which indicates the intervention increased angular velocities (see also the main text).

K CARLA dataset

We used the CARLA simulator [71] (ver. 0.9.8). The CARLA environment contains dynamic obstacles (e.g., pedestrians and cars) that interact with the ego car. For generating data, we performed simulation through various towns, starting points, and obstacle positions, which were subsampled at 2 Hz. The obstacles’ positions and starting points of the ego car were randomly selected from the possible locations on the map for each run. Approximately 20-120 obstacles were placed on each map. Using the data collected in CARLA, the same driving conditions were reproduced in ROS (robot operating system) [78], and Autoware was used to make the vehicle drive the same route autonomously. Here we consider two types of autonomous vehicles: partial observation with Autoware [74] (ver.1.12.0) and full observation types with the pre-installed CARLA simulator. The partial observation model often stops when dangerous situations for safety. If an obstacle enters the deceleration or stopping range, the vehicle decelerates or stops in front of the obstacle. The deceleration range was set to be wider than usual for safety. The full observation model intervenes to accelerate the stopped ego car after the safety confirmation based on the speed information at the data collection.

We set T=60,Tb=40T=60,T_{b}=40, and Ti=Tb+1,,Tb+10T_{i}=T_{b}+1,\ldots,T_{b}+10. We evaluated and predicted the safe driving distance of the ego car from the starting point while driving without collision with any obstacle (for details, see the main text). In the simulation, the ground truth of τT\tau_{T} was 0.013±0.0010.013\pm 0.001, which indicates the intervention increased safe driving distance.

L NBA dataset

Here, we describe the details of the computation using the NBA dataset. Dataset description is given in the main text. We describe the computation of the attack effectiveness used in this study. Then, we predict the effective attacks at the next time stamp using the pre-training samples and a logistic regression.

From the aforementioned reasons in the main text, we compute an interpretable and simple indicator from available statistics (i.e., based on the frequency) to evaluate whether a player attempts a better shot, rather than based on the shot label or learning-based score prediction. From available statistics, we focused on two basic factors for effective attacks at an individual player level: the shot zone on the court and the distance between a shooter and the nearest defender. These two factors have been considered to be important for basketball shot prediction [79, 76, 11]. In the NBA advanced stats (https://www.nba.com/stats/players/shots-closest-defender/), we can access probabilities of successful shots in each zone and distance for each player. The shot zones are separated into four areas: the restricted area, in-the-paint, mid-range, and the 3-point area. The restricted area is defined as the area within a radius of 2.44 m from the ring (distance between the side of the rectangle and the ring) from the ring. In-the-paint is defined as the area within a radius of 5.46 m (distance between the ring and the farthest vertex of the rectangle) from the ring. The 3-point area is defined as the outside of the 3-point line. Mid-range is the remaining area. The shooter’s distance from the nearest defender is categorized into four ranges: 020-2 feet, 242-4 feet, 464-6 feet, and 6+6+ feet.

We define the attack effectiveness using the following criteria:

  • The shooter’s position in the restricted area is effective at any distance because the defender often exists near the shooter.

  • The shooter’s position in the paint and mid-range is effective at 6 feet or further (this range is regarded as “open” in the NBA advanced stats).

  • The shooter’s position in the 3-point area is effective when a player with a probability of 0.35 and hits with 6 feet or more (because some players do not shoot tactically).

Based on the statistics before the game (e.g., for the pre-training data, we used the 2014/2015 season statistics) and the tracking data, we computed the probabilities of successful shots for each zone and the distances for each player. We computed the probability of the player who attempted shots less than 10 times as the probability of the same position player (i.e., guard, forward, center, guard/forward, and forward/center based on the registration information from the NBA 2014/2015 season). It should be noted that there are some strategies of a good shot in basketball that differ depending on the court location and context, for example, 2 pointers and 3 pointers. Note that, unfortunately, we can access those for only two areas (2-point and 3-point areas) with four distance categories, thus we computed the shot success probability at the restricted area, in-the-paint, and mid-range using that at the 2-point area. Based on the above definitions, for pre-training data, there were 13,681 shot successes, 15,443 shot failures, 5,572 turnovers, 16,976 effective attacks, and 17,720 ineffective attacks. For training data, there were 4,677 shot successes, 5,092 shot failures, 1,691 turnovers, 5,602 effective attacks, and 5,858 ineffective attacks. For test data, there were 517 shot successes, 577 shot failures, 211 turnovers, 664 effective attacks, and 641 ineffective attacks. The probabilities of scoring, given the attack was effective and ineffective, were 0.463/0.415/0.417 and 0.328/0.402/0.374 for pre-training, training, and test data, respectively. We confirmed that the effective attack had a higher probability of a successful shot than the ineffective attack.

After the computation of the attack effectiveness, we predict the effective attacks at the next time stamp using the pre-training samples and a logistic regression. In the soccer domain, there has been some work to evaluate players and teams (e.g., [80, 81]) based on the prediction model (i.e., classifier) on the assumption that a good play is a play that will bring good outcomes (e.g., score and ball recovery) in the future. These approaches can transform a discrete value (i.e., an outcome) into a continuous value (e.g., a probability). According to these papers, we also predict the effective attack at the next time stamp. To verify the prediction accuracy, we compared the accuracy of the logistic regression, LightGBM [82], which is a popular classifier using a highly efficient gradient boosting decision tree, and prediction with all effective attacks. For the training data, the accuracies of logistic regression, LightGBM, and the prediction with all the same labels were 0.838, 0.838, and 0.731, respectively. For the test data, these were 0.679, 0.676, and 0.568, respectively. We confirmed that the logistic regression had competitive performance compared with LightGBM, and higher accuracy than the prediction with all the same labels, even maintaining higher interpretability.

M Additional analysis in basketball dataset

As additional analysis, we indicate endpoint errors and distributions of long-term covariate prediction for basketball data. Results in Table IV that the tendency of the endpoint prediction error for all models is similar to the mean prediction error, in which our approach (TV-CRN and TGV-CRN) show the best performance. In addition, the distribution of the player velocity in Fig. 8 shows that our approach without theory-based computation had a wider distribution like ground truth than other models, and those with theory-based computation had less zero-velocity bins like ground truth than other models.

  LCovariatesL_{Covariates} at endpoint
RNN 1.637 ±\pm 0.0096
DSW [7]
GCRN [8]
GCRN++X 2.266 ±\pm 0.0138
GV-CRN 1.937 ±\pm 0.0116
TG-CRN 1.482 ±\pm 0.0090
TV-CRN 1.566 ±\pm 0.0094
TGV-CRN (full) 1.510 ±\pm 0.0093
 
TABLE IV: Endpoint prediction error on NBA dataset.
Refer to caption
Figure 8: Histograms of player velocity in NBA dataset.

N Convergence in the learning of our models

We illustrate the change in the validation losses of our models over the course of training epochs as shown in Figs. 9, 10, and 11. Compared with CARLA dataset results, which show better outcome and covariate prediction results than other datasets, those in Boid and NBA datasets sometimes show unstable learning curves, but most of the models finally show convergence in all datasets. To perform fair comparisons among the models, we trained all models with 20 epochs.

Refer to caption
Figure 9: The learning curve of our model training in CARLA dataset. Left and right subfigures indicate losses in outcome and covariate prediction, respectively.
Refer to caption
Figure 10: The learning curve of our model training in Boid dataset.
Refer to caption
Figure 11: The learning curve of our model training in NBA dataset.

References

  • [1] T. A. Glass, S. N. Goodman, M. A. Hernán, and J. M. Samet, “Causal inference in public health,” Annual Review of Public Health, vol. 34, pp. 61–75, 2013.
  • [2] N. Baum-Snow and F. Ferreira, “Causal inference in urban and regional economics,” in Handbook of regional and urban economics.   Elsevier, 2015, vol. 5, pp. 3–68.
  • [3] P. Wang, W. Sun, D. Yin, J. Yang, and Y. Chang, “Robust tree-based causal inference for complex ad effectiveness analysis,” in Proceedings of the Eighth ACM International Conference on Web Search and Data Mining, 2015, pp. 67–76.
  • [4] B. Lim, A. M. Alaa, and M. van der Schaar, “Forecasting treatment responses over time using recurrent marginal structural networks.” Advances in Neural Information Processing Systems, vol. 18, pp. 7483–7493, 2018.
  • [5] I. Bica, A. M. Alaa, J. Jordon, and M. van der Schaar, “Estimating counterfactual treatment outcomes over time through adversarially balanced representations,” in International Conference on Learning Representations, 2020.
  • [6] I. Bica, A. Alaa, and M. Van Der Schaar, “Time series deconfounder: Estimating treatment effects over time in the presence of hidden confounders,” in International Conference on Machine Learning.   PMLR, 2020, pp. 884–895.
  • [7] R. Liu, C. Yin, and P. Zhang, “Estimating individual treatment effects with time-varying confounders,” in 2020 IEEE International Conference on Data Mining (ICDM).   IEEE, 2020, pp. 382–391.
  • [8] J. Ma, R. Guo, C. Chen, A. Zhang, and J. Li, “Deconfounding with networked observational data in a dynamic environment,” in Proceedings of the 14th ACM International Conference on Web Search and Data Mining, 2021, pp. 166–174.
  • [9] T. Vicsek, A. Czirók, E. Ben-Jacob, I. Cohen, and O. Shochet, “Novel type of phase transition in a system of self-driven particles,” Physical Review Letters, vol. 75, no. 6, pp. 1226–1229, 1995.
  • [10] I. D. Couzin, J. Krause, R. James, G. D. Ruxton, and N. R. Franks, “Collective memory and spatial sorting in animal groups,” Journal of Theoretical Biology, vol. 218, no. 1, pp. 1–11, 2002.
  • [11] K. Fujii, T. Kawasaki, Y. Inaba, and Y. Kawahara, “Prediction and classification in equation-free collective motion dynamics,” PLoS Computational Biology, vol. 14, no. 11, p. e1006545, 2018.
  • [12] K. Fujii, N. Takeishi, M. Hojo, Y. Inaba, and Y. Kawahara, “Physically-interpretable classification of network dynamics for complex collective motions,” Scientific Reports, vol. 10, no. 3005, 2020.
  • [13] R. A. Yeh, A. G. Schwing, J. Huang, and K. Murphy, “Diverse generation for multi-agent sports games,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2019, pp. 4610–4619.
  • [14] K. Fujii, K. Takeuchi, A. Kuribayashi, N. Takeishi, Y. Kawahara, and K. Takeda, “Estimating counterfactual treatment outcomes over time in multi-vehicle simulation,” in Proceedings of the 30th International Conference on Advances in Geographic Information Systems (SIGSPATIAL’22), 2022.
  • [15] J. Chung, K. Kastner, L. Dinh, K. Goel, A. C. Courville, and Y. Bengio, “A recurrent latent variable model for sequential data,” in Advances in Neural Information Processing Systems 28, 2015, pp. 2980–2988.
  • [16] T. Kipf, E. Fetaya, K.-C. Wang, M. Welling, and R. Zemel, “Neural relational inference for interacting systems,” in International Conference on Machine Learning, 2018, pp. 2688–2697.
  • [17] D. B. Rubin, “Bayesian inference for causal effects: The role of randomization,” The Annals of statistics, pp. 34–58, 1978.
  • [18] J. M. Robins and M. A. Hernán, “Estimation of the causal effects of time-varying exposures,” in Longitudinal Data Analysis, G. Fitzmaurice, M. Davidian, G. Verbeke et al., Eds.   New York, NY: Chapman & Hall/CRC Press, 2009, pp. 553–597.
  • [19] J. M. Robins, M. A. Hernan, and B. Brumback, “Marginal structural models and causal inference in epidemiology,” 2000.
  • [20] M. A. Hernán and J. M. Robins, Causal inference.   CRC Boca Raton, FL, 2010.
  • [21] J. Pearl, Causality.   Cambridge university press, 2009.
  • [22] E. Zhan, S. Zheng, Y. Yue, L. Sha, and P. Lucey, “Generating multi-agent trajectories using programmatic weak supervision,” in International Conference on Learning Representations, 2019.
  • [23] M. Fraccaro, S. K. Sønderby, U. Paquet, and O. Winther, “Sequential neural models with stochastic layers,” in Advances in Neural Information Processing Systems 29, 2016, pp. 2199–2207.
  • [24] A. G. A. P. Goyal, A. Sordoni, M.-A. Côté, N. R. Ke, and Y. Bengio, “Z-forcing: Training stochastic recurrent networks,” in Advances in Neural Information Processing Systems 30, 2017, pp. 6713–6723.
  • [25] M. Zaheer, S. Kottur, S. Ravanbakhsh, B. Poczos, R. R. Salakhutdinov, and A. J. Smola, “Deep sets,” Advances in Neural Information Processing Systems, vol. 30, pp. 3394––3404, 2017.
  • [26] N. Takeishi and A. Kalousis, “Physics-integrated variational autoencoders for robust and interpretable generative modeling,” Advances in Neural Information Processing Systems, vol. 34, pp. 14 809–14 821, 2021.
  • [27] K. Fujii, N. Takeishi, Y. Kawahara, and K. Takeda, “Decentralized policy learning with partial observation and mechanical constraints for multiperson modeling,” Neural Networks, vol. 171, pp. 40–52, 2024.
  • [28] P. R. Rosenbaum and D. B. Rubin, “The central role of the propensity score in observational studies for causal effects,” Biometrika, vol. 70, no. 1, pp. 41–55, 1983.
  • [29] U. Shalit, F. D. Johansson, and D. Sontag, “Estimating individual treatment effect: generalization bounds and algorithms,” in International Conference on Machine Learning.   PMLR, 2017, pp. 3076–3085.
  • [30] Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, M. Marchand, and V. Lempitsky, “Domain-adversarial training of neural networks,” The Journal of Machine Learning Research, vol. 17, no. 1, pp. 2096–2030, 2016.
  • [31] J. Robins, “A new approach to causal inference in mortality studies with a sustained exposure period—application to control of the healthy worker survivor effect,” Mathematical Modelling, vol. 7, no. 9-12, pp. 1393–1512, 1986.
  • [32] J. M. Robins, “Correcting for non-compliance in randomized trials using structural nested mean models,” Communications in Statistics-Theory and methods, vol. 23, no. 8, pp. 2379–2412, 1994.
  • [33] Y. Xu, Y. Xu, and S. Saria, “A bayesian nonparametric approach for estimating individualized treatment-response curves,” in Machine Learning for Healthcare Conference.   PMLR, 2016, pp. 282–300.
  • [34] P. Schulam and S. Saria, “Reliable decision support using counterfactual models,” Advances in Neural Information Processing Systems, vol. 30, pp. 1697–1708, 2017.
  • [35] H. Soleimani, A. Subbaswamy, and S. Saria, “Treatment-response models for counterfactual reasoning with continuous-time, continuous-valued interventions,” in 33rd Conference on Uncertainty in Artificial Intelligence, UAI 2017.   AUAI Press Corvallis, OR, 2017.
  • [36] J. Roy, K. J. Lum, and M. J. Daniels, “A bayesian nonparametric approach to marginal structural models for point treatments and a continuous or survival outcome,” Biostatistics, vol. 18, no. 1, pp. 32–47, 2017.
  • [37] J. L. Hill, “Bayesian nonparametric modeling for causal inference,” Journal of Computational and Graphical Statistics, vol. 20, no. 1, pp. 217–240, 2011.
  • [38] S. Wager and S. Athey, “Estimation and inference of heterogeneous treatment effects using random forests,” Journal of the American Statistical Association, vol. 113, no. 523, pp. 1228–1242, 2018.
  • [39] A. Alaa and M. Schaar, “Limits of estimating heterogeneous treatment effects: Guidelines for practical algorithm design,” in International Conference on Machine Learning.   PMLR, 2018, pp. 129–138.
  • [40] F. Johansson, U. Shalit, and D. Sontag, “Learning representations for counterfactual inference,” in International Conference on Machine Learning.   PMLR, 2016, pp. 3020–3029.
  • [41] L. Yao, S. Li, Y. Li, M. Huai, J. Gao, and A. Zhang, “Representation learning for treatment effect estimation from observational data,” Advances in Neural Information Processing Systems, vol. 31, 2018.
  • [42] J. Yoon, J. Jordon, and M. Van Der Schaar, “Ganite: Estimation of individualized treatment effects using generative adversarial nets,” in International Conference on Learning Representations, 2018.
  • [43] C. Shi, D. M. Blei, and V. Veitch, “Adapting neural networks for the estimation of treatment effects,” in Proceedings of the 33rd International Conference on Neural Information Processing Systems, 2019, pp. 2507–2517.
  • [44] P. Grecov, A. N. Prasanna, K. Ackermann, S. Campbell, D. Scott, D. I. Lubman, and C. Bergmeir, “Probabilistic causal effect estimation with global neural network forecasting models,” IEEE Transactions on Neural Networks and Learning Systems, pp. 1–15, July 2022.
  • [45] M. Abroshan, K. H. Yip, C. Tekin, and M. van der Schaar, “Conservative policy construction using variational autoencoders for logged data with missing values,” IEEE Transactions on Neural Networks and Learning Systems, pp. 1–11, 2022.
  • [46] Q. Li, Z. Wang, S. Liu, G. Li, and G. Xu, “Causal optimal transport for treatment effect estimation,” IEEE Transactions on Neural Networks and Learning Systems, pp. 1–13, 2021.
  • [47] Z. Wu, S. Pan, G. Long, J. Jiang, and C. Zhang, “Graph wavenet for deep spatial-temporal graph modeling,” in Proceedings of the 28th International Joint Conference on Artificial Intelligence, 2019, pp. 1907–1913.
  • [48] Q. Zhang, J. Chang, G. Meng, S. Xiang, and C. Pan, “Spatio-temporal graph structure learning for traffic forecasting,” in Proceedings of the AAAI conference on artificial intelligence, vol. 34, no. 01, 2020, pp. 1177–1185.
  • [49] K. Takeuchi, R. Nishida, H. Kashima, and M. Onishi, “Causal effect estimation on hierarchical spatial graph data,” in Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, 2023, pp. 2145–2154.
  • [50] Z. Guo, T. Xiao, C. Aggarwal, H. Liu, and S. Wang, “Counterfactual learning on graphs: A survey,” arXiv preprint arXiv:2304.01391, 2023.
  • [51] N. Seedat, F. Imrie, A. Bellot, Z. Qian, and M. van der Schaar, “Continuous-time modeling of counterfactual outcomes using neural controlled differential equations,” in International Conference on Machine Learning.   PMLR, 2022, pp. 19 497–19 521.
  • [52] S. Jiang, Z. Huang, X. Luo, and Y. Sun, “Cf-gode: Continuous-time causal inference for multi-agent dynamical systems,” arXiv preprint arXiv:2306.11216, 2023.
  • [53] G. Chen, J. Li, J. Lu, and J. Zhou, “Human trajectory prediction via counterfactual analysis,” arXiv preprint arXiv:2107.14202, 2021.
  • [54] K. Fujii, N. Takeishi, K. Tsutsui, E. Fujioka, N. Nishiumi, R. Tanaka, M. Fukushiro, K. Ide, H. Kohno, K. Yoda, S. Takahashi, S. Hiryu, and Y. Kawahara, “Learning interaction rules from multi-animal trajectories via augmented behavioral models,” Advances in Neural Information Processing Systems, vol. 34, pp. 11 108–11 122, 2021.
  • [55] H. Nakahara, K. Takeda, and K. Fujii, “Estimating the effect of hitting strategies in baseball using counterfactual virtual simulation with deep learning,” International Journal of Computer Science in Sport, vol. 22, no. 1, pp. 1–12, January 2022.
  • [56] M. Teranishi, K. Tsutsui, K. Takeda, and K. Fujii, “Evaluation of creating scoring opportunities for teammates in soccer via trajectory prediction,” in International Workshop on Machine Learning and Data Mining for Sports Analytics.   NY: Springer, 2022, pp. 53–73.
  • [57] A. Lerer, S. Gross, and R. Fergus, “Learning physical intuition of block towers by example,” in International conference on machine learning.   PMLR, 2016, pp. 430–438.
  • [58] G. Liu and O. Schulte, “Deep reinforcement learning in ice hockey for context-aware player evaluation,” in Proceedings of the 27th International Joint Conference on Artificial Intelligence, 2018, pp. 3442–3448.
  • [59] G. Liu, Y. Luo, O. Schulte, and T. Kharrat, “Deep soccer analytics: learning an action-value function for evaluating soccer players,” Data Mining and Knowledge Discovery, vol. 34, no. 5, pp. 1531–1559, 2020.
  • [60] P. Rahimian, J. Van Haaren, T. Abzhanova, and L. Toka, “Beyond action valuation: A deep reinforcement learning framework for optimizing player decisions in soccer,” in 16th Annual MIT Sloan Sports Analytics Conference. Boston, MA, USA: MIT, 2022, p. 25.
  • [61] H. Nakahara, K. Tsutsui, K. Takeda, and K. Fujii, “Action valuation of on-and off-ball soccer players based on multi-agent deep reinforcement learning,” IEEE Access, vol. 11, pp. 131 237–131 244, 2023.
  • [62] V. Ramanishka, Y.-T. Chen, T. Misu, and K. Saenko, “Toward driving scene understanding: A dataset for learning driver behavior and causal reasoning,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2018, pp. 7699–7707.
  • [63] T. You and B. Han, “Traffic accident benchmark for causality recognition,” in European Conference on Computer Vision.   NY: Springer, 2020, pp. 540–556.
  • [64] D. McDuff, Y. Song, J. Lee, V. Vineet, S. Vemprala, N. Gyde, H. Salman, S. Ma, K. Sohn, and A. Kapoor, “Causalcity: Complex simulations with agency for causal discovery and reasoning,” arXiv preprint arXiv:2106.13364, 2021.
  • [65] K. Takeuchi, R. Nishida, H. Kashima, and M. Onishi, “Grab the reins of crowds: Estimating the effects of crowd movement guidance using causal inference,” in Proceedings of the 20th International Conference on Autonomous Agents and MultiAgent Systems, 2021, pp. 1290–1298.
  • [66] D. R. Yam and M. J. Lopez, “What was lost? a causal estimate of fourth down behavior in the national football league,” Journal of Sports Analytics, vol. 5, no. 3, pp. 153–167, 2019.
  • [67] A. Toumi and M. Lopez, “From grapes and prunes to apples and apples: Using matched methods to estimate optimal zone entry decision-making in the national hockey league,” in Carnegie Mellon Sports Analytics Conference 2019, 2019.
  • [68] C. Gibbs, R. Elmore, and B. Fosdick, “The causal effect of a timeout at stopping an opposing run in the nba,” arXiv preprint arXiv:2011.11691, 2020.
  • [69] H. Nakahara, K. Takeda, and K. Fujii, “Pitching strategy evaluation via stratified analysis using propensity score,” Journal of Quantitative Analysis in Sports, vol. 19, no. 2, pp. 91–102, 2023.
  • [70] D. M. Vock and L. F. B. Vock, “Estimating the effect of plate discipline using a causal inference framework: an application of the g-computation algorithm,” Journal of Quantitative Analysis in Sports, vol. 14, no. 2, pp. 37–56, 2018.
  • [71] A. Dosovitskiy, G. Ros, F. Codevilla, A. Lopez, and V. Koltun, “Carla: An open urban driving simulator,” in Conference on Robot Learning.   PMLR, 2017, pp. 1–16.
  • [72] K. Cho, B. Van Merriënboer, D. Bahdanau, and Y. Bengio, “On the properties of neural machine translation: Encoder-decoder approaches,” arXiv preprint arXiv:1409.1259, 2014.
  • [73] C. J. Willmott and K. Matsuura, “Advantages of the mean absolute error (mae) over the root mean square error (rmse) in assessing average model performance,” Climate Research, vol. 30, no. 1, pp. 79–82, 2005.
  • [74] S. Kato, E. Takeuchi, Y. Ishiguro, Y. Ninomiya, K. Takeda, and T. Hamada, “An open approach to autonomous vehicles,” IEEE Micro, vol. 35, no. 6, pp. 60–68, 2015.
  • [75] C. W. Reynolds, “Flocks, herds and schools: A distributed behavioral model,” in Proceedings of the 14th annual Conference on Computer Graphics and Interactive Techniques, 1987, pp. 25–34.
  • [76] K. Fujii, Y. Inaba, and Y. Kawahara, “Koopman spectral kernels for comparing complex dynamics: Application to multiagent sport plays,” in European Conference on Machine Learning and Knowledge Discovery in Databases (ECML-PKDD’17).   NY: Springer, 2017, pp. 127–139.
  • [77] D. P. Kingma and J. Ba, “Adam: A method for stochastic optimization,” in International Conference on Learning Representations, 2015.
  • [78] M. Quigley, K. Conley, B. Gerkey, J. Faust, T. Foote, J. Leibs, R. Wheeler, A. Y. Ng et al., “Ros: an open-source robot operating system,” in ICRA workshop on open source software, vol. 3, no. 3.2.   Kobe, Japan, 2009, p. 5.
  • [79] K. Fujii, K. Yokoyama, T. Koyama, A. Rikukawa, H. Yamada, and Y. Yamamoto, “Resilient help to switch and overlap hierarchical subsystems in a small human group,” Scientific Reports, vol. 6, 2016.
  • [80] T. Decroos, L. Bransen, J. Van Haaren, and J. Davis, “Actions speak louder than goals: Valuing player actions in soccer,” in KDD, 2019, pp. 1851–1861.
  • [81] K. Toda, M. Teranishi, K. Kushiro, and K. Fujii, “Evaluation of soccer team defense based on prediction models of ball recovery and being attacked,” PLoS One, vol. 17, no. 1, p. e0263051, 2022.
  • [82] G. Ke, Q. Meng, T. Finley, T. Wang, W. Chen, W. Ma, Q. Ye, and T.-Y. Liu, “Lightgbm: A highly efficient gradient boosting decision tree,” Advances in Neural Information Processing Systems, vol. 30, pp. 3146–3154, 2017.