This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

AdaSAM: Boosting Sharpness-Aware Minimization with Adaptive Learning Rate and Momentum for Training Deep Neural Networks

Hao Sun, Li Shen, Qihuang Zhong, Liang Ding, Shixiang Chen
Jingwei Sun, Jing Li, Guangzhong Sun, and Dacheng Tao 
Hao Sun, Jingwei Sun, Jing Li and Guangzhong Sun are with School of Computer Science and Technology, University of Science and Technology of China, Hefei, China, 230000. (E-mail: ustcsh@mail.ustc.edu.cn, sunjw@ustc.edu.cn, lj@ustc.edu.cn, gzsun@ustc.edu.cn.) Qihuang Zhong is with the School of Computer Science, Wuhan University, Hubei, 430000. (E-mail: zhongqihuang@whu.edu.cn) Li Shen, Liang Ding, Shixiang Chen, and Dacheng Tao are with JD Explore Academy, Beijing, 100000. (E-mail: mathshenli@gmail.com, liangding.liam@gmail.com, chenshxiang@gmail.com, dacheng.tao@gmail.com.)
Abstract

Sharpness aware minimization (SAM) optimizer has been extensively explored as it can generalize better for training deep neural networks via introducing extra perturbation steps to flatten the landscape of deep learning models. Integrating SAM with adaptive learning rate and momentum acceleration, dubbed AdaSAM, has already been explored empirically to train large-scale deep neural networks without theoretical guarantee due to the triple difficulties in analyzing the coupled perturbation step, adaptive learning rate and momentum step. In this paper, we try to analyze the convergence rate of AdaSAM in the stochastic non-convex setting. We theoretically show that AdaSAM admits a 𝒪(1/bT)\mathcal{O}(1/\sqrt{bT}) convergence rate, which achieves linear speedup property with respect to mini-batch size bb. Specifically, to decouple the stochastic gradient steps with the adaptive learning rate and perturbed gradient, we introduce the delayed second-order momentum term to decompose them to make them independent while taking an expectation during the analysis. Then we bound them by showing the adaptive learning rate has a limited range, which makes our analysis feasible. To the best of our knowledge, we are the first to provide the non-trivial convergence rate of SAM with an adaptive learning rate and momentum acceleration. At last, we conduct several experiments on several NLP tasks, which show that AdaSAM could achieve superior performance compared with SGD, AMSGrad, and SAM optimizers.

Index Terms:
Sharpness-aware minimization, Adaptive learning rate, Non-convex optimization, linear speedup.

I Introduction

Sharpness-aware minimization (SAM) [1] is a powerful optimizer for training large-scale deep learning models by explicitly minimizing the gap between the training performance and generalization performance. It has achieved remarkable results in training various deep neural networks, including ResNet [2, 1, 3], vision transformer [4, 5], language models [6, 7, 8], on extensive benchmarks.

However, SAM-type methods suffer from several issues during training the deep neural networks, especially for huge computation costs and heavily hyper-parameter tuning procedure. In each iteration, SAM needs double gradients computation compared with classic optimizers, like SGD, Adam [9], AMSGrad [10], due to the extra perturbation step. Hence, SAM requires to forward and back propagate twice for one parameter update, resulting in one more computation cost than the classic optimizers. Moreover, as there are two steps during the training process, it needs double hyper-parameters, which makes the learning rate tuning unbearable and costly.

Adaptive learning rate optimization methods [11] scale the gradients based on the history gradient information to accelerate the convergence by tuning the learning rate automatically. These methods, such as Adagrad [12], Adam [9], and AMSGrad [10], have been proposed for solving the computer vision, natural language process, and generative neural networks tasks [11, 13, 14, 15]. Recently, several works have tried to ease the learning rate tuning in SAM by inheriting the triplet advantages of SAM, adaptive learning rate, and momentum acceleration. For example, [16] and [17] train ViT models and NLP models with adaptive learning rates and momentum acceleration, respectively. Although remarkable performance has been achieved, their convergences are still unknown since the adaptive learning rate and momentum acceleration are used in SAM. Directly analyzing its convergence is complicated and difficult due to the three coupled steps of optimization, i.e., the adaptive learning rate estimation is coupled with the momentum step and perturbation step of SAM.

In this paper, we analyze the convergence rate of SAM with an adaptive learning rate and momentum acceleration, dubbed AdaSAM, in the non-convex stochastic setting. To circumvent the difficulty in the analysis, we develop a novel technique to decouple the three-step training of SAM from the adaptive learning rate and momentum step. The analysis procedure is mainly divided into three parts. The first part is to analyze the procedure of the SAM. Then we analyze the second step that adopts the adaptive learning rate method. We introduce a second-order momentum term from the previous iteration, which is related to the adaptive learning rate and independent of SAM while taking an expectation. Then we can bound the term composed by the SAM and the previous second-order momentum due to the limited adaptive learning rate. In the last part, we analysis the momentum acceleration that is combined with the SAM and the adaptive learning rate. The momentum acceleration lead to an extra term in convergence analysis. Here, we introduce an auxiliary sequence to absorb it and show that their summation over the all iterations is controllable. We prove that AdaSAM enjoys the property of linear speedup property with respect to the batch size, i.e. 𝒪(1/bT)\mathcal{O}(1/\sqrt{bT}) where bb is the mini-batch size. Empirically, we apply AdaSAM to train RoBERTa model on the GLUE benchmark to evaluate our theoretical findings. We show that AdaSAM achieves the best performance in experiments, where it wins 6 tasks of 8 tasks, and the linear speedup can be clearly observed.

In the end, we summarize our contributions as follows:

  • We present the first convergence guarantee of the adaptive SAM method with momentum acceleration under the stochastic non-convex setting. Our results suggest that a large mini-batch can help convergence due to the established linear speedup with respect to batch size.

  • We conduct a series of experiments on various tasks. The results show that AdaSAM outperforms most of the state-of-art optimizers and the linear speedup is verified.

II Preliminary and Related Work

In this section, we first describe the basic problem setup and then introduce several related works on the SAM, adaptive learning rate and momentum steps.

II-A Problem Setup

In this work, we focus on stochastic nonconvex optimization

minxdf(x):=𝔼ξDfξ(x),\min_{x\in{\mathbb{R}}^{d}}f(x):=\mathbb{E}_{\xi\sim D}{f_{\xi}(x)}, (1)

where dd is dimension of variable xx, DD is the unknown distribution of the data samples, fξ(x)f_{\xi}(x) is a smooth and possibly non-convex function, and fξi(x)f_{\xi_{i}}(x) denotes the objective function at the sampled data point ξi\xi_{i} according to data distribution DD. In machine learning, it covers empirical risk minimization as a special case and ff is the loss function when the dataset DD cover NN data points, i.e., D={ξi,i=1,2,,N}D=\{\xi_{i},i=1,2,\ldots,N\}. Problem (1) reduces to the following finite-sum problem:

minxdf(x):=1Nifξi(x).\min_{x\in{\mathbb{R}}^{d}}f(x):=\frac{1}{N}\sum_{i}{f_{\xi_{i}}(x)}. (2)

Notations.

Without additional declaration, we represent fi(x)f_{i}(x) as fξi(x)f_{\xi_{i}}(x) for simplification, which is the ii-th loss function while xdx\in{\mathbb{R}}^{d} is the model parameter and dd is the parameter dimension. We denote the l2l_{2} norm as 2\|\cdot\|_{2}. A Hadamard product is denoted as aba\odot b where aa,bb are two vectors. For a vector ada\in{\mathbb{R}}^{d}, a\sqrt{a} is denoted as a vector that the jj-th value, (a)(j)(\sqrt{a})_{(j)}, is equal to the square root of aja_{j}.

II-B Related Work

Sharpness-aware minimization

Many works try to improve the generalization ability during training the deep learning model. Some methods such as dropout[18], weight decay [19], and regularization methods [20, 21] provide an explicit way to improve generalization. Previous work shows that sharp minima may lead to poor generalization whereas flat minima perform better[22, 23, 24]. Therefore, it is popular to consider sharpness to be closely related to the generalization. Sharpness-aware minimization (SAM) [1] targets to find flat minimizers explicitly by minimizing the training loss uniformly in the entire neighborhood. Specifically, SAM aims to solve the following minimax saddle point problem:

minxmaxδρf(x+δ)+λx22,\min_{x}\max_{\|\delta\|\leq\rho}f(x+\delta)+\lambda\|x\|^{2}_{2}, (3)

where ρ0\rho\geq 0 and λ0\lambda\geq 0 are two hyperparameters. That is, the perturbed loss function of f(x)f(x) in a neighborhood is minimized instead of the original loss function f(x)f(x). By using Taylor expansion of f(x+δ)f(x+\delta) with respect to δ\delta, the inner max problem is approximately solved via

δ(x)\displaystyle\delta^{*}(x) =argmaxδρf(x+δ)\displaystyle=\operatorname*{arg\,max}_{\|\delta\|\leq\rho}f(x+\delta)
argmaxδρf(x)+δf(x)\displaystyle\approx\operatorname*{arg\,max}_{\|\delta\|\leq\rho}f(x)+\delta^{\top}\nabla f(x)
=argmaxδρδf(x)=ρf(x)f(x).\displaystyle=\operatorname*{arg\,max}_{\|\delta\|\leq\rho}\delta^{\top}\nabla f(x)=\rho\frac{\nabla f(x)}{\|\nabla f(x)\|}.

By dropping the quadratic term, (3) is simplified as the following minimization problem

minxf(x+ρf(x)f(x)).\min_{x}f\left(x+\rho\frac{\nabla f(x)}{\|\nabla f(x)\|}\right). (4)

The stochastic gradient of f(x+ρf(x)f(x))f\left(x+\rho\frac{\nabla f(x)}{\|\nabla f(x)\|}\right) on a batch data bb includes the Hessian-vector product, SAM further approximates the gradient by

xfb(x+ρfb(x)fb(x))xfb(x)|x+ρfb(x)fb(x).\nabla_{x}f_{b}\left(x+\rho\frac{\nabla f_{b}(x)}{\|\nabla f_{b}(x)\|}\right)\approx\nabla_{x}f_{b}(x)|_{x+\rho\frac{\nabla f_{b}(x)}{\|\nabla f_{b}(x)\|}}.

Then, along the negative direction xfb(x)|x+ρfb(x)fb(x)-\nabla_{x}f_{b}(x)|_{x+\rho\frac{\nabla f_{b}(x)}{\|\nabla f_{b}(x)\|}}, SGD is applied to solve the surrogate minimization problem (4). It is easy to see that SAM requires twice gradient back-propagation, i.e., fb(x)\nabla f_{b}(x) and xfb(x)|x+ρfb(x)fb(x)\nabla_{x}f_{b}(x)|_{x+\rho\frac{\nabla f_{b}(x)}{\|\nabla f_{b}(x)\|}}. Due to the existence of hyperparameter ρ,\rho, one needs to carefully tune both ρ\rho and learning rate in SAM. In practice, ρ\rho is predefined to control the radius of the neighborhood.

Recently, Several variants of SAM are proposed to improve its performance. For example, [16, 17, 8] have empirically incorporated adaptive learning rate with SAM and shown impressive generalization accuracy, while their convergence analysis has never been studied. ESAM [25] proposes an efficient method by sparsifying the gradients to alleviate the double computation cost of backpropagation. ASAM [17] modifies SAM by adaptively scaling the neighborhood so that the sharpness is invariant to parameters re-scaling. GSAM [16] simultaneously minimizes the perturbed function and a new defined surrogate gap function to further improve the flatness of minimizers. Liu et al. [26] also study SAM in large-batch training scenario and periodically update the perturbed gradient. Recently, [3, 8] improve the efficiency of SAM by adopting the sparse gradient perturbation technique. [27, 28] extend SAM to the federated learning setting setting with a significant performance gain. On the other hand, there are some works analyzing the convergence of the SAM such as [29] without considering the normalization step, i.e., the normalization in fb(x)fb(x)\frac{\nabla f_{b}(x)}{\|\nabla f_{b}(x)\|}.

Adaptive optimizer

The adaptive optimizer can automatically adjust the learning rate based on the history gradients methods. The first adaptive method, Adagrad [12], can achieve a better result than other first-order methods under the convex setting. While training the deep neural network, Adagrad will decrease the learning rate rapidly with a degraded performance. Adadelta [30] is proposed to change this situation and introduces a learning rate based on the exponential average history gradients. Adam [9] additionally adds momentum step to stabilize the training process, and it shows great performance in many tasks. However, Reddi et al [10] give a counterexample that it cannot converge even when the objective function is convex and propose an alternative method called AMSGrad with convergence guarantee. Then, many works [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44] have been proposed to study the convergence of adaptive methods and their variants in the nonconvex setting. However, their analysis techniques can not directly extend to establish the convergence of SAM with adaptive learning rate due to the coupled perturbation step and adaptive learning rate.

Momentum acceleration

Momentum methods such as Polyak’s heavy ball method [45], Nestrov’s accelerated gradient descent method [46] and accelerated projected method [47] are used to optimize the parameters of the deep neural network. In practice, they have been used to accelerated for federated learning tasks [48], non-negative latent factor model [49] and recommender systems [50]. There are many theoretical works [51, 52, 53] that focus on analyzing the momentum acceleration for optimizing non-convex problem. [54] shows that it is important for tuning momentum while training deep neural network. [55] first points out linear convergence results for stochastic momentum method. [56] proposes a class of accelerated zeroth-order and first-order momentum method to solve mini-optimization and minimax-optimization problem. [57] extend the momentum method by introducing an RNA scheme and a constrained formulation RNA which has nonlinear updates. [58] propose a heuristic adaptive restart method and [59] propose a scheduled restart momentum accelerated SGD method named SRSGD which helps reduce the training time. [60] adds one momentum term on to the distributed gradient algorithm.

III Methodology

In this section, we introduce SAM with adaptive learning rate and momentum acceleration, dubbed AdaSAM, to stabilize the training process of SAM and ease the learning rate tuning. Then, we present the convergence results of AdaSAM. At last, we give the proof sketch for the main theorem.

Input: Initial parameters x0x_{0}, m1=0m_{-1}=0, v^1=ϵ2\hat{v}_{-1}=\epsilon^{2}(a small positive scalar to avoid the denominator diminishing), base learning rate γ\gamma, neighborhood size ρ\rho and momentum parameters β1\beta_{1}, β2\beta_{2}.
Output: Optimized parameter xT+1x_{T+1}
1 for iteration t \in {0,1,2,,T1}\{0,1,2,...,T-1\} do
2   Sample mini-batch B={ξt1,ξt2,,ξt|B|}B=\{\xi_{t_{1}},\xi_{t_{2}},...,\xi_{t_{|B|}}\};
3   Compute gradient st=xfB(x)|xt=1biBfti(xt)s_{t}=\nabla_{x}f_{B}(x)|_{x_{t}}=\frac{1}{b}\sum_{i\in B}\nabla f_{t_{i}}(x_{t});
4   Compute δ(xt)=ρtstst\delta(x_{t})=\rho_{t}\frac{s_{t}}{\|s_{t}\|};
5   Compute SAM gradient gt=xfB(x)|xt+δ(xt)g_{t}=\nabla_{x}f_{B}(x)|_{x_{t}+\delta(x_{t})};
6 mt=β1mt1+(1β1)gtm_{t}=\beta_{1}m_{t-1}+(1-\beta_{1})g_{t};
7 vt=β2vt1+(1β2)[gt]2v_{t}=\beta_{2}v_{t-1}+(1-\beta_{2})[g_{t}]^{2};
8 v^t=max(v^t1,vt)\hat{v}_{t}=\max(\hat{v}_{t-1},v_{t});
9 ηt=1/v^t\eta_{t}={1}{/}{\sqrt{\hat{v}_{t}}};
10 xt+1=xtγmtηtx_{t+1}=x_{t}-\gamma m_{t}\odot\eta_{t};
11 
12 end for
Algorithm 1 AdaSAM: SAM with adaptive learning rate and momentum acceleration

III-A AdaSAM Algorithm

AdaSAM for solving Problem (1) is described in Algorithm 1. In each iteration, a mini-batch gradient estimation gtg_{t} at point x+ϵ(x)x+\epsilon(x) with batchsize bb is computed, i.e.,

gt=xfb(x)|xt+ϵ(xt)=1biBfξi(xt+δ(xt)).g_{t}=\nabla_{x}f_{b}(x)|_{x_{t}+\epsilon(x_{t})}=\frac{1}{b}\sum_{i\in B}\nabla f_{\xi_{i}}(x_{t}+\delta(x_{t})).

Here, δ(xt)\delta(x_{t}) is the extra perturbed gradient step in SAM that is given as follows

δ(xt)=ρstst,wherest=xfb(x)|xt=1biBfξi(xt).\delta(x_{t})=\rho\frac{s_{t}}{\|s_{t}\|},{\rm\ where}\ s_{t}=\nabla_{x}f_{b}(x)|_{x_{t}}=\frac{1}{b}\sum_{i\in B}\nabla f_{\xi_{i}}(x_{t}).

Then, the momentum term of gtg_{t} and the second-order moment term [gt]2[g_{t}]^{2} is accumulatively computed as mtm_{t} and vtv_{t}, respectively. AdaSAM then updates iterate along mt-m_{t} with the adaptive learning rate γηt\gamma\eta_{t}.

Remark 1.

Below, we give several comments on AdaSAM:

  • When β2=1\beta_{2}=1, the adaptive learning rate reduces to the diminishing one as SGD. Then, AdaSAM recovers the classic SAM optimizer.

  • If we drop out the 8-th line v^t=max(v^t1,vt),\hat{v}_{t}=\max(\hat{v}_{t-1},v_{t}), then our algorithm becomes the variant of Adam. The counterexample that Adam does not converge in the [10] also holds for the SAM variant, while AdaSAM can converge.

III-B Convergence Analysis

Before presenting the convergence results of the AdaSAM algorithm, we first introduce some necessary assumptions.

Assumption 1 (LL-smooth).

fif_{i} and ff is differentiable with gradient Lipschitz property: fi(x)fi(y)Lxy\|\nabla f_{i}(x)-\nabla f_{i}(y)\|\leq L\|x-y\|,f(x)f(y)Lxy,x,yd,i=1,2,,N,\|\nabla f(x)-\nabla f(y)\|\leq L\|x-y\|,\forall x,y\in{\mathbb{R}}^{d},i=1,2,...,N, which also implies the descent inequality, i.e., fi(y)fi(x)+fi(x),yx+L2yx2f_{i}(y)\leq f_{i}(x)+\langle\nabla f_{i}(x),y-x\rangle+\frac{L}{2}\|y-x\|^{2}.

Assumption 2 (Bounded variance).

The estimator of the gradient is unbiased and the variance of the stochastic gradient is bounded. i.e.,

𝔼fi(x)=f(x),𝔼fi(x)f(x)2σ2.\mathbb{E}\nabla f_{i}(x)=\nabla f(x),\quad\mathbb{E}\|\nabla f_{i}(x)-\nabla f(x)\|^{2}\leq\sigma^{2}.

When the mini-batch size bb is used, we have 𝔼fb(x)f(x)2σ2b.\mathbb{E}\|\nabla f_{b}(x)-\nabla f(x)\|^{2}\leq\frac{\sigma^{2}}{b}.

Assumption 3 (Bounded stochastic gradients).

The stochastic gradient is uniformly bounded, i.e.,

fi(x)G,foranyi=1,,N.\|\nabla f_{i}(x)\|_{\infty}\leq G,for\ any\ i=1,\ldots,N.
Remark 2.

The above assumptions are commonly used in the proof of convergence for adaptive stochastic gradient methods such as [61, 62, 31, 32].

Below, we briefly explain the main idea of analyzing the convergence of the AdaSAM algorithm. First, we discuss the difficulty of applying the adaptive learning rate on SAM. We notice that the main step which contains adaptive learning rate in convergence analysis is to estimate the expectation 𝔼[xt+1xt]=𝔼mtηt=𝔼(1β1)gtηt𝔼β1mt1ηt,\mathbb{E}{[x_{t+1}-x_{t}]}=-\mathbb{E}m_{t}\odot\eta_{t}=-\mathbb{E}(1-\beta_{1})g_{t}\odot\eta_{t}-\mathbb{E}\beta_{1}m_{t-1}\odot\eta_{t}, which is conditioned on the filtration σ(xt)\sigma(x_{t}). In this part, we consider the situation that β1=0\beta_{1}=0 which does not include the momentum. Then, we apply delay technology to disentangle the dependence between gtg_{t} and ηt\eta_{t}, that is

𝔼gtηt\displaystyle\mathbb{E}g_{t}\odot\eta_{t} =𝔼[gtηt1]+𝔼[gt(ηtηt1)]\displaystyle=\mathbb{E}{[g_{t}\odot\eta_{t-1}]}+\mathbb{E}{[g_{t}\odot(\eta_{t}-\eta_{t-1})]}
=f(xt)ηt1+𝔼[gt(ηtηt1)].\displaystyle=\nabla f(x_{t})\odot\eta_{t-1}+\mathbb{E}{[g_{t}\odot(\eta_{t}-\eta_{t-1})]}.

The second term 𝔼[gt(ηtηt1)]\mathbb{E}{[g_{t}\odot(\eta_{t}-\eta_{t-1})]} is dominated by the first term f(xt)ηt1\nabla f(x_{t})\odot\eta_{t-1}. Then, it is not difficult to get the convergence result of the stochastic gradient descend with the adaptive learning rate such as AMSGrad. However, when we apply the same strategy to AdaSAM, we find that 𝔼gtηt1\mathbb{E}g_{t}\odot\eta_{t-1} cannot be handled similarly because 𝔼gt=𝔼xfb(x+ρfb(x)fb(x))f(xt)\mathbb{E}g_{t}=\mathbb{E}\nabla_{x}f_{b}\left(x+\rho\frac{\nabla f_{b}(x)}{\|\nabla f_{b}(x)\|}\right)\neq\nabla f(x_{t}). Inspired by [29, Lemma 16], our key observation is that

𝔼xfb(x+ρfb(x)fb(x))\displaystyle\mathbb{E}\nabla_{x}f_{b}\left(x+\rho\frac{\nabla f_{b}(x)}{\|\nabla f_{b}(x)\|}\right) 𝔼xfb(x+ρf(x)f(x))\displaystyle\approx\mathbb{E}\nabla_{x}f_{b}\left(x+\rho\frac{\nabla f(x)}{\|\nabla f(x)\|}\right)
=xf(x+ρf(x)f(x))\displaystyle=\nabla_{x}f\left(x+\rho\frac{\nabla f(x)}{\|\nabla f(x)\|}\right)

and we prove the other terms such as 𝔼(xfb(x+ρfb(x)fb(x))xfb(x+ρf(x)f(x)))ηt1\mathbb{E}\left(\nabla_{x}f_{b}\left(x+\rho\frac{\nabla f_{b}(x)}{\|\nabla f_{b}(x)\|}\right)-\nabla_{x}f_{b}\left(x+\rho\frac{\nabla f(x)}{\|\nabla f(x)\|}\right)\right)\odot\eta_{t-1} have small values that do not dominate the convergence rate.

On the other hand, when we apply the momentum steps, we find that the term 𝔼mt1ηt\mathbb{E}m_{t-1}\odot\eta_{t} cannot be ignored. By introducing an auxiliary sequence zt=xt+β11β1(xtxt1)z_{t}=x_{t}+\frac{\beta_{1}}{1-\beta_{1}}(x_{t}-x_{t-1}), we have 𝔼[zt+1zt]=𝔼[β11β1γmt1(ηt1ηt)γgtηt]\mathbb{E}{[z_{t+1}-z_{t}]}=-\mathbb{E}{[\frac{\beta_{1}}{1-\beta_{1}}\gamma m_{t-1}\odot(\eta_{t-1}-\eta_{t})-\gamma g_{t}\odot\eta_{t}]}. The first term contains the momentum term which has a small value due to the difference of the adaptive learning rate ηt\eta_{t}. Thus, it is diminishing without hurting the convergence rate.

Theorem 1.

Under the assumptions 1,2,3, and γ\gamma is a fixed number satisfying γϵ16L\gamma\leq\frac{\epsilon}{16L}, for the sequence {xt}\{x_{t}\} generated by Algorithm 1, we have the following convergence rate

1Tt=0T1𝔼f(xt)222G(f(x0)f)γT+8GγLϵσ2bϵ+Φ\displaystyle\frac{1}{T}\sum_{t=0}^{T-1}\mathbb{E}\|\nabla f(x_{t})\|^{2}_{2}\!\leq\!\frac{2G(f(x_{0})\!-\!f^{*})}{\gamma T}\!+\!\frac{8G\gamma L}{\epsilon}\frac{\sigma^{2}}{b\epsilon}\!+\!\Phi (5)

where

Φ=45GL2ρt2ϵ+2G3(1β1)Td(1ϵ1G)+6γ2L2β12(1β1)2dG3ϵ3\displaystyle\Phi=\frac{45GL^{2}\rho_{t}^{2}}{\epsilon}+\frac{2G^{3}}{(1-\beta_{1})T}d(\frac{1}{\epsilon}-\frac{1}{G})+\frac{6\gamma^{2}L^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\frac{dG^{3}}{\epsilon^{3}}
+2(4+(β11β1)2)γLG3Td(ϵ2G2)+8GγLϵLρt2ϵ,\displaystyle+\frac{2(4+(\frac{\beta_{1}}{1-\beta_{1}})^{2})\gamma LG^{3}}{T}d(\epsilon^{-2}-G^{-2})+\frac{8G\gamma L}{\epsilon}\frac{L\rho_{t}^{2}}{\epsilon}, (6)

in which TT is the number of iteration, ff^{*} is the minimal value of the function ff, γ\gamma is the base learning rate, bb is the mini-batch size, d is the dimension of paramter xx. β1\beta_{1}, GG, LL, ϵ\epsilon, σ2\sigma^{2}, dd are fixed constants.

Theorem 1 characterizes the convergence rate of the sequence {xt}\{x_{t}\} generated by AdaSAM with respect to the stochastic gradient residual. The first two terms of the right hand side of Inequality (5) are the terms that dominate the convergence rate. Compared with the first two terms, Φ\Phi is a small value while we set neighborhood size ρ\rho and learning rate γ\gamma as small values which are related to large iteration number TT. Then, we obtain the following corollary directly.

Corollary 1 (Mini-batch linear speedup).

Under the same conditions of Theorem 1. Furthermore, when we choose the base learning rate γ=O(bT)\gamma=O(\sqrt{\frac{b}{T}}) and neighborhood size ρ=O(1bT)\rho=O(\sqrt{\frac{1}{bT}}) , the following result holds:

1Tt=0T1𝔼f(xt)22\displaystyle\frac{1}{T}\sum_{t=0}^{T-1}\mathbb{E}\|\nabla f(x_{t})\|^{2}_{2} =O(1bT)+O(1bT)+O(1T)\displaystyle=O\left(\frac{1}{\sqrt{bT}}\right)+O\left(\frac{1}{bT}\right)+O\left(\frac{1}{T}\right)
+O(1b12T32)+O(b12T32)+O(bT).\displaystyle+O\left(\frac{1}{b^{\frac{1}{2}}T^{\frac{3}{2}}}\right)+O\left(\frac{b^{\frac{1}{2}}}{T^{\frac{3}{2}}}\right)+O\left(\frac{b}{T}\right).

When TT is sufficiently large, we achieve the linear speedup convergence rate with respect to mini-batch size bb, i.e.,

1Tt=0T1𝔼f(xt)22=O(1bT).\displaystyle\frac{1}{T}\sum_{t=0}^{T-1}\mathbb{E}\|\nabla f(x_{t})\|^{2}_{2}=O\left(\frac{1}{\sqrt{bT}}\right). (7)
Remark 3.

Two comments are given about the above results:

  • To reach a O(δ)O(\delta) stationary point, when the batch size is 1, it needs T=O(1δ2)T=O(\frac{1}{\delta^{2}}) iterations. When the batch size is bb, we need to run T=O(1bδ2)T=O(\frac{1}{b\delta^{2}}) steps. The method with batch size bb is bb times faster than batch size of 1, which means that it has the mini-batch linear speedup property.

  • According to [63, 64, 37], AdaSAM can be extended to distributed version and achieves linear speedup property with respect to the number of works in the Parameter-Sever setting.

TABLE I: Evaluating SGD, SAM, AMSGrad and AdaSAM on the GLUE benchmark with β1=0.9\beta_{1}=0.9
CoLA SST-2 MRPC STS-B RTE MNLI QNLI QQP
Model mcc. Acc. Acc./F1 Pcor./Scor. Acc. m./mm. Acc. F1/ Acc. Avg.
SGD 9.25 50.92 68.38/ 81.22 3.22/ 1.9 55.6 84.94/ 84.87 63.61 85.6/ 80.14 55.8
SAM(ρ=\rho=0.01) 4.64 95.87 70.58/ 81.98 84.74/ 85.57 52.71 90.5/ 90.19 94.44 84.7/ 87.88 76.98
SAM(ρ=\rho=0.005) 66.76 95.76 68.38/ 81.22 2/ 2 52.71 90.42/ 89.74 94.6 86.72/ 89.94 68.35
SAM(best) 66.76 95.87 70.58/ 81.98 84.74/ 85.57 52.71 90.5/ 90.19 94.6 86.72/ 89.94 82.51
AMSGrad 68.0 96.33 90.2/ 92.72 91.72/ 91.48 87.73 90.67/ 90.41 94.82 88.7/ 91.41 89.52
AdaSAM(ρ=\rho=0.01) 65.29 96.33 91.18/ 93.64 90.13/ 90.36 84.84 90.97/ 90.42 94.65 88.55/ 91.23 88.97
AdaSAM(ρ=\rho=0.005) 68.74 96.67 90.93/ 93.36 91.64/ 91.38 87.73 90.88/ 90.4 94.56 88.69/ 91.33 89.69
AdaSAM(ρ=\rho=0.001) 67.3 96.1 90.2/ 92.96 91.9/ 91.62 85.92 90.45/ 90.4 94.56 88.64/ 91.27 89.28
AdaSAM(best) 68.74 96.67 91.18/ 93.64 91.9/ 91.62 87.73 90.97/ 90.42 94.65 88.69/ 91.33 89.8
Refer to caption
(a) MRPC
Refer to caption
(b) RTE
Refer to caption
(c) CoLA
Refer to caption
(d) SST-2

​​ ​​ ​​

Refer to caption
(e) MRPC
Refer to caption
(f) RTE
Refer to caption
(g) CoLA
Refer to caption
(h) SST-2

​​ ​​ ​​

Refer to caption
(i) STS-B
Refer to caption
(j) MNLI
Refer to caption
(k) QQP
Refer to caption
(l) QNLI

​​ ​​ ​​

Refer to caption
(m) STS-B
Refer to caption
(n) MNLI
Refer to caption
(o) QQP
Refer to caption
(p) QNLI

​​ ​​ ​​

Figure 1: The loss and evaluation metric v.s. steps on MRPC, RTE, CoLA, SST-2, STS-B, MNLI, QQP, and QNLI.(β1=0.9\beta_{1}=0.9)

III-C Proof Sketch

In this part, we give the proof sketch of the Theorem 1. For the complete proof, please see Appendix. Below, we first introduce an auxiliary sequence zt=xt+β11β1(xtxt1)z_{t}=x_{t}+\frac{\beta_{1}}{1-\beta_{1}}(x_{t}-x_{t-1}). By applying LL-smooth condition, we have

f(zt+1)f(zt)+f(zt),zt+1zt+L2zt+1zt2.\displaystyle f(z_{t+1})\!\leq\!f(z_{t})\!+\!\langle\nabla f(z_{t}),z_{t+1}-z_{t}\rangle\!+\!\frac{L}{2}\|z_{t+1}-z_{t}\|^{2}. (8)

Applying it to the sequence {zt}\{z_{t}\} and using the delay strategy yield

f(zt+1)f(zt)\displaystyle\;\;\;\;f(z_{t+1})-f(z_{t})
f(zt),γβ11β1mt1(ηt1ηt)+L2zt+1zt2\displaystyle\leq\langle\nabla f(z_{t}),\frac{\gamma\beta_{1}}{1-\beta_{1}}m_{t-1}\odot(\eta_{t-1}-\eta_{t})\rangle+\frac{L}{2}\|z_{t+1}-z_{t}\|^{2}
+f(zt),γbiBfi(xt+ρtstst)(ηt1ηt)\displaystyle\;\;\;\;+\langle\nabla f(z_{t}),\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot(\eta_{t-1}-\eta_{t})\rangle
+f(zt)f(xt),γbiBfi(xt+ρtstst)ηt1\displaystyle\;\;\;\;+\langle\nabla f(z_{t})-\nabla f(x_{t}),-\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle
+f(xt),γbiBfi(xt+ρtf(xt)f(xt))ηt1\displaystyle\;\;\;\;+\langle\nabla f(x_{t}),-\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}\rangle
+f(xt),γbiBfi(xt+ρtf(xt)f(xt))ηt1\displaystyle\;\;\;\;+\langle\nabla f(x_{t}),\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}
γbiBfi(xt+ρtstst)ηt1.\displaystyle\;\;\;\;\;\;\;\;\;\;\;\;-\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle. (9)

From the Lemma 5, Lemma 6, Lemma 7 in appendix, we can bound the above terms in (9) as follows

f(zt),γbiBfi(xt+ρtstst)(ηt1ηt)\displaystyle\langle\nabla f(z_{t}),\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot(\eta_{t-1}-\eta_{t})\rangle
γG2ηt1ηt1,\displaystyle\leq\gamma G^{2}\|\eta_{t-1}-\eta_{t}\|_{1}, (10)
f(zt),γβ11β1mt1(ηt1ηt)\displaystyle\langle\nabla f(z_{t}),\frac{\gamma\beta_{1}}{1-\beta_{1}}m_{t-1}\odot(\eta_{t-1}-\eta_{t})\rangle
γβ11β1G2ηt1ηt1,\displaystyle\leq\frac{\gamma\beta_{1}}{1-\beta_{1}}G^{2}\|\eta_{t-1}-\eta_{t}\|_{1}, (11)
f(xt),γbiBfi(xt+ρtf(xt)f(xt))ηt1\displaystyle\langle\nabla f(x_{t}),\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}
γbiBfi(xt+ρtstst)ηt1\displaystyle\;\;\;\;-\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle
γ2μ2f(xt)ηt12+2μ2γL2ρt2ϵ.\displaystyle\leq\frac{\gamma}{2\mu^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{2\mu^{2}\gamma L^{2}\rho_{t}^{2}}{\epsilon}. (12)

Then we substitute them into the (9), and take the conditional expectation to get

𝔼f(zt+1)f(zt)\displaystyle\mathbb{E}f(z_{t+1})-f(z_{t})
𝔼f(xt),γbiBfi(xt+ρtf(xt)f(xt))ηt1\displaystyle\leq\mathbb{E}\langle\nabla f(x_{t}),-\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}\rangle
+γ2μ2f(xt)ηt12+γ1β1G2ηt1ηt1\displaystyle+\frac{\gamma}{2\mu^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\gamma}{1-\beta_{1}}G^{2}\|\eta_{t-1}-\eta_{t}\|_{1}
+𝔼f(zt)f(xt),γbiBfi(xt+ρtstst)ηt1\displaystyle+\mathbb{E}\langle\nabla f(z_{t})-\nabla f(x_{t}),-\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle
+2μ2γL2ρt2ϵ+L2𝔼zt+1zt2,\displaystyle+\frac{2\mu^{2}\gamma L^{2}\rho_{t}^{2}}{\epsilon}+\frac{L}{2}\mathbb{E}\|z_{t+1}-z_{t}\|^{2}, (13)

where μ>0\mu>0 is a constant to be determined. Next, from the Lemma 8, Lemma 10 and Lemma 9 in Appendix, we have

𝔼f(xt),γbiBfi(xt+ρtf(xt)f(xt))ηt1\displaystyle\mathbb{E}\langle\nabla f(x_{t}),-\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}\rangle
γf(xt)ηt12+𝔼γ2α2f(xt)ηt12\displaystyle\leq-\gamma\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\frac{\gamma}{2\alpha^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}
+γα2L2ρt22ϵ,\displaystyle\;\;\;\;+\frac{\gamma\alpha^{2}L^{2}\rho_{t}^{2}}{2\epsilon}, (14)
L2𝔼zt+1zt2LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle\frac{L}{2}\mathbb{E}\|z_{t+1}-z_{t}\|^{2}\leq\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}
+γ2L(31+ββϵ(Lρt2ϵ+σ2bϵ+𝔼f(xt)ηt12)\displaystyle\;\;\;\;+\gamma^{2}L(3\frac{1+\beta}{\beta\epsilon}(\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon}+\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2})
+(1+β)G2𝔼ηtηt12),\displaystyle\;\;\;\;+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}), (15)
𝔼f(zt)f(xt),γbiBfi(xt+ρtstst)ηt1\displaystyle\mathbb{E}\langle\nabla f(z_{t})-\nabla f(x_{t}),-\frac{\gamma}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle
γ3L2β122ϵ(1β1)2(1λ12+1λ22+1λ32)dG2ϵ2+γL2ρt22ϵ(λ22+4λ32)\displaystyle\leq\frac{\gamma^{3}L^{2}\beta_{1}^{2}}{2\epsilon(1-\beta_{1})^{2}}(\frac{1}{\lambda_{1}^{2}}+\frac{1}{\lambda_{2}^{2}}+\frac{1}{\lambda_{3}^{2}})\frac{dG_{\infty}^{2}}{\epsilon^{2}}+\frac{\gamma L^{2}\rho_{t}^{2}}{2\epsilon}(\lambda_{2}^{2}+4\lambda_{3}^{2})
+γλ122f(xt)ηt12.\displaystyle\;\;\;\;+\frac{\gamma\lambda_{1}^{2}}{2}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}. (16)

Next, we substitute it into the (13). Taking the expectation over all history information yields

𝔼f(xt+1)𝔼f(xt)\displaystyle\mathbb{E}f(x_{t+1})-\mathbb{E}f(x_{t})
γ(112μ212α23γL(1+β)βϵλ122)𝔼f(xt)ηt12\displaystyle\leq\!-\gamma(1\!-\!\frac{1}{2\mu^{2}}\!-\!\frac{1}{2\alpha^{2}}\!-\!\frac{3\gamma L(1\!+\!\beta)}{\beta\epsilon}\!-\!\frac{\lambda_{1}^{2}}{2})\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}
+2μ2γL2ρt2ϵ+γ1β1G2𝔼ηt1ηt1+γα2L2ρ22ϵ\displaystyle+\frac{2\mu^{2}\gamma L^{2}\rho_{t}^{2}}{\epsilon}+\frac{\gamma}{1-\beta_{1}}G^{2}\mathbb{E}\|\eta_{t-1}-\eta_{t}\|_{1}+\frac{\gamma\alpha^{2}L^{2}\rho^{2}}{2\epsilon}
+γ3L2β122ϵ(1β1)2(1λ12+1λ22+1λ32)dG2ϵ2+γL2ρt22ϵ(λ22+4λ32)\displaystyle+\frac{\gamma^{3}L^{2}\beta_{1}^{2}}{2\epsilon(1-\beta_{1})^{2}}(\frac{1}{\lambda_{1}^{2}}+\frac{1}{\lambda_{2}^{2}}+\frac{1}{\lambda_{3}^{2}})\frac{dG_{\infty}^{2}}{\epsilon^{2}}+\frac{\gamma L^{2}\rho_{t}^{2}}{2\epsilon}(\lambda_{2}^{2}+4\lambda_{3}^{2})
+γ2LG2((β11β1)2+1+β)𝔼ηtηt12\displaystyle+\gamma^{2}LG^{2}((\frac{\beta_{1}}{1-\beta_{1}})^{2}+1+\beta)\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}
+3γ2L(1+β)βϵ(Lρt2ϵ+σ2bϵ).\displaystyle+\frac{3\gamma^{2}L(1+\beta)}{\beta\epsilon}(\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon}). (17)

We set μ2=α2=8\mu^{2}=\alpha^{2}=8, β=3\beta=3, λ12=14\lambda_{1}^{2}=\frac{1}{4}, λ22=λ32=1\lambda_{2}^{2}=\lambda_{3}^{2}=1 and we choose 2γLϵ18\frac{2\gamma L}{\epsilon}\leq\frac{1}{8}. Note that ηt\eta_{t} is bounded. We have

γ2G𝔼f(xt)2γ2𝔼f(xt)ηt12\displaystyle\frac{\gamma}{2G}\mathbb{E}\|\nabla f(x_{t})\|^{2}\leq\frac{\gamma}{2}\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2} (18)
𝔼f(xt+1)+𝔼f(xt)+45γL2ρt22ϵ+4γ2Lϵ(Lρt2ϵ+σ2bϵ)\displaystyle\leq-\mathbb{E}f(x_{t+1})+\mathbb{E}f(x_{t})+\frac{45\gamma L^{2}\rho_{t}^{2}}{2\epsilon}+\frac{4\gamma^{2}L}{\epsilon}(\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon})
+γ1β1G2𝔼ηt1ηt1+3γ3L2β12(1β1)2dG2ϵ3\displaystyle+\frac{\gamma}{1-\beta_{1}}G^{2}\mathbb{E}\|\eta_{t-1}-\eta_{t}\|_{1}+\frac{3\gamma^{3}L^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\frac{dG_{\infty}^{2}}{\epsilon^{3}}
+(4+(β11β1)2)γ2LG2𝔼ηtηt12.\displaystyle+(4+(\frac{\beta_{1}}{1-\beta_{1}})^{2})\gamma^{2}LG^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}. (19)

Then, telescoping it from t=0t=0 to t=T1t=T-1, and assuming γ\gamma is a constant, it follows that

1Tt=0T1𝔼f(xt)22G(f(x0)f)γT+8GγLϵσ2bϵ\displaystyle\frac{1}{T}\sum_{t=0}^{T-1}\mathbb{E}\|\nabla f(x_{t})\|^{2}\leq\frac{2G(f(x_{0})-f^{*})}{\gamma T}+\frac{8G\gamma L}{\epsilon}\frac{\sigma^{2}}{b\epsilon}
+45GL2ρt2ϵ+2G3(1β1)Td(1ϵ1G)+6γ2L2β12(1β1)2dG3ϵ3\displaystyle+\frac{45GL^{2}\rho_{t}^{2}}{\epsilon}+\frac{2G^{3}}{(1-\beta_{1})T}d(\frac{1}{\epsilon}-\frac{1}{G})+\frac{6\gamma^{2}L^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\frac{dG^{3}}{\epsilon^{3}}
+8GγLϵLρt2ϵ+2(4+(β11β1)2)γLG3Td(ϵ2G2),\displaystyle+\frac{8G\gamma L}{\epsilon}\frac{L\rho_{t}^{2}}{\epsilon}+\frac{2(4+(\frac{\beta_{1}}{1-\beta_{1}})^{2})\gamma LG^{3}}{T}d(\epsilon^{-2}-G^{-2}), (20)

which completes the proof.

Refer to caption
(a) MRPC
Refer to caption
(b) RTE
Refer to caption
(c) CoLA

​​ ​​

Figure 2: The linear speedup verification of AdaSAM with the number of batch size of 4, 8, 16, 32.
TABLE II: Results of SGD, SAM, AMSGrad and AdaSAM on the GLUE benchmark without momentum, i.e., β1=0\beta_{1}=0
CoLA SST-2 MRPC STS-B RTE MNLI QNLI QQP
Model mcc. Acc. Acc./F1 Pcor./Scor. Acc. m./mm. Acc. F1/ Acc. Avg.
SGD 0 51.722 68.38/ 81.22 5.55/ 7.2 51.27 32.51/ 32.42 53.32 0/ 63.18 37.23
SAM(ρ=\rho=0.01) 41.91 95.3 68.38/ 81.22 9.21/ 10.38 53.07 87.99/ 87.8 51.24 83.44/ 87.27 63.1
SAM(ρ=\rho=0.005) 58.79 81.54 68.38/ 81.22 13.52/ 16.6 53.79 88.42/ 88.15 92.95 83.84/ 87.7 67.91
SAM(best) 58.79 95.3 68.38/ 81.22 13.52/ 16.6 53.79 88.42/ 88.15 92.95 83.84/ 87.7 69.06
AMSGrad 63.78 96.44 89.71/ 92.44 89.98/ 90.35 87.36 90.65/ 90.35 94.53 88.59/ 91.27 88.79
AdaSAM(ρ=\rho=0.01) 69.23 96.22 89.96/ 92.84 88.83/ 89.07 87 90.83/ 90.41 94.8 88.67/ 91.38 89.1
AdaSAM(ρ=\rho=0.005) 68.47 96.22 89.96/ 92.82 91.59/ 91.22 73.65 90.75/ 90.42 94.73 88.72/ 91.46 88.33
AdaSAM(best) 69.23 96.22 89.96/ 92.84 91.59/ 91.22 87 90.83/ 90.42 94.8 88.72/ 91.46 89.52

IV Experiments

In this section, we apply AdaSAM to train language models and compare it with SGD, AMSGrad, and SAM to show its effectiveness. Due to space limitations, more experiments, including visualization, task description, implementation details and results description, are placed in the Appendix.

IV-A Experimental Setup

Tasks and Datasets. We evaluate AdaSAM on a popular benchmark, i.e. General Language Understanding Evaluation (GLUE) [65], which consists of several language understanding tasks including sentiment analysis, question answering and textual entailment. For a fair comparison, we report the results based on single-task, without multi-task or ensemble training. We evaluate the performance with Accuracy (“Acc”) metric for most tasks, except the F1 scores for QQP and MRPC, the Pearson-Spearman correlations (“Pcor/Scor”) for STS-B and the Matthew correlations (“Mcc”) for CoLA. The performance is better as the metric is higher.

Implementations. We conduct our experiments using a widely-used pre-trained language model, RoBERTa-large111https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz in the open-source toolkit fairseq222https://github.com/facebookresearch/fairseq, with 24 transformer layers, a hidden size of 1024. For fine-tuning on each task, we use different combinations of hyper-parameters, including the learning rate, the number of epochs, the batch size, etc 333Due to the space limitation, we show the details of the dataset and training setting in Appendix A.. In particular, for RTE, STS-B and MRPC of GLUE benchmark, we first fine-tune the pre-trained RoBERTa-large model on the MNLI dataset and continue fine-tuning the RoBERTa-large-MNLI model on the corresponding single-task corpus for better performance, as many prior works did [66, 7]. All models are trained on NVIDIA DGX SuperPOD cluster, in which each machine contains 8×\times40GB A100 GPUs.

IV-B Results on GLUE Benchmark

Table I shows the performance of SGD, SAM, AMSGrad, and AdaSAM. For the AdaSAM, we tune the neighborhood size of the perturbation parameter from 0.01, 0.005, and 0.001. The result shows that AdaSAM outperforms AMSGrad on 6 tasks of 8 tasks except for QNLI and QQP. Overall, it improves the 0.28 average score than AMSGrad. On the other hand, Table I indicates that SAM is better than SGD on 7 tasks of 8 tasks except for RTE. And SAM can significantly improve performance. Comparing the results of Table I, we can find that the adaptive learning rate method is better than SGD tuned with handicraft learning rate. AdaSAM achieves the best metric on 6 tasks which is CoLA, SST-2, MRPC, STS-B, RTE, QNLI, and MNLI. In general, AdaSAM is better than the other methods.

In addition, Figure 3 shows the convergence speed of the detailed loss and evaluation metrics vs. the number of steps during training, respectively. The loss curve of AdaSAM decreases faster than SAM and SGD in all tasks, and it has a similar decreasing speed as the AMSGrad. The evaluation metric curve of AdaSAM and AMSGrad show that the AdaSAM is better than SGD and SAM and decreases the loss value as faster as the AMSGrad in all tasks.

IV-C Mini-batch Speedup

In this part, we test the performance with different batch sizes to validate the linear speedup property. The experiments are conducted on the MRPC, RTE, and CoLA tasks. The batch size is set as 4, 8, 16, 32, respectively. We scale the learning rate as N\sqrt{N}, which is similar as [67], where NN is the batch size. The results show that the training loss decreases faster as the batchsize increases, and the loss curve with the batch size of 32 achieves nearly half iterations as the curve with the batch size of 16.

IV-D Ablation Study

In this subsection, we conduct the experiments the momentum hyper-parameter β1\beta_{1} is set to 0 to evaluate the influence of the momentum acceleration and the adaptive learning rate. Table II shows that AdaSAM outperforms AMSGrad on 6 tasks of 8 tasks except for SST-2 and RTE. In Table II, we also compare SGD and SAM, and without the momentum, SAM outperforms SGD on all tasks. Under this situation, AdaSAM without the momentum acceleration method is better than the other methods.

When comparing the result of Table I and Table II, we find that both the adaptive learning rate method and the momentum acceleration are helpful for the model’s generalization ability. When there is no momentum term, SAM with an adaptive learning rate improves the 0.74 average score to AMSGrad. With a momentum term, AdaSAM improves the 0.28 average score to AMSGrad. It shows that the adaptive method can improve the performance with or without momentum acceleration and it achieves the best performance with momentum acceleration. And we can find that momentum acceleration improves the performance of SAM, AMSGrad and AdaSAM.

V Conclusion

In this work, we study the convergence rate of Sharpness aware minimization optimizer with an adaptive learning rate and momentum acceleration, dubbed AdaSAM in the stochastic non-convex setting. To the best of our knowledge, we are the first to provide the non-trivial 𝒪(1/bT)\mathcal{O}(1/\sqrt{bT}) convergence rate of AdaSAM, which achieves a linear speedup property with respect to mini-batch size bb. We have conducted extensive experiments on several NLP tasks, which verifies that AdaSAM could achieve superior performance compared with AMSGrad and SAM optimizers. Future works include extending AdaSAM to the distributed setting and reducing the twice gradient back-propagation cost.

References

  • [1] P. Foret, A. Kleiner, H. Mobahi, and B. Neyshabur, “Sharpness-aware minimization for efficiently improving generalization,” in International Conference on Learning Representations, 2021.
  • [2] K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in Proceedings of the IEEE conference on computer vision and pattern recognition, 2016, pp. 770–778.
  • [3] P. Mi, L. Shen, T. Ren, Y. Zhou, X. Sun, R. Ji, and D. Tao, “Make sharpness-aware minimization stronger: A sparsified perturbation approach,” in Advances in Neural Information Processing Systems.
  • [4] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly, J. Uszkoreit, and N. Houlsby, “An image is worth 16x16 words: Transformers for image recognition at scale,” in International Conference on Learning Representations, 2021.
  • [5] X. Chen, C.-J. Hsieh, and B. Gong, “When vision transformers outperform resnets without pre-training or strong data augmentations,” in International Conference on Learning Representation, 2022.
  • [6] J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova, “Bert: Pre-training of deep bidirectional transformers for language understanding,” arXiv preprint arXiv:1810.04805, 2018.
  • [7] P. He, X. Liu, J. Gao, and W. Chen, “Deberta: Decoding-enhanced bert with disentangled attention,” in ICLR, 2020.
  • [8] Q. Zhong, L. Ding, L. Shen, P. Mi, J. Liu, B. Du, and D. Tao, “Improving sharpness-aware minimization with fisher mask for better generalization on language models,” arXiv preprint arXiv:2210.05497, 2022.
  • [9] D. P. Kingma and J. Ba, “Adam: A method for stochastic optimization,” in ICLR (Poster), 2015. [Online]. Available: http://arxiv.org/abs/1412.6980
  • [10] S. J. Reddi, S. Kale, and S. Kumar, “On the convergence of adam and beyond,” in International Conference on Learning Representations, 2018.
  • [11] H. Iiduka, “Appropriate learning rates of adaptive learning rate optimization algorithms for training deep neural networks,” IEEE Trans. Cybern., vol. 52, no. 12, pp. 13 250–13 261, 2022.
  • [12] J. Duchi, E. Hazan, and Y. Singer, “Adaptive subgradient methods for online learning and stochastic optimization.” Journal of machine learning research, vol. 12, no. 7, 2011.
  • [13] S. Ruder, “An overview of gradient descent optimization algorithms,” arXiv preprint arXiv:1609.04747, 2016.
  • [14] L. Liao, L. Shen, J. Duan, M. Kolar, and D. Tao, “Local adagrad-type algorithm for stochastic convex-concave optimization,” Machine Learning, pp. 1–20, 2022.
  • [15] J. Zhang, S. P. Karimireddy, A. Veit, S. Kim, S. Reddi, S. Kumar, and S. Sra, “Why are adaptive methods good for attention models?” Advances in Neural Information Processing Systems, vol. 33, pp. 15 383–15 393, 2020.
  • [16] J. Zhuang, B. Gong, L. Yuan, Y. Cui, H. Adam, N. C. Dvornek, sekhar tatikonda, J. s Duncan, and T. Liu, “Surrogate gap minimization improves sharpness-aware training,” in International Conference on Learning Representations, 2022.
  • [17] J. Kwon, J. Kim, H. Park, and I. K. Choi, “Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks,” in International Conference on Machine Learning. PMLR, 2021, pp. 5905–5914.
  • [18] N. Srivastava, G. E. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov, “Dropout: a simple way to prevent neural networks from overfitting,” J. Mach. Learn. Res., vol. 15, no. 1, pp. 1929–1958, 2014.
  • [19] I. Loshchilov and F. Hutter, “Decoupled weight decay regularization,” arXiv preprint arXiv:1711.05101, 2017.
  • [20] Z. Li, H. Zhao, Y. Guo, Z. Yang, and S. Xie, “Accelerated log-regularized convolutional transform learning and its convergence guarantee,” IEEE Trans. Cybern., vol. 52, no. 10, pp. 10 785–10 799, 2022.
  • [21] Y. Lu, Z. Zhang, G. Lu, Y. Zhou, J. Li, and D. Zhang, “Addi-reg: A better generalization-optimization tradeoff regularization method for convolutional neural networks,” IEEE Trans. Cybern., vol. 52, no. 10, pp. 10 827–10 842, 2022.
  • [22] Y. Jiang, B. Neyshabur, H. Mobahi, D. Krishnan, and S. Bengio, “Fantastic generalization measures and where to find them,” in International Conference on Learning Representations, 2020.
  • [23] N. S. Keskar, D. Mudigere, J. Nocedal, M. Smelyanskiy, and P. T. P. Tang, “On large-batch training for deep learning: Generalization gap and sharp minima,” in 5th International Conference on Learning Representations, ICLR 2017. OpenReview.net, 2017.
  • [24] H. He, G. Huang, and Y. Yuan, “Asymmetric valleys: Beyond sharp and flat local minima,” Advances in neural information processing systems, vol. 32, 2019.
  • [25] J. Du, H. Yan, J. Feng, J. T. Zhou, L. Zhen, R. S. M. Goh, and V. Tan, “Efficient sharpness-aware minimization for improved training of neural networks,” in International Conference on Learning Representations, 2022.
  • [26] Y. Liu, S. Mai, X. Chen, C.-J. Hsieh, and Y. You, “Towards efficient and scalable sharpness-aware minimization,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2022, pp. 12 360–12 370.
  • [27] Z. Qu, X. Li, R. Duan, Y. Liu, B. Tang, and Z. Lu, “Generalized federated learning via sharpness aware minimization,” in International Conference on Machine Learning. PMLR, 2022, pp. 18 250–18 280.
  • [28] Y. Sun, L. Shen, T. Huang, L. Ding, and D. Tao, “Fedspeed: Larger local interval, less communication round, and higher generalization accuracy,” in International Conference on Learning Representations.
  • [29] M. Andriushchenko and N. Flammarion, “Towards understanding sharpness-aware minimization,” in International Conference on Machine Learning. PMLR, 2022, pp. 639–668.
  • [30] M. D. Zeiler, “Adadelta: an adaptive learning rate method,” arXiv preprint arXiv:1212.5701, 2012.
  • [31] D. Zhou, J. Chen, Y. Cao, Y. Tang, Z. Yang, and Q. Gu, “On the convergence of adaptive gradient methods for nonconvex optimization,” arXiv preprint arXiv:1808.05671, 2018.
  • [32] X. Chen, S. Liu, R. Sun, and M. Hong, “On the convergence of a class of adam-type algorithms for non-convex optimization,” in International Conference on Learning Representations, 2019.
  • [33] M. Zaheer, S. Reddi, D. Sachan, S. Kale, and S. Kumar, “Adaptive methods for nonconvex optimization,” Advances in neural information processing systems, vol. 31, 2018.
  • [34] R. Ward, X. Wu, and L. Bottou, “Adagrad stepsizes: Sharp convergence over nonconvex landscapes,” in International Conference on Machine Learning. PMLR, 2019, pp. 6677–6686.
  • [35] A. Défossez, L. Bottou, F. Bach, and N. Usunier, “On the convergence of adam and adagrad,” CoRR, vol. abs/2003.02395, 2020. [Online]. Available: https://arxiv.org/abs/2003.02395
  • [36] F. Zou, L. Shen, Z. Jie, W. Zhang, and W. Liu, “A sufficient condition for convergences of adam and rmsprop,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019, pp. 11 127–11 135.
  • [37] C. Chen, L. Shen, F. Zou, and W. Liu, “Towards practical adam: Non-convexity, convergence theory, and mini-batch acceleration,” arXiv preprint arXiv:2101.05471, 2021.
  • [38] C. Chen, L. Shen, H. Huang, and W. Liu, “Quantized adam with error feedback,” ACM Transactions on Intelligent Systems and Technology (TIST), vol. 12, no. 5, pp. 1–26, 2021.
  • [39] C. Chen, L. Shen, F. Zou, and W. Liu, “Towards practical adam: Non-convexity, convergence theory, and mini-batch acceleration,” Journal of Machine Learning Research, vol. 23, pp. 1–47, 2022.
  • [40] C. Chen, L. Shen, W. Liu, and Z.-Q. Luo, “Efficient-adam: Communication-efficient distributed adam with complexity analysis,” arXiv preprint arXiv:2205.14473, 2022.
  • [41] F. Zou, L. Shen, Z. Jie, J. Sun, and W. Liu, “Weighted adagrad with unified momentum,” arXiv preprint arXiv:1808.03408, 2018.
  • [42] H. Iiduka, “Appropriate learning rates of adaptive learning rate optimization algorithms for training deep neural networks,” IEEE Transactions on Cybernetics, vol. 52, no. 12, pp. 13 250–13 261, 2021.
  • [43] S. Sun, Z. Cao, H. Zhu, and J. Zhao, “A survey of optimization methods from a machine learning perspective,” IEEE transactions on cybernetics, vol. 50, no. 8, pp. 3668–3681, 2019.
  • [44] H. Sakai and H. Iiduka, “Riemannian adaptive optimization algorithm and its application to natural language processing,” IEEE Transactions on Cybernetics, vol. 52, no. 8, pp. 7328–7339, 2021.
  • [45] B. T. Polyak, “Some methods of speeding up the convergence of iteration methods,” Ussr computational mathematics and mathematical physics, vol. 4, no. 5, pp. 1–17, 1964.
  • [46] Y. Nesterov, Introductory lectures on convex optimization: A basic course. Springer Science & Business Media, 2003, vol. 87.
  • [47] B. O’Donoghue and E. J. Candès, “Adaptive restart for accelerated gradient schemes,” Found. Comput. Math., vol. 15, no. 3, pp. 715–732, 2015.
  • [48] W. Liu, L. Chen, Y. Chen, and W. Zhang, “Accelerating federated learning via momentum gradient descent,” IEEE Transactions on Parallel and Distributed Systems, vol. 31, no. 8, pp. 1754–1766, 2020.
  • [49] X. Luo, Z. Liu, S. Li, M. Shang, and Z. Wang, “A fast non-negative latent factor model based on generalized momentum method,” IEEE Transactions on Systems, Man, and Cybernetics: Systems, vol. 51, no. 1, pp. 610–620, 2018.
  • [50] M. Shang, Y. Yuan, X. Luo, and M. Zhou, “An α\alphaβ\beta-divergence-generalized recommender for highly accurate predictions of missing user preferences,” IEEE transactions on cybernetics, vol. 52, no. 8, pp. 8006–8018, 2021.
  • [51] T. Yang, Q. Lin, and Z. Li, “Unified convergence analysis of stochastic momentum methods for convex and non-convex optimization,” arXiv preprint arXiv:1604.03257, 2016.
  • [52] S. S. Mannelli and P. Urbani, “Analytical study of momentum-based acceleration methods in paradigmatic high-dimensional non-convex problems,” in Advances in Neural Information Processing Systems 34: Annual Conference on Neural Information Processing Systems 2021, NeurIPS 2021, December 6-14, 2021, virtual, M. Ranzato, A. Beygelzimer, Y. N. Dauphin, P. Liang, and J. W. Vaughan, Eds., 2021, pp. 187–199.
  • [53] X. Gao, M. Gürbüzbalaban, and L. Zhu, “Global convergence of stochastic gradient hamiltonian monte carlo for nonconvex stochastic optimization: Nonasymptotic performance bounds and momentum-based acceleration,” Operations Research, vol. 70, no. 5, pp. 2931–2947, 2022.
  • [54] I. Sutskever, J. Martens, G. Dahl, and G. Hinton, “On the importance of initialization and momentum in deep learning,” in International conference on machine learning. PMLR, 2013, pp. 1139–1147.
  • [55] B. Can, M. Gürbüzbalaban, and L. Zhu, “Accelerated linear convergence of stochastic momentum methods in wasserstein distances,” in Proceedings of the 36th International Conference on Machine Learning, ICML 2019, 9-15 June 2019, Long Beach, California, USA, ser. Proceedings of Machine Learning Research, K. Chaudhuri and R. Salakhutdinov, Eds., vol. 97. PMLR, 2019, pp. 891–901.
  • [56] F. Huang, S. Gao, J. Pei, and H. Huang, “Accelerated zeroth-order and first-order momentum methods from mini to minimax optimization,” J. Mach. Learn. Res., vol. 23, pp. 36:1–36:70, 2022.
  • [57] R. Bollapragada, D. Scieur, and A. d’Aspremont, “Nonlinear acceleration of momentum and primal-dual algorithms,” Mathematical Programming, pp. 1–38, 2022.
  • [58] B. O’donoghue and E. Candes, “Adaptive restart for accelerated gradient schemes,” Foundations of computational mathematics, vol. 15, pp. 715–732, 2015.
  • [59] B. Wang, T. M. Nguyen, T. Sun, A. L. Bertozzi, R. G. Baraniuk, and S. J. Osher, “Scheduled restart momentum for accelerated stochastic gradient descent,” SIAM J. Imaging Sci., vol. 15, no. 2, pp. 738–761, 2022.
  • [60] B. Liu, L. Chai, and J. Yi, “Convergence analysis of distributed gradient descent algorithms with one and two momentum terms,” IEEE Transactions on Cybernetics, 2022.
  • [61] A. Cutkosky and F. Orabona, “Momentum-based variance reduction in non-convex sgd,” Advances in neural information processing systems, vol. 32, 2019.
  • [62] F. Huang, J. Li, and H. Huang, “Super-adam: faster and universal framework of adaptive gradients,” Advances in Neural Information Processing Systems, vol. 34, 2021.
  • [63] M. Li, D. G. Andersen, A. J. Smola, and K. Yu, “Communication efficient distributed machine learning with the parameter server,” Advances in Neural Information Processing Systems, vol. 27, 2014.
  • [64] M. Li, T. Zhang, Y. Chen, and A. J. Smola, “Efficient mini-batch training for stochastic optimization,” in Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining, 2014, pp. 661–670.
  • [65] A. Wang, A. Singh, J. Michael, F. Hill, O. Levy, and S. Bowman, “Glue: A multi-task benchmark and analysis platform for natural language understanding,” in EMNLP, 2018.
  • [66] Y. Liu, M. Ott, N. Goyal, J. Du, M. Joshi, D. Chen, O. Levy, M. Lewis, L. Zettlemoyer, and V. Stoyanov, “Roberta: A robustly optimized bert pretraining approach,” arXiv, 2019.
  • [67] X. Li, B. Karimi, and P. Li, “On distributed adaptive optimization with gradient compression,” in International Conference on Learning Representations, 2021.

In this supplementary material, we give additional discussion on this paper. In Appendix A, detailed experimental settings such as some hyper-parameters are listed. In Appendix B, we first give the proof, then we give some useful lemmas to help proving the main theorem. In Appendix C, we provide additional experiment illustration.

Appendix A Experimental Settings

TABLE III: Experimental settings and data divisions upon different downstream tasks. Notably, for each tasks in GLUE benchmark, we provide the number of classes (“classes”), the learning rate (“lr”), the batch size (“bsz”), the total number of updates (“total”), the number of warmup updates (“warmup”) and the number of GPUs (“GPUs”) during fine-tuning, respectively.
MNLI QNLI QQP RTE SST-2 MRPC CoLA STS-B
experimental settings upon different downstream tasks
–classes 3 2 2 2 2 2 2 1
–lr 1e-5 1e-5 1e-5 2e-5 1e-5 1e-5 1e-5 2e-5
–bsz 256 128 256 32 64 32 32 32
–total 15,484 8,278 14,453 1,018 10,467 1,148 2,668 1,799
–warmup 929 496 867 61 628 68 160 107
–GPUs 4 4 8 2 2 2 2 2
data divisions for each dataset
train 392,720 104,743 363,870 2,491 67,350 5,801 8,551 5,749
dev 9,815 5,463 40,431 277 873 4,076 1,043 1,500
test 9,796 5,461 390,956 3,000 1,821 1,725 1,063 1,379

The GLUE benchmark contains 8 tasks, they are RTE, STS-B, CoLA, SST-2, MNLI, MRPC, QNLI and QQP. CoLA is a single sentence task. Each sentence has a label 1 and -1. 1 represents that it is a grammatical sentence, while -1 represents that it is illegal. Matthews correlation coefficient, dubbed mcc is used as our evaluation metric. STS-B is a similarity and paraphrase task. Each sample has a pair of a paragraph. People annotated the sample from 1 to 5 based on the similarity between the two paragraphs. The metric is Pearson and Spearman, dubbed p/s correlation coefficients. RTE is an inference task. Each sample has two sentences. If two sentences have a relation of entailment, we view them as a positive sample. If not, they compose of a negative sample. In the RTE task, the metric is the accuracy, dubbed acc. SST-2 is a single sentence task and its metric is the accuracy. MNLI is a sentence-level task that has 3 classes. They are entailment, contradiction and neutral. MRPC is a task to classify whether the sentences in the pair are equivalent. QNLI is a question-answering task. If the sentence contains the answer to the question, then it is a positive sample. QQP is a social question-answering task that consists of question pairs from Quora. It determines whether the questions are equivalent. The metric of MNLI, MRPC, QNLI, QQP is accuracy.

Appendix B Proof of the Main Results

We set zt=xt+β11β1(xtxt1)z_{t}=x_{t}+\frac{\beta_{1}}{1-\beta_{1}}(x_{t}-x_{t-1}) for t0t\geq 0 and we assume x1=0x_{-1}=0 and m1=0m_{-1}=0.

We have that

zt+1zt=xt+1+β11β1(xt+1xt)xtβ11β1(xtxt1)\displaystyle z_{t+1}-z_{t}=x_{t+1}+\frac{\beta_{1}}{1-\beta_{1}}(x_{t+1}-x_{t})-x_{t}-\frac{\beta_{1}}{1-\beta_{1}}(x_{t}-x_{t-1}) (21)
=11β1(xt+1xt)β11β1(xtxt1)\displaystyle=\frac{1}{1-\beta_{1}}(x_{t+1}-x_{t})-\frac{\beta_{1}}{1-\beta_{1}}(x_{t}-x_{t-1}) (22)
=11β1γmtηt+β11β1(xtxt1)γmt1ηt1\displaystyle=-\frac{1}{1-\beta_{1}}\gamma m_{t}\odot\eta_{t}+\frac{\beta_{1}}{1-\beta_{1}}(x_{t}-x_{t-1})\gamma m_{t-1}\odot\eta_{t-1} (23)
=11β1γ(β1mt1+(1β1)gt)ηt+β11β1(xtxt1)γmt1ηt1\displaystyle=-\frac{1}{1-\beta_{1}}\gamma(\beta_{1}m_{t-1}+(1-\beta_{1})g_{t})\odot\eta_{t}+\frac{\beta_{1}}{1-\beta_{1}}(x_{t}-x_{t-1})\gamma m_{t-1}\odot\eta_{t-1} (24)
=β11β1γmt1(ηt1ηt)γgtηt\displaystyle=\frac{\beta_{1}}{1-\beta_{1}}\gamma m_{t-1}\odot(\eta_{t-1}-\eta_{t})-\gamma g_{t}\odot\eta_{t} (25)

By applying L-smooth, we have

f(zt+1)f(zt)+f(zt),zt+1zt+L2zt+1zt2\displaystyle f(z_{t+1})\leq f(z_{t})+\langle\nabla f(z_{t}),z_{t+1}-z_{t}\rangle+\frac{L}{2}\|z_{t+1}-z_{t}\|^{2} (26)

We re-organize it, and we have

f(zt+1)f(zt)\displaystyle f(z_{t+1})-f(z_{t})
f(zt),zt+1zt+L2zt+1zt2\displaystyle\leq\langle\nabla f(z_{t}),z_{t+1}-z_{t}\rangle+\frac{L}{2}\|z_{t+1}-z_{t}\|^{2} (27)
=f(zt),γβ11β1mt1(ηt1ηt)+f(zt),γgtηt+L2zt+1zt2\displaystyle=\langle\nabla f(z_{t}),\frac{\gamma\beta_{1}}{1-\beta_{1}}m_{t-1}\odot(\eta_{t-1}-\eta_{t})\rangle+\langle\nabla f(z_{t}),-\gamma g_{t}\odot\eta_{t}\rangle+\frac{L}{2}\|z_{t+1}-z_{t}\|^{2} (28)
=f(zt),γβ11β1mt1(ηt1ηt)+L2zt+1zt2\displaystyle=\langle\nabla f(z_{t}),\frac{\gamma\beta_{1}}{1-\beta_{1}}m_{t-1}\odot(\eta_{t-1}-\eta_{t})\rangle+\frac{L}{2}\|z_{t+1}-z_{t}\|^{2}
+f(zt),γtbiBfi(xt+ρtstst)(ηt1ηt)\displaystyle\;\;\;\;+\langle\nabla f(z_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot(\eta_{t-1}-\eta_{t})\rangle
+f(zt),γtbiBfi(xt+ρtstst)ηt1\displaystyle\;\;\;\;+\langle\nabla f(z_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle (29)
=f(zt),γβ11β1mt1(ηt1ηt)+L2zt+1zt2\displaystyle=\langle\nabla f(z_{t}),\frac{\gamma\beta_{1}}{1-\beta_{1}}m_{t-1}\odot(\eta_{t-1}-\eta_{t})\rangle+\frac{L}{2}\|z_{t+1}-z_{t}\|^{2}
+f(zt),γtbiBfi(xt+ρtstst)(ηt1ηt)\displaystyle\;\;\;\;+\langle\nabla f(z_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot(\eta_{t-1}-\eta_{t})\rangle
+f(zt)f(xt),γtbiBfi(xt+ρtstst)ηt1\displaystyle\;\;\;\;+\langle\nabla f(z_{t})-\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle
+f(xt),γtbiBfi(xt+ρtstst)ηt1\displaystyle\;\;\;\;+\langle\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle (30)
=f(zt),γβ11β1mt1(ηt1ηt)+L2zt+1zt2\displaystyle=\langle\nabla f(z_{t}),\frac{\gamma\beta_{1}}{1-\beta_{1}}m_{t-1}\odot(\eta_{t-1}-\eta_{t})\rangle+\frac{L}{2}\|z_{t+1}-z_{t}\|^{2}
+f(zt),γtbiBfi(xt+ρtstst)(ηt1ηt)\displaystyle\;\;\;\;+\langle\nabla f(z_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot(\eta_{t-1}-\eta_{t})\rangle
+f(zt)f(xt),γtbiBfi(xt+ρtfi(xt)fi(xt))ηt1\displaystyle\;\;\;\;+\langle\nabla f(z_{t})-\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum\nabla f_{i}(x_{t})}{\|\sum\nabla f_{i}(x_{t})\|})\odot\eta_{t-1}\rangle
+f(xt),γtbiBfi(xt+ρtf(xt)f(xt))ηt1γtbiBfi(xt+ρtstst)ηt1\displaystyle\;\;\;\;+\langle\nabla f(x_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle
+f(xt),γtbiBfi(xt+ρtf(xt)f(xt))ηt1.\displaystyle\;\;\;\;+\langle\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}\rangle. (31)

From the Lemma 5, Lemma 6, Lemma 7, we have

f(zt),γtbiBfi(xt+ρtstst)(ηt1ηt)γtG2ηt1ηt1,\displaystyle\langle\nabla f(z_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot(\eta_{t-1}-\eta_{t})\rangle\leq\gamma_{t}G^{2}\|\eta_{t-1}-\eta_{t}\|_{1}, (32)
f(zt),γβ11β1mt1(ηt1ηt)γβ11β1G2ηt1ηt1,\displaystyle\langle\nabla f(z_{t}),\frac{\gamma\beta_{1}}{1-\beta_{1}}m_{t-1}\odot(\eta_{t-1}-\eta_{t})\rangle\leq\frac{\gamma\beta_{1}}{1-\beta_{1}}G^{2}\|\eta_{t-1}-\eta_{t}\|_{1}, (33)
f(xt),ηtbiBfi(xt+ρtf(xt)f(xt))ηt1γtbiBfi(xt+ρtstst)ηt1\displaystyle\langle\nabla f(x_{t}),\frac{\eta_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle
γt2μ2f(xt)ηt12+2μ2γtL2ρt2ϵ.\displaystyle\leq\frac{\gamma_{t}}{2\mu^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{2\mu^{2}\gamma_{t}L^{2}\rho_{t}^{2}}{\epsilon}. (34)

Taking conditional expectation, we have

𝔼f(zt+1)f(zt)\displaystyle\mathbb{E}f(z_{t+1})-f(z_{t}) (35)
𝔼f(xt),γtbiBfi(xt+ρtf(xt)f(xt))ηt1+L2𝔼zt+1zt2\displaystyle\leq\mathbb{E}\langle\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}\rangle+\frac{L}{2}\mathbb{E}\|z_{t+1}-z_{t}\|^{2}
+γt2μ2f(xt)ηt12+2μ2γtL2ρt2ϵ+γ1β1G2ηt1ηt1\displaystyle+\frac{\gamma_{t}}{2\mu^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{2\mu^{2}\gamma_{t}L^{2}\rho_{t}^{2}}{\epsilon}+\frac{\gamma}{1-\beta_{1}}G^{2}\|\eta_{t-1}-\eta_{t}\|_{1}
+𝔼f(zt)f(xt),γtbiBfi(xt+ρtstst)ηt1\displaystyle+\mathbb{E}\langle\nabla f(z_{t})-\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle (36)

where μ>0\mu>0 is to be determined.

For the term

𝔼f(xt),γtbiBfi(xt+ρtf(xt)f(xt))ηt1,\displaystyle\;\;\;\;\mathbb{E}\langle\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}\rangle, (37)

the term

L2𝔼zt+1zt2,\displaystyle\frac{L}{2}\mathbb{E}\|z_{t+1}-z_{t}\|^{2}, (38)

and the term

𝔼f(zt)f(xt),γtbiBfi(xt+ρtstst)ηt1,\displaystyle\mathbb{E}\langle\nabla f(z_{t})-\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle, (39)

we introduce the Lemma 8, the Lemma 10 and the Lemma 9. We take the expectation over the whole processing and we have

𝔼f(zt+1)𝔼f(zt)\displaystyle\mathbb{E}f(z_{t+1})-\mathbb{E}f(z_{t})
γt2μ2𝔼f(xt)ηt12+2μ2γtL2ρt2ϵ+γ1β1G2𝔼ηt1ηt1\displaystyle\leq\frac{\gamma_{t}}{2\mu^{2}}\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{2\mu^{2}\gamma_{t}L^{2}\rho_{t}^{2}}{\epsilon}+\frac{\gamma}{1-\beta_{1}}G^{2}\mathbb{E}\|\eta_{t-1}-\eta_{t}\|_{1}
γt𝔼f(xt)ηt12+𝔼γt2α2𝔼f(xt)ηt12+γtα2L2ρ22ϵ+LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle-\gamma_{t}\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\frac{\gamma_{t}}{2\alpha^{2}}\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\gamma_{t}\alpha^{2}L^{2}\rho^{2}}{2\epsilon}+\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}
+γt2L(31+ββϵ(𝔼f(xt)ηt12+Lρt2ϵ+σ2bϵ)+(1+β)G2𝔼ηtηt12)\displaystyle+\gamma_{t}^{2}L(3\frac{1+\beta}{\beta\epsilon}(\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon})+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2})
+γ3L2β122ϵ(1β1)2(1λ12+1λ22+1λ32)dG2ϵ2+γλ122f(xt)ηt12+γL2ρt22ϵ(λ22+4λ32)\displaystyle+\frac{\gamma^{3}L^{2}\beta_{1}^{2}}{2\epsilon(1-\beta_{1})^{2}}(\frac{1}{\lambda_{1}^{2}}+\frac{1}{\lambda_{2}^{2}}+\frac{1}{\lambda_{3}^{2}})\frac{dG_{\infty}^{2}}{\epsilon^{2}}+\frac{\gamma\lambda_{1}^{2}}{2}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\gamma L^{2}\rho_{t}^{2}}{2\epsilon}(\lambda_{2}^{2}+4\lambda_{3}^{2}) (40)
=γt(112μ212α23γL(1+β)βϵλ122)𝔼f(xt)ηt12+2μ2γtL2ρt2ϵ+γ1β1G2𝔼ηt1ηt1\displaystyle=-\gamma_{t}(1-\frac{1}{2\mu^{2}}-\frac{1}{2\alpha^{2}}-\frac{3\gamma L(1+\beta)}{\beta\epsilon}-\frac{\lambda_{1}^{2}}{2})\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{2\mu^{2}\gamma_{t}L^{2}\rho_{t}^{2}}{\epsilon}+\frac{\gamma}{1-\beta_{1}}G^{2}\mathbb{E}\|\eta_{t-1}-\eta_{t}\|_{1}
+γtα2L2ρ22ϵ+3γt2L(1+β)βϵ(Lρt2ϵ+σ2bϵ)+γt2LG2((β11β1)2+1+β)𝔼ηtηt12\displaystyle+\frac{\gamma_{t}\alpha^{2}L^{2}\rho^{2}}{2\epsilon}+\frac{3\gamma_{t}^{2}L(1+\beta)}{\beta\epsilon}(\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon})+\gamma_{t}^{2}LG^{2}((\frac{\beta_{1}}{1-\beta_{1}})^{2}+1+\beta)\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}
+γ3L2β122ϵ(1β1)2(1λ12+1λ22+1λ32)dG2ϵ2+γL2ρt22ϵ(λ22+4λ32).\displaystyle+\frac{\gamma^{3}L^{2}\beta_{1}^{2}}{2\epsilon(1-\beta_{1})^{2}}(\frac{1}{\lambda_{1}^{2}}+\frac{1}{\lambda_{2}^{2}}+\frac{1}{\lambda_{3}^{2}})\frac{dG_{\infty}^{2}}{\epsilon^{2}}+\frac{\gamma L^{2}\rho_{t}^{2}}{2\epsilon}(\lambda_{2}^{2}+4\lambda_{3}^{2}). (41)

We set μ2=α2=8\mu^{2}=\alpha^{2}=8, β=3\beta=3, λ12=14\lambda_{1}^{2}=\frac{1}{4}, λ22=λ32=1\lambda_{2}^{2}=\lambda_{3}^{2}=1 and we choose 2γtLϵ18\frac{2\gamma_{t}L}{\epsilon}\leq\frac{1}{8}. So we have

𝔼f(xt+1)𝔼f(xt)\displaystyle\mathbb{E}f(x_{t+1})-\mathbb{E}f(x_{t})
γt2𝔼f(xt)ηt12+16γtL2ρt2ϵ+γ1β1G2𝔼ηt1ηt1\displaystyle\leq-\frac{\gamma_{t}}{2}\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{16\gamma_{t}L^{2}\rho_{t}^{2}}{\epsilon}+\frac{\gamma}{1-\beta_{1}}G^{2}\mathbb{E}\|\eta_{t-1}-\eta_{t}\|_{1}
+4γtL2ρ2ϵ+4γt2Lϵ(Lρt2ϵ+σ2bϵ)+(4+(β11β1)2)γt2LG2𝔼ηtηt12\displaystyle+\frac{4\gamma_{t}L^{2}\rho^{2}}{\epsilon}+\frac{4\gamma_{t}^{2}L}{\epsilon}(\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon})+(4+(\frac{\beta_{1}}{1-\beta_{1}})^{2})\gamma_{t}^{2}LG^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}
+3γ3L2β12ϵ(1β1)2dG2ϵ2+5γL2ρt22ϵ\displaystyle+\frac{3\gamma^{3}L^{2}\beta_{1}^{2}}{\epsilon(1-\beta_{1})^{2}}\frac{dG_{\infty}^{2}}{\epsilon^{2}}+\frac{5\gamma L^{2}\rho_{t}^{2}}{2\epsilon} (42)

We re-arrange it and ηt\eta_{t} is bounded. We have

γt2G𝔼f(xt)2γt2𝔼f(xt)ηt12\displaystyle\frac{\gamma_{t}}{2G}\mathbb{E}\|\nabla f(x_{t})\|^{2}\leq\frac{\gamma_{t}}{2}\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2} (43)
𝔼f(xt+1)+𝔼f(xt)+45γtL2ρt22ϵ+γ1β1G2𝔼ηt1ηt1\displaystyle\leq-\mathbb{E}f(x_{t+1})+\mathbb{E}f(x_{t})+\frac{45\gamma_{t}L^{2}\rho_{t}^{2}}{2\epsilon}+\frac{\gamma}{1-\beta_{1}}G^{2}\mathbb{E}\|\eta_{t-1}-\eta_{t}\|_{1}
+4γt2Lϵ(Lρt2ϵ+σ2bϵ)+(4+(β11β1)2)γt2LG2𝔼ηtηt12+3γ3L2β12(1β1)2dG2ϵ3.\displaystyle+\frac{4\gamma_{t}^{2}L}{\epsilon}(\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon})+(4+(\frac{\beta_{1}}{1-\beta_{1}})^{2})\gamma_{t}^{2}LG^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}+\frac{3\gamma^{3}L^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\frac{dG_{\infty}^{2}}{\epsilon^{3}}. (44)

We summary it from t=0t=0 to t=T1t=T-1, and we assume γt\gamma_{t} is a constant, and we have

1Tt=0T1𝔼f(xt)22G𝔼f(x0)𝔼f(xt+1)γtT+45GL2ρt2ϵ+2G3(1β1)T𝔼t=0T1ηt1ηt1\displaystyle\frac{1}{T}\sum_{t=0}^{T-1}\mathbb{E}\|\nabla f(x_{t})\|^{2}\leq 2G\frac{\mathbb{E}f(x_{0})-\mathbb{E}f(x_{t+1})}{\gamma_{t}T}+\frac{45GL^{2}\rho_{t}^{2}}{\epsilon}+\frac{2G^{3}}{(1-\beta_{1})T}\mathbb{E}\sum_{t=0}^{T-1}\|\eta_{t-1}-\eta_{t}\|_{1}
+8GγtLϵ(Lρt2ϵ+σ2bϵ)+2(4+(β11β1)2)γtLG3T𝔼t=0T1ηtηt12+6γ2L2β12(1β1)2dG3ϵ3\displaystyle+\frac{8G\gamma_{t}L}{\epsilon}(\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon})+\frac{2(4+(\frac{\beta_{1}}{1-\beta_{1}})^{2})\gamma_{t}LG^{3}}{T}\mathbb{E}\sum_{t=0}^{T-1}\|\eta_{t}-\eta_{t-1}\|^{2}+\frac{6\gamma^{2}L^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\frac{dG^{3}}{\epsilon^{3}} (45)
2G(f(x0)f)γtT+45GL2ρt2ϵ+2G3(1β1)Td(1ϵ1G)+8GγtLϵ(Lρt2ϵ+σ2bϵ)\displaystyle\leq\frac{2G(f(x_{0})-f^{*})}{\gamma_{t}T}+\frac{45GL^{2}\rho_{t}^{2}}{\epsilon}+\frac{2G^{3}}{(1-\beta_{1})T}d(\frac{1}{\epsilon}-\frac{1}{G})+\frac{8G\gamma_{t}L}{\epsilon}(\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon})
+2(4+(β11β1)2)γtLG3Td(ϵ2G2)+6γ2L2β12(1β1)2dG3ϵ3\displaystyle+\frac{2(4+(\frac{\beta_{1}}{1-\beta_{1}})^{2})\gamma_{t}LG^{3}}{T}d(\epsilon^{-2}-G^{-2})+\frac{6\gamma^{2}L^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\frac{dG^{3}}{\epsilon^{3}} (46)
=2G(f(x0)f)γtT+8GγtLϵσ2bϵ+45GL2ρt2ϵ+2G3(1β1)Td(1ϵ1G)+8GγtLϵLρt2ϵ\displaystyle=\frac{2G(f(x_{0})-f^{*})}{\gamma_{t}T}+\frac{8G\gamma_{t}L}{\epsilon}\frac{\sigma^{2}}{b\epsilon}+\frac{45GL^{2}\rho_{t}^{2}}{\epsilon}+\frac{2G^{3}}{(1-\beta_{1})T}d(\frac{1}{\epsilon}-\frac{1}{G})+\frac{8G\gamma_{t}L}{\epsilon}\frac{L\rho_{t}^{2}}{\epsilon}
+2(4+(β11β1)2)γtLG3Td(ϵ2G2)+6γ2L2β12(1β1)2dG3ϵ3.\displaystyle+\frac{2(4+(\frac{\beta_{1}}{1-\beta_{1}})^{2})\gamma_{t}LG^{3}}{T}d(\epsilon^{-2}-G^{-2})+\frac{6\gamma^{2}L^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\frac{dG^{3}}{\epsilon^{3}}. (47)

B-A Technical Lemma

Lemma 1.

Given two vectors aa, bdb\in\mathbb{R}^{d}, we have a,bλ22a2+12λ2b2\langle a,b\rangle\leq\frac{\lambda^{2}}{2}\|a\|^{2}+\frac{1}{2\lambda^{2}}\|b\|^{2} for parameter λ\lambda, λ(1,+)\forall\lambda\in(1,+\infty).

Proof.
RHS=λ22j=1d(a)j2+12λ2j=1d(b)j2j=1d2λ22(a)j2×12λ2(b)j2=j=1d|(a)j|×|(b)j|LHS.\displaystyle RHS=\frac{\lambda^{2}}{2}\sum_{j=1}^{d}(a)_{j}^{2}+\frac{1}{2\lambda^{2}}\sum_{j=1}^{d}(b)_{j}^{2}\geq\sum_{j=1}^{d}2\sqrt{\frac{\lambda^{2}}{2}(a)_{j}^{2}\times\frac{1}{2\lambda^{2}}(b)_{j}^{2}}=\sum_{j=1}^{d}|(a)_{j}|\times|(b)_{j}|\geq LHS. (48)

Lemma 2.

For any vector xx,ydy\in\mathbb{R}^{d}, we have

xy2x2×y2x2×y2.\displaystyle\|x\odot y\|^{2}\leq\|x\|^{2}\times\|y\|_{\infty}^{2}\leq\|x\|^{2}\times\|y\|^{2}. (49)
Proof.

The first inequality can be derived from that i=1d(xi2yi2)i=1d(xi2y2)\sum_{i=1}^{d}(x_{i}^{2}y_{i}^{2})\leq\sum_{i=1}^{d}(x_{i}^{2}\|y\|_{\infty}^{2}). The second inequality follows from that y2y2\|y\|_{\infty}^{2}\leq\|y\|^{2}. ∎

Lemma 3.

η\eta is bounded, i.e., 1G(ηt)j1ϵ\frac{1}{G_{\infty}}\leq(\eta_{t})_{j}\leq\frac{1}{\epsilon}.

Proof.

As the gradient is bounded by GG and (ηt)j=1(v^t)j(\eta_{t})_{j}=\frac{1}{\sqrt{(\hat{v}_{t})_{j}}}. Follow the update rule, we have 1G(ηt)j1ϵ\frac{1}{G_{\infty}}\leq(\eta_{t})_{j}\leq\frac{1}{\epsilon}. ∎

Lemma 4.

For the term defined in the algorithm, we have

1T𝔼t=0T1ηt1ηt1dT(1ϵ1G)\displaystyle\frac{1}{T}\mathbb{E}\sum_{t=0}^{T-1}\|\eta_{t-1}-\eta_{t}\|^{1}\leq\frac{d}{T}(\frac{1}{\epsilon}-\frac{1}{G}) (50)
Proof.

(ηt)i(\eta_{t})_{i}, the i-th dimension of ηt\eta_{t} deceases as t increases. So we have

1T𝔼t=0T1ηt1ηt1=𝔼1Ti=1dt=0T1|(ηt1)i(ηt)i|\displaystyle\frac{1}{T}\mathbb{E}\sum_{t=0}^{T-1}\|\eta_{t-1}-\eta_{t}\|^{1}=\mathbb{E}\frac{1}{T}\sum_{i=1}^{d}\sum_{t=0}^{T-1}|(\eta_{t-1})_{i}-(\eta_{t})_{i}|
𝔼1Ti=1d((η1)i(ηT1)i)𝔼1Ti=1d(1ϵ1G)=dT(1ϵ1G)\displaystyle\leq\mathbb{E}\frac{1}{T}\sum_{i=1}^{d}((\eta_{-1})_{i}-(\eta_{T-1})_{i})\leq\mathbb{E}\frac{1}{T}\sum_{i=1}^{d}(\frac{1}{\epsilon}-\frac{1}{G})=\frac{d}{T}(\frac{1}{\epsilon}-\frac{1}{G}) (51)

Lemma 5.

For the term defined in the algorithm, we have

f(zt),γtbiBfi(xt+ρtstst)(ηt1ηt)γtG2ηt1ηt1\displaystyle\langle\nabla f(z_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot(\eta_{t-1}-\eta_{t})\rangle\leq\gamma_{t}G^{2}\|\eta_{t-1}-\eta_{t}\|_{1} (52)
Proof.
f(zt),γtbiBfi(xt+ρtstst)(ηt1ηt)\displaystyle\langle\nabla f(z_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot(\eta_{t-1}-\eta_{t})\rangle
γtj=1d|(f(zt))(j)|×|(1biBfi(xt+ρtfi(xt)fi(xt))(ηt1ηt))(j)|\displaystyle\leq\gamma_{t}\sum_{j=1}^{d}|(\nabla f(z_{t}))_{(j)}|\times|(\frac{1}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum\nabla f_{i}(x_{t})}{\|\sum\nabla f_{i}(x_{t})\|})\odot(\eta_{t-1}-\eta_{t}))_{(j)}| (53)
γtGj=1d|((1biBfi(xt+ρtfi(xt)fi(xt))(ηt1ηt))(j)|\displaystyle\leq\gamma_{t}G\sum_{j=1}^{d}|((\frac{1}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum\nabla f_{i}(x_{t})}{\|\sum\nabla f_{i}(x_{t})\|})\odot(\eta_{t-1}-\eta_{t}))_{(j)}| (54)
γtGbj=1diB|((fi(xt+ρtfi(xt)fi(xt))(ηt1ηt))(j)|\displaystyle\leq\frac{\gamma_{t}G}{b}\sum_{j=1}^{d}\sum_{i\in B}|((\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum\nabla f_{i}(x_{t})}{\|\sum\nabla f_{i}(x_{t})\|})\odot(\eta_{t-1}-\eta_{t}))_{(j)}| (55)
=γtGbj=1diB|(fi(xt+ρtfi(xt)fi(xt))(j)×(ηt1ηt)(j)|\displaystyle=\frac{\gamma_{t}G}{b}\sum_{j=1}^{d}\sum_{i\in B}|(\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum\nabla f_{i}(x_{t})}{\|\sum\nabla f_{i}(x_{t})\|})_{(j)}\times(\eta_{t-1}-\eta_{t})_{(j)}| (56)
γtG2bj=1diB|(ηt1ηt)(j)|\displaystyle\leq\frac{\gamma_{t}G^{2}}{b}\sum_{j=1}^{d}\sum_{i\in B}|(\eta_{t-1}-\eta_{t})_{(j)}| (57)
=γtG2ηt1ηt1\displaystyle=\gamma_{t}G^{2}\|\eta_{t-1}-\eta_{t}\|_{1} (58)

Lemma 6.

For the term defined in the algorithm, we have

f(zt),γβ11β1mt1(ηt1ηt)γβ11β1G2ηt1ηt1\displaystyle\langle\nabla f(z_{t}),\frac{\gamma\beta_{1}}{1-\beta_{1}}m_{t-1}\odot(\eta_{t-1}-\eta_{t})\rangle\leq\frac{\gamma\beta_{1}}{1-\beta_{1}}G^{2}\|\eta_{t-1}-\eta_{t}\|_{1} (59)
Proof.
f(zt),γβ11β1mt1(ηt1ηt)\displaystyle\langle\nabla f(z_{t}),\frac{\gamma\beta_{1}}{1-\beta_{1}}m_{t-1}\odot(\eta_{t-1}-\eta_{t})\rangle
γβ11β1j=1d|(f(zt))(j)|×|(mt1(ηt1ηt))(j)|\displaystyle\leq\frac{\gamma\beta_{1}}{1-\beta_{1}}\sum_{j=1}^{d}|(\nabla f(z_{t}))_{(j)}|\times|(m_{t-1}\odot(\eta_{t-1}-\eta_{t}))_{(j)}| (60)
γβ11β1Gj=1d|(mt1(ηt1ηt))(j)|\displaystyle\leq\frac{\gamma\beta_{1}}{1-\beta_{1}}G\sum_{j=1}^{d}|(m_{t-1}\odot(\eta_{t-1}-\eta_{t}))_{(j)}| (61)
=γβ11β1j=1d|(mt1)(j)×(ηt1ηt)(j)|\displaystyle=\frac{\gamma\beta_{1}}{1-\beta_{1}}\sum_{j=1}^{d}|(m_{t-1})_{(j)}\times(\eta_{t-1}-\eta_{t})_{(j)}| (62)
γβ11β1G2j=1d|(ηt1ηt)(j)|\displaystyle\leq\frac{\gamma\beta_{1}}{1-\beta_{1}}G^{2}\sum_{j=1}^{d}|(\eta_{t-1}-\eta_{t})_{(j)}| (63)
=γβ11β1G2ηt1ηt1\displaystyle=\frac{\gamma\beta_{1}}{1-\beta_{1}}G^{2}\|\eta_{t-1}-\eta_{t}\|_{1} (64)

Lemma 7.

For the term defined in the algorithm, we have

f(xt),γtbiBfi(xt+ρtf(xt)f(xt))ηt1γtbiBfi(xt+ρtstst)ηt1\displaystyle\langle\nabla f(x_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle
γt2μ2f(xt)ηt12+2μ2γtL2ρt2ϵ.\displaystyle\leq\frac{\gamma_{t}}{2\mu^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{2\mu^{2}\gamma_{t}L^{2}\rho_{t}^{2}}{\epsilon}. (65)
Proof.
f(xt),γtbiBfi(xt+ρtf(xt)f(xt))ηt1γtbiBfi(xt+ρtstst)ηt1\displaystyle\langle\nabla f(x_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle
=f(xt)ηt1,γtbiB(fi(xt+ρtf(xt)f(xt))fi(xt+ρtiBfi(xt)iBfi(xt)))ηt1\displaystyle=\langle\nabla f(x_{t})\odot\sqrt{\eta_{t-1}},\frac{\gamma_{t}}{b}\sum_{i\in B}(\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|}))\odot\sqrt{\eta_{t-1}}\rangle (66)
μ2γt2b2(fi(xt+ρtf(xt)f(xt))fi(xt+ρtiBfi(xt)iBfi(xt)))ηt12\displaystyle\leq\frac{\mu^{2}\gamma_{t}}{2b^{2}}\|\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|}))\odot\sqrt{\eta_{t-1}}\|^{2}
+γt2μ2f(xt)ηt12\displaystyle+\frac{\gamma_{t}}{2\mu^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2} (67)
+μ2γt2bfi(xt+ρtf(xt)f(xt))fi(xt+ρtiBfi(xt)iBfi(xt))ηt12\displaystyle\leq+\frac{\mu^{2}\gamma_{t}}{2b}\sum\|\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|})\odot\sqrt{\eta_{t-1}}\|^{2}
+γt2μ2f(xt)ηt12\displaystyle+\frac{\gamma_{t}}{2\mu^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2} (68)
+μ2γt2bfi(xt+ρtf(xt)f(xt))fi(xt+ρtiBfi(xt)iBfi(xt))2×ηt12\displaystyle\leq+\frac{\mu^{2}\gamma_{t}}{2b}\sum\|\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|})\|^{2}\times\|\sqrt{\eta_{t-1}}\|^{2}_{\infty}
+γt2μ2f(xt)ηt12\displaystyle+\frac{\gamma_{t}}{2\mu^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2} (69)
γt2μ2f(xt)ηt12+μ2γtL2ρt22bϵf(xt)f(xt)iBfi(xt)iBfi(xt)2\displaystyle\leq\frac{\gamma_{t}}{2\mu^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\mu^{2}\gamma_{t}L^{2}\rho_{t}^{2}}{2b\epsilon}\sum\|\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|}-\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|}\|^{2} (70)
γt2μ2f(xt)ηt12+2μ2γtL2ρt2ϵ.\displaystyle\leq\frac{\gamma_{t}}{2\mu^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{2\mu^{2}\gamma_{t}L^{2}\rho_{t}^{2}}{\epsilon}. (71)

Lemma 8.

For the term defined in the algorithm, we have

𝔼f(xt),γtbiBfi(xt+ρtf(xt)f(xt))ηt1\displaystyle\;\;\;\;\mathbb{E}\langle\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}\rangle
γtf(xt)ηt12+𝔼γt2α2f(xt)ηt12+γtα2L2ρt22ϵ\displaystyle\leq-\gamma_{t}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\frac{\gamma_{t}}{2\alpha^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\gamma_{t}\alpha^{2}L^{2}\rho_{t}^{2}}{2\epsilon} (72)
Proof.
𝔼f(xt),γtbiBfi(xt+ρtf(xt)f(xt))ηt1\displaystyle\;\;\;\;\mathbb{E}\langle\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\eta_{t-1}\rangle
=γtf(xt)ηt12+𝔼f(xt),γtbiB(f(xt)fi(xt+ρtf(xt)f(xt)))ηt1\displaystyle=-\gamma_{t}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\langle\nabla f(x_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}(\nabla f(x_{t})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|}))\odot\eta_{t-1}\rangle (73)
=γtf(xt)ηt12+𝔼f(xt),γtbiB(fi(xt)fi(xt+ρtf(xt)f(xt)))ηt1\displaystyle=-\gamma_{t}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\langle\nabla f(x_{t}),\frac{\gamma_{t}}{b}\sum_{i\in B}(\nabla f_{i}(x_{t})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|}))\odot\eta_{t-1}\rangle (74)
γtf(xt)ηt12+𝔼γt2α2f(xt)ηt12\displaystyle\leq-\gamma_{t}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\frac{\gamma_{t}}{2\alpha^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}
+γtα22𝔼1biB(fi(xt)fi(xt+ρtf(xt)f(xt)))ηt12\displaystyle+\frac{\gamma_{t}\alpha^{2}}{2}\mathbb{E}\|\frac{1}{b}\sum_{i\in B}(\nabla f_{i}(x_{t})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|}))\odot\sqrt{\eta_{t-1}}\|^{2} (75)
γtf(xt)ηt12+𝔼γt2α2f(xt)ηt12\displaystyle\leq-\gamma_{t}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\frac{\gamma_{t}}{2\alpha^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}
+γtα22ϵ𝔼1biB(fi(xt)fi(xt+ρtf(xt)f(xt)))2\displaystyle+\frac{\gamma_{t}\alpha^{2}}{2\epsilon}\mathbb{E}\|\frac{1}{b}\sum_{i\in B}(\nabla f_{i}(x_{t})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|}))\|^{2} (76)
γtf(xt)ηt12+𝔼γt2α2f(xt)ηt12\displaystyle\leq-\gamma_{t}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\frac{\gamma_{t}}{2\alpha^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}
+γtα22bϵ𝔼iB(fi(xt)fi(xt+ρtf(xt)f(xt)))2\displaystyle+\frac{\gamma_{t}\alpha^{2}}{2b\epsilon}\mathbb{E}\sum_{i\in B}\|(\nabla f_{i}(x_{t})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|}))\|^{2} (77)
γtf(xt)ηt12+𝔼γt2α2f(xt)ηt12+γtα2L2ρt22bϵ𝔼iBf(xt)f(xt)2\displaystyle\leq-\gamma_{t}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\frac{\gamma_{t}}{2\alpha^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\gamma_{t}\alpha^{2}L^{2}\rho_{t}^{2}}{2b\epsilon}\mathbb{E}\sum_{i\in B}\|\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|}\|^{2} (78)
=γtf(xt)ηt12+𝔼γt2α2f(xt)ηt12+γtα2L2ρt22ϵ\displaystyle=-\gamma_{t}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\frac{\gamma_{t}}{2\alpha^{2}}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\gamma_{t}\alpha^{2}L^{2}\rho_{t}^{2}}{2\epsilon} (79)

Lemma 9.

For the term defined in the algorithm, we have

𝔼f(zt)f(xt),γtbiBfi(xt+ρtstst)ηt1\displaystyle\mathbb{E}\langle\nabla f(z_{t})-\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle
γ3L2β122ϵ(1β1)2(1λ12+1λ22+1λ32)dG2ϵ2+γλ122f(xt)ηt12+γL2ρt22ϵ(λ22+4λ32).\displaystyle\leq\frac{\gamma^{3}L^{2}\beta_{1}^{2}}{2\epsilon(1-\beta_{1})^{2}}(\frac{1}{\lambda_{1}^{2}}+\frac{1}{\lambda_{2}^{2}}+\frac{1}{\lambda_{3}^{2}})\frac{dG_{\infty}^{2}}{\epsilon^{2}}+\frac{\gamma\lambda_{1}^{2}}{2}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\gamma L^{2}\rho_{t}^{2}}{2\epsilon}(\lambda_{2}^{2}+4\lambda_{3}^{2}). (80)
Proof.
𝔼f(zt)f(xt),γtbiBfi(xt+ρtstst)ηt1\displaystyle\mathbb{E}\langle\nabla f(z_{t})-\nabla f(x_{t}),-\frac{\gamma_{t}}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|})\odot\eta_{t-1}\rangle (81)
=γ𝔼(f(xt)f(zt))ηt1,1biBfi(xt+ρtiBfi(xt)iBfi(xt))ηt1\displaystyle=\gamma\mathbb{E}\langle(\nabla f(x_{t})-\nabla f(z_{t}))\odot\sqrt{\eta_{t-1}},\frac{1}{b}\sum_{i\in B}\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|})\odot\sqrt{\eta_{t-1}}\rangle (82)
=γ𝔼(f(xt)f(zt))ηt1,f(xt)ηt1\displaystyle=\gamma\mathbb{E}\langle(\nabla f(x_{t})-\nabla f(z_{t}))\odot\sqrt{\eta_{t-1}},\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\rangle
+γ𝔼(f(xt)f(zt))ηt1,1biB(fi(xt+ρtf(xt)f(xt))fi(xt))ηt1\displaystyle+\gamma\mathbb{E}\langle(\nabla f(x_{t})-\nabla f(z_{t}))\odot\sqrt{\eta_{t-1}},\frac{1}{b}\sum_{i\in B}(\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})-\nabla f_{i}(x_{t}))\odot\sqrt{\eta_{t-1}}\rangle
+γ𝔼(f(xt)f(zt))ηt1,1biB(fi(xt+ρtiBfi(xt)iBfi(xt))fi(xt+ρtf(xt)f(xt))ηt1\displaystyle+\gamma\mathbb{E}\langle(\nabla f(x_{t})-\nabla f(z_{t}))\odot\sqrt{\eta_{t-1}},\frac{1}{b}\sum_{i\in B}(\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\sqrt{\eta_{t-1}}\rangle (83)
γ2(1λ12+1λ22+1λ32)𝔼(f(xt)f(zt))ηt12+γλ122f(xt)ηt12\displaystyle\leq\frac{\gamma}{2}(\frac{1}{\lambda_{1}^{2}}+\frac{1}{\lambda_{2}^{2}}+\frac{1}{\lambda_{3}^{2}})\mathbb{E}\|(\nabla f(x_{t})-\nabla f(z_{t}))\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\gamma\lambda_{1}^{2}}{2}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}
+γλ222𝔼1biB(fi(xt+ρtf(xt)f(xt))fi(xt))ηt12\displaystyle+\frac{\gamma\lambda_{2}^{2}}{2}\mathbb{E}\|\frac{1}{b}\sum_{i\in B}(\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})-\nabla f_{i}(x_{t}))\odot\sqrt{\eta_{t-1}}\|^{2}
+γλ322𝔼1biB(fi(xt+ρtiBfi(xt)iBfi(xt))fi(xt+ρtf(xt)f(xt))ηt12\displaystyle+\frac{\gamma\lambda_{3}^{2}}{2}\mathbb{E}\|\frac{1}{b}\sum_{i\in B}(\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|})-\nabla f_{i}(x_{t}+\rho_{t}\frac{\nabla f(x_{t})}{\|\nabla f(x_{t})\|})\odot\sqrt{\eta_{t-1}}\|^{2} (84)
γ2(1λ12+1λ22+1λ32)𝔼(f(xt)f(zt))ηt12+γλ122f(xt)ηt12\displaystyle\leq\frac{\gamma}{2}(\frac{1}{\lambda_{1}^{2}}+\frac{1}{\lambda_{2}^{2}}+\frac{1}{\lambda_{3}^{2}})\mathbb{E}\|(\nabla f(x_{t})-\nabla f(z_{t}))\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\gamma\lambda_{1}^{2}}{2}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}
+γλ22L2ρt22ϵ+2λ32γL2ρt2ϵ\displaystyle+\frac{\gamma\lambda_{2}^{2}L^{2}\rho_{t}^{2}}{2\epsilon}+\frac{2\lambda_{3}^{2}\gamma L^{2}\rho_{t}^{2}}{\epsilon} (85)
γL22ϵ(1λ12+1λ22+1λ32)𝔼ztxt2+γλ122f(xt)ηt12\displaystyle\leq\frac{\gamma L^{2}}{2\epsilon}(\frac{1}{\lambda_{1}^{2}}+\frac{1}{\lambda_{2}^{2}}+\frac{1}{\lambda_{3}^{2}})\mathbb{E}\|z_{t}-x_{t}\|^{2}+\frac{\gamma\lambda_{1}^{2}}{2}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}
+γλ22L2ρt22ϵ+2λ32γL2ρt2ϵ\displaystyle+\frac{\gamma\lambda_{2}^{2}L^{2}\rho_{t}^{2}}{2\epsilon}+\frac{2\lambda_{3}^{2}\gamma L^{2}\rho_{t}^{2}}{\epsilon} (86)
=γ3L2β122ϵ(1β1)2(1λ12+1λ22+1λ32)mt1ηt12+γλ122f(xt)ηt12\displaystyle=\frac{\gamma^{3}L^{2}\beta_{1}^{2}}{2\epsilon(1-\beta_{1})^{2}}(\frac{1}{\lambda_{1}^{2}}+\frac{1}{\lambda_{2}^{2}}+\frac{1}{\lambda_{3}^{2}})\|m_{t-1}\odot\eta{t-1}\|^{2}+\frac{\gamma\lambda_{1}^{2}}{2}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}
+γλ22L2ρt22ϵ+2λ32γL2ρt2ϵ\displaystyle+\frac{\gamma\lambda_{2}^{2}L^{2}\rho_{t}^{2}}{2\epsilon}+\frac{2\lambda_{3}^{2}\gamma L^{2}\rho_{t}^{2}}{\epsilon} (87)
γ3L2β122ϵ(1β1)2(1λ12+1λ22+1λ32)dG2ϵ2+γλ122f(xt)ηt12+γL2ρt22ϵ(λ22+4λ32).\displaystyle\leq\frac{\gamma^{3}L^{2}\beta_{1}^{2}}{2\epsilon(1-\beta_{1})^{2}}(\frac{1}{\lambda_{1}^{2}}+\frac{1}{\lambda_{2}^{2}}+\frac{1}{\lambda_{3}^{2}})\frac{dG_{\infty}^{2}}{\epsilon^{2}}+\frac{\gamma\lambda_{1}^{2}}{2}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{\gamma L^{2}\rho_{t}^{2}}{2\epsilon}(\lambda_{2}^{2}+4\lambda_{3}^{2}). (88)

Lemma 10.

For the term defined in the algorithm, we have

L2𝔼zt+1zt2LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle\frac{L}{2}\mathbb{E}\|z_{t+1}-z_{t}\|^{2}\leq\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}
+γt2L(31+ββϵ(𝔼f(xt)ηt12+Lρt2ϵ+σ2bϵ)+(1+β)G2𝔼ηtηt12)\displaystyle+\gamma_{t}^{2}L(3\frac{1+\beta}{\beta\epsilon}(\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon})+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}) (89)
Proof.
L2𝔼zt+1zt2\displaystyle\frac{L}{2}\mathbb{E}\|z_{t+1}-z_{t}\|^{2}
=L2𝔼γβ11β1mt1(ηtηt1)γgtηt2\displaystyle=\frac{L}{2}\mathbb{E}\|\frac{\gamma\beta_{1}}{1-\beta_{1}}m_{t-1}\odot(\eta_{t}-\eta_{t-1})-\gamma g_{t}\odot\eta_{t}\|^{2} (90)
Lγ2β12(1β1)2𝔼mt1(ηtηt1)2+L𝔼γtb(fi(xt+ρtstst))ηt2\displaystyle\leq\frac{L\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|m_{t-1}\odot(\eta_{t}-\eta_{t-1})\|^{2}+L\mathbb{E}\|\frac{\gamma_{t}}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|}))\odot\eta_{t}\|^{2} (91)
LG2γ2β12(1β1)2𝔼ηtηt12+L𝔼γtb(fi(xt+ρtstst))ηt2\displaystyle\leq\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}+L\mathbb{E}\|\frac{\gamma_{t}}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|}))\odot\eta_{t}\|^{2} (92)
=γt2L𝔼1b(fi(xt+ρtstst))ηt1+1b(fi(xt+ρtstst))(ηtηt1)2\displaystyle=\gamma_{t}^{2}L\mathbb{E}\|\frac{1}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|}))\odot\eta_{t-1}+\frac{1}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|}))\odot(\eta_{t}-\eta_{t-1})\|^{2}
+LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle+\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2} (93)
LG2γ2β12(1β1)2𝔼ηtηt12+γt2L((1+1β)𝔼1b(fi(xt+ρtstst))ηt12\displaystyle\leq\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}+\gamma_{t}^{2}L((1+\frac{1}{\beta})\mathbb{E}\|\frac{1}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|}))\odot\eta_{t-1}\|^{2}
+(1+β)𝔼1b(fi(xt+ρtstst))(ηtηt1)2)\displaystyle\;\;\;\;\;\;\;+(1+\beta)\mathbb{E}\|\frac{1}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|}))\odot(\eta_{t}-\eta_{t-1})\|^{2}) (94)
γt2L((1+1β)𝔼1b(fi(xt+ρtstst))ηt12+(1+β)G2𝔼ηtηt12)\displaystyle\leq\gamma_{t}^{2}L((1+\frac{1}{\beta})\mathbb{E}\|\frac{1}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|}))\odot\eta_{t-1}\|^{2}+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2})
+LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle+\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2} (95)
γt2L((1+1β)𝔼1b(fi(xt+ρtstst))ηt12×ηt12\displaystyle\leq\gamma_{t}^{2}L((1+\frac{1}{\beta})\mathbb{E}\|\frac{1}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|}))\odot\sqrt{\eta_{t-1}}\|^{2}\times\|\sqrt{\eta_{t-1}}\|^{2}_{\infty}
+(1+β)G2𝔼ηtηt12)+LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2})+\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2} (96)
γt2L(1+ββϵ𝔼1b(fi(xt+ρtstst))ηt12+(1+β)G2𝔼ηtηt12)\displaystyle\leq\gamma_{t}^{2}L(\frac{1+\beta}{\beta\epsilon}\mathbb{E}\|\frac{1}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{s_{t}}{\|s_{t}\|}))\odot\sqrt{\eta_{t-1}}\|^{2}+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2})
+LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle+\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2} (97)
γt2L(31+ββϵ𝔼(f(xt)ηt12+(1bfi(xt)f(xt))ηt12\displaystyle\leq\gamma_{t}^{2}L(3\frac{1+\beta}{\beta\epsilon}\mathbb{E}(\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\|(\frac{1}{b}\sum\nabla f_{i}(x_{t})-\nabla f(x_{t}))\odot\sqrt{\eta_{t-1}}\|^{2}
+1b(fi(xt+ρtiBfi(xt)iBfi(xt))fi(xt))ηt12)+(1+β)G2𝔼ηtηt12)\displaystyle+\|\frac{1}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|})-\nabla f_{i}(x_{t}))\odot\sqrt{\eta_{t-1}}\|^{2})+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2})
+LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle+\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2} (98)
γt2L(31+ββϵ(𝔼f(xt)ηt12+𝔼1b(fi(xt+ρtiBfi(xt)iBfi(xt))fi(xt))ηt12\displaystyle\leq\gamma_{t}^{2}L(3\frac{1+\beta}{\beta\epsilon}(\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\mathbb{E}\|\frac{1}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|})-\nabla f_{i}(x_{t}))\odot\sqrt{\eta_{t-1}}\|^{2}
+σ2bϵ)+(1+β)G2𝔼ηtηt12)+LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle+\frac{\sigma^{2}}{b\epsilon})+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2})+\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2} (99)
γt2L(31+ββϵ(𝔼f(xt)ηt12+1ϵ𝔼1b(fi(xt+ρtiBfi(xt)iBfi(xt))fi(xt))2\displaystyle\leq\gamma_{t}^{2}L(3\frac{1+\beta}{\beta\epsilon}(\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{1}{\epsilon}\mathbb{E}\|\frac{1}{b}\sum(\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|})-\nabla f_{i}(x_{t}))\|^{2}
+σ2bϵ)+(1+β)G2𝔼ηtηt12)+LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle+\frac{\sigma^{2}}{b\epsilon})+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2})+\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2} (100)
γt2L(31+ββϵ(𝔼f(xt)ηt12+1ϵb𝔼fi(xt+ρtiBfi(xt)iBfi(xt))fi(xt)2\displaystyle\leq\gamma_{t}^{2}L(3\frac{1+\beta}{\beta\epsilon}(\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{1}{\epsilon b}\mathbb{E}\sum\|\nabla f_{i}(x_{t}+\rho_{t}\frac{\sum_{i\in B}\nabla f_{i}(x_{t})}{\|\sum_{i\in B}\nabla f_{i}(x_{t})\|})-\nabla f_{i}(x_{t})\|^{2}
+σ2bϵ)+(1+β)G2𝔼ηtηt12)+LG2γ2β12(1β1)2𝔼ηtηt12\displaystyle+\frac{\sigma^{2}}{b\epsilon})+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2})+\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2} (101)
γt2L(31+ββϵ(𝔼f(xt)ηt12+Lρt2ϵ+σ2bϵ)+(1+β)G2𝔼ηtηt12)\displaystyle\leq\gamma_{t}^{2}L(3\frac{1+\beta}{\beta\epsilon}(\mathbb{E}\|\nabla f(x_{t})\odot\sqrt{\eta_{t-1}}\|^{2}+\frac{L\rho_{t}^{2}}{\epsilon}+\frac{\sigma^{2}}{b\epsilon})+(1+\beta)G^{2}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2})
+LG2γ2β12(1β1)2𝔼ηtηt12.\displaystyle+\frac{LG^{2}\gamma^{2}\beta_{1}^{2}}{(1-\beta_{1})^{2}}\mathbb{E}\|\eta_{t}-\eta_{t-1}\|^{2}. (102)

Appendix C Additional Experiment Illustrations

C-A Experiment Illustrations

Refer to caption
(a) MRPC
Refer to caption
(b) RTE
Refer to caption
(c) CoLA
Refer to caption
(d) SST-2

​​ ​​ ​​

Refer to caption
(e) MRPC
Refer to caption
(f) RTE
Refer to caption
(g) CoLA
Refer to caption
(h) SST-2

​​ ​​ ​​

Refer to caption
(i) STS-B
Refer to caption
(j) MNLI
Refer to caption
(k) QQP
Refer to caption
(l) QNLI

​​ ​​ ​​ ​​

Refer to caption
(m) STS-B
Refer to caption
(n) MNLI
Refer to caption
(o) QQP
Refer to caption
(p) QNLI

​​ ​​ ​​ ​​

Figure 3: The loss and evaluation metric v.s. steps on MRPC, RTE, CoLA, SST-2, STS-B, MNLI, QQP and QNLI.(β1=0\beta_{1}=0)

In the ablation study, we conduct the experiments on the GLUE benchmark with AdaSAM, AMSGrad, SAM and SGD, respectively. The optimizers do not have the momentum part (β1=0\beta_{1}=0). As a supplement to Table II, Figure 3 show the detailed loss and evaluation metrics versus number of steps curves during training. The loss curve of AdaSAM decreases faster than SAM and SGD in all tasks, and it has a similar decreasing speed as the AMSGrad. The metric curve of AdaSAM and AMSGrad show that the adaptive learning rate method is better than SGD and SAM. And AdaSAM decrease as faster as the AMSGrad in all tasks.