This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Learning Trajectories are Generalization Indicators

Jingwen Fu1 , Zhizheng Zhang2 , Dacheng Yin311footnotemark: 1 , Yan Lu2 , Nanning Zheng122footnotemark: 2
fu1371252069@stu.xjtu.edu.cn
{zhizzhang,yanlu}@microsoft.com
ydc@mail.ustc.edu.cn
nnzheng@mail.xjtu.edu.cn
1National Key Laboratory of Human-Machine Hybrid Augmented Intelligence,
National Engineering Research Center for Visual Information and Applications,
and Institute of Artificial Intelligence and Robotics, Xi’an Jiaotong University,
2Microsoft Research Asia, 3University of Science and Technology of China
Work done during internships at Microsoft Research Asia.Corresponding Authors
Abstract

This paper explores the connection between learning trajectories of Deep Neural Networks (DNNs) and their generalization capabilities when optimized using (stochastic) gradient descent algorithms. Instead of concentrating solely on the generalization error of the DNN post-training, we present a novel perspective for analyzing generalization error by investigating the contribution of each update step to the change in generalization error. This perspective enable a more direct comprehension of how the learning trajectory influences generalization error. Building upon this analysis, we propose a new generalization bound that incorporates more extensive trajectory information. Our proposed generalization bound depends on the complexity of learning trajectory and the ratio between the bias and diversity of training set. Experimental observations reveal that our method effectively captures the generalization error throughout the training process. Furthermore, our approach can also track changes in generalization error when adjustments are made to learning rates and label noise levels. These results demonstrate that learning trajectory information is a valuable indicator of a model’s generalization capabilities.

1 Introduction

The generalizability of a Deep Neural Network (DNN) is a crucial research topic in the field of machine learning. Deep neural networks are commonly trained with a limited number of training samples while being tested on unseen samples. Depite the commonly used independent and identically distributed (i.i.d.) assumption between the training and testing sets, there often exists a varying degree of discrepancy between them in real-world applications. Generalization theories study the generalization of DNNs by modeling the gap between the empirical risk [36] and the popular risk [36]. Classical uniform convergence based methods [20] adopt the complexity of the function space to analyze this generalization error. These theories discover that more complex function space results in a larger generalization error [37]. However, they are not well applicable for DNNs [32, 22]. In deep learning, the double descent phenomenon [6] exists, which tells that larger complexity of function space may lead to smaller generalization error. This violates the aforementioned property in uniform convergence methods and imposes demands in studying the generalization of DNNs.

Although the function space of DNNs is vast, not all functions within that space can be discovered by learning algorithms. Therefore, some representative works bound the generalization of DNNs based on the properties of the learning algorithm, e.g. , stability of algorithm [11], information-theoretic analysis [39]. These works rely on the relation between the input (i.e. , training data) and output (weights of the model after training) of the learning algorithm to infer the generalization ability of the learned model. Here, the relation refers to how the change of one sample in the training data impacts the final weights of model in the stability of algorithms while referring to the mutual information between the weights and the training data in the information-theoretic analysis. Although some works [24, 11] leverage some information from training process to understand the properties of learning algorithm, there is limited trajectory information conveyed.

The purpose of this article is to enhance our theoretical comprehension of the relation between learning trajectory and generalization. While some recent experiments [9, 13] have shown a strong correlation between the information contained in learning trajectory and generalization, the theoretical understanding behind this is still underexplored. By investigating the contribution of each update step to the change in generalization error, we give a new generalization bound with rich trajectory related information. Our work can serve as a starting point to understand those experimental discoveries.

1.1 Our Contribution

Our contributions can be summarized below:

  • We demonstrate that learning trajectory information serves as a valuable indicator of generalization abilities. With this motivation, we present a novel perspective for analyzing generalization error by investigating the contribution of each update step to the change in generalization error.

  • Utilizing the aforementioned modeling technique, we introduce a novel generalization bound for deep neural networks (DNNs). Our proposed bound provides a greater depth of trajectory-related insights than existing methods.

  • Our method effectively captures the generalization error throughout the training process. And the assumption corresponding to this method is also confirmed by experiments. Furthermore, our approach can also track changes in generalization error when adjustments are made to learning rates and label noise levels.

2 Related Work

Generalization Theories

Existing works on studying the generalization of DNNs can be divided into three categories: the methods based on the complexity of function space, the methods based on learning algorithms, and the methods based on PAC Bayes. The first category considers the generalization of DNNs from the perspective of the complexity of the function space. Many methods for measuring the complexity of the function space have been proposed, e.g. , VC dimension [38], Rademacher Complexity [4] and covering number [32]. These works fail in being applied to DNN models since the complexity of the function space of a DNN model is too large to deliver a trivial result [40]. This thus motivates recent works to rethink the generalization of DNNs based on the accessible information in different learning algorithms such as stability of algorithm [11], information-theoretic analysis [39]. Among them, the stability of algorithm [7] measures how one sample change of training data impacts the model weights finally learned, and the information theory [29, 30, 39] based generalization bounds rely on the mutual information of the input (training data) and output (weights after training) of the learning algorithm. Another line is PAC Bayes [19] based method, which bounds the expectation of the error rates of a classifier chosen from a posterior distribution in terms of the KL divergence from a given prior distribution. Our research modifies the conventional Rademacher Complexity to calculate the complexity of the space explored by a learning algorithm, which in turn helps derive the generalization bound. Our approach resembles the first category, as we also rely on the complexity of the function space. However, our method differs as we focus on the function space explored by the learning trajectory, rather than the entire function space. The novelty of our technique lies in addressing the issue of dependence on training data within the function space explored by the learning trajectory, a dependency that is not permitted by the original Rademacher Complexity Theory.

Generalization Analysis for SGD

The optimization plays an nonnegligible role in the success of DNN. Therefore, there are many prior works studying the generalization of DNNs by exploring property of SGD, which could be summarized into two categories: stability of SGD and information-theoretic analysis. The most popular way of the former category is to analyze the stability of the weights updating. Hardt et al. [11] is the first work to analyze the stability of SGD with the requirements of smooth and Lipschitz assumptions. Its follow-up works try to discard the smooth [5], or Lipschitz [25] assumptions towards getting a more general bound. Information-theoretic methods leverage the chain rule of KL-divergence to calculate the mutual information between the learned model weights and the data. This kind of works is mainly applied for Stochastic Gradient Langevin Dynamics(SGLD), i.e. , SGD with noise injected in each step of parameters updating [28]. Negrea et al. [23], Haghifam et al. [10] improve the results using data-dependent priors. Neu et al. [24] construct an auxiliary iterative noisy process to adapt this method to the SGD scenario. In contrast to these studies, our approach utilizes more information related to learning trajectories. A more detailed comparison can be found in Table 2 and Appendix B.

3 Generalization Bound

Let us consider a supervised learning problem with a instance space 𝒵\mathcal{Z} and a parameter space 𝒲\mathcal{W}. The loss function can be defined as f:𝒲×𝒵+f:\mathcal{W}\times\mathcal{Z}\rightarrow\mathbb{R}_{+}. We denote the distribution of the instance space 𝒵\mathcal{Z} as μ\mu. The nn i.i.d samples draw from μ\mu are denoted as S={z1,,zn}μnS=\{z_{1},...,z_{n}\}\sim\mu^{n}. Given parameters 𝐰\mathbf{w}, the empirical risk and popular risk are denoted as FS(𝐰)1ninf(𝐰,zi)F_{S}(\mathbf{w})\triangleq\frac{1}{n}\sum_{i}^{n}f(\mathbf{w},z_{i}), and Fμ(𝐰)𝔼zμ[f(𝐰,z)]F_{\mu}(\mathbf{w})\triangleq\mathbb{E}_{z\sim\mu}[f(\mathbf{w},z)] respectively. Our work studies the generalization error of the learned model, i.eFμ(𝐰)FS(𝐰)F_{\mu}(\mathbf{w})-F_{S}(\mathbf{w}). For an optimizaiton process, the learning trajectory is represented as a function 𝐉:𝒲\mathbf{J}:\mathbb{N}\to\mathcal{W}. We use 𝐉𝐭\mathbf{J_{t}} to denote the weights of model after tt times updating, where 𝐉𝐭=𝐉(t)\mathbf{J_{t}}=\mathbf{J}(t). The learning algorithm is defined as 𝒜:μn×𝐉\mathcal{A}:\mu^{n}\times\mathbb{R}\to\mathbf{J}, where the second input \mathbb{R} denotes all randomness in the algorithm 𝒜\mathcal{A}, including the randomness in initialization, batch sampling et al. . We simply use 𝒜(S)\mathcal{A}(S) to represent a random choice for the second input term. Given two functions U,VU,V, tU(t)dV(t)tU(t)(V(t+1)V(t))\int_{t}U(t)\mathrm{d}V(t)\triangleq\sum_{t}U(t)(V(t+1)-V(t)) and we use \|\cdot\| to denote LL2 norm. If SS is a set, then |S||S| denotes the number of elements in SS. 𝔼t\mathbb{E}_{t} denotes taking the expectiation conditioned on {𝐉𝐢|it}\{\mathbf{J_{i}}|i\leq t\}.

Let mini-batch BB be a random subset sampled from dataset SS, and we have |B|=b|B|=b. The averaged function value of mini-batch BB is denoted as FB(𝐰)1bzBf(𝐰,z)F_{B}(\mathbf{w})\triangleq\frac{1}{b}\sum_{z\in B}f(\mathbf{w},z). The parameters updated with gradient descent can be formulated as:

𝐉𝐭+𝟏=𝐉𝐭ηtFS(𝐉𝐭).\mathbf{J_{t+1}}=\mathbf{J_{t}}-\eta_{t}\nabla F_{S}(\mathbf{J_{t}}). (1)

where ηt\eta_{t} is the learning rate for thr tt-th update. The parameter updating with stochastic gradient descent is:

𝐉𝐭+𝟏=𝐉𝐭ηtFB(𝐉𝐭).\mathbf{J_{t+1}}=\mathbf{J_{t}}-\eta_{t}\nabla F_{B}(\mathbf{J_{t}}). (2)

Let ϵ(𝐰)FS(𝐰)FB(𝐰)\epsilon(\mathbf{w})\triangleq\nabla F_{S}(\mathbf{w})-\nabla F_{B}(\mathbf{w}) be the gradient noise in mini-batch updating, where 𝐰\mathbf{w} is the weights of a DNN. Then we can transform Equation (2) into:

𝐉𝐭+𝟏=𝐉𝐭ηtFS(𝐉𝐭)+ηtϵ(𝐉𝐭).\mathbf{J_{t+1}}=\mathbf{J_{t}}-\eta_{t}\nabla F_{S}(\mathbf{J_{t}})+\eta_{t}\epsilon(\mathbf{J_{t}}). (3)

The covariance of the gradients over the entire dataset SS can be calculated as:

Σ(𝐰)1ni=1nf(𝐰,zi)f(𝐰,zi)TFS(𝐰)FS(𝐰)T.\Sigma(\mathbf{w})\triangleq\frac{1}{n}\sum_{i=1}^{n}\nabla f(\mathbf{w},z_{i})\nabla f(\mathbf{w},z_{i})^{\mathrm{T}}-\nabla F_{S}(\mathbf{w})\nabla F_{S}(\mathbf{w})^{\mathrm{T}}. (4)

Therefore, the covariance of the gradient noise ϵ(𝐰)\epsilon(\mathbf{w}) is:

C(𝐰)nbb(n1)Σ(𝐰).C(\mathbf{w})\triangleq\frac{n-b}{b(n-1)}\Sigma(\mathbf{w}). (5)

Since for any ww we have 𝔼(ϵ(𝐰))=0\mathbb{E}(\epsilon(\mathbf{w}))\!=\!0, we can represent ϵ(𝐰)\epsilon(\mathbf{w}) as C(𝐰)12ϵC(\mathbf{w})^{\frac{1}{2}}\epsilon^{\prime}, where ϵ\epsilon^{\prime} is a random distribution whose mean is zero and covariance matrix is an identity matrix. Here, ϵ\epsilon^{\prime} can be any distributions, including Guassian distribution [12] and 𝒮α𝒮\mathcal{S}\alpha\mathcal{S} distribution [34].

The primary objective of our work is to suggest a new generalization bound that incorporates more comprehensive trajectory-related information. The key aspects of this information are: 1) It should be adaptive and change according to different learning trajectories. 2) It should not rely on the extra information from data distribution μ\mu except from the training data SS.

3.1 Investigating generalization alone learning trajectory

As annotated before, the learning trajectory is represented by a function 𝐉:𝒲\mathbf{J}:\mathbb{N}\to\mathcal{W}, which defines the relationship between the model weights and the training timesteps tt. 𝐉t\mathbf{J}_{t} denotes the model weights after tt times updating. Note that 𝐉\mathbf{J} depends on SS, because it comes from the equation 𝐉=𝒜(S)\mathbf{J}=\mathcal{A}(S). We simply use f(𝐉𝐭):𝒵+f(\mathbf{J_{t}}):\mathcal{Z}\rightarrow\mathbb{R}_{+} to represent the function after tt-times update. Our goal is to analyze the generalization error, i.e., Fμ(𝐉𝐓)FS(𝐉𝐓)F_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}}), where TT represents the total training steps.

We reformulate the function corresponding to the finally obtained model as:

f(𝐉𝐓)=f(𝐉𝟎)+t=1T(f(𝐉𝐭)f(𝐉𝐭𝟏)).f(\mathbf{J_{T}})=f(\mathbf{J_{0}})+\sum_{t=1}^{T}(f(\mathbf{J_{t}})-f(\mathbf{J_{t-1}})). (6)

Therefore, the generalization error can be rewritten as:

Fμ(𝐉𝐓)FS(𝐉𝐓)=Fμ(𝐉𝟎)FS(𝐉𝟎)(i)+t=1T[(Fμ(𝐉𝐭)Fμ(𝐉𝐭𝟏))(FS(𝐉𝐭)FS(𝐉𝐭𝟏))](ii)t.F_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}})=\underbrace{F_{\mu}(\mathbf{J_{0}})-F_{S}(\mathbf{J_{0}})}_{(i)}+\sum_{t=1}^{T}\underbrace{[(F_{\mu}(\mathbf{J_{t}})-F_{\mu}(\mathbf{J_{t-1}}))-(F_{S}(\mathbf{J_{t}})-F_{S}(\mathbf{J_{t-1}}))]}_{(ii)_{t}}. (7)

In this form, we divide the generalization error into two parts. (i)(i) is the generalization error before the training. (ii)t(ii)_{t} is the generalization error caused by tt-step update.

Typically, there is independence between 𝐉𝟎\mathbf{J_{0}} and the data SS. Therefore, we have 𝔼(i)=0\mathbb{E}(i)=0. Combining with this, we have:

𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]=𝔼t=1T(ii)t.\mathbb{E}[F_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}})]=\mathbb{E}\sum_{t=1}^{T}\mathbb{(}ii)_{t}. (8)

Analyzing the generalization error after training can be transformed into analyzing the increase of generalization error for each update. This is a straighforward and quite different way to extract the information from learning trajectory compared with previous work. Here, we list two techniques that most used by previous works to extract the information from learning trajectory.

  • (T1). This method leverages the chaining rule of mutual informaton to calculate a upper bound of the mutual information between 𝐉𝐓\mathbf{J_{T}} and the training data SS, i.eI(S;𝐉𝐓)I(S;𝐉𝐭𝐓)t=0TI(S;𝐉𝐭|𝐉𝐢<𝐭)I(S;\mathbf{J_{T}})\leq I(S;\mathbf{J_{t\leq T}})\leq\sum_{t=0}^{T}I(S;\mathbf{J_{t}}|\mathbf{J_{i<t}}). I(S;𝐉𝐓)I(S;\mathbf{J_{T}}) is the value of concerning for their theory.

  • (T2). This method assumes we have another data SS^{\prime}, which is obtained by replacing one sample in data SS with another sample drawing from distribution μ\mu. 𝐉\mathbf{J^{\prime}} is the learning trajectory trained from data SS^{\prime} with same randomness value as 𝐉\mathbf{J}. Denote Δk𝐉𝐤𝐉𝐤\Delta_{k}\triangleq\|\mathbf{J_{k}}-\mathbf{J^{\prime}_{k}}\| and assume Δ0=0\Delta_{0}=0. Then, the value of concerning is ΔT\Delta_{T}. The upper bound of ΔT\Delta_{T} is calculate by iterately apply the formular Δkck1Δk1+ek1\Delta_{k}\leq c_{k-1}\Delta_{k-1}+e_{k-1}.

(T1) is commonly utilized in analyzing Stochastic Gradient Langevin Dynamics(SGLD)[18, 2, 28], while (T2) is frequently employed in stability-based works for analyzing SGD[11, 15, 5]. Our method offers several benefits, including: 1) We directly focus on the change in generalization error, rather than intermediate values such as Δk\Delta_{k} and I(S;𝐉𝐭|𝐉𝐢<𝐭)I(S;\mathbf{J_{t}}|\mathbf{J_{i<t}}), 2) The generalization error is equivalent to the sum of (ii)t(ii)_{t}, while (T1) and (T2) takes the upper bound value of I(S;𝐉𝐓)I(S;\mathbf{J_{T}}) and ΔT\Delta_{T}, and 3) From this perspective, We can extract more in-depth trajectory-related information. For (T1), the computation of I(S;𝐉𝐭|𝐉𝐢<𝐭)I(S;\mathbf{J_{t}}|\mathbf{J_{i<t}}) primarily involves the information of Fμ(𝐉𝐭)\nabla F_{\mu}(\mathbf{J_{t}}), which is inaccessible to us (Detail in Appendix D and Neu et al. [24]). (T2) faces the challenge that only the upper bounds of ckc_{k} and eke_{k} can be calculated. The upper bounds remain unchanged across various learning trajectories. Consequently, both (T1) and (T2) have difficulty conveying meaningful trajectory information.

Table 1: Comparison of the generalization bounds with stability based method for SGD learning algorithms. T.R.T is an abbreviation for the term related to trajectory. T.R.T is defined as the term that 1) varies based on different learning trajectories, and 2) don’t rely on the extra information of data distribution μ\mu except from training data SS. We can infer that the proposed bound incorporates a greater amount of information pertaining to the trajectory. Other related works are discussed in Appendix D.
Method β\beta-Smooth LL-Lipschitz Convex Small LR Other Conditions Generalization Bound T.R.T
Hardt et al. [11] 2L2nt=1Tηt\frac{2L^{2}}{n}\sum_{t=1}^{T}\eta_{t} t=1Tηt\sum_{t=1}^{T}\eta_{t}
Hardt et al. [11] f[0,1],ηt<ctf\in[0,1],\eta_{t}<\frac{c}{t} 𝒪(1nL2βc+1Tβcβc+1)\mathcal{O}(\frac{1}{n}L^{\frac{2}{\beta c+1}}T^{\frac{\beta c}{\beta c+1}}) Tβcβc+1T^{\frac{\beta c}{\beta c+1}}
Zhang et al. [42] T>n,ηt=cβtT>n,\eta_{t}=\frac{c}{\beta t} 16L2Tcn1+c\frac{16L^{2}T^{c}}{n^{1+c}} TcT^{c}
Zhou et al. [43] 𝔼zSf(𝐰,z)FS(𝐰)2B2\mathbb{E}_{z\in S}\|\nabla f(\mathbf{w},z)-\nabla F_{S}(\mathbf{w})\|^{2}\leq B^{2} 𝒪(1nL2βFμ(𝐉𝟎)+12𝔼B2logT)\mathcal{O}(\sqrt{\frac{1}{n}L\sqrt{2\beta F_{\mu}(\mathbf{J_{0}})+\frac{1}{2}\mathbb{E}B^{2}}\log T}) logT\sqrt{\log T}
Bassily et al. [5] Projected SGD 2L2t=1T1ηt2+4L2nt=1T1ηt2L^{2}\sqrt{\sum_{t=1}^{T-1}\eta_{t}^{2}}+\frac{4L^{2}}{n}\sum_{t=1}^{T-1}\eta_{t} t=1T1ηt\sum_{t=1}^{T-1}\eta_{t}
Lei and Ying [16] Projected SGD 𝒪((1+Tn2)t=1Tηt2)\mathcal{O}((1+\frac{T}{n^{2}})\sum_{t=1}^{T}\eta_{t}^{2}) Tt=1Tηt2T\sum_{t=1}^{T}\eta_{t}^{2}
Ours (Theorem 3.6) Fμ(𝐰)γFS(𝐰)\|\nabla F_{\mu}(\mathbf{w})\|\leq\gamma\|\nabla F_{S}(\mathbf{w})\| Theorem 3.6 t𝑑FS(𝐉𝐭)1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)2\int_{t}dF_{S}(\mathbf{J_{t}})\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|^{2}}}

3.2 A New Generalization Bound

In this section, we introduce the generalization bound based on our aforementioned modeling. Let us start with the definition of commonly used assumptions.

Definition 3.1.

The function ff is LL-Lipschitz, if for all 𝐰𝟏,𝐰𝟐𝒲\mathbf{w_{1}},\mathbf{w_{2}}\in\mathcal{W} and for all z𝒵z\in\mathcal{Z}, wherein we have f(𝐰𝟏,z)f(𝐰𝟐,z)L𝐰𝟏𝐰𝟐\|f(\mathbf{w_{1}},z)-f(\mathbf{w_{2}},z)\|\leq L\|\mathbf{w_{1}}-\mathbf{w_{2}}\|.

Definition 3.2.

The function ff is β\beta-smooth, if for all 𝐰𝟏,𝐰𝟐𝒲\mathbf{w_{1}},\mathbf{w_{2}}\in\mathcal{W} and for all z𝒵z\in\mathcal{Z}, wherein we have f(𝐰𝟏,z)f(𝐰𝟐,z)β𝐰𝟏𝐰𝟐\|\nabla f(\mathbf{w_{1}},z)-\nabla f(\mathbf{w_{2}},z)\|\leq\beta\|\mathbf{w_{1}}-\mathbf{w_{2}}\|.

Definition 3.3.

The function ff is convex, if for all 𝐰𝟏,𝐰𝟐𝒲\mathbf{w_{1}},\mathbf{w_{2}}\in\mathcal{W} and for all z𝒵z\in\mathcal{Z}, wherein we have f(𝐰𝟏,z)f(𝐰𝟐,z)+(𝐰𝟏𝐰𝟐)Tf(𝐰𝟐,z)f(\mathbf{w_{1}},z)\geq f(\mathbf{w_{2}},z)+(\mathbf{w_{1}}-\mathbf{w_{2}})^{\mathrm{T}}\nabla f(\mathbf{w_{2}},z).

Here, LL-lipschitz assumption implies that the f(𝐰,z)L\|\nabla f(\mathbf{w},z)\|\leq L holds. β\beta-smooth assumption indicates the largest eignvalue of 2f(𝐰,z)\nabla^{2}f(\mathbf{w},z) is smaller than β\beta. The convexity indicates the smallest eigenvalue of 2f(𝐰,z)\nabla^{2}f(\mathbf{w},z) are positive. These assumptions tell us the constraints of gradients and Hessian matrices of the training data and the unseen samples in the test set. Since the values of gradients and Hessian matrices in the training set are accessible, the key role of these assumptions is to deliver knowledge about the unseen samples in the test set.

In the following, we introduce a new generalization bound. We give the assumption required by our new generalization bound in the following.

Assumption 3.4.

There is a value γ\gamma, so that for all 𝐰{𝐉𝐭|t}\mathbf{w}\in\{\mathbf{J_{t}}|t\in\mathbb{N}\}, we have Fμ(𝐰)γFS(𝐰)\|\nabla F_{\mu}(\mathbf{w})\|\leq\gamma\|\nabla F_{S}(\mathbf{w})\|.

Remark 3.5.

Assumption 3.4 gives a restriction with the norm of popular gradient Fμ(𝐰)\nabla F_{\mu}(\mathbf{w}). This assumption is easily satisfied when nn is a large number, because we have limnFS(𝐰)=Fμ(𝐰)\lim\limits_{n\rightarrow\infty}\|\nabla F_{S}(\mathbf{w})\|=\|\nabla F_{\mu}(\mathbf{w})\|. When the nn is not large enough, the assumption will hold before SGD enter the neighbourhood of convergent point. Under the case that SGD enters the neighbourhood of convergent point, we give a relaxed assumption and its corresponding generalization bound in Appendix B. According to paper [41], this case will ununsually happen in real situation. Section 4 gives experiments to explore the assumption.

Theorem 3.6.

Under Assumption 3.4, given SμnS\sim\mu^{n}, let 𝐉=𝒜(S)\mathbf{J}=\mathcal{A}(S), where 𝒜\mathcal{A} denoted the SGD or GD algorithm training with TT steps, we have:

𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]2γ𝕍m𝔼tdFS(𝐉𝐭)n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22+𝒪(ηm)\mathbb{E}[F_{\mu}(\mathbf{J_{T}})-F_{S}(\mathbf{J_{T}})]\leq-2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}}+\mathcal{O}(\eta_{m}) (9)

where 𝕍(𝐰)=FS(𝐰)𝔼US|U|nFU(𝐰)n|U|nFS/U(𝐰)\mathbb{V}(\mathbf{w})=\frac{\|\nabla F_{S}(\mathbf{w})\|}{\mathbb{E}_{U\subset S}\|\frac{|U|}{n}\nabla F_{U}(\mathbf{w})-\frac{n-|U|}{n}\nabla F_{S/U}(\mathbf{w})\|}, 𝕍m=maxt𝕍(𝐉𝐭)\mathbb{V}_{m}\!=\!\max\limits_{t}\mathbb{V}(\mathbf{J_{t}}), γ=max{1,maxUS;t|U|FU(𝐉𝐭)nFS(𝐉𝐭)}γ\gamma^{\prime}\!=\!\max\{1,\max\limits_{U\subset S;t}\frac{|U|\|\nabla F_{U}(\mathbf{J_{t}})\|}{n\|\nabla F_{S}(\mathbf{J_{t}})\|}\}\gamma and ηmmaxtηt\eta_{m}\triangleq\max\limits_{t}\eta_{t}.

Remark 3.7.

Our generalization bound mainly relies on the information from gradients. 𝕍(𝐰)\mathbb{V}(\mathbf{w}) is related to the variance of the gradient. When the variance of the gradients across different samples in the training set SS is large, then the value of 𝕍(𝐰)\mathbb{V}(\mathbf{w}) is small, and vice versa. Note that we have |U|<n|U|<n due to USU\subset S. Our bound will became trival if 𝔼US|U|nFU(𝐰)n|U|nFS/U(𝐰)=0\mathbb{E}_{U\subset S}\|\frac{|U|}{n}\nabla F_{U}(\mathbf{w})-\frac{n-|U|}{n}\nabla F_{S/U}(\mathbf{w})\|=0. This rarely happens in real case, because it requires that for all USU\subset S, we have |U|FU(𝐰)=(n|U|)FS/U(𝐰)|U|\nabla F_{U}(\mathbf{w})=(n-|U|)\nabla F_{S/U}(\mathbf{w}). A example of linear regression case is given in Appendix LABEL:subsec:example. We also give a relaxed assumption version of this theorem in Appendix B. The generalization bound provides a clear insight into how the reduction of training loss leads to a increase in generalization error.

Proof Sketch

The proof of this theorem is placed in Appendix A. Here, we give the sketch for this proof.

Step 1

Beginning with Equation (8), we decomposite the Fμ(𝐉𝐓)FS(𝐉𝐓)F_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}}) into a linear part (genlin(𝐉𝐓)\operatorname{gen}^{lin}(\mathbf{J_{T}})) and nonlinear part(gennl(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}})). We have genlin(𝐉𝐓)=t=1T(ii)tlin\operatorname{gen}^{lin}(\mathbf{J_{T}})=\sum_{t=1}^{T}(ii)^{lin}_{t}, where (ii)tlin(𝐉𝐭𝐉𝐭𝟏)T(Fμ(𝐉𝐭𝟏)FS(𝐉𝐭𝟏))(ii)^{lin}_{t}\triangleq(\mathbf{J_{t}}-\mathbf{J_{t-1}})^{\mathrm{T}}(\nabla F_{\mu}(\mathbf{J_{t-1}})-\nabla F_{S}(\mathbf{J_{t-1}})). The nonlinear part is gennl(𝐉𝐓)=Fμ(𝐉𝐓)FS(𝐉𝐓)genlin(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}})=F_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}})-\operatorname{gen}^{lin}(\mathbf{J_{T}}). We takle these two parts differently. Here, we focus on analyzing genlin(𝐉𝐓)\operatorname{gen}^{lin}(\mathbf{J_{T}}) because it dominates under small learning rate. Detail discussion of gennl(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}}) is given in Appendix (Propositon A.1 and Subsection C.3)

Step 2

We construct the addictive linear space 𝐉|S{t=0T1𝐰𝐭Tf(𝐉𝐭)|𝐰𝐭Δt}\mathcal{L}_{\mathbf{J}|S}\triangleq\{\sum_{t=0}^{T-1}\mathbf{w_{t}}^{\mathrm{T}}\nabla f(\mathbf{J_{t}})\ |\ \|\mathbf{w_{t}}\|\leq\Delta_{t}\}, where ΔtηtFS(𝐉𝐭)\Delta_{t}\triangleq\|\eta_{t}\nabla F_{S}(\mathbf{J_{t}})\|. Then 𝔼[genlin(𝐉𝐓)]2γ𝕍m𝔼RS(𝐉|S)\mathbb{E}[\operatorname{gen}^{lin}(\mathbf{J_{T}})]\leq 2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}R_{S}(\mathcal{L}_{\mathbf{J}|S}), where RS(𝐉|S)𝔼σsuph𝐉|S(1ni=1nσih(zi))R_{S}(\mathcal{L}_{\mathbf{J}|S})\triangleq\mathbb{E}_{\sigma}\sup\limits_{h\in\mathcal{L}_{\mathbf{J}|S}}(\frac{1}{n}\sum_{i=1}^{n}\sigma_{i}h(z_{i})).

Step 3

Finally, we compute the upper bound of RS(𝐉|S)R_{S}(\mathcal{L}_{\mathbf{J}|S}), which follows same techniques used in Radermacher Complexity theory. By combining this with Proposition A.1, we establish the theorem.

Technical Novety

Directly applying the Rademacher complexity to calculate the generalization error bound fails because the large complexity of neural network’s function space leads to trival bound[40]. In this work, we want to calculate the complexity of the function space that can be explored during the training process. However, there are two challenges here. First, the trajectory of neural network is a "line", instead of a function space that can be calculated the complexity. To solve this problem, we indroduce the addictive linear space 𝐉|S\mathcal{L}_{\mathbf{J}|S}. This space contains the local information of learning trajectory, and can serve as the pseudo function space. Second, the function space 𝐉|S\mathcal{L}_{\mathbf{J}|S} has a dependent on the sample set SS, while the theory of Rademacher complexity requires that the function space is independent with training samples. To decouple this dependence, we adapt the Rademacher complexity and we obtain that 𝔼[genlin(𝐉𝐓)]2γ𝕍m𝔼RS(𝐉|S)\mathbb{E}[\operatorname{gen}^{lin}(\mathbf{J_{T}})]\leq 2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}R_{S}(\mathcal{L}_{\mathbf{J}|S}). Here, γ\gamma^{\prime} is indroduced to decouple the dependent fact mentioned above.

Compared with Previous Works

In Table 1, we present a summary of stability-based methods, while other methods are outlined in Appendix D. We focus on generalization bounds from previous works that eliminate terms dependent on extra information about data distribution μ\mu, apart from the training data SS, using assumptions such as smoothness or Lipschitz continuity. Analyzing Table 1 reveals that most prior works primarily depend on the learning rate η\eta and the total number of training steps TT. This suggests that we can achieve the same bound by using an identical learning rate schedule and total training steps, which does not align with our practical experience. Our proposed generalization bound considers the evolution of function values, gradient covariance, and gradient norms throughout the training process. As a result, our bounds encompass more comprehensive information about the learning trajectory.

Asymptotic Analysis

We will first analyze the dependent of nn for 𝕍\mathbb{V}. The 𝕍\mathbb{V} is calculated as 𝕍(𝐰)=FS(𝐰)𝔼US|U|nFU(𝐰)n|U|nFS/U(𝐰)\mathbb{V}(\mathbf{w})=\frac{\|\nabla F_{S}(\mathbf{w})\|}{\mathbb{E}_{U\subset S}\|\frac{|U|}{n}\nabla F_{U}(\mathbf{w})-\frac{n-|U|}{n}\nabla F_{S/U}(\mathbf{w})\|}. Obviously, the gradient of individual sample is unrelated to the sample size nn. And |U|n|U|\sim n. Therefore, 𝕍=𝒪(1)\mathbb{V}=\mathcal{O}(1). Similarly, we have 𝔼tdFS(𝐉𝐭)n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22=𝒪(1n)\mathbb{E}\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}}=\mathcal{O}(\frac{1}{\sqrt{n}}). As for the 𝒪(ηm)\mathcal{O}(\eta_{m}) term in Theorem 3.6, we have limn𝒪(ηm)=0\lim\limits_{n\to\infty}\mathcal{O}(\eta_{m})=0 according to Proposition A.1. We simply assume that 𝒪(ηm)=𝒪(1nc)\mathcal{O}(\eta_{m})=\mathcal{O}(\frac{1}{n^{c}}). Therefore, our bound has 𝒪(1nmin{0.5,c})\mathcal{O}(\frac{1}{n^{\text{min}\{0.5,c\}}})

Next, in order to draw a clearer comparison with the stability-based method, we present the following corollary. This corollary employs the β\beta-smooth assumption to bound gennl(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}}) and leverages a similar learning rate setting to that found in stability based works.

Corollary 3.8.

If function f()f(\cdot) is β\beta-smooth, under Assumption 3.4 given SμnS\sim\mu^{n}, let 𝐉=𝒜(S)\mathbf{J}=\mathcal{A}(S), ηt=cβ(t+1)\eta_{t}=\frac{c}{\beta(t+1)}, M22=maxt𝔼t1(FS(𝐉𝐭)+ϵ(𝐉𝐭)2)M^{2}_{2}=\max\limits_{t}\mathbb{E}_{t-1}(\|\nabla F_{S}(\mathbf{J_{t}})+\epsilon(\mathbf{J_{t}})\|^{2}) and M44=maxt𝔼t1(FS(𝐉𝐭)+ϵ(𝐉𝐭)4)M^{4}_{4}=\max\limits_{t}\mathbb{E}_{t-1}(\|\nabla F_{S}(\mathbf{J_{t}})+\epsilon(\mathbf{J_{t}})\|^{4}) , where 𝒜\mathcal{A} denoted the SGD or GD algorithm training with TT steps, we have:

𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]\displaystyle\mathbb{E}[F_{\mu}(\mathbf{J_{T}})-F_{S}(\mathbf{J_{T}})]\leq 2γ𝕍m𝔼tdFS(𝐉𝐭)n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22\displaystyle-2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}} (10)
+2c2γ𝕍mM42𝔼tdtnβ2(t+1)4(1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22)\displaystyle+2c^{2}\gamma^{\prime}\mathbb{V}_{m}M_{4}^{2}\sqrt{\mathbb{E}\int_{t}\frac{dt}{n\beta^{2}(t+1)^{4}}\left(1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}\right)}
+2c2M22β.\displaystyle+2c^{2}\frac{M_{2}^{2}}{\beta}.

where 𝕍(𝐰)=FS(𝐰)𝔼US|U|nFU(𝐰)n|U|nFS/U(𝐰)\mathbb{V}(\mathbf{w})=\frac{\|\nabla F_{S}(\mathbf{w})\|}{\mathbb{E}_{U\subset S}\|\frac{|U|}{n}\nabla F_{U}(\mathbf{w})-\frac{n-|U|}{n}\nabla F_{S/U}(\mathbf{w})\|}, 𝕍m=maxt𝕍(𝐉𝐭)\mathbb{V}_{m}\!=\!\max\limits_{t}\mathbb{V}(\mathbf{J_{t}}) and γ=max{1,maxUS;t|U|FU(𝐉𝐭)nFS(𝐉𝐭)}γ\gamma^{\prime}\!=\!\max\{1,\max\limits_{U\subset S;t}\frac{|U|\|\nabla F_{U}(\mathbf{J_{t}})\|}{n\|\nabla F_{S}(\mathbf{J_{t}})\|}\}\gamma.

3.3 Analysis

3.3.1 Generalization Bounds

Our obtained generalization bound is:

𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]γBias of Training Set𝕍m1Diversity of Training Set(2𝔼tdFS(𝐉𝐭)n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22)Complexity of Learning Trajectory+𝒪(ηm)\mathbb{E}[F_{\mu}(\mathbf{J_{T}})-F_{S}(\mathbf{J_{T}})]\leq\!\underbrace{\gamma^{\prime}}_{\text{\tiny Bias of Training Set}}\!\overbrace{\mathbb{V}_{m}}^{\frac{1}{\text{Diversity of Training Set}}}\!\underbrace{\!\left(\!-2\mathbb{E}\!\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}}\!\right)\!}_{\text{Complexity of Learning Trajectory}}+\mathcal{O}(\!\eta_{m}\!) (11)

The "Bias of Training Set" refers to the disparity between the characteristics of the training set and those of the broader population. To measure this difference, we use the distance between the norm of the popular gradient and that of the training set gradient, as specified in Assumption 3.4. The "Diversity of Training Set" can be understood as the variation among the samples in the training set, which in turn affects the quality of the training data. The ratio Bias of Training SetDiversity of Training Set\frac{\text{Bias of Training Set}}{\text{Diversity of Training Set}} gives us the property of information conveyed by the training set. It is important to consider the properties of the training set, as the data may not contribute equally to the generalization[35]. The detail version of the equation can be found in Theorem 3.6.

3.3.2 Comparison with Uniform Stability Results

Here, we compare our modelling with uniform stability [11] from several perspectives in Table 2.

Table 2: Comparison with Uniform Stability Methods. The β\beta refers to the β\beta-smooth assumption (see in Definition 3.2). SS denotes the training set.
Uniform Stability[11] Ours
Assumption 𝐰𝒲z𝒵f(𝐰,z)L\forall\mathbf{w}\in\mathcal{W}\quad\forall z^{\prime}\in\mathcal{Z}\quad\|\nabla f(\mathbf{w},z^{\prime})\|\leq L 𝐰{𝐉𝐭|t}𝔼zμf(𝐰,z)γ𝔼zSf(𝐰,z)\forall\mathbf{w}\in\{\mathbf{J_{t}}|t\in\mathbb{N}\}\quad\|\mathbb{E}_{z^{\prime}\sim\mu}\nabla f(\mathbf{w},z^{\prime})\|\leq\gamma\|\mathbb{E}_{z\in S}\nabla f(\mathbf{w},z^{\prime})\|
Modelling Method of SGD Epoch Structure Full Batch Gradient + Stochastic Noise
Batch Size 1 n\leq n
Trajectory Information in Bound Learning rate and number of training step Values in Trajectory (gradient norm and covariance)
Perspective Stability of Algorithm Complexity of Learning Trajectory

The concept of uniform stability is commonly used to evaluate the ability of SGD in generalizaton, by assessing its stability when a single training sample is altered. Our primary point of comparison is with Hardt et al. [11], as their work is considered the most representative in terms of analyzing the stability of SGD.

First, the assumption of Uniform Stability requires the gradient norm of all input samples for all weights being bounded by LL, whereas our assumption only limits the expectation of the gradients for the weights during the learning trajectory. Secondly, Uniform Stability uses an epoch structure to model the stochastic gradient descent, whereas our approach regards each stochastic gradient descent as full batch gradient descent with added stochastic noise. The epoch structure complicates the modelling process because it requires a consideration of sampling. As a result, in Hardt et al. [11], the author only considers the setting with batch size 1. Thirdly, the bound of Uniform Stability only uses hyperparameters setting such as learning rate and number of training step. In contrast, our bound contains more trajectory-related information, such as the gradient norm and covariance. Finally, the Uniform Stability provides the generalization bound based on the stability of the algorithm, while our approach leverages the complexity of the learning trajectory. In summary, there are some notable differences between our approach and Uniform Stability, such as the assumptions made, the modelling process, the type of information used in the bound, and the perspectives.

4 Experiments

4.1 Tightness of Our Bounds

Table 3: Numeric comparison with stability-based work on toy examples. The reason for the value of Zhang et al. [42] is large is because that our and Hardt et al. [11] has dependent on L2β\frac{L^{2}}{\beta}, while Hardt et al. [11] depends on L2L^{2}. LL and β\beta are usually large numbers.
Gen Error Ours Hardt et al. [11] Zhang et al. [42]
1.49 3.62 4.04 4417

In a toy dataset setting, we compare our generalization bound with stability-based methods.

Reasons for toy examples

1) Some values in the bounds are hard to be calculated. Calculating β\beta (under the β\beta-smooth assumption) and LL (under the LL-Lipschitz assumption) in stability-based work, as well as the values of 𝕍\mathbb{V} and γ\gamma in our proposed bound, are challenging. 2) Stability-based methods require a batch size of 1. The training is hard for batch size of 1 with learning rate setting ηt=1t\eta_{t}=\frac{1}{t} in complex datasets.

Constuction of the toy examples

In the following, we discuss the construction of the toy dataset used to compare the tightness of the generalization bounds. The training data is Xtr={xi}i=1nX_{tr}=\{x_{i}\}_{i=1}^{n}. All the data xix_{i} is sampled from Guassian distribution 𝒩(0,𝐈d)\mathcal{N}(0,\mathbf{I}_{d}). Sampling 𝐰~𝒩(0,𝐈d)\tilde{\mathbf{w}}\sim\mathcal{N}(0,\mathbf{I}_{d}),the ground truth is generated by yi=1if𝐰~Txi>0else 0y_{i}=1\ \ \text{if}\ \ \tilde{\mathbf{w}}^{\mathrm{T}}x_{i}>0\ \ \text{else}\ \ 0. The weights for learning is denoted as 𝐰\mathbf{w}. The predict y~\tilde{y} is calculated as y~i=𝐰Txi\tilde{y}_{i}=\mathbf{w}^{\mathrm{T}}x_{i}. The loss for a simple data point is li=yi𝐰Txi2l_{i}=\left\|y_{i}-\mathbf{w}^{\mathrm{T}}x_{i}\right\|_{2}. The training loss is =i=1nli\mathcal{L}=\sum_{i=1}^{n}l_{i}. The test data is Xte={xi}X_{te}=\{x^{\prime}_{i}\}, where xi=x~ix^{\prime}_{i}=\tilde{x}^{\prime}_{i} and x~i𝒩(0,𝐈d)\tilde{x}^{\prime}_{i}\sim\mathcal{N}(0,\mathbf{I}_{d}). We use 100 samples for training and 1,000 samples for evaluation. The model is trained using SGD for 200 epochs.

We evaluate the tightness of our bound by comparing our results with those in references Hardt et al. [11] and Zhang et al. [42] from the original paper. We set the learning rate as ηt=1βt\eta_{t}=\frac{1}{\beta t}. Our reasons for comparing with these two papers are: 1) Hardt et al. [11] is a representative study, 2) Both papers have theorems using a learning rate setting of ηt=𝒪(1t)\eta_{t}=\mathcal{O}(\frac{1}{t}), which aligns with Corollary 3.8 in our paper, and 3) They do not assume convexity. The generalization bounds we compare include Corollary 3.8 from our paper, Theorem 3.12 from Hardt et al. [11], and Theorem 5 from Zhang et al. [42].

Our results are given in Table 3. Our bound is tighter under this setting.

4.2 Capturing the trend of generalization error

In this section, 1) we conduct the deep learning experiment to verify Assumption 3.4 and 2) Verify whether our proposed generalization bound can capture the changes of generalization error. In this experiment, we mainly consider the term 𝒞(𝐉𝐭)2i=0tdFS(𝐉𝐢)n1+Tr(Σ(𝐉𝐢))FS(𝐉𝐢)22\mathcal{C}(\mathbf{J_{t}})\triangleq-2\int_{i=0}^{t}\frac{dF_{S}(\mathbf{J_{i}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{i}}))}{\|\nabla F_{S}(\mathbf{J_{i}})\|_{2}^{2}}}. We omit the term γ\gamma^{\prime} and 𝕍m\mathbb{V}_{m}, because all the trajectory related information that we want to explore is stored in 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}). Capturing the trend of generalization error is regarded as an important problem in Nagarajan [21]. Unless further specified, we use the default setting of the experiments on CIFAR-10 dataset [14] with the VGG13 [33] network. The experimental details for each figure can be found in Appendix C.2.

Our observations are:

  • Assumption 3.4 is valid when SGD is not exhibiting extreme overfitting.

  • The term of 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) can depict how the generalization error varies along the training process. And it can also track the changes in generalization error when adjustments are made to learnling rates and label noise levels

Refer to caption
Figure 1: Exploration of Assumption 3.4 for different dataset. The γ~t\widetilde{\gamma}_{t} is stable before training loss reaches a relative small value. Assumption holds if the training is stop before extremely overfitting. A relaxed assumption and its corresponding generalization bound are given in Appendix B for extremely overfitting situation
Exploring the assumption 3.4 for different dataset during the training process

To explore the Assumption 3.4, we define γtFμ(𝐉𝐭)FS(𝐉𝐭)\gamma_{t}\triangleq\frac{\|\nabla F_{\mu}(\mathbf{J_{t}})\|}{\|\nabla F_{S}(\mathbf{J_{t}})\|} and γ~tFS(𝐉𝐭)FS(𝐉𝐭)\widetilde{\gamma}_{t}\triangleq\frac{\|\nabla F_{S^{\prime}}(\mathbf{J_{t}})\|}{\|\nabla F_{S}(\mathbf{J_{t}})\|}, where SS^{\prime} is another data set i.i.d sampled from distribution μ\mu. Because SS^{\prime} is independent with SS, we have γ~tγt\widetilde{\gamma}_{t}\approx\gamma_{t}. We found that γ~t\widetilde{\gamma}_{t} is stable around 1 during the early stage of training(Figure 1). When the training loss is reaching a relative small value, γ~t\widetilde{\gamma}_{t} increases as we continue training. This phenomenon remain consistant aross the Cifar10, Cifar100 and SVHN datasets. The γ\gamma in Assumption 3.4 can be assigned as γ=maxtγ~t\gamma=\max_{t}\widetilde{\gamma}_{t}. We can always find such γ\gamma if the optimizer is not extreme overfitting. Under the extremely overfitting case, we can use the relaxed theorem in Appendix B to bound the generalization error.

The bound capturing the trend of generalization error during training process
Refer to caption
Figure 2: Exploration of 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) during the training process. Left: The curve FS(𝐉𝐭)+𝒞(𝐉𝐭)F_{S}(\mathbf{J_{t}})+\mathcal{C}(\mathbf{J_{t}}) exhibits a comparable Pattern with the curve FS(𝐉𝐭)F_{S^{\prime}}(\mathbf{J_{t}}). Center: After the early stage, 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) and FS(𝐉𝐭)FS(𝐉𝐭)\nabla F_{S^{\prime}}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}) have a similar trend. Right: The value of d𝒞(𝐉𝐭)dFS(𝐉𝐭)\frac{d\mathcal{C}(\mathbf{J_{t}})}{dF_{S}(\mathbf{J_{t}})} alone the training process.

The generalization error and the 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) both changes as the training continues. Therefore, we want to verify whether they correlate with each other during the training process. Here, we use the term FS(𝐉𝐭)FS(𝐉𝐭)\nabla F_{S^{\prime}}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}) to approximate the generalization error. We find that FS(𝐉𝐭)FS(𝐉𝐭)\nabla F_{S^{\prime}}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}) has similar trend with 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) (Figure 2 Center). What’s more, we also find that the curve of FS(𝐉𝐭)+𝒞(𝐉𝐭)\nabla F_{S}(\mathbf{J_{t}})+\mathcal{C}(\mathbf{J_{t}}) exhibits a comparable pattern with the curve FS(𝐉𝐭)F_{S^{\prime}}(\mathbf{J_{t}}) (Figure 2 Left). To explore whether 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) reveals influence of the change of FS(𝐉𝐭)F_{S}(\mathbf{J_{t}}) to the generalization error, we plot d𝒞(𝐉𝐭)dFS(𝐉𝐭)\frac{d\mathcal{C}(\mathbf{J_{t}})}{dF_{S}(\mathbf{J_{t}})} (Figure 2 Right) during the training process. d𝒞(𝐉𝐭)dFS(𝐉𝐭)\frac{d\mathcal{C}(\mathbf{J_{t}})}{dF_{S}(\mathbf{J_{t}})} increases slowly during the early stage of training, but surge rapidly afterward. This discovery is aligned with our intuition about the overfitting.

Refer to caption
Figure 3: 𝒞(𝐉𝐓)\mathcal{C(\mathbf{J_{T}})} correlates with FS(𝐉𝐭)FS(𝐉𝐭)\nabla F_{S^{\prime}}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}). Left: 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) and the generalization error under different label noise level. Right: 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) and the generalization error under learning rate. The 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) can capture the trend of generalization error cased by learning rate when learning rate is small. Appendix E provides proof that a large learning rate results in a smaller proposed generalization bound. Further discussions on why a small learning rate leads to a larger generalization error can be found in Li et al. [17], Barrett and Dherin [3].
The complexity of learning trajectory correlates with the generalization error

In Figure 3, we carry out experiments under various settings. Each data point in the figure represents the average of three repeated experiments. The results demonstrate that both the generalization error and 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) increase as the level of label noise is raised (Figure 3 Left). The another experiments measure 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) and generalization error for different learning rate and discover that 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}) can capture the trend generalization error. The reasons behind a larger learning rate resulting in a smaller generalization error have been explored in Li et al. [17], Barrett and Dherin [3]. Additionally, Appendix E discusses why a larger learning rate can lead to a smaller 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}}).

5 Limitation

The assumption of small learning rate is required by our method. But this assumption is also common use in previous works. For example, Hardt et al. [11], Zhang et al. [42], Zhou et al. [43] explicitly requires that the learning rate should be small and is decayed with a rate of 𝒪(1t)\mathcal{O}(\frac{1}{t}). Some methods have no explict requirements about this but show that large learning rate pushes the generalization bounds to a trivial point. For example, the generalization bounds in works [5, 16] have a term t=1Tηt2\sum_{t=1}^{T}\eta_{t}^{2} that is not decayed as the data size nn increases. The value of this term is unignorable when the learning rate is large. The small learning assumption widens the gap between theory and practice. Eliminating this assumption is crucial for future work.

6 Conclusion

In this study, we investigate the relation between learning trajectories and generalization capabilities of Deep Neural Networks (DNNs) from a unique standpoint. We show that learning trajectories can serve as reliable predictors for DNNs’ generalization performance. To understand the relation between learning trajectory and generalization error, we analyze how each update step impacts the generalization error. Based on this, we propose a novel generalization bound that encompasses extensive information related to the learning trajectory. The conducted experiments validate our newly proposed assumption. Experimental findings reveal that our method effectively captures the generalization error throughout the training process. Furthermore, our approach can also track changes in generalization error when adjustments are made to learning rates and the level of label noises.

7 Acknowledgement

We thank all the anonymous reviewers for their valuable comments. The work was supported in part with the National Natural Science Foundation of China (Grant No. 62088102).

References

  • Ahn et al. [2022] K. Ahn, J. Zhang, and S. Sra. Understanding the unstable convergence of gradient descent. In International Conference on Machine Learning, pages 247–257. PMLR, 2022.
  • Banerjee et al. [2022] A. Banerjee, T. Chen, X. Li, and Y. Zhou. Stability based generalization bounds for exponential family langevin dynamics. arXiv preprint arXiv:2201.03064, 2022.
  • Barrett and Dherin [2020] D. G. Barrett and B. Dherin. Implicit gradient regularization. arXiv preprint arXiv:2009.11162, 2020.
  • Bartlett and Mendelson [2002] P. L. Bartlett and S. Mendelson. Rademacher and gaussian complexities: Risk bounds and structural results. Journal of Machine Learning Research, 3(Nov):463–482, 2002.
  • Bassily et al. [2020] R. Bassily, V. Feldman, C. Guzmán, and K. Talwar. Stability of stochastic gradient descent on nonsmooth convex losses. Advances in Neural Information Processing Systems, 33:4381–4391, 2020.
  • Belkin et al. [2019] M. Belkin, D. Hsu, S. Ma, and S. Mandal. Reconciling modern machine-learning practice and the classical bias–variance trade-off. Proceedings of the National Academy of Sciences, 116(32):15849–15854, 2019.
  • Bousquet and Elisseeff [2002] O. Bousquet and A. Elisseeff. Stability and generalization. The Journal of Machine Learning Research, 2:499–526, 2002.
  • Chandramoorthy et al. [2022] N. Chandramoorthy, A. Loukas, K. Gatmiry, and S. Jegelka. On the generalization of learning algorithms that do not converge. arXiv preprint arXiv:2208.07951, 2022.
  • Cohen et al. [2021] J. M. Cohen, S. Kaur, Y. Li, J. Z. Kolter, and A. Talwalkar. Gradient descent on neural networks typically occurs at the edge of stability. arXiv preprint arXiv:2103.00065, 2021.
  • Haghifam et al. [2020] M. Haghifam, J. Negrea, A. Khisti, D. M. Roy, and G. K. Dziugaite. Sharpened generalization bounds based on conditional mutual information and an application to noisy, iterative algorithms. Advances in Neural Information Processing Systems, 33:9925–9935, 2020.
  • Hardt et al. [2016] M. Hardt, B. Recht, and Y. Singer. Train faster, generalize better: Stability of stochastic gradient descent. In International conference on machine learning, pages 1225–1234. PMLR, 2016.
  • Jastrzebski et al. [2017] S. Jastrzebski, Z. Kenton, D. Arpit, N. Ballas, A. Fischer, Y. Bengio, and A. Storkey. Three factors influencing minima in sgd. arXiv preprint arXiv:1711.04623, 2017.
  • Jastrzebski et al. [2021] S. Jastrzebski, D. Arpit, O. Astrand, G. B. Kerg, H. Wang, C. Xiong, R. Socher, K. Cho, and K. J. Geras. Catastrophic fisher explosion: Early phase fisher matrix impacts generalization. In International Conference on Machine Learning, pages 4772–4784. PMLR, 2021.
  • Krizhevsky et al. [2009] A. Krizhevsky, G. Hinton, et al. Learning multiple layers of features from tiny images. 2009.
  • Lei [2022] Y. Lei. Stability and generalization of stochastic optimization with nonconvex and nonsmooth problems. arXiv preprint arXiv:2206.07082, 2022.
  • Lei and Ying [2020] Y. Lei and Y. Ying. Fine-grained analysis of stability and generalization for stochastic gradient descent. In International Conference on Machine Learning, pages 5809–5819. PMLR, 2020.
  • Li et al. [2019] Y. Li, C. Wei, and T. Ma. Towards explaining the regularization effect of initial large learning rate in training neural networks. Advances in Neural Information Processing Systems, 32, 2019.
  • Luo et al. [2022] X. Luo, B. Luo, and J. Li. Generalization bounds for gradient methods via discrete and continuous prior. Advances in Neural Information Processing Systems, 35:10600–10614, 2022.
  • McAllester [1999] D. A. McAllester. Pac-bayesian model averaging. In Proceedings of the twelfth annual conference on Computational learning theory, pages 164–170, 1999.
  • Mohri et al. [2018] M. Mohri, A. Rostamizadeh, and A. Talwalkar. Foundations of machine learning. MIT press, 2018.
  • Nagarajan [2021] V. Nagarajan. Explaining generalization in deep learning: progress and fundamental limits. arXiv preprint arXiv:2110.08922, 2021.
  • Nagarajan and Kolter [2019] V. Nagarajan and J. Z. Kolter. Uniform convergence may be unable to explain generalization in deep learning. Advances in Neural Information Processing Systems, 32, 2019.
  • Negrea et al. [2019] J. Negrea, M. Haghifam, G. K. Dziugaite, A. Khisti, and D. M. Roy. Information-theoretic generalization bounds for sgld via data-dependent estimates. Advances in Neural Information Processing Systems, 32, 2019.
  • Neu et al. [2021] G. Neu, G. K. Dziugaite, M. Haghifam, and D. M. Roy. Information-theoretic generalization bounds for stochastic gradient descent. In Conference on Learning Theory, pages 3526–3545. PMLR, 2021.
  • Nikolakakis et al. [2022] K. E. Nikolakakis, F. Haddadpour, A. Karbasi, and D. S. Kalogerias. Beyond lipschitz: Sharp generalization and excess risk bounds for full-batch gd. arXiv preprint arXiv:2204.12446, 2022.
  • Oksendal [2013] B. Oksendal. Stochastic differential equations: an introduction with applications. Springer Science & Business Media, 2013.
  • Park et al. [2022] S. Park, U. Simsekli, and M. A. Erdogdu. Generalization bounds for stochastic gradient descent via localized ϵ\epsilon-covers. Advances in Neural Information Processing Systems, 35:2790–2802, 2022.
  • Pensia et al. [2018] A. Pensia, V. Jog, and P.-L. Loh. Generalization error bounds for noisy, iterative algorithms. In 2018 IEEE International Symposium on Information Theory (ISIT), pages 546–550. IEEE, 2018.
  • Russo and Zou [2016] D. Russo and J. Zou. Controlling bias in adaptive data analysis using information theory. In Artificial Intelligence and Statistics, pages 1232–1240. PMLR, 2016.
  • Russo and Zou [2019] D. Russo and J. Zou. How much does your data exploration overfit? controlling bias via information usage. IEEE Transactions on Information Theory, 66(1):302–323, 2019.
  • Sagun et al. [2017] L. Sagun, U. Evci, V. U. Guney, Y. Dauphin, and L. Bottou. Empirical analysis of the hessian of over-parametrized neural networks. arXiv preprint arXiv:1706.04454, 2017.
  • Shalev-Shwartz and Ben-David [2014] S. Shalev-Shwartz and S. Ben-David. Understanding machine learning: From theory to algorithms. Cambridge university press, 2014.
  • Simonyan and Zisserman [2014] K. Simonyan and A. Zisserman. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556, 2014.
  • Simsekli et al. [2019] U. Simsekli, L. Sagun, and M. Gurbuzbalaban. A tail-index analysis of stochastic gradient noise in deep neural networks. In International Conference on Machine Learning, pages 5827–5837. PMLR, 2019.
  • Sorscher et al. [2022] B. Sorscher, R. Geirhos, S. Shekhar, S. Ganguli, and A. Morcos. Beyond neural scaling laws: beating power law scaling via data pruning. Advances in Neural Information Processing Systems, 35:19523–19536, 2022.
  • Vapnik [1991] V. Vapnik. Principles of risk minimization for learning theory. Advances in neural information processing systems, 4, 1991.
  • Vapnik [1999] V. Vapnik. The nature of statistical learning theory. Springer science & business media, 1999.
  • Vapnik and Chervonenkis [2015] V. N. Vapnik and A. Y. Chervonenkis. On the uniform convergence of relative frequencies of events to their probabilities. In Measures of complexity, pages 11–30. Springer, 2015.
  • Xu and Raginsky [2017] A. Xu and M. Raginsky. Information-theoretic analysis of generalization capability of learning algorithms. Advances in Neural Information Processing Systems, 30, 2017.
  • Zhang et al. [2021] C. Zhang, S. Bengio, M. Hardt, B. Recht, and O. Vinyals. Understanding deep learning (still) requires rethinking generalization. Communications of the ACM, 64(3):107–115, 2021.
  • Zhang et al. [2022a] J. Zhang, H. Li, S. Sra, and A. Jadbabaie. Neural network weights do not converge to stationary points: An invariant measure perspective. In International Conference on Machine Learning, pages 26330–26346. PMLR, 2022a.
  • Zhang et al. [2022b] Y. Zhang, W. Zhang, S. Bald, V. Pingali, C. Chen, and M. Goswami. Stability of sgd: Tightness analysis and improved bounds. In Uncertainty in Artificial Intelligence, pages 2364–2373. PMLR, 2022b.
  • Zhou et al. [2022] Y. Zhou, Y. Liang, and H. Zhang. Understanding generalization error of sgd in nonconvex optimization. Machine Learning, 111(1):345–375, 2022.
  • Zhu et al. [2018] Z. Zhu, J. Wu, B. Yu, L. Wu, and J. Ma. The anisotropic noise in stochastic gradient descent: Its behavior of escaping from sharp minima and regularization effects. arXiv preprint arXiv:1803.00195, 2018.

Appendix A Proof of Theorem 3.6

We rewrite the Equation (7) and Equation (8):

Fμ(𝐉𝐓)FS(𝐉𝐓)=Fμ(𝐉𝟎)FS(𝐉𝟎)(i)+t=1T[(Fμ(𝐉𝐭)Fμ(𝐉𝐭𝟏))(FS(𝐉𝐭)FS(𝐉𝐭𝟏))](ii)t,F_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}})=\underbrace{F_{\mu}(\mathbf{J_{0}})-F_{S}(\mathbf{J_{0}})}_{(i)}+\sum_{t=1}^{T}\underbrace{[(F_{\mu}(\mathbf{J_{t}})-F_{\mu}(\mathbf{J_{t-1}}))-(F_{S}(\mathbf{J_{t}})-F_{S}(\mathbf{J_{t-1}}))]}_{(ii)_{t}}, (12)

and

𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]=𝔼t=1T(ii)t.\mathbb{E}[F_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}})]=\mathbb{E}\sum_{t=1}^{T}(ii)_{t}. (13)

Using Taylor expansion for the function f()f(\cdot), we have:

f(𝐉𝐭)f(𝐉𝐭𝟏)=(𝐉𝐭𝐉𝐭𝟏)Tf(𝐉𝐭𝟏)+𝒪(𝐉𝐭+𝟏𝐉𝐭2).f(\mathbf{J_{t}})-f(\mathbf{J_{t-1}})=(\mathbf{J_{t}}-\mathbf{J_{t-1}})^{\mathrm{T}}\nabla f(\mathbf{J_{t-1}})+\mathcal{O}(\|\mathbf{J_{t+1}}-\mathbf{J_{t}}\|^{2}). (14)

Therefore, we can define (ii)tlin(ii)^{lin}_{t} as:

(ii)tlin(𝐉𝐭𝐉𝐭𝟏)T(Fμ(𝐉𝐭𝟏)FS(𝐉𝐭𝟏)).(ii)^{lin}_{t}\triangleq(\mathbf{J_{t}}-\mathbf{J_{t-1}})^{\mathrm{T}}(\nabla F_{\mu}(\mathbf{J_{t-1}})-\nabla F_{S}(\mathbf{J_{t-1}})). (15)

The (ii)t(ii)_{t} can be decomposed as (ii)t=(ii)tlin+(ii)tnl(ii)_{t}=(ii)_{t}^{lin}+(ii)_{t}^{nl}, where (ii)tnl(ii)t(ii)tlin(ii)_{t}^{nl}\triangleq(ii)_{t}-(ii)_{t}^{lin}.

Then Equation 13 can be decomposited as:

𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]=𝔼t=1T(ii)tlingenlin(𝐉𝐓)+𝔼t=1T(ii)tnlgennl(𝐉𝐓).\mathbb{E}[F_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}})]=\mathbb{E}\underbrace{\sum_{t=1}^{T}(ii)^{lin}_{t}}_{\operatorname{gen}^{lin}(\mathbf{J_{T}})}+\mathbb{E}\underbrace{\sum_{t=1}^{T}(ii)^{nl}_{t}}_{\operatorname{gen}^{nl}(\mathbf{J_{T}})}. (16)
Proposition A.1.

For the gradient descent or the stochastic gradient descent algorithm, we have:

𝔼[genlin(𝐉𝐓)]=𝔼[t=0T1ηtFS(𝐉𝐭)T(FS(𝐉𝐭)Fμ(𝐉𝐭))].\mathbb{E}[\operatorname{gen}^{lin}(\mathbf{J_{T}})]=\mathbb{E}[\sum_{t=0}^{T-1}\eta_{t}\nabla F_{S}(\mathbf{J_{t}})^{\mathrm{T}}(\nabla F_{S}(\mathbf{J_{t}})-\nabla F_{\mu}(\mathbf{J_{t}}))]. (17)

If T=𝒪(1ηm)T=\mathcal{O}(\frac{1}{\eta_{m}}), we have:

|gennl(𝐉𝐓)|=𝒪(ηm),|\operatorname{gen}^{nl}(\mathbf{J_{T}})|=\mathcal{O}(\eta_{m}), (18)

where ηmmaxtηt\eta_{m}\triangleq\max\limits_{t}\eta_{t}, and we have:

limn|gennl(𝐉𝐓)|=0.\lim\limits_{n\rightarrow\infty}|\operatorname{gen}^{nl}(\mathbf{J_{T}})|=0. (19)
Remark A.2.

Furthermore, we give a experimental exploration of the gennl(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}}) in Appendix C.3. We discover that if the optimizer doesn’t enter the EoS (Edge of Stability) regime[9], we have gennl(𝐉𝐓)0\operatorname{gen}^{nl}(\mathbf{J_{T}})\approx 0. One of common used assumption in stability based generalization theories is ηm2β\eta_{m}\leq\frac{2}{\beta}. For gradient descent, we have that maximum eigenvalue of Hessian hovers above 2η\frac{2}{\eta} when the optimizers enter EoS. This indicates that the assumption ηm2β\eta_{m}\leq\frac{2}{\beta} is valid only when the optimizer doesn’t enter EoS. In addition, we observe that the proposed bound can effectively represent the generalization error trend in Section 4 under common used experiment settings.

Proof.

Analyzing of genlin(𝐉𝐓)\operatorname{gen}^{lin}(\mathbf{J_{T}})

Because of ϵ(𝐰)=C(𝐰)12ϵ\epsilon(\mathbf{w})=C(\mathbf{w})^{\frac{1}{2}}\epsilon^{\prime} and 𝔼[ϵ]=0\mathbb{E}[\epsilon^{\prime}]=0 (detail in Equation (3) and Equation (5)d), we can get:

𝔼t1[ϵtT(Fμ(𝐉𝐭)FS(𝐉𝐭))]\displaystyle\quad\ \mathbb{E}_{t-1}[\epsilon_{t}^{\mathrm{T}}(\nabla F_{\mu}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}))] (20)
=𝔼t1[(ϵ)T(C(𝐉𝐭)12)T(Fμ(𝐉𝐭)FS(𝐉𝐭))]\displaystyle=\mathbb{E}_{t-1}[(\epsilon^{\prime})^{\mathrm{T}}(C(\mathbf{J_{t}})^{\frac{1}{2}})^{\mathrm{T}}(\nabla F_{\mu}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}))]
=𝔼t1[ϵ]T𝔼t1[(C(𝐉𝐭)12)T(Fμ(𝐉𝐭)FS(𝐉𝐭))]\displaystyle=\mathbb{E}_{t-1}[\epsilon^{\prime}]^{\mathrm{T}}\mathbb{E}_{t-1}[(C(\mathbf{J_{t}})^{\frac{1}{2}})^{\mathrm{T}}(\nabla F_{\mu}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}))]
=0.\displaystyle=0.

Combining with Formula (3), we have

𝔼[genlin(𝐉𝐓)]=𝔼[t=0T1ηtFS(𝐉𝐭)T(FS(𝐉𝐭)Fμ(𝐉𝐭))].\mathbb{E}[\operatorname{gen}^{lin}(\mathbf{J_{T}})]=\mathbb{E}[\sum_{t=0}^{T-1}\eta_{t}\nabla F_{S}(\mathbf{J_{t}})^{\mathrm{T}}(\nabla F_{S}(\mathbf{J_{t}})-\nabla F_{\mu}(\mathbf{J_{t}}))]. (21)

Analyzing of gennl(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}})

Here, we denote MmaxtFS(𝐉𝐭)M\triangleq\max\limits_{t}\|\nabla F_{S}(\mathbf{J_{t}})\|. According to the definition of gennl(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}}).

|gennl(𝐉𝐓)|\displaystyle|\operatorname{gen}^{nl}(\mathbf{J_{T}})| |Fμ(𝐉𝐓)Fμlin(𝐉𝐓)|+|FS(𝐉𝐓)FSlin(𝐉𝐓)|\displaystyle\leq|F_{\mu}(\mathbf{J_{T}})-F_{\mu}^{lin}(\mathbf{J_{T}})|+|F_{S}(\mathbf{J_{T}})-F_{S}^{lin}(\mathbf{J_{T}})| (22)
=|t=1T𝒪(𝐉𝐭+𝟏𝐉𝐭2)|+|t=1T𝒪(𝐉𝐭+𝟏𝐉𝐭2)|\displaystyle=|\sum_{t=1}^{T}\mathcal{O}(\|\mathbf{J_{t+1}}-\mathbf{J_{t}}\|^{2})|+|\sum_{t=1}^{T}\mathcal{O}(\|\mathbf{J_{t+1}}-\mathbf{J_{t}}\|^{2})|
=t=1T𝒪(ηt2FS(𝐉𝐭)2)\displaystyle=\sum_{t=1}^{T}\mathcal{O}(\eta_{t}^{2}\|\nabla F_{S}(\mathbf{J_{t}})\|^{2})
=𝒪(Tηm2M2)\displaystyle=\mathcal{O}(T\eta_{m}^{2}M^{2})
=𝒪(1ηmηm2M2)\displaystyle=\mathcal{O}(\frac{1}{\eta_{m}}\eta_{m}^{2}M^{2})
=𝒪(ηm)\displaystyle=\mathcal{O}(\eta_{m})

Because all the element of training set SS is sampled from distribution μ\mu, we have limnFS(𝐰)=Fμ(𝐰),\lim\limits_{n\rightarrow\infty}\nabla F_{S}(\mathbf{w})=F_{\mu}(\mathbf{w}), Therefore:

limn(ii)tl=limn(𝐉𝐭𝐉𝐭𝟏)T(Fμ(𝐉𝐭𝟏)FS(𝐉𝐭𝟏))=0.\lim\limits_{n\rightarrow\infty}(ii)^{l}_{t}=\lim\limits_{n\rightarrow\infty}(\mathbf{J_{t}}-\mathbf{J_{t-1}})^{\mathrm{T}}(\nabla F_{\mu}(\mathbf{J_{t-1}})-\nabla F_{S}(\mathbf{J_{t-1}}))=0. (23)

What’s more, we also have:

limn(ii)tl=Fμ(𝐉𝐓)FS(𝐉𝐓)=0.\lim\limits_{n\rightarrow\infty}(ii)^{l}_{t}=F_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}})=0. (24)

Because Fμ(𝐉𝐓)FS(𝐉𝐓)=t=1T(ii)tlin+t=1T(ii)tnlF_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}})=\sum_{t=1}^{T}(ii)^{lin}_{t}+\sum_{t=1}^{T}(ii)^{nl}_{t}, we have:

limn|gennl(𝐉𝐓)|=limn|t=1T(ii)tnl|=limn|Fμ(𝐉𝐓)FS(𝐉𝐓)t=1T(ii)tlin|=|00|=0\lim\limits_{n\rightarrow\infty}|\operatorname{gen}^{nl}(\mathbf{J_{T}})|=\lim\limits_{n\rightarrow\infty}\left|\sum_{t=1}^{T}(ii)^{nl}_{t}\right|=\lim\limits_{n\rightarrow\infty}\left|F_{\mu}(\mathbf{\mathbf{J_{T}}})-F_{S}(\mathbf{J_{T}})-\sum_{t=1}^{T}(ii)^{lin}_{t}\right|=|0-0|=0 (25)

According to the Equation (17), we analyze the generalization error of 𝐉|S{t=0T1𝐰𝐭Tf(𝐉𝐭)|𝐰𝐭=δtf(𝐉𝐭)f(𝐉𝐭)}\mathcal{F}_{\mathbf{J}|S}\triangleq\{\sum_{t=0}^{T-1}\mathbf{w_{t}}^{\mathrm{T}}\nabla f(\mathbf{J_{t}})\ |\ \mathbf{w_{t}}=\delta_{t}\frac{\nabla f(\mathbf{J_{t}})}{\|\nabla f(\mathbf{J_{t}})\|}\} as a proxy for analyzing generalization error of the function trained using SGD or GD algorithm. The value of genlin(𝐉𝐓)\operatorname{gen}^{lin}(\mathbf{J_{T}}) is equal to the generalization error of 𝐉|S\mathcal{F}_{\mathbf{J}|S}. To analyze 𝐉|S\mathcal{F}_{\mathbf{J}|S}, we construct an addictive linear space as 𝐉|S{t=0T1𝐰𝐭Tf(𝐉𝐭)|𝐰𝐭δt}\mathcal{L}_{\mathbf{J}|S}\triangleq\{\sum_{t=0}^{T-1}\mathbf{w_{t}}^{\mathrm{T}}\nabla f(\mathbf{J_{t}})\ |\ \|\mathbf{w_{t}}\|\leq\delta_{t}\}, where δtηtFS(𝐉𝐭)\delta_{t}\triangleq\|\eta_{t}\nabla F_{S}(\mathbf{J_{t}})\|. Here, we use 𝐉|S\mathbf{J}|S to emphasize that 𝐉\mathbf{J} depends on SS.

Under Assumption 3.4 (that is introduced in the main paper), we can have the following lemma.

Lemma A.3.

Under Assumption 3.4, given SμnS\sim\mu^{n}, let 𝐉=𝒜(S)\mathbf{J}=\mathcal{A}(S), we have:

𝔼[genlin(𝐉𝐓)]2γ𝕍m𝔼RS(𝐉|S),\mathbb{E}[\operatorname{gen}^{lin}(\mathbf{J_{T}})]\leq 2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}R_{S}(\mathcal{L}_{\mathbf{J}|S}), (26)

where 𝕍(𝐰)=FS(𝐰)𝔼US|U|nFU(𝐰)|S||U|nFS/U(𝐰)\mathbb{V}(\mathbf{w})=\frac{\|\nabla F_{S}(\mathbf{w})\|}{\mathbb{E}_{U\subset S}\|\frac{|U|}{n}\nabla F_{U}(\mathbf{w})-\frac{|S|-|U|}{n}\nabla F_{S/U}(\mathbf{w})\|}, 𝕍m=maxt𝕍(𝐉𝐭)\mathbb{V}_{m}=\max\limits_{t}\mathbb{V}(\mathbf{J_{t}}) and γ=max{1,maxUS;t|U|FU(𝐉𝐭)nFS(𝐉𝐭)}γ\gamma^{\prime}=\max\{1,\max\limits_{U\subset S;t}\frac{|U|\|\nabla F_{U}(\mathbf{J_{t}})\|}{n\|\nabla F_{S}(\mathbf{J_{t}})\|}\}\gamma.

Proof.

For a function hh, we define that hμ𝔼zμ[g(z)]h_{\mu}\triangleq\mathbb{E}_{z\sim\mu}[g(z)] and hS=1nziSh(zi)h_{S}=\frac{1}{n}\sum_{z_{i}\in S}h(z_{i}). Given a function space, the maximum generalization error of the space can be defined as: Φ(S,H)suphH(hμhS)\Phi(S,H)\triangleq\sup\limits_{h\in H}(h_{\mu}-h_{S})

Φ(S,H|S)\displaystyle\Phi(S,H|S) =suphH|S(hμhS)\displaystyle=\sup\limits_{h\in H|S}(h_{\mu}-h_{S}) (27)
=suphH|S(𝔼ShShS)\displaystyle=\sup\limits_{h\in H|S}(\mathbb{E}_{S^{\prime}}h_{S^{\prime}}-h_{S})
𝔼SsuphH|S(hShS)\displaystyle\leq\mathbb{E}_{S^{\prime}}\sup\limits_{h\in H|S}(h_{S^{\prime}}-h_{S})
=𝔼S,σsuphH|S(1ninσi(h(ziσi)h(ziσi)))\displaystyle=\mathbb{E}_{S^{\prime},\sigma}\sup\limits_{h\in H|S}(\frac{1}{n}\sum_{i}^{n}\sigma_{i}(h(z_{i}^{\sigma_{i}})-h(z_{i}^{-\sigma_{i}})))
𝔼S,σsuphH|S(1ninσi(h(ziσi)))+𝔼S,σsuphH|S(1ninσi(h(ziσi)))\displaystyle\leq\mathbb{E}_{S^{\prime},\sigma}\sup\limits_{h\in H|S}(\frac{1}{n}\sum_{i}^{n}\sigma_{i}(h(z_{i}^{\sigma_{i}})))+\mathbb{E}_{S^{\prime},\sigma}\sup\limits_{h\in H|S}(\frac{1}{n}\sum_{i}^{n}\sigma_{i}(h(z_{i}^{\sigma_{i}})))
=2𝔼S,σsuphH|S(1ninσi(h(ziσi))),\displaystyle=2\mathbb{E}_{S^{\prime},\sigma}\sup\limits_{h\in H|S}(\frac{1}{n}\sum_{i}^{n}\sigma_{i}(h(z_{i}^{\sigma_{i}}))),

where SS^{\prime} is another i.i.d sample set drawn from μn\mu^{n} and σ\sigma denotes the Rademacher variable. The σi\sigma_{i} in ziσiz_{i}^{\sigma_{i}} denotes ziσiz_{i}^{\sigma_{i}} that belongs to SS or SS^{\prime}. if σi=1\sigma_{i}=-1 ziσiSz_{i}^{\sigma_{i}}\in S, otherwise, ziσiSz_{i}^{\sigma_{i}}\in S^{\prime}.

RS(𝐉|S)\displaystyle R_{S}(\mathcal{L}_{\mathbf{J}|S}) 𝔼σsuph𝐉|S(1ninσih(zi))\displaystyle\triangleq\mathbb{E}_{\sigma}\sup\limits_{h\in\mathcal{L}_{\mathbf{J}|S}}(\frac{1}{n}\sum_{i}^{n}\sigma_{i}h(z_{i})) (28)
=𝔼σsuph𝐉|S(1n(zS+h(z)zSh(z)))\displaystyle=\mathbb{E}_{\sigma}\sup\limits_{h\in\mathcal{L}_{\mathbf{J}|S}}(\frac{1}{n}(\sum_{z\in S_{+}}h(z)-\sum_{z\in S_{-}}h(z)))
=𝔼σ(1nt=0TδtgS+(𝐉𝐭)gS(𝐉𝐭)),\displaystyle=\mathbb{E}_{\sigma}(\frac{1}{n}\sum_{t=0}^{\mathrm{T}}\delta_{t}\|g_{S_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|),

where S+{zi|σi=+1}S_{+}\triangleq\{z_{i}\ |\ \sigma_{i}=+1\} and S{zi|σi=1}S_{-}\triangleq\{z_{i}\ |\ \sigma_{i}=-1\}, and gS(𝐰)|S|FS(𝐰)g_{S}(\mathbf{w})\triangleq|S|\nabla F_{S}(\mathbf{w}).

𝔼S,σsuph𝐉|S(1ninσi(h(ziσi)))\displaystyle\mathbb{E}_{S^{\prime},\sigma}\sup\limits_{h\in\mathcal{F}_{\mathbf{J}|S}}(\frac{1}{n}\sum_{i}^{n}\sigma_{i}(h(z_{i}^{\sigma_{i}}))) =𝔼S,σsuph𝐉|S(1n(zS+h(z)zSh(z)))\displaystyle=\mathbb{E}_{S^{\prime},\sigma}\sup\limits_{h\in\mathcal{F}_{\mathbf{J}|S}}(\frac{1}{n}(\sum_{z\in S^{\prime}_{+}}h(z)-\sum_{z\in S_{-}}h(z))) (29)
=𝔼S,σ(1nt=0T1δtgS(𝐉𝐭)gS(𝐉𝐭)(gS+(𝐉𝐭)gS(𝐉𝐭)))\displaystyle=\mathbb{E}_{S^{\prime},\sigma}(\frac{1}{n}\sum_{t=0}^{T-1}\delta_{t}\frac{g_{S}(\mathbf{J_{t}})}{\|g_{S}(\mathbf{J_{t}})\|}(g_{S^{\prime}_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})))
=𝔼S,σ(1nt=0T1δtgS(𝐉𝐭)gS(𝐉𝐭)(|S+|Fμ(𝐉𝐭)gS(𝐉𝐭)))\displaystyle=\mathbb{E}_{S^{\prime},\sigma}(\frac{1}{n}\sum_{t=0}^{T-1}\delta_{t}\frac{g_{S}(\mathbf{J_{t}})}{\|g_{S}(\mathbf{J_{t}})\|}(|S_{+}|\nabla F_{\mu}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})))

where S+S^{\prime}_{+} is a subset of SS^{\prime} with |S+|=|S+||S^{\prime}_{+}|=|S_{+}|. Defining kγ𝕍mk\triangleq\gamma^{\prime}\mathbb{V}_{m}, we have:

k𝔼σsuph𝐉|S(1ninσih(zi))𝔼S,σsuph𝐉|S(1ninσi(h(zi)))\displaystyle k\mathbb{E}_{\sigma}\sup\limits_{h\in\mathcal{L}_{\mathbf{J}|S}}(\frac{1}{n}\sum_{i}^{n}\sigma_{i}h(z_{i}))-\mathbb{E}_{S^{\prime},\sigma}\sup\limits_{h\in\mathcal{F}_{\mathbf{J}|S}}(\frac{1}{n}\sum_{i}^{n}\sigma_{i}(h(z_{i}))) (30)
=k𝔼σ(1nt=0TδtgS+(𝐉𝐭)gS(𝐉𝐭))𝔼S,σ(1nt=0TδtgS(𝐉𝐭)gS(𝐉𝐭)(gS+(𝐉𝐭)gS(𝐉𝐭)))\displaystyle=k\mathbb{E}_{\sigma}(\frac{1}{n}\sum_{t=0}^{\mathrm{T}}\delta_{t}\|g_{S_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|)-\mathbb{E}_{S^{\prime},\sigma}(\frac{1}{n}\sum_{t=0}^{T}\delta_{t}\frac{g_{S}(\mathbf{J_{t}})}{\|g_{S}(\mathbf{J_{t}})\|}(g_{S^{\prime}_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})))
=k𝔼σ(1nt=0TδtgS+(𝐉𝐭)gS(𝐉𝐭)1nt=0TδtgS(𝐉𝐭)gS(𝐉𝐭)(|S+|Fμ(𝐉𝐭)gS(𝐉𝐭)))\displaystyle=k\mathbb{E}_{\sigma}(\frac{1}{n}\sum_{t=0}^{\mathrm{T}}\delta_{t}\|g_{S_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|-\frac{1}{n}\sum_{t=0}^{T}\delta_{t}\frac{g_{S}(\mathbf{J_{t}})}{\|g_{S}(\mathbf{J_{t}})\|}(|S_{+}|\nabla F_{\mu}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})))
k𝔼σ(1nt=0TδtgS+(𝐉𝐭)gS(𝐉𝐭)1nt=0Tδt|S+|Fμ(𝐉𝐭)gS(𝐉𝐭))\displaystyle\geq k\mathbb{E}_{\sigma}(\frac{1}{n}\sum_{t=0}^{\mathrm{T}}\delta_{t}\|g_{S_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|-\frac{1}{n}\sum_{t=0}^{T}\delta_{t}\||S_{+}|\nabla F_{\mu}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|)
k1nt=0Tδt𝔼σ(gS+(𝐉𝐭)gS(𝐉𝐭))t=0TδtγFS(𝐉𝐭)\displaystyle\geq k\frac{1}{n}\sum_{t=0}^{\mathrm{T}}\delta_{t}\mathbb{E}_{\sigma}(\|g_{S_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|)-\sum_{t=0}^{T}\delta_{t}\gamma^{\prime}\|\nabla F_{S}(\mathbf{J_{t}})\|
0\displaystyle\geq 0

Therefore, combining Equation (27) and (30), we have 𝔼[genlin(𝐉𝐓)]2γ𝕍m𝔼RS(𝐉|S)\mathbb{E}[\operatorname{gen}^{lin}(\mathbf{J_{T}})]\leq 2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}R_{S}(\mathcal{L}_{\mathbf{J}|S}).

Lemma A.4.

Given 𝐉=𝒜(S)\mathbf{J}=\mathcal{A}(S), the formula RS(𝐉|S)R_{S}(\mathcal{L}_{\mathbf{J}|S}) can be upper bounded with:

RS(𝐉|S)𝔼tdFS(𝐉𝐭)n1+Tr(Σ(𝐰))FS(𝐰)22.R_{S}(\mathcal{L}_{\mathbf{J}|S})\leq-\mathbb{E}\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{w}))}{\|\nabla F_{S}(\mathbf{w})\|_{2}^{2}}}. (31)
Proof.

Let us start with the calculation of RS(𝐰Tf(𝐉𝐭))R_{S}(\mathbf{w}^{\mathrm{T}}\nabla f(\mathbf{J_{t}})):

RS({𝐰Tf(𝐉𝐭)|𝐰δ})\displaystyle R_{S}(\{\mathbf{w}^{\mathrm{T}}\nabla f(\mathbf{J_{t}})|\|\mathbf{w}\|\leq\delta\}) =1n𝔼σ(sup𝐰δ𝐰Ti=1nσif(𝐉𝐭,zi))\displaystyle=\frac{1}{n}\mathbb{E}_{\sigma}\left(\sup_{\left\|\mathbf{w}\right\|\leq\delta}\mathbf{w}^{\mathrm{T}}\sum_{i=1}^{n}\sigma_{i}\nabla f(\mathbf{J_{t}},z_{i})\right) (32)
=δn𝔼σ(i=1nσif(𝐉𝐭,zi)2)\displaystyle=\frac{\delta}{n}\mathbb{E}_{\sigma}\left(\sqrt{\left\|\sum_{i=1}^{n}\sigma_{i}\nabla f(\mathbf{J_{t}},z_{i})\right\|^{2}}\right)
δn(𝔼σi=1nσif(𝐉𝐭,zi)2)\displaystyle\leq\frac{\delta}{n}\left(\sqrt{\mathbb{E}_{\sigma}\left\|\sum_{i=1}^{n}\sigma_{i}\nabla f(\mathbf{J_{t}},z_{i})\right\|^{2}}\right)
δn(𝔼σi=1nσif(𝐉𝐭,zi)2)\displaystyle\overset{\blacktriangle}{\leq}\frac{\delta}{n}\left(\sqrt{\mathbb{E}_{\sigma}\sum_{i=1}^{n}\left\|\sigma_{i}\nabla f(\mathbf{J_{t}},z_{i})\right\|^{2}}\right)
=δni=1nf(𝐉𝐭,zi)2,\displaystyle=\frac{\delta}{n}\sqrt{\sum_{i=1}^{n}\left\|\nabla f(\mathbf{J_{t}},z_{i})\right\|^{2}},

where \blacktriangle represents using the relation that for i,ji,j satisfying iji\neq j, we have 𝔼σiσj=0\mathbb{E}\sigma_{i}\sigma_{j}=0.

Because wiw_{i} is independent of wjw_{j} if iji\neq j, we have:

RS(𝐉|S)\displaystyle R_{S}(\mathcal{L}_{\mathbf{J}|S}) =RS({f(𝐉𝟎)+t=0T1𝐰𝐭Tf(𝐉𝐭)|𝐰𝐭δt})\displaystyle=R_{S}(\{f(\mathbf{J_{0}})+\sum_{t=0}^{T-1}\mathbf{w_{t}}^{\mathrm{T}}\nabla f(\mathbf{J_{t}})\ |\ \|\mathbf{w_{t}}\|\leq\delta_{t}\}) (33)
=t=0T1RS({𝐰𝐭Tf(𝐉𝐭)|𝐰𝐭δt})\displaystyle=\sum_{t=0}^{T-1}R_{S}(\{\mathbf{w_{t}}^{\mathrm{T}}\nabla f(\mathbf{J_{t}})|\|\mathbf{w_{t}}\|\leq\delta_{t}\})
t=0T1δtni=1nf(𝐉𝐭,zi)2.\displaystyle\leq\sum_{t=0}^{T-1}\frac{\delta_{t}}{n}\sqrt{\sum_{i=1}^{n}\|\nabla f(\mathbf{J_{t}},z_{i})\|^{2}}.

The covariance of gradient noise can be calculated as:

Tr[Σ(𝐰)]\displaystyle\operatorname{Tr}[\Sigma(\mathbf{w})] =Tr[1ni=1nf(𝐰,zi)f(𝐰,zi)TFS(𝐰)FS(𝐰)T]\displaystyle=\operatorname{Tr}[\frac{1}{n}\sum_{i=1}^{n}\nabla f(\mathbf{w},z_{i})\nabla f(\mathbf{w},z_{i})^{\mathrm{T}}-\nabla F_{S}(\mathbf{w})\nabla F_{S}(\mathbf{w})^{\mathrm{T}}] (34)
=1ni=1nTr[f(𝐰,zi)f(𝐰,zi)T]Tr[FS(𝐰)FS(𝐰)T]\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\operatorname{Tr}[\nabla f(\mathbf{w},z_{i})\nabla f(\mathbf{w},z_{i})^{\mathrm{T}}]-\operatorname{Tr}[\nabla F_{S}(\mathbf{w})\nabla F_{S}(\mathbf{w})^{\mathrm{T}}]
=1ni=1nf(𝐰,zi)2FS(𝐰)2\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\|\nabla f(\mathbf{w},z_{i})^{2}\|-\|\nabla F_{S}(\mathbf{w})\|^{2}

Taking Equation (33) and δtηtFS(𝐉𝐭)\delta_{t}\triangleq\|\eta_{t}\nabla F_{S}(\mathbf{J_{t}})\| into Equation (34), we have :

RS(𝐉|S)\displaystyle R_{S}(\mathcal{L}_{\mathbf{J}|S}) t=0T1δtni=1nf(𝐉𝐭,zi)2\displaystyle\leq\sum_{t=0}^{T-1}\frac{\delta_{t}}{n}\sqrt{\sum_{i=1}^{n}\|\nabla f(\mathbf{J_{t}},z_{i})\|^{2}} (35)
=t=0T1δtnTr[Σ(𝐉𝐭)]+FS(𝐉𝐭)2\displaystyle=\sum_{t=0}^{T-1}\frac{\delta_{t}}{\sqrt{n}}\sqrt{\operatorname{Tr}[\Sigma(\mathbf{J_{t}})]+\|\nabla F_{S}(\mathbf{J_{t}})\|^{2}}
=t=0T1ηtFS(𝐉𝐭)nTr[Σ(𝐉𝐭)]+FS(𝐉𝐭)2\displaystyle=\sum_{t=0}^{T-1}\frac{\eta_{t}\|\nabla F_{S}(\mathbf{J_{t}})\|}{\sqrt{n}}\sqrt{\operatorname{Tr}[\Sigma(\mathbf{J_{t}})]+\|\nabla F_{S}(\mathbf{J_{t}})\|^{2}}

When ηt\eta_{t} is small, δt𝔼ϵ(𝐉𝐭+𝟏𝐉𝐭)TFS(𝐉𝐭)FS(𝐉𝐭)𝔼ϵFS(𝐉𝐭+𝟏)FS(𝐉𝐭)FS(𝐉𝐭)\delta_{t}\approx-\mathbb{E}_{\epsilon}\frac{(\mathbf{J_{t+1}}-\mathbf{J_{t}})^{\mathrm{T}}\nabla F_{S}(\mathbf{J_{t}})}{\|\nabla F_{S}(\mathbf{J_{t}})\|}\approx-\mathbb{E}_{\epsilon}\frac{F_{S}(\mathbf{J_{t+1}})-F_{S}(\mathbf{J_{t}})}{\|\nabla F_{S}(\mathbf{J_{t}})\|} holds, therefore we have:

𝔼RS(𝐉|S)\displaystyle\mathbb{E}R_{S}(\mathcal{L}_{\mathbf{J}|S}) 𝔼t=0T1δtni=1nf(𝐉𝐭,zi)22\displaystyle\leq\mathbb{E}\sum_{t=0}^{T-1}\frac{\delta_{t}}{n}\sqrt{\sum_{i=1}^{n}\left\|\nabla f(\mathbf{J_{t}},z_{i})\right\|_{2}^{2}} (36)
𝔼t=0T1FS(𝐉𝐭+𝟏)FS(𝐉𝐭)n1+Tr(Σ(𝐰))FS(𝐰)22\displaystyle\approx-\mathbb{E}\sum_{t=0}^{T-1}\frac{F_{S}(\mathbf{J_{t+1}})-F_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{w}))}{\|\nabla F_{S}(\mathbf{w})\|_{2}^{2}}}
𝔼tdFS(𝐉𝐭)n1+Tr(Σ(𝐰))FS(𝐰)22\displaystyle\approx-\mathbb{E}\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{w}))}{\|\nabla F_{S}(\mathbf{w})\|_{2}^{2}}}

Theorem A.5.

Under Assumption 3.4, given SμnS\sim\mu^{n}, let 𝐉=𝒜(S)\mathbf{J}=\mathcal{A}(S), where 𝒜\mathcal{A} denotes the SGD or GD algorithm training with TT steps, we have:

𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]2γ𝕍m𝔼tdFS(𝐉𝐭)n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22+𝒪(ηm),\mathbb{E}[F_{\mu}(\mathbf{J_{T}})-F_{S}(\mathbf{J_{T}})]\leq-2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}}+\mathcal{O}(\eta_{m}), (37)

where 𝕍(𝐰)=FS(𝐰)𝔼US|U|nFU(𝐰)|S||U|nFS/U(𝐰)\mathbb{V}(\mathbf{w})=\frac{\|\nabla F_{S}(\mathbf{w})\|}{\mathbb{E}_{U\subset S}\|\frac{|U|}{n}\nabla F_{U}(\mathbf{w})-\frac{|S|-|U|}{n}\nabla F_{S/U}(\mathbf{w})\|}, 𝕍m=maxt𝕍(𝐉𝐭)\mathbb{V}_{m}=\max_{t}\mathbb{V}(\mathbf{J_{t}}) and γ=max{1,maxUS;t|U|FU(𝐉𝐭)nFS(𝐉𝐭)}γ\gamma^{\prime}=\max\{1,\max\limits_{U\subset S;t}\frac{|U|\|\nabla F_{U}(\mathbf{J_{t}})\|}{n\|\nabla F_{S}(\mathbf{J_{t}})\|}\}\gamma.

Proof.

We rewrite Equation 2 of the update of SGD with batchsize bb here:

𝐉𝐭=𝐉𝐭𝟏ηtFS(𝐉𝐭𝟏)+ηtϵt\mathbf{J_{t}}=\mathbf{J_{t-1}}-\eta_{t}\nabla F_{S}(\mathbf{J_{t-1}})+\eta_{t}\epsilon_{t} (38)

where we simplify the ϵ(𝐰𝐭)\epsilon(\mathbf{w_{t}}) as ϵt\epsilon_{t}, then we can expand the function at f(𝐉𝐓)f(\mathbf{J_{T}}) as:

flin(𝐉𝐓)\displaystyle f^{lin}(\mathbf{J_{T}}) f(𝐉𝟎)+t=0T1(ηtFS(𝐉𝐭)+ϵ)Tf(𝐉𝐭)\displaystyle\triangleq f(\mathbf{J_{0}})+\sum_{t=0}^{T-1}(\eta_{t}\nabla F_{S}(\mathbf{J_{t}})+\epsilon)^{\mathrm{T}}\nabla f(\mathbf{J_{t}}) (39)
=f(𝐉𝟎)+t=0K1ηtFS(𝐉𝐭)Tf(𝐉𝐭)+t=0T1ϵtTf(𝐉𝐭)\displaystyle=f(\mathbf{J_{0}})+\sum_{t=0}^{K-1}\eta_{t}\nabla F_{S}(\mathbf{J_{t}})^{\mathrm{T}}\nabla f(\mathbf{J_{t}})+\sum_{t=0}^{T-1}\epsilon_{t}^{\mathrm{T}}\nabla f(\mathbf{J_{t}})

Note that when the learning rate is small, we have f(𝐉𝐓)flin(𝐉𝐓)f(\mathbf{J_{T}})\approx f^{lin}(\mathbf{J_{T}}).

The difference between the distributional value and the empirical value of the linear function can be calculated as:

𝔼[Fμ(𝐉𝟎)+t=0T1(ηtFS(𝐉𝐭)+ϵ)TFμ(𝐉𝐭)]𝔼[FS(𝐉𝟎)+t=0T1(ηtFS(𝐉𝐭)+ϵ)TFS(𝐉𝐭)]\displaystyle\mathbb{E}[F_{\mu}(\mathbf{J_{0}})+\sum_{t=0}^{T-1}(\eta_{t}\nabla F_{S}(\mathbf{J_{t}})+\epsilon)^{\mathrm{T}}\nabla F_{\mu}(\mathbf{J_{t}})]-\mathbb{E}[F_{S}(\mathbf{J_{0}})+\sum_{t=0}^{T-1}(\eta_{t}\nabla F_{S}(\mathbf{J_{t}})+\epsilon)^{\mathrm{T}}\nabla F_{S}(\mathbf{J_{t}})] (40)
=𝔼[Fμ(𝐉𝟎)FS(𝐉𝟎)+t=0T1(ηtFS(𝐉𝐭)+ϵ)TFμ(𝐉𝐭)t=0T1(ηtFS(𝐉𝐭)+ϵ)TFS(𝐉𝐭)]\displaystyle=\mathbb{E}[F_{\mu}(\mathbf{J_{0}})-F_{S}(\mathbf{J_{0}})+\sum_{t=0}^{T-1}(\eta_{t}\nabla F_{S}(\mathbf{J_{t}})+\epsilon)^{\mathrm{T}}\nabla F_{\mu}(\mathbf{J_{t}})-\sum_{t=0}^{T-1}(\eta_{t}\nabla F_{S}(\mathbf{J_{t}})+\epsilon)^{\mathrm{T}}\nabla F_{S}(\mathbf{J_{t}})]
=𝔼[t=0T1ηtFS(𝐉𝐭)T(Fμ(𝐉𝐭)FS(𝐉𝐭))+t=0T1ϵtT(Fμ(𝐉𝐭)FS(𝐉𝐭))]]\displaystyle=\mathbb{E}[\sum_{t=0}^{T-1}\eta_{t}\nabla F_{S}(\mathbf{J_{t}})^{\mathrm{T}}(\nabla F_{\mu}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}))+\sum_{t=0}^{T-1}\epsilon_{t}^{\mathrm{T}}(\nabla F_{\mu}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}))]]
=𝔼[t=0T1ηtFS(𝐉𝐭)T(Fμ(𝐉𝐭)FS(𝐉𝐭))]\displaystyle\overset{\blacktriangle}{=}\mathbb{E}[\sum_{t=0}^{T-1}\eta_{t}\nabla F_{S}(\mathbf{J_{t}})^{\mathrm{T}}(\nabla F_{\mu}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}))]
Φ(S,𝐉|S),\displaystyle\leq\Phi(S,\mathcal{F}_{\mathbf{J}|S}),

where \blacktriangle using the equation that 𝔼[ϵtT(Fμ(𝐉𝐭)FS(𝐉𝐭))]=0\mathbb{E}[\epsilon_{t}^{\mathrm{T}}(\nabla F_{\mu}(\mathbf{J_{t}})-\nabla F_{S}(\mathbf{J_{t}}))]=0, according to Equation 20.

Because of 𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]=𝔼[Fμlin(𝐉𝐓)+𝒪(ηm)FSlin(𝐉𝐓)𝒪(ηm)]=𝔼[Fμlin(𝐉𝐓)FSlin(𝐉𝐓)]+𝒪(ηm)\mathbb{E}[F_{\mu}(\mathbf{J_{T}})-F_{S}(\mathbf{J_{T}})]=\mathbb{E}[F^{lin}_{\mu}(\mathbf{J_{T}})+\mathcal{O}(\eta_{m})-F^{lin}_{S}(\mathbf{J_{T}})-\mathcal{O}(\eta_{m})]=\mathbb{E}[F^{lin}_{\mu}(\mathbf{J_{T}})-F^{lin}_{S}(\mathbf{J_{T}})]+\mathcal{O}(\eta_{m})(from Proposition A.1), by applying Lemma A.3 and Lemma A.4, the theorm is proved.

Corollary A.6.

If function f()f(\cdot) is β\beta-smooth, under Assumption 3.4 given SμnS\sim\mu^{n}, let 𝐉=𝒜(S)\mathbf{J}=\mathcal{A}(S), ηt=cβ(t+1)\eta_{t}=\frac{c}{\beta(t+1)}, M22=maxt𝔼t1(FS(𝐉𝐭)+ϵ(𝐉𝐭)2)M^{2}_{2}=\max\limits_{t}\mathbb{E}_{t-1}(\|\nabla F_{S}(\mathbf{J_{t}})+\epsilon(\mathbf{J_{t}})\|^{2}) and M44=maxt𝔼t1(FS(𝐉𝐭)+ϵ(𝐉𝐭)4)M^{4}_{4}=\max\limits_{t}\mathbb{E}_{t-1}(\|\nabla F_{S}(\mathbf{J_{t}})+\epsilon(\mathbf{J_{t}})\|^{4}) , where 𝒜\mathcal{A} denoted the SGD or GD algorithm training with TT steps, we have:

𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]\displaystyle\mathbb{E}[F_{\mu}(\mathbf{J_{T}})-F_{S}(\mathbf{J_{T}})]\leq 2γ𝕍m𝔼tdFS(𝐉𝐭)n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22\displaystyle-2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}} (41)
+2c2γ𝕍mM42𝔼tdtnβ2(t+1)4(1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22)\displaystyle+2c^{2}\gamma^{\prime}\mathbb{V}_{m}M_{4}^{2}\sqrt{\mathbb{E}\int_{t}\frac{dt}{n\beta^{2}(t+1)^{4}}\left(1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}\right)}
+2c2M22β.\displaystyle+2c^{2}\frac{M_{2}^{2}}{\beta}.

where 𝕍(𝐰)=FS(𝐰)𝔼US|U|nFU(𝐰)|S||U|nFS/U(𝐰)\mathbb{V}(\mathbf{w})=\frac{\|\nabla F_{S}(\mathbf{w})\|}{\mathbb{E}_{U\subset S}\|\frac{|U|}{n}\nabla F_{U}(\mathbf{w})-\frac{|S|-|U|}{n}\nabla F_{S/U}(\mathbf{w})\|}, 𝕍m=maxt𝕍(𝐉𝐭)\mathbb{V}_{m}\!=\!\max\limits_{t}\mathbb{V}(\mathbf{J_{t}}) and γ=max{1,maxUS;t|U|FU(𝐉𝐭)nFS(𝐉𝐭)}γ\gamma^{\prime}\!=\!\max\{1,\max\limits_{U\subset S;t}\frac{|U|\|\nabla F_{U}(\mathbf{J_{t}})\|}{n\|\nabla F_{S}(\mathbf{J_{t}})\|}\}\gamma.

Proof.

If f()f(\cdot) is β\beta-smooth, we have:

f(𝐉𝐭+𝟏)f(𝐉𝐭)(𝐉𝐭+𝟏𝐉𝐭)Tf(𝐉𝐭)+12β𝐉𝐭+𝟏𝐉𝐭2\displaystyle f(\mathbf{J_{t+1}})-f(\mathbf{J_{t}})\leq(\mathbf{J_{t+1}}-\mathbf{J_{t}})^{\mathrm{T}}\nabla f(\mathbf{J_{t}})+\frac{1}{2}\beta\|\mathbf{J_{t+1}-\mathbf{J_{t}}}\|^{2} (42)
f(𝐉𝐭+𝟏)f(𝐉𝐭)(𝐉𝐭+𝟏𝐉𝐭)Tf(𝐉𝐭)12β𝐉𝐭+𝟏𝐉𝐭2.\displaystyle f(\mathbf{J_{t+1}})-f(\mathbf{J_{t}})\geq(\mathbf{J_{t+1}}-\mathbf{J_{t}})^{\mathrm{T}}\nabla f(\mathbf{J_{t}})-\frac{1}{2}\beta\|\mathbf{J_{t+1}-\mathbf{J_{t}}}\|^{2}. (43)

Combining the two equations, we obtain:

|Rμ(𝐉𝐓)RS(𝐉𝐓)|\displaystyle|R_{\mu}(\mathbf{J_{T}})-R_{S}(\mathbf{J_{T}})| |Fμ(𝐉𝐓)Fμlin(𝐉𝐓)|+|FS(𝐉𝐓)FSlin(𝐉𝐓)|\displaystyle\leq|F_{\mu}(\mathbf{J_{T}})-F_{\mu}^{lin}(\mathbf{J_{T}})|+|F_{S}(\mathbf{J_{T}})-F_{S}^{lin}(\mathbf{J_{T}})| (44)
β2t=0T1𝐉𝐭+𝟏𝐉𝐭2+β2t=0T1𝐉𝐭+𝟏𝐉𝐭2\displaystyle\leq\frac{\beta}{2}\sum_{t=0}^{T-1}\|\mathbf{J_{t+1}-\mathbf{J_{t}}}\|^{2}+\frac{\beta}{2}\sum_{t=0}^{T-1}\|\mathbf{J_{t+1}-\mathbf{J_{t}}}\|^{2}
=βt=0T1𝐉𝐭+𝟏𝐉𝐭2.\displaystyle=\beta\sum_{t=0}^{T-1}\|\mathbf{J_{t+1}-\mathbf{J_{t}}}\|^{2}.

The generalization error can be divided into three parts:

𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]\displaystyle\mathbb{E}[F_{\mu}(\mathbf{J_{T}})-F_{S}(\mathbf{J_{T}})]\leq 2γ𝕍m𝔼tdFS(𝐉𝐭)n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22\displaystyle-2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}} (45)
2γ𝕍m𝔼tdFSlin(𝐉𝐭)dFS(𝐉𝐭)n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22(A)\displaystyle\underbrace{-2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}\int_{t}\frac{dF^{lin}_{S}(\mathbf{J_{t}})-dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}}}_{(A)}
+β𝔼t=0T1𝐉𝐭+𝟏𝐉𝐭2(B).\displaystyle+\underbrace{\beta\mathbb{E}\sum_{t=0}^{T-1}\|\mathbf{J_{t+1}-\mathbf{J_{t}}}\|^{2}}_{(B)}.

The term“(A)(A)” is caused by using FS(𝐉𝐭+𝟏)FS(𝐉𝐭)F_{S}(\mathbf{J_{t+1}})-F_{S}(\mathbf{J_{t}}) to replace FSlin(𝐉𝐭+𝟏)FSlin(𝐉𝐭)F^{lin}_{S}(\mathbf{J_{t+1}})-F^{lin}_{S}(\mathbf{J_{t}}). The term "(B)(B)" is induced by gennl(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}}). Then, we want to give a upper bound of (A)(A) using M44M_{4}^{4}:

(A)\displaystyle(A) ()2γ𝕍m𝔼t=0T1β𝐉𝐭+𝟏𝐉𝐭2n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22\displaystyle\overset{(\star)}{\leq}2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}\sum_{t=0}^{T-1}\frac{\beta\|\mathbf{J_{t+1}}-\mathbf{J_{t}}\|^{2}}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}} (46)
()2c2γ𝕍m𝔼t=0T1FS(𝐉𝐭)+ϵ(𝐉𝐭)2βn(t+1)21+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22\displaystyle\overset{(\star\star)}{\leq}2c^{2}\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}\sum_{t=0}^{T-1}\frac{\|\nabla F_{S}(\mathbf{J_{t}})+\epsilon(\mathbf{J_{t}})\|^{2}}{\beta\sqrt{n}(t+1)^{2}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}}
=2c2γ𝕍mt=0T1𝔼t1FS(𝐉𝐭)+ϵ(𝐉𝐭)2βn(t+1)21+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22\displaystyle=2c^{2}\gamma^{\prime}\mathbb{V}_{m}\sum_{t=0}^{T-1}\mathbb{E}_{t-1}\frac{\|\nabla F_{S}(\mathbf{J_{t}})+\epsilon(\mathbf{J_{t}})\|^{2}}{\beta\sqrt{n}(t+1)^{2}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}}
()2c2γ𝕍mt=0T1𝔼t1FS(𝐉𝐭)+ϵ(𝐉𝐭)4β2n(t+1)4𝔼t1(1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22)\displaystyle\overset{(\star\star\star)}{\leq}2c^{2}\gamma^{\prime}\mathbb{V}_{m}\sum_{t=0}^{T-1}\sqrt{\frac{\mathbb{E}_{t-1}\|\nabla F_{S}(\mathbf{J_{t}})+\epsilon(\mathbf{J_{t}})\|^{4}}{\beta^{2}n(t+1)^{4}}\mathbb{E}_{t-1}\left(1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}\right)}
2c2γ𝕍mt=0T1M44β2n(t+1)4𝔼t1(1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22)\displaystyle\leq 2c^{2}\gamma^{\prime}\mathbb{V}_{m}\sum_{t=0}^{T-1}\sqrt{\frac{M_{4}^{4}}{\beta^{2}n(t+1)^{4}}\mathbb{E}_{t-1}\left(1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}\right)}
2c2γ𝕍mM42𝔼t=0T11β2n(t+1)4(1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22).\displaystyle\leq 2c^{2}\gamma^{\prime}\mathbb{V}_{m}M_{4}^{2}\sqrt{\mathbb{E}\sum_{t=0}^{T-1}\frac{1}{\beta^{2}n(t+1)^{4}}\left(1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}\right)}.

where ()(\star) is due to the Equation 44, ()(\star\star) is due to the update rule of 𝐉𝐭\mathbf{J_{t}} and ()(\star\star\star) is que to Hölder’s inequality. In the following, we use M22M_{2}^{2} to give a upper bound for (B)(B):

(B)\displaystyle(B) c2βt=0T11(t+1)2𝔼F(𝐉𝐭)+ϵ(𝐉𝐭)2\displaystyle\leq\frac{c^{2}}{\beta}\sum_{t=0}^{T-1}\frac{1}{(t+1)^{2}}\mathbb{E}\|\nabla F(\mathbf{J_{t}})+\epsilon(\mathbf{J_{t}})\|^{2} (47)
c2βt=0T11(t+1)2M22\displaystyle\leq\frac{c^{2}}{\beta}\sum_{t=0}^{T-1}\frac{1}{(t+1)^{2}}M_{2}^{2}
c2β(M22+t=1T11(t+1)2M22)\displaystyle\leq\frac{c^{2}}{\beta}\left(M_{2}^{2}+\sum_{t=1}^{T-1}\frac{1}{(t+1)^{2}}M_{2}^{2}\right)
c2β(M22+t=1T1(1t1t+1)M22)\displaystyle\leq\frac{c^{2}}{\beta}\left(M_{2}^{2}+\sum_{t=1}^{T-1}\left(\frac{1}{t}-\frac{1}{t+1}\right)M_{2}^{2}\right)
c2β(2M221TM22)\displaystyle\leq\frac{c^{2}}{\beta}\left(2M_{2}^{2}-\frac{1}{T}M_{2}^{2}\right)
2c2M22β\displaystyle\leq 2c^{2}\frac{M_{2}^{2}}{\beta}

Taking the upper bound value of "(A)" and "(B)" into Equation 45, we obtain the result.

Appendix B Relaxed Assumption and Corresponding Bound

Assumption B.1.

There is a value γ\gamma, T0T_{0} and ζ\zeta, so that for all 𝐰{𝐉𝐭|tt<T0}\mathbf{w}\in\{\mathbf{J_{t}}|t\in\mathbb{N}\ \wedge\ t<T_{0}\}, we have Fμ(𝐰)γFS(𝐰)\|\nabla F_{\mu}(\mathbf{w})\|\leq\gamma\|\nabla F_{S}(\mathbf{w})\| and for all 𝐰{𝐉𝐭|ttT0}\mathbf{w}\in\{\mathbf{J_{t}}|t\in\mathbb{N}\ \wedge\ t\geq T_{0}\}, we have Fμ(𝐰)γFS(𝐰)+ζ\|\nabla F_{\mu}(\mathbf{w})\|\leq\gamma\|\nabla F_{S}(\mathbf{w})\|+\zeta.

Theorem B.2.

Under Assumption B.1, given SμnS\sim\mu^{n}, let 𝐉=𝒜(S)\mathbf{J}=\mathcal{A}(S), where 𝒜\mathcal{A} denotes the SGD or GD algorithm training with TT steps, we have:

𝔼[Fμ(𝐉𝐓)FS(𝐉𝐓)]2γ𝕍m𝔼tdFS(𝐉𝐭)n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22+12t=T0TηtFS(𝐉𝐭)ζ+𝒪(ηm),\mathbb{E}[F_{\mu}(\mathbf{J_{T}})-F_{S}(\mathbf{J_{T}})]\leq-2\gamma^{\prime}\mathbb{V}_{m}\mathbb{E}\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}}+\frac{1}{2}\sum_{t=T_{0}}^{T}\eta_{t}\|\nabla F_{S}(\mathbf{J_{t}})\|\zeta+\mathcal{O}(\eta_{m}), (48)

where 𝕍(𝐰)=FS(𝐰)𝔼US|U|nFU(𝐰)n|U|nFS/U(𝐰)\mathbb{V}(\mathbf{w})=\frac{\|\nabla F_{S}(\mathbf{w})\|}{\mathbb{E}_{U\subset S}\|\frac{|U|}{n}\nabla F_{U}(\mathbf{w})-\frac{n-|U|}{n}\nabla F_{S/U}(\mathbf{w})\|}, 𝕍m=maxt𝕍(𝐉𝐭)\mathbb{V}_{m}=\max_{t}\mathbb{V}(\mathbf{J_{t}}) and γ=max{1,maxUS;t|U|FU(𝐉𝐭)nFS(𝐉𝐭)}γ\gamma^{\prime}=\max\{1,\max\limits_{U\subset S;t}\frac{|U|\|\nabla F_{U}(\mathbf{J_{t}})\|}{n\|\nabla F_{S}(\mathbf{J_{t}})\|}\}\gamma.

Proof.

Most of the proofs in this part are the same as those in Appendix A, except for Equation 30. The Equation 30 is replaced by:

k𝔼σsuph𝐉|S(1ninσih(zi))𝔼S,σsuph𝐉|S(1ninσi(h(zi)))\displaystyle k\mathbb{E}_{\sigma}\sup\limits_{h\in\mathcal{L}_{\mathbf{J}|S}}(\frac{1}{n}\sum_{i}^{n}\sigma_{i}h(z_{i}))-\mathbb{E}_{S^{\prime},\sigma}\sup\limits_{h\in\mathcal{F}_{\mathbf{J}|S}}(\frac{1}{n}\sum_{i}^{n}\sigma_{i}(h(z_{i}))) (49)
=k𝔼σ(1nt=0TδtgS+(𝐉𝐭)gS(𝐉𝐭))𝔼S,σ(1nt=0TδtgS(𝐉𝐭)gS(𝐉𝐭)(gS+(𝐉𝐭)gS(𝐉𝐭)))\displaystyle=k\mathbb{E}_{\sigma}(\frac{1}{n}\sum_{t=0}^{\mathrm{T}}\delta_{t}\|g_{S_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|)-\mathbb{E}_{S^{\prime},\sigma}(\frac{1}{n}\sum_{t=0}^{T}\delta_{t}\frac{g_{S}(\mathbf{J_{t}})}{\|g_{S}(\mathbf{J_{t}})\|}(g_{S^{\prime}_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})))
=k𝔼σ(1nt=0TδtgS+(𝐉𝐭)gS(𝐉𝐭)1nt=0TδtgS(𝐉𝐭)gS(𝐉𝐭)(|S+|Fμ(𝐉𝐭)gS(𝐉𝐭)))\displaystyle=k\mathbb{E}_{\sigma}(\frac{1}{n}\sum_{t=0}^{\mathrm{T}}\delta_{t}\|g_{S_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|-\frac{1}{n}\sum_{t=0}^{T}\delta_{t}\frac{g_{S}(\mathbf{J_{t}})}{\|g_{S}(\mathbf{J_{t}})\|}(|S_{+}|\nabla F_{\mu}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})))
k𝔼σ(1nt=0TδtgS+(𝐉𝐭)gS(𝐉𝐭)1nt=0Tδt|S+|Fμ(𝐉𝐭)gS(𝐉𝐭))\displaystyle\geq k\mathbb{E}_{\sigma}(\frac{1}{n}\sum_{t=0}^{\mathrm{T}}\delta_{t}\|g_{S_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|-\frac{1}{n}\sum_{t=0}^{T}\delta_{t}\||S_{+}|\nabla F_{\mu}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|)
k1nt=0Tδt𝔼σ(gS+(𝐉𝐭)gS(𝐉𝐭))t=0TδtγFS(𝐉𝐭)12t=T0Tδtζ\displaystyle\geq k\frac{1}{n}\sum_{t=0}^{\mathrm{T}}\delta_{t}\mathbb{E}_{\sigma}(\|g_{S_{+}}(\mathbf{J_{t}})-g_{S_{-}}(\mathbf{J_{t}})\|)-\sum_{t=0}^{T}\delta_{t}\gamma^{\prime}\|\nabla F_{S}(\mathbf{J_{t}})\|-\frac{1}{2}\sum_{t=T_{0}}^{T}\delta_{t}\zeta
12t=T0TηtFS(𝐉𝐭)ζ.\displaystyle\geq-\frac{1}{2}\sum_{t=T_{0}}^{T}\eta_{t}\|\nabla F_{S}(\mathbf{J_{t}})\|\zeta.

Remark B.3.

Compared of Theorem 3.6, we have a extra term t=T0TηtFS(𝐉𝐭)ζ\sum_{t=T_{0}}^{T}\eta_{t}\|\nabla F_{S}(\mathbf{J_{t}})\|\zeta here. Since the unrelaxed assumption Fμ(𝐰)γFS(𝐰)\|\nabla F_{\mu}(\mathbf{w})\|\leq\gamma\|\nabla F_{S}(\mathbf{w})\| is not satisfied only when FS(𝐰)\|\nabla F_{S}(\mathbf{w})\| is relative small, the term t=T0TηtFS(𝐉𝐭)ζ\sum_{t=T_{0}}^{T}\eta_{t}\|\nabla F_{S}(\mathbf{J_{t}})\|\zeta is small value.

Appendix C Experiments

C.1 Calculation of 𝒞(𝐉)\mathcal{C}(\mathbf{J})

To reduce the calculation, we construct a randomly sampled subset Ssp={z1sp,,znsp}SS_{\operatorname{sp}}=\{z^{\operatorname{sp}}_{1},...,z^{\operatorname{sp}}_{n}\}\subset S.

tdFS(𝐉𝐭)n1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)22\displaystyle\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}} =tdFS(𝐉𝐭)ni=1nf(𝐉𝐭,zi)22n1FS(𝐉𝐭)22\displaystyle=\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{\frac{\sum_{i=1}^{n}\left\|\nabla f(\mathbf{J_{t}},z_{i})\right\|_{2}^{2}}{n}\frac{1}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}}
tdFS(𝐉𝐭)ni=1nspf(𝐉𝐭,zisp)22nsp1FS(𝐉𝐭)22\displaystyle\approx\int_{t}\frac{dF_{S}(\mathbf{J_{t}})}{\sqrt{n}}\sqrt{\frac{\sum_{i=1}^{n_{\operatorname{sp}}}\left\|\nabla f(\mathbf{J_{t}},z^{\operatorname{sp}}_{i})\right\|_{2}^{2}}{n_{\operatorname{sp}}}\frac{1}{\|\nabla F_{S}(\mathbf{J_{t}})\|_{2}^{2}}}

Denote the weights after tt-epoch training as 𝐗𝐭\mathbf{X_{t}}. We can roughly calculated 𝒞(𝐉𝐭)\mathcal{C}(\mathbf{J_{t}})

t=1TFS(𝐗𝐭)FS(𝐗𝐭𝟏)ni=1nspf(𝐗𝐭,zisp)22nsp1FS(𝐗𝐭)22\sum_{t=1}^{T}\frac{F_{S}(\mathbf{X_{t}})-F_{S}(\mathbf{\mathbf{X_{t-1}}})}{\sqrt{n}}\sqrt{\frac{\sum_{i=1}^{n_{\operatorname{sp}}}\left\|\nabla f(\mathbf{\mathbf{X_{t}}},z^{\operatorname{sp}}_{i})\right\|_{2}^{2}}{n_{\operatorname{sp}}}\frac{1}{\|\nabla F_{S}(\mathbf{\mathbf{X_{t}}})\|_{2}^{2}}}

C.2 Experimental Details

Here, we give a detail setting of the experiment for each figure.

Figure 1  The learning rate is fixed to 0.05 during all the training process. The batch size is 256. All experiments is trained with 100 epoch. The test accuracy for CIFAR-10, CIFAR-100, and SVHN are 87.64%, 55.08%, and 92.80%, respectively.

Figure 2  The initial learning rate is set to 0.05 with the batch size of 1024. We use the Cosine Annealing LR Schedule to adjust the learning rate during training.

Figure 3  Each point is an average of three repeated experiments. We stop training when the training loss is small than 0.2.

C.3 Experimental exploration of gennl(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}})

In this section, our aim is to investigate the conditions under which gennl(𝐉𝐓)0\operatorname{gen}^{nl}(\mathbf{J_{T}})\approx 0. Since directly calculating the difference |Rμ(𝐉𝐓)RS(𝐉𝐓)||R_{\mu}(\mathbf{J_{T}})-R_{S}(\mathbf{J_{T}})| is challenging, we concentrate on the upper bound value |Rμ(𝐉𝐓)|+|RS(𝐉𝐓)||R_{\mu}(\mathbf{J_{T}})|+|R_{S}(\mathbf{J_{T}})|.

We conduct the experiment using cifar10-5k dataset and fc-tanh network, following the setting of paper [9]. Cifar10-5k[9] is a subset of cifar10 dataset. Building upon the work of [1], we compute the Relative Progress Ratio (RP) and Test Relative Progress Ratio (TRP) throughout the training process. We initially consider the case of gradient descent. The definitions of RP and TRP for gradient descent are as follows:

RP(𝐉𝐭)FS(𝐉𝐭+𝟏)FS(𝐉𝐭)ηFS(𝐉𝐭)2\displaystyle\operatorname{RP}(\mathbf{J_{t}})\triangleq\frac{F_{S}(\mathbf{J_{t+1}})-F_{S}(\mathbf{J_{t}})}{\eta\|\nabla F_{S}(\mathbf{J_{t}})\|^{2}} (50)
TRP(𝐉𝐭)FS(𝐉𝐭+𝟏)FS(𝐉𝐭)ηFS(𝐉𝐭)TFS(𝐉𝐭).\displaystyle\operatorname{TRP}(\mathbf{J_{t}})\triangleq\frac{F_{S^{\prime}}(\mathbf{J_{t+1}})-F_{S^{\prime}}(\mathbf{J_{t}})}{\eta\nabla F_{S}(\mathbf{J_{t}})^{\mathrm{T}}\nabla F_{S^{\prime}}(\mathbf{J_{t}})}. (51)

Therefore, we have:

FS(𝐉𝟎)+t=1T(FS(𝐉𝐭)FS(𝐉𝐭𝟏))FS(𝐉𝟎)t=0T1(𝐉𝐭𝐉𝐭𝟏)TFS(𝐉𝐭𝟏)\displaystyle F_{S}(\mathbf{J_{0}})+\sum_{t=1}^{T}(F_{S}(\mathbf{J_{t}})-F_{S}(\mathbf{J_{t-1}}))-F_{S}(\mathbf{J_{0}})-\sum_{t=0}^{T-1}(\mathbf{J_{t}}-\mathbf{J_{t-1}})^{\mathrm{T}}\nabla F_{S}(\mathbf{J_{t-1}}) (52)
=t=1T[(FS(𝐉𝐭)FS(𝐉𝐭𝟏))(𝐉𝐭𝐉𝐭𝟏)TFS(𝐉𝐭𝟏)]\displaystyle=\sum_{t=1}^{T}\left[(F_{S}(\mathbf{J_{t}})-F_{S}(\mathbf{J_{t-1}}))-(\mathbf{J_{t}}-\mathbf{J_{t-1}})^{\mathrm{T}}\nabla F_{S}(\mathbf{J_{t-1}})\right]
=t=1T[(FS(𝐉𝐭)FS(𝐉𝐭𝟏))+ηFS(𝐉𝐭𝟏)2]\displaystyle=\sum_{t=1}^{T}\left[(F_{S}(\mathbf{J_{t}})-F_{S}(\mathbf{J_{t-1}}))+\eta\|\nabla F_{S}(\mathbf{J_{t-1}})\|^{2}\right]
=t=1T[ηt(1+RP(𝐉𝐭𝟏))FS(𝐉𝐭𝟏)2]\displaystyle=\sum_{t=1}^{T}\left[\eta_{t}(1+\operatorname{RP}(\mathbf{J_{t-1}}))\|\nabla F_{S}(\mathbf{J_{t-1}})\|^{2}\right]

Following the same way, we have:

FS(𝐉𝟎)+t=1T(FS(𝐉𝐭)FS(𝐉𝐭𝟏))FS(𝐉𝟎)t=0T1(𝐉𝐭𝐉𝐭𝟏)TFS(𝐉𝐭𝟏)\displaystyle F_{S^{\prime}}(\mathbf{J_{0}})+\sum_{t=1}^{T}(F_{S^{\prime}}(\mathbf{J_{t}})-F_{S^{\prime}}(\mathbf{J_{t-1}}))-F_{S^{\prime}}(\mathbf{J_{0}})-\sum_{t=0}^{T-1}(\mathbf{J_{t}}-\mathbf{J_{t-1}})^{\mathrm{T}}\nabla F_{S^{\prime}}(\mathbf{J_{t-1}}) (53)
=t=1T[ηt(1+TRP(𝐉𝐭𝟏))FS(𝐉𝐭𝟏)TFS(𝐉𝐭𝟏)]\displaystyle=\sum_{t=1}^{T}\left[\eta_{t}(1+\operatorname{TRP}(\mathbf{J_{t-1}}))\nabla F_{S}(\mathbf{J_{t-1}})^{\mathrm{T}}\nabla F_{S^{\prime}}(\mathbf{J_{t-1}})\right]
Refer to caption
Figure 4: Exploration of gennl(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}}) on Gradient descent. Experiments is conducted on cifar10-5k dataset with cross entropy loss. The blue dash line in fourth row denotes 2η\frac{2}{\eta}. Gradient descent enter the EoS regime when the sharpness is above 2η\frac{2}{\eta}. Both RP and TRP have values around -1 when sharpness is below the 2η\frac{2}{\eta}.

Combining Equation (52) and Equation (53), we have:

|gennl(𝐉𝐓)|\displaystyle|\operatorname{gen}^{nl}(\mathbf{J_{T}})|\leq (54)
t=1Tηt[(1+TRP(𝐉𝐭𝟏))|FS(𝐉𝐭𝟏)TFS(𝐉𝐭𝟏)|+(1+RP(𝐉𝐭𝟏))FS(𝐉𝐭𝟏)2]\displaystyle\sum_{t=1}^{T}\eta_{t}\left[(1+\operatorname{TRP}(\mathbf{J_{t-1}}))|\nabla F_{S}(\mathbf{J_{t-1}})^{\mathrm{T}}\nabla F_{S^{\prime}}(\mathbf{J_{t-1}})|+(1+\operatorname{RP}(\mathbf{J_{t-1}}))\|\nabla F_{S}(\mathbf{J_{t-1}})\|^{2}\right]

Therefore, if we have for all tt,RP(𝐉𝐭)1\operatorname{RP}(\mathbf{J_{t}})\approx-1 and TRP(𝐉𝐭)1\operatorname{TRP}(\mathbf{J_{t}})\approx-1, then |gennl(𝐉𝐓)|0|\operatorname{gen}^{nl}(\mathbf{J_{T}})|\approx 0.

From Figure 4 we find that in stable regime, where the sharpness is below the 2η\frac{2}{\eta}, we have TRPRP1\operatorname{TRP}\approx\operatorname{RP}\approx-1. Under small learning rate, the gradient descent doesn’t enter the regime of edge of stability and we have TRPRP1\operatorname{TRP}\approx\operatorname{RP}\approx-1 during whole training process and gennl(𝐉𝐓)0\operatorname{gen}^{nl}(\mathbf{J_{T}})\approx 0.

Next, we consider the case of Stochastic Gradient Descent (SGD). Due to the stochastic estimation of the gradient, we need to rely on some approximations. Let 𝐗𝐭𝐢\mathbf{X_{t}^{i}} represent the weights after the tt-epoch and ii-th iteration of training. We assume a constant learning rate η\eta for SGD. The gradient is approximated as follows:

ηFS(𝐗𝐭𝐢)Bn(𝐗𝐭𝐗𝐭+𝟏)=Bni=1nBFS(𝐗𝐭𝐢),\eta\nabla F_{S}(\mathbf{X_{t}^{i}})\approx\frac{B}{n}(\mathbf{X_{t}}-\mathbf{X_{t+1}})=\frac{B}{n}\sum_{i=1}^{\frac{n}{B}}\nabla F_{S}(\mathbf{X_{t}^{i}}), (55)

and we appximate FS(𝐗𝐭𝐢)\nabla F_{S^{\prime}}(\mathbf{X_{t}^{i}}) as:

ηFS(𝐗𝐭𝐢)ηFS(𝐗𝐭).\eta\nabla F_{S}(\mathbf{X_{t}^{i}})\approx\eta\nabla F_{S}(\mathbf{X_{t}}). (56)

Therefore, we have:

RP(𝐗𝐭)η(FS(𝐗𝐭+𝟏)FS(𝐗𝐭))𝐗𝐭+𝟏𝐗𝐭\displaystyle\operatorname{RP}(\mathbf{X_{t}})\approx\frac{\eta(F_{S}(\mathbf{X_{t+1}})-F_{S}(\mathbf{X_{t}}))}{\|\mathbf{X_{t+1}}-\mathbf{X_{t}}\|} (57)
TRP(𝐗𝐭)FS(𝐗𝐭+𝟏)FS(𝐗𝐭)(𝐗𝐭𝐗𝐭+𝟏)TFS(𝐗𝐭).\displaystyle\operatorname{TRP}(\mathbf{X_{t}})\approx\frac{F_{S^{\prime}}(\mathbf{X_{t+1}})-F_{S^{\prime}}(\mathbf{X_{t}})}{(\mathbf{X_{t}}-\mathbf{X_{t+1}})^{\mathrm{T}}\nabla F_{S^{\prime}}(\mathbf{X_{t}})}. (58)
Refer to caption
Figure 5: Exploration of gennl(𝐉𝐓)\operatorname{gen}^{nl}(\mathbf{J_{T}}) on SGD case. Here, the effective learning rate is defined as ηefnBη\eta_{ef}\triangleq\frac{n}{B}\eta. We still have gennl(𝐉𝐓)0\operatorname{gen}^{nl}(\mathbf{J_{T}})\approx 0 under small learning rate.

We calculated the effect learning rate for SGD as ηefnBη\eta_{ef}\triangleq\frac{n}{B}\eta. Figure 5 shows that the conclusions of SGD are similar as GD, except that the conditions of entering EoS are different.

Appendix D Other Related Work

Table 4: Comparison of trajectory based generalization bounds. Only our proposed method can apply to the SGD with rich trajectory related information.
Method Conditions T.R.T
Nikolakakis et al. [25] Gradient Descent, ηtct1β\eta_{t}\leq\frac{c}{t}\leq\frac{1}{\beta}, β\beta-smooth t=1Tηt1ni=1nf(𝐉𝐭,zi)2\sum_{t=1}^{T}\eta_{t}\frac{1}{n}\sum_{i=1}^{n}\|\nabla f(\mathbf{J_{t}},z_{i})\|^{2}
Neu et al. [24] β\beta-smooth, 𝔼[f(𝐰,z)Fμ(𝐰)]v\mathbb{E}\left[\|\nabla f(\mathbf{w},z)-\nabla F_{\mu}(\mathbf{w})\|\right]\leq v, f()f(\cdot) is subguassian distribution Tη2\sqrt{T\eta^{2}}
Park et al. [27] Weak Lipschitz continuity, Piecewise β\beta^{\prime}-smooth, f()f(\cdot) is bounded, η<2β\eta<\frac{2}{\beta} TT
Ours Small Learning Rate, Fμ(𝐰)γFS(𝐰)\|\nabla F_{\mu}(\mathbf{w})\|\leq\gamma\|\nabla F_{S}(\mathbf{w})\| t𝑑FS(𝐉𝐭)1+Tr(Σ(𝐉𝐭))FS(𝐉𝐭)2\int_{t}dF_{S}(\mathbf{J_{t}})\sqrt{1+\frac{\operatorname{Tr}(\Sigma(\mathbf{J_{t}}))}{\|\nabla F_{S}(\mathbf{J_{t}})\|^{2}}}

This part compares the works that is not listed in Table 2. Table 4 gives other trajectory based generalization bounds. [25] is a stability based work designed mainly for generalization of gradient descent. It removes the Lipschitz assumption, and replaced by the term t=1Tηt1ni=1nf(𝐉𝐭,zi)2\sum_{t=1}^{T}\eta_{t}\frac{1}{n}\sum_{i=1}^{n}\|\nabla f(\mathbf{J_{t}},z_{i})\|^{2} in the generalization bounds. This helps enrich the trajectory information in the bounds. The limitation of this work is that it can only apply to the gradient descent and it is hard to extend to the stochastic gradient descent. Neu et al. [24] adapt the information-theretical generalization bound to the stochastic gradient descent. The Theorem 1 in Neu et al. [24] contains rich information about the learning trajectory, but most is about Fμ(𝐰)\nabla F_{\mu}(\mathbf{w}), which is unavailable for us. Therefore, we mainly consider the result of Corollary 2 in Neu et al. [24], which removes the term Fμ(𝐰)\nabla F_{\mu}(\mathbf{w}) by the assumption listed in Table 4. For this Collorary, the remained information within trajectory is merely the Tη2\sqrt{T\eta^{2}}. Althouth Neu et al. [24] dosen’t require the assumption of small learning rate, the bound contains the dimension of model, which is large for deep neural network. Compared with these work, our proposed method has advantage in that it can both reveal rich information about learning trajectory and applied to stochastic gradient descent.

Chandramoorthy et al. [8] analyzes the generalization behavior based on statistical algorithmic stability. The proposed generalization bound can be applied into algorithms that don’t converge. Let S(i)S^{(i)} be the dataset obtained by replace ziz_{i} in SS with another sample ziz_{i}^{\prime} draw from distribution μ\mu.The generalization bound relies on the stability measure msup{1Tt=0T1f(𝐉𝐭|S,z)1Tt=0T1f(𝐉𝐭|S(i),z)|z𝒵,i[n]}m\triangleq\sup\{\frac{1}{T}\sum_{t=0}^{T-1}f(\mathbf{J_{t}}|S,z)-\frac{1}{T}\sum_{t=0}^{T-1}f(\mathbf{J_{t}}|S^{(i)},z)|z\in\mathcal{Z},i\in[n]\}. We don’t directly compare with this method because the calculation of mm relies on S(i)S^{(i)} which contains sample outside of SS. Therefore, we treat this result as intermediate results. More assumption is needed to remove this dependence of the information about the unseen samples, i.e., the samples outside set SS.

Appendix E Effect of Learning Rate and Stochastic Noise

In this part, we want to analyze how learning rate and the stochastic noise jointly affect our proposed generalization bound. Specifically, we denote pt(𝐰)p_{t}(\mathbf{w}) as the distribution of the 𝐉𝐭\mathbf{J_{t}} during the training with multiple training steps. Following the work [12], we consider the SDE function as an approximation, which is shown as below:

d𝐰=FS(𝐰)dt+ηC12d𝐖(t).\mathrm{d}\mathbf{w}=-\nabla F_{S}(\mathbf{w})\mathrm{d}t+\sqrt{\eta}C^{\frac{1}{2}}\mathrm{d}\mathbf{W}(t). (59)

The SDE can be regarded as the continuous counterpart of Equation(3) when sets the distribution of noise term ϵ\epsilon^{\prime} in Equation(3) as Gaussian distribution. The influence of the noise ϵ\epsilon on pt(𝐰)p_{t}(\mathbf{w}) is shown in the following theorem.

Theorem E.1.

When the updating of the weight ww follows Equation (59), the covariance matrix CC is a hessian matrix of a function with a scalar output, then we have:

pt(𝐰)t=i=1d𝐰𝐢[FS(𝐰)pt(𝐰)η2[Tr(C(𝐰))+C(𝐰)wlog(pt(𝐰))dampling factor]pt(𝐰)].\frac{\partial p_{t}(\mathbf{w})}{\partial t}=-\sum_{i=1}^{d}\frac{\partial}{\partial\mathbf{w_{i}}}[\nabla F_{S}(\mathbf{w})p_{t}(\mathbf{w})-\frac{\eta}{2}[\nabla\operatorname{Tr}(C(\mathbf{w}))+\underbrace{C(\mathbf{w})\nabla_{w}\log(p_{t}(\mathbf{w}))}_{\text{dampling factor}}]p_{t}(\mathbf{w})]. (60)
Remark E.2.

Previous studies ([44, 31, 12]) tell that the covariance matrix CC is proximately equal to the hessian matrix of the loss function with respect to the parameters of DNN. Thus, the above condition that the covariance matrix CC is a hessian matrix of a function with scalar output is easy to be satisfied. Formula (60) contains three parts. The item FS(𝐰)pt(𝐰)F_{S}(\mathbf{w})p_{t}(\mathbf{w}) enlarge the probability of parameters being located in the parameter space with low FS(𝐰)F_{S}(\mathbf{w}). Tr(C(𝐰))\nabla\text{Tr}(C(\mathbf{w})) and C(𝐰)wlog(pt(𝐰))C(\mathbf{w})\nabla_{w}\log(p_{t}(\mathbf{w})) ususally contradict with each other. Tr(C(𝐰))\nabla\text{Tr}(C(\mathbf{w})) enlarge the probability of parameters being located in the parameter space with low Tr(C(𝐰))\text{Tr}(C(\mathbf{w})) value, while C(𝐰)wlog(pt(𝐰))C(\mathbf{w})\nabla_{w}\log(p_{t}(\mathbf{w})) serves as a damping factor to prevent the probability from concentrating on a small space. Therefore, setting larger learning rate gives stronger force for the weight to the area with lower Tr(C(𝐰))\text{Tr}(C(\mathbf{w})) values. According to Equation 5, we also have a lower Σ(𝐰)\Sigma(\mathbf{w}). As a result, large learning rate causes a small lower bound in Theorem 3.6

Proof.

Based on the condition described above, we can infer that C(𝐰)=G(𝐰)C(\mathbf{w})=\nabla\nabla G(\mathbf{w}), where G is a function with a scalar output.

We first prove that C(𝐰)=Tr(C(𝐰))\nabla\cdot C(\mathbf{w})=\nabla\operatorname{Tr}(C(\mathbf{w})) as below:

j =[G(𝐰)]j\displaystyle=[\nabla\cdot\nabla\nabla G(\mathbf{w})]_{j} (61)
=iwiwiwjG(𝐰)\displaystyle=\sum_{i}\frac{\partial}{\partial w_{i}}\frac{\partial}{\partial w_{i}}\frac{\partial}{\partial w_{j}}G(\mathbf{w})
=wjiwiwiG(𝐰)\displaystyle=\frac{\partial}{\partial w_{j}}\sum_{i}\frac{\partial}{\partial w_{i}}\frac{\partial}{\partial w_{i}}G(\mathbf{w})
=wjTr(C(𝐰)).\displaystyle=\frac{\partial}{\partial w_{j}}\operatorname{Tr}(C(\mathbf{w})).

So far, we can infer that C=Tr(C)\nabla\cdot C=\nabla\operatorname{Tr}(C). According to Fokker-Planck equation( [26]), we have:

pt(𝐰)t\displaystyle\frac{\partial p_{t}(\mathbf{w})}{\partial t} =i=1d𝐰𝐢[FD(𝐰)pt(𝐰)]+12ηi=1d𝐰𝐢[jd𝐰𝐣[C(𝐰)pt(𝐰)]]\displaystyle=-\sum_{i=1}^{d}\frac{\partial}{\partial\mathbf{w_{i}}}\left[\nabla F_{D}(\mathbf{w})p_{t}(\mathbf{w})\right]+\frac{1}{2}\eta\sum_{i=1}^{d}\frac{\partial}{\partial\mathbf{w_{i}}}\left[\sum_{\mathrm{j}}^{\mathrm{d}}\frac{\partial}{\partial\mathbf{w_{j}}}\left[C(\mathbf{w})p_{t}(\mathbf{w})\right]\right] (62)
=i=1d𝐰𝐢[FD(𝐰)pt(𝐰)]+12ηi=1d𝐰𝐢[pt(𝐰)C+pt(𝐰)Cwlogpt(𝐰)]\displaystyle=-\sum_{i=1}^{d}\frac{\partial}{\partial\mathbf{w_{i}}}\left[\nabla F_{D}(\mathbf{w})p_{t}(\mathbf{w})\right]+\frac{1}{2}\eta\sum_{i=1}^{d}\frac{\partial}{\partial\mathbf{w_{i}}}\left[p_{t}(\mathbf{w})\nabla\cdot C+p_{t}(\mathbf{w})C\nabla_{w}\log p_{t}(\mathbf{w})\right]
=i=1d𝐰𝐢[FD(𝐰)pt(𝐰)12η[C(𝐰)+C(𝐰)wlogpt(𝐰)]pt(𝐰)]\displaystyle=-\sum_{i=1}^{d}\frac{\partial}{\partial\mathbf{w_{i}}}\left[\nabla F_{D}(\mathbf{w})p_{t}(\mathbf{w})-\frac{1}{2}\eta\left[\nabla\cdot C(\mathbf{w})+C(\mathbf{w})\nabla_{w}\log p_{t}(\mathbf{w})\right]p_{t}(\mathbf{w})\right]
=i=1d𝐰𝐢[FD(𝐰)pt(𝐰)12η[Tr(C(𝐰))+C(𝐰)wlogpt(𝐰)]pt(𝐰)].\displaystyle=-\sum_{i=1}^{d}\frac{\partial}{\partial\mathbf{w_{i}}}\left[\nabla F_{D}(\mathbf{w})p_{t}(\mathbf{w})-\frac{1}{2}\eta\left[\nabla\operatorname{Tr}(C(\mathbf{w}))+C(\mathbf{w})\nabla_{w}\log p_{t}(\mathbf{w})\right]p_{t}(\mathbf{w})\right].

Therefore, the theorem is proven. ∎