This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

\newaliascnt

mylemmatheorem \aliascntresetthemylemma \newaliascntmypropositiontheorem \aliascntresetthemyproposition \newaliascntmycorollarytheorem \aliascntresetthemycorollary \newaliascntmydefinitiontheorem \aliascntresetthemydefinition \newaliascntmyremarktheorem \aliascntresetthemyremark \newaliascntmyconjecturetheorem \aliascntresetthemyconjecture \newaliascntmyexampletheorem \aliascntresetthemyexample \newaliascntdefinitiontheorem \aliascntresetthedefinition \newaliascntfacttheorem \aliascntresetthefact \newaliascntclaimtheorem \aliascntresettheclaim \newaliascntquestiontheorem \aliascntresetthequestion \newaliascntexercisetheorem \aliascntresettheexercise \newaliascntnotationtheorem \aliascntresetthenotation \newaliascntproblemtheorem \aliascntresettheproblem \newaliascntremarktheorem \aliascntresettheremark

Generalization Bounds for Data-Driven Numerical Linear Algebra

Peter Bartlett
UC Berkeley
peter@berkeley.edu
   Piotr Indyk
MIT
indyk@mit.edu
   Tal Wagner
Microsoft Research
tal.wagner@gmail.com
Abstract

Data-driven algorithms can adapt their internal structure or parameters to inputs from unknown application-specific distributions, by learning from a training sample of inputs. Several recent works have applied this approach to problems in numerical linear algebra, obtaining significant empirical gains in performance. However, no theoretical explanation for their success was known.

In this work we prove generalization bounds for those algorithms, within the PAC-learning framework for data-driven algorithm selection proposed by Gupta and Roughgarden (SICOMP 2017). Our main results are closely matching upper and lower bounds on the fat shattering dimension of the learning-based low rank approximation algorithm of Indyk et al. (NeurIPS 2019). Our techniques are general, and provide generalization bounds for many other recently proposed data-driven algorithms in numerical linear algebra, covering both sketching-based and multigrid-based methods. This considerably broadens the class of data-driven algorithms for which a PAC-learning analysis is available.

1 Introduction

Traditionally, algorithms are formally designed to handle a single unknown input. In reality, however, computational problems are often solved on multiple different yet related inputs over time. It is therefore natural to tailor the algorithm to the inputs it encounters, and leverage — or learn from — past inputs, in order to improve performance on future ones.

This paradigm addresses common scenarios in which we need to make design choices about the algorithm, such as setting parameters or selecting algorithmic components. It advocates making those design choices in an automated way, based on past inputs viewed as a training set, premised to share underlying similarities with future inputs. Rather than trying to model or gauge these similarities explicitly, the idea is to let a learning mechanism detect and exploit them implicitly. This is often referred to as self-improving, data-driven or learning-based algorithm design, and by now has been applied successfully to a host of computational problems in various domains [ACC+11, GR17, Bal20, MV20].

In particular, this approach has recently become popular in efficient algorithms for computational problems in linear algebra. Matrix computations are very widely used and are often hard to scale, leading to an enormous body of work on developing fast approximate algorithms for them. These algorithms often rely on internal auxiliary matrices, like the sketching matrix in sketch-based methods [Woo14] or the prolongation matrix in algebraic multigrid methods [BHM00], and the choice of auxiliary matrix is crucial for good performance. Classically, these auxiliary matrices are chosen either at random (from carefully designed distributions which are oblivious to the input) or handcrafted via elaborate heuristic methods. However, a recent surge of work on learning-based linear algebra shows they can be successfully learned in an automated way from past inputs [IVY19, ALN20, LLV+20, LGM+20, IWW21].

The empirical success of learning-based algorithms has naturally led to seeking theoretical frameworks for reasoning about them. [GR17] suggested modeling the problem in terms of statistical learning, and initiated a PAC-learning theory for algorithm selection. In their formulation, the inputs are drawn independently from an unknown distribution 𝒟\mathcal{D}, and the goal is to choose an algorithm from a given class \mathcal{L}, so as to approximately optimize the expected performance on 𝒟\mathcal{D}. One then proves upper bounds on the pseudo-dimension or the fat shattering dimension of \mathcal{L}, which are analogs of the VC-dimension for real-valued function classes. By classical VC theory, such bounds imply generalization, i.e., that an approximately optimal algorithm for 𝒟\mathcal{D} can be chosen from a bounded number of samples (proportional to either notion of dimension). Building on this framework, subsequent works have developed sets of tools for proving pseudo-dimension bounds for learning-based algorithms for various problems, focusing on combinatorial optimization and mechanism design [BNVW17, BDSV18, BSV18, BSV20, BDD+21, BPSV21].

However, existing techniques do not yield generalization bounds for the learning-based linear algebraic algorithms mentioned above, and so far, no generalization bounds were known for them.

Our contribution.

We develop a new approach for proving PAC-learning generalization bounds for learning-based algorithms, which is applicable to linear algebraic problems. For concreteness, we mostly focus on the learning-based low rank approximation algorithm of [IVY19], called IVY. Operating on input matrices of order n×nn\times n, IVY learns an auxiliary sketching matrix specified by nn real parameters. Our main results is, loosely speaking, that the fat shattering dimension of IVY is Θ~(n)\widetilde{\Theta}(n), which yields a similar bound on its sample complexity.111Throughout, we use O~(f)\widetilde{O}(f) for O(fpolylog(f))O(f\cdot\mathrm{polylog}(f)). The precise bound we prove also depends on the fatness parameter ϵ\epsilon, the target low rank kk, and the row-dimension mm and sparsity ss of the learned sketching matrix; for now, it is instructive to think of all of them as constants. See Section 2.3 for the full result.

We proceed to show that the tools we develop also lead to pseudo-dimension or fat shattering dimension bounds for various other learning-based linear algebraic algorithms known in the literature, which in addition to IVY include the learned low rank approximation algorithms of [ALN20], [LLV+20] and [IWW21], and the learned regression algorithm of [LGM+20]. These algorithms represent both the sketch-based and the multigrid-based approaches in numerical linear algebra.

Our work significantly advances the line of research on the statistical foundations of learning-based algorithms, by extending the PAC-learning framework to a prominent and well-studied area of algorithms not previously covered by it, through developing a novel set of techniques.

2 Background and Main Results

2.1 Generalization Framework for Learning-Based Algorithms

The following is the PAC-learning framework for algorithm selection, due to [GR17]. Let 𝒳\mathcal{X} be a domain of inputs for a given computational problem, and let 𝒟\mathcal{D} be a distribution over 𝒳\mathcal{X}. Let \mathcal{L} be a class of algorithms that operate on inputs from 𝒳\mathcal{X}. We identify each algorithm LL\in\mathcal{L} with a function L:𝒳[0,1]L:\mathcal{X}\rightarrow[0,1] that maps the result of running LL on xx to a loss L(x)L(x), where we shift and normalize the loss to be in [0,1][0,1] for convenience. Let LL^{*}\in\mathcal{L} be an optimal algorithm in expectation on 𝒟\mathcal{D}, i.e., L=argminL𝔼x𝒟[L(x)]L^{*}=\mathrm{argmin}_{L\in\mathcal{L}}\;\mathbb{E}_{x\sim\mathcal{D}}[L(x)].

We wish to find an algorithm in \mathcal{L} that approximately matches the performance of LL^{*}. To this end, we draw \ell samples X~={x1,,x}\tilde{X}=\{x_{1},\ldots,x_{\ell}\} from 𝒟\mathcal{D}, and apply a learning procedure that maps X~\tilde{X} to an algorithm LL\in\mathcal{L}. We say that \mathcal{L} is (ϵ,δ)(\epsilon,\delta)-learnable with \ell samples if there is a learning procedure that maps X~\tilde{X} to LL\in\mathcal{L} such that

PrX~[𝔼xD[L(x)]𝔼xD[L(x)]+ϵ]1δ.\Pr_{\tilde{X}}[\mathbb{E}_{x\sim D}[L(x)]\leq\mathbb{E}_{x\sim D}[L^{*}(x)]+\epsilon]\geq 1-\delta.

A canonical learning procedure is Empirical Risk Minimization (ERM), which maps X~\tilde{X} to an algorithm that minimizes the average loss over the samples, i.e., to argminL1i=1L(xi)\mathrm{argmin}_{L\in\mathcal{L}}\;\frac{1}{\ell}\sum_{i=1}^{\ell}L(x_{i}). We say that \mathcal{L} admits (ϵ,δ)(\epsilon,\delta)-uniform convergence with \ell samples if

PrX~[L|1i=1L(xi)𝔼x𝒟[L(x)]|ϵ]1δ.\Pr_{\tilde{X}}\left[\forall L\in\mathcal{L}\;\;\left|\frac{1}{\ell}\sum_{i=1}^{\ell}L(x_{i})-\mathbb{E}_{x\sim\mathcal{D}}[L(x)]\right|\leq\epsilon\right]\geq 1-\delta.

It is straightforward that (ϵ,δ)(\epsilon,\delta)-uniform convergence implies (2ϵ,δ)(2\epsilon,\delta)-learnability with ERM with the same number of samples. To bound the number of samples, we have the following notions.

Definition \themydefinition (pseudo-dimension and fat shattering dimension).

Let X={x1,,xN}𝒳X=\{x_{1},\ldots,x_{N}\}\subset\mathcal{X}. We say that XX is pseudo-shattered by \mathcal{L} if there are thresholds r1,,rNr_{1},\ldots,r_{N}\in\mathbb{R} such that,

I{1,,N}Ls.t.L(xi)>riiI.\forall I\subset\{1,\ldots,N\}\;\exists L\subset\mathcal{L}\;\;\;\;\text{s.t.}\;\;\;\;L(x_{i})>r_{i}\Leftrightarrow i\in I.

For γ>0\gamma>0, we say that XX is γ\gamma-fat shattered by \mathcal{L} if there are thresholds r1,,rNr_{1},\ldots,r_{N}\in\mathbb{R} such that,

I{1,,N}Ls.t.iIL(xi)>ri+γandiIL(xi)<riγ.\forall I\subset\{1,\ldots,N\}\;\exists L\subset\mathcal{L}\;\;\;\;\text{s.t.}\;\;\;\;i\in I\Rightarrow L(x_{i})>r_{i}+\gamma\;\;\;\;\text{and}\;\;\;\;i\notin I\Rightarrow L(x_{i})<r_{i}-\gamma.

The pseudo-dimension of \mathcal{L}, denoted pdim()\mathrm{pdim}(\mathcal{L}), is the maximum size of a pseudo-shattered set. The γ\gamma-fat shattering dimension of \mathcal{L}, denoted fatdimγ()\mathrm{fatdim}_{\gamma}(\mathcal{L}), is the maximum size of a γ\gamma-fat shattered set.

Note that the pseudo-dimension is an upper bound on the γ\gamma-fat shattering dimension for every γ>0\gamma>0. Classical results in PAC-learning theory show that these quantities govern the number of samples needed for (ϵ,δ)(\epsilon,\delta)-uniform convergence, and thus for (ϵ,δ)(\epsilon,\delta)-learning with ERM.222For example, a standard result is that O(ϵ2(pdim()+log(δ1)))O(\epsilon^{-2}\cdot(\mathrm{pdim}(\mathcal{L})+\log(\delta^{-1}))) samples suffice for (ϵ,δ)(\epsilon,\delta)-uniform convergence, and thus also for (ϵ,δ)(\epsilon,\delta)-learning with ERM. The pseudo-dimension yields somewhat tighter upper bounds for learning than the fat shattering dimension, while the latter also yields lower bounds. See, for example, Theorems 19.1, 19.2, 19.5 in [AB09]. Therefore, a typical goal is to prove upper and lower bounds on them.

2.2 Learning-Based Low Rank Approximation (LRA)

Let An×dA\in\mathbb{R}^{n\times d} be an input matrix, with ndn\geq d. By normalization, we assume throughout the paper that AF2=1\lVert A\rVert_{F}^{2}=1. Let [A]k[A]_{k} denote the optimal rank-kk approximation of AA in the Frobenius norm, i.e.,

[A]k=argminAn×d of rank kAAF2.[A]_{k}=\mathrm{argmin}_{A^{\prime}\in\mathbb{R}^{n\times d}\text{ of rank }k}\lVert A-A^{\prime}\rVert_{F}^{2}.

It is well-known that [A]k[A]_{k} can be computed using the singular value decomposition (SVD), in time O(nd2)O(nd^{2}). Since this running time does not scale well for large matrices, a very large body of work has been dedicated to developing fast approximate LRA algorithms, which output an matrix AA^{\prime} of rank kk that attains error close to [A]k[A]_{k}, see surveys by [Mah11, Woo14, MT20].

SCW.

The SCW algorithm for LRA is due to [CW13], building on [Sar06, CW09]. It uses an auxiliary sketching matrix Sm×nS\in\mathbb{R}^{m\times n}, where its row-dimension mm is called the sketching dimension, and is chosen to be slightly larger than kk and much smaller than nn. The algorithm is specified in Algorithm 1.

In [CW13], SS is chosen at random from an data-oblivious distribution as follows. Let s{1,,m}s\in\{1,\ldots,m\} be a sparsity parameter. One chooses ss uniformly random entries in each column of SS, and chooses the value of each of them uniformly at random from {1,1}\{1,-1\}. The rest of the entries in SS are zero. [CW13] show that given ϵ>0\epsilon>0, by setting the sketching dimension to m=O~(k2/ϵ2)m=\widetilde{O}(k^{2}/\epsilon^{2}), even with sparsity s=1s=1, SCW returns AA^{\prime} of rank kk that satisfies AAF2(1+ϵ)A[A]kF2\lVert A-A^{\prime}\rVert_{F}^{2}\leq(1+\epsilon)\lVert A-[A]_{k}\rVert_{F}^{2} with high probability (over the random choice of SS), while running in time nearly linear in the size of AA.

Algorithm 1 The SCW low rank approximation algorithm

Input: An×dA\in\mathbb{R}^{n\times d}, target rank kk, auxiliary matrix Sm×nS\in\mathbb{R}^{m\times n}. Output: An×dA^{\prime}\in\mathbb{R}^{n\times d} of rank kk.  

1:  Compute the product SASA.
2:  If SASA is a zero matrix, return the zero matrix of order n×dn\times d.
3:  Compute the SVD U,Σ,VTU,\Sigma,V^{T} of SASA.
4:  Compute the product AVAV.
5:  Compute and return [AV]kVT[AV]_{k}V^{T}.

IVY.

The IVY algorithm due to [IVY19] is a learning-based variant of SCW. The sparsity pattern of SS is chosen similarly to SCW (ss non-zero entries per column, chosen uniformly at random) and remains fixed. However, the values of the non-zero entries in SS are now trainable parameters, learned from a training set of input matrices. In particular, SS is chosen by minimizing the empirical loss A𝒜trainASCWk(S,A)F2\sum_{A\in\mathcal{A}_{train}}\lVert A-\mathrm{SCW}_{k}(S,A)\rVert_{F}^{2}, where SCWk(S,A)\mathrm{SCW}_{k}(S,A) denotes the output matrix of Algorithm 1, using stochastic gradient descent (SGD) on a training set 𝒜trainn×d\mathcal{A}_{train}\subset\mathbb{R}^{n\times d}.

To cast IVY in the statistical learning framework from Section 2.1, let 𝒜\mathcal{A} denote the set of possible input matrices (i.e., all An×dA\in\mathbb{R}^{n\times d} with AF2=1\lVert A\rVert_{F}^{2}=1), and let 𝒮\mathcal{S} denote the set of possible sketching matrices (i.e., all Sm×nS\in\mathbb{R}^{m\times n} with the fixed sparsity pattern specified above). Every S𝒮S\in\mathcal{S} gives rise to an LRA algorithm, whose output rank-kk approximation of AA is A=SCWk(S,A)A^{\prime}=\mathrm{SCW}_{k}(S,A), and whose associated loss function LkSCW(S,):𝒜[0,1]L^{\mathrm{SCW}}_{k}(S,\cdot):\mathcal{A}\rightarrow[0,1] is given by333It can be shown that the loss never exceeds AF2\lVert A\rVert_{F}^{2} for any AA and SS, and recalling that we assume that all input matrices are normalized so that AF2=1\lVert A\rVert_{F}^{2}=1, the loss is contained in [0,1][0,1].

LkSCW(S,A)=ASCWk(S,A)F2.L^{\mathrm{SCW}}_{k}(S,A)=\lVert A-\mathrm{SCW}_{k}(S,A)\rVert_{F}^{2}.

Given samples from an unknown distribution 𝒟\mathcal{D} over 𝒜\mathcal{A}, the learning problem is to choose S𝒮S\in\mathcal{S} which has approximately optimal loss in expectation over 𝒟\mathcal{D}. Our objective is to prove generalization bounds for learning a sketching matrix S𝒮S\in\mathcal{S}, or equivalently, for learning the class of loss functions IVY={LkSCW(S,)}S𝒮\mathcal{L}_{\mathrm{IVY}}=\{L^{\mathrm{SCW}}_{k}(S,\cdot)\}_{S\in\mathcal{S}}.

2.3 Our Main Results

We prove closely matching upper and lower bounds on the fat shattering dimension of IVY.

Theorem 2.1.

For every ϵ>0\epsilon>0 smaller than a sufficiently small constant, the ϵ\epsilon-fat shattering dimension of IVY, fatdimϵ(IVY)\mathrm{fatdim}_{\epsilon}(\mathcal{L}_{\mathrm{IVY}}), satisfies

fatdimϵ(IVY)O(ns(m+klog(d/k)+log(1/ϵ))).\mathrm{fatdim}_{\epsilon}(\mathcal{L}_{\mathrm{IVY}})\leq O(ns\cdot(m+k\log(d/k)+\log(1/\epsilon))).

Furthermore, if ϵ<1/(2k)\epsilon<1/(2\sqrt{k}), then fatdimϵ(IVY)Ω(ns)\mathrm{fatdim}_{\epsilon}(\mathcal{L}_{\mathrm{IVY}})\geq\Omega(ns).

These results translate to sample complexity bounds on its uniform convergence and learnability:

Theorem 2.2.

Let ϵ,δ>0\epsilon,\delta>0 be smaller than sufficiently small constants. The number of samples needed for (ϵ,δ)(\epsilon,\delta)-uniform convergence for IVY, and thus for (ϵ,δ)(\epsilon,\delta)-learning the sketching matrix of IVY with ERM, is

O(ϵ2(ns(m+klog(d/k)+log(1/ϵ))+log(1/δ))).O(\epsilon^{-2}\cdot(ns\cdot(m+k\log(d/k)+\log(1/\epsilon))+\log(1/\delta))).

Furthermore, if ϵ1/(256k)\epsilon\leq 1/(256\sqrt{k}), then (ϵ,ϵ)(\epsilon,\epsilon)-uniform convergence for IVY requires Ω(ϵ2ns/k)\Omega(\epsilon^{-2}\cdot ns/k) samples, and if ϵ<1/(2k)\epsilon<1/(2\sqrt{k}), then (ϵ,δ)(\epsilon,\delta)-learning IVY with any learning procedure requires Ω(ϵ1+ns)\Omega(\epsilon^{-1}+ns) samples.

For the sake of intuition, let us comment on typical settings for the various sizes in these results. The input matrix is of order n×dn\times d, where we think of n,dn,d as being arbitrarily large.444Recall we assume that dnd\leq n by convention. For all conceptual purposes it suffices to consider the square case n=dn=d. The target rank kk can be thought of as a small integer constant, and the sketching dimension mm as only slightly larger (the larger mm is, the slower but more accurate the LRA algorithm would be). Both are generally independent of n,dn,d. Concretely, the empirical evaluations in [IVY19, IVWW19, IWW21, ALN20, LLV+20] use kk up to 4040, and mm up to 4k4k, for matrices with thousands of rows and columns. The sparsity ss is often chosen to be the smallest possible, s=1s=1, while some have found it beneficial to make it slightly larger, up to 88 [TYUC19, MT20]. The upshot is that the upper and lower bounds in Theorem 2.1 are essentially matching up to a factor of log(d/ϵ)\log(d/\epsilon). Furthermore, both are essentially proportional to nsns, which is the number of non-zero entries in the sketching matrix, i.e., the number of trainable parameters that IVY learns.

Other learning-based algorithms.

While we focus on IVY for concreteness, our techniques also yield bounds for many other learning-based linear algebraic algorithms. See Section 6.

Remark \theremark (on computational efficiency).

Like most prior work on PAC-learning, our results focus on the sample complexity of the learning problem, rather than the computational efficiency of learning (see, e.g., Remark 3.5 in [GR17]). It is currently unknown whether there exist efficient ERM learners for the data-driven algorithms considered in this work. In particular, while in practice IVY uses SGD to minimize the empirial loss, it is not known to provably converge to a global minimum. The computational efficiency of ERM for backpropagation-based algorithms (like IVY) remains a challenging open problem.

Remark \theremark (on the precision of real number representation).

We prove Theorem 2.1 without any restriction on the numerical precision of real numbers in the computational model, that is, even when they may have unbounded precision. If one considers only bounded precision models, say where real numbers are represented by bb-bit machine words, then the total number of possible sketching matrices is 2nsb2^{nsb}, leading trivially to a bound of O(nsb)O(nsb) on the pseudo-dimension of IVY. On the other hand, the lower bound in Theorem 2.1 holds even with 11-bit machine words. In summary, both our upper and lower bounds are proven in the most general computational model.

2.4 Technical Overview

Our starting point is a general framework of [GJ95] for bounding the VC-dimension or the pseudo-dimension. They showed that if the functions in a class (which in our case are the losses of candidate algorithms) can be computed by a certain type of algorithm, which we call a GJ algorithm, then the pseudo-dimension of the function class can be bounded in terms of the running time of that algorithm. We employ a simple yet crucial refinement of this framework, relying on more refined complexity measures of GJ algorithms than the running time.

Under this framework, ideally we would have liked to compute the SCW loss with a GJ algorithm which is efficient in these refined complexity measures, thus obtaining a bound on the pseudo-dimension of IVY. Unfortunately, it is not clear how to compute the SCW loss at all with a GJ algorithm, since the GJ framework only allows a narrow set of operations. Nonetheless, we show it can be approximated by a GJ algorithm, by relying on recent advances in numerical linear algebra, and specifically on “gap-independent” analyses of power method iterations for LRA, that do not depend on numerical properties of the input matrix (like eigenvalue gaps) [RST10, HMT11, BDMI14, Woo14, WC15, MM15]. We thus obtain a pseudo-dimension bound for the approximate loss, which translates to a fat shattering dimension bound for the true SCW loss, yielding generalization results for IVY, as well as for other learning-based numerical linear algebra algorithms.

Proof roadmap.

Section 3 presents the [GJ95] framework. In Section 4, as a warm-up, we prove tight bounds for IVY in the simple case m=k=1m=k=1. Section 5 contains the main proof of this paper, establishing the upper bound from Theorem 2.1. The lower bound is proven in Appendix C, together completing the proof of Theorem 2.1. Theorem 2.2 follows from Theorem 2.1 using mostly standard arguments, given in Appendix D.

2.5 Related Work

[IVY19] gave a “safeguarding” technique for IVY, showing it can be easily modified to guarantee that the learned algorithm never performs worse than the oblivious SCW algorithm on future matrices. The modification is simply to concatenate an oblivious sketching matrix vertically to the learned sketching matrix. While this result guarantees that learning does not hurt, it does not show that learning can help, and provides no generalization guarantees.

[IWW21] prove “consistency” results for their learning-based algorithms, showing that the learned sketching matrix performs well on the training input matrices from which it was learned. These results have no bearing on future matrices, and again provide no generalization guarantees.

[BDD+21] recently gave a general technique for proving generalization bounds for learning-based algorithm in the statistical learning framework of Section 2.1, based on piecewise decompositions of dual function classes. While there are some formal connections between their techniques and ours, their approach does not yield useful bounds for the linear algebraic algorithms that we study, since these algorithms do not exhibit a sufficiently simple piecewise dual structure as the technique of [BDD+21] requires. We discuss this in more detail in Appendix E.

3 The Goldberg-Jerrum Framework

Our upper bounds are based on a general framework due to [GJ95] for bounding the VC-dimension. We instantiate a refined version of it, which still follows immediately from their proofs. For completeness and self-containedness, Appendix A reproduces the proof of the variant we state below.

Definition \themydefinition.

A GJ algorithm Γ\Gamma operates on real-valued inputs, and can perform two types of operations:

  • Arithmetic operations of the form v′′=vvv^{\prime\prime}=v\odot v^{\prime}, where {+,,×,÷}\odot\in\{+,-,\times,\div\}.

  • Conditional statements of the form “if v0v\geq 0 … else …”.

In both cases, v,vv,v^{\prime} are either inputs or values previously computed by the algorithm.

Every intermediate value computed by Γ\Gamma is a multivariate rational function (i.e., the ratio of two polynomials) of its inputs. The degree of a rational function is the maximum of the degrees of the two polynomials in its numerator and denominator, when written as a reduced fraction. We now define two complexity measures of GJ algorithms.

Definition \themydefinition.

The degree of a GJ algorithm is the maximum degree of any rational function it computes of the inputs. The predicate complexity of a GJ algorithm is the number of distinct rational functions that appear in its conditional statements.

Theorem 3.1.

Using the notation of Section 2.1, suppose that each algorithm LL\in\mathcal{L} is specified by nn real parameters. Suppose that for every x𝒳x\in\mathcal{X} and rr\in\mathbb{R} there is a GJ algorithm Γx,r\Gamma_{x,r} that given LL\in\mathcal{L}, returns “true” if L(x)>rL(x)>r and “false” otherwise. Suppose Γx,r\Gamma_{x,r} has degree Δ\Delta and predicate complexity pp. Then, the pseudo-dimension of \mathcal{L} is at most O(nlog(Δp))O(n\log(\Delta p)).

Let us comment on the relation between this statement and the original theorem from [GJ95]. Their statement is that if Γx,r\Gamma_{x,r} has running time tt then the pseudo-dimension of \mathcal{L} is O(nt)O(nt). They prove it by implcitly proving Theorem 3.1, as it is not hard to see that if the running time is tt then both the degree and the predicate complexity are at most 2t2^{t}.

To illustrate why the refinement stated in Theorem 3.1 could be useful, let us first consider the degree. Consider computing qq power method iterations for a matrix Mn×nM\in\mathbb{R}^{n\times n}, i.e., computing MqπM^{q}\pi with some initial vector πn\pi\in\mathbb{R}^{n}. The running time is t=O(n2q)t=O(n^{2}q), so the pseudo-dimension upper bound we get based on the running time alone is O(nt)=O(n3q)O(nt)=O(n^{3}q). However, the degree of every entry in MqπM^{q}\pi (when considered as a rational function of the entries of MM and π\pi) is just q+1q+1, and hence Theorem 3.1 yields the better upper bound O(nlogq)O(n\log q).

As for the predicate complexity, consider for example a GJ algorithm for choosing the minimum of rr numbers (this operation will be useful for us in derandomizing the power method). The running time is t=O(r)t=O(r), and this is tight since GJ algorithms adhere to the comparison model, so the pseudo-dimension upper bound we get based on the running time alone is O(nt)=O(nr)O(nt)=O(nr). However, the predicate complexity is (r2){r\choose 2}, and the degree is just 11, since choosing the minimum among v1,,vrv_{1},\ldots,v_{r}\in\mathbb{R} involves only the polynomials vivjv_{i}-v_{j} for every i<ji<j. Hence, Theorem 3.1 yields the better upper bound O(nlogr)O(n\log r). The same reasoning applies to sorting rr numbers (see Section E.1 for an example that uses this operation), and to other useful subroutines.

Remark \theremark (oracle access to optimal solution).

It is also worth observing that Γx,r\Gamma_{x,r} has free access to all information about xx — and in particular, to the optimal solution of the computational problem addressed by \mathcal{L} with xx as the input— without any computational cost. For example, in LRA, Γx,r\Gamma_{x,r} has free access to the exact SVD factorization of the input matrix AA. (Note that the difficulty for Γx,r\Gamma_{x,r} lies in computing the losses of other than optimal solutions, namely the SCW loss induced by the given sketching matrix SS). Similarly, in a regression problem minyAyb\min_{y}\lVert Ay-b\rVert, the input xx is the pair A,bA,b, and Γx,r\Gamma_{x,r} has free access to the optimal solution y=argminyAyby^{*}=\mathrm{argmin}_{y}\lVert Ay-b\rVert. These observations are useful in proving bounds for some of the algorithms we consider in Section 6.

4 Warm-up: The Case m=k=1m=k=1

As a warm-up, we consider the simple case where both the target rank kk and the sketching dimension is mm are 11. We show an upper bound of O(n)O(n) on the pseudo-dimension and a lower bound of Ω(n)\Omega(n) on the ϵ\epsilon-fat shattering dimension for every ϵ(0,0.5)\epsilon\in(0,0.5) (this immediately implies that both are Θ(n)\Theta(n)). This case is particularly simple because the SCW loss admits a closed-form expression, given next (the proof appears in Section B.1). Note that the sketching matrix in this case is just a vector, denoted wnw\in\mathbb{R}^{n} for the rest of this section, and it is specified by nn real parameters.

Fact \thefact.

The loss of SCWSCW with input matrix An×dA\in\mathbb{R}^{n\times d}, sketching vector wnw\in\mathbb{R}^{n} and target rank k=1k=1 is A1ATw2AATwwTAF2\lVert A-\frac{1}{\lVert A^{T}w\rVert^{2}}AA^{T}ww^{T}A\rVert_{F}^{2}.

Theorem 4.1.

The pseudo-dimension of IVY in the m=k=1m=k=1 case is Θ(n)\Theta(n), and the ϵ\epsilon-fat shattering dimension is Θ(n)\Theta(n) for every ϵ(0,0.5)\epsilon\in(0,0.5)

Proof.

Since the pseudo-dimension is an upper bound on the fat shattering dimension, it suffices to prove the upper bound for the former and the lower bound for the latter. For a fixed matrix AA and threshold rr\in\mathbb{R}, a GJ algorithm ΓA,r\Gamma_{A,r} can evaluate the loss formula from Section 4 on a given ww with degree O(1)O(1), and compare the loss to rr with predicate complexity 11. Therefore, by Theorem 3.1, the pseudo-dimension is O(n)O(n).

For the lower bound, we first argue that if AA is a rank-11 matrix, then SCW with sketching vector ww attains zero loss if and only if wTA0w^{T}A\neq 0. Indeed, by Section 4, if wTA=0w^{T}A=0 then the loss is AF2\lVert A\rVert_{F}^{2}, which is non-zero for a rank-11 matrix. If wTA0w^{T}A\neq 0, we write AA in SVD form A=σuvTA=\sigma uv^{T} with σ\sigma\in\mathbb{R}, unu\in\mathbb{R}^{n}, vdv\in\mathbb{R}^{d}, and u=v=1\lVert u\rVert=\lVert v\rVert=1, and by plugging this into Section 4, the loss is 0.

Now, consider the set of matrices A1,,Ann×dA_{1},\ldots,A_{n}\in\mathbb{R}^{n\times d}, where AiA_{i} is all zeros except for a single 11-entry in row ii, column 11. For every I{1,,n}I\subset\{1,\ldots,n\}, let wInw_{I}\in\mathbb{R}^{n} be its indicator vector. By the above, SCW with sketching vector wIw_{I} attains zero loss on the matrices {Ai:iI}\{A_{i}:i\in I\}, and loss AF2=1\lVert A\rVert_{F}^{2}=1 on the rest. Therefore this set of nn matrices is ϵ\epsilon-fat shattered for every ϵ(0,0.5)\epsilon\in(0,0.5). ∎

5 Proof of the Upper Bound

In this section we prove the upper bound in Theorem 2.1 on the fat shattering dimension of IVY. We start by stating a formula for the output rank-kk matrix of SCW (Algorithm 1). For a matrix MM, we use MM^{\dagger} to denote its Moore-Penrose pseudo-inverse, and recall that we use [M]k[M]_{k} to denote its best rank-kk approximation. All of the proofs omitted from this section are collected in Appendix B.

Lemma \themylemma.

The output matrix of the SCW algorithm equals [A(SA)(SA)]k[A(SA)^{\dagger}(SA)]_{k}.

5.1 Computing Projection Matrices

Next, we give a GJ algorithm for computing orthogonal projection matrices. Recall that the orthogonal projection matrix on the row-space of a matrix ZZ is given by ZZZ^{\dagger}Z.

Lemma \themylemma.

There is a GJ algorithm that given an input matrix ZZ with kk rows, computes ZZZ^{\dagger}Z. The algorithm has degree O(k)O(k) and predicate complexity at most 2k2^{k}. Furthermore, if ZZ is promised to have full rank kk, then the predicate complexity is zero.

Proof.

We start by proving the “furthermore” part, supposing ZZ has full rank kk. Since ZZ has full rank kk, by a known identity for the pseudo-inverse we can write ZZ=ZT(ZZT)1ZZ^{\dagger}Z=Z^{T}(ZZ^{T})^{-1}Z. We use the matrix inversion algorithm from [Csa76] (similar/identical algorithms have appeared in other places as well), to invert an invertible k×kk\times k matrix MM. Define the following matrices,

B1=I,Bi=MBi1tr(MBi1)k1I.B_{1}=I,\;\;\;\;B_{i}=MB_{i-1}-\frac{\mathrm{tr}(MB_{i-1})}{k-1}\cdot I.

The inverse is then given by,

M1=ktr(MBk)Bk.M^{-1}=\frac{k}{\mathrm{tr}(MB_{k})}\cdot B_{k}.

Note that each entry of BiB_{i} is a polynomial of degree i1i-1 in the entries of MM, and therefore each entry of M1M^{-1} is a rational function of degree k1k-1 in the entries of MM. In our case M=ZZTM=ZZ^{T}, and the desired output is ZT(ZZT)1ZZ^{T}(ZZ^{T})^{-1}Z, and therefore it has degree O(k)O(k). This finishes the proof of the “furthermore” part.

We proceed to showing the lemma without the full rank assumption. Suppose MM has rank rkr\leq k which could be strictly smaller than kk (and rr need not be known to the algorithm). If we find a matrix Yr×dY\in\mathbb{R}^{r\times d} whose rows form a basis for the row-space of MM, then we could write the desired projection matrix as ZZ=YY=YT(YYT)1YZ^{\dagger}Z=Y^{\dagger}Y=Y^{T}(YY^{T})^{-1}Y, and compute it with a GJ-algorithm of degree O(r)O(k)O(r)\leq O(k) in the entries of YY, as above.

So, it remains to compute such a matrix YY. To this end we use another result from [Csa76]: if MM is a k×kk\times k matrix with characteristic polynomial fM(λ)=det(λIM)=i=0kciλkif_{M}(\lambda)=\mathrm{det}(\lambda I-M)=\sum_{i=0}^{k}c_{i}\lambda^{k-i}, then ci=1itr(MBi)c_{i}=-\frac{1}{i}\mathrm{tr}(MB_{i}), with BiB_{i} defined as above. In particular, the free coefficient ckc_{k} can be computed by a GJ algorithm of degree O(k)O(k). This yields a GJ algorithm for determining whether a k×kk\times k matrix has full rank, since this holds if and only if ck0c_{k}\neq 0.

Now, to compute YY from ZZ, we can simply go over the rows of ZZ one by one, tentatively add each row to our (partial) YY, and keep or remove it depending on whether YY still has full row rank. To check this, we compute YYTYY^{T} (which is a square matrix of order at most rkr\leq k) and check whether it has full rank as above. Overall, YY and therefore the output YT(YYT)1YY^{T}(YY^{T})^{-1}Y are computed using a GJ algorithm with degree O(k)O(k). Choosing YY involves conditional statements with up to 2k2^{k} rational functions, which are the free coefficients of the characteristic polynomials of every subset of the kk rows of ZZ, so the predicate complexity is at most 2k2^{k}. ∎

We remark that this lemma already yields an upper bound on the pseudo-dimension of IVY in the case m=km=k. The full regime mkm\geq k is more involved and will occupy the rest of this section.

Corollary \themycorollary.

The pseudo-dimension of IVY with m=km=k is O(nsk)O(nsk).

Proof.

SS has m=km=k rows, hence SASA has rank at most kk, hence [A(SA)(SA)]k=A(SA)(SA)[A(SA)^{\dagger}(SA)]_{k}=A(SA)^{\dagger}(SA). By Section 5, the SCW loss is AA(SA)(SA)F2\lVert A-A(SA)^{\dagger}(SA)\rVert_{F}^{2}. By Section 5.1, it can be computed with a GJ algorithm of degree O(k)O(k) and predicate complexity 2k2^{k}. Recalling that the learned sketching matrix in IVY is specified by nsns real parameters, the corollary follows from Theorem 3.1. ∎

5.2 Derandomized Power Method Iterations

For the rest of this section we denote B=A(SA)(SA)B=A(SA)^{\dagger}(SA) for brevity. Note that BB is of order n×dn\times d. By Section 5 the SCW loss equals A[B]kF2\lVert A-[B]_{k}\rVert_{F}^{2}, while Section 5.1 shows how to compute BB with a GJ algorithm. Our next goal is to approximately compute [B]k[B]_{k} with a GJ algorithm.

To this end we use a result on iterative methods for LRA, due to [MM15], building on ideas from [RST10, HMT11, WC15, BDMI14, Woo14]. It states that given BB, an initial crude approximation for [B]k[B]_{k} can be refined into a good approximation with a small number of powering iterations. The next theorem is somewhat implicit in [MM15]; see Section B.3 for details.

Theorem 5.1 ([MM15]).

Suppose we have a matrix Pd×kP\in\mathbb{R}^{d\times k} that satisfies

B(BP)(BP)BF2O(kd)B[B]kF2.\lVert B-(BP)(BP)^{\dagger}B\rVert_{F}^{2}\leq O(kd)\cdot\lVert B-[B]_{k}\rVert_{F}^{2}. (1)

Then, Z=(BBT)qBPZ=(BB^{T})^{q}BP with q=O(ϵ1log(d/ϵ))q=O(\epsilon^{-1}\log(d/\epsilon)) satisfies BZZBF2(1+ϵ)B[B]kF2\lVert B-ZZ^{\dagger}B\rVert_{F}^{2}\leq(1+\epsilon)\cdot\lVert B-[B]_{k}\rVert_{F}^{2}.

Normally, PP is chosen at random as a matrix of independent gaussians, which satisfies Equation 1 with high probability. Since GJ algorithms are deterministic, we derandomize this approach using subsets of the standard basis.

Lemma \themylemma.

Suppose k<dk<d. For every Bn×dB\in\mathbb{R}^{n\times d}, there is a subset of size kk of the standard basis in d\mathbb{R}^{d}, such that if we organize its elements into a matrix Pd×kP\in\mathbb{R}^{d\times k}, it satisfies Equation 1.

5.3 The Proxy Loss

Let LkSCW(S,A)L^{\mathrm{SCW}}_{k}(S,A) denote the loss of SCW with input matrix AA, sketching matrix SS and target rank kk. By Section 5, LkSCW(S,A)=AA(SA)(SA)F2L^{\mathrm{SCW}}_{k}(S,A)=\lVert A-A(SA)^{\dagger}(SA)\rVert_{F}^{2}. We now approximate this quantity with a GJ algorithm. Formally, given ϵ>0\epsilon>0, we define a proxy loss L^k,ϵ(S,A)\hat{L}_{k,\epsilon}(S,A) as the output of the following GJ algorithm, which operates on an input SS with a fixed AA. Recall that dd is the column dimension of AA.

  1. 1.

    Compute B=A(SA)(SA)B=A(SA)^{\dagger}(SA) using the GJ algorithm from Section 5.1.

  2. 2.

    Let {Pi:i=1(dk)}\{P_{i}:i=1\ldots{d\choose k}\} be the set of matrices in d×k\mathbb{R}^{d\times k} whose columns form all of the possible kk-subsets of the standard basis in d\mathbb{R}^{d}.

  3. 3.

    For every PiP_{i}, compute Zi=(BBT)qBPiZ_{i}=(BB^{T})^{q}BP_{i}, where q=O(ϵ1log(d/ϵ))q=O(\epsilon^{-1}\log(d/\epsilon)).

  4. 4.

    Choose ZZ as the ZiZ_{i} that minimizes BZiZiBF2\lVert B-Z_{i}Z_{i}^{\dagger}B\rVert_{F}^{2}, using Section 5.1 to compute ZiZiZ_{i}Z_{i}^{\dagger}.

  5. 5.

    Compute and return the proxy loss, L^k,ϵ(S,A)=AZZBF2\hat{L}_{k,\epsilon}(S,A)=\lVert A-ZZ^{\dagger}B\rVert_{F}^{2}.

Given rr\in\mathbb{R}, we can return “true” if L^k,ϵ(S,A)>r\hat{L}_{k,\epsilon}(S,A)>r and “false” otherwise, obtaining a GJ algorithm that fits the assumptions of Theorem 3.1.

Lemma \themylemma.

This GJ algorithm has degree Δ=O(mkϵ1log(d/ϵ))\Delta=O(mk\epsilon^{-1}\log(d/\epsilon)) and predicate complexity p2m2O(k)(d/k)3kp\leq 2^{m}\cdot 2^{O(k)}\cdot(d/k)^{3k}.

Proof.

Since SASA has mm rows, computing (SA)(SA)(SA)^{\dagger}(SA) with Section 5.1 in step 1 has degree O(m)O(m) and predicate complexity 2m2^{m}. For each PiP_{i}, the qq powering iterations in step 3 blow up the degree by qq. Since ZiZ_{i} has kk columns, computing ZiZiZ_{i}Z_{i}^{\dagger} with (the transposed version of) Section 5.1 blows up the degree by O(k)O(k) and the predicate complexity by 2k2^{k}. Choosing ZZ in step 4 entails pairwise comparisons between (dk){d\choose k} values, blowing up the predicate complexity by ((dk)2)e2(ed/k)2k{{d\choose k}\choose 2}\leq e^{2}(ed/k)^{2k}. Step 5 blows up the degree by only O(1)O(1) and does not change the predicate complexity (note that ZZZZ^{\dagger} has already been computed). The final check whether L^k,ϵ(S,A)>r\hat{L}_{k,\epsilon}(S,A)>r is one of (dk)(ed/k)k{d\choose k}\leq(ed/k)^{k} polynomials (one per possible value of ZZ). Overall, the algorithm has degree O(mkq)=O(mkϵ1log(d/ϵ))O(mkq)=O(mk\epsilon^{-1}\log(d/\epsilon)), and predicate complexity at most 2m2O(k)(d/k)3k2^{m}\cdot 2^{O(k)}\cdot(d/k)^{3k}. ∎

Next we show that the proxy loss approximates the true SCW loss.

Lemma \themylemma.

For every S,AS,A, it holds that 0L^k,ϵ(S,A)LkSCW(S,A)ϵ0\leq\hat{L}_{k,\epsilon}(S,A)-L^{\mathrm{SCW}}_{k}(S,A)\leq\epsilon.

Proof.

Given S,AS,A, let B=A(SA)(SA)B=A(SA)^{\dagger}(SA), and let Un×kU\in\mathbb{R}^{n\times k} a matrix whose columns are the top kk left-singular vectors of BB. This means [B]k=UUTB[B]_{k}=UU^{T}B. Therefore, by Section 5, we have LkSCW(S,A)=AUUTBF2L^{\mathrm{SCW}}_{k}(S,A)=\lVert A-UU^{T}B\rVert_{F}^{2}. Let ZZ be the matrix computed in step 4 of the GJ algorithm for L^k,ϵ(S,A)\hat{L}_{k,\epsilon}(S,A), and recall we have L^k,ϵ(S,A)=AZZBF2\hat{L}_{k,\epsilon}(S,A)=\lVert A-ZZ^{\dagger}B\rVert_{F}^{2}. On one hand, since ZZ has kk columns, ZZBZZ^{\dagger}B has rank at most kk. Therefore, the optimality of [B]k=UUTB[B]_{k}=UU^{T}B as a rank-kk approximation of BB implies,

BUUTBF2BZZBF2.\lVert B-UU^{T}B\rVert_{F}^{2}\leq\lVert B-ZZ^{\dagger}B\rVert_{F}^{2}. (2)

On the other hand, by Section 5.2, some PiP_{i} considered in step 2 of the GJ algorithm satisfies Equation 1. By Theorem 5.1, this implies that its corresponding ZiZ_{i} (computed in step 3) satisfies BZiZiBF2(1+ϵ)BUUTBF2\lVert B-Z_{i}Z_{i}^{\dagger}B\rVert_{F}^{2}\leq(1+\epsilon)\cdot\lVert B-UU^{T}B\rVert_{F}^{2}, and consequently, ZZ satisfies

BZZBF2(1+ϵ)BUUTBF2.\lVert B-ZZ^{\dagger}B\rVert_{F}^{2}\leq(1+\epsilon)\cdot\lVert B-UU^{T}B\rVert_{F}^{2}. (3)

Since (SA)(SA)(SA)^{\dagger}(SA) is a projection matrix, we have by the Pythagorean identity,

L^k,ϵ(S,A)\displaystyle\hat{L}_{k,\epsilon}(S,A) =AZZBF2\displaystyle=\lVert A-ZZ^{\dagger}B\rVert_{F}^{2}
=AZZA(SA)(SA)F2\displaystyle=\lVert A-ZZ^{\dagger}A(SA)^{\dagger}(SA)\rVert_{F}^{2}
=A(SA)(SA)ZZA(SA)(SA)F2+A(I(SA)(SA))F2\displaystyle=\lVert A(SA)^{\dagger}(SA)-ZZ^{\dagger}A(SA)^{\dagger}(SA)\rVert_{F}^{2}+\lVert A(I-(SA)^{\dagger}(SA))\rVert_{F}^{2}
=BZZBF2+ABF2,\displaystyle=\lVert B-ZZ^{\dagger}B\rVert_{F}^{2}+\lVert A-B\rVert_{F}^{2},

and similarly, LkSCW(S,A)=AUUTBF2=BUUTBF2+ABF2L^{\mathrm{SCW}}_{k}(S,A)=\lVert A-UU^{T}B\rVert_{F}^{2}=\lVert B-UU^{T}B\rVert_{F}^{2}+\lVert A-B\rVert_{F}^{2}. Putting these together, L^k,ϵ(S,A)LkSCW(S,A)=BZZBF2BUUTBF2\hat{L}_{k,\epsilon}(S,A)-L^{\mathrm{SCW}}_{k}(S,A)=\lVert B-ZZ^{\dagger}B\rVert_{F}^{2}-\lVert B-UU^{T}B\rVert_{F}^{2}. From Equations 2 and 3 we now get, 0L^(S,A)LkSCW(S,A)ϵBUUTBF20\leq\hat{L}(S,A)-L^{\mathrm{SCW}}_{k}(S,A)\leq\epsilon\cdot\lVert B-UU^{T}B\rVert_{F}^{2}. The lemma follows since both (IUUT)(I-UU^{T}) and (SA)(SA)(SA)^{\dagger}(SA) are projection matrices, implying that

(IUUT)BF2BF2=A(SA)(SA)F2AF2,\lVert(I-UU^{T})B\rVert_{F}^{2}\leq\lVert B\rVert_{F}^{2}=\lVert A(SA)^{\dagger}(SA)\rVert_{F}^{2}\leq\lVert A\rVert_{F}^{2},

and we recall that we assume throughout that AF2=1\lVert A\rVert_{F}^{2}=1. ∎

We can now complete the proof of the upper bound in Theorem 2.1. Let us recall notation: Let 𝒜\mathcal{A} be set possible input matrices to the LRA problem (i.e., all matrices An×dA\in\mathbb{R}^{n\times d} with AF2=1\lVert A\rVert_{F}^{2}=1), and let 𝒮\mathcal{S} be the set of all possible sketching matrices that IVY can learn (which are all matrices of order m×nm\times n with the fixed sparsity pattern used by SCW and IVY). The class of IVY losses (whose fat shattering dimension we aim to bound) is IVY={LkSCW(S,):𝒜[0,1]}S𝒮\mathcal{L}_{\mathrm{IVY}}=\{L^{\mathrm{SCW}}_{k}(S,\cdot):\mathcal{A}\rightarrow[0,1]\}_{S\in\mathcal{S}}, and the class of proxy losses is ^ϵ={L^k,ϵ(S,):𝒜[0,1]}S𝒮\hat{\mathcal{L}}_{\epsilon}=\{\hat{L}_{k,\epsilon}(S,\cdot):\mathcal{A}\rightarrow[0,1]\}_{S\in\mathcal{S}}.

By Section 5.3, given A𝒜A\in\mathcal{A} and S𝒮S\in\mathcal{S}, the proxy loss L^k,ϵ(S,A)\hat{L}_{k,\epsilon}(S,A) can be computed by a GJ algorithm with degree Δ=O(mkϵ1log(d/ϵ))\Delta=O(mk\epsilon^{-1}\log(d/\epsilon)) and predicate complexity p2m2O(k)(d/k)3kp\leq 2^{m}\cdot 2^{O(k)}\cdot(d/k)^{3k}. Since each S𝒮S\in\mathcal{S} is defined by nsns real parameters, Theorem 3.1 yields that

pdim(^ϵ)=O(nslog(Δp))=O(ns(m+klog(d/k)+log(1/ϵ))).\mathrm{pdim}(\hat{\mathcal{L}}_{\epsilon})=O(ns\log(\Delta p))=O(ns\cdot(m+k\log(d/k)+\log(1/\epsilon))). (4)

Let A1,,AN𝒜A_{1},\ldots,A_{N}\in\mathcal{A} be a subset of matrices of size NN which is ϵ\epsilon-fat shattered by IVY\mathcal{L}_{\mathrm{IVY}}. Recall that by Section 2.1, this means that there are thresholds r1,,rNr_{1},\ldots,r_{N}\in\mathbb{R}, such that for every I{1,,N}I\subset\{1,\ldots,N\} there exists S𝒮S\in\mathcal{S} such that LkSCW(S,Ai)>ri+ϵL^{\mathrm{SCW}}_{k}(S,A_{i})>r_{i}+\epsilon if iIi\in I and LkSCW(S,Ai)<riϵL^{\mathrm{SCW}}_{k}(S,A_{i})<r_{i}-\epsilon if iIi\notin I. By Section 5.3, this implies that L^k,ϵ(S,Ai)>ri\hat{L}_{k,\epsilon}(S,A_{i})>r_{i} if and only if iIi\in I. Thus, A1,,ANA_{1},\ldots,A_{N} is pseudo-shattered by ^ϵ\hat{\mathcal{L}}_{\epsilon}, implying that Npdim(^ϵ)N\leq\mathrm{pdim}(\hat{\mathcal{L}}_{\epsilon}). Since this holds for any subset of size NN which is ϵ\epsilon-fat shattered by IVY\mathcal{L}_{\mathrm{IVY}}, we have fatdimϵ(IVY)pdim(^ϵ)\mathrm{fatdim}_{\epsilon}(\mathcal{L}_{\mathrm{IVY}})\leq\mathrm{pdim}(\hat{\mathcal{L}}_{\epsilon}). The upper bound in Theorem 2.1 now follows from Equation 4.

6 Other Learning-Based Algorithms in Numerical Linear Algebra

Generally, our approach is applicable whenever the loss incurred by a given algorithm (or equivalently, by a given setting of parameters) on a fixed input can be computed by an efficient GJ algorithm (in the efficiency measures of Section 3), when given “oracle access” to the optimal solution for that fixed input.555See Remark 3 regarding oracle access to the optimal solution. This oracle access was not needed for proving our bounds for IVY, but t will be needed for two of the algorithms discussed in this section: Few-shot LRA and Learned Multigrid. In this section, we show that in addition to IVY, our method also yields generalization bounds for various other data-driven and learning-based algorithms for problems in numerical linear algebra (specifically, LRA and regression) that have appeared in the literature. The first two algorithms discussed below (Butterfly LRA and Multi-sketch LRA) are variants of IVY, and the bounds for them essentially follow from Theorem 2.1. For the other two algorithms (Few-shot LRA and Learned Multigrid) we describe the requisite GJ algorithms below.

Butterfly learned LRA.

[ALN20] suggested a learned LRA algorithm similar to IVY, except that the learned sparse sketching matrix is replaced by a dense but implicitly-sparse butterfly gadget matrix (a known and useful gadget that arises in fast Fourier transforms). This gadget induces a sketching matrix Sm×nS\in\mathbb{R}^{m\times n} specified by O(mlogn)O(m\log n) learned parameters, where each entry is a product of logn\log n of those parameters. Since they use the IVY loss to learn the matrix, our results give the same upper bound as Theorem 2.1 on the fat shattering dimension of their algorithm, times logn\log n to account for the initial degree (in the GJ sense of Section 3) of the entries in SS.

Multi-sketch learned LRA.

[LLV+20] propose a more involved learned LRA algorithm, which uses two learned sketching matrices. Still, they train each of them using the IVY loss, and therefore the upper bound from Theorem 2.1 holds for their algorithm as well, up to constants.

Few-shot learned LRA.

[IWW21] use a different loss than IVY for learned LRA: loss(S,A)=UkTSTSUI0F2\mathrm{loss}(S,A)=\lVert U_{k}^{T}S^{T}SU-I_{0}\rVert_{F}^{2}, where UU is the left-factor in the SVD of AA, UkU_{k} is its restriction to the top-kk columns, and I0I_{0} is the identity of order kk concatenated on the right with zero columns to match the number of columns in UU. Recall that in the GJ framework, the GJ algorithm has free “oracle access” to UU (see Remark 3). Therefore, this loss can be computed by GJ algorithm of degree 44 and predicate complexity 11, yielding an upper bounded of O(ns)O(ns) on its pseudo-dimension by Theorem 3.1.

Learned multigrid regression.

[LGM+20] presented a learning-based algorithm for linear regression, which is based on the well-studied multigrid paradigm for solving numerical problems. The specific algorithm they build on is known as 2-level algebraic multigrid (AMG). It approximates the solution to a regression problem minxAxb22\min_{x}\lVert Ax-b\rVert_{2}^{2}, where AA is a square matrix of order n×nn\times n, by iterative improvements that use a sparse auxiliary matrix Pm×nP\in\mathbb{R}^{m\times n} called a prolongation matrix. While there is a large body of literature on heuristic methods for choosing PP, [LGM+20] suggest to instead learn it using a graph neural network. Similarly to IVY, they keep the sparsity pattern of PP fixed, and learn the values of its non-zero entries as trainable parameters.

Let x(i)x^{(i)} denote the approximate solution produced by 2-level AMG in iteration ii, starting from some fixed initial guess x(0)x^{(0)}. For details of how x(i)x^{(i)} is computed in practice, we refer to [LGM+20]. For our purposes, it suffices to note that there is a closed-form formula for obtaining x(i)x^{(i)} from the true optimal solution xx^{*}, described next. Let s1,s21s_{1},s_{2}\geq 1 be two fixed (non-learned) integer parameters of the algorithm. Let LL be the lower-triangular part of the input matrix AA (including the diagonal). [LGM+20] restrict their algorithm for input matrices AA and prolongation matrices PP such that both LL and PTAPP^{T}AP are invertible. We then have the identity,

x(i)=x+(IL1A)s2(IP(PTAP)1PTA)(IL1A)s1(x(i1)x).x^{(i)}=x^{*}+(I-L^{-1}A)^{s_{2}}\cdot(I-P(P^{T}AP)^{-1}P^{T}A)\cdot(I-L^{-1}A)^{s_{1}}\cdot(x^{(i-1)}-x^{*}). (5)

Let P0\lVert P\rVert_{0} denote the number of non-zeros in the fixed sparsity pattern of PP, and recall that mm is the row-dimension of PP. We show that the pseudo-dimension of 2-level AMG with qq iterations is O(P0qlogm)O(\lVert P\rVert_{0}\cdot q\log m), by describing a GJ algorithm Γ\Gamma for computing its loss Ax(q)b22\lVert Ax^{(q)}-b\rVert_{2}^{2}. To this end, note that Γ\Gamma has “oracle access” to any information about AA (see Remark 3). In particular, (IL1A)s1(I-L^{-1}A)^{s_{1}}, (IL1A)s2(I-L^{-1}A)^{s_{2}} and xx^{*} are all available to it. Furthermore, as shown in the proof of Section 5.1, Γ\Gamma can compute (PTAP)1(P^{T}AP)^{-1} using degree O(m)O(m) and predicate complexity zero. Using Equation 5, it can compute x(i)x^{(i)} from x(i1)x^{(i-1)} while increasing the degree by only 22 (from multiplying (PTAP)1(P^{T}AP)^{-1} by PP on the left and by PTP^{T} on the right). Iterating this, Γ\Gamma can compute x(q)x^{(q)} and Ax(q)b22\lVert Ax^{(q)}-b\rVert_{2}^{2} from the initial guess x(0)x^{(0)} using degree O(mq)O(m^{q}) and predicate complexity zero. Comparing the loss Ax(q)b22\lVert Ax^{(q)}-b\rVert_{2}^{2} to a given threshold rr increases the predicate complexity to 11. The upper bound O(P0qlogm)O(\lVert P\rVert_{0}\cdot q\log m) now follows from Theorem 3.1.

Acknowledgments

We thank Sebastien Bubeck for helpful discussions on statistical learning, and the anonymous reviewers for useful comments. This work was supported in part by NSF TRIPODS program (award DMS-2022448); Simons Investigator Award; GIST-MIT Research Collaboration grant; MIT-IBM Watson collaboration.

References

  • [AB09] Martin Anthony and Peter L Bartlett, Neural network learning: Theoretical foundations, cambridge university press, 2009.
  • [ACC+11] Nir Ailon, Bernard Chazelle, Kenneth L Clarkson, Ding Liu, Wolfgang Mulzer, and C Seshadhri, Self-improving algorithms, SIAM Journal on Computing 40 (2011), no. 2, 350–375.
  • [ALN20] Nir Ailon, Omer Leibovich, and Vineet Nair, Sparse linear networks with a fixed butterfly structure: Theory and practice, arXiv preprint arXiv:2007.08864 (2020).
  • [Bal20] Maria-Florina Balcan, Data-driven algorithm design, Beyond Worst Case Analysis of Algorithms (Tim Roughgarden, ed.), Cambridge University Press, 2020.
  • [BDD+21] Maria-Florina Balcan, Dan DeBlasio, Travis Dick, Carl Kingsford, Tuomas Sandholm, and Ellen Vitercik, How much data is sufficient to learn high-performing algorithms? generalization guarantees for data-driven algorithm design, STOC, 2021.
  • [BDMI14] Christos Boutsidis, Petros Drineas, and Malik Magdon-Ismail, Near-optimal column-based matrix reconstruction, SIAM Journal on Computing 43 (2014), no. 2, 687–717.
  • [BDSV18] Maria-Florina Balcan, Travis Dick, Tuomas Sandholm, and Ellen Vitercik, Learning to branch, International conference on machine learning, PMLR, 2018, pp. 344–353.
  • [BHM00] William L Briggs, Van Emden Henson, and Steve F McCormick, A multigrid tutorial, SIAM, 2000.
  • [BNVW17] Maria-Florina Balcan, Vaishnavh Nagarajan, Ellen Vitercik, and Colin White, Learning-theoretic foundations of algorithm configuration for combinatorial partitioning problems, Conference on Learning Theory, PMLR, 2017, pp. 213–274.
  • [BPSV21] Maria-Florina Balcan, Siddharth Prasad, Tuomas Sandholm, and Ellen Vitercik, Sample complexity of tree search configuration: Cutting planes and beyond, NeurIPS, 2021.
  • [BSV18] Maria-Florina Balcan, Tuomas Sandholm, and Ellen Vitercik, A general theory of sample complexity for multi-item profit maximization, Proceedings of the 2018 ACM Conference on Economics and Computation, 2018, pp. 173–174.
  • [BSV20]  , Refined bounds for algorithm configuration: The knife-edge of dual class approximability, International Conference on Machine Learning, PMLR, 2020, pp. 580–590.
  • [Csa76] L Csanky, Fast parallel matrix inversion algorithms, SIAM Journal on Computing 5 (1976), no. 4, 618–623.
  • [CW09] Kenneth L Clarkson and David P Woodruff, Numerical linear algebra in the streaming model, Proceedings of the forty-first annual ACM symposium on Theory of computing, 2009, pp. 205–214.
  • [CW13]  , Low rank approximation and regression in input sparsity time, Proceedings of the forty-fifth annual ACM symposium on Theory of Computing, 2013, pp. 81–90.
  • [DRVW06] Amit Deshpande, Luis Rademacher, Santosh S Vempala, and Grant Wang, Matrix approximation and projective clustering via volume sampling, Theory of Computing 2 (2006), no. 1, 225–247.
  • [GJ95] Paul W Goldberg and Mark R Jerrum, Bounding the vapnik-chervonenkis dimension of concept classes parameterized by real numbers, Machine Learning 18 (1995), no. 2-3, 131–148.
  • [GR17] Rishi Gupta and Tim Roughgarden, A pac approach to application-specific algorithm selection, SIAM Journal on Computing 46 (2017), no. 3, 992–1017.
  • [HMT11] Nathan Halko, Per-Gunnar Martinsson, and Joel A Tropp, Finding structure with randomness: Probabilistic algorithms for constructing approximate matrix decompositions, SIAM review 53 (2011), no. 2, 217–288.
  • [IVWW19] Pitor Indyk, Ali Vakilian, Tal Wagner, and David P Woodruff, Sample-optimal low-rank approximation of distance matrices, Conference on Learning Theory, PMLR, 2019, pp. 1723–1751.
  • [IVY19] Piotr Indyk, Ali Vakilian, and Yang Yuan, Learning-based low-rank approximations, NeurIPS, 2019.
  • [IWW21] Piotr Indyk, Tal Wagner, and David Woodruff, Few-shot data-driven algorithms for low rank approximation, NeurIPS, 2021.
  • [Kol06] Vladimir Koltchinskii, Local rademacher complexities and oracle inequalities in risk minimization, The Annals of Statistics 34 (2006), no. 6, 2593–2656.
  • [LGM+20] Ilay Luz, Meirav Galun, Haggai Maron, Ronen Basri, and Irad Yavneh, Learning algebraic multigrid using graph neural networks, ICML, 2020.
  • [LLV+20] Simin Liu, Tianrui Liu, Ali Vakilian, Yulin Wan, and David P Woodruff, Learning the positions in countsketch, arXiv preprint arXiv:2007.09890 (2020).
  • [LT13] Michel Ledoux and Michel Talagrand, Probability in banach spaces: Isoperimetry and processes, Springer Science & Business Media, 2013.
  • [Mah11] Michael W Mahoney, Randomized algorithms for matrices and data, arXiv preprint arXiv:1104.5557 (2011).
  • [MM15] Cameron Musco and Christopher Musco, Randomized block krylov methods for stronger and faster approximate singular value decomposition, NeurIPS (2015).
  • [MT20] Per-Gunnar Martinsson and Joel A Tropp, Randomized numerical linear algebra: Foundations and algorithms, Acta Numerica 29 (2020), 403–572.
  • [MV20] Michael Mitzenmacher and Sergei Vassilvitskii, Algorithms with predictions, Beyond Worst Case Analysis of Algorithms (Tim Roughgarden, ed.), Cambridge University Press, 2020.
  • [RST10] Vladimir Rokhlin, Arthur Szlam, and Mark Tygert, A randomized algorithm for principal component analysis, SIAM Journal on Matrix Analysis and Applications 31 (2010), no. 3, 1100–1124.
  • [Sar06] Tamas Sarlos, Improved approximation algorithms for large matrices via random projections, 2006 47th Annual IEEE Symposium on Foundations of Computer Science (FOCS’06), IEEE, 2006, pp. 143–152.
  • [Ste15] Ingo Steinwart, Measuring the capacity of sets of functions in the analysis of erm, Measures of Complexity, Springer, 2015, pp. 217–233.
  • [TYUC19] Joel A Tropp, Alp Yurtsever, Madeleine Udell, and Volkan Cevher, Streaming low-rank matrix approximation with an application to scientific simulation, SIAM Journal on Scientific Computing 41 (2019), no. 4, A2430–A2463.
  • [War68] Hugh E Warren, Lower bounds for approximation by nonlinear manifolds, Transactions of the American Mathematical Society 133 (1968), no. 1, 167–178.
  • [WC15] Rafi Witten and Emmanuel Candes, Randomized algorithms for low-rank matrix factorizations: sharp performance bounds, Algorithmica 72 (2015), no. 1, 264–281.
  • [Woo14] David P Woodruff, Sketching as a tool for numerical linear algebra, Theoretical Computer Science 10 (2014), no. 1-2, 1–157.

Appendix A Proof of the Goldberg-Jerrum Framework (Theorem 3.1)

In this section we prove Theorem 3.1, the main theorem of the GJ framework presented in Section 3. The proof is taken from [GJ95] with slight modifications, and we include it here for completeness.

We start with an intermediate result about polynomial formulas.

Definition \thedefinition.

A polynomial formula f:n{True,False}f:\mathbb{R}^{n}\rightarrow\{\text{True,False}\} is a DNF formula over boolean predicates of the form P(x1,,xn)0P(x_{1},\ldots,x_{n})\geq 0, where PP is a polynomial in the real-valued inputs to ff.

Theorem A.1.

Using the notation of Section 2.1, suppose that each algorithm LL\in\mathcal{L} is specified by nn real parameters. Suppose that for every x𝒳x\in\mathcal{X} and rr\in\mathbb{R} there is a polynomial formula fx,rf_{x,r} over nn variables, that given LL\in\mathcal{L} checks whether L(x)>rL(x)>r. Suppose further that fx,rf_{x,r} has pp distinct polynomials in its predicates, and each them has degree at most Δ\Delta. Then, the pseudo-dimension of \mathcal{L} is O(nlog(Δp))O(n\log(\Delta p)).

Proof.

Let x1,,xNx_{1},\ldots,x_{N} be a shattered set of instances, and r1,,rNr_{1},\ldots,r_{N} the corresponding loss thresholds. We need to show N=O(nlog(Δp))N=O(n\log(\Delta p)).

For every i=1,,Ni=1,\ldots,N, let us denote fi=fxi,rif_{i}=f_{x_{i},r_{i}} for brevity. Then, for a given LL\in\mathcal{L}, fif_{i} checks whether L(xi)>riL(x_{i})>r_{i}. Thus, x1,,xNx_{1},\ldots,x_{N} being a shattered set means that the family of NN-dimensional boolean strings {f1(L),,fN(L):L}\{f_{1}(L),\ldots,f_{N}(L):L\in\mathcal{L}\} has size 2N2^{N}.

The truth value of each fif_{i} is determined by the signs of pp different polynomials in nn variables, each of degree at most Δ\Delta. Therefore, a boolean string f1(L),,fN(L)f_{1}(L),\ldots,f_{N}(L) is determined by the signs of at most pNpN different polynomials in nn variables, each of degree at most Δ\Delta. We now use the following classical theorem,

Theorem A.2 ([War68]).

Suppose NnN\geq n. Then NN polynomials in nn variables, each of degree at most Δ\Delta, take at most O(NΔ/n)nO(N\Delta/n)^{n} different sign patterns.

In our case, if NnN\leq n then the conclusion N=O(nlog(Δp))N=O(n\log(\Delta p)) is trivial, thus we may assume N>nN>n and hence pNnpN\geq n. Now by Warren’s theorem, 2N={f1(L),,fN(L):L}=O(pNΔ/n)n2^{N}=\{f_{1}(L),\ldots,f_{N}(L):L\in\mathcal{L}\}=O(pN\Delta/n)^{n}, and by solving for NN we get N=O(nlog(Δp))N=O(n\log(\Delta p)) as desired. ∎

Now we can prove Theorem 3.1. We can describe Γx,r\Gamma_{x,r} by a computation tree, where arithmetic operations correspond to nodes with one child, conditional statements correspond to nodes with two children, and leaves correspond to output values. We can transform the tree into a polynomial DNF formula (as defined in the previous section) that checks whether L(x)>rL(x)>r, by ORing the ANDs of the conditional statements along each root-to-leaf computation path that ends with a “true” leaf.

Each branching node in the tree (i.e., a node with two children) corresponds to a conditional statement in Γx,r\Gamma_{x,r}, which has the form of a rational predicate, meaning it determines whether R0R\geq 0 for some rational function RR of the inputs of Γx,r\Gamma_{x,r}. It can be easily checked that we can replace each such rational predicate with O(1)O(1) polynomials predicates (that determine whether P0P\geq 0 for some polynomial in the inputs of Γx,r\Gamma_{x,r}) of degree no larger than that of RR. Since Γx,r\Gamma_{x,r} has predicate complexity pp, the obtained formula has at most O(p)O(p) distinct polynomials in its predicates. Since Γx,r\Gamma_{x,r} has degree Δ\Delta, each of these polynomials has degree at most Δ\Delta. Theorem 3.1 now follows from Theorem A.1.

Appendix B Omitted Proofs from Sections 4 and 5

B.1 Proof of Section 4

If wTA=0w^{T}A=0 then SCW returns a zero matrix, whose loss is AF2\lVert A\rVert_{F}^{2}, and the statement holds. Now assume wTA0w^{T}A\neq 0. We go over the steps of SCW (Algorithm 1):

  • Compute the row vector wTAw^{T}A.

  • Compute the SVD of wTAw^{T}A. Note that for any row vector zTz^{T}, its SVD UΣVTU\Sigma V^{T} is given by UU being a 1×11\times 1 matrix whose only entry is 11, Σ\Sigma being a 1×11\times 1 matrix whose only entry is z\lVert z\rVert, and VTV^{T} being the row vector 1zzT\frac{1}{\lVert z\rVert}z^{T}. So for zT=wTAz^{T}=w^{T}A we have VT=1ATwwTAV^{T}=\frac{1}{\lVert A^{T}w\rVert}w^{T}A.

  • Compute [AV]1[AV]_{1}, the best rank-11 approximation of AVAV. But in our case AVAV equals 1ATwAATw\frac{1}{\lVert A^{T}w\rVert}AA^{T}w, which is already rank 11, so its best rank-11 approximation is itself, [AV]1=1ATwAATw[AV]_{1}=\frac{1}{\lVert A^{T}w\rVert}AA^{T}w.

  • Return [AV]1VT[AV]_{1}V^{T}, which in our case equals 1ATw2AATwwTA\frac{1}{\lVert A^{T}w\rVert^{2}}AA^{T}ww^{T}A.

So, the output rank-11 matrix of SCW is 1ATw2AATwwTA\frac{1}{\lVert A^{T}w\rVert^{2}}AA^{T}ww^{T}A, and its loss is as stated.

B.2 Proof of Section 5

Write the SVD of A(SA)(SA)A(SA)^{\dagger}(SA) as UΣVTU\Sigma V^{T}. Note that (SA)(SA)(SA)^{\dagger}(SA) is the orthogonal projection on the row-space of SASA, in which every vector is a linear combination of the rows of AA. Therefore, projecting the rows of AA onto the row-space of SASA spans all of the row-space of SASA, or in other words, the row-spans of A(SA)(SA)A(SA)^{\dagger}(SA) and of SASA are the same. Consequently, the rows of VTV^{T} form an orthonormal basis for the row-space of SASA. This means in particular that VVTVV^{T} is the orthogonal projection on the row-space of SASA, thus (SA)(SA)=VVT(SA)^{\dagger}(SA)=VV^{T}, so we can write the matrix from the lemma statement as [A(SA)(SA)]k=[AVVT]k[A(SA)^{\dagger}(SA)]_{k}=[AVV^{T}]_{k}.

We recall that SCW returns the best rank-kk approximation of AA in the row-space of SASA (see [Woo14]), meaning it returns ZVTZV^{T} where ZZ is a rank-kk matrix that minimizes AZVTF2\lVert A-ZV^{T}\rVert_{F}^{2}. So, we need to show that

Zof rank k,A[AVVT]kF2AZVTF2.\forall\;Z\;\;\text{of rank $k$,}\;\;\;\;\lVert A-[AVV^{T}]_{k}\rVert_{F}^{2}\leq\lVert A-ZV^{T}\rVert_{F}^{2}. (6)

We use the following observation.

Claim \theclaim.

[AVVT]kVVT=[AVVT]k[AVV^{T}]_{k}VV^{T}=[AVV^{T}]_{k}.

Proof.

Recall that A(SA)(SA)=AVVTA(SA)^{\dagger}(SA)=AVV^{T}, thus the SVD of AVVTAVV^{T} is UΣVTU\Sigma V^{T}, thus [AVVT]k=UkΣkVkT[AVV^{T}]_{k}=U_{k}\Sigma_{k}V_{k}^{T}. Since VkTVVT=VkTV_{k}^{T}VV^{T}=V_{k}^{T} (projecting a subset of kk rows of VTV^{T} onto the row-space of VTV^{T} does not change anything), we have [AVVT]kVVT=UkΣkVkTVVT=UkΣkVkT=[AVVT]k[AVV^{T}]_{k}VV^{T}=U_{k}\Sigma_{k}V_{k}^{T}VV^{T}=U_{k}\Sigma_{k}V_{k}^{T}=[AVV^{T}]_{k}. ∎

Proceeding with the proof of Section 5, we now have for every ZZ of rank kk,

A[AVVT]kF2\displaystyle\lVert A-[AVV^{T}]_{k}\rVert_{F}^{2} =AVVT[AVVT]kVVTF2+A(IVVT)[AVVT]k(IVVT)F2\displaystyle=\lVert AVV^{T}-[AVV^{T}]_{k}VV^{T}\rVert_{F}^{2}+\lVert A(I-VV^{T})-[AVV^{T}]_{k}(I-VV^{T})\rVert_{F}^{2}
=AVVT[AVVT]kF2+AAVVTF2\displaystyle=\lVert AVV^{T}-[AVV^{T}]_{k}\rVert_{F}^{2}+\lVert A-AVV^{T}\rVert_{F}^{2}
AVVTZVTF2+AAVVTF2\displaystyle\leq\lVert AVV^{T}-ZV^{T}\rVert_{F}^{2}+\lVert A-AVV^{T}\rVert_{F}^{2}
=AZVTF2,\displaystyle=\lVert A-ZV^{T}\rVert_{F}^{2},

where the first and last equalities are the Pythagorean identity (orthogonally projecting onto and against VVTVV^{T}), the second equality is by Section B.2, and the inequality is by the optimality of [AVVT]k[AVV^{T}]_{k} as a rank-kk approximation of AVVTAVV^{T} (since ZZ has rank kk). This proves Equation 6, proving the lemma.

B.3 Proof of Theorem 5.1

In this section we explain how to read the statement of Theorem 5.1 from [MM15]. All section, theorem and page numbers refer to the arXiv version of their paper.666https://arxiv.org/abs/1504.05477v4.

The relevant result in [MM15] is the Frobenius-norm analysis of their Algorithm 1 (which they call “Simultaneous Iteration”), stated in their Theorem 11. As they state in their Section 5.1, for the purpose of low-rank approximation in the Frobenius norm, it suffices to return 𝐐\mathrm{\mathbf{Q}} instead of 𝐙\mathrm{\mathbf{Z}} (in the notation of their Algorithm 1). Thus, steps 4–6 of Algorithm 1 can be skipped. Since (again in the notation of their Algorithm 1) 𝐐\mathrm{\mathbf{Q}} is the result of orthonormalizing the columns of 𝐊\mathrm{\mathbf{K}}, we clearly have 𝐊𝐊=𝐐𝐐\mathrm{\mathbf{KK}}^{\dagger}=\mathrm{\mathbf{QQ}}^{\dagger} (this is because 𝐌𝐌\mathrm{\mathbf{MM}}^{\dagger} is the projection matrix on the column space of any matrix 𝐌\mathrm{\mathbf{M}}, and 𝐊\mathrm{\mathbf{K}} and 𝐐\mathrm{\mathbf{Q}} have the same column spaces). Since for our purpose we only need the projection matrix 𝐐𝐐\mathrm{\mathbf{QQ^{\dagger}}} (rather than the orthonormal basis 𝐐\mathrm{\mathbf{Q}}), we can also skip step 3, and simply use the matrix 𝐊\mathrm{\mathbf{K}} as the output of their Algorithm 1, while maintaining the guarantee of their Theorem 11 (with 𝐙𝐙T\mathrm{\mathbf{ZZ}}^{T} replaced by 𝐊𝐊\mathrm{\mathbf{KK}}^{\dagger}).

Next, while their Algorithm 1 chooses the initial matrix 𝚷\mathrm{\mathbf{\Pi}} as a random gaussian matrix, they state in their Section 5.1 that their analysis (and in particular, their Theorem 11) holds for every initial matrix 𝚷\mathrm{\mathbf{\Pi}} that satisfies the guarantee of their Lemma 4, which is equivalent to satisfying our Equation 1.

Finally, note that they set the number of iterations in step 1 of their Algorithm 1 to q=O(ϵ1log(d))q=O(\epsilon^{-1}\log(d)), while we set it in Theorem 5.1 to q=O(ϵ1log(d/ϵ))q=O(\epsilon^{-1}\log(d/\epsilon)). This is because in their setting they may assume w.l.o.g. that ϵ1=poly(d)\epsilon^{-1}=\mathrm{poly}(d) (see bottom of their page 10), while in our setting this is not the case.

Altogether, the statement of Theorem 5.1 follows.

B.4 Proof of Section 5.2

For completeness, we discuss two ways to prove the lemma.

Proof 1.

We start by noting that a significantly stronger version of Section 5.2 follows from Theorem 1.3 of [DRVW06]. Slightly rephrasing (and transposing) their theorem, it states that for every matrix Bn×dB\in\mathbb{R}^{n\times d}, there is a distribution (that depends on BB) over matrices Pd×kP\in\mathbb{R}^{d\times k} whose columns are standard basis vectors, such that if we let B~k\tilde{B}_{k} be the projection of the columns of BB onto the columns-space of BPBP, namely B~k=(BP)(BP)B\tilde{B}_{k}=(BP)(BP)^{\dagger}B, then we have

𝔼PBB~kF2(k+1)B[B]kF2.\mathbb{E}_{P}\lVert B-\tilde{B}_{k}\rVert_{F}^{2}\leq(k+1)\cdot\lVert B-[B]_{k}\rVert_{F}^{2}.

In particular, there exists a supported PP that satisfies this equation without the expectation, yielding Section 5.2.

Remark. Note that this is in fact a quantitatively stronger version of Section 5.2, since the O(kd)O(kd) term in Equation 1 is replaced here by (k+1)(k+1), which [DRVW06] furthermore show is the best possible. However, this improvement does not strengthen the final bounds we obtain in our theorems. This is because the analysis of the power method (Theorem 5.1) requires the logd\log d term in the number of iterations qq even if the initial approximation (Equation 1) is up to a factor of (k+1)(k+1) instead of O(kd)O(kd). Technically, this stems from the analysis in [MM15]: In the equation immediately after their eq. (4) on page 11, the logd\log d term is needed not just to eliminate dd from the numerator, but also to gain a dO(1)d^{O(1)} term in the denominator.

Proof 2.

Since the aforementioned result of [DRVW06] is difficult and stronger than we require, we now also give a more basic proof of Section 5.2, for completeness.

Let Bn×dB\in\mathbb{R}^{n\times d} be written in SVD form as B=UΣVTB=U\Sigma V^{T}. Let VkTV_{k}^{T} be the matrix with the top kk rows of VTV^{T}, and VkTV^{T}_{-k} the matrix with the remaining rows. We use the following fact from [Woo14], which originates in [BDMI14] and is also used in [MM15] (see their Lemma 14).

Lemma \themylemma.

Let Pd×kP\in\mathbb{R}^{d\times k} be any matrix such that the k×kk\times k matrix VkTPV_{k}^{T}P has rank kk. Then,

B(BP)(BP)BF2B[B]kF2(1+VkTP22(VkTP)22).\lVert B-(BP)(BP)^{\dagger}B\rVert_{F}^{2}\leq\lVert B-[B]_{k}\rVert_{F}^{2}\cdot\left(1+\lVert V^{T}_{-k}P\rVert_{2}^{2}\cdot\lVert(V_{k}^{T}P)^{\dagger}\rVert_{2}^{2}\right).

We will construct a matrix Pd×kP\in\mathbb{R}^{d\times k} whose columns are distinct standard basis vectors of d\mathbb{R}^{d}. Since it would thus have orthonomal columns, as does VkV_{-k}, we would have VkTP22VkT22P22=1\lVert V_{-k}^{T}P\rVert_{2}^{2}\leq\lVert V_{-k}^{T}\rVert_{2}^{2}\cdot\lVert P\rVert_{2}^{2}=1. Therefore, using Section B.4, for PP to satisfy Equation 1 and thus prove Section 5.2, it suffices to construct it such that VkTPV_{k}^{T}P has rank kk and satisfies (VkTP)22d\lVert(V_{k}^{T}P)^{\dagger}\rVert_{2}^{2}\leq d. This is equivalent to showing that the smallest of the kk singular values of VkTPV_{k}^{T}P is at least 1/d1/\sqrt{d}, which is what we do in the rest of the current proof.

We construct PP by the following process. Initialize Z1k×dZ_{1}\in\mathbb{R}^{k\times d} as Z1VkTZ_{1}\leftarrow V_{k}^{T}, and a zero matrix Pd×kP\in\mathbb{R}^{d\times k}. For i=1,,ki=1,\ldots,k:

  1. 1.

    Let z1i,,zdiz^{i}_{1},\ldots,z^{i}_{d} denote the columns of ZiZ_{i}.

  2. 2.

    Let jiargmaxj[d]zji22j_{i}\leftarrow\mathrm{argmax}_{j\in[d]}\lVert z^{i}_{j}\rVert_{2}^{2}.

  3. 3.

    Set column ii of PP to be ejie_{j_{i}}.

  4. 4.

    Let Πi\Pi_{i} be the orthogonal projection matrix against span(vj1,,vji)\mathrm{span}(v_{j_{1}},\ldots,v_{j_{i}}).

  5. 5.

    Zi+1ΠiVkTZ_{i+1}\leftarrow\Pi_{i}V_{k}^{T}.

Let v1,,vdv_{1},\ldots,v_{d} denote the columns of VkTV_{k}^{T}. Observe that the columns of VkTPV_{k}^{T}P are vj1,,vjkv_{j_{1}},\ldots,v_{j_{k}}.

Claim \theclaim.

For every i=1,,ki=1,\ldots,k we have zjii22ki+1d\lVert z^{i}_{j_{i}}\rVert_{2}^{2}\geq\frac{k-i+1}{d}.

Proof.

Let i[k]i\in[k]. Since VkTV_{k}^{T} has orthonormal rows, each of its kk singular values equals 11, and any best rank-(i1)(i-1) approximation of it, [VkT]i1[V_{k}^{T}]_{i-1}, satisfies VkT[VkT]i1F2=k(i1)\lVert V_{k}^{T}-[V_{k}^{T}]_{i-1}\rVert_{F}^{2}=k-(i-1). Since Πi1\Pi_{i-1} is a projection against i1i-1 directions, (IΠi1)VkT(I-\Pi_{i-1})V_{k}^{T} can be viewed as a rank-(i1)(i-1) approximation of VkTV_{k}^{T}, and therefore

ZiF2=Πi1VkTF2=VkT(IΠi1)VkTF2VkT[VkT]i1F2=ki+1.\lVert Z_{i}\rVert_{F}^{2}=\lVert\Pi_{i-1}V_{k}^{T}\rVert_{F}^{2}=\lVert V_{k}^{T}-(I-\Pi_{i-1})V_{k}^{T}\rVert_{F}^{2}\geq\lVert V_{k}^{T}-[V_{k}^{T}]_{i-1}\rVert_{F}^{2}=k-i+1.

Hence the average squared-2\ell_{2} mass of columns in ZiZ_{i} is at least ki+1d\frac{k-i+1}{d}, and hence the column with maximal 2\ell_{2}-norm — which we recall is defined to be zjiiz_{j_{i}}^{i} — has squared 2\ell_{2}-norm at least ki+1d\frac{k-i+1}{d}. ∎

Claim \theclaim.

For every i=1,,ki=1,\ldots,k, zjiiz^{i}_{j_{i}} is spanned by vj1,,vjiv_{j_{1}},\ldots,v_{j_{i}}.

Proof.

Observe that zjii=Πi1vjiz^{i}_{j_{i}}=\Pi_{i-1}v_{j_{i}} (with the convention that Π0\Pi_{0} is the identity). Therefore, zjii=vji(IΠi1)vjiz^{i}_{j_{i}}=v_{j_{i}}-(I-\Pi_{i-1})v_{j_{i}}, and the claim follows by noting that IΠi1I-\Pi_{i-1} is the orthogonal projection onto the subspace spanned by vj1,,vji1v_{j_{1}},\ldots,v_{j_{i-1}}. ∎

Claim \theclaim.

zj11,,zjkkz^{1}_{j_{1}},\ldots,z^{k}_{j_{k}} are pairwise orthogonal.

Proof.

Let i<ii<i^{\prime}. By Section B.4 zjiiz^{i}_{j_{i}} is spanned by vj1,,vjiv_{j_{1}},\ldots,v_{j_{i}}. At the same time, zjiiz^{i^{\prime}}_{j_{i^{\prime}}} is a column of the matrix ZiZ_{i^{\prime}}, whose columns have been orthogonally projected against vj1,,vji1v_{j_{1}},\ldots,v_{j_{i^{\prime}-1}}, a set that contains vj1,,vjiv_{j_{1}},\ldots,v_{j_{i}}. Thus zjiiz^{i^{\prime}}_{j_{i^{\prime}}} is orthogonal to zjiiz^{i}_{j_{i}}. ∎

Claim \theclaim.

Let i[k]i\in[k]. Then vjiv_{j_{i}} can be written uniquely as a linear combination of zj11,,zjiiz_{j_{1}}^{1},\ldots,z_{j_{i}}^{i}, such that the coefficient of zjiiz_{j_{i}}^{i} is 11.

Proof.

We have shown in Section B.4 that zj11,,zjkkz^{1}_{j_{1}},\ldots,z^{k}_{j_{k}} are orthogonal and in Section B.4 that each is non-zero, so they form a basis of k\mathbb{R}^{k}. Thus vjiv_{j_{i}} is written uniquely as a linear combination of them. As noted in the proof of Section B.4, we have zjii=Πi1vji=vji(IΠi1)vjiz^{i}_{j_{i}}=\Pi_{i-1}v_{j_{i}}=v_{j_{i}}-(I-\Pi_{i-1})v_{j_{i}}, or equivalently vji=zjii+(IΠi1)vjiv_{j_{i}}=z^{i}_{j_{i}}+(I-\Pi_{i-1})v_{j_{i}}.

We recall that (IΠi1)(I-\Pi_{i-1}) is the orthogonal projection onto the subspace W=span(vj1,,vji1)W=\mathrm{span}(v_{j_{1}},\ldots,v_{j_{i-1}}), whose dimension is at most i1i-1. By Section B.4, WW contains zj1,,zji1z_{j_{1}},\ldots,z_{j_{i-1}}. Since zj1,,zjkz_{j_{1}},\ldots,z_{j_{k}} form a basis of k\mathbb{R}^{k}, we now get that WW is spanned by zj1,,zji1z_{j_{1}},\ldots,z_{j_{i-1}}, and cannot contain any zjiz_{j_{i^{\prime}}} with iii^{\prime}\geq i. In particular, (IΠi1)vji(I-\Pi_{i-1})v_{j_{i}} is written uniquely as a linear combination of zj1,,zji1z_{j_{1}},\ldots,z_{j_{i-1}}. Recalling that vji=zjii+(IΠi1)vjiv_{j_{i}}=z^{i}_{j_{i}}+(I-\Pi_{i-1})v_{j_{i}}, the claim follows. ∎

We can finally complete the proof of Section 5.2. For every i[k]i\in[k], let qiq_{i} denote the unit-length vector in the direction of zjiiz^{i}_{j_{i}}, i.e., qi=1zjiizjiiq_{i}=\frac{1}{\lVert z^{i}_{j_{i}}\rVert}z^{i}_{j_{i}}. Let Qk×kQ\in\mathbb{R}^{k\times k} be the matrix whose columns are q1,,qkq_{1},\ldots,q_{k}. Section B.4 means that we can write VkTPV_{k}^{T}P (since its columns, we recall, are vj1,,vjkv_{j_{1}},\ldots,v_{j_{k}}) as VkTP=QRV_{k}^{T}P=QR, where RR is an upper triangular matrix, whose diagonal entries are zjii\lVert z^{i}_{j_{i}}\rVert for i=1,,ki=1,\ldots,k. Since QQ has orthonormal columns (from Section B.4 and the fact that we normalized its columns), this is a QR-decomposition of VkTPV_{k}^{T}P. Hence, VkTPV_{k}^{T}P and RR have the same singular values. The diagonal entries of RR are its eigenvalues, which are also its singular values as all are non-negative. By Section B.4, each diagonal entry of RR is at least 1/d1/\sqrt{d}. This implies the same lower bound on the smallest singular value of RR and thus of VkTPV_{k}^{T}P, which as explained earlier, implies Section 5.2.

Appendix C Proof of the Lower Bound

In this section we prove the lower bound in Theorem 2.1, restated next.

Theorem C.1.

For every skms\leq k\leq m and ϵ(0,12k)\epsilon\in(0,\tfrac{1}{2\sqrt{k}}), the ϵ\epsilon-fat shattering dimension of IVY with target low rank kk, sketching dimension mm and sparsity ss is Ω(ns)\Omega(ns).

Proof.

We may assume w.l.o.g. that m=km=k, since we can always augment a sketching matrix with kk rows with mkm-k additional zero rows, without changing the result of SCW.

We start with the special case s=ks=k, where the sketching matrix SS is allowed to be fully dense, and the desired lower bound is Ω(nk)\Omega(nk). Let A0A_{0} be the n×kn\times k matrix whose ii-th column is eie_{i} (the standard basis vector) for all i=1ki=1\ldots k. For every i{1k}i\in\{1\ldots k\} and t{k+1n}t\in\{k+1\ldots n\}, Let A(i,t)A_{(i,t)} be given by replacing the ii-th column of A0A_{0} with ete_{t}. Observe that each A(i,t)A_{(i,t)} has rank kk, and each of its singular values equals 11. We will show that the set of matrices Z={A(i,t):i=1k,t=k+1n}Z=\{A_{(i,t)}:i=1\ldots k,\;t=k+1\ldots n\} is shattered, which would imply the lower bound since its size is k(nk)=Ω(nk)k(n-k)=\Omega(nk).

Let ZZZ^{\prime}\subset Z. It suffices to exhibit a sketching matrix Sk×nS\in\mathbb{R}^{k\times n} (with unbounded sparsity, since s=ks=k) such that for every A(i,t)A_{(i,t)}, the SCW loss of using SS for a rank-kk approximation of AA is 0 if A(i,t)ZA_{(i,t)}\in Z^{\prime}, and at least 11 otherwise. We set the first kk columns of SS to be the order-kk identity matrix. Then, for every t=k+1,,nt=k+1,\ldots,n, let Jt={i1k:A(i,t)Z}J_{t}=\{i\in{1\ldots k}:A_{(i,t)}\in Z^{\prime}\}. In the tt-th column of SS, we put 11’s in rows JtJ_{t} and 0’s in the remaining rows.

Fix A(i,t)A_{(i,t)}. Recall that SCW returns the best rank-kk approximation of AA in the row-space of SA(i,t)SA_{(i,t)}, which is a k×kk\times k matrix.

  • Suppose A(i,t)ZA_{(i,t)}\notin Z^{\prime}. In this case, we argue that the ii-th row of SA(i,t)SA_{(i,t)} is zero. This implies that the row-space of SA(i,t)SA_{(i,t)} has dimension at most k1k-1, so SCW incurs loss at least 1.

    Indeed, row ii of SA(i,t)SA_{(i,t)} equals the sum of rows jj of A(i,t)A_{(i,t)} for which S(i,j)=1S(i,j)=1. The only non-zero rows in A(i,t)A_{(i,t)}, by construction, are rows {1k}{i}\{1...k\}\setminus\{i\} and row tt. Since the first kk columns of SS are the identity, S(i,j)=0S(i,j)=0 for every j{1k}{i}j\in\{1...k\}\setminus\{i\}. Since iJti\notin J_{t}, then by construction of SS we have S(i,t)=0S(i,t)=0. Overall, the ii-th row of SA(i,t)SA_{(i,t)} is zero.

  • Suppose A(i,t)ZA_{(i,t)}\in Z^{\prime}. In this case, we argue that the row-space of SA(i,t)SA_{(i,t)} has dimension kk. Since the ambient row-dimension is also kk, this means SCW returns a perfect rank-kk approximation, i.e., zero loss.

    Indeed, again, the only non-zero rows of A(i,t)A_{(i,t)} are rows {1k}{i}\{1...k\}\setminus\{i\} and row tt. But in this case iJti\in J_{t}, thus S(i,t)=1S(i,t)=1 and thus row ii of SA(i,t)SA_{(i,t)} equals eie_{i}. For every j{1k}{i}j\in\{1...k\}\setminus\{i\}, row jj of SA(i,t)SA_{(i,t)} equals either ej+eie_{j}+e_{i} (if jJtj\in J_{t}) or just eje_{j} (if jJtj\notin J_{t}). It is thus clear that the rows of SA(i,t)SA_{(i,t)} span all of k\mathbb{R}^{k}.

In conclusion, the SCW loss is 0 on matrices in ZZ^{\prime} at least 11 on matrices in ZZZ\setminus Z^{\prime}.

Next we extend it to an Ω(ns)\Omega(ns) lower bound for any s=1,,ks=1,\ldots,k. We assume for simplicity that ss is an integer divisor of kk and that kk is an integer divisor of nsns.

We partition an input matrix An×kA\in\mathbb{R}^{n\times k} into k/sk/s diagonal blocks of order (ns/k)×s(ns/k)\times s each (everything outside the diagonal blocks is zero). To construct the shattered set, we first set each diagonal block in AA to A0(ns/k)×sA_{0}\in\mathbb{R}^{(ns/k)\times s} (as defined above), then choose one “critical” block b{1,,k}b\in\{1,\ldots,k\} and replace its A0A_{0} with A(i,t)A_{(i,t)} as defined above, with i{1,,s}i\in\{1,\ldots,s\} and t{s+1,,ns/k}t\in\{s+1,\ldots,ns/k\}. Denote the resulting matrix by A(b,i,t)A_{(b,i,t)}. The total size of the shattered set {A(b,i,t)}\{A_{(b,i,t)}\} is kss(nsks)=(nk)s=Ω(ns)\tfrac{k}{s}\cdot s\cdot(\tfrac{ns}{k}-s)=(n-k)s=\Omega(ns). The corresponding sketching matrix, arising from the dense construction above, has block-diagonal structure with blocks of order s×(ns/k)s\times(ns/k) each, and in particular has at most ss nonzeros per column, as needed.

The correctness of the construction follow immediately from the dense case. In more detail, let Z′′Z^{\prime\prime} be a subset of the shattered set {A(b,i,t)}\{A_{(b,i,t)}\}. Fix A(b,i,t)A_{(b,i,t)} (not necessarily in Z′′Z^{\prime\prime}). The k×kk\times k matrix SA(b,i,t)SA_{(b,i,t)} has block-diagonal structure with blocks of order s×ss\times s. Each non-critical block equals the order-ss identity, while the critical block, by the proof of the dense case above, equals the order-ss identity if A(b,i,t)Z′′A_{(b,i,t)}\in Z^{\prime\prime} and has rank at most s1s-1 otherwise. Therefore, if A(b,i,t)Z′′A_{(b,i,t)}\in Z^{\prime\prime} then SCW with sketching matrix SS finds a zero-loss rank-kk approximation of every A(b,i,t)Z′′A_{(b,i,t)}\in Z^{\prime\prime}, and incurs loss at least 11 for every A(b,i,t)Z′′A_{(b,i,t)}\notin Z^{\prime\prime}.

Finally, recall that we need to normalize our matrices to have squared Frobenius norm 11. In the above construction it is instead kk, so we need to divide each entry by 1/k1/\sqrt{k}. This means that the loss is zero for matrices in the shattered set, and is at least 1/k1/\sqrt{k} for matrices outside the shattered set. Thus, the set is ϵ\epsilon-fat shattered for every ϵ(0,12k)\epsilon\in(0,\tfrac{1}{2\sqrt{k}}). ∎

Appendix D Proof of Theorem 2.2

D.1 Uniform Convergence and ERM Upper Bound

By classical results in learning theory, the fat shattering dimension can be used to obtain an upper bound on the number of samples needed for uniform convergence and for ERM learning (see for example Theorem 19.1 in [AB09]), and we could simply plug Theorem 2.1 into these results. However, we can get a sharper bound by exploiting the proxy loss from Section 5.3, since for the latter we have an upper bound on the pseudo-dimension, which yields sharper bounds for ERM learning than the fat shattering dimension.

Let IVY=IVY(ϵ,δ)\ell_{\mathrm{IVY}}=\ell_{\mathrm{IVY}}(\epsilon,\delta) be the number of samples required for (ϵ,δ)(\epsilon,\delta)-uniform convergence for the family IVY losses, IVY={LkSCW(S,):𝒜[0,1]}S𝒮\mathcal{L}_{\mathrm{IVY}}=\{L_{k}^{\mathrm{SCW}}(S,\cdot):\mathcal{A}\rightarrow[0,1]\}_{S\in\mathcal{S}}. Let ^=^(ϵ,δ)\hat{\ell}=\hat{\ell}(\epsilon,\delta) be the number of samples required for (ϵ,δ)(\epsilon,\delta)-uniform convergence for the family of proxy losses, ^ϵ={L^k,ϵ(S,):𝒜[0,1]}S𝒮\hat{\mathcal{L}}_{\epsilon}=\{\hat{L}_{k,\epsilon}(S,\cdot):\mathcal{A}\rightarrow[0,1]\}_{S\in\mathcal{S}}. Fix ^\hat{\ell} matrices A1,,A^𝒜A_{1},\ldots,A_{\hat{\ell}}\in\mathcal{A}. Section 5.3 implies

S𝒮,|1^i=1^L^k,ϵ(S,Ai)1^i=1^LkSCW(S,Ai)|ϵ\forall\;S\in\mathcal{S},\;\;\;\;\left|\frac{1}{\hat{\ell}}\sum_{i=1}^{\hat{\ell}}\hat{L}_{k,\epsilon}(S,A_{i})-\frac{1}{\hat{\ell}}\sum_{i=1}^{\hat{\ell}}L_{k}^{\mathrm{SCW}}(S,A_{i})\right|\leq\epsilon

and

S𝒮,|𝔼A𝒟[L^k,ϵ(S,A)]𝔼A𝒟[LkSCW(S,A)]|ϵ.\forall\;S\in\mathcal{S},\;\;\;\;\left|\mathbb{E}_{A\sim\mathcal{D}}[\hat{L}_{k,\epsilon}(S,A)]-\mathbb{E}_{A\sim\mathcal{D}}[L_{k}^{\mathrm{SCW}}(S,A)]\right|\leq\epsilon.

Let 𝒟\mathcal{D} be a distribution over 𝒜\mathcal{A}. By definition of ^\hat{\ell}, with probability at least 1δ1-\delta over sampling A1,,A^A_{1},\ldots,A_{\hat{\ell}} independently from 𝒟\mathcal{D}, we have

S𝒮,|1^i=1^L^k,ϵ(S,Ai)𝔼A𝒟[L^k,ϵ(S,A)]|ϵ,\forall\;S\in\mathcal{S},\;\;\;\;\left|\frac{1}{\hat{\ell}}\sum_{i=1}^{\hat{\ell}}\hat{L}_{k,\epsilon}(S,A_{i})-\mathbb{E}_{A\sim\mathcal{D}}[\hat{L}_{k,\epsilon}(S,A)]\right|\leq\epsilon,

which combined with the above, implies

S𝒮,|1^i=1^LkSCW(S,Ai)𝔼A𝒟[LkSCW(S,A)]|3ϵ.\forall\;S\in\mathcal{S},\;\;\;\;\left|\frac{1}{\hat{\ell}}\sum_{i=1}^{\hat{\ell}}L_{k}^{\mathrm{SCW}}(S,A_{i})-\mathbb{E}_{A\sim\mathcal{D}}[L_{k}^{\mathrm{SCW}}(S,A)]\right|\leq 3\epsilon.

Consequently, IVY(3ϵ,δ)^(ϵ,δ)\ell_{\mathrm{IVY}}(3\epsilon,\delta)\leq\hat{\ell}(\epsilon,\delta). By classical results in learning theory (see, e.g., Theorem 3.2 in [GR17]), the sample complexity can be upper bounded using the pseudo-dimension, and in particular, ^(ϵ,δ)=O(ϵ2(pdim(^ϵ)+log(1/δ)))\hat{\ell}(\epsilon,\delta)=O(\epsilon^{-2}\cdot(\mathrm{pdim}(\hat{\mathcal{L}}_{\epsilon})+\log(1/\delta))). By Equation 4 in Section 5.3, pdim(^ϵ)=O(ns(m+klog(d/k)+log(1/ϵ)))\mathrm{pdim}(\hat{\mathcal{L}}_{\epsilon})=O(ns\cdot(m+k\log(d/k)+\log(1/\epsilon))). Putting everything together,

IVY(3ϵ,δ)=O(ϵ2(ns(m+klog(d/k)+log(1/ϵ))+log(1/δ))).\ell_{\mathrm{IVY}}(3\epsilon,\delta)=O(\epsilon^{-2}\cdot(ns\cdot(m+k\log(d/k)+\log(1/\epsilon))+\log(1/\delta))).

As a consequence, that many samples suffice for (6ϵ,δ)(6\epsilon,\delta)-learning IVY with ERM. The upper bound in Theorem 2.2 follows by scaling ϵ\epsilon down by a constant.

D.2 Uniform Convergence Lower Bound

We proceed to the lower bound on the number of samples needed for IVY to admit (ϵ,ϵ)(\epsilon,\epsilon)-uniform convergence. It relies on some known results about the connection between uniform convergence and the fat shattering dimension, which we now detail.

We introduce some notation. Let \mathcal{L} be a class of functions L:𝒳[0,1]L:\mathcal{X}\rightarrow[0,1], and let 𝒟\mathcal{D} be a distribution over 𝒳\mathcal{X}. For every LL\in\mathcal{L}, denote its expected loss over 𝒟\mathcal{D} by

z(L):=𝔼x𝒟[L(x)].z(L):=\mathbb{E}_{x\sim\mathcal{D}}[L(x)].

Given \ell i.i.d. samples (x1,,x)𝒟(x_{1},\ldots,x_{\ell})\sim\mathcal{D}^{\ell}, denote the empirical loss over the samples by

z^(L):=1i=1L(xi).\hat{z}_{\ell}(L):=\frac{1}{\ell}\sum_{i=1}^{\ell}L(x_{i}).

Suppose \mathcal{L} admits (ϵ,ϵ)(\epsilon,\epsilon)-uniform convergence with \ell samples. This can now be written as

Pr𝒟[supL|z^(L)z(L)|ϵ]1ϵ.\Pr_{\mathcal{D}^{\ell}}\left[\sup_{L\in\mathcal{L}}\left|\hat{z}_{\ell}(L)-z(L)\right|\leq\epsilon\right]\geq 1-\epsilon. (7)

Our goal for this section is to prove a lower bound on \ell. We begin the proof. Equation 7 implies

𝔼𝒟[supL|z^(L)z(L)|]2ϵ.\mathbb{E}_{\mathcal{D}^{\ell}}\left[\sup_{L\in\mathcal{L}}\left|\hat{z}_{\ell}(L)-z(L)\right|\right]\leq 2\epsilon. (8)

This expectation can be bounded using Rademacher averages. Define the centered class c\mathcal{L}_{c} of \mathcal{L} as

c:={Lz(L):L}.\mathcal{L}_{c}:=\{L-z(L):L\in\mathcal{L}\}.

For the given sample (x1,,x)(x_{1},\ldots,x_{\ell}), define the Rademacher average of c\mathcal{L}_{c} as

Rad(c):=𝔼σ1,,σ[supLc|1i=1σiL(xi)|]=𝔼σ1,,σ[supL|1i=1σi(L(xi)z(L))|],\mathrm{Rad}_{\ell}(\mathcal{L}_{c}):=\mathbb{E}_{\sigma_{1},\ldots,\sigma_{\ell}}\left[\sup_{L^{\prime}\in\mathcal{L}_{c}}\left|\frac{1}{\ell}\sum_{i=1}^{\ell}\sigma_{i}L^{\prime}(x_{i})\right|\right]=\mathbb{E}_{\sigma_{1},\ldots,\sigma_{\ell}}\left[\sup_{L\in\mathcal{L}}\left|\frac{1}{\ell}\sum_{i=1}^{\ell}\sigma_{i}\left(L(x_{i})-z(L)\right)\right|\right],

where each of σ1,,σ\sigma_{1},\ldots,\sigma_{\ell} is independently uniform in {1,1}\{1,-1\}. The desymmetrization inequality for Rademacher processes (see [Kol06]) states that

𝔼𝒟[supL|z^(L)z(L)|]12𝔼𝒟[Rad(c)].\mathbb{E}_{\mathcal{D}^{\ell}}\left[\sup_{L\in\mathcal{L}}\left|\hat{z}_{\ell}(L)-z(L)\right|\right]\geq\frac{1}{2}\mathbb{E}_{\mathcal{D}^{\ell}}\left[\mathrm{Rad}_{\ell}(\mathcal{L}_{c})\right]. (9)

A version of Sudakov’s minorization inequality for Rademacher processes (Equation (26) in [Ste15], based on Corollary 4.14 in [LT13]) states that

Rad(c)α1(ln(2+α2Ψ(c)))1/2supγ>0γln𝒩2(γ,,),\mathrm{Rad}_{\ell}(\mathcal{L}_{c})\geq\frac{\alpha_{1}}{\sqrt{\ell}}\cdot\left(\ln\left(2+\frac{\alpha_{2}}{\Psi_{\ell}(\mathcal{L}_{c})}\right)\right)^{-1/2}\cdot\sup_{\gamma>0}\gamma\sqrt{\ln\mathcal{N}_{2}(\gamma,\mathcal{L},\ell)}, (10)

where α1,α2>0\alpha_{1},\alpha_{2}>0 are absolute constants, 𝒩2(ϵ,,)\mathcal{N}_{2}(\epsilon,\mathcal{L},\ell) is the covering number, and

Ψ(c):=supLc1i=1(L(xi))2.\Psi_{\ell}(\mathcal{L}_{c}):=\sup_{L^{\prime}\in\mathcal{L}_{c}}\sqrt{\frac{1}{\ell}\sum_{i=1}^{\ell}(L^{\prime}(x_{i}))^{2}}.

We forgo the definition of the covering number, since we only need the standard fact that it can be lower-bounded in terms of the fat shattering dimension. It is encompassed in the following claim.

Claim \theclaim.

Let γ[4ϵ,1)\gamma\in[4\epsilon,1) and δ(0,1100)\delta\in(0,\tfrac{1}{100}). Suppose \mathcal{L} admits (ϵ,δ)(\epsilon,\delta)-uniform convergence with \ell samples, and fatdim16γ()>1\mathrm{fatdim}_{16\gamma}(\mathcal{L})>1. Then ln𝒩2(γ,,)18fatdim16γ()\ln\mathcal{N}_{2}(\gamma,\mathcal{L},\ell)\geq\tfrac{1}{8}\cdot\mathrm{fatdim}_{16\gamma}(\mathcal{L}).

Proof.

Let \ell^{\prime} be the number of samples required for (12γ,δ)(\tfrac{1}{2}\gamma,\delta)-learning \mathcal{L} (see Section 2.1 for the definition of (ϵ,δ)(\epsilon,\delta)-learning). Theorem 19.5 in [AB09] gives the lower bound 116α(fatdimγ/(2α)()1)\ell^{\prime}\geq\tfrac{1}{16\alpha}(\mathrm{fatdim}_{\gamma/(2\alpha)}(\mathcal{L})-1) for any 0<α<1/40<\alpha<1/4. Setting α=1/32\alpha=1/32 and using fatdim16γ()>1\mathrm{fatdim}_{16\gamma}(\mathcal{L})>1 yields 2fatdim16γ()2fatdim16γ()\ell^{\prime}\geq 2\cdot\mathrm{fatdim}_{16\gamma}(\mathcal{L})-2\geq\mathrm{fatdim}_{16\gamma}(\mathcal{L}). Since (ϵ,δ)(\epsilon,\delta)-uniform convergence implies (2ϵ,δ)(2\epsilon,\delta)-learning and thus (12γ,δ)(\tfrac{1}{2}\gamma,\delta)-learning (as γ4ϵ\gamma\geq 4\epsilon), we have fatdim16γ()\ell\geq\ell^{\prime}\geq\mathrm{fatdim}_{16\gamma}(\mathcal{L}). Now we can apply Lemma 10.5 and Theorem 12.10 from [AB09], which yield ln𝒩2(γ,,)ln𝒩1(γ,,)18fatdim16γ()\ln\mathcal{N}_{2}(\gamma,\mathcal{L},\ell)\geq\ln\mathcal{N}_{1}(\gamma,\mathcal{L},\ell)\geq\tfrac{1}{8}\cdot\mathrm{fatdim}_{16\gamma}(\mathcal{L}). ∎

Putting together everything so far (i.e., combining Equations 8, 9, 10 and D.2), we get

2ϵ12𝔼𝒟[α18(ln(2+α2Ψ(c)))1/2supγ[4ϵ,1)γfatdim16γ()],2\epsilon\geq\frac{1}{2}\mathbb{E}_{\mathcal{D}^{\ell}}\left[\frac{\alpha_{1}}{8\sqrt{\ell}}\cdot\left(\ln\left(2+\frac{\alpha_{2}}{\Psi_{\ell}(\mathcal{L}_{c})}\right)\right)^{-1/2}\cdot\sup_{\gamma\in[4\epsilon,1)}\gamma\sqrt{\mathrm{fatdim}_{16\gamma}(\mathcal{L})}\right],

provided that fatdim16γ()>1\mathrm{fatdim}_{16\gamma}(\mathcal{L})>1 (as needed for Section D.2).

The derivations so far have been for any \mathcal{L} that admits (ϵ,ϵ)(\epsilon,\epsilon)-uniform convergence with \ell samples. Now we begin specializing the arguments for IVY, that is, for =IVY\mathcal{L}=\mathcal{L}_{\mathrm{IVY}}. We choose γ=1/(64k)\gamma=1/(64\sqrt{k}). Note that the assumption in Theorem 2.2 is that ϵ1/(256k)\epsilon\leq 1/(256\sqrt{k}), ensuring that γ4ϵ\gamma\geq 4\epsilon. By the lower bound in Theorem 2.1, fatdim16γ()=Ω(ns)\mathrm{fatdim}_{16\gamma}(\mathcal{L})=\Omega(ns). Plugging these above, we get

ϵΩ(1)𝔼𝒟[nsk(ln(2+α2Ψ(c)))1].\epsilon\geq\Omega(1)\cdot\mathbb{E}_{\mathcal{D}^{\ell}}\left[\sqrt{\frac{ns}{\ell k}\cdot\left(\ln\left(2+\frac{\alpha_{2}}{\Psi_{\ell}(\mathcal{L}_{c})}\right)\right)^{-1}}\right]. (11)

It remains to bound Ψ(c)\Psi_{\ell}(\mathcal{L}_{c}) from below. We show it is lower-bounded by a constant even for very simple input distributions for IVY, supported on only two matrices (and indeed, for any distribution such that there is a loss function LL\in\mathcal{L} with loss 0 on half the distribution mass and loss 11 on the other half). Recall that

Ψ(c):=supLc1i=1(L(xi))2=supL1i=1(L(xi)z(L))2.\Psi_{\ell}(\mathcal{L}_{c}):=\sup_{L^{\prime}\in\mathcal{L}_{c}}\sqrt{\frac{1}{\ell}\sum_{i=1}^{\ell}(L^{\prime}(x_{i}))^{2}}=\sup_{L\in\mathcal{L}}\sqrt{\frac{1}{\ell}\sum_{i=1}^{\ell}(L(x_{i})-z(L))^{2}}.

Let e1,,enne_{1},\ldots,e_{n}\in\mathbb{R}^{n} be the standard basis in n\mathbb{R}^{n}. Let A0,A0,A1,A1n×dA_{0},A_{0}^{\prime},A_{1},A_{1}^{\prime}\in\mathbb{R}^{n\times d} be defined as follows: the first kk columns of A0A_{0}^{\prime} are e1,,eke_{1},\ldots,e_{k}, and the rest are zero; the first kk columns of A1A_{1}^{\prime} are ek+1,,e2ke_{k+1},\ldots,e_{2k}, and the rest are zero; A0=1kA0A_{0}=\frac{1}{\sqrt{k}}A_{0}^{\prime}; and A1=1kA1A_{1}=\frac{1}{\sqrt{k}}A_{1}^{\prime}. Note that A0F2=A1F2=1\lVert A_{0}\rVert_{F}^{2}=\lVert A_{1}\rVert_{F}^{2}=1. Let 𝒟\mathcal{D} be the uniform distribution over {A0,A1}\{A_{0},A_{1}\}. Let L0L_{0}\in\mathcal{L} be the IVY loss function induced by the sketching matrix S0k×nS_{0}\in\mathbb{R}^{k\times n} whose rows are e1T,,ekTe_{1}^{T},\ldots,e_{k}^{T}. (In the notation of the previous sections, L0=LkSCW(S0,)IVYL_{0}=L_{k}^{\mathrm{SCW}}(S_{0},\cdot)\in\mathcal{L}_{\mathrm{IVY}}.) It is not hard to see that L0(A0)=0L_{0}(A_{0})=0 and L0(A1)=1L_{0}(A_{1})=1. In particular, z(L0)=12z(L_{0})=\tfrac{1}{2}. Therefore, for any sample x1,,xx_{1},\ldots,x_{\ell} from 𝒟\mathcal{D},

Ψ(c)1i=1(L0(xi)z(L0))2=1i=1(12)2=12.\Psi_{\ell}(\mathcal{L}_{c})\geq\sqrt{\frac{1}{\ell}\sum_{i=1}^{\ell}(L_{0}(x_{i})-z(L_{0}))^{2}}=\sqrt{\frac{1}{\ell}\sum_{i=1}^{\ell}\left(\frac{1}{2}\right)^{2}}=\frac{1}{2}.

The proof is now easily completed. Plugging the above bound on Ψ(c)\Psi_{\ell}(\mathcal{L}_{c}) into Equation 11 yields ϵΩ(1)𝔼𝒟[ns/(k)]\epsilon\geq\Omega(1)\cdot\mathbb{E}_{\mathcal{D}^{\ell}}\left[\sqrt{ns/(\ell k)}\right]. Since n,s,,kn,s,\ell,k are constants w.r.t. the sample from DD^{\ell}, we can dispense with the expecation and write ϵΩ(ns/(k))\epsilon\geq\Omega(\sqrt{ns/(\ell k)}). Rearranging yields Ω(ϵ2ns/k)\ell\geq\Omega(\epsilon^{-2}ns/k).

D.3 General Learning Lower Bound

The lower bound Ω(ϵ1+ns)\Omega(\epsilon^{-1}+ns) on the number of samples needed for (ϵ,δ)(\epsilon,\delta)-learning IVY with any learning procedure follows by plugging the lower bound on the fat shattering dimension from Theorem 2.1 into Theorem 19.5 in [AB09].

Appendix E Connection to Prior Work

In this section we discuss the connection of our techniques to prior techniques for proving statistical generalization bounds for data-driven algorithms, and specifically to [GR17] and [BDD+21]. The goal is to place our work in the context of related work, and also to explain why previous techniques do not give useful generalization bounds for the linear algebraic algorithms considered in this paper.

E.1 Illustrative Example 1: Learning Greedy Heuristics for Knapsack

[GR17] discuss learned heuristics for the Knapsack problem as an illustrative example for their statistical learning framework. Recall that in the Knapsack problem, the input is nn items with values v1,,vn>0v_{1},\ldots,v_{n}>0, respective costs c1,,cn>0c_{1},\ldots,c_{n}>0, and a cost limit C>0C>0. The goal is to return I{1,,n}I\subset\{1,\ldots,n\} that maximizes the total value iIvi\sum_{i\in I}v_{i} under the cost constraint iIci<C\sum_{i\in I}c_{i}<C. This problem is well-known to be NP-hard.

[GR17] consider the following family of greedy heurstics for Knapcask. Given a parameter ρR\rho\in R, define the rank of item ii as vi/ciρv_{i}/c_{i}^{\rho}, and let LρL_{\rho} be the greedy heuristic that adds the highest ranked items to II as long as the cost limit is not exceeded. They then let ρ\rho be a learnable parameter, and prove that the pseudo-dimension of the class of heuristics ={Lρ}ρ\mathcal{L}=\{L_{\rho}\}_{\rho\in\mathbb{R}} is O(logn)O(\log n).

The crux of the proof is the observation that for a given instance, the “loss” (or in this case the utility, since this is a combinatorial maximization problem) of a solution II is fully determined by the result of the (n2){n\choose 2} comparison vi/ciρ?vj/cjρv_{i}/c_{i}^{\rho}\geq^{?}v_{j}/c_{j}^{\rho}, or equivalently ρ?log(vj/vi)/log(cj/ci)\rho\geq^{?}\log(v_{j}/v_{i})/\log(c_{j}/c_{i}). This means that the parameter space \mathbb{R} is partitioned into (n2){n\choose 2} intervals such that the “loss” (utility) of the given instance is constant on each interval. The pseudo-dimension can be then bounded by the log of the number of intervals. The duality-based framework developed by [BDD+21] is a vast generalization of this argument, and recovers the same bound for Knapsack.

The GJ framework presented in Section 3 recovers it as well, for the same underlying reason. Since the “loss” (utility) of a given instance is determined by the (n2){n\choose 2} comparisons ρ?log(vj/vi)/log(cj/ci)\rho\geq^{?}\log(v_{j}/v_{i})/\log(c_{j}/c_{i}), which are polynomial predicates of degree 11 in the variable ρ\rho, it can be computed by a GJ algorithm with degree 11 and predicate complexity (n2){n\choose 2}. The upper bound O(logn)O(\log n) on the pseudo-dimension thus follows from Theorem 3.1.777Note that each heuristic LρL_{\rho} is specified by the single real parameter ρ\rho, so nn in the notation of Theorem 3.1 is 11.

E.2 Illustrative Example 2: IVY in the Case m=k=1m=k=1

In order to compare our techniques with the duality-based framework of [BDD+21], it is instructive to consider the simple case m=k=1m=k=1 of IVY. In Section 4, we proved a tight bound of O(n)O(n) on the pseudo-dimension of IVY in this case.

Our proof showed that for a given input matrix An×dA\in\mathbb{R}^{n\times d}, the loss equals AF2\lVert A\rVert_{F}^{2} if the sketching vector wnw\in\mathbb{R}^{n} satisfies wTA=0w^{T}A=0, and equals 0 otherwise. By the main definition of [BDD+21], this means that the dual class of losses is (,𝒢,d)(\mathcal{F},\mathcal{G},d)-piecewise decomposable, where \mathcal{F} is the class of constant-valued functions, 𝒢\mathcal{G} is the class of nn-dimensional linear threshold functions, and dd is the column-dimension of AA (note that dd is the number of functions from 𝒢\mathcal{G} involved in the condition wTA=0w^{T}A=0). Denoting the dual classes of \mathcal{F} and 𝒢\mathcal{G} by \mathcal{F}^{*} and 𝒢\mathcal{G}^{*} respectively (see [BDD+21] for the definition of dual classes), we have pdim()=0\mathrm{pdim}(\mathcal{F}^{*})=0 and VCdim(𝒢)=n+1\mathrm{VCdim}(\mathcal{G}^{*})=n+1. Therefore, the main theorem of [BDD+21] gives a bound of O(nlogn)O(n\log n) on the pseudo-dimension of IVY with m=k=1m=k=1, which is looser than the tight bound by logn\log n.

We remark that since the condition wTA=0w^{T}A=0 is equivalent to wTA2=0\lVert w^{T}A\rVert^{2}=0, the dual class is also (,𝒢2,1)(\mathcal{F},\mathcal{G}_{2},1)-piecewise decomposable where 𝒢2\mathcal{G}_{2} is a class of quadratic threshold functions in nn variables. It can be checked that VCdim(𝒢2)=12(n+1)(n+2)\mathrm{VCdim}(\mathcal{G}_{2}^{*})=\tfrac{1}{2}(n+1)(n+2), leading to an even looser bound of O(n2logn)O(n^{2}\log n) on the pseudo-dimension.

E.3 The General Case

Finally, let us point out a formal connection between the GJ framework and the duality framework of [BDD+21]. The proof of Theorem 3.1 in fact shows that if a class of algorithms admits a GJ algorithm with degree Δ\Delta and predicate complexity pp for computing the loss, then the dual class as defined by [BDD+21] is (,𝒢,p)(\mathcal{F},\mathcal{G},p)-piecewise decomposable, where \mathcal{F} is the class of constant-valued functions, and 𝒢\mathcal{G} is the class of polynomial threshold functions of degree Δ=O(Δ)\Delta^{\prime}=O(\Delta) in nn variables (where nn is the number of parameters that specify an algorithm, as in Theorem 3.1). The VC-dimension of the dual class 𝒢\mathcal{G}^{*} can be upper-bounded by the number of monomials in nn variables of degree at most Δ\Delta^{\prime}, which is (n+Δn){n+\Delta^{\prime}\choose n}.888The VC-dimension of the dual class could in principle be even smaller, if the primal class has a simpler structure than the dual class. However, in typical scenarios the primal class is less nicely structured than the dual class, which is the motivation for the work of [BDD+21]. Furthermore, if the primal class is indeed simpler, then there is no reason to go through duality at all. Therefore, the main theorem of [BDD+21] implies an upper bound of O((n+Δn)log((n+Δn)p))O\left({n+\Delta\choose n}\log\left({n+\Delta\choose n}\cdot p\right)\right) on the pseudo-dimension. Unfortunately, this is typically much looser than the bound O(nlog(Δp))O(n\log(\Delta p)) in Theorem 3.1. On the other hand, the result of [BDD+21] is much more general, and can handle decomposability beyond constant functions \mathcal{F} and polynomial thresholds 𝒢\mathcal{G}.

Specifically for IVY, the main component in our proof of Theorem 2.1 was a GJ algorithm of degree Δ=O(mk(d/ϵ)O(1/ϵ))\Delta=O(mk(d/\epsilon)^{O(1/\epsilon)}) and predicate complexity p2m2k(ed/k)3kp\leq 2^{m}\cdot 2^{k}\cdot(ed/k)^{3k} for computing the proxy loss (Section 5.3). While the main result of [BDD+21] can technically be applied here as described above, the pseudo-dimension upper bound it gives is super-polynomial in nn (as opposed to the linear dependence on nn in Theorem 2.1), so it does not seem to be useful in our setting.