This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Learning Implicit Generative Models Using
Differentiable Graph Tests

Josip Djolonga
ETH Zürich
josipd@inf.ethz.ch
   Andreas Krause
ETH Zürich
krausea@ethz.ch
Abstract

Recently, there has been a growing interest in the problem of learning rich implicit models — those from which we can sample, but can not evaluate their density. These models apply some parametric function, such as a deep network, to a base measure, and are learned end-to-end using stochastic optimization. One strategy of devising a loss function is through the statistics of two sample tests — if we can fool a statistical test, the learned distribution should be a good model of the true data. However, not all tests can easily fit into this framework, as they might not be differentiable with respect to the data points, and hence with respect to the parameters of the implicit model. Motivated by this problem, in this paper we show how two such classical tests, the Friedman-Rafsky and kk-nearest neighbour tests, can be effectively smoothed using ideas from undirected graphical models – the matrix tree theorem and cardinality potentials. Moreover, as we show experimentally, smoothing can significantly increase the power of the test, which might of of independent interest. Finally, we apply our method to learn implicit models.

1 Introduction

The main motivation for our work is that of learning implicit models, i.e., those from which we can easily sample, but can not evaluate their density. Formally, we can generate a sample from an implicit distribution QQ by first drawing 𝐳\mathbf{z} from some known and fixed distribution Q0Q_{0}, typically Gaussian or uniform, and then passing it through some differentiable function f𝜽f_{\bm{\theta}} parametrized by some vector 𝜽{\bm{\theta}} to generate 𝐱=f𝜽(𝐳)Q\mathbf{x}=f_{\bm{\theta}}(\mathbf{z})\sim Q. The goal is then to optimize the parameters 𝜽{\bm{\theta}} of the mapping 𝜽{\bm{\theta}} so that QQ is as close as possible to some target distribution PP, which we can access only via iid samples. The approach that we undertake in this paper is that of defeating statistical two-sample tests. These tests operate in the following setting — given two sets of iid samples, X1={𝐱1,𝐱2,,𝐱n1}X_{1}=\{\mathbf{x}_{1},\mathbf{x}_{2},\ldots,\mathbf{x}_{n_{1}}\} from PP, and X2={𝐱n1+1,𝐱n1+2,,𝐱n1+n2}X_{2}=\{\mathbf{x}_{n_{1}+1},\mathbf{x}_{n_{1}+2},\ldots,\mathbf{x}_{n_{1}+n_{2}}\} from QQ, we have to distinguish between the following hypotheses

H0:P=Q vs H1:PQ.H_{0}\colon P=Q\quad\textrm{ vs }\quad H_{1}\colon P\neq Q.

The tests that we consider start by defining a function T:(d)n1×(d)n2T\colon(\mathbb{R}^{d})^{n_{1}}\times(\mathbb{R}^{d})^{n_{2}}\to\mathbb{R} that should result in a low value if the two samples come from different distributions. Then, the hypothesis H0H_{0} is rejected at significance level α[0,1]\alpha\in[0,1] if T(X1,X2)T(X_{1},X_{2}) is lower than some threshold tαt_{\alpha}, which is computed using a permutation test, as explained in Section 2. Going back the original problem, one intuitive approach would be to maximize the expected statistic 𝔼𝐱iP,𝐳iQ0[T({𝐱i}i=1n1,{f𝜽(𝐳i)}i=n1n1+n2)]\mathbb{E}_{\mathbf{x}_{i}\sim P,\mathbf{z}_{i}\sim Q_{0}}[T(\{\mathbf{x}_{i}\}_{i=1}^{n_{1}},\{f_{\bm{\theta}}(\mathbf{z}_{i})\}_{i=n_{1}}^{n_{1}+n_{2}})] using stochastic optimization over the parameters of the mapping f𝜽f_{\bm{\theta}}. However, this requires the availability of the derivatives T/𝐱i\partial T/\partial\mathbf{x}_{i}, which is unfortunately not always possible. For example, the Friedman-Rafsky (FR) and kk-nearest neighbours (kk-NN) tests, which have very desirable statistical properties (including consistency and convergence of their statistics to ff-divergences), can not be cast in the above framework as they use the output of a combinatorial optimization problem. Our main contribution is the development of differentiable versions of these tests that remedy the above problem by smoothing their statistics. We moreover show, similarly to these classical tests, that our tests are asymptotically normal under certain conditions, and derive the corresponding tt-statistic, which can be evaluated with minimal additional complexity. Our smoothed tests can have more power over their classical variants, as we showcase with numerical experiments. Finally, we experimentally learn implicit models in Section 5.

Related work.

The problem of two-sample testing for distributional equality has received significant interest in statistics. For example, the celebrated Kolmogorov-Smirnov test compares two one dimensional distributions by taking the maximal difference of the empirical CDFs. Another one-dimensional test is the runs test of Wald and Wolfowitz [1], which has been extended to the multivariate case by Friedman and Rafsky [2] (FR). It is exactly this test, together with kk-NN test originally suggested in [3] that we analyze. These tests have been analyzed in more detail by Henze and Penrose [4], and Henze [5], Schilling [6] respectively. Their asymptotic efficiency has been discussed by Bhattacharya [7]. Chen and Zhang [8] considered the problem of tie breaking when applying the FR tests to discrete data and suggested averaging over all minimal spanning trees, which can be seen as as special case of our test in the low-temperature setting. A very prominent test that has been more recently developed is the kernel maximum mean discrepancy (MMD) test of Gretton et al. [9], which we compare with in Section 5. The test statistic is differentiable and has been used for learning implicit models by Li et al. [10], Dziugaite et al. [11]. Sutherland et al. [12] consider the problem of learning the kernel by creating a tt-statistic using a variance estimator. Moreover, they also pioneered the idea of using tests for model criticism — for two fixed distributions, one optimizes over the parameters of the test (the kernel used). The energy test of Székely and Rizzo [13], a special case of the MMD test, has been used by Bellemare et al. [14].

Other approaches for learning implicit models that do not depend on two sample tests have been developed as well. For example, one approach is by estimating the log-ratio of the distributions [15]. Another approach, that has recently sparked significant interest, and can be also seen as estimating the log-ratio of the distributions, are the generative adversarial networks (GAN) of Goodfellow et al. [16], who pose the problem as a two player game. One can, as done in [12], combine GANs with two sample tests by using them as feature matchers at some layer of the generating network [17]. Nowozin et al. [18] minimize an arbitrary ff-divergence [19] using a GAN framework, which can be related to our approach, because the limit of our tests converge to specific ff-divergences, as explained in Section 2. For an overview of various approaches to learning implicit models we direct the reader to Mohamed and Lakshminarayanan [20].

2 Classical Graph Tests

Let us start by introducing some notation. For any set X={𝐱1,𝐱2,,𝐱n}X=\{\mathbf{x}_{1},\mathbf{x}_{2},\ldots,\mathbf{x}_{n}\} of points in d\mathbb{R}^{d}, we will denote by 𝒢(X)=(X,E)\mathcal{G}(X)=(X,E) the complete directed graph111For the FR test we will arbitrarily choose one of the two edges for each pair of nodes. defined over the vertex set XX with edges EE. We will moreover weigh this graph using some function d:d×d[0,)d\colon\mathbb{R}^{d}\times\mathbb{R}^{d}\to[0,\infty), e.g. a natural choice would be d(𝐱,𝐱)=𝐱𝐱d(\mathbf{x},\mathbf{x}^{\prime})=\|\mathbf{x}-\mathbf{x}^{\prime}\|. Similarly, we will use d(e)d(e) for the weight of the edge ee under d(,)d(\cdot,\cdot). For any labelling of the vertices π:X{1,2}\pi:X\to\{1,2\}, and any edge eEe\in E with adjacent vertices ii and jj we define222We use the Iverson bracket S\llbracket S\rrbracket that evaluates to 1 if SS is true and 0 otherwise. Δπ(e)=π(i)π(j)\Delta_{\pi}(e)=\llbracket\pi(i)\neq\pi(j)\rrbracket, i.e., Δπ(e)\Delta_{\pi}(e) indicates if its end points of ee have different labels under π\pi. Remember that we are given n1n_{1} points X1={𝐱1,𝐱2,,𝐱n1}X_{1}=\{\mathbf{x}_{1},\mathbf{x}_{2},\ldots,\mathbf{x}_{n_{1}}\} from PP, and n2n_{2} points X2={𝐱n1+1,𝐱n1+2,,𝐱n1+n2}X_{2}=\{\mathbf{x}_{n_{1}+1},\mathbf{x}_{n_{1}+2},\ldots,\mathbf{x}_{n_{1}+n_{2}}\} from QQ. In the remaining of the paper we will use n=n1+n2n=n_{1}+n_{2} for the total number of points. The tests are based on the following four-step strategy.

  1. (i)

    Pool the samples X1X_{1} and X2X_{2} together into X=X1X2={𝐱1,𝐱2,,𝐱n1+n2}X=X_{1}\cup X_{2}=\{\mathbf{x}_{1},\mathbf{x}_{2},\ldots,\mathbf{x}_{n_{1}+n_{2}}\}, and create the graph 𝒢(X)\mathcal{G}(X). Define the mapping π:X{1,2}\pi^{*}\colon X\to\{1,2\} evaluating to 1 on X1X_{1} and to 2 on X2X_{2}.

  2. (ii)

    Using some well-defined algorithm 𝒜\mathcal{A} choose a subset U=𝒜(𝒢(X))U^{*}=\mathcal{A}(\mathcal{G}(X)) of the edges of this graph with the underlying motivation that it defines some neighbourhood structure.

  3. (iii)

    Count how many edges in UU^{*} connect points from X1X_{1} with points from X2X_{2}, i.e., compute the statistic Tπ(U)=eUΔπ(e)T_{\pi^{*}}(U^{*})=\sum_{e\in U^{*}}\Delta_{\pi^{*}}(e).

  4. (iv)

    Reject H0H_{0} for small values of Tπ(U)T_{\pi^{*}}(U^{*}).

These tests condition on the data and are executed as permutation tests, so that the critical value in step (iv) is computed using the quantiles of 𝔼πH0Tπ(U){\mathbb{E}}_{\pi\sim H_{0}}T_{\pi}(U^{*}), where π:X{1,2}\pi\colon X\to\{1,2\} is drawn uniformly at random from the set of (n1+n2n1){n_{1}+n_{2}\choose n_{1}} labellings that map exactly n1n_{1} points from XX to 1. Formally, the pp-value is given as 𝔼πH0[Tπ(U)Tπ(U)]\mathbb{E}_{\pi\sim H_{0}}[\llbracket T_{\pi^{*}}(U^{*})\geq T_{\pi}(U^{*})\rrbracket]. We are now ready to introduce the two tests that we consider in this paper, which are obtained by using a different neighbourhood selection algorithm 𝒜\mathcal{A} in step (ii).

Friedman-Rafsky (FR).

This test, developed by Friedman and Rafsky [2], uses the minimum-spanning tree (MST) of 𝒢(X)\mathcal{G}(X) as the neighbourhood structure UU^{*}, which can be computed using the classical algorithms of Prim [21] and Kruskal [22] in time O(n2logn)O(n^{2}\log n). If we use d(𝐱i,𝐱j)=𝐱i𝐱jd(\mathbf{x}_{i},\mathbf{x}_{j})=\|\mathbf{x}_{i}-\mathbf{x}_{j}\|, the problem is also known as the Euclidean spanning tree problem, and in this case Henze and Penrose [4] have proven that the test is consistent and has the following asymptotic limit.

Theorem 1 ([4]).

If d(𝐱,𝐱)=𝐱𝐱d(\mathbf{x},\mathbf{x}^{\prime})=\|\mathbf{x}-\mathbf{x}^{\prime}\| and n1/(n1+n2)α(0,1)n_{1}/(n_{1}+n_{2})\to\alpha\in(0,1), then it almost surely holds that

Tπ(U)n1+n22α(1α)p(𝐱)q(𝐱)αp(𝐱)+(1α)q(𝐱)𝑑𝐱,\frac{T_{\pi^{*}}(U^{*})}{n_{1}+n_{2}}\to 2\alpha(1-\alpha)\int\frac{p(\mathbf{x})q(\mathbf{x})}{\alpha p(\mathbf{x})+(1-\alpha)q(\mathbf{x})}d\mathbf{x},

where pp and qq are the densities of PP and QQ.

As noted by Berisha and Hero [23], after some algebraic manipulation of the right hand side of the above equation, we obtain that 1Tπ(U)n1+n22n1n21-T_{\pi^{*}}(U^{*})\frac{n_{1}+n_{2}}{2n_{1}n_{2}} converges almost surely to the following ff-divergence [19]

DαFR(PQ)\displaystyle D^{\textrm{FR}}_{\alpha}(P\,\|\,Q) =14α(1α)(αp(𝐱)(1α)q(𝐱))2αp(𝐱)+(1α)q(𝐱)𝑑𝐱\displaystyle=\frac{1}{4\alpha(1-\alpha)}\int\frac{(\alpha p(\mathbf{x})-(1-\alpha)q(\mathbf{x}))^{2}}{\alpha p(\mathbf{x})+(1-\alpha)q(\mathbf{x})}d\mathbf{x}
(2α1)24α(1α).\displaystyle-\frac{(2\alpha-1)^{2}}{4\alpha(1-\alpha)}.

In [23] it is also noted that if n1=n2n_{1}=n_{2}, then α=1/2\alpha=1/2 and in that case D1/2D_{1/2} is equal to 2(p(𝐱)q(𝐱))2p(𝐱)+q(𝐱)𝑑𝐱2\int\frac{(p(\mathbf{x})-q(\mathbf{x}))^{2}}{p(\mathbf{x})+q(\mathbf{x})}d\mathbf{x}, which is known as the symmetric χ2\chi^{2} divergence.

0112233445500.50.5111.51.5xxf(x)f(x)D1/2FRD^{\textrm{FR}}_{1/2}D1/2NND^{\textrm{NN}}_{1/2}
Figure 1: The functions generating the ff-divergences.

kk-nearest-neighbours (kk-NN).

Maybe the most intuitive way to construct a neighbourhood structure is to connect each point 𝐱jX\mathbf{x}_{j}\in X to its kk nearest neighbours. Specifically, we will add the edge 𝐱i𝐱j\mathbf{x}_{i}\to\mathbf{x}_{j} to UU^{*} iff 𝐱i\mathbf{x}_{i} is one of the kk closest neighbours of 𝐱j\mathbf{x}_{j} as measured by d(𝐱,𝐱)d(\mathbf{x},\mathbf{x}^{\prime}). If one uses the Euclidean norm, then the asymptotic distribution and the consistency of the test have been proven by Schilling [6]. These results has been extended to arbitrary norms by Henze [5], who also proved the limiting behaviour of the statistic as nn\to\infty.

Theorem 2 ([5]).

If n1/(n1+n2)α(0,1)n_{1}/(n_{1}+n_{2})\to\alpha\in(0,1), then 1Tπ(U)(n1+n2)k1-\frac{T_{\pi^{*}}(U^{*})}{(n_{1}+n_{2})k} converges in probability to

DαNN(PQ)α2p2(𝐱)+(1α)2q2(𝐱)αp(𝐱)+(1α)q(𝐱)𝑑𝐱,D^{\textrm{NN}}_{\alpha}(P\,\|\,Q)\equiv\int\frac{\alpha^{2}p^{2}(\mathbf{x})+(1-\alpha)^{2}q^{2}(\mathbf{x})}{\alpha p(\mathbf{x})+(1-\alpha)q(\mathbf{x})}d\mathbf{x},

where pp and qq are the continuous densities of PP and QQ.

As for the FR test, we can also re-write the limit as an ff-divergence333This ff does not vanish at one, but we can simply shift it. corresponding to f(t)=(α2t2+(1α))/(αt+(1α))f(t)=(\alpha^{2}t^{2}+(1-\alpha))/(\alpha t+(1-\alpha)). Moreover, if we compare the integrands in DαFRD^{\textrm{FR}}_{\alpha} and DαNND^{\textrm{NN}}_{\alpha}, we see that they are related and they differ by the term 2α(1α)p(𝐱)q(𝐱)2\alpha(1-\alpha)p(\mathbf{x})q(\mathbf{x}) in the numerator. The fact that they are closely related can be also seen from Figure 1, where we plot the corresponding ff-functions for the n1=n2n_{1}=n_{2} case.

3 Differentiable Graph Tests

While the tests from the previous section have been studied from a statistical perspective, we can not use them to train implicit models because the derivatives T/𝐱i\partial T/\partial\mathbf{x}_{i} are either zero or do not exist, as TT takes on finitely many values. The strategy that we undertake in this paper is to smooth them into continuously differentiable functions by relaxing them to expectations in natural probabilistic models. To motivate the models we will introduce, note that for both the kk-NN and the FR test, the optimal neighbourhood is the solution to the following optimization problem

U=argminUEeUd(e)s.t. ν(U)=1,U^{*}=\operatorname*{arg\,min}_{U\subseteq E}\sum_{e\in U}d(e)\;\textrm{s.t.\ }\nu(U)=1, (1)

where ν:2E{0,1}\nu\colon 2^{E}\to\{0,1\} indicates if the set of edges is valid, i.e., if every vertex has exactly kk neighbours in the kk-NN case, or if the set of edges forms a poly-tree in the MST case. Moreover, note that once we fix n1n_{1} and n2n_{2}, the optimization problem (1) depends only on the edge weights d(e)d(e), which we will concatenate in an arbitrary order and store in the vector 𝐝|E|\mathbf{d}\in\mathbb{R}^{|E|}. We want to design a probability distribution over UU that focuses on those configurations UU that are both feasible and have a low cost for problem (1). One such natural choice is the following Gibbs measure

P(U𝐝/λ)=eeUd(e)/λA(𝐝/λ)ν(U),P(U\mid\mathbf{d}/\lambda)=e^{-\sum_{e\in U}d(e)/\lambda-A(-\mathbf{d}/\lambda)}\nu(U), (2)

where λ\lambda is the so-called temperature parameter, and A(𝐝/λ)A(-\mathbf{d}/\lambda) is the log-partition function that ensures that the distribution is normalized. Note that UU^{*} is a MAP configuration of this distribution (2), and the distribution will concentrate on the MAP configurations as λ0\lambda\to 0. Once we have fixed the model, the strategy is clear — replace the statistic Tπ(U)T_{\pi^{*}}(U^{*}) with its expectation 𝔼U[Tπ(U)]{\mathbb{E}}_{U}[T_{\pi^{*}}(U)], which results in the following smooth statistic

Tπ(U)Tπλ\displaystyle T_{\pi*}(U^{*})\longrightarrow T_{\pi^{*}}^{\lambda} 𝔼UP(𝐝,λ)[Tπ(U)]\displaystyle\equiv\mathbb{E}_{U\sim P(\cdot\mid\mathbf{d},\lambda)}[T_{\pi^{*}}(U)]
=eEΔπ(e)𝝁(𝐝/λ)e,\displaystyle=\sum_{e\in E}\Delta_{\pi^{*}}(e)\bm{\mu}(\mathbf{d}/\lambda)_{e},

where 𝝁(𝐝/λ)\bm{\mu}(\mathbf{d}/\lambda) are the marginal probabilities of the edges, i.e., [𝝁(𝐝/λ)]e=𝔼P(U𝐝/λ)[eU][\bm{\mu}(\mathbf{d}/\lambda)]_{e}=\mathbb{E}_{P(U\mid\mathbf{d}/\lambda)}[\llbracket e\in U\rrbracket]. Hence, we can compute the statistic as long as we can perform inference in (2). To compute its derivatives we can use the fact that (2) is a member of the exponential family. Namely, leveraging the classical properties of the log-partition function [24, Prop. 3.1], we obtain the following identities

𝝁(𝐝/λ)\displaystyle\bm{\mu}(\mathbf{d}/\lambda) =A(𝐝/λ), and\displaystyle=\nabla A(-\mathbf{d}/\lambda),\textrm{ and} (3)
𝝁(𝐝/λ)e𝝁(𝐝/λ)e\displaystyle\frac{\partial\bm{\mu}(\mathbf{d}/\lambda)_{e}}{\partial\bm{\mu}(\mathbf{d}/\lambda)_{e^{\prime}}} =𝔼P(U𝐝/λ)[{e,e}U]\displaystyle=\mathbb{E}_{P(U\mid\mathbf{d}/\lambda)}[\llbracket\{e,e^{\prime}\}\subseteq U\rrbracket]
𝝁(𝐝/λ)e𝝁(𝐝/λ)e.\displaystyle-\bm{\mu}(\mathbf{d}/\lambda)_{e}\bm{\mu}(\mathbf{d}/\lambda)_{e^{\prime}}.

Thus, if we can compute both first- and second-order moments under (2), we get both the smoothed statistic and its derivative. We show how to do this for the kk-NN and FR tests in Section 4.

A smooth pp-value.

Even though one can directly use the smoothed test statistic TπλT_{\pi^{*}}^{\lambda} as an objective when learning implicit models, it does not necessarily mean that lower values of this statistic result in higher pp-values. Remember that to compute a pp-value, one has to run a permutation test by computing quantiles of TπλT_{\pi}^{\lambda} under random draws of the permutation πH0\pi\sim H_{0}. However, as this procedure is not smooth and can be costly to compute, we suggest as an alternative that does not suffer from these problems the following tt-statistic

tπλ=Tπλ𝔼πH0[Tπλ]𝕍πH0[Tπλ].t^{\lambda}_{\pi^{*}}=\frac{T_{\pi^{*}}^{\lambda}-\mathbb{E}_{\pi\sim H_{0}}[T_{\pi}^{\lambda}]}{\sqrt{\mathbb{V}_{\pi\sim H_{0}}[T_{\pi}^{\lambda}]}}. (4)

The same strategy has been undertaken for the FR and kk-NN tests in [2, 4, 6]. Before we show to compute the first two moments under H0H_{0}, we need to define the matrix Π\Pi holding the second moments of the variables Δπ(e)\Delta_{\pi}(e).

Lemma 1 ([2]).

The matrix Π|E|×|E|\Pi\in\mathbb{R}^{|E|\times|E|} with entries Πe,e=𝔼πH0[Δπ(e)Δπ(e)]\Pi_{e,e^{\prime}}={\mathbb{E}}_{\pi\sim H_{0}}[\Delta_{\pi}(e)\Delta_{\pi}(e^{\prime})] is equal to

Πe,e={2n1n2n(n1) if δ(e)=δ(e), orn1n2n(n1) if |δ(e)δ(e)|=1, or4n1n2(n11)(n21)n(n1)(n2)(n3) if δ(e)δ(e)=,\Pi_{e,e^{\prime}}=\begin{cases}\frac{2n_{1}n_{2}}{n(n-1)}&\textrm{ if }\delta(e)=\delta(e^{\prime}),\textrm{ or}\\ \frac{n_{1}n_{2}}{n(n-1)}&\textrm{ if }|\delta(e)\cap\delta(e^{\prime})|=1,\textrm{ or}\\ \frac{4n_{1}n_{2}(n_{1}-1)(n_{2}-1)}{n(n-1)(n-2)(n-3)}&\textrm{ if }\delta(e)\cap\delta(e^{\prime})=\emptyset,\end{cases}

where δ(e)\delta(e) is the set of vertices incident to the edge eEe\in E.

Theorem 3.

Assume that all valid configurations UU satisfy |U|=m|U|=m, i.e. that ν(U)0\nu(U)\neq 0 implies |U|=m.|U|=m.444Note that we have m=knm=kn for kk-NN and m=n1m=n-1 for FR. Then, the first two moments of the statistic under H0H_{0} are

𝔼πH0[Tπλ]\displaystyle{\mathbb{E}}_{\pi\sim H_{0}}[T^{\lambda}_{\pi^{*}}] =2mn1n2/n(n1), and\displaystyle=2mn_{1}n_{2}/n(n-1),\textrm{ and}
𝕍πH0[Tπλ]\displaystyle{\mathbb{V}}_{\pi\sim H_{0}}[T^{\lambda}_{\pi^{*}}] =𝝁(𝐝/λ)TΠ𝝁(𝐝/λ)4n12n22n2(n1)2m2.\displaystyle=\bm{\mu}(\mathbf{d}/\lambda)^{T}\Pi\bm{\mu}(\mathbf{d}/\lambda)-4\frac{n_{1}^{2}n_{2}^{2}}{n^{2}(n-1)^{2}}m^{2}.

While the computation of the mean is trivial, it seems that the computation of the variance needs O(|E|2)O(|E|^{2}) operations. However, we can simplify its computation to O(|E|)O(|E|) using the following result.

Lemma 2.

Define χ1=n1n2n(n1)\chi_{1}=\frac{n_{1}n_{2}}{n(n-1)} and χ2=4(n11)(n21)(n2)(n3)\chi_{2}=\frac{4(n_{1}-1)(n_{2}-1)}{(n-2)(n-3)}. Then, the variance can be computed as

σ2\displaystyle\sigma^{2} =χ1(1χ2)v(eδ(v)μe)2\displaystyle=\chi_{1}(1-\chi_{2})\sum_{v}(\sum_{e\in\delta(v)}\mu_{e})^{2}
+χ1χ2eeμeμe+χ1(χ24χ1)m2,\displaystyle+\chi_{1}\chi_{2}\sum_{e\|e^{\prime}}\mu_{e}\mu_{e^{\prime}}+\chi_{1}(\chi_{2}-4\chi_{1})m^{2},

where ee\sum_{e\|e} sums over all pairs of parallel edges, i.e., those connecting the same end-points.

Approximate normality of tπλt_{\pi^{*}}^{\lambda}.

To better motivate the use of a tt-statistic, we can, similarly to the arguments in [2, 4, 6], show that it is is close to a normal distribution by casting it as a generalized correlation coefficient [25, 3]. Namely, these are tests whose statistics are the form form κ=i=1nj=1nμ¯i,jbi,j\kappa=\sum_{i=1}^{n}\sum_{j=1}^{n}\overline{\mu}_{i,j}b_{i,j}, and whose critical values are computed using the distribution of i=1nj=1nμ¯i,jbπ(i),π(j)\sum_{i=1}^{n}\sum_{j=1}^{n}\overline{\mu}_{i,j}b_{\pi(i),\pi(j)}, where π\pi is a random permutation on {1,2,,n}\{1,2,\ldots,n\}. It is easily seen that we can fit the suggested tests in this framework if we set μ¯i,j=12(𝝁(𝐝/λ)ij+𝝁(𝐝/λ)ji)\overline{\mu}_{i,j}=\frac{1}{2}(\bm{\mu}(\mathbf{d}/\lambda)_{i\to j}+\bm{\mu}(\mathbf{d}/\lambda)_{j\to i}) and bi,j=Δπ({i,j})b_{i,j}=\Delta_{\pi^{*}}(\{i,j\}). Then, using the conditions of Barbour and Eagleson [26], we obtain the following bound on the deviation from normality.

Theorem 4.

Let n1/(n1+n2)α(0,1)n_{1}/(n_{1}+n_{2})\to\alpha\in(0,1), and define

  • S2=i,j,kμ¯i,jμ¯i,kS_{2}=\sum_{i,j,k}\overline{\mu}_{i,j}\overline{\mu}_{i,k}, i.e., the expected number of edges sharing a vertex,

  • S3=i,j,k,mμ¯i,jμ¯i,kμ¯i,mS_{3}=\sum_{i,j,k,m}\overline{\mu}_{i,j}\overline{\mu}_{i,k}\overline{\mu}_{i,m}, i.e., the expected number of 3 stars, and

  • L4=i,j,k,mμ¯i,jμ¯j,kμ¯k,mL_{4}=\sum_{i,j,k,m}\overline{\mu}_{i,j}\overline{\mu}_{j,k}\overline{\mu}_{k,m}, i.e., the expected number of paths with 4 nodes.

Then, the Wasserstein distance between the permutation null 𝔼πH0[Tπλ(U)]{\mathbb{E}}_{\pi\sim H_{0}}[T_{\pi}^{\lambda}(U^{*})] and the standard normal is of order O((nk3+kS2+S3+L4)/σ3)O\big{(}(nk^{3}+kS_{2}+S_{3}+L_{4})/\sigma^{3}\big{)}.

Let us analyze the above bound in the setting that we will use it — when n1=n2n_{1}=n_{2}. First, let us look at the variance, as formulated in 2. The last term can be ignored as it is always non-negative because χ24χ1\chi_{2}\geq 4\chi_{1} (shown in the appendix). Because eδ(v)μe1\sum_{e\in\delta(v)}\mu_{e}\geq 1, it follows that the variance grows as Ω(n)\Omega(n). Thus, without any additional assumption on the growth of the neighbourhoods, we have asymptotic normality as nn\to\infty if the numerator is of order o(n1.5)o(n^{1.5}). For example, that would be satisfied if the largest neighbourhood maxieδ(i)μ¯e\max_{i}\sum_{e\in\delta(i)}\overline{\mu}_{e} grows as o(n1/6)o(n^{1/6}). Note that in the low temperature setting (when λ0\lambda\to 0), the coordinates of 𝝁\bm{\mu} will be very close to either zero or one. As observed by Friedman and Rafsky [2], in this case S2=O(1)S_{2}=O(1) as the nodes of both the kk-NN and MST graphs have nodes whose degree is bounded by a constant independent of nn as nn\to\infty [27]. We also observe experimentally in Section 5 that the distribution gets closer to normality as λ\lambda decreases.

4 The Differentiable kk-NN and Friedman-Rafsky Tests

In this section, we discuss these two tests in more detail and show to efficiently compute their statistics. Remember that to compute and optimize both TπλT_{\pi^{*}}^{\lambda} and tπλt_{\pi^{*}}^{\lambda} we have to be able to perform inference in the model P(U)=exp(ed(e)/λA(𝐝/λ))ν(U)P(U)=\exp(-\sum_{e}d(e)/\lambda-A(-\mathbf{d}/\lambda))\nu(U), by computing the first and- second-order moments of the edge indicator variables. We would stress that, in the learning setting that we consider nn refers to the number of data-points in a mini-batch.

kk-NN.

The constraint ν()\nu(\cdot) in this case requires the total number of edges in UU incoming at each node to be exactly kk. First, note that the problem completely separates per node, i.e., the marginals of edges with different target vertices are independent. Formally, if we denote by UiU_{i} the set of edges incoming at vertex ii, then UiU_{i} and UjU_{j} are independent for iji\neq j. Hence, for each node ii separately, have to perform inference in

P(Ui)exp(jUid(𝐱i,𝐱j)/λ)|Ui|=k,P(U_{i})\propto\exp(-\sum_{j\in U_{i}}d(\mathbf{x}_{i},\mathbf{x}_{j})/\lambda)\llbracket|U_{i}|=k\rrbracket,

which is a special case of the cardinality potentials considered by Tarlow et al. [28], Swersky et al. [29]. Swersky et al. [29] consider the same model, and note that we can compute all marginals in time O(nk)O(nk) using the algorithm in [28], which works by re-writing the model as a chain CRF and running the classical forward-backward algorithm. Hence, the total time complexity to compute the vector 𝝁(𝐝/λ)\bm{\mu}(\mathbf{d}/\lambda) is O(n2k)O(n^{2}k). Moreover, as marginalization requires only simple operations, we can compute the derivatives with any automatic differentiation software, and we thus do not provide formulas for the second-order moments. In [29] the authors provide an approximation for the Jacobian, which we did not use in our experiments, but instead we differentiate through the messages of the forward-backward algorithm.

As a concrete example, let us work out the simplest case — the kk-NN test with k=1k=1. In this case, the smoothed statistic reduces to

Tπλ(𝐱1,,𝐱n)=i=1nj=1π(i)π(j)nsi(𝐱1,,𝐱n)j,T^{\lambda}_{\pi^{*}}(\mathbf{x}_{1},\ldots,\mathbf{x}_{n})=\sum_{i=1}^{n}\sum_{\begin{subarray}{c}j=1\\ \pi^{*}(i)\neq\pi^{*}(j)\end{subarray}}^{n}s_{i}(\mathbf{x}_{1},\ldots,\mathbf{x}_{n})_{j},

where si(𝐱1,,𝐱n)=softmax(li𝐱i𝐱l/λ)s_{i}(\mathbf{x}_{1},\ldots,\mathbf{x}_{n})=\texttt{softmax}(-\otimes_{l\neq i}\|\mathbf{x}_{i}-\mathbf{x}_{l}\|/\lambda). In other words, for each ii you compute the softmax of the distances to all other points using sis_{i}, and then sum up only those positions that correspond to points from the other sample. One interpretation of the loss is the following — maximize the number of incorrect predictions if we are to estimate the label π(i)\pi(i) from 𝐱i\mathbf{x}_{i} using a soft 11-nearest neighbour approach.

Furthermore, we can also make a clear connection between the smooth 11-NN test and neighbourhood component analysis (NCA) [30]. Namely, we can see NCA as learning a mapping h:𝐱A𝐱h\colon\mathbf{x}\to A\mathbf{x} so that the test distinguishes (by minimizing TπλT^{\lambda}_{\pi^{*}}) the two samples as best as possible after applying hh on them. The extension of NCA to kk-NN [31] can be also seen as minimizing the test statistic for a particular instance of their loss function.

Friedman-Rafsky.

The model that we have to perform inference in for this test seems extremely complicated and intractable at first because the constraint has the form ν(U)=U forms a spanning tree\nu(U)=\llbracket U\textrm{ forms a spanning tree}\rrbracket. First, note that if 𝐝/λ\mathbf{d}/\lambda had all entries equal to a constant γ\gamma, we have that A(𝐝/λ)=(1n)γ+logcG(X)A(-\mathbf{d}/\lambda)=(1-n)\gamma+\log c_{G(X)}, where c𝒢(X)c_{\mathcal{G}(X)} is the number of spanning trees in the graph 𝒢(X)\mathcal{G}(X), and can be computed using Kirchoff’s (also known as the matrix-tree) theorem. To treat the weighted case, we use the approach of Lyons [32], who has showed that the above model is a determinantal point process (DPP), so that marginalization can be done exactly as follows. First, create the incidence matrix A{1,0,+1}(n1)×|E|A\in\{-1,0,+1\}^{(n-1)\times|E|} of the graph 𝒢(X)\mathcal{G}(X) after removing an arbitrary vertex, and construct its Laplacian L=Adiag[exp(𝐝/λ)]ATL=A\texttt{diag}\big{[}\exp(-\mathbf{d}/\lambda)\big{]}A^{T}. Then, if we compute H=L1/2Adiag[exp(𝐝/(2λ))]H=L^{-1/2}A\texttt{diag}\big{[}\exp(-\mathbf{d}/(2\lambda))\big{]}, the distribution P(U)P(U) is a DPP with kernel matrix K=HTHK=H^{T}H, implying that for every WEW\subseteq E

𝔼P(U𝐝/λ)[WU]=detKW,\mathbb{E}_{P(U\mid\mathbf{d}/\lambda)}[\llbracket W\subseteq U\rrbracket]=\det K_{W},

where KWK_{W} is the |W|×|W||W|\times|W| submatrix of KK formed by the rows and columns indexed by WW. Thus, we can easily compute all marginals and the smoothed test statistic and its derivatives using (3) as

μij\displaystyle\mu_{i\to j} =ed(𝐱i,𝐱j)/λ(𝐮i𝐮j)TL1(𝐮i𝐮j), and\displaystyle=e^{-d(\mathbf{x}_{i},\mathbf{x}_{j})/\lambda}(\mathbf{u}_{i}-\mathbf{u}_{j})^{T}L^{-1}(\mathbf{u}_{i}-\mathbf{u}_{j}),\textrm{ and}
μijμkl\displaystyle\frac{\partial\mu_{i\to j}}{\partial\mu_{k\to l}} =ed(𝐱i,𝐱j)+d(𝐱k,𝐱l)λ((𝐮i𝐮j)TL1(𝐮k𝐮l))2,\displaystyle=e^{-\frac{d(\mathbf{x}_{i},\mathbf{x}_{j})+d(\mathbf{x}_{k},\mathbf{x}_{l})}{\lambda}}((\mathbf{u}_{i}-\mathbf{u}_{j})^{T}L^{-1}(\mathbf{u}_{k}-\mathbf{u}_{l}))^{2},

where 𝐮i\mathbf{u}_{i} is the vector with coordinates equal to zero, except the ii-th coordinate which is one. Note that if we first compute the inverse L1L^{-1}, all quantities of the form L1(𝐮i𝐮j)L^{-1}(\mathbf{u}_{i}-\mathbf{u}_{j}) can be computed in time O(n)O(n) as the vectors 𝐮i\mathbf{u}_{i} have a single non-zero entry, for a total complexity of O(n3)O(n^{3}).

To speed up this computation we can leverage the existing theory on fast solvers of Laplacian systems. Let us first create from 𝒢(X)\mathcal{G}(X) the graph e𝒢(X)e^{\mathcal{G}}(X) that has the same structure as 𝒢(X)\mathcal{G}(X), but with edge weights ed(e)/λe^{-d(e)/\lambda} instead of d(e)d(e). Hence, in this graph, a large weight between 𝐱\mathbf{x} and 𝐱\mathbf{x}^{\prime} indicates that these two points are similar to one another. In e𝒢(X)e^{\mathcal{G}}(X), the marginals 𝝁e\bm{\mu}_{e} are also known as effective resistances555For additional properties of the effective resistances see [33].. Spielman and Srivastava [34] provide a method to compute all marginals at once in time that is O~(rn2/ε2)\tilde{O}(rn^{2}/\varepsilon^{2}), where ε\varepsilon is the desired relative precision and r=1λ(maxed(e)mined(e))r=\frac{1}{\lambda}(\max_{e}d(e)-\min_{e}d(e)). The idea is to first solve for ZT=L1Adiag[exp(𝐝/2λ)]RZ^{T}=L^{-1}A\texttt{diag}\big{[}\exp(-\mathbf{d}/2\lambda)\big{]}R where R{1/k,+1/k}|E|×pR\in\{-1/\sqrt{k},+1/\sqrt{k}\}^{|E|\times p} is a random projection matrix with elements chosen uniformly from {1/k,+1/k}\{-1/\sqrt{k},+1/\sqrt{k}\} and p=O(logn/ε2)p=O(\log n/\varepsilon^{2}). Then, the suggested approximation is μijZ(𝐮i𝐮j)2\mu_{i\to j}\approx\|Z(\mathbf{u}_{i}-\mathbf{u}_{j})\|^{2}. While computing ZZ naïvely would take O(n3+n2p)O(n^{3}+n^{2}p), one achieves the claimed bound with the Laplacian solver of Spielman and Teng [35].

As an extra benefit, the above connection provides an alternative interpretation of the smoothed FR test. Namely, assume that we want to create a spectral sparsifier [36] of e𝒢(X)e^{\mathcal{G}}(X), which should contain significantly less edges, but be a good summary of the graph by having a similar spectrum. Spielman and Srivastava [34] provide a strategy to create such a sparsifier by sampling edges randomly, where edge ee is sampled proportional to μe\mu_{e}. Hence, by optimizing TπλT^{\lambda}_{\pi^{*}} we are encouraging the constructed sparsifier of e𝒢(X)e^{\mathcal{G}}(X) to have in expectation as many edges as possible connecting points from X1X_{1} with points from X2X_{2}.

5 Experiments

We implemented our methods in Python using the PyTorch library. For the kk-NN test, we have adapted the code accompanying [29]. Throughout this section we used a 10 dimensional normal as Q0Q_{0}, drew samples of equal size n1=n2n_{1}=n_{2}, and used the 2\ell_{2} norm d(𝐱,𝐱)=𝐱𝐱2d(\mathbf{x},\mathbf{x}^{\prime})=\|\mathbf{x}-\mathbf{x}^{\prime}\|_{2} as a weighting function. We provide additional details in Appendix B.

Power as a function of λ\lambda and dd.

In our first experiment we analyze the effect of the smoothing strength on the power of our differentiable tests. In addition to the classical FR and kk-NN tests, we have considered the unbiased MMD test [9] with the squared exponential kernel (as implemented in Shogun [37] using the code from [12]), and the energy test [13]. The problem that we consider, which is challenging in high dimensions, is that of differentiating the distribution 𝒩(𝟎,I)\mathcal{N}(\mathbf{0},I) from 𝒩((μ,0,,0),diag(σ2,1,,1))\mathcal{N}((\mu,0,\ldots,0),\texttt{diag}(\sigma^{2},1,\ldots,1)). This setting was considered to be fair in [38], as the KL divergence between the distribution is constant irrespective of the dimension. To set the smoothing strength and the bandwidth of the MMD kernel (in addition to the median heuristic) we used the same strategy as in [38] by setting λ=dγ\lambda=d^{\gamma} for varying γ[0,1]\gamma\in[0,1]. The results are presented in Figure 2, where can observe that (i) our test have similar results with MMD for shift-alternatives, while performing significantly better for scale alternatives, and (ii) by varying the smoothing parameter we can significantly increase the power of the test. In the third column we present only the best performing MMD, while we present the remaining results in Appendix B. Note that we expect the power to go to zero as the dimension increases [7, 38].

Learning.

As we have already hinted in the introduction, we stochastically optimize

max.𝜽𝔼𝐱iP,𝐳iQ0[tπλ({𝐱i}i=1n1,{f𝜽(𝐳i)}i=n1n1+n2)]\textrm{max.}_{{\bm{\theta}}}\,\mathbb{E}_{\mathbf{x}_{i}\sim P,\mathbf{z}_{i}\sim Q_{0}}[t_{\pi^{*}}^{\lambda}(\{\mathbf{x}_{i}\}_{i=1}^{n_{1}},\{f_{\bm{\theta}}(\mathbf{z}_{i})\}_{i=n_{1}}^{n_{1}+n_{2}})]

using the Adam [39] optimizer. To optimize, we draw at each round n1n_{1} samples from the true distribution PP, n2=n1n_{2}=n_{1} samples from the base measure Q0Q_{0}, and then plug them in into the smoothed tt-statistic.

The first experiment we perform, with the goal of understanding the effects of λ\lambda, is on the toy two moons dataset from scikit-learn [40]. We show the results in Figure 3. From the second row, showing the estimated pp-value versus the correct one (from 1000 random permutations) at several points during training, we can indeed see that the permutation null gets closer to normality as λ\lambda decreases. Most importantly, note that the relationship is monotone, so that we would expect the optimization to not be significantly harmed if we use the approximation. Qualitatively, we can observe that the solutions have the general structure of PP, and that they improve as we decrease λ\lambda — the symmetry is better captured and the two moons get better separated.

MNIST.

Finally, we have trained several models on the MNIST [41] dataset, which we present in Figure 4. We can observe that despite the high (784) dimensional data and the fact that we use the distance directly on the pixels, the learned models generate digits that look mostly realistic and are competitive with those obtained using MMD [10, 11].

Refer to caption
Refer to caption
Refer to caption
(a) Power against the alternative (μ=0.5,σ=1)(\mu=0.5,\sigma=1) from n1=n2=128n_{1}=n_{2}=128 samples.
Refer to caption
Refer to caption
Refer to caption
(b) Power against the alternative (μ=0,σ=3)(\mu=0,\sigma=3) from n1=n2=128n_{1}=n_{2}=128 samples.
Refer to caption
Refer to caption
Refer to caption
(c) Power against the alternative (μ=0,σ=3)(\mu=0,\sigma=3) from n1=n2=256n_{1}=n_{2}=256 samples.
Figure 2: Test power when comparing two normal distributions. In the first two columns we present the 33-NN and FR tests as we vary λ\lambda — we use fr-γ\gamma for λ=dγ\lambda=d^{\gamma}, and fr-ct for the classical test (analogously for 33-NN). The legends presented in the first row are consistent across the respective columns. The last column compares the best performing of these tests with the best performing MMD tests (the remaining MMD plots are provided in Appendix B). Note that our smoothed tests have the largest power, and they significantly outperform their classical counterparts.
Refer to caption
(a) Original data.
Refer to caption
Refer to caption
(b) 11-NN with λ=10\lambda=10.
Refer to caption
Refer to caption
(c) 11-NN with λ=1\lambda=1.
Refer to caption
Refer to caption
(d) 11-NN with λ=0.05\lambda=0.05.
Figure 3: The effect of varying λ\lambda on the learned model and the normality of the null statistic. Note that with decreasing λ\lambda we get closer to normality, and the learned distribution better models the true one.
Refer to caption
(a) 11-NN with λ=10\lambda=10 and n1=256n_{1}=256.
Refer to caption
(b) 11-NN with λ=10\lambda=10 and n1=512n_{1}=512.
Refer to caption
(c) FR λ=10\lambda=10 and n1=128n_{1}=128.
Refer to caption
(d) FR λ=5\lambda=5 and n1=128n_{1}=128.
Figure 4: Four different models trained on MNIST.

6 Conclusion

We have developed smoothed two-sample graph tests that can be used for learning implicit models. These tests moreover outperform their classical equivalents on the problem of two sample testing. We have shown how to compute them by performing inference in undirected models, and presented alternative interpretations by drawing connections to neighbourhood component analysis and spectral graph sparsifiers. In the last section we have experimentally showcased the benefits of our approach, and presented results from a learned model.

Acknowledgements.

The research was partially supported by ERC StG 307036 and a Google European PhD Fellowship.

References

  • Wald and Wolfowitz [1940] Abraham Wald and Jacob Wolfowitz. On a test whether two samples are from the same population. Annals of Mathematical Statistics, 11(2):147–162, 1940.
  • Friedman and Rafsky [1979] Jerome H Friedman and Lawrence C Rafsky. Multivariate generalizations of the wald-wolfowitz and smirnov two-sample tests. Annals of Statistics, pages 697–717, 1979.
  • Friedman and Rafsky [1983] Jerome H Friedman and Lawrence C Rafsky. Graph-theoretic measures of multivariate association and prediction. Annals of Statistics, pages 377–391, 1983.
  • Henze and Penrose [1999] Norbert Henze and Mathew D Penrose. On the multivariate runs test. Annals of Statistics, pages 290–298, 1999.
  • Henze [1988] Norbert Henze. A multivariate two-sample test based on the number of nearest neighbor type coincidences. Annals of Statistics, pages 772–783, 1988.
  • Schilling [1986] Mark F Schilling. Multivariate two-sample tests based on nearest neighbors. Journal of the American Statistical Association, 81(395):799–806, 1986.
  • Bhattacharya [2015] Bhaswar B Bhattacharya. Power of graph-based two-sample tests. arXiv preprint arXiv:1508.07530, 2015.
  • Chen and Zhang [2013] Hao Chen and Nancy R Zhang. Graph-based tests for two-sample comparisons of categorical data. Statistica Sinica, pages 1479–1503, 2013.
  • Gretton et al. [2012] Arthur Gretton, Karsten M Borgwardt, Malte J Rasch, Bernhard Schölkopf, and Alexander Smola. A kernel two-sample test. Journal of Machine Learning Research, 13(Mar):723–773, 2012.
  • Li et al. [2015] Yujia Li, Kevin Swersky, and Rich Zemel. Generative moment matching networks. In International Conference on Machine Learning (ICML), 2015.
  • Dziugaite et al. [2015] Gintare Karolina Dziugaite, Daniel M. Roy, and Zoubin Ghahramani. Training generative neural networks via maximum mean discrepancy optimization. In Uncertainty in Artificial Intelligence (UAI), 2015.
  • Sutherland et al. [2016] Dougal J Sutherland, Hsiao-Yu Tung, Heiko Strathmann, Soumyajit De, Aaditya Ramdas, Alex Smola, and Arthur Gretton. Generative models and model criticism via optimized maximum mean discrepancy. In International Conference on Learning Representations (ICLR), 2016.
  • Székely and Rizzo [2013] Gábor J Székely and Maria L Rizzo. Energy statistics: A class of statistics based on distances. Journal of Statistical Planning and Inference, 143(8):1249–1272, 2013.
  • Bellemare et al. [2017] Marc G Bellemare, Ivo Danihelka, Will Dabney, Shakir Mohamed, Balaji Lakshminarayanan, Stephan Hoyer, and Rémi Munos. The cramer distance as a solution to biased wasserstein gradients. arXiv preprint arXiv:1705.10743, 2017.
  • Sugiyama et al. [2012] Masashi Sugiyama, Taiji Suzuki, and Takafumi Kanamori. Density ratio estimation in machine learning. Cambridge University Press, 2012.
  • Goodfellow et al. [2014] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Advances in Neural Information Processing Systems (NIPS), pages 2672–2680, 2014.
  • Salimans et al. [2016] Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen. Improved techniques for training gans. In Advances in Neural Information Processing Systems (NIPS), pages 2234–2242, 2016.
  • Nowozin et al. [2016] Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. ff-GAN: Training generative neural samplers using variational divergence minimization. In Advances in Neural Information Processing Systems (NIPS), pages 271–279, 2016.
  • Ali and Silvey [1966] Syed Mumtaz Ali and Samuel D Silvey. A general class of coefficients of divergence of one distribution from another. Journal of the Royal Statistical Society. Series B (Methodological), pages 131–142, 1966.
  • Mohamed and Lakshminarayanan [2016] Shakir Mohamed and Balaji Lakshminarayanan. Learning in implicit generative models. arXiv preprint arXiv:1610.03483, 2016.
  • Prim [1957] Robert Clay Prim. Shortest connection networks and some generalizations. Bell Labs Technical Journal, 36(6):1389–1401, 1957.
  • Kruskal [1956] Joseph B Kruskal. On the shortest spanning subtree of a graph and the traveling salesman problem. Proceedings of the American Mathematical society, 7(1):48–50, 1956.
  • Berisha and Hero [2015] Visar Berisha and Alfred O Hero. Empirical non-parametric estimation of the fisher information. IEEE Signal Processing Letters, 22(7):988–992, 2015.
  • Wainwright and Jordan [2008] Martin J Wainwright and Michael I Jordan. Graphical models, exponential families, and variational inference. Foundations and Trends® in Machine Learning, 1(1-2), 2008.
  • Daniels [1944] Henry E Daniels. The relation between measures of correlation in the universe of sample permutations. Biometrika, 33(2):129–135, 1944.
  • Barbour and Eagleson [1986] AD Barbour and GK Eagleson. Random association of symmetric arrays. Stochastic Analysis and Applications, 4(3):239–281, 1986.
  • Yukich [2006] Joseph E Yukich. Probability theory of classical Euclidean optimization problems. Springer, 2006.
  • Tarlow et al. [2012] Daniel Tarlow, Kevin Swersky, Richard S Zemel, Ryan Prescott Adams, and Brendan J Frey. Fast exact inference for recursive cardinality models. Uncertainty in Artificial Intelligence (UAI), 2012.
  • Swersky et al. [2012] Kevin Swersky, Ilya Sutskever, Daniel Tarlow, Richard S Zemel, Ruslan R Salakhutdinov, and Ryan P Adams. Cardinality restricted boltzmann machines. In Advances in Neural Information Processing Systems (NIPS), pages 3293–3301, 2012.
  • Goldberger et al. [2005] Jacob Goldberger, Geoffrey E Hinton, Sam T Roweis, and Ruslan R Salakhutdinov. Neighbourhood components analysis. In Advances in Neural Information Processing Systems (NIPS), pages 513–520, 2005.
  • Tarlow et al. [2013] Daniel Tarlow, Kevin Swersky, Laurent Charlin, Ilya Sutskever, and Rich Zemel. Stochastic k-neighborhood selection for supervised and unsupervised learning. In International Conference on Machine Learning, pages 199–207, 2013.
  • Lyons [2003] Russell Lyons. Determinantal probability measures. Publications mathématiques de l’IHÉS, 98(1):167–212, 2003.
  • Chandra et al. [1996] Ashok K Chandra, Prabhakar Raghavan, Walter L Ruzzo, Roman Smolensky, and Prasoon Tiwari. The electrical resistance of a graph captures its commute and cover times. Computational Complexity, 6(4):312–340, 1996.
  • Spielman and Srivastava [2011] Daniel A Spielman and Nikhil Srivastava. Graph sparsification by effective resistances. SIAM Journal on Computing, 40(6):1913–1926, 2011.
  • Spielman and Teng [2014] Daniel A Spielman and Shang-Hua Teng. Nearly linear time algorithms for preconditioning and solving symmetric, diagonally dominant linear systems. SIAM Journal on Matrix Analysis and Applications, 35(3):835–885, 2014.
  • Spielman and Teng [2011] Daniel A Spielman and Shang-Hua Teng. Spectral sparsification of graphs. SIAM Journal on Computing, 40(4):981–1025, 2011.
  • Sonnenburg et al. [2010] SĆ Sonnenburg, Sebastian Henschel, Christian Widmer, Jonas Behr, Alexander Zien, Fabio de Bona, Alexander Binder, Christian Gehl, VojtÄ Franc, et al. The shogun machine learning toolbox. Journal of Machine Learning Research, 11(Jun):1799–1802, 2010.
  • Ramdas et al. [2015] Aaditya Ramdas, Sashank Jakkam Reddi, Barnabás Póczos, Aarti Singh, and Larry A Wasserman. On the decreasing power of kernel and distance based nonparametric hypothesis tests in high dimensions. In AAAI, 2015.
  • Kingma and Ba [2015] Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In International Conference on Learning Representations (ICLR), 2015.
  • Pedregosa et al. [2011] F. Pedregosa, G. Varoquaux, A. Gramfort, V. Michel, B. Thirion, O. Grisel, M. Blondel, P. Prettenhofer, R. Weiss, V. Dubourg, J. Vanderplas, A. Passos, D. Cournapeau, M. Brucher, M. Perrot, and E. Duchesnay. Scikit-learn: Machine learning in Python. Journal of Machine Learning Research, 12:2825–2830, 2011.
  • LeCun et al. [1998] Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.

Appendix A Proofs

Proof of 3.

The expectation of the statistic under H0H_{0} is (when π\pi is a uniformly random labelling)

eEμ(𝐝/λ)e𝔼π[Δπ(e)]2n1n2/n(n1)=2mn1n2/n(n1),\sum_{e\in E}\mu(\mathbf{d}/\lambda)_{e}\underbrace{{\mathbb{E}}_{\pi}[\Delta_{\pi}(e)]}_{2n_{1}n_{2}/n(n-1)}=2mn_{1}n_{2}/n(n-1),

where the inner expectation 𝔼π[Δπ(e)]{\mathbb{E}}_{\pi}[\Delta_{\pi}(e)] has been computed in [2]. We can also easily compute the variance as

e,eECovπH0[μeΔπ(e),μeΔπ(e)]\displaystyle\sum_{e,e^{\prime}\in E}{\mathrm{Cov}}_{\pi\sim H_{0}}[\mu_{e}\Delta_{\pi}(e),\mu_{e^{\prime}}\Delta_{\pi}(e^{\prime})] =e,eEμeμe𝔼πH0[Δπ(e)Δπ(e)]Πe,e4n12n22n2(n1)2m2(𝔼πH0[Tπλ])2.\displaystyle=\sum_{e,e^{\prime}\in E}\mu_{e}\mu_{e^{\prime}}\underbrace{{\mathbb{E}}_{\pi\sim H_{0}}[\Delta_{\pi}(e)\Delta_{\pi}(e^{\prime})]}_{\Pi_{e,e^{\prime}}}-\underbrace{\frac{4n_{1}^{2}n_{2}^{2}}{n^{2}(n-1)^{2}}m^{2}}_{({\mathbb{E}}_{\pi\sim H_{0}}[T^{\lambda}_{\pi^{*}}])^{2}}. (5)

Proof of 2.

We can split the sum in the variance formula over all edge pairs into three groups as follows

eeen1n2n(n1)χ1μeμe+eee4n1n2(n11)(n21)n(n1)(n2)(n3)χ1χ2μeμe+en1n2n(n1)χ1(μe2+μeμe¯),\sum_{e}\sum_{e^{\prime}\sim e}\underbrace{\frac{n_{1}n_{2}}{n(n-1)}}_{\chi_{1}}\mu_{e}\mu_{e^{\prime}}+\sum_{e}\sum_{e^{\prime}\perp e}\underbrace{\frac{4n_{1}n_{2}(n_{1}-1)(n_{2}-1)}{n(n-1)(n-2)(n-3)}}_{\chi_{1}\chi_{2}}\mu_{e}\mu_{e^{\prime}}+\sum_{e}\underbrace{\frac{n_{1}n_{2}}{n(n-1)}}_{\chi_{1}}(\mu_{e}^{2}+\mu_{e}\mu_{\overline{e}}), (6)

where ee\sum_{e^{\prime}\sim e} sums over all edges ee^{\prime} that share at least one vertex with ee, and ee\sum_{e^{\prime}\perp e} sums over those edges that share no vertex with ee, and e¯\overline{e} denote the reverse edge of ee (if it exist, zero otherwise). Note that each term μeμe\mu_{e}\mu_{e^{\prime}} appears twice if eee\neq e^{\prime}, as in the formula for the variance (5). Moreover, note that if δ(e)=δ(e)\delta(e)=\delta(e^{\prime}), then in the above formula the term μeμe\mu_{e}\mu_{e^{\prime}} (same for μeμe\mu_{e^{\prime}}\mu_{e}) gets multiplied by 2χ1=Πe,e2\chi_{1}=\Pi_{e,e^{\prime}}, as it appears in both the first and the third term. Given that assumption that |U|=m|U|=m under ν()\nu(\cdot), we also know that

m2=(eμe)2=eeμeμe=eeeμeμe+eeeμeμe,m^{2}=(\sum_{e}\mu_{e})^{2}=\sum_{e}\sum_{e^{\prime}}\mu_{e}\mu_{e^{\prime}}=\sum_{e}\sum_{e^{\prime}\sim e}\mu_{e}\mu_{e^{\prime}}+\sum_{e}\sum_{e^{\prime}\perp e}\mu_{e}\mu_{e^{\prime}},

so that eq. (6) can be simplified to

χ1eeeμeμe+χ1χ2(m2eeeμeμe)+χ1e(μe2+μeμe¯),\chi_{1}\sum_{e}\sum_{e^{\prime}\sim e}\mu_{e}\mu_{e^{\prime}}+\chi_{1}\chi_{2}(m^{2}-\sum_{e}\sum_{e^{\prime}\sim e}\mu_{e}\mu_{e^{\prime}})+\chi_{1}\sum_{e}(\mu_{e}^{2}+\mu_{e}\mu_{\overline{e}}),

which be simplified to

χ1(1χ2)eeeμeμe+χ1e(μe2+μeμe¯)+χ1χ2m2.\chi_{1}(1-\chi_{2})\sum_{e}\sum_{e^{\prime}\sim e}\mu_{e}\mu_{e^{\prime}}+\chi_{1}\sum_{e}(\mu_{e}^{2}+\mu_{e}\mu_{\overline{e}})+\chi_{1}\chi_{2}m^{2}.

Now the result follows by observing that

v(eδ(v)μe)2=eeeμeμe+eμe2+eμeμe¯.\sum_{v}(\sum_{e\in\delta(v)}\mu_{e})^{2}=\sum_{e}\sum_{e^{\prime}\sim e}\mu_{e}\mu_{e^{\prime}}+\sum_{e}\mu_{e}^{2}+\sum_{e}\mu_{e}\mu_{\overline{e}}.

To understand why this holds, let us count how many times each term μeμe\mu_{e}\mu_{e^{\prime}} appears on both sides of the equality if we expand the lhs. If eee\neq e^{\prime} and they share exactly one vertex, then the lhs will have two μeμe\mu_{e}\mu_{e^{\prime}} terms, as μe\mu_{e} and μe\mu_{e^{\prime}} will be multiplied only at the term corresponding to the shared vertex. On the other hand, if e=ee=e^{\prime} we will again have two μeμe=μe2\mu_{e}\mu_{e^{\prime}}=\mu_{e}^{2} terms, as we get one contribution from each end-point of ee. Finally, if e=e¯e^{\prime}=\overline{e}, we have a total of four μeμe\mu_{e}\mu_{e^{\prime}} terms, as we get two μeμe\mu_{e}\mu_{e^{\prime}} from each end-point. Thus, eq. (6) is equal to

χ1(1χ2)(v(eδ(v)μe)2eμe2eμeμe¯)+χ1e(μe2+μeμe¯)+χ1χ2m2.\chi_{1}(1-\chi_{2})\big{(}\sum_{v}(\sum_{e\in\delta(v)}\mu_{e})^{2}-\sum_{e}\mu_{e}^{2}-\sum_{e}\mu_{e}\mu_{\overline{e}}\big{)}+\chi_{1}\sum_{e}(\mu_{e}^{2}+\mu_{e}\mu_{\overline{e}})+\chi_{1}\chi_{2}m^{2}.

Finally, if we subtract 4χ12m24\chi_{1}^{2}m^{2} and simplify the expression we have

χ1(1χ2)v(eδ(v)μe)2+χ1χ2eμe2+χ1χ2eμeμe¯+χ1(χ24χ1)m2,\chi_{1}(1-\chi_{2})\sum_{v}(\sum_{e\in\delta(v)}\mu_{e})^{2}+\chi_{1}\chi_{2}\sum_{e}\mu_{e}^{2}+\chi_{1}\chi_{2}\sum_{e}\mu_{e}\mu_{\overline{e}}+\chi_{1}(\chi_{2}-4\chi_{1})m^{2},

which is exactly what is claimed in the theorem, if we observe that ee and e¯\overline{e} are the only edges parallel to ee. ∎

Proof that χ24χ10\chi_{2}-4\chi_{1}\geq 0 when n1=n2=n/2n_{1}=n_{2}=n/2.

First, note that n1nn11n2\frac{n_{1}}{n}\leq\frac{n_{1}-1}{n-2}, if and only if n1n2n1nn1nn_{1}n-2n_{1}\leq nn_{1}-n, which is equivalent to n112nn_{1}\geq\frac{1}{2}n. Similarly, we have n2n1n21n3\frac{n_{2}}{n-1}\leq\frac{n_{2}-1}{n-3} iff nn23n2nn2nn2+1nn_{2}-3n_{2}\leq nn_{2}-n-n_{2}+1, which can be re-written as 2n2n+1-2n_{2}\leq-n+1, i.e., n2n212n_{2}\geq\frac{n}{2}-\frac{1}{2}. Combining these two inequalities proves the result.

Proof of 4.

Let us compute an upper bound on the quantities in [26].

a1\displaystyle a_{1} =1n(n1)i,jμ¯i,j=kn\displaystyle=\frac{1}{n(n-1)}\sum_{i,j}\overline{\mu}_{i,j}=\frac{k}{n} b1\displaystyle b_{1} =2n(n1)n2n1=Θ(1)\displaystyle=\frac{2}{n(n-1)}n_{2}n_{1}=\Theta(1)
a2\displaystyle a_{2} =1n(n1)(n2)i,j,kμ¯i,jμ¯i,kS2\displaystyle=\frac{1}{n(n-1)(n-2)}\underbrace{\sum_{i,j,k}\overline{\mu}_{i,j}\overline{\mu}_{i,k}}_{S_{2}} b2\displaystyle b_{2} =n2n12+n1n22n(n1)(n2)=Θ(1)\displaystyle=\frac{n_{2}n_{1}^{2}+n_{1}n_{2}^{2}}{n(n-1)(n-2)}=\Theta(1)
a3\displaystyle a_{3} =1n(n1)(n2)(n3)i,j,k,mμ¯i,jμ¯i,kμ¯i,mS3\displaystyle=\frac{1}{n(n-1)(n-2)(n-3)}\underbrace{\sum_{i,j,k,m}\overline{\mu}_{i,j}\overline{\mu}_{i,k}\overline{\mu}_{i,m}}_{S_{3}} b3\displaystyle b_{3} =n2n13+n1n23n(n1)(n2)(n3)=Θ(1)\displaystyle=\frac{n_{2}n_{1}^{3}+n_{1}n_{2}^{3}}{n(n-1)(n-2)(n-3)}=\Theta(1)
a4\displaystyle a_{4} =1n(n1)(n2)(n3)i,j,k,mμ¯k,iμ¯i,jμ¯j,mL4\displaystyle=\frac{1}{n(n-1)(n-2)(n-3)}\underbrace{\sum_{i,j,k,m}\overline{\mu}_{k,i}\overline{\mu}_{i,j}\overline{\mu}_{j,m}}_{L_{4}} b4\displaystyle b_{4} =2n22n12n(n1)(n2)(n3)=Θ(1)\displaystyle=2\frac{n_{2}^{2}n_{1}^{2}}{n(n-1)(n-2)(n-3)}=\Theta(1)
a5\displaystyle a_{5} =1n(n1)(n2)i,j,kμ¯i,j2μ¯i,k=O(a2)\displaystyle=\frac{1}{n(n-1)(n-2)}\sum_{i,j,k}\overline{\mu}_{i,j}^{2}\overline{\mu}_{i,k}=O(a_{2}) b5\displaystyle b_{5} =b2\displaystyle=b_{2}
a6\displaystyle a_{6} =1n(n1)i,jμ¯i,j3=O(a1)\displaystyle=\frac{1}{n(n-1)}\sum_{i,j}\overline{\mu}_{i,j}^{3}=O(a_{1}) b6\displaystyle b_{6} =b1\displaystyle=b_{1}
a7\displaystyle a_{7} =1n(n1)(n2)i,j,k,mμ¯i,jμ¯i,kμ¯j,k\displaystyle=\frac{1}{n(n-1)(n-2)}\sum_{i,j,k,m}\overline{\mu}_{i,j}\overline{\mu}_{i,k}\overline{\mu}_{j,k} b7\displaystyle b_{7} =n2n1n2+n1n2n1n(n1)(n2)=Θ(1)\displaystyle=\frac{n_{2}n_{1}n_{2}+n_{1}n_{2}n_{1}}{n(n-1)(n-2)}=\Theta(1)
a8\displaystyle a_{8} =1n(n1)i,jμ¯i,j2=O(a1)\displaystyle=\frac{1}{n(n-1)}\sum_{i,j}\overline{\mu}_{i,j}^{2}=O(a_{1}) b8\displaystyle b_{8} =b1.\displaystyle=b_{1}.

Then, the upper bound has the form

1σ3[\displaystyle\frac{1}{\sigma^{3}}\big{[} n4(a13k3/n3+a1a2O(kS2/n4)+a3O(S3/n4)+a4O(L4/n4))(b13+b1b2+b3+b4)O(1)+\displaystyle n^{4}(\underbrace{a_{1}^{3}}_{k^{3}/n^{3}}+\underbrace{a_{1}a_{2}}_{O(kS_{2}/n^{4})}+\underbrace{a_{3}}_{O(S_{3}/n^{4})}+\underbrace{a_{4}}_{O(L_{4}/n^{4})})\underbrace{(b_{1}^{3}+b_{1}b_{2}+b_{3}+b_{4})}_{O(1)}+
n3(a5O(S2/n3)+a1a8O(k2/n2))(b5+b1b8)O(1)+n2a6O(k/n)b6O(1)],\displaystyle n^{3}(\underbrace{a_{5}}_{O(S_{2}/n^{3})}+\underbrace{a_{1}a_{8}}_{O(k^{2}/n^{2})})\underbrace{(b_{5}+b_{1}b_{8})}_{O(1)}+n^{2}\underbrace{a_{6}}_{O(k/n)}\underbrace{b_{6}}_{O(1)}\big{]},

which can be simplified to

O(1σ3[nk3+kS2+S3+L4+S2+nk2+k/n])=O(1σ3(nk3+kS2+S3+L4)),O(\frac{1}{\sigma^{3}}\big{[}nk^{3}+kS_{2}+S_{3}+L_{4}+S_{2}+nk^{2}+k/n\big{]})=O\big{(}\frac{1}{\sigma^{3}}(nk^{3}+kS_{2}+S_{3}+L_{4})\big{)},

which is what is claimed in the theorem.

Appendix B Experiments

B.1 MMD

Refer to caption
(a) μ=0.5,σ=1,n1=128\mu=0.5,\sigma=1,n_{1}=128.
Refer to caption
(b) μ=0,σ=3,n1=128\mu=0,\sigma=3,n_{1}=128.
Refer to caption
(c) μ=0,σ=3,n1=256\mu=0,\sigma=3,n_{1}=256.
Figure 5: The different MMD tests on the three setups in Figure 2. The legend is consistent across the panels.

B.2 Architecture

We have used the same architecture as in [10, 12], which using the modules from PyTorch can be written as follows.

nn.Sequential(
    nn.Linear(noise_dim, 64),
    nn.ReLU(),
    nn.Linear(64, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 1024),
    nn.ReLU(),
    nn.Linear(1024, ambient_dim))

For MNIST we have also added a terminal nn.Tanh layer.

B.3 Data

We have used the MNIST data as packaged by torchvision, with the additional processing of scaling the output to [1,1][-1,1] as we are using a final Tanh layer. For the two moons data, we have used a noise level of 0.050.05.

B.4 Optimization

All details are provided in the table below. In some cases we have optimized with a larger step for a number of epochs, and then reduced it for the remaining epochs — in the table below these are separated by commas.

Model Step size Batch size Epochs
Figure 3(b) 10410^{-4} 256 500
Figure 3(c) 10410^{-4} 256 500
Figure 3(d) 10410^{-4} 256 500
Figure 4(a) 103,10410^{-3},10^{-4} 256 500, 500
Figure 4(b) 103,10410^{-3},10^{-4} 512 500, 500
Figure 4(c) 103,10410^{-3},10^{-4} 128 100, 100
Figure 4(d) 104,10410^{-4},10^{-4} 128 100, 100