This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Implicit Bias of SignGD and Adam on Multiclass Separable Data

Chen Fan    Mark Schmidt    Christos Thrampoulidis
Abstract

In the optimization of overparameterized models, different gradient-based methods can achieve zero training error yet converge to distinctly different solutions inducing different generalization properties. While a decade of research on implicit optimization bias has illuminated this phenomenon in various settings, even the foundational case of linear classification with separable data still has important open questions. We resolve a fundamental gap by characterizing the implicit bias of both Adam and Sign Gradient Descent in multi-class cross-entropy minimization: we prove that their iterates converge to solutions that maximize the margin with respect to the classifier matrix’s max-norm and characterize the rate of convergence. We extend our results to general p-norm normalized steepest descent algorithms and to other multi-class losses.

Machine Learning, ICML

1 Introduction

Machine learning models are trained to minimize a surrogate loss function with the goal of finding model-weight configurations that generalize well. The loss function used for training is typically a proxy for the evaluation metric. The prototypical example is one-hot classification where the predominant choice, the cross-entropy (CE) loss, serves as a convex surrogate of the zero-one metric that measures correct class membership. Intuitively, a good surrogate loss is one where weight configurations minimizing the training loss also achieve strong generalization performance. However, modern machine learning models are overparameterized, leading to multiple weight configurations that achieve identical training loss but exhibit markedly different generalization properties (Zhang et al., 2017; Belkin et al., 2019).

This observation has motivated extensive research into the implicit bias (or implicit regularization) of gradient-based optimization. This theory investigates how optimizers select specific solutions from the infinite set of possible minimizers in overparameterized settings. The key insight is that gradient-based methods inherently prefer “simple” solutions according to optimizer-specific notions of simplicity. Understanding this preference requires analyzing not just loss convergence, but the geometric trajectory of parameter updates throughout training.

The prototypical example—which has rightfully earned its place as a “textbook” setting in the field—is the study of implicit optimization bias of gradient descent in linear classification. In this setting, embeddings of the training data are assumed fixed and only the classifier is learned, with overparameterization modeled through linear separability of the training data. (Soudry et al., 2018) demonstrated that while the GD parameters diverge in norm, they converge in direction to the maximum-margin classifier—the minimum L2L_{2}-norm solution that separates the data. This non-trivial result, despite (or perhaps because of) the simplicity of its setting, spurred numerous research directions extending to nonlinear models, alternative optimizers, different loss functions, and theoretical explanations of phenomena like benign overfitting (see Sec. 2).

Our paper contributes a fundamental result to this literature by establishing the implicit bias of Adam and its simpler variant, signed GD (SignGD), in the multiclass version of this textbook setting: minimization of CE loss on multiclass linearly separable data. Our work addresses several gaps between theory assumption and practice: the predominant focus on binary classification despite multiclass problems dominating real-world applications; the much theoretical emphasis on exponential loss despite practitioners’ overwhelming preference for cross-entropy loss; and the limited theoretical understanding of Adam’s implicit bias, which has only recently been studied in the binary case (Zhang et al., 2024) despite its widespread adoption in deep learning practice. Here, we provide a direct analysis of the multiclass setting by exploiting the nice properties of the softmax function, avoiding the traditional approach of reducing multiclass problems to binary ones (Wang & Scott, 2024).

Contributions. Our contributions are as follows: For multiclass separable data trained with CE loss, we show that the iterates of SignGD converge to a solution that maximizes the margin defined with respect to (w.r.t.) the matrix max-norm, with a rate 𝒪(1t1/2)\mathcal{O}(\frac{1}{t^{1/2}}). We generalize this result to normalized steepest descent algorithms w.r.t. any entry-wise matrix pp-norm. This directly extends Nacson et al. (2019); Sun et al. (2023)’s results to multiclass classification and to the widely-used CE loss. To achieve this, we construct a proxy for the loss and show that it closely traces its value and gradient. We also show the same machinery applies to other multiclass losses such as the exponential loss (Mukherjee & Schapire, 2010) and the PairLogLoss (Wang et al., 2021b).

Under the same setting, we prove that the iterates of Adam also maximize the margin w.r.t. the matrix max-norm, with a rate 𝒪(1t1/3)\mathcal{O}(\frac{1}{t^{1/3}}). This matches the convergence rate previously established for binary classification (Zhang et al., 2024), and remarkably, is independent of the number of classes kk. This class-independence is non-trivial since naive reduction approaches that map a kk-class problem in dd dimensions to a binary problem in kdkd dimensions inevitably introduce kk-dependent factors. Our key insight is to decompose the CE proxy function class-wise, enabling us to bound Adam’s first and second gradient moments separately for each class.This decomposition reveals how each component of the proxy function governs the dynamics of the weight vectors associated with its corresponding class. Finally, we experimentally verify our theory predictions for the algorithms considered. Specific to Adam, we numerically demonstrate that the implicit bias results are consistent with or without the (small) stability constant in the multiclass setting, and the solutions found by SignGD and Adam favor the max-norm margin over the 2-norm margin.

2 Related Works

Starting with GD, the foundational result by (Soudry et al., 2018) showed that gradient descent optimization of logistic loss on linearly separable data converges in direction to the L2L_{2} max-margin classifier at a rate O(1/log(t))O(1/\log(t)). Contemporaneous work by (Ji & Telgarsky, 2019) generalized this by removing the data separability requirement. (Ji et al., 2020) later connected these findings to earlier work on regularization paths of logistic loss minimization (Rosset et al., 2003), which enabled extensions to other loss functions (e.g., those with polynomial tail decay). More recently, (Wu et al., 2024b) extends these results to the large step size regime with the same O(1/log(t))O(1/\log(t)) rate. The relatively slow convergence rate to the max-margin classifier motivated investigation into adaptive step-sizes. (Nacson et al., 2019) showed that normalized gradient descent (NGD) with decaying step-size ηt=1/t\eta_{t}=1/\sqrt{t} achieves L2L_{2}-margin convergence at rate O(1/t)O(1/\sqrt{t}). This rate was improved to O(1/t)O(1/t) by (Ji & Telgarsky, 2021) using constant step-sizes, and further to O(1/t2)O(1/t^{2}) through a specific momentum formulation (Ji et al., 2021). Besides linear classifications, implicit bias of GD has been studied for least squares (Gunasekar et al., 2017, 2018), homogeneous (Lyu & Li, 2019; Ji & Telgarsky, 2020) and non-homogeneous neural networks (Wu et al., 2024a), and matrix factorization (Gunasekar et al., 2017); see (Vardi, 2023) for a survey.

Beyond GD, (Gunasekar et al., 2018) and (Nacson et al., 2019) showed that steepest descent optimization w.r.t. norm \|\cdot\| yields updates that in the limit maximize the margin with respect to the same norm. (Sun et al., 2022) showed that mirror descent with potential function chosen as the pp-th power of the pp-norm (an algorithm which also enjoys efficient parallelization) yields updates that converge in direction to the classifier that maximizes the margin with respect to the pp-norm. In both cases, the convergence rate is slow at O(1/log(t))O(1/\log(t)). Wang et al. (2023) further improved the rates for both steepest descent and mirror descent when p(1,2]p\in(1,2]. It is important to note that all these results apply only to the exponential loss. More recently, Tsilivis et al. (2024) have shown that iterates of steepest descent algorithms converge to a KKT point of a generalized margin maximization problem in homogeneous neural networks.

On the other hand, the implicit bias of adaptive algorithms such as Adagrad (Duchi et al., 2011) or Adam (Kingma & Ba, 2014) is less explored compared to GD. (Qian & Qian, 2019) studied the implict bias of Adagrad and showed its directional convergence to a solution characterized by a quadratic minimization problem. (Wang et al., 2021a, 2022) demonstrated the normalized iterates of Adam (with non-negligible stability constant) converge to a KKT point of a L2L_{2}-margin maximization problem for homogeneous neural networks. More recently and most relevant to our work, (Zhang et al., 2024) studied the implicit bias of Adam without the stability constant on binary linearly separable data. They showed that unlike GD, the Adam iterates converge to a solution that maximizes the margin with respect to the LL_{\infty}-norm. This study excluding the stability constant is practically-relevant given the magnitude of the constant is typically very small (default 1e81e-8 in PyTorch (Paszke et al., 2019)). This setting is also the focus of another closely-related recent study of the implicit bias of AdamW (Xie & Li, 2024), where the authors again establish that convergence aligns with the LL_{\infty} (rather than L2L_{2}) geometry. Our work extends these latter studies to the multiclass setting (see Remark 6.5 for technical comparisons).

All the above mentioned works focus solely on binary classification. The noticeable gap in analysis of multi class classification in most existing literature is recently emphasized by (Ravi et al., 2024) who extend the implicit bias result of (Soudry et al., 2018) to multiclass classification for losses with exponential tails, including cross-entropy, multiclass exponential, and PairLogLoss. Their approach leverages a framework introduced by (Wang & Scott, 2024), mapping multiclass analysis to binary cases. Our work directly addresses their open questions regarding the implicit bias of alternative gradient-based methods in multiclass settings by analyzing methods with adaptive step-sizes. Thanks to the adaptive step-sizes, our rates of convergence to the margin improve to polynomial dependence on tt. Furthermore, our technical approach differs: rather than mapping to binary analysis, we work directly with multiclass losses, exploiting properties of the softmax function to produce elegant proofs that apply to all three losses studied by (Ravi et al., 2024). Our class-wise decomposition is crucial for analyzing Adam with the same convergence rate as the binary case, avoiding any extra factors that depend on the number of classes.

3 Preliminaries

Notations

For any integer kk, [k][k] denotes {1,,k}\{1,\ldots,k\}. Matrices, vectors, and scalars are denoted by 𝑨\bm{A}, 𝒂\bm{a}, and aa respectively. For matrix 𝑨\bm{A}, we denote its (i,j)(i,j)-th entry as 𝑨[i,j]\bm{A}[i,j], and for vector 𝒂\bm{a}, its ii-th entry as 𝒂[i]\bm{a}[i]. We consider entry-wise matrix pp-norms defined as 𝑨p=(i,j|𝑨[i,j]|p)1/p\|\bm{A}\|_{p}=(\sum_{i,j}|{\bm{A}}[i,j]|^{p})^{1/p}. Central to our results are: the infinity norm, denoted as 𝑨max=maxi,j|𝑨[i,j]|{\left\|\bm{A}\right\|_{\max}}=\max_{i,j}|\bm{A}[i,j]| and called the max-norm, and the entry-wise 1-norm, denoted as 𝑨sum=i,j|𝑨[i,j]|{\left\|{\bm{A}}\right\|_{\rm{sum}}}=\sum_{i,j}|\bm{A}[i,j]|. For any other entry-wise pp-norm with p>1p>1, we write 𝑨\|{\bm{A}}\| (dropping subscripts) for simplicity. We denote by 𝑨\|\bm{A}\|_{*} the dual-norm with respect to the standard matrix inner product 𝑨,𝑩=tr(𝑨𝑩)\langle\bm{A},\bm{B}\rangle=\operatorname{tr}(\bm{A}^{\top}\bm{B}). The entry-wise 1-norm is dual to the max-norm. For vectors, the max-norm is equivalent to the infinity norm, denoted as 𝒂\|\bm{a}\|_{\infty}, while we denote the 1\ell_{1} norm as 𝒂1\|\bm{a}\|_{1}. Let indicator δij\delta_{ij} such that δij=1\delta_{ij}=1 if and only if i=ji=j. Denote 𝕊:kk1\mathbb{S}:\mathbb{R}^{k}\rightarrow\triangle^{k-1} the softmax map of kk-dimensional vectors to the probability simplex k1\triangle^{k-1} such that for 𝒂k\bm{a}\in\mathbb{R}^{k}:

𝕊(𝒂)=[exp(𝒂[c])c[k]exp(𝒂[c])]c=1kk1.\mathbb{S}(\bm{a})=\big{[}\frac{\exp({\bm{a}[c]})}{\sum_{c\in[k]}\exp(\bm{a}[c])}\big{]}_{c=1}^{k}\in\triangle^{k-1}.

Let 𝕊c(𝒗)\mathbb{S}_{c}({\bm{v}}) denote the cc-th entry of 𝕊(𝒗)\mathbb{S}({\bm{v}}). Let 𝕊(𝒂)=diag(𝕊(𝒂))𝕊(𝒂)𝕊(𝒂)\mathbb{S}^{\prime}(\bm{a})=\operatorname{diag}(\mathbb{S}(\bm{a}))-\mathbb{S}(\bm{a})\mathbb{S}(\bm{a})^{\top} denote the softmax gradient, with diag()\operatorname{diag}(\cdot) a diagonal matrix. Finally, let {𝒆c}c=1k\{\bm{e}_{c}\}_{c=1}^{k} be the standard basis vectors of k\mathbb{R}^{k}.

Setup

Consider a multiclass classification problem with training data 𝒉1,,𝒉n{\bm{h}_{1},\ldots,\bm{h}_{n}} and labels y1,,yn{y_{1},\ldots,y_{n}}. Each datapoint 𝒉id\bm{h}_{i}\in\mathbb{R}^{d} is a vector in a dd-dimensional embedding space (denote data matrix 𝑯=[𝒉1,,𝒉n]n×d{\bm{H}}=[\bm{h}_{1},\ldots,\bm{h}_{n}]^{\top}\in\mathbb{R}^{n\times d}), and each label yi[k]y_{i}\in[k] represents one of kk classes. We assume each class contains at least one datapoint. The classifier f𝑾:df_{{\bm{W}}}:\mathbb{R}^{d}\rightarrow\mathbb{R} is a linear model with weight matrix 𝑾k×d{\bm{W}}\in\mathbb{R}^{k\times d}. The model outputs logits i=f𝑾(𝒉i)=𝑾𝒉i\bm{\ell}_{i}=f_{{\bm{W}}}(\bm{h}_{i})={\bm{W}}\bm{h}_{i} for i[n]i\in[n], which are passed through the softmax map to produce class probabilities p^(c|𝒉i)=𝕊c(i)\hat{p}(c|\bm{h}_{i})=\mathbb{S}_{c}(\bm{\ell}_{i}). We train using empirical risk minimization (ERM): ERM(𝑾):=1ni[n](𝑾𝒉i;yi),\mathcal{L}_{\text{ERM}}({\bm{W}}):=-\frac{1}{n}\sum_{i\in[n]}\ell\left({{\bm{W}}\bm{h}_{i}};y_{i}\right)\,, where the loss function \ell takes as input the logits of a datapoint and its label. The predominant choice in classification is the CE loss

(𝑾):=1ni[n]log(𝕊yi(𝑾𝒉i))\displaystyle\mathcal{L}({\bm{W}}):=-\frac{1}{n}\sum\nolimits_{i\in[n]}\log\big{(}\mathbb{S}_{{y_{i}}}({\bm{W}}\bm{h}_{i})\big{)} (1)
=1ni[n]log(1+cyiexp((𝒆yi𝒆c)𝑾𝒉i)),\displaystyle=\frac{1}{n}\sum\nolimits_{i\in[n]}\log\big{(}1+\sum\nolimits_{c\neq y_{i}}\exp(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i})\big{)}\,,

We focus our discussions on the CE loss due to its ubiquity in practice. However, our results hold for other multiclass losses such as the exponential (Mukherjee & Schapire, 2010) and the PairLogLoss (Wang et al., 2021b) (see App. F).

Define maximum margin of the dataset w.r.t. norm \|\cdot\| as

γ:=max𝑾1mini[n]cyi(𝒆yi𝒆c)𝑾𝒉i,\displaystyle\gamma_{\|\cdot\|}:=\max_{\|{\bm{W}}\|\leq 1}\,\min_{\begin{subarray}{c}i\in[n]\\ c\neq{y_{i}}\end{subarray}}\,\left(\bm{e}_{y_{i}}-\bm{e}_{c}\right)^{\top}{\bm{W}}\bm{h}_{i}\,, (2)

where recall \|\cdot\| denotes entrywise pp-norm. Of special interest to us is the maximum margin with respect to the max-norm. For simplicity, we refer to this as γ:=max𝑾max1mini[n],cyi(𝒆yi𝒆c)𝑾𝒉i.\gamma:=\max_{{\left\|{\bm{W}}\right\|_{\max}}\leq 1}\,\min_{{i\in[n],c\neq{y_{i}}}}\,\left(\bm{e}_{y_{i}}-\bm{e}_{c}\right)^{\top}{\bm{W}}\bm{h}_{i}\,.

Optimization Methods

To minimize the empirical loss ERM\mathcal{L}_{\text{ERM}}, we use Adam without the stability constant, which performs the following coordinate-wise updates for iteration t0t\geq 0 and initialization 𝑾0{\bm{W}}_{0} (Kingma & Ba, 2014):

𝐌t\displaystyle\mathbf{M}_{t} =β1𝐌t1+(1β1)(𝑾t)\displaystyle=\beta_{1}\mathbf{M}_{t-1}+(1-\beta_{1})\nabla\mathcal{L}({\bm{W}}_{t}) (3a)
𝐕t\displaystyle\mathbf{V}_{t} =β2𝐕t1+(1β2)(𝑾t)2\displaystyle=\beta_{2}\mathbf{V}_{t-1}+(1-\beta_{2})\nabla\mathcal{L}({\bm{W}}_{t})^{2} (3b)
𝑾t+1\displaystyle{\bm{W}}_{t+1} =𝑾tηt𝐌t𝐕t,\displaystyle={\bm{W}}_{t}-\eta_{t}\frac{\mathbf{M}_{t}}{\sqrt{\mathbf{V}_{t}}}, (3c)

where 𝐌t\mathbf{M}_{t}, 𝐕t\mathbf{V}_{t} are first and second moment estimates of the gradient, with decay rates β1\beta_{1} and β2\beta_{2}. The squaring ()2(\cdot)^{2} and dividing \frac{\cdot}{\cdot} operations are applied entry-wise.

In the special case β1=β2=0\beta_{1}=\beta_{2}=0, the updates simplify to:

𝑾t+1\displaystyle{\bm{W}}_{t+1} =𝑾tηtsign((𝑾t))),\displaystyle={\bm{W}}_{t}-\eta_{t}\texttt{sign}(\nabla\mathcal{L}({\bm{W}}_{t}))), (4)

which we recognize as the signed gradient descent (SignGD) algorithm. We will also consider a generalization of SignGD called normalized steepest descent (NSD) (Boyd & Vandenberghe, 2004) with respect to entry-wise matrix pp-norm (p1p\geq 1) \lVert\cdot\rVert given by the updates:

𝑾t+1=𝑾tηt𝚫t,where\displaystyle{\bm{W}}_{t+1}={\bm{W}}_{t}-\eta_{t}\bm{\Delta}_{t},\qquad\text{where}
𝚫t:=argmax𝚫1(𝑾t),𝚫.\displaystyle\bm{\Delta}_{t}:=\arg\max\nolimits_{\|\bm{\Delta}\|\leq 1}\langle\nabla\mathcal{L}({\bm{W}}_{t}),\bm{\Delta}\rangle\,. (5)

Note that this reduces to SignGD, Coordinate Descent (e.g., (Nutini et al., 2015)), or normalized gradient-descent (NGD) when the max-norm (i.e. p=p=\infty), the entry-wise 11-norm, or the Frobenious Euclidean-norm (i.e. p=2p=2) is used, respectively.

Assumptions

Establishing the implicit bias of the above mentioned gradient-based optimization algorithms, requires the following assumptions. First, we assume data are linearly separable, ensuring the margin γ\gamma_{\|\cdot\|} is strictly positive, an assumption routinely used in previous works (Soudry et al., 2018; Ravi et al., 2024; Soudry et al., 2018; Gunasekar et al., 2018; Nacson et al., 2019; Wu et al., 2024b).

Assumption 3.1.

There exists 𝑾k×d{\bm{W}}\in\mathbb{R}^{k\times d} such that mincyi(𝒆yi𝒆c)T𝑾𝒉i>0\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}\bm{h}_{i}>0 for all i[n]i\in[n].

The second assumption ensures that all entries of the second moment buffer 𝑽t{\bm{V}}_{t} of Adam are bounded away from 0 for all t0t\geq 0. Previously used by Zhang et al. (2024) in binary classification, this assumption is satisfied when the data distribution is continuous and non-degenerate. A similar assumption appears in (Xie & Li, 2024).

Assumption 3.2.

The Adam initialization satisfies (𝑾0)[c,j]2ω\nabla\mathcal{L}({\bm{W}}_{0})[c,j]^{2}\geq\omega for all c[k]c\in[k] and j[d]j\in[d].

In this work, we consider a decay learning rate schedule of the form ηt=Θ(1ta)\eta_{t}=\Theta(\frac{1}{t^{a}}), where a(0,1]a\in(0,1]. Such schedule has been studied in the convergence and implicit bias of various optimization algorithms (e.g., (Bottou et al., 2018; Nacson et al., 2019; Sun et al., 2023)) including Adam (Huang et al., 2021; Zhang et al., 2024; Xie & Li, 2024).

Assumption 3.3.

The learning rate schedule {ηt}\{\eta_{t}\} is decreasing with respect to tt and satisfies the following conditions: limtηt=0\lim_{t\rightarrow\infty}\eta_{t}=0 and t=0ηt=\sum_{t=0}^{\infty}\eta_{t}=\infty.

Assumption 3.4 can be satisfied by the above learning rate for a sufficiently large tt shown in Zhang et al. (2024, Lemma C.1). It is used in our analysis of Adam.

Assumption 3.4.

The learning rate schedule satisfies the following: let β(0,1)\beta\in(0,1) and c1>0c_{1}>0 be two constants, there exist time t0+t_{0}\in\mathbb{N}_{+} and constant c2=c2(c1,β)>0c_{2}=c_{2}(c_{1},\beta)>0 such that s=0tβs(ec1τ=1sηsτ1)c2ηt\sum_{s=0}^{t}\beta^{s}(e^{c_{1}\sum_{\tau=1}^{s}\eta_{s-\tau}}-1)\leq c_{2}\eta_{t} for all tt0t\geq t_{0}.

Finally, we assume the 11-norm of the data is bounded. Similar assumptions were used in (Ji & Telgarsky, 2019), (Nacson et al., 2019), (Wu et al., 2024b), and (Zhang et al., 2024).

Assumption 3.5.

There exists constant B>0B>0 such that 𝒉i1B\lVert\bm{h}_{i}\rVert_{1}\leq B for all i[n]i\in[n].

4 𝒢(𝑾){\mathcal{G}}({\bm{W}}) - A proxy to (𝑾)\mathcal{L}({\bm{W}})

Analyzing margin convergence begins with studying loss convergence through second-order Taylor expansion of the CE loss:

(𝑾+𝚫)=(𝑾)+(𝑾),𝚫\displaystyle\mathcal{L}({\bm{W}}+\bm{\Delta})=\mathcal{L}({\bm{W}})+\langle\nabla\mathcal{L}({\bm{W}}),\bm{\Delta}\rangle
+12ni[n]𝒉i𝚫𝕊(𝑾𝒉i)𝚫𝒉i+o(𝚫F3),\displaystyle\quad+\frac{1}{2n}\sum\nolimits_{i\in[n]}\bm{h}_{i}^{\top}\bm{\Delta}^{\top}{\mathbb{S}}^{\prime}({\bm{W}}\bm{h}_{i})\bm{\Delta}\bm{h}_{i}+o(\|\bm{\Delta}\|_{F}^{3}), (6)

where recall that 𝕊(𝒗)=diag(𝒗)𝒗𝒗{\mathbb{S}}^{\prime}({\bm{v}})=\operatorname{diag}({\bm{v}})-{\bm{v}}{\bm{v}}^{\top}. To bound the loss at 𝑾t+1=𝑾tηt𝚫t{\bm{W}}_{t+1}={\bm{W}}_{t}-\eta_{t}\bm{\Delta}_{t}, we must bound both terms in (6). For the NSD updates defined in Eq. (5), the first term evaluates to ηt(𝑾)-\eta_{t}\|\nabla\mathcal{L}({\bm{W}})\|_{*} (\|\cdot\|_{*} is the dual norm). This leads to two key tasks: (1) Lower-bounding the dual gradient norm. (2) Upper-bounding the second-order term.

For the proof to proceed, these bounds should satisfy two desiderata: (1) They are expressible as the same function of 𝑾t{\bm{W}}_{t}, call it 𝒢(𝑾){\mathcal{G}}({\bm{W}}), up to constants. (2) The function 𝒢(𝑾){\mathcal{G}}({\bm{W}}) is a good proxy for the loss for small values of the latter. The former helps with combining the terms, while the latter helps with demonstrating descent. Next, we obtain these key bounds for the CE loss by determining the appropriate proxy 𝒢(𝑾){\mathcal{G}}({\bm{W}}). We focus on the matrix max-norm max{\left\|\cdot\right\|_{\max}}, which arises naturally in the analysis of SignGD and Adam. Later, we extend these results to arbitrary pp-norms \|\cdot\| for p1p\geq 1, establishing the implicit bias of pp-norm NSD.

Construction of 𝒢(𝑾){\mathcal{G}}({\bm{W}})

Before showing our construction for the CE loss, it is insightful to discuss how previous works do this in the binary case with labels yb,i±1y_{b,i}\in{\pm 1}, classifier vector 𝒘d{\bm{w}}\in\mathbb{R}^{d} and binary margin γbmax𝒘1mini[n]yb,i𝒘𝒉i\gamma_{b}\coloneqq\max_{\lVert{\bm{w}}\rVert\leq 1}\min_{i\in[n]}y_{b,i}{\bm{w}}^{\top}\bm{h}_{i}. For exponential loss, Gunasekar et al. (2018) showed that (𝒘)γb(𝒘)\lVert\nabla\mathcal{L}({\bm{w}})\rVert\geq\gamma_{b}\mathcal{L}({\bm{w}}). For logistic loss (t)=log(1+exp(t))\ell(t)=\log(1+\exp(-t)), Zhang et al. (2024) proved (𝒘)1γb𝒢(𝒘)\lVert\nabla\mathcal{L}({\bm{w}})\rVert_{1}\geq\gamma_{b}{\mathcal{G}}({\bm{w}}), where 𝒢(𝒘)=1ni=1n|(yb,i𝒘𝒉i)|{\mathcal{G}}({\bm{w}})=\frac{1}{n}\sum_{i=1}^{n}|\ell^{\prime}(y_{b,i}{\bm{w}}^{\top}\bm{h}_{i})| and \ell^{\prime} is the first-derivative. In both cases, one can take the common form 𝒢b(𝒘)=1ni=1n|(yb,i𝒘𝒉i)|{\mathcal{G}}_{b}({\bm{w}})=\frac{1}{n}\sum_{i=1}^{n}|\ell^{\prime}(y_{b,i}{\bm{w}}^{\top}\bm{h}_{i})|. The proof relies on showing γmin𝒓n1𝑯T𝒓\gamma\leq\min_{\bm{r}\in\triangle^{n-1}}\lVert{\bm{H}}^{T}\bm{r}\rVert via Fenchel Duality (Telgarsky, 2013; Gunasekar et al., 2018) and appropriately choosing 𝒓\bm{r}.

In the multiclass setting, where the loss function is vector-valued, it is unclear how to extend the binary proof or definition of 𝒢(𝑾){\mathcal{G}}({\bm{W}}). To this end, we realize that the key is in the proper manipulation of the gradient inner product 𝑨,(𝑾)\langle{\bm{A}},-\nabla\mathcal{L}({\bm{W}})\rangle (for arbitrary matrix 𝑨k×d{\bm{A}}\in\mathbb{R}^{k\times d}). The CE gradient evaluates to (𝑾)=1ni=1n(𝒆yi𝕊(𝑾𝒉i))𝒉i\nabla\mathcal{L}({\bm{W}})=\frac{1}{n}\sum_{i=1}^{n}(\bm{e}_{y_{i}}-\mathbb{S}({\bm{W}}\bm{h}_{i}))\bm{h}_{i}^{\top} and using the fact that 𝕊(𝑾𝒉i)k1\mathbb{S}({\bm{W}}\bm{h}_{i})\in\triangle^{k-1}, it turns out that we can express (details in Lemma A.1)

𝑨,(𝑾)=1ni[n]cyi𝕊c(𝑾𝒉i)(𝒆yi𝒆c)𝑨𝒉i.\langle{\bm{A}},-\nabla\mathcal{L}({\bm{W}})\rangle=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}\mathbb{S}_{c}({\bm{W}}\bm{h}_{i})(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{A}}\bm{h}_{i}\,.

This motivates defining 𝒢(𝑾){\mathcal{G}}({\bm{W}}) as:

𝒢(𝑾)1ni[n](1𝕊yi(𝑾𝒉i)).\displaystyle{\mathcal{G}}({\bm{W}})\coloneqq\frac{1}{n}\sum_{i\in[n]}(1-\mathbb{S}_{{y_{i}}}({\bm{W}}\bm{h}_{i}))\,. (7)

The following key lemma , which directly follows from the inner-product calculation above and our definition of 𝒢(𝑾){\mathcal{G}}({\bm{W}}), confirms that this is the right choice. For convenience, denote sic𝕊c(𝑾𝒉i)s_{ic}\coloneqq\mathbb{S}_{c}({\bm{W}}\bm{h}_{i}), for i[n],c[k]i\in[n],c\in[k].

Lemma 4.1 (Lower bounding the gradient dual-norm).

For any 𝐖k×d{\bm{W}}\in\mathbb{R}^{k\times d}, it holds that (𝐖)sumγ𝒢(𝐖){\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\rm{sum}}}\geq\gamma\cdot{\mathcal{G}}({\bm{W}}).

Proof.

By duality, the above calculated formula for the CE-gradient inner product and the fact that c[k]sic=1\sum_{c\in[k]}s_{ic}=1:

(𝑾)sum=max𝑨max1𝑨,(𝑾)\displaystyle{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\rm{sum}}}=\max\nolimits_{{\left\|\bm{A}\right\|_{\max}}\leq 1}\,\langle\bm{A},-\nabla\mathcal{L}({\bm{W}})\rangle
=max𝑨max11ni[n]cyisic(𝒆yi𝒆c)𝑨𝒉i\displaystyle=\max\nolimits_{{\left\|\bm{A}\right\|_{\max}}\leq 1}\,\frac{1}{n}\sum\nolimits_{i\in[n]}\sum\nolimits_{c\neq y_{i}}s_{ic}\,(\bm{e}_{{y_{i}}}-\bm{e}_{c})^{\top}\bm{A}\bm{h}_{i}
max𝑨max11ni[n](1siyi)mincyi(𝒆yi𝒆c)T𝑨𝒉i\displaystyle\geq\max_{{\left\|\bm{A}\right\|_{\max}}\leq 1}\frac{1}{n}\sum_{i\in[n]}(1-s_{iy_{i}})\,\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}\bm{A}\bm{h}_{i}
=1ni[n](1siyi)max𝑨max1mini[n],cyi(𝒆yi𝒆c)T𝑨𝒉i.\displaystyle=\frac{1}{n}\sum_{i\in[n]}(1-s_{iy_{i}})\cdot\max_{{\left\|\bm{A}\right\|_{\max}}\leq 1}\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}\bm{A}\bm{h}_{i}\,.

To finish recall the definitions of 𝒢(𝑾){\mathcal{G}}({\bm{W}}) and of the max-norm margin γ\gamma. ∎

The lemma completes the first of our two tasks: lower bounding the gradient’s dual norm. Importantly, the factor appearing in front of 𝒢(𝑾){\mathcal{G}}({\bm{W}}) is the max-norm margin γ\gamma, which will prove crucial in the forthcoming margin analysis.

𝒢(𝑾){\mathcal{G}}({\bm{W}}) and second-order term

We now show how to bound the second-order term in (6). For this, we establish the following essential lemma.

Lemma 4.2.

For any 𝐬Δk1\bm{s}\in\Delta^{k-1} in the kk-dimensional simplex, any index c[k]c\in[k], and any 𝐯k{\bm{v}}\in\mathbb{R}^{k} it holds:

𝒗(diag(𝒔)𝒔𝒔)𝒗4(1sc)𝒗2.{\bm{v}}^{\top}\left(\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\right){\bm{v}}\leq 4\,(1-s_{c})\,\|{\bm{v}}\|_{\infty}^{2}\,.
Proof.

By Cauchy-Schwartz,

𝒗\displaystyle{\bm{v}}^{\top} (diag(𝒔)𝒔𝒔)𝒗=diag(𝒔)𝒔𝒔,𝒗𝒗\displaystyle\left(\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\right){\bm{v}}=\langle\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top},{\bm{v}}{\bm{v}}^{\top}\rangle
diag(𝒔)𝒔𝒔sum𝒗2.\displaystyle\leq\,{\left\|\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\right\|_{\rm{sum}}}\,\|{\bm{v}}\|_{\infty}^{2}\,.

Direct calculation yields diag(𝒔)𝒔𝒔sum=2c[k]sc(1sc){\left\|\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\right\|_{\rm{sum}}}=2\sum\nolimits_{c\in[k]}s_{c}(1-s_{c}). The advertised bound then follows by noting the following c[k]sc(1sc)2(1sc)\sum_{c\in[k]}s_{c}(1-s_{c})\leq 2(1-s_{c^{\prime}}) for any c[k]c^{\prime}\in[k] (verified in Lemma A.4). ∎

To bound the second-order term in Eq. (6), we can apply the above lemma with 𝒗𝚫𝒉i{\bm{v}}\leftarrow\bm{\Delta}\bm{h}_{i} and cyic\leftarrow y_{i} and further use the inequality 𝚫𝒉i𝚫max𝒉i1B𝚫max\|{\bm{\Delta}\bm{h}_{i}}\|_{\infty}\leq{\left\|\bm{\Delta}\right\|_{\max}}\|\bm{h}_{i}\|_{1}\leq B{\left\|\bm{\Delta}\right\|_{\max}} (the last step invoked Ass. 3.5). This yields an upper bound

2B2𝚫max21ni[n](1𝕊yi(𝑾𝒉i))2B^{2}\|{\bm{\Delta}}\|_{\max}^{2}\cdot\frac{1}{n}\sum_{i\in[n]}(1-\mathbb{S}_{y_{i}}({\bm{W}}\bm{h}_{i}))

for the second-order term in the CE loss expansion in terms of the proxy function 𝒢(𝑾){\mathcal{G}}({\bm{W}}). Thus, we have fullfilled the first desiderate by finding 𝒢(𝑾){\mathcal{G}}({\bm{W}}) that simultaneously bounds the first and second-order terms in Eq. (6).

Properties of 𝒢(𝑾){\mathcal{G}}({\bm{W}})

We now show that the function 𝒢(𝑾){\mathcal{G}}({\bm{W}}) in Eq. (7) meets the second desiderata: being a good proxy for the loss (𝑾)\mathcal{L}({\bm{W}}). This is rooted in elementary relationships between 𝒢(𝑾){\mathcal{G}}({\bm{W}}) and (𝑾)\mathcal{L}({\bm{W}}), which play crucial roles in the various parts of the proof. Below, we summarize this key relationships.

Lemma 4.3 (Properties of 𝒢(𝑾){\mathcal{G}}({\bm{W}}) and (𝑾)\mathcal{L}({\bm{W}})).

Let 𝐖k×d{\bm{W}}\in\mathbb{R}^{k\times d}. The following hold: (i) Under Ass. 3.5, 2B𝒢(𝐖)(𝐖)sum2B\cdot{\mathcal{G}}({\bm{W}})\geq{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\rm{sum}}}; (ii) 1𝒢(𝐖)(𝐖)1n(𝐖)21\geq\frac{{\mathcal{G}}({\bm{W}})}{\mathcal{L}({\bm{W}})}\geq 1-\frac{n\mathcal{L}({\bm{W}})}{2}; (iii) If 𝐖{\bm{W}} satisfies (𝐖)log2n\mathcal{L}({\bm{W}})\leq\frac{\log 2}{n} or 𝒢(𝐖)12n{\mathcal{G}}({\bm{W}})\leq\frac{1}{2n}, then (𝐖)2𝒢(𝐖)\mathcal{L}({\bm{W}})\leq 2{\mathcal{G}}({\bm{W}}).

Lemma 4.3 (i) extends Lemma 4.1 by establishing a sandwich relationship between 𝒢(𝑾){\mathcal{G}}({\bm{W}}) and the gradient’s dual norm. The lemma’s statements (ii) and (iii) show that 𝒢(𝑾){\mathcal{G}}({\bm{W}}) can substitute for the loss - it lower bounds (𝑾)\mathcal{L}({\bm{W}}) and serves as an upper bound when either (𝑾)\mathcal{L}({\bm{W}}) or 𝒢(𝑾){\mathcal{G}}({\bm{W}}) is sufficiently small. Specifically, the ratio 𝒢(𝑾)/(𝑾){{\mathcal{G}}({\bm{W}})}\big{/}{\mathcal{L}({\bm{W}})} converges to 1 as the loss decreases, with the convergence rate depending on the rate of decrease. The key property (ii) may seem algebraically complex, but it turns out (Lemma B.3) that both sides of the sandwich relationship follow from the elementary fact that x>0:1xex1x+x2/2\forall x>0:1-x\leq e^{-x}\leq 1-x+x^{2}/2.

5 Implicit Bias of SignGD

We now leverage our construction of 𝒢(𝑾){\mathcal{G}}({\bm{W}}) to show that the margin of SignGD’s iterates converges to max-norm margin. We only highlight the key steps in the proof and defer details to Appendix C.

SignGD Descent

We start by showing a descent property. By applying Lemmas 4.1 and 4.2 to lower and upper bound the first and second order terms in Eq. (6) yields:

(𝑾t+1)\displaystyle\mathcal{L}({\bm{W}}_{t+1}) (𝑾t)γηt𝒢(𝑾t)+\displaystyle\leq\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}{\mathcal{G}}({\bm{W}}_{t})+
2ηt2B2𝒢(𝑾t)supζ[0,1]𝒢(𝑾t+ζ𝚫t)𝒢(𝑾t).\displaystyle\quad\quad 2\eta_{t}^{2}B^{2}{\mathcal{G}}({\bm{W}}_{t})\sup_{\zeta\in[0,1]}\frac{{\mathcal{G}}({\bm{W}}_{t}+\zeta\bm{\Delta}_{t})}{{\mathcal{G}}({\bm{W}}_{t})}.

Algebraic manipulations of the definition of 𝒢(𝑾){\mathcal{G}}({\bm{W}}) allows us to bound the ratio in the right hand side.

Lemma 5.1 (Ratio of 𝒢(𝑾){\mathcal{G}}({\bm{W}})).

For any ψ[0,1]\psi\in[0,1], we have the following: 𝒢(𝐖+ψ𝐖)𝒢(𝐖)e2Bψ𝐖max\frac{{\mathcal{G}}({\bm{W}}+\psi\triangle{\bm{W}})}{{\mathcal{G}}({\bm{W}})}\leq e^{2B\psi{\left\|\triangle{\bm{W}}\right\|_{\max}}}.

From this and the fact 𝚫tmaxηt{\left\|\bm{\Delta}_{t}\right\|_{\max}}\leq\eta_{t} for SignGD, we obtain

(𝑾t+1)\displaystyle\mathcal{L}({\bm{W}}_{t+1}) (𝑾t)γηt(1αs1ηt)𝒢(𝑾t),\displaystyle\leq\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}(1-\alpha_{s_{1}}\eta_{t}){\mathcal{G}}({\bm{W}}_{t}), (8)

where αs1=2B2e2Bη0/γ\alpha_{s_{1}}=2B^{2}e^{2B\eta_{0}}/\gamma. Given a decay learning rate of the form ηt=Θ(1ta)\eta_{t}=\Theta(\frac{1}{t^{a}}), we can conclude that the loss starts to monotonically decrease after some time.

SignGD Unnormalized Margin

We now use the descent property in (8) to lower bound the unnormalized margin. An intermediate result towards this is recognizing that sufficiently small loss (𝑾)log2n\mathcal{L}({\bm{W}})\leq\frac{\log 2}{n} guarantees 𝑾{\bm{W}} separates the data (Lemma B.4). The descent property ensures that SignGD iterates will eventually achieve this loss threshold, thereby guaranteeing separability. The main result of this section, shows that eventually the iterates achieve separability with a substantial margin.

Lemma 5.2 (SignGD Unnormalized Margin).

Assume t~\tilde{t} such that (𝐖t)log2n,t>t~\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n},\forall t>\tilde{t}. Then, the minimum unnormalized margin mini[n],cyi(𝐞yi𝐞c)T𝐖t𝐡i\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i} of iterates 𝐖t,tt~{\bm{W}}_{t},t\geq\tilde{t} is lower bounded by (αs2=2Be2Bη0\alpha_{s_{2}}=2Be^{2B\eta_{0}})

γs=t~t1ηs𝒢(𝑾s)(𝑾s)αs2s=t~t1ηs2.\displaystyle\gamma\sum_{s=\tilde{t}}^{t-1}\eta_{s}\frac{{\mathcal{G}}({\bm{W}}_{s})}{\mathcal{L}({\bm{W}}_{s})}-\alpha_{s_{2}}\sum_{s=\tilde{t}}^{t-1}\eta_{s}^{2}. (9)
Proof.

By exponentiating the unnormalized margin:

emini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉i=maxi[n]emincyi(𝒆yi𝒆c)T𝑾t𝒉i\displaystyle e^{-\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}=\max_{i\in[n]}e^{-\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}
(a)maxi[n]1log2log(1+emincyi(𝒆yi𝒆c)T𝑾t𝒉i)\displaystyle\stackrel{{\scriptstyle(a)}}{{\leq}}\max_{i\in[n]}\frac{1}{\log 2}\log\bigl{(}1+e^{-\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}\bigr{)}
maxi[n]1log2log(1+cyie(𝒆yi𝒆c)T𝑾t𝒉i)n(𝑾t)log2\displaystyle\leq\max_{i\in[n]}\frac{1}{\log 2}\log(1+\sum_{c\neq y_{i}}e^{-(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}})\leq\frac{n\mathcal{L}({\bm{W}}_{t})}{\log 2}
(b)exp(γs=t~t1ηs𝒢(𝑾s)(𝑾s)+αs2s=t~t1ηs2).\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}\exp\bigl{(}-\gamma\sum_{s=\tilde{t}}^{t-1}\eta_{s}\frac{{\mathcal{G}}({\bm{W}}_{s})}{\mathcal{L}({\bm{W}}_{s})}+\alpha_{s_{2}}\sum_{s=\tilde{t}}^{t-1}\eta_{s}^{2}\bigr{)}.

(a) is because log(1+ez)ezlog2,z0\frac{\log(1+e^{-z})}{e^{-z}}\geq\log 2,z\geq 0 with zz chosen to be mincyi(𝒆yi𝒆c)T𝑾t𝒉i\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i} for any i[n]i\in[n]; (b) is by some manipulations of (8) (details in App. C.2). ∎

SignGD Margin Convergence

Proceeding from Eq. (9) requires showing convergence of the ratio 𝒢(𝑾)(𝑾)\frac{{\mathcal{G}}({\bm{W}})}{\mathcal{L}({\bm{W}})}. The two key ingredients are given in Lemma 4.3 (ii) and (iii). Lemma 4.3 (ii) suggests that it is sufficient to study the convergence of (𝑾)\mathcal{L}({\bm{W}}), which is captured in (8). However, to obtain an explicit rate via (8), we need to rewrite 𝒢(𝑾t){\mathcal{G}}({\bm{W}}_{t}) in terms of (𝑾t)\mathcal{L}({\bm{W}}_{t}). This is where Lemma 4.3 (iii) helps. Putting them together, we arrive at the following theorem.

Theorem 5.3.

Suppose that Ass. 3.1, 3.3, and 3.5 hold, then there exists ts2=ts2(n,γ,B,𝐖0)t_{s_{2}}=t_{s_{2}}(n,\gamma,B,{\bm{W}}_{0}) such that the margin gap γmini[n],cyi(𝐞yi𝐞c)T𝐖t𝐡i/𝐖tmax\gamma-{\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}/{\left\|{\bm{W}}_{t}\right\|_{\max}} of SignGD’s iterates for all t>ts2t>t_{s_{2}} is upper bounded by

𝒪(s=ts2t1ηseγ4τ=ts2s1ητ+s=0ts21ηs+s=ts2t1ηs2s=0t1ηs).\displaystyle\mathcal{O}\Big{(}\frac{\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}e^{-\frac{\gamma}{4}\sum_{\tau=t_{s_{2}}}^{s-1}\eta_{\tau}}+\sum_{s=0}^{t_{s_{2}}-1}\eta_{s}+\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}^{2}}{\sum_{s=0}^{t-1}\eta_{s}}\Big{)}.
Proof.

SignGD’s update rule (4) gives 𝑾tmax𝑾0max+s=0t1ηs{\left\|{\bm{W}}_{t}\right\|_{\max}}\leq{\left\|{\bm{W}}_{0}\right\|_{\max}}+\sum_{s=0}^{t-1}\eta_{s}. We first find iteration ts1t_{s_{1}} via (8) such that (𝑾t+1)(𝑾t)ηtγ2𝒢(𝑾t)\mathcal{L}({\bm{W}}_{t+1})\leq\mathcal{L}({\bm{W}}_{t})-\frac{\eta_{t}\gamma}{2}{\mathcal{G}}({\bm{W}}_{t}) for all tts1t\geq t_{s_{1}}. Then, we find iteration ts2>ts1t_{s_{2}}>t_{s_{1}} such that (𝑾t)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} for all tts2t\geq t_{s_{2}}. By Lemma 4.3 (iii), this guarantees (𝑾t)2𝒢(𝑾t)\mathcal{L}({\bm{W}}_{t})\leq 2{\mathcal{G}}({\bm{W}}_{t}). Substituting it to the above recursion on (𝑾t)\mathcal{L}({\bm{W}}_{t}), we obtain (𝑾t)log2neγ4s=ts2t1ηs\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n}e^{-\frac{\gamma}{4}\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}}. By Lemma 4.3 (ii), we further obtain 𝒢(𝑾t)(𝑾t)1eγ4s=ts2t1ηs\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}\geq 1-e^{-\frac{\gamma}{4}\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}}. Combining this with (9) and the upper bound on 𝑾tmax{\left\|{\bm{W}}_{t}\right\|_{\max}}, the final result is proved (details in App. C.4). ∎

In the proof, ts1t_{s_{1}} can be set to (4B2e2Bη0/γ)1/a(4B^{2}e^{2B\eta_{0}}/{\gamma})^{1/a}, and ts2t_{s_{2}} is chosen such that ts2Θ(n11a+n11a(𝑾0)11a)t_{s_{2}}\leq\Theta(n^{\frac{1}{1-a}}+n^{\frac{1}{1-a}}\mathcal{L}({\bm{W}}_{0})^{\frac{1}{1-a}}). Note that both are independent of the problem’s dimensionality, yielding the following convergence rates.

Corollary 5.4.

Set learning rate ηt=Θ(1ta),a(0,1]\eta_{t}=\Theta(\frac{1}{t^{a}}),a\in(0,1]. Under the setting of Theorem 5.3, the margin gap of SignGD iterates with a=1/2a=1/2 reduces at rate 𝒪(logt+nt1/2)\mathcal{O}(\frac{\log t+n}{t^{1/2}}).111The rates for other values of aa can be found in Corollary C.5.

Remark 5.5.

These results generalize to NSD defined in (5). Concretely, Lemma 4.1, 4.2, and 4.3 (i) generalize to any pp-norm (where margin in (2) is defined w.r.t. the same norm) together with the dual norm on the loss gradient. These results are summarized in Lemmas E.1 and E.3 in App. E. In terms of margin convergence of NSD, (Nacson et al., 2019) showed a rate of 𝒪(logtt1/2)\mathcal{O}(\frac{\log t}{t^{1/2}}) in the binary setting, limited to the exponential loss. Compared to this, our results hold for the more practical setting of multilcass data and CE loss. The rate in Cor. 5.4 surpasses the max-norm margin convergence rate of Adam in the binary setting (Zhang et al., 2024); thus even in the binary setting it improves the latter and extends (Nacson et al., 2019)’s result to logistic loss. We elaborate on this gap in the next section.

6 Implicit Bias of Adam

Our analysis of Adam for multiclass data, while inspired by the binary setting of (Zhang et al., 2024), requires significant new technical insights to achieve the same convergence rate without class-dependent factors. While the proof follows similar high-level steps as SignGD above and leverages all the key properties of 𝒢(𝑾){\mathcal{G}}({\bm{W}}) identified in Sec. B, using 𝒢(𝑾){\mathcal{G}}({\bm{W}}) alone would lead to suboptimal bounds with an extra factor of kk. Our crucial insight lies in a per-class decomposition of 𝒢(𝑾){\mathcal{G}}({\bm{W}}) as follows:

𝒢(𝑾)\displaystyle{\mathcal{G}}({\bm{W}}) =c[k]1ni[n],yi=c(1siyi)\displaystyle=\sum_{c\in[k]}\frac{1}{n}\sum_{i\in[n],y_{i}=c}(1-s_{iyi})
=c[k]1ni[n],yicsic,\displaystyle=\sum_{c\in[k]}\frac{1}{n}\sum_{i\in[n],y_{i}\neq c}s_{ic},

where recall sic𝕊c(𝑾𝒉i)s_{ic}\coloneqq\mathbb{S}_{c}({\bm{W}}\bm{h}_{i}). This decomposition motivates further defining the “per-class proxies”:

𝒢c(𝑾)\displaystyle{\mathcal{G}}_{c}({\bm{W}}) 1ni[n],yi=c(1siyi)\displaystyle\coloneqq\frac{1}{n}\sum_{i\in[n],y_{i}=c}(1-s_{iy_{i}})
𝒬c(𝑾)\displaystyle\mathcal{Q}_{c}({\bm{W}}) 1ni[n],yicsic.\displaystyle\coloneqq\frac{1}{n}\sum_{i\in[n],y_{i}\neq c}s_{ic}.

Both terms play crucial roles in our per-class analysis, with this decomposition being key to avoiding the factor kk in our bounds. The remainder of this section highlights the technical challenges beyond those encountered in analyzing SignGD (full proof details are provided in Appendix D).

Adam Descent

To show descent of Adam via Taylor expansion (6), the first-order term involves an additional factor when introducing the dual norm of the gradient, i.e.,

(𝑾t),𝚫tηt(𝑾t)sum\displaystyle\langle\nabla\mathcal{L}({\bm{W}}_{t}),\bm{\Delta}_{t}\rangle\leq-\eta_{t}{\left\|\nabla\mathcal{L}({\bm{W}}_{t})\right\|_{\rm{sum}}}
+ηt|(𝑾t),𝑴t𝑽t(𝑾t)|(𝑾t)||.\displaystyle\qquad+\eta_{t}\underbrace{\bigm{|}\langle\nabla\mathcal{L}({\bm{W}}_{t}),\frac{\bm{M}_{t}}{\sqrt{{\bm{V}}_{t}}}-\frac{\nabla\mathcal{L}({\bm{W}}_{t})}{|\nabla\mathcal{L}({\bm{W}}_{t})|}\rangle\bigm{|}}_{\clubsuit}. (10)

Given that we bound the remaining terms in (6) using 𝒢(𝑾){\mathcal{G}}({\bm{W}}), it seems natural to bound the \clubsuit term similarly. While one could attempt to relate the entries of both matrices 𝑴t\bm{M}_{t} and 𝑽t{\bm{V}}_{t} to 𝒢(𝑾t){\mathcal{G}}({\bm{W}}_{t}), our entry-wise analysis of the \clubsuit term (following (Zhang et al., 2024)’s approach in the binary setting) would inevitably introduce an extra factor of kk. To avoid this, we instead relate the row vectors 𝑴t[c,:]\bm{M}_{t}[c,:] and 𝑽t[c,:]{\bm{V}}_{t}[c,:] to the class-specific functions 𝒢c(𝑾){\mathcal{G}}_{c}({\bm{W}}) and 𝒬c(𝑾)\mathcal{Q}_{c}({\bm{W}}) for each c[k]c\in[k]. This class-wise approach requires careful accounting of interactions between the k>2k>2 classes throughout our analysis. We first show how to bound 𝑴t\bm{M}_{t}.

Lemma 6.1.

Let c[k]c\in[k]. Under the setting of Theorem 6.3, there exists time t0t_{0} such that for all tt0t\geq t_{0}:

|𝐌t[c,j](1β1t+1)\displaystyle|\mathbf{M}_{t}[c,j]-(1-\beta_{1}^{t+1}) (𝑾t)[c,j]|\displaystyle\nabla\mathcal{L}({\bm{W}}_{t})[c,j]|\leq
αMηt(𝒢c(𝑾t)+𝒬c(𝑾t)),\displaystyle\alpha_{M}\eta_{t}\bigl{(}{\mathcal{G}}_{c}({\bm{W}}_{t})+\mathcal{Q}_{c}({\bm{W}}_{t})\bigr{)},

where j[d]j\in[d], and αM:=B(1β1)c2\alpha_{M}:=B(1-\beta_{1})c_{2}.

Here, we provide a quick proof sketch. By the Adam update rule 3a, we have for all c[k]c\in[k] and j[d]j\in[d]: the following |𝐌t[c,j](1β1t+1)(𝑾t)[c,j]||\mathbf{M}_{t}[c,j]-(1-\beta_{1}^{t+1})\nabla\mathcal{L}({\bm{W}}_{t})[c,j]| can be bounded by

τ=0t(1β1)β1τ|(𝑾tτ)[c,j](𝑾t)[c,j]|.\displaystyle\sum_{\tau=0}^{t}(1-\beta_{1})\beta_{1}^{\tau}\underbrace{|\nabla\mathcal{L}({\bm{W}}_{t-\tau})[c,j]-\nabla\mathcal{L}({\bm{W}}_{t})[c,j]|}_{\spadesuit}.

By explicitly writing out the gradient and grouping terms we can show:

\displaystyle\spadesuit B1ni[n]|𝕊c(𝑾tτ𝒉i)𝕊c(𝑾t𝒉i)|\displaystyle\leq B\frac{1}{n}\sum\nolimits_{i\in[n]}|\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})|
=B1ni[n]:yic|𝕊c(𝑾tτ𝒉i)𝕊c(𝑾t𝒉i)|1\displaystyle=B\underbrace{\frac{1}{n}\sum\nolimits_{i\in[n]:y_{i}\neq c}|\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})|}_{\spadesuit_{1}}
+B1ni[n]:yi=c|𝕊yi(𝑾tτ𝒉i)𝕊yi(𝑾t𝒉i)|2.\displaystyle\quad\quad+B\underbrace{\frac{1}{n}\sum\nolimits_{i\in[n]:y_{i}=c}|\mathbb{S}_{y_{i}}({\bm{W}}_{t-\tau}\bm{h}_{i})-\mathbb{S}_{y_{i}}({\bm{W}}_{t}\bm{h}_{i})|}_{\spadesuit_{2}}.

In the above equality, we split the sum into two cases: samples where yi=cy_{i}=c and the rest. Using the definitions of 𝒢c(𝑾){\mathcal{G}}_{c}({\bm{W}}) and 𝒬c(𝑾)\mathcal{Q}_{c}({\bm{W}}), we manipulate and express the first (1\spadesuit_{1}) and second (2\spadesuit_{2}) terms via 𝒬c(𝑾)\mathcal{Q}_{c}({\bm{W}}) and 𝒢c(𝑾){\mathcal{G}}_{c}({\bm{W}}) respectively. This leaves us to bound expressions of the form |𝕊c(𝒗)/𝕊c(𝒗)1||\mathbb{S}_{c}({\bm{v}}^{\prime})/\mathbb{S}_{c}({\bm{v}})-1| and |(1𝕊c(𝒗))/(1𝕊c(𝒗))1||(1-\mathbb{S}_{c}({\bm{v}}^{\prime}))/(1-\mathbb{S}_{c}({\bm{v}}))-1| respectively for the two terms. We address this technical challenge in Lemma D.10, showing how to bound such terms using 𝒗𝒗\|{\bm{v}}-{\bm{v}}^{\prime}\|_{\infty}, which we can control for Adam. We now show how to bound 𝑽t{\bm{V}}_{t}.

Lemma 6.2.

Let c[k]c\in[k]. Under the setting of Theorem 6.3, there exists time t0t_{0} such that for all tt0t\geq t_{0}:

|𝑽t[c,j](1β2t+1)|\displaystyle\bigm{|}\sqrt{{\bm{V}}_{t}[c,j]}-\sqrt{(1-\beta_{2}^{t+1})}| (𝑾t)[c,j]||\displaystyle\nabla\mathcal{L}({\bm{W}}_{t})[c,j]|\bigm{|}\leq
αVηt(𝒢c(𝑾t)+𝒬c(𝑾t)),\displaystyle\alpha_{V}\sqrt{\eta_{t}}\bigl{(}{\mathcal{G}}_{c}({\bm{W}}_{t})+\mathcal{Q}_{c}({\bm{W}}_{t})\bigr{)},

where j[d]j\in[d], and αV=B(1β2)c2\alpha_{V}=B\sqrt{(1-\beta_{2})c_{2}}.

By Adam’s update rule (3c), we have c[k],j[d]\forall c\in[k],j\in[d] that |𝑽t[c,j](1β2t+1)(𝑾t)[c,j]2||{\bm{V}}_{t}[c,j]-(1-\beta_{2}^{t+1})\nabla\mathcal{L}({\bm{W}}_{t})[c,j]^{2}| can be bounded by

τ=0t(1β2)β2τ|(𝑾tτ)[c,j]2(𝑾t)[c,j]2|.\displaystyle\sum_{\tau=0}^{t}(1-\beta_{2})\beta_{2}^{\tau}\underbrace{|\nabla\mathcal{L}({\bm{W}}_{t-\tau})[c,j]^{2}-\nabla\mathcal{L}({\bm{W}}_{t})[c,j]^{2}|}_{\diamond}.

For the \diamond term, we compute (𝑾)[c,j]2\nabla\mathcal{L}({\bm{W}})[c,j]^{2} and define the function fc,i,p(𝑾):=(δcyi𝕊c(𝑾𝒉i))(δcyp𝕊c(𝑾𝒉p))f_{c,i,p}({\bm{W}}):=(\delta_{cy_{i}}-\mathbb{S}_{c}({\bm{W}}\bm{h}_{i}))(\delta_{cy_{p}}-\mathbb{S}_{c}({\bm{W}}\bm{h}_{p})) to obtain (after some algebra):

B2n2i[n]p[n]|fc,i,p(𝑾tτ)fc,i,p(𝑾t)|.\displaystyle\diamond\leq\frac{B^{2}}{n^{2}}\sum\nolimits_{i\in[n]}\sum\nolimits_{p\in[n]}|f_{c,i,p}({\bm{W}}_{t-\tau})-f_{c,i,p}({\bm{W}}_{t})|.

Unlike the first moment 𝑴t\bm{M}_{t}, this summation involves pairs of sample indices ii and pp with label cc taking values equal to one of four cases: yi,non-yi,yp,non-ypy_{i},\text{non-}y_{i},y_{p},\text{non-}y_{p}. Following our first-moment analysis strategy, we decompose the sum into these four components. The resulting bounds involve squared terms 𝒢c(𝑾)2{\mathcal{G}}_{c}({\bm{W}})^{2}, 𝒬c(𝑾)2\mathcal{Q}_{c}({\bm{W}})^{2}, and 𝒢c(𝑾)𝒬c(𝑾){\mathcal{G}}_{c}({\bm{W}})\mathcal{Q}_{c}({\bm{W}}) due to the quadratic softmax expressions in each component (see Lemma D.10 for bounding ratios of softmax quadratics).

With Lemmas 6.1 and 6.2 at hand, we can now bound the Adam-specific term \clubsuit in Eq. (10) in terms of 𝒢(𝑾){\mathcal{G}}({\bm{W}}). The essence is to do the double sum (over c[k]c\in[k] and j[d]j\in[d]) in \clubsuit in two stages: first sum over dd for each c[k]c\in[k] with the help of Lemma 6.1 and 6.2; then sum over cc with the recognition of 𝒢(𝑾)=c[k]𝒢c(𝑾)=c[k]𝒬c(𝑾){\mathcal{G}}({\bm{W}})=\sum_{c\in[k]}{\mathcal{G}}_{c}({\bm{W}})=\sum_{c\in[k]}\mathcal{Q}_{c}({\bm{W}}). The results are summarized in Lemma D.4. From this point on, following similar steps of SignGD, we can show the CE loss (eventually) monotonically decreases (see Lemma D.5).

Adam Implicit Bias

To establish Adam’s implicit bias, we follow the same strategy as SignGD: bound the margin and control weight growth. For the margin bound, we leverage our descent property from above (details in Lemma D.6). Unlike SignGD’s simple updates, Adam’s moment averaging makes bounding 𝑾tmax{\left\|{\bm{W}}_{t}\right\|_{\max}} more challenging. Our solution again relies on the proxy 𝒢(𝑾){\mathcal{G}}({\bm{W}}), which we have shown approximates both loss and gradients well. Specifically, using this, we prove that the second moment 𝑽t{\bm{V}}_{t} remains controlled when the loss is small, i.e. c[k],j[d]\forall c\in[k],j\in[d]:

𝑽t[c,j]\displaystyle{\bm{V}}_{t}[c,j] (𝑾t)[c,j]2+αVηt𝒢(𝑾t)2\displaystyle\leq\nabla\mathcal{L}({\bm{W}}_{t})[c,j]^{2}+\alpha_{V}\eta_{t}{\mathcal{G}}({\bm{W}}_{t})^{2}
4B2𝒢(𝑾t)2+αVη0𝒢(𝑾t)2\displaystyle\leq 4B^{2}{\mathcal{G}}({\bm{W}}_{t})^{2}+\alpha_{V}\eta_{0}{\mathcal{G}}({\bm{W}}_{t})^{2}
(4B2+αVη0)(𝑾t)2,\displaystyle\leq(4B^{2}+\alpha_{V}\eta_{0})\mathcal{L}({\bm{W}}_{t})^{2},

where the penultimate and last inequalities are by Lemma 4.3 (i) and (ii), respectively. This suggests that with sufficiently small loss, all entries of 𝑽t{\bm{V}}_{t} remain bounded by 1. Building on our bound for 𝑽t{\bm{V}}_{t}, we apply Zhang et al. (2024, Lemma A.4) (see also Xie & Li (2024, Lemma 4.2)), which connects 𝑾t{\bm{W}}_{t} and log(𝑽t)\log({\bm{V}}_{t}) entry-wise. This allows us to bound 𝑾tmax{\left\|{\bm{W}}_{t}\right\|_{\max}} using learning rate sums under Ass. 3.2 (details in Lemma D.7). With the convergence 𝒢(𝑾)(𝑾)1\frac{{\mathcal{G}}({\bm{W}})}{\mathcal{L}({\bm{W}})}\rightarrow 1 following as in SignGD, we arrive at our main theorem.

Theorem 6.3.

Suppose that Ass. 3.1-3.5 hold, and β1β2\beta_{1}\leq\beta_{2}. There exists ta2=ta2(n,d,γ,B,𝐖0,β1,β2,ω)t_{a_{2}}=t_{a_{2}}(n,d,\gamma,B,{\bm{W}}_{0},\beta_{1},\beta_{2},\omega) such that the margin gap γmini[n],cyi(𝐞yi𝐞c)T𝐖t𝐡i/𝐖tmax\gamma-{\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}/{\left\|{\bm{W}}_{t}\right\|_{\max}} of Adam’s iterates for for all t>ta2t>t_{a_{2}} is upper bounded by

𝒪(s=ta2t1ηseγ4τ=ta2s1ητ+s=0ta21ηs+ds=ta2t1ηs3/2s=0t1ηs).\displaystyle\mathcal{O}(\frac{\sum_{s=t_{a_{2}}}^{t-1}\eta_{s}e^{-\frac{\gamma}{4}\sum_{\tau=t_{a_{2}}}^{s-1}\eta_{\tau}}+\sum_{s=0}^{t_{a_{2}}-1}\eta_{s}+d\sum_{s=t_{a_{2}}}^{t-1}\eta_{s}^{3/2}}{\sum_{s=0}^{t-1}\eta_{s}}).

The referenced iterations are ta1=Θ(d2/a)t_{a_{1}}=\Theta(d^{2/a}) and ta2Θ(n11ad2/a+n11a(𝑾0)11a+log(1/ω))t_{a_{2}}\leq\Theta(n^{\frac{1}{1-a}}d^{2/a}+n^{\frac{1}{1-a}}\mathcal{L}({\bm{W}}_{0})^{\frac{1}{1-a}}+\log(1/\omega)). Note that compared to the respective iterations for SignGD, ta1t_{a_{1}} depends on dd, unlike ts1t_{s_{1}}. These values the explicit rates below.

Corollary 6.4.

Set learning rate ηt=Θ(1ta),a(0,1]\eta_{t}=\Theta(\frac{1}{t^{a}}),a\in(0,1]. Under the setting of Theorem 6.3, the margin gap of Adam iterates with a=2/3a=2/3 reduces at rate 𝒪(dlog(t)+nd+[log(1/ω)]1/3t1/3)\mathcal{O}(\frac{d\log(t)+nd+[\log(1/\omega)]^{1/3}}{t^{1/3}}).222The rates for other values of aa can be found in Corollary D.9.

Remark 6.5.

These rates exactly match those in the binary case of (Zhang et al., 2024) with logarithmic dependence on the initialization parameter ω\omega (Ass. 3.2). This is only made possible through the fine-grained per-class bounding of the first and second moments using both 𝒢c(𝑾){\mathcal{G}}_{c}({\bm{W}}) and 𝒬c(𝑾)\mathcal{Q}_{c}({\bm{W}}). Note that Lemma D.4 takes the same form as Zhang et al. (2024, Lemma A.4). However, without the tight per-class bound and the equivalent decomposition of 𝒢(𝑾){\mathcal{G}}({\bm{W}}) using either 𝒬c(𝑾)\mathcal{Q}_{c}({\bm{W}}) or 𝒢c(𝑾){\mathcal{G}}_{c}({\bm{W}}), an extra factor of kk would appear. Interestingly, our rates for SignGD in Corollary 5.4 reveal a theoretical gap: Adam’s optimal choice a=23a=\frac{2}{3} yields 𝒪(dlog(t)+ndt1/3)\mathcal{O}(\frac{d\log(t)+nd}{t^{1/3}}) while SignGD achieves 𝒪(log(t)+nt1/2)\mathcal{O}(\frac{\log(t)+n}{t^{1/2}}) with a=12a=\frac{1}{2}. Despite achieving tightness w.r.t. class-dimension (kk), this gap emerges from our entry-wise analysis of the \clubsuit term in (10) across the feature dimension (d)(d) using scalar functions 𝒢c(𝑾),𝒬c(𝑾){\mathcal{G}}_{c}({\bm{W}}),\mathcal{Q}_{c}({\bm{W}}). Closing this theoretical gap–revealed through our SignGD analysis–that also appears in the binary case (Zhang et al., 2024), forms an important direction for future work.

Experiments

We generate snythetic multiclass separable data as follows: k=5k=5 class centers are sampled from a standard normal distribution; within each class, data is sampled from normal distribution 𝒩(0,σ2I),σ=0.1\mathcal{N}(0,\sigma^{2}I),\sigma=0.1. We set d=25d=25, sample 55 data points for each class, and ensure that margin is positive (thus data is separable). We run different algorithms to minimize CE loss using ηt=η0ta\eta_{t}=\frac{\eta_{0}}{t^{a}} (η0=0.01\eta_{0}=0.01), where (based on our theorems) aa is set to 1/2\nicefrac{{1}}{{2}}, 1/2\nicefrac{{1}}{{2}}, and 2/3\nicefrac{{2}}{{3}} for NGD, SignGD, and Adam, respectively. The stability constant ϵ\epsilon for Adam is set from {0,106,107,108}\{0,10^{-6},10^{-7},10^{-8}\}. We denote max-margin classifiers defined w.r.t. the 2-norm and the max-norm as 𝑽2{\bm{V}}_{2} and 𝑽{\bm{V}}_{\infty} respectively. Fig. LABEL:fig:main_fig1 shows the following: (1) SignGD/Adam iterates favor the the max-norm margin over the 2-norm margin. The opposite is true for NGD (Figs. LABEL:fig:l2_margin,LABEL:fig:max_margin); (2) SignGD/Adam iterates correlate well with 𝑽{\bm{V}}_{\infty} whereas NGD correlates better with 𝑽2{\bm{V}}_{2} (Figs. LABEL:fig:cor_inf,LABEL:fig:cor_v2); (3) Results are consistent for different (small) values of stability constant (curves nearly overlap).

7 Conclusion

We have characterized the implicit bias of SignGD, NSD, and Adam for multiclass separable data with CE loss, providing explicit rates for their margin maximization behavior. While these results establish fundamental theoretical guarantees, they are limited to linear models and separable data. Yet, they open several promising directions for future research: (1) Improving SignGD’s rates through appropriate momentum leveraging ideas from Ji et al. (2021); Wang et al. (2023) for NGD. (2) Extending the results to non-separable data (Ji & Telgarsky, 2019) and soft-label classification (Thrampoulidis, 2024) settings. (3) Extending the analysis to non-linear architectures like self-attention (Tarzanagh et al., 2023b, a; Vasudeva et al., 2024; Julistiono et al., 2024). Our treatment of CE loss through softmax properties in Sec. B is particularly relevant, as the implicit bias in self-attention is fundamentally driven by the softmax map. Moreover, self-attention naturally aligns with our multiclass setting since each token attends to multiple other tokens through softmax. (4) Investigating whether different implicit biases of SignGD/Adam to GD could theoretically explain empirical observations about their superior performance on heavy-tailed, imbalanced multiclass data (Kunstner et al., 2024). (5) From a statistical perspective, identifying specific scenarios where max-norm margin maximization leads to better generalization (building on initial investigations from (Salehi et al., 2019; Varma & Hassibi, 2024; Akhtiamov et al., 2024) on the impact of different regularization forms in linear classification, and recent findings from (Mohamadi et al., 2024) specifically related to infinity-norm regularization in neural networks).

Acknowledgement

This work is funded partially by NSERC Discovery Grants RGPIN-2021-03677 and RGPIN-2022-03669, Alliance GrantALLRP 581098-22, a CIFAR AI Catalyst grant, and the Canada CIFAR AI Chair Program.

References

  • Akhtiamov et al. (2024) Akhtiamov, D., Ghane, R., and Hassibi, B. Regularized linear regression for binary classification. In 2024 IEEE International Symposium on Information Theory (ISIT), pp.  202–207. IEEE, 2024.
  • Belkin et al. (2019) Belkin, M., Hsu, D., Ma, S., and Mandal, S. Reconciling modern machine-learning practice and the classical bias–variance trade-off. Proceedings of the National Academy of Sciences, 116(32):15849–15854, 2019.
  • Bottou et al. (2018) Bottou, L., Curtis, F. E., and Nocedal, J. Optimization methods for large-scale machine learning. SIAM review, 60(2):223–311, 2018.
  • Boyd & Vandenberghe (2004) Boyd, S. and Vandenberghe, L. Convex optimization. Cambridge university press, 2004.
  • Duchi et al. (2011) Duchi, J., Hazan, E., and Singer, Y. Adaptive subgradient methods for online learning and stochastic optimization. Journal of machine learning research, 12(7), 2011.
  • Gunasekar et al. (2017) Gunasekar, S., Woodworth, B. E., Bhojanapalli, S., Neyshabur, B., and Srebro, N. Implicit regularization in matrix factorization. In Advances in Neural Information Processing Systems, pp.  6151–6159, 2017.
  • Gunasekar et al. (2018) Gunasekar, S., Lee, J., Soudry, D., and Srebro, N. Characterizing implicit bias in terms of optimization geometry. In International Conference on Machine Learning, pp.  1832–1841. PMLR, 2018.
  • Huang et al. (2021) Huang, F., Li, J., and Huang, H. Super-adam: faster and universal framework of adaptive gradients. Advances in Neural Information Processing Systems, 34:9074–9085, 2021.
  • Ji & Telgarsky (2019) Ji, Z. and Telgarsky, M. The implicit bias of gradient descent on nonseparable data. In Conference on learning theory, pp.  1772–1798. PMLR, 2019.
  • Ji & Telgarsky (2020) Ji, Z. and Telgarsky, M. Directional convergence and alignment in deep learning. Advances in Neural Information Processing Systems, 33:17176–17186, 2020.
  • Ji & Telgarsky (2021) Ji, Z. and Telgarsky, M. Characterizing the implicit bias via a primal-dual analysis. In Algorithmic Learning Theory, pp.  772–804. PMLR, 2021.
  • Ji et al. (2020) Ji, Z., Dudík, M., Schapire, R. E., and Telgarsky, M. Gradient descent follows the regularization path for general losses. In Conference on Learning Theory, pp.  2109–2136. PMLR, 2020.
  • Ji et al. (2021) Ji, Z., Srebro, N., and Telgarsky, M. Fast margin maximization via dual acceleration. In International Conference on Machine Learning, pp.  4860–4869. PMLR, 2021.
  • Julistiono et al. (2024) Julistiono, A. A. K., Tarzanagh, D. A., and Azizan, N. Optimizing attention with mirror descent: Generalized max-margin token selection. arXiv preprint arXiv:2410.14581, 2024.
  • Kingma & Ba (2014) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Kunstner et al. (2024) Kunstner, F., Yadav, R., Milligan, A., Schmidt, M., and Bietti, A. Heavy-tailed class imbalance and why adam outperforms gradient descent on language models. arXiv preprint arXiv:2402.19449, 2024.
  • Lyu & Li (2019) Lyu, K. and Li, J. Gradient descent maximizes the margin of homogeneous neural networks. arXiv preprint arXiv:1906.05890, 2019.
  • Mohamadi et al. (2024) Mohamadi, M. A., Li, Z., Wu, L., and Sutherland, D. J. Why do you grok? a theoretical analysis of grokking modular addition. arXiv preprint arXiv:2407.12332, 2024.
  • Mukherjee & Schapire (2010) Mukherjee, I. and Schapire, R. E. A theory of multiclass boosting. Advances in Neural Information Processing Systems, 23, 2010.
  • Nacson et al. (2019) Nacson, M. S., Lee, J., Gunasekar, S., Savarese, P. H. P., Srebro, N., and Soudry, D. Convergence of gradient descent on separable data. In The 22nd International Conference on Artificial Intelligence and Statistics, pp.  3420–3428. PMLR, 2019.
  • Nutini et al. (2015) Nutini, J., Schmidt, M., Laradji, I., Friedlander, M., and Koepke, H. Coordinate descent converges faster with the gauss-southwell rule than random selection. In International Conference on Machine Learning, pp.  1632–1641. PMLR, 2015.
  • Paszke et al. (2019) Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., et al. Pytorch: An imperative style, high-performance deep learning library. Advances in Neural Information Processing Systems, 32, 2019.
  • Qian & Qian (2019) Qian, Q. and Qian, X. The implicit bias of adagrad on separable data. Advances in Neural Information Processing Systems, 32, 2019.
  • Ravi et al. (2024) Ravi, H., Scott, C., Soudry, D., and Wang, Y. The implicit bias of gradient descent on separable multiclass data. arXiv preprint arXiv:2411.01350, 2024.
  • Rosset et al. (2003) Rosset, S., Zhu, J., and Hastie, T. J. Margin maximizing loss functions. In NIPS, 2003.
  • Salehi et al. (2019) Salehi, F., Abbasi, E., and Hassibi, B. The impact of regularization on high-dimensional logistic regression. arXiv preprint arXiv:1906.03761, 2019.
  • Soudry et al. (2018) Soudry, D., Hoffer, E., Nacson, M. S., Gunasekar, S., and Srebro, N. The implicit bias of gradient descent on separable data. Journal of Machine Learning Research, 19(70):1–57, 2018.
  • Sun et al. (2022) Sun, H., Ahn, K., Thrampoulidis, C., and Azizan, N. Mirror descent maximizes generalized margin and can be implemented efficiently. Advances in Neural Information Processing Systems, 35:31089–31101, 2022.
  • Sun et al. (2023) Sun, H., Gatmiry, K., Ahn, K., and Azizan, N. A unified approach to controlling implicit regularization via mirror descent. Journal of Machine Learning Research, 24(393):1–58, 2023.
  • Tarzanagh et al. (2023a) Tarzanagh, D. A., Li, Y., Thrampoulidis, C., and Oymak, S. Transformers as support vector machines, 2023a.
  • Tarzanagh et al. (2023b) Tarzanagh, D. A., Li, Y., Zhang, X., and Oymak, S. Max-margin token selection in attention mechanism, 2023b.
  • Telgarsky (2013) Telgarsky, M. Margins, shrinkage, and boosting. In International Conference on Machine Learning, pp.  307–315. PMLR, 2013.
  • Thrampoulidis (2024) Thrampoulidis, C. Implicit optimization bias of next-token prediction in linear models. In The Thirty-eighth Annual Conference on Neural Information Processing Systems, 2024.
  • Tsilivis et al. (2024) Tsilivis, N., Vardi, G., and Kempe, J. Flavors of margin: Implicit bias of steepest descent in homogeneous neural networks. arXiv preprint arXiv:2410.22069, 2024.
  • Vardi (2023) Vardi, G. On the implicit bias in deep-learning algorithms. Communications of the ACM, 66(6):86–93, 2023.
  • Varma & Hassibi (2024) Varma, K. N. and Hassibi, B. Benefits of stochastic mirror descent in high-dimensional binary classification. In 2024 IEEE International Symposium on Information Theory (ISIT), pp.  196–201. IEEE, 2024.
  • Vasudeva et al. (2024) Vasudeva, B., Deora, P., and Thrampoulidis, C. Implicit bias and fast convergence rates for self-attention. arXiv preprint arXiv:2402.05738, 2024.
  • Wang et al. (2021a) Wang, B., Meng, Q., Chen, W., and Liu, T.-Y. The implicit bias for adaptive optimization algorithms on homogeneous neural networks. In International Conference on Machine Learning, pp.  10849–10858. PMLR, 2021a.
  • Wang et al. (2022) Wang, B., Meng, Q., Zhang, H., Sun, R., Chen, W., Ma, Z.-M., and Liu, T.-Y. Does momentum change the implicit regularization on separable data? Advances in Neural Information Processing Systems, 35:26764–26776, 2022.
  • Wang et al. (2023) Wang, G., Hu, Z., Muthukumar, V., and Abernethy, J. D. Faster margin maximization rates for generic optimization methods. Advances in Neural Information Processing Systems, 36:62488–62518, 2023.
  • Wang et al. (2021b) Wang, N., Qin, Z., Yan, L., Zhuang, H., Wang, X., Bendersky, M., and Najork, M. Rank4class: a ranking formulation for multiclass classification. arXiv preprint arXiv:2112.09727, 2021b.
  • Wang & Scott (2024) Wang, Y. and Scott, C. Unified binary and multiclass margin-based classification. Journal of Machine Learning Research, 25(143):1–51, 2024.
  • Wu et al. (2024a) Wu, J., Bartlett, P. L., Telgarsky, M., and Yu, B. Large stepsize gradient descent for logistic loss: Non-monotonicity of the loss improves optimization efficiency. arXiv preprint arXiv:2402.15926, 2024a.
  • Wu et al. (2024b) Wu, J., Braverman, V., and Lee, J. D. Implicit bias of gradient descent for logistic regression at the edge of stability. Advances in Neural Information Processing Systems, 36, 2024b.
  • Xie & Li (2024) Xie, S. and Li, Z. Implicit bias of adamw: L-infinity norm constrained optimization. arXiv preprint arXiv:2404.04454, 2024.
  • Zhang et al. (2017) Zhang, C., Bengio, S., Hardt, M., Recht, B., and Vinyals, O. Understanding deep learning requires rethinking generalization, 2017.
  • Zhang et al. (2024) Zhang, C., Zou, D., and Cao, Y. The implicit bias of adam on separable data. arXiv preprint arXiv:2406.10650, 2024.

Appendix A Facts about CE loss and Softmax

Lemma A.1 is on the gradient of the cross-entropy loss. It will be used for showing the form of 𝒢(𝑾){\mathcal{G}}({\bm{W}}) in (7) lower bounds (𝑾)max{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\max}} in Lemma B.1.

Lemma A.1 (Gradient).

Let CE loss

(𝑾):=1ni[n]log(𝕊yi(𝑾𝒉i)).\mathcal{L}({\bm{W}}):=-\frac{1}{n}\sum_{i\in[n]}\log\big{(}\mathbb{S}_{{y_{i}}}({\bm{W}}\bm{h}_{i})\big{)}.

For any 𝐖{\bm{W}}, it holds

  • (𝑾)=1ni[n](𝒆yi𝒔i)𝒉i=1n(𝒀𝑺)𝑯\nabla\mathcal{L}({\bm{W}})=-\frac{1}{n}\sum_{i\in[n]}\left(\bm{e}_{y_{i}}-\bm{s}_{i}\right)\bm{h}_{i}^{\top}=-\frac{1}{n}(\bm{Y}-\bm{S}){\bm{H}}^{\top}

  • 𝟙k(𝑾)=0\mathds{1}_{k}^{\top}\nabla\mathcal{L}({\bm{W}})=0

  • For any matrix 𝑨k×d{\bm{A}}\in\mathbb{R}^{k\times d},

    𝑨,(𝑾)\displaystyle\langle\bm{A},-\nabla\mathcal{L}({\bm{W}})\rangle =1ni(1siyi)(𝒆yi𝑨𝒉icyisic𝒆c𝑨𝒉i(1siyi))\displaystyle=\frac{1}{n}\sum_{i}\left(1-s_{i{y_{i}}}\right)\left(\bm{e}_{y_{i}}^{\top}\bm{A}\bm{h}_{i}-\frac{\sum_{c\neq{y_{i}}}s_{ic}\,\bm{e}_{c}^{\top}\bm{A}\bm{h}_{i}}{\left(1-s_{i{y_{i}}}\right)}\right)
    =1ni[n]cyisic(𝒆yi𝒆c)𝑨𝒉i\displaystyle=\frac{1}{n}\sum_{i\in[n]}{\sum_{c\neq{y_{i}}}s_{ic}\,(\bm{e}_{{y_{i}}}-\bm{e}_{c})^{\top}\bm{A}\bm{h}_{i}} (11)

where we simplify 𝐒:=𝕊(𝐖𝐇)=[𝐬1,,𝐬n]k×n\bm{S}:=\mathbb{S}({\bm{W}}{\bm{H}})=[\bm{s}_{1},\ldots,\bm{s}_{n}]\in\mathbb{R}^{k\times n}. The last statement yields

𝑨,(𝑾)1ni[n](1siyi)mincyi(𝒆yi𝒆c)𝑨𝒉i.\displaystyle\langle\bm{A},-\nabla\mathcal{L}({\bm{W}})\rangle\geq\frac{1}{n}\sum_{i\in[n]}\left(1-s_{i{y_{i}}}\right)\,\cdot\,\min_{c\neq{y_{i}}}\left(\bm{e}_{y_{i}}-\bm{e}_{c}\right)^{\top}\bm{A}\bm{h}_{i}. (12)
Proof.

First bullet is by direct calculation. Second bullet uses the fact that 𝟙(𝒚i𝒔i)=11=0\mathds{1}^{\top}(\bm{y}_{i}-\bm{s}_{i})=1-1=0 since 𝟙𝒔i=1\mathds{1}^{\top}\bm{s}_{i}=1. The third bullet follows by direct calculation and writing 𝒔i𝑨𝒉i=(csic𝒆c)𝑨𝒉i=csic𝒆c𝑨𝒉i\bm{s}_{i}^{\top}\bm{A}\bm{h}_{i}=(\sum_{c}s_{ic}\bm{e}_{c})^{\top}\bm{A}\bm{h}_{i}=\sum_{c}s_{ic}\,\bm{e}_{c}^{\top}\bm{A}\bm{h}_{i}. ∎

Lemma A.2 is on the Taylor expansion of the loss. It will be used in showing the descent properties of SignGD (Lemma C.1) and Adam (Lemma D.5).

Lemma A.2 (Hessian).

Let perturbation 𝚫k×d\bm{\Delta}\in\mathbb{R}^{k\times d} and denote 𝐖=𝐖+𝚫{\bm{W}}^{\prime}={\bm{W}}+\bm{\Delta}. Then,

(𝑾)\displaystyle\mathcal{L}({\bm{W}}^{\prime}) =(𝑾)1ni[n](𝒆yi𝕊(𝑾𝒉i))𝒉i,𝚫\displaystyle=\mathcal{L}({\bm{W}})-\frac{1}{n}\sum_{i\in[n]}\langle(\bm{e}_{y_{i}}-\mathbb{S}({\bm{W}}\bm{h}_{i}))\bm{h}_{i}^{\top},\bm{\Delta}\rangle
+12ni[n]𝒉i𝚫(diag(𝕊(𝑾𝒉i))𝕊(𝑾𝒉i)𝕊(𝑾𝒉i))𝚫𝒉i+o(𝚫3).\displaystyle\quad+\frac{1}{2n}\sum_{i\in[n]}\bm{h}_{i}^{\top}\bm{\Delta}^{\top}\left(\operatorname{diag}(\mathbb{S}({\bm{W}}\bm{h}_{i}))-\mathbb{S}({\bm{W}}\bm{h}_{i})\mathbb{S}({\bm{W}}\bm{h}_{i})^{\top}\right)\bm{\Delta}\,\bm{h}_{i}+o(\|\bm{\Delta}\|^{3})\,. (13)
Proof.

Define function y:k\ell_{y}:\mathbb{R}^{k}\rightarrow\mathbb{R} parameterized by y[k]y\in[k] as follows:

y(𝒍):=log(𝕊y(𝒍)).\ell_{y}(\bm{l}):=-\log(\mathbb{S}_{y}(\bm{l}))\,.

From Lemma A.1,

y(𝒍)=(𝒆y𝕊(𝒍)).\nabla\ell_{y}(\bm{l})=-(\bm{e}_{y}-\mathbb{S}(\bm{l}))\,.

Thus,

2y(𝒍)=𝕊(𝒍)=diag(𝕊(𝒍))𝕊(𝒍)𝕊(𝒍)\nabla^{2}\ell_{y}(\bm{l})=\nabla\mathbb{S}(\bm{l})=\operatorname{diag}(\mathbb{S}(\bm{l}))-\mathbb{S}(\bm{l})\mathbb{S}(\bm{l})^{\top}

Combining these the second-order taylor expansion of y\ell_{y} writes as follows for any 𝒍,𝜹k\bm{l},\bm{\delta}\in\mathbb{R}^{k}:

y(𝒍+𝜹)=y(𝒍)(𝒆y𝕊(𝒍))𝜹+12𝜹(diag(𝕊(𝒍))𝕊(𝒍)𝕊(𝒍))𝜹+o(𝜹3).\displaystyle\ell_{y}(\bm{l}+\bm{\delta})=\ell_{y}(\bm{l})-(\bm{e}_{y}-\mathbb{S}(\bm{l}))^{\top}\bm{\delta}+\frac{1}{2}\bm{\delta}^{\top}\left(\operatorname{diag}(\mathbb{S}(\bm{l}))-\mathbb{S}(\bm{l})\mathbb{S}(\bm{l})^{\top}\right)\bm{\delta}+o(\|\bm{\delta}\|^{3})\,.

To evaluate this with respect to a change on the classifier parameters, set 𝒍=𝑾𝒉\bm{l}={\bm{W}}\bm{h} and 𝜹=𝚫𝒉\bm{\delta}=\bm{\Delta}\bm{h} for 𝚫k×d\bm{\Delta}\in\mathbb{R}^{k\times d}. Denoting 𝑾=𝑾+𝚫{\bm{W}}^{\prime}={\bm{W}}+\bm{\Delta}, we then have

y(𝑾)=y(𝑾)(𝒆y𝕊(𝒍))𝒉,𝚫+12𝒉𝚫(diag(𝕊(𝒍))𝕊(𝒍)𝕊(𝒍))𝚫𝒉+o(𝚫3).\displaystyle\ell_{y}({\bm{W}}^{\prime})=\ell_{y}({\bm{W}})-\langle(\bm{e}_{y}-\mathbb{S}(\bm{l}))\bm{h}^{\top},\bm{\Delta}\rangle+\frac{1}{2}\bm{h}^{\top}\bm{\Delta}^{\top}\left(\operatorname{diag}(\mathbb{S}(\bm{l}))-\mathbb{S}(\bm{l})\mathbb{S}(\bm{l})^{\top}\right)\bm{\Delta}\bm{h}+o(\|\bm{\Delta}\|^{3})\,.

This shows the desired since n(𝑾):=i[n]yi(𝑾𝒉i)n\mathcal{L}({\bm{W}}):=\sum_{i\in[n]}\ell_{y_{i}}({\bm{W}}\bm{h}_{i})  and we can further obtain

y(𝑾)=y(𝑾)(𝒆y𝕊(𝒍))𝒉,𝚫+12𝒉𝚫(diag(𝕊(𝒍))𝕊(𝒍)𝕊(𝒍))𝚫𝒉,\displaystyle\ell_{y}({\bm{W}}^{\prime})=\ell_{y}({\bm{W}})-\langle(\bm{e}_{y}-\mathbb{S}(\bm{l}))\bm{h}^{\top},\bm{\Delta}\rangle+\frac{1}{2}\bm{h}^{\top}\bm{\Delta}^{\top}\left(\operatorname{diag}(\mathbb{S}(\bm{l}^{\prime}))-\mathbb{S}(\bm{l}^{\prime})\mathbb{S}(\bm{l}^{\prime})^{\top}\right)\bm{\Delta}\bm{h}, (14)

where 𝒍=𝒍+ζ𝜹\bm{l}^{\prime}=\bm{l}+\zeta\bm{\delta} for some ζ[0,1]\zeta\in[0,1]. ∎

Lemma A.3 is used in bounding the second order term in the Taylor expansion of (𝑾)\mathcal{L}({\bm{W}}).

Lemma A.3.

For any 𝐬Δk1\bm{s}\in\Delta^{k-1} in the kk-dimensional simplex, any index c[k]c\in[k], and any 𝐯k{\bm{v}}\in\mathbb{R}^{k} it holds:

𝒗(diag(𝒔)𝒔𝒔)𝒗4𝒗2(1sc){\bm{v}}^{\top}\left(\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\right){\bm{v}}\leq 4\,\|{\bm{v}}\|_{\infty}^{2}\,(1-s_{c})
Proof.

By Cauchy-Schwartz,

𝒗(diag(𝒔)𝒔𝒔)𝒗\displaystyle{\bm{v}}^{\top}\left(\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\right){\bm{v}} =vec(diag(𝒔)𝒔𝒔)vec(𝒗𝒗)\displaystyle=\operatorname{vec}\big{(}\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\big{)}^{\top}\operatorname{vec}\big{(}{\bm{v}}{\bm{v}}^{\top}\big{)}
vec(diag(𝒔)𝒔𝒔)1vec(𝒗𝒗)\displaystyle\leq\|\operatorname{vec}\big{(}\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\big{)}\|_{1}\|\operatorname{vec}\big{(}{\bm{v}}{\bm{v}}^{\top}\big{)}\|_{\infty}
𝒗2vec(diag(𝒔)𝒔𝒔)1.\displaystyle\leq\|{\bm{v}}\|_{\infty}^{2}\,\|\operatorname{vec}\big{(}\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\big{)}\|_{1}\,.

But,

vec(diag(𝒔)𝒔𝒔)1\displaystyle\|\operatorname{vec}\big{(}\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\big{)}\|_{1} =c[k]sc(1sc)+cc[k]scsc\displaystyle=\sum_{c\in[k]}s_{c}(1-s_{c})+\sum_{c^{\prime}\neq c\in[k]}s_{c}s_{c^{\prime}}
=2c[k]sc(1sc)\displaystyle=2\sum_{c\in[k]}s_{c}(1-s_{c})

where the last line uses ccsc=1sc\sum_{c\neq c}s_{c^{\prime}}=1-s_{c}.

The proof completes by applying Lemma A.4 to the above.

Lemma A.4 is used in the proof of Lemma A.3.

Lemma A.4.

For any 𝐬Δk1\bm{s}\in\Delta^{k-1} in the kk-dimensional simplex and any index c[k]c\in[k] it holds that

csc(1sc)2(1sc).\sum_{c^{\prime}}s_{c^{\prime}}(1-s_{c^{\prime}})\leq 2(1-s_{c})\,.
Proof.

With a bit of algebra and using ccsc=1sc\sum_{c^{\prime}\neq c}s_{c^{\prime}}=1-s_{c} the claim becomes equivalent to

ccsc2+sc22sc+10.\sum_{c^{\prime}\neq c}s_{c^{\prime}}^{2}+s_{c}^{2}-2s_{c}+1\geq 0.

Since this holds true, the lemma holds. ∎

Appendix B Lemmas on (𝑾)\mathcal{L}({\bm{W}}) and 𝒢(𝑾){\mathcal{G}}({\bm{W}})

Lemma B.1 shows that 𝒢(𝑾){\mathcal{G}}({\bm{W}}) lower bounds the loss gradient and that the bound becomes tight as the loss approaches zero.

Lemma B.1 (𝒢(𝑾){\mathcal{G}}({\bm{W}}) as proxy to the loss-gradient norm).

Recall the margin γ\gamma and the assumption 𝐡i1B\|\bm{h}_{i}\|_{1}\leq B. Then, for any 𝐖{\bm{W}} it holds that

2B𝒢(𝑾)(𝑾)sumγ𝒢(𝑾).2B\cdot{\mathcal{G}}({\bm{W}})\geq{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\rm{sum}}}\geq\gamma\cdot{\mathcal{G}}({\bm{W}})\,.
Proof.

First, we prove the lower bound. By duality and direct application of (12)

(𝑾)sum\displaystyle{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\rm{sum}}} =max𝑨max1𝑨,(𝑾)\displaystyle=\max_{{\left\|\bm{A}\right\|_{\max}}\leq 1}\langle\bm{A},-\nabla\mathcal{L}({\bm{W}})\rangle
max𝑨max11ni[n](1siyi)mincyi(𝒆yi𝒆c)T𝑨𝒉i\displaystyle\geq\max_{{\left\|\bm{A}\right\|_{\max}}\leq 1}\frac{1}{n}\sum_{i\in[n]}(1-s_{iy_{i}})\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}\bm{A}\bm{h}_{i}
1ni[n](1siyi)max𝑨max1mini[n],cyi(𝒆yi𝒆c)T𝑨𝒉i.\displaystyle\geq\frac{1}{n}\sum_{i\in[n]}(1-s_{iy_{i}})\cdot\max_{{\left\|\bm{A}\right\|_{\max}}\leq 1}\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}\bm{A}\bm{h}_{i}.

Second, for the upper bound, start with noting by triangle inequality that

(𝑾)sum1ni[n]i(𝑾)sum,{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\rm{sum}}}\leq\frac{1}{n}\sum_{i\in[n]}{\left\|\nabla\ell_{i}({\bm{W}})\right\|_{\rm{sum}}}\,,

where i(𝑾)=log(𝕊yi(𝑾𝒉i))\ell_{i}({\bm{W}})=-\log(\mathbb{S}_{y_{i}}({\bm{W}}\bm{h}_{i})). Recall that

i(𝑾)=(𝒆y𝕊yi(𝑾𝒉i))𝒉i,\nabla\ell_{i}({\bm{W}})=-(\bm{e}_{y}-\mathbb{S}_{y_{i}}({\bm{W}}\bm{h}_{i}))\bm{h}_{i}^{\top},

and, for two vectors 𝒗,𝒖{\bm{v}},\bm{u}: 𝒖𝒗sum=𝒖1𝒗1{\left\|\bm{u}{\bm{v}}^{\top}\right\|_{\rm{sum}}}=\|\bm{u}\|_{1}\|{\bm{v}}\|_{1}. Combining these and noting that

𝒆yi𝕊yi(𝑾𝒉i)1=2(1syi)\|\bm{e}_{y_{i}}-\mathbb{S}_{y_{i}}({\bm{W}}\bm{h}_{i})\|_{1}=2(1-s_{y_{i}})

together with using the assumption 𝒉iB\|\bm{h}_{i}\|\leq B yields the advertised upper bound. ∎

Built upon B.1, we obtain a simple bound on the difference between losses at two different points using the max-norm on the iterates.

Lemma B.2.

For any 𝐖,𝐖0k×d{\bm{W}},{\bm{W}}_{0}\in\mathbb{R}^{k\times d}, suppose that (𝐖)\mathcal{L}({\bm{W}}) is convex, we have

|(𝑾)(𝑾0)|2B𝑾𝑾0max.\displaystyle|\mathcal{L}({\bm{W}})-\mathcal{L}({\bm{W}}_{0})|\leq 2B{\left\|{\bm{W}}-{\bm{W}}_{0}\right\|_{\max}}.
Proof.

By convexity of \mathcal{L}, we have

(𝑾0)(𝑾)(𝑾0),𝑾0𝑾(𝑾0)sum𝑾0𝑾max2B𝑾0𝑾max,\displaystyle\mathcal{L}({\bm{W}}_{0})-\mathcal{L}({\bm{W}})\leq\langle\nabla\mathcal{L}({\bm{W}}_{0}),{\bm{W}}_{0}-{\bm{W}}\rangle\leq{\left\|\nabla\mathcal{L}({\bm{W}}_{0})\right\|_{\rm{sum}}}{\left\|{\bm{W}}_{0}-{\bm{W}}\right\|_{\max}}\leq 2B{\left\|{\bm{W}}_{0}-{\bm{W}}\right\|_{\max}}\,,

where the last inequality is by Lemma B.1. Similarly, we can also show that (𝑾)(𝑾0)2B𝑾0𝑾max\mathcal{L}({\bm{W}})-\mathcal{L}({\bm{W}}_{0})\leq 2B{\left\|{\bm{W}}_{0}-{\bm{W}}\right\|_{\max}}. ∎

Lemma B.3 shows the close relationship between 𝒢(𝑾){\mathcal{G}}({\bm{W}}) and (𝑾)\mathcal{L}({\bm{W}}). 𝒢(𝑾){\mathcal{G}}({\bm{W}}) not only lower bounds (𝑾)\mathcal{L}({\bm{W}}), but also upper bounds (𝑾)\mathcal{L}({\bm{W}}) up to a constant provided that the loss is sufficiently small. Moreover, the rate of convergence 𝒢(𝑾)(𝑾)\frac{{\mathcal{G}}({\bm{W}})}{\mathcal{L}({\bm{W}})} depends on the rate of decrease in the loss.

Lemma B.3 (𝒢(𝑾){\mathcal{G}}({\bm{W}}) as proxy to the loss).

Let 𝐖k×d{\bm{W}}\in\mathbb{R}^{k\times d}, we have

  1. (i)

    1𝒢(𝑾)(𝑾)1n(𝑾)21\geq\frac{{\mathcal{G}}({\bm{W}})}{\mathcal{L}({\bm{W}})}\geq 1-\frac{n\mathcal{L}({\bm{W}})}{2}

  2. (ii)

    Suppose that 𝑾{\bm{W}} satisfies (𝑾)log2n\mathcal{L}({\bm{W}})\leq\frac{\log 2}{n} or 𝒢(𝑾)12n{\mathcal{G}}({\bm{W}})\leq\frac{1}{2n}, then (𝑾)2𝒢(𝑾).\mathcal{L}({\bm{W}})\leq 2{\mathcal{G}}({\bm{W}}).

Proof.

(i) Denote for simplicity si:=siyi=𝕊yi(𝑾𝒉i)s_{i}:=s_{i{y_{i}}}=\mathbb{S}_{{y_{i}}}({\bm{W}}\bm{h}_{i}), thus (𝑾)=1ni[n]log(1/si)\mathcal{L}({\bm{W}})=\frac{1}{n}\sum_{i\in[n]}\log(1/s_{i}) and 𝒢(W)=1ni[n](1si){\mathcal{G}}(W)=\frac{1}{n}\sum_{i\in[n]}(1-s_{i}). For the upper bound, simply use the fact that ex1x,e^{x-1}\geq x, forall x[0,1],x\in[0,1], thus log(1/si)1si\log(1/s_{i})\geq 1-s_{i} for all i[n]i\in[n].

The lower bound can be proved using the exact same arguments in the proof of Zhang et al. (2024, Lemma C.7) for the binary case. For completeness, we provide an alternative elementary proof. It suffices to prove for n=1n=1 that for s(0,1)s\in(0,1):

1slog(1/s)12log2(1/s).\displaystyle 1-s\geq\log(1/s)-\frac{1}{2}\log^{2}(1/s). (15)

The general case follows by summing over s=sis=s_{i} and using i[n]log2(1/si)(i[n]log(1/si))2\sum_{i\in[n]}\log^{2}(1/s_{i})\leq\left(\sum_{i\in[n]}\log(1/s_{i})\right)^{2} since log(1/si)>0\log(1/s_{i})>0. For (15), let x=log(1/s)>0x=\log(1/s)>0. The inequality becomes ex1x+x2/2e^{-x}\leq 1-x+x^{2}/2, which holds for x>0x>0 by the second-order Taylor expansion of exe^{-x} around 0.

(ii) The sufficiency of (𝑾)log2n\mathcal{L}({\bm{W}})\leq\frac{\log 2}{n} (to guarantee that (𝑾)2𝒢(𝑾)\mathcal{L}({\bm{W}})\leq 2{\mathcal{G}}({\bm{W}})) follows from (i) and (𝑾)log2n1n\mathcal{L}({\bm{W}})\leq\frac{\log 2}{n}\leq\frac{1}{n}. The inequality log(1x)2(1x)\log(\frac{1}{x})\leq 2(1-x) holds when x[0.2032,1]x\in[0.2032,1]. This translates to the following sufficient condition on siyis_{iy_{i}}

si=ei[yi]c[k]ei[c]=11+c[k],cyiei[c]i[yi]0.2032.\displaystyle s_{i}=\frac{e^{\bm{\ell}_{i}[y_{i}]}}{\sum_{c\in[k]}e^{\bm{\ell}_{i}[c]}}=\frac{1}{1+\sum_{c\in[k],c\neq y_{i}}e^{\bm{\ell}_{i}[c]-\bm{\ell}_{i}[y_{i}]}}\geq 0.2032.

Under the assumption 𝒢(𝑾)12n{\mathcal{G}}({\bm{W}})\leq\frac{1}{2n}, we have 1sii[n](1si)=n𝒢(𝑾)121-s_{i}\leq\sum_{i\in[n]}(1-s_{i})=n{\mathcal{G}}({\bm{W}})\leq\frac{1}{2}, from which we obtain si120.2032s_{i}\geq\frac{1}{2}\geq 0.2032 for all i[n]i\in[n]. ∎

Lemma B.4 shows that weights 𝑾{\bm{W}} of low loss separate the data. It is used in deriving the lower bound on the unnormalized margin.

Lemma B.4 (Low (𝑾)\mathcal{L}({\bm{W}}) implies separability).

Suppose that there exists 𝐖k×d{\bm{W}}\in\mathbb{R}^{k\times d} such that (𝐖)log2n\mathcal{L}({\bm{W}})\leq\frac{\log 2}{n}, then we have

(𝒆yi𝒆c)T𝑾𝒉i0,for all i[n] and for all c[k] such that cyi.\displaystyle(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}\bm{h}_{i}\geq 0,\quad\text{for all $i\in[n]$ and for all $c\in[k]$ such that $c\neq y_{i}$}. (16)
Proof.

We rewrite the loss into the form:

(𝑾)=1ni[n]log(ei[yi]c[k]ei[c])=1ni[n]log(1+cyie(i[yi]i[c])).\displaystyle\mathcal{L}({\bm{W}})=-\frac{1}{n}\sum_{i\in[n]}\log(\frac{e^{\bm{\ell}_{i}[y_{i}]}}{\sum_{c\in[k]}e^{\bm{\ell}_{i}[c]}})=\frac{1}{n}\sum_{i\in[n]}\log(1+\sum_{c\neq y_{i}}e^{-(\bm{\ell}_{i}[y_{i}]-\bm{\ell}_{i}[c])}).

Fix any i[n]i\in[n], by the assumption that (𝑾)log2n\mathcal{L}({\bm{W}})\leq\frac{\log 2}{n}, we have the following:

log(1+cyie(i[yi]i[c]))n(𝑾)log(2).\displaystyle\log(1+\sum_{c\neq y_{i}}e^{-(\bm{\ell}_{i}[y_{i}]-\bm{\ell}_{i}[c])})\leq n\mathcal{L}({\bm{W}})\leq\log(2).

This implies:

emincyi(i[yi]i[c])=maxcyie(i[yi]i[c])cyie(i[yi]i[c])1.\displaystyle e^{-\min_{c\neq y_{i}}(\bm{\ell}_{i}[y_{i}]-\bm{\ell}_{i}[c])}=\max_{c\neq y_{i}}e^{-(\bm{\ell}_{i}[y_{i}]-\bm{\ell}_{i}[c])\leq}\leq\sum_{c\neq y_{i}}e^{-(\bm{\ell}_{i}[y_{i}]-\bm{\ell}_{i}[c])}\leq 1.

After taking log\log on both sides, we obtain the following: i[yi]i[c]=(𝒆yi𝒆c)T𝑾𝒉i0\bm{\ell}_{i}[y_{i}]-\bm{\ell}_{i}[c]=(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}\bm{h}_{i}\geq 0 for any c[k]c\in[k] such that cyic\neq y_{i}. ∎

Lemma B.5 shows that the ratio of 𝒢(𝑾){\mathcal{G}}({\bm{W}}) at two points can be bounded by the exponential of the max-norm of their differences. It is used in handling the second order term in the Taylor expansion of the loss.

Lemma B.5 (Ratio of 𝒢(𝑾){\mathcal{G}}({\bm{W}})).

For any ψ[0,1]\psi\in[0,1], we have the following:

𝒢(𝑾+ψ𝑾)𝒢(𝑾)e2Bψ𝑾max\displaystyle\frac{{\mathcal{G}}({\bm{W}}+\psi\triangle{\bm{W}})}{{\mathcal{G}}({\bm{W}})}\leq e^{2B\psi{\left\|\triangle{\bm{W}}\right\|_{\max}}}
Proof.

By the definition of 𝒢(𝑾){\mathcal{G}}({\bm{W}}), we have:

𝒢(𝑾+ψ𝑾)𝒢(𝑾)=i[n](1𝕊yi((𝑾+ψ𝑾)𝒉i))i[n](1𝕊yi(𝑾𝒉i)).\displaystyle\frac{{\mathcal{G}}({\bm{W}}+\psi\triangle{\bm{W}})}{{\mathcal{G}}({\bm{W}})}=\frac{\sum_{i\in[n]}\bigl{(}1-\mathbb{S}_{y_{i}}(({\bm{W}}+\psi\triangle{\bm{W}})\bm{h}_{i})\bigr{)}}{\sum_{i\in[n]}\bigl{(}1-\mathbb{S}_{y_{i}}({\bm{W}}\bm{h}_{i})\bigr{)}}.

For any c[k]c\in[k] and 𝒗,𝒗k{\bm{v}},{\bm{v}}^{\prime}\in\mathbb{R}^{k}, we have:

1𝕊c(𝒗)1𝕊c(𝒗)\displaystyle\frac{1-\mathbb{S}_{c}({\bm{v}}^{\prime})}{1-\mathbb{S}_{c}({\bm{v}})} =1evci[k]evi1evci[k]evi\displaystyle=\frac{1-\frac{e^{v^{\prime}_{c}}}{\sum_{i\in[k]}e^{v^{\prime}_{i}}}}{1-\frac{e^{v_{c}}}{\sum_{i\in[k]}e^{v_{i}}}}
=j[k],jcevji[k]evij[k],jcevji[k]evi\displaystyle=\frac{\frac{\sum_{j\in[k],j\neq c}e^{v^{\prime}_{j}}}{\sum_{i\in[k]}e^{v^{\prime}_{i}}}}{\frac{\sum_{j\in[k],j\neq c}e^{v_{j}}}{\sum_{i\in[k]}e^{v_{i}}}}
=j[k],jci[k]evj+vij[k],jci[k]evj+vi\displaystyle=\frac{\sum_{j\in[k],j\neq c}\sum_{i\in[k]}e^{v^{\prime}_{j}+v_{i}}}{\sum_{j\in[k],j\neq c}\sum_{i\in[k]}e^{v_{j}+v^{\prime}_{i}}}
e2𝒗𝒗.\displaystyle\leq e^{2\lVert{\bm{v}}-{\bm{v}}^{\prime}\rVert_{\infty}}.

The last inequality is because evj+vievj+vie|vjvj|+|vivi|e2𝒗𝒗\frac{e^{v^{\prime}_{j}+v_{i}}}{e^{v_{j}+v^{\prime}_{i}}}\leq e^{|v^{\prime}_{j}-v_{j}|+|v_{i}-v^{\prime}_{i}|}\leq e^{2\lVert{\bm{v}}-{\bm{v}}^{\prime}\rVert_{\infty}}, which implies that j[k],jci[k]evj+vie2𝒗𝒗j[k],jci[k]evj+vi\sum_{j\in[k],j\neq c}\sum_{i\in[k]}e^{v^{\prime}_{j}+v_{i}}\leq e^{2\lVert{\bm{v}}-{\bm{v}}^{\prime}\rVert_{\infty}}\sum_{j\in[k],j\neq c}\sum_{i\in[k]}e^{v_{j}+v^{\prime}_{i}}. Next, we specialize this result to 𝒗=(𝑾+ψ𝑾)𝒉i{\bm{v}}^{\prime}=({\bm{W}}+\psi\triangle{\bm{W}})\bm{h}_{i}, 𝒗=𝑾𝒉i{\bm{v}}={\bm{W}}\bm{h}_{i}, and c=yic=y_{i} for any i[n]i\in[n] to obtain:

1𝕊yi((𝑾+ψ𝑾)𝒉i))1𝕊yi(𝑾𝒉i)e2ψ𝑾𝒉ie2Bψ𝑾max.\displaystyle\frac{1-\mathbb{S}_{y_{i}}(({\bm{W}}+\psi\triangle{\bm{W}})\bm{h}_{i})\bigr{)}}{1-\mathbb{S}_{y_{i}}({\bm{W}}\bm{h}_{i})}\leq e^{2\psi\lVert\triangle{\bm{W}}\bm{h}_{i}\rVert_{\infty}}\leq e^{2B\psi{\left\|\triangle{\bm{W}}\right\|_{\max}}}.

Then, we rearrange and sum over i[n]i\in[n] to obtain: i[n](1𝕊yi((𝑾+ψ𝑾)𝒉i))e2Bψ𝑾maxi[n](1𝕊yi(𝑾𝒉i))\sum_{i\in[n]}\bigl{(}1-\mathbb{S}_{y_{i}}(({\bm{W}}+\psi\triangle{\bm{W}})\bm{h}_{i})\bigr{)}\leq e^{2B\psi{\left\|\triangle{\bm{W}}\right\|_{\max}}}\sum_{i\in[n]}\bigl{(}1-\mathbb{S}_{y_{i}}({\bm{W}}\bm{h}_{i})\bigr{)}, from which the desired inequality follows. ∎

Proof Overview

We consider a decay learning rate schedule of the form ηt=Θ(1ta)\eta_{t}=\Theta(\frac{1}{t^{a}}) where a(0,1]a\in(0,1]. The first step is to show that the loss monotonically decreases after certain time and the rate depends on 𝒢(𝑾){\mathcal{G}}({\bm{W}}). To obtain this, we apply Lemma B.1 and Lemma A.3 to upper bound the first-order and second-order terms in the Taylor expansion of the loss 17, respectively. The difference between SignGD and Adam is that Adam involves an additional term |(𝑾t),𝑴t𝑽t(𝑾t)|(𝑾t)||\bigm{|}\langle\nabla\mathcal{L}({\bm{W}}_{t}),\frac{\bm{M}_{t}}{\sqrt{{\bm{V}}_{t}}}-\frac{\nabla\mathcal{L}({\bm{W}}_{t})}{|\nabla\mathcal{L}({\bm{W}}_{t})|}\rangle\bigm{|}. To handle it, we apply Lemma D.2 and Lemma D.3 to bound the first (𝑴t\bm{M}_{t}) and second moment (𝑽t{\bm{V}}_{t}) buffer of Adam using 𝒢(𝑾){\mathcal{G}}({\bm{W}}). Next, we use loss monotonically to derive a lower bound on the unnormalized margin which involves the ratio 𝒢(𝑾)(𝑾)\frac{{\mathcal{G}}({\bm{W}})}{\mathcal{L}({\bm{W}})}. A crucial step involved is to find a time t¯2\bar{t}_{2} such that separability (36) holds for all tt¯2t\geq\bar{t}_{2}, and the existence of t¯2\bar{t}_{2} is guaranteed by loss monotonicity given (𝑾t)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} implying separability (36) proved in Lemma B.4.

Then, we argue that the ratio 𝒢(Wt)(Wt)\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})} converges to 11 exponentially fast (recalling that 1𝒢(𝑾t)(𝑾t)1n(𝑾t)21\geq\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}\geq 1-\frac{n\mathcal{L}({\bm{W}}_{t})}{2}) by proving the same rate for the decrease of the loss (𝑾t)\mathcal{L}({\bm{W}}_{t}). We first choose a time t1t_{1} after t0t_{0} (recall that t0t_{0} is the time that satisfies Assumption 3.4) such that (𝑾t+1)(𝑾t)ηtγ2𝒢(𝑾T)\mathcal{L}({\bm{W}}_{t+1})\leq\mathcal{L}({\bm{W}}_{t})-\frac{\eta_{t}\gamma}{2}{\mathcal{G}}({\bm{W}}_{T}) for all tt1t\geq t_{1}. Next, we lower bound G(𝑾t)G({\bm{W}}_{t}) using (𝑾t)\mathcal{L}({\bm{W}}_{t}). By Lemma B.3, there are two sufficient conditions (namely, (𝑾t)log2n~\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n}\eqqcolon\tilde{\mathcal{L}} or 𝑮(𝑾t)12n\bm{G}({\bm{W}}_{t})\leq\frac{1}{2n}) that guarantee (𝑾t)2𝒢(𝑾t)\mathcal{L}({\bm{W}}_{t})\leq 2{\mathcal{G}}({\bm{W}}_{t}). We choose a time t2t_{2} (after t1t_{1}) that is sufficiently large such that there exists t[t1,t2]t^{*}\in[t_{1},t_{2}] for which we have 𝒢(𝑾t)~212n{\mathcal{G}}({\bm{W}}_{t^{*}})\leq\frac{\tilde{\mathcal{L}}}{2}\leq\frac{1}{2n}. This not only guarantees that (𝑾t)2𝒢(𝑾t)\mathcal{L}({\bm{W}}_{t^{*}})\leq 2{\mathcal{G}}({\bm{W}}_{t^{*}}) at time tt^{*}, but also (crucially due to monotonicity) implies that (𝑾t)(𝑾t)2𝒢(𝑾t)log2n\mathcal{L}({\bm{W}}_{t})\leq\mathcal{L}({\bm{W}}_{t^{*}})\leq 2{\mathcal{G}}({\bm{W}}_{t^{*}})\leq\frac{\log 2}{n} for all tt2t\geq t_{2}. Thus, we observe that the other sufficient condition (𝑾t)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} is satisfied, from which we conclude that (𝑾t)2𝒢(𝑾t)\mathcal{L}({\bm{W}}_{t})\leq 2{\mathcal{G}}({\bm{W}}_{t}) for all tt2t\geq t_{2}. We remark that the choice of t2t_{2} depends on (𝑾t1)\mathcal{L}({\bm{W}}_{t_{1}}) (whose magnitude is bounded using Lemma B.2), and t2t_{2} can be used as t¯2\bar{t}_{2} above. To recap, t1t_{1} is the time (after t0t_{0}) after which the successive loss decrease is lower bounded by the product ηtγ𝒢(Wt)\eta_{t}\gamma{\mathcal{G}}({\bm{W}}_{t}); t2t_{2} (after t1t_{1}) is the time after which (Wt)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} (thus, both (Wt)2𝒢(Wt)\mathcal{L}({\bm{W}}_{t})\leq 2{\mathcal{G}}({\bm{W}}_{t}) and separability condition (36) hold for all tt2t\geq t_{2}).

The next step is to bound the max-norm of the iterates in terms of learning rates. The SignGD case is straightforward given sign((𝑾))max=1{\left\|\texttt{sign}(\nabla\mathcal{L}({\bm{W}}))\right\|_{\max}}=1. In the case of Adam, the important step is to show that 𝑽t[c,j]𝒪((𝑾t)2){\bm{V}}_{t}[c,j]\leq\mathcal{O}(\mathcal{L}({\bm{W}}_{t})^{2}) for all c[k]c\in[k] and j[d]j\in[d]. Thus, 𝑽t[c,j]1{\bm{V}}_{t}[c,j]\leq 1 is feasible again due to loss monotonicity. To prove this, we use Lemma B.1 that lower bounds (𝑾)max{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\max}} using 𝒢(𝑾){\mathcal{G}}({\bm{W}}) up to some constant. Here, we provide a summary of the lemmas (in sections A and B) on the properties of (𝑾)\mathcal{L}({\bm{W}}) and 𝒢(𝑾){\mathcal{G}}({\bm{W}}) that are used in proving the various key steps in the Proof Overview:

  • Loss Monotonicity: Lemma A.2 is on the Taylor expansion of the loss; Lemma A.3 and B.5 are for handling the second-order term in the expansion; Lemma B.1 proves that (𝑾)sumγ𝒢(𝑾){\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\rm{sum}}}\geq\gamma{\mathcal{G}}({\bm{W}}).

  • Unnormalized Margin: Lemma B.4 is on the separability condition (36) implied by a low loss ((𝑾)log2n\mathcal{L}({\bm{W}})\leq\frac{\log 2}{n}); Lemma B.3 proves that 𝒢(𝑾)(𝑾)1\frac{{\mathcal{G}}({\bm{W}})}{\mathcal{L}({\bm{W}})}\leq 1.

  • Convergence of 𝒢(Wt)(Wt)\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}: Lemma B.3 proves that 𝒢(𝑾t)(𝑾t)1n(𝑾t)2\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}\geq 1-\frac{n\mathcal{L}({\bm{W}}_{t})}{2}, and it provides the sufficient conditions that gaurantee (𝑾)2𝒢(𝑾)\mathcal{L}({\bm{W}})\leq 2{\mathcal{G}}({\bm{W}}); Lemma B.2 provides a simple bound on the loss using the max-norm of the iterates.

  • Iterate Bound: Lemma B.1 proves that 2B𝒢(𝑾)(𝑾)max2B\cdot{\mathcal{G}}({\bm{W}})\leq{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\max}}.

Appendix C Proof of SignGD

In this section, we break the proof of implicit bias of SignGD into several parts following the arguments in the Proof Overview. Lemma C.1 shows the descent properties of SignGD. It is used in Lemma C.2 to lower bound the unnormalized margin, and in the proof of Theorem C.4 to show the convergence of 𝒢(𝑾t)(𝑾t)\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}.

Lemma C.1 (SignGD Descent).

Under the same setting as Theorem C.4, it holds for all t0t\geq 0,

(𝑾t+1)\displaystyle\mathcal{L}({\bm{W}}_{t+1}) (𝑾t)γηt(1αs1ηt)𝒢(𝑾t),\displaystyle\leq\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}(1-\alpha_{s_{1}}\eta_{t}){\mathcal{G}}({\bm{W}}_{t}),

where αs1\alpha_{s_{1}} is some constant that depends on BB and γ\gamma.

Proof.

By Lemma A.2, we let 𝑾=𝑾t+1{\bm{W}}^{\prime}={\bm{W}}_{t+1}, 𝑾=𝑾t{\bm{W}}={\bm{W}}_{t}, 𝚫t=𝑾t+1𝑾t\bm{\Delta}_{t}={\bm{W}}_{t+1}-{\bm{W}}_{t}, and define 𝑾t,t+1,ζ:=𝑾t+ζ(𝑾t+1𝑾t){\bm{W}}_{t,t+1,\zeta}:={\bm{W}}_{t}+\zeta({\bm{W}}_{t+1}-{\bm{W}}_{t}). We choose ζ\zeta^{*} such that 𝑾t,t+1,ζ{\bm{W}}_{t,t+1,\zeta^{*}} satisfies (14), we have:

(𝑾t+1)\displaystyle\mathcal{L}({\bm{W}}_{t+1}) =(𝑾t)+(𝑾t),𝚫tt\displaystyle=\mathcal{L}({\bm{W}}_{t})+\underbrace{\langle\nabla\mathcal{L}({\bm{W}}_{t}),\bm{\Delta}_{t}\rangle}_{\spadesuit_{t}}
+12ni[n]𝒉i𝚫t(diag(𝕊(𝑾t,t+1,γ𝒉i))𝕊(𝑾t,t+1,ζ𝒉i)𝕊(𝑾t,t+1,ζ𝒉i))𝚫t𝒉it.\displaystyle\quad+\frac{1}{2n}\sum_{i\in[n]}\underbrace{\bm{h}_{i}^{\top}\bm{\Delta}_{t}^{\top}\left(\operatorname{diag}(\mathbb{S}({\bm{W}}_{t,t+1,\gamma}\bm{h}_{i}))-\mathbb{S}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i})\mathbb{S}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i})^{\top}\right)\bm{\Delta}_{t}\,\bm{h}_{i}}_{\clubsuit_{t}}\,. (17)

For the t\spadesuit_{t} term, we have by Lemma B.1:

t=(𝑾t),ηt(𝑾t)2(𝑾t)=ηt(𝑾t)sumηtγG(𝑾t)\displaystyle\spadesuit_{t}=\langle\nabla\mathcal{L}({\bm{W}}_{t}),-\eta_{t}\frac{\nabla\mathcal{L}({\bm{W}}_{t})}{\sqrt{\nabla\mathcal{L}^{2}({\bm{W}}_{t})}}\rangle=-\eta_{t}{\left\|\nabla\mathcal{L}({\bm{W}}_{t})\right\|_{\rm{sum}}}\leq-\eta_{t}\gamma G({\bm{W}}_{t})

For the t\clubsuit_{t} term, we let 𝒗=𝚫t𝒉i{\bm{v}}=\bm{\Delta}_{t}\bm{h}_{i} and 𝒔=𝕊(𝑾t,t+1,ζ𝒉i)\bm{s}=\mathbb{S}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i}), and apply Lemma A.3 to obtain

t4𝚫t𝒉i2(1𝕊yi(𝑾t,t+1,ζ𝒉i))4ηt2B2(1𝕊yi(𝑾t,t+1,ζ𝒉i)),\displaystyle\clubsuit_{t}\leq 4\|\bm{\Delta}_{t}\bm{h}_{i}\|_{\infty}^{2}(1-\mathbb{S}_{y_{i}}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i}))\leq 4\eta_{t}^{2}B^{2}(1-\mathbb{S}_{y_{i}}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i})),

where in the second inequality we have used 𝚫t𝒉i𝚫tmax𝒉i1\|\bm{\Delta}_{t}\bm{h}_{i}\|_{\infty}\leq{\left\|\bm{\Delta}_{t}\right\|_{\max}}\|\bm{h}_{i}\|_{1}, 𝒉i1B\|\bm{h}_{i}\|_{1}\leq B, and 𝚫tmax=ηtsign((𝑾t))maxηt{\left\|\bm{\Delta}_{t}\right\|_{\max}}=\eta_{t}{\left\|\texttt{sign}(\nabla\mathcal{L}({\bm{W}}_{t}))\right\|_{\max}}\leq\eta_{t}. Putting these two pieces together, we obtain

(𝑾t+1)\displaystyle\mathcal{L}({\bm{W}}_{t+1}) (𝑾t)γηt𝒢(𝑾t)+2ηt2B21ni[n](1𝕊yi(𝑾t,t+1,ζ𝒉i))\displaystyle\leq\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}{\mathcal{G}}({\bm{W}}_{t})+2\eta_{t}^{2}B^{2}\frac{1}{n}\sum_{i\in[n]}(1-\mathbb{S}_{y_{i}}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i}))
=(𝑾t)γηt𝒢(𝑾t)+2ηt2B2𝒢(𝑾t,t+1,ζ)\displaystyle=\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}{\mathcal{G}}({\bm{W}}_{t})+2\eta_{t}^{2}B^{2}{\mathcal{G}}({\bm{W}}_{t,t+1,\zeta^{*}})
(𝑾t)γηt𝒢(𝑾t)+2ηt2B2supζ[0,1]𝒢(𝑾t,t+1,ζ)\displaystyle\leq\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}{\mathcal{G}}({\bm{W}}_{t})+2\eta_{t}^{2}B^{2}\sup_{\zeta\in[0,1]}{\mathcal{G}}({\bm{W}}_{t,t+1,\zeta})
=(𝑾t)γηt𝒢(𝑾t)+2ηt2B2𝒢(𝑾t)supζ[0,1]𝒢(𝑾t+ζ𝚫t)𝒢(𝑾t)\displaystyle=\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}{\mathcal{G}}({\bm{W}}_{t})+2\eta_{t}^{2}B^{2}{\mathcal{G}}({\bm{W}}_{t})\sup_{\zeta\in[0,1]}\frac{{\mathcal{G}}({\bm{W}}_{t}+\zeta\bm{\Delta}_{t})}{{\mathcal{G}}({\bm{W}}_{t})}
(a)(𝑾t)γηt𝒢(𝑾t)+2ηt2B2𝒢(𝑾t)supζ[0,1]e2Bζ𝚫tmax\displaystyle\stackrel{{\scriptstyle(a)}}{{\leq}}\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}{\mathcal{G}}({\bm{W}}_{t})+2\eta_{t}^{2}B^{2}{\mathcal{G}}({\bm{W}}_{t})\sup_{\zeta\in[0,1]}e^{2B\zeta{\left\|\bm{\Delta}_{t}\right\|_{\max}}}
(b)(𝑾t)γηt𝒢(𝑾t)+2ηt2B2e2Bη0𝒢(𝑾t),\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}{\mathcal{G}}({\bm{W}}_{t})+2\eta_{t}^{2}B^{2}e^{2B\eta_{0}}{\mathcal{G}}({\bm{W}}_{t}), (18)

where (a) is by Lemma B.5 and (b) is by 𝚫maxηt{\left\|\bm{\Delta}\right\|_{\max}}\leq\eta_{t}. Letting αs1=2B2e2Bη0γ\alpha_{s_{1}}=\frac{2B^{2}e^{2B\eta_{0}}}{\gamma}, Eq. (18) simplifies to:

(𝑾t+1)\displaystyle\mathcal{L}({\bm{W}}_{t+1}) (𝑾t)γηt(1αs1ηt)𝒢(𝑾t),\displaystyle\leq\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}(1-\alpha_{s_{1}}\eta_{t}){\mathcal{G}}({\bm{W}}_{t}),

from which we observe that the loss starts to monotonically decrease after ηt\eta_{t} satisfies ηt1αs1\eta_{t}\leq\frac{1}{\alpha_{s_{1}}} for a decreasing learning rate schedule. ∎

For a decaying learning rate schedule, Lemma C.1 implies that the loss monotonically decreases after a certain time. Thus, we know that the assumption of Lemma C.2 can be satisfied. In the proof of Theorem C.4, we will specify a concrete form of t~\tilde{t} in Lemma C.2.

Lemma C.2 (SignGD Unnormalized Margin).

Suppose that there exist t~\tilde{t} such that (𝐖t)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} for all t>t~t>\tilde{t}, then we have

mini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉iγs=t~t1ηs𝒢(𝑾s)(𝑾s)αs2s=t~t1ηs2,\displaystyle\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}\geq\gamma\sum_{s=\tilde{t}}^{t-1}\eta_{s}\frac{{\mathcal{G}}({\bm{W}}_{s})}{\mathcal{L}({\bm{W}}_{s})}-\alpha_{s_{2}}\sum_{s=\tilde{t}}^{t-1}\eta_{s}^{2},

where αs2\alpha_{s_{2}} is some constant that depends on BB.

Proof.

We let αs2=2Be2Bη0\alpha_{s_{2}}=2Be^{2B\eta_{0}}, then from (18), we have for t>t~t>\tilde{t}:

(𝑾t+1)\displaystyle\mathcal{L}({\bm{W}}_{t+1}) (𝑾t)γηt𝒢(𝑾t)+αs2ηt2𝒢(𝑾t)\displaystyle\leq\mathcal{L}({\bm{W}}_{t})-\gamma\eta_{t}{\mathcal{G}}({\bm{W}}_{t})+\alpha_{s_{2}}\eta_{t}^{2}{\mathcal{G}}({\bm{W}}_{t})
=(𝑾t)(1γηt𝒢(𝑾t)(𝑾t)+αs2ηt2𝒢(𝑾t)(𝑾t))\displaystyle=\mathcal{L}({\bm{W}}_{t})\bigl{(}1-\gamma\eta_{t}\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}+\alpha_{s_{2}}\eta_{t}^{2}\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}\bigr{)}
(𝑾t)exp(γηt𝒢(𝑾t)(𝑾t)+αs2ηt2𝒢(𝑾t)(𝑾t))\displaystyle\leq\mathcal{L}({\bm{W}}_{t})\exp\bigl{(}-\gamma\eta_{t}\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}+\alpha_{s_{2}}\eta_{t}^{2}\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}\bigr{)}
(𝑾t~)exp(γs=t~tηs𝒢(𝑾s)(𝑾s)+αs2s=t~tηs2).\displaystyle\leq\mathcal{L}({\bm{W}}_{\tilde{t}})\exp\bigl{(}-\gamma\sum_{s=\tilde{t}}^{t}\eta_{s}\frac{{\mathcal{G}}({\bm{W}}_{s})}{\mathcal{L}({\bm{W}}_{s})}+\alpha_{s_{2}}\sum_{s=\tilde{t}}^{t}\eta_{s}^{2}\bigr{)}.
log2nexp(γs=t~tηs𝒢(𝑾s)(𝑾s)+αs2s=t~tηs2),\displaystyle\leq\frac{\log 2}{n}\exp\bigl{(}-\gamma\sum_{s=\tilde{t}}^{t}\eta_{s}\frac{{\mathcal{G}}({\bm{W}}_{s})}{\mathcal{L}({\bm{W}}_{s})}+\alpha_{s_{2}}\sum_{s=\tilde{t}}^{t}\eta_{s}^{2}\bigr{)}, (19)

where the penultimate inequality uses Lemma B.3, and the last inequality uses the assumption that (𝑾t)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} for all tt~t\geq\tilde{t}. Then, we have for all t>t~t>\tilde{t}:

emini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉i\displaystyle e^{-\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}} =maxi[n]emincyi(𝒆yi𝒆c)T𝑾t𝒉i\displaystyle=\max_{i\in[n]}e^{-\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}
(a)maxi[n]1log2log(1+emincyi(𝒆yi𝒆c)T𝑾t𝒉i)\displaystyle\stackrel{{\scriptstyle(a)}}{{\leq}}\max_{i\in[n]}\frac{1}{\log 2}\log\bigl{(}1+e^{-\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}\bigr{)}
maxi[n]1log2log(1+cyie(𝒆yi𝒆c)T𝑾t𝒉i)n(𝑾t)log2\displaystyle\leq\max_{i\in[n]}\frac{1}{\log 2}\log(1+\sum_{c\neq y_{i}}e^{-(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}})\leq\frac{n\mathcal{L}({\bm{W}}_{t})}{\log 2}
(b)exp(γs=t~t1ηs𝒢(𝑾s)(𝑾s)+αs2s=t~t1ηs2).\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}\exp\bigl{(}-\gamma\sum_{s=\tilde{t}}^{t-1}\eta_{s}\frac{{\mathcal{G}}({\bm{W}}_{s})}{\mathcal{L}({\bm{W}}_{s})}+\alpha_{s_{2}}\sum_{s=\tilde{t}}^{t-1}\eta_{s}^{2}\bigr{)}.

(a) is by the following: the assumption (𝑾t)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} implies that mincyi(𝒆yi𝒆c)T𝑾t𝒉i0\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}\geq 0 for all i[n]i\in[n] by Lemma B.4. We also know the inequality log(1+ez)ezlog2\frac{\log(1+e^{-z})}{e^{-z}}\geq\log 2 holds for any z0z\geq 0. Then, for any i[n]i\in[n], we can set z=mincyi(𝒆yi𝒆c)T𝑾t𝒉iz=\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i} to obtain the desired inequality; and (b) is by (19). Finally, taking log\log on both sides leads to the result. ∎

Next Lemma upper bounds the max-norm of SignGD iterates using learning rates. It is used in the proof of Theorem C.4.

Lemma C.3 (SignGD 𝑾tmax{\left\|{\bm{W}}_{t}\right\|_{\max}}).

For SignGD, we have for any t>0t>0 that

𝑾tmax𝑾0max+s=0t1ηs.\displaystyle{\left\|{\bm{W}}_{t}\right\|_{\max}}\leq{\left\|{\bm{W}}_{0}\right\|_{\max}}+\sum_{s=0}^{t-1}\eta_{s}.
Proof.

By the SignGD update rule (4), we have

𝑾t+1=𝑾0s=0tηssign((𝑾t)).\displaystyle{\bm{W}}_{t+1}={\bm{W}}_{0}-\sum_{s=0}^{t}\eta_{s}\texttt{sign}(\nabla\mathcal{L}({\bm{W}}_{t})).

This leads to 𝑾tmax𝑾0max+s=0t1ηs{\left\|{\bm{W}}_{t}\right\|_{\max}}\leq{\left\|{\bm{W}}_{0}\right\|_{\max}}+\sum_{s=0}^{t-1}\eta_{s} . ∎

The main step in the proof of Theorem C.4 is to determine the time that satisfies the assumption in Lemma C.2 and show the convergence of 𝒢(𝑾t)(𝑾t)\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}. After this, Lemma C.2 and Lemma C.3 will be combined to obtain the final result.

Theorem C.4.

Suppose that Assumption 3.1, 3.3, and 3.5 hold, then there exists ts2=ts2(n,γ,B,𝐖0)t_{s_{2}}=t_{s_{2}}(n,\gamma,B,{\bm{W}}_{0}) such that SignGD achieves the following for all t>ts2t>t_{s_{2}}

|mini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉i𝑾tmaxγ|\displaystyle\left|\frac{\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}{{\left\|{\bm{W}}_{t}\right\|_{\max}}}-\gamma\right| 𝒪(s=ts2t1ηseγ4τ=ts2s1ητ+s=0ts21ηs+s=ts2t1ηs2s=0t1ηs).\displaystyle\leq\mathcal{O}\Bigg{(}\frac{\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}e^{-\frac{\gamma}{4}\sum_{\tau=t_{s_{2}}}^{s-1}\eta_{\tau}}+\sum_{s=0}^{t_{s_{2}}-1}\eta_{s}+\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}^{2}}{\sum_{s=0}^{t-1}\eta_{s}}\Bigg{)}.
Proof.

Determination of ts1t_{s_{1}}. In Lemma C.1 we choose ts1t_{s_{1}} such that ηt12αs1\eta_{t}\leq\frac{1}{2\alpha_{s_{1}}} for all tts1t\geq t_{s_{1}}. Considering ηt=Θ(1ta)\eta_{t}=\Theta(\frac{1}{t^{a}}) (where a(0,1]a\in(0,1]), we set ts1=(2αs1)1a=(4B2e2Bη0γ)1at_{s_{1}}=(2\alpha_{s_{1}})^{\frac{1}{a}}=(\frac{4B^{2}e^{2B\eta_{0}}}{\gamma})^{\frac{1}{a}}. Then, we have for all tts1t\geq t_{s_{1}}

(𝑾t+1)(𝑾t)ηtγ2𝒢(𝑾t).\displaystyle\mathcal{L}({\bm{W}}_{t+1})\leq\mathcal{L}({\bm{W}}_{t})-\frac{\eta_{t}\gamma}{2}{\mathcal{G}}({\bm{W}}_{t}). (20)

Rearranging this equation and using non-negativity of the loss we obtain γs=ts1tηs𝒢(𝑾s)2(𝑾ts1)\gamma\sum_{s=t_{s_{1}}}^{t}\eta_{s}{\mathcal{G}}({\bm{W}}_{s})\leq 2\mathcal{L}({\bm{W}}_{t_{s_{1}}}).
Determination of ts2t_{s_{2}}. By Lemma B.2, we can bound (𝑾ts1)\mathcal{L}({\bm{W}}_{t_{s_{1}}}) as follows

|(𝑾ts1)(𝑾0)|2B𝑾ts1𝑾0max2Bs=0ts11ηssign((𝑾s))max2Bs=0ts11ηs,\displaystyle|\mathcal{L}({\bm{W}}_{t_{s_{1}}})-\mathcal{L}({\bm{W}}_{0})|\leq 2B{\left\|{\bm{W}}_{t_{s_{1}}}-{\bm{W}}_{0}\right\|_{\max}}\leq 2B\sum_{s=0}^{t_{s_{1}}-1}\eta_{s}{\left\|\texttt{sign}(\nabla\mathcal{L}({\bm{W}}_{s}))\right\|_{\max}}\leq 2B\sum_{s=0}^{t_{s_{1}}-1}\eta_{s},

where the last inequality is by Lemma D.1. Combining this with the result above and letting ~log2n\tilde{\mathcal{L}}\coloneqq\frac{\log 2}{n}, we obtain

𝒢(𝑾t)=mins[ts1,ts2]𝒢(𝑾s)2(𝑾0)+4Bs=0ts11ηsγs=ts1ts2ηs~212n,\displaystyle{\mathcal{G}}({\bm{W}}_{t^{*}})=\min_{s\in[t_{s_{1}},t_{s_{2}}]}{\mathcal{G}}({\bm{W}}_{s})\leq\frac{2\mathcal{L}({\bm{W}}_{0})+4B\sum_{s=0}^{t_{s_{1}}-1}\eta_{s}}{\gamma\sum_{s=t_{s_{1}}}^{t_{s_{2}}}\eta_{s}}\leq\frac{\tilde{\mathcal{L}}}{2}\leq\frac{1}{2n},

from which we derive the sufficient condition on ts2t_{s_{2}} to be s=ts1ts2ηs4(𝑾0)+8Bs=0ts11ηsγ~\sum_{s=t_{s_{1}}}^{t_{s_{2}}}\eta_{s}\geq\frac{4\mathcal{L}({\bm{W}}_{0})+8B\sum_{s=0}^{t_{s_{1}}-1}\eta_{s}}{\gamma\tilde{\mathcal{L}}}.
Convergence of 𝒢(Wt)(Wt)\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})} Given 𝒢(𝑾t)~212n{\mathcal{G}}({\bm{W}}_{t^{*}})\leq\frac{\tilde{\mathcal{L}}}{2}\leq\frac{1}{2n}, we obtain that (𝑾t)(𝑾t)2𝒢(𝑾t)L~\mathcal{L}({\bm{W}}_{t})\leq\mathcal{L}({\bm{W}}_{t^{*}})\leq 2{\mathcal{G}}({\bm{W}}_{t^{*}})\leq\tilde{L} for all tts2t\geq t_{s_{2}}, where the first and second inequalities are due to monotonicity in the risk and Lemma B.3, respectively. Thus, the other sufficient condition (𝑾t)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} in Lemma B.3 is satisfied, from which we conclude that (𝑾t)2𝒢(𝑾t)\mathcal{L}({\bm{W}}_{t})\leq 2{\mathcal{G}}({\bm{W}}_{t}) for all tts2t\geq t_{s_{2}}. Substituting this into (20), we obtain for all t>ts2t>t_{s_{2}}

(𝑾t)(1γηt14)(𝑾t1)(𝑾ts2)eγ4s=ts2t1ηs~eγ4s=ts2t1ηs\displaystyle\mathcal{L}({\bm{W}}_{t})\leq(1-\frac{\gamma\eta_{t-1}}{4})\mathcal{L}({\bm{W}}_{t-1})\leq\mathcal{L}({\bm{W}}_{t_{s_{2}}})e^{-\frac{\gamma}{4}\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}}\leq\tilde{\mathcal{L}}e^{-\frac{\gamma}{4}\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}}

Then, by Lemma B.3, we obtain

𝒢(𝑾t)(𝑾t)1n(𝑾t)21n~eγ4s=ts2t1ηs21eγ4s=ts2t1ηs.\displaystyle\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}\geq 1-\frac{n\mathcal{L}({\bm{W}}_{t})}{2}\geq 1-\frac{n\tilde{\mathcal{L}}e^{-\frac{\gamma}{4}\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}}}{2}\geq 1-e^{-\frac{\gamma}{4}\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}}. (21)

Margin Convergence Finally, we combine Lemma C.2, Lemma C.3, and (21) to obtain

|mini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉i𝑾tmaxγ|\displaystyle|\frac{\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}{{\left\|{\bm{W}}_{t}\right\|_{\max}}}-\gamma| γ(𝑾0max+s=ts2t1ηseγ4τ=ts2s1ητ+s=0ts21ηs)+αs2s=ts2t1ηs2𝑾0max+s=0t1ηs\displaystyle\leq\frac{\gamma\bigl{(}{\left\|{\bm{W}}_{0}\right\|_{\max}}+\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}e^{-\frac{\gamma}{4}\sum_{\tau=t_{s_{2}}}^{s-1}\eta_{\tau}}+\sum_{s=0}^{t_{s_{2}}-1}\eta_{s}\bigr{)}+\alpha_{s_{2}}\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}^{2}}{{\left\|{\bm{W}}_{0}\right\|_{\max}}+\sum_{s=0}^{t-1}\eta_{s}}
𝒪(s=ts2t1ηseγ4τ=ts2s1ητ+s=0ts21ηs+s=ts2t1ηs2s=0t1ηs)\displaystyle\leq\mathcal{O}(\frac{\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}e^{-\frac{\gamma}{4}\sum_{\tau=t_{s_{2}}}^{s-1}\eta_{\tau}}+\sum_{s=0}^{t_{s_{2}}-1}\eta_{s}+\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}^{2}}{\sum_{s=0}^{t-1}\eta_{s}})

Next, we explicitly upper bound ts2t_{s_{2}} in Theorem C.4 and derive the margin convergence rates of SignGD.

Corollary C.5.

Consider learning rate schedule of the form ηt=Θ(1ta)\eta_{t}=\Theta(\frac{1}{t^{a}}) where a(0,1]a\in(0,1], under the same setting as Theorem C.4, then we have for SignGD

|mini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉i𝑾tmaxγ|={𝒪(t12a+nt1a)ifa<12𝒪(logt+nt1/2)ifa=12𝒪(nt1a)if12<a<1𝒪(nlogt)ifa=1|\frac{\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}{{\left\|{\bm{W}}_{t}\right\|_{\max}}}-\gamma|=\left\{\begin{array}[]{ll}\mathcal{O}(\frac{t^{1-2a}+n}{t^{1-a}})&\text{if}\quad a<\frac{1}{2}\\ \mathcal{O}(\frac{\log t+n}{t^{1/2}})&\text{if}\quad a=\frac{1}{2}\\ \mathcal{O}(\frac{n}{t^{1-a}})&\text{if}\quad\frac{1}{2}<a<1\\ \mathcal{O}(\frac{n}{\log t})&\text{if}\quad a=1\end{array}\right.
Proof.

Recall that ts1=(4B2e2Bη0γ)1a=:Cs1t_{s_{1}}=(\frac{4B^{2}e^{2B\eta_{0}}}{\gamma})^{\frac{1}{a}}=:C_{s_{1}}, and the condition on ts2t_{s_{2}} is s=ts1ts2ηs4(𝑾0)+8Bs=0ts11ηsγ~\sum_{s=t_{s_{1}}}^{t_{s_{2}}}\eta_{s}\geq\frac{4\mathcal{L}({\bm{W}}_{0})+8B\sum_{s=0}^{t_{s_{1}}-1}\eta_{s}}{\gamma\tilde{\mathcal{L}}}, where L~=log2n\tilde{L}=\frac{\log 2}{n}. We can apply integral approximations to the terms that involve sums of learning rates to obtain

ts2Cs2n11ats1+Cs3n11a(𝑾0)11a.\displaystyle t_{s_{2}}\leq C_{s_{2}}n^{\frac{1}{1-a}}t_{s_{1}}+C_{s_{3}}n^{\frac{1}{1-a}}\mathcal{L}({\bm{W}}_{0})^{\frac{1}{1-a}}.

Given ts1t_{s_{1}} is some constant, this further implies that

s=0ts21ηs=𝒪(ts21a)=𝒪(n+n(𝑾0)).\displaystyle\sum_{s=0}^{t_{s_{2}}-1}\eta_{s}=\mathcal{O}(t_{s_{2}}^{1-a})=\mathcal{O}(n+n\mathcal{L}({\bm{W}}_{0})).

Next, we focus on the term s=ts2t1ηs2\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}^{2}. For a>12a>\frac{1}{2}, this term can be bounded by some constant. For a<12a<\frac{1}{2}, we have s=ts2t1ηs2=𝒪(t12a)\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}^{2}=\mathcal{O}(t^{1-2a}), and it evaluates to 𝒪(logt)\mathcal{O}(\log t) for a=12a=\frac{1}{2}. Finally, we have that s=0t1ηs=𝒪(t1a)\sum_{s=0}^{t-1}\eta_{s}=\mathcal{O}(t^{1-a}) for a<1a<1 and s=0t1ηs=𝒪(logt)\sum_{s=0}^{t-1}\eta_{s}=\mathcal{O}(\log t) for a=1a=1. The rest arguments can be found in Zhang et al. (2024, Corollary 4.7 and Lemma C.1), including showing the learning rate schedule ηt=1(t+2)a\eta_{t}=\frac{1}{(t+2)^{a}} satisfying Assumption 3.4, and the term s=ts2t1ηseγ4τ=ts2s1ητ\sum_{s=t_{s_{2}}}^{t-1}\eta_{s}e^{-\frac{\gamma}{4}\sum_{\tau=t_{s_{2}}}^{s-1}\eta_{\tau}} is bounded by some constant for all a(0,1]a\in(0,1]. ∎

Appendix D Proof of Adam

The proof of Adam follows the similar approach as SignGD. The key challenge is to connect 𝑴t\bm{M}_{t} and 𝑽t{\bm{V}}_{t} to a per-class decomposition of 𝒢(𝑾t){\mathcal{G}}({\bm{W}}_{t}). The following Lemma in ([)Lemma 6.5]zhang2024implicit is useful. It provides an entry-wise bound on the ratio between the first moment and square root of the second moment.

Lemma D.1.

Considering the Adam updates given in (3a), (3b), and (3c), suppose that β1β2\beta_{1}\leq\beta_{2} and set α=β2(1β1)2(1β2)(β2β12)2\alpha=\sqrt{\frac{\beta_{2}(1-\beta_{1})^{2}}{(1-\beta_{2})(\beta_{2}-\beta_{1}^{2})^{2}}}, then we obtain 𝐌t[c,j]α𝐕t[c,j]\bm{M}_{t}[c,j]\leq\alpha\cdot\sqrt{{\bm{V}}_{t}[c,j]} for all c[k]c\in[k] and j[d]j\in[d].

The following Lemma bounds the first moment buffer (𝑴t\bm{M}_{t}) of Adam in terms of the product of ηt\eta_{t} with 𝒢c(𝑾t){\mathcal{G}}_{c}({\bm{W}}_{t}) and 𝒬c(𝑾t)\mathcal{Q}_{c}({\bm{W}}_{t}). It is used in the proof of Lemma D.4.

Lemma D.2.

Let c[k]c\in[k]. Under the same setting as Theorem D.8, there exists a time t0t_{0} such that the following holds for all tt0t\geq t_{0}

|𝐌t[c,j](1β1t+1)(𝑾t)[c,j]|\displaystyle|\mathbf{M}_{t}[c,j]-(1-\beta_{1}^{t+1})\nabla\mathcal{L}({\bm{W}}_{t})[c,j]| αMηt(𝒢c(𝑾t)+𝒬c(𝑾t)),\displaystyle\leq\alpha_{M}\eta_{t}({\mathcal{G}}_{c}({\bm{W}}_{t})+\mathcal{Q}_{c}({\bm{W}}_{t})),

where j[d]j\in[d] and αM\alpha_{M} is some constant that depends on BB and β1\beta_{1}.

Proof.

For any fixed c[k]c\in[k] and j[d]j\in[d],

|𝐌t[c,j](1β1t+1)(𝑾t)[c,j]|\displaystyle|\mathbf{M}_{t}[c,j]-(1-\beta_{1}^{t+1})\nabla\mathcal{L}({\bm{W}}_{t})[c,j]| =|τ=0t(1β1)β1τ((𝑾tτ)[c,j](𝑾t)[c,j])|\displaystyle=|\sum_{\tau=0}^{t}(1-\beta_{1})\beta_{1}^{\tau}\bigl{(}\nabla\mathcal{L}({\bm{W}}_{t-\tau})[c,j]-\nabla\mathcal{L}({\bm{W}}_{t})[c,j]\bigr{)}|
τ=0t(1β1)β1τ|(𝑾tτ)[c,j](𝑾t)[c,j]|.\displaystyle\leq\sum_{\tau=0}^{t}(1-\beta_{1})\beta_{1}^{\tau}\underbrace{|\nabla\mathcal{L}({\bm{W}}_{t-\tau})[c,j]-\nabla\mathcal{L}({\bm{W}}_{t})[c,j]|}_{\clubsuit}. (22)

We first notice that for any 𝑾k×d{\bm{W}}\in\mathbb{R}^{k\times d}, we have (𝑾)[c,j]=𝒆cT(𝑾)𝒆j=1ni[n]𝒆cT(𝒆yi𝕊(𝑾𝒉i))𝒉iT𝒆j=1ni[n]𝒆cT(𝒆yi𝕊(𝑾𝒉i))hij\nabla\mathcal{L}({\bm{W}})[c,j]=\bm{e}_{c}^{T}\nabla\mathcal{L}({\bm{W}})\bm{e}_{j}=-\frac{1}{n}\sum_{i\in[n]}\bm{e}_{c}^{T}\bigl{(}\bm{e}_{y_{i}}-\mathbb{S}({\bm{W}}\bm{h}_{i})\bigr{)}\bm{h}_{i}^{T}\bm{e}_{j}=-\frac{1}{n}\sum_{i\in[n]}\bm{e}_{c}^{T}\bigl{(}\bm{e}_{y_{i}}-\mathbb{S}({\bm{W}}\bm{h}_{i})\bigr{)}h_{ij}. Then, the gradient difference term becomes

\displaystyle\clubsuit =|1ni[n]𝒆cT(𝒆yi𝕊(𝑾tτ𝒉i))hij+1ni[n]𝒆cT(𝒆yi𝕊(𝑾t𝒉i))hij|\displaystyle=|-\frac{1}{n}\sum_{i\in[n]}\bm{e}_{c}^{T}\bigl{(}\bm{e}_{y_{i}}-\mathbb{S}({\bm{W}}_{t-\tau}\bm{h}_{i})\bigr{)}h_{ij}+\frac{1}{n}\sum_{i\in[n]}\bm{e}_{c}^{T}\bigl{(}\bm{e}_{y_{i}}-\mathbb{S}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}h_{ij}|
=|1ni[n]𝒆cT(𝕊(𝑾tτ𝒉i)𝕊(𝑾t𝒉i))hij|\displaystyle=|\frac{1}{n}\sum_{i\in[n]}\bm{e}_{c}^{T}\bigl{(}\mathbb{S}({\bm{W}}_{t-\tau}\bm{h}_{i})-\mathbb{S}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}h_{ij}|
=|1ni[n](𝕊c(𝑾tτ𝒉i)𝕊c(𝑾t𝒉i))hij|\displaystyle=|\frac{1}{n}\sum_{i\in[n]}\bigl{(}\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}h_{ij}|
B1ni[n]|𝕊c(𝑾tτ𝒉i)𝕊c(𝑾t𝒉i)|\displaystyle\leq B\frac{1}{n}\sum_{i\in[n]}|\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})|
=B1ni[n],yic|𝕊c(𝑾tτ𝒉i)𝕊c(𝑾t𝒉i)|1+B1ni[n],yi=c|𝕊c(𝑾tτ𝒉i)𝕊c(𝑾t𝒉i)|2\displaystyle=B\underbrace{\frac{1}{n}\sum_{i\in[n],y_{i}\neq c}|\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})|}_{\clubsuit_{1}}+B\underbrace{\frac{1}{n}\sum_{i\in[n],y_{i}=c}|\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})|}_{\clubsuit_{2}}

Next, we link the 1\clubsuit_{1} and 2\clubsuit_{2} terms with 𝒢(𝑾){\mathcal{G}}({\bm{W}}). Starting with the first term, we obtain:

1\displaystyle\clubsuit_{1} =1ni[n],yic𝕊c(𝑾t𝒉i)|𝕊c(𝑾tτ𝒉i)𝕊c(𝑾t𝒉i)1|\displaystyle=\frac{1}{n}\sum_{i\in[n],y_{i}\neq c}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})|\frac{\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})}{\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})}-1|
(a)1ni[n],yic𝕊c(𝑾t𝒉i)(e2(𝑾tτ𝑾t)𝒉i1)\displaystyle\stackrel{{\scriptstyle(a)}}{{\leq}}\frac{1}{n}\sum_{i\in[n],y_{i}\neq c}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})(e^{2\lVert({\bm{W}}_{t-\tau}-{\bm{W}}_{t})\bm{h}_{i}\rVert_{\infty}}-1)
(b)1ni[n],yic𝕊c(𝑾t𝒉i)(e2B𝑾tτ𝑾tmax1)\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}\frac{1}{n}\sum_{i\in[n],y_{i}\neq c}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})(e^{2B{\left\|{\bm{W}}_{t-\tau}-{\bm{W}}_{t}\right\|_{\max}}}-1)
(c)(e2Bs=1τηts𝑴ts𝑽tsmax1)(1ni[n],yic𝕊c(𝑾t𝒉i))\displaystyle\stackrel{{\scriptstyle(c)}}{{\leq}}\bigl{(}e^{2B\sum_{s=1}^{\tau}\eta_{t-s}{\left\|\frac{\bm{M}_{t-s}}{\sqrt{{\bm{V}}_{t-s}}}\right\|_{\max}}}-1\bigr{)}\bigl{(}\frac{1}{n}\sum_{i\in[n],y_{i}\neq c}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}
(d)(e2αBs=1τηts1)𝒬c(𝑾t),\displaystyle\stackrel{{\scriptstyle(d)}}{{\leq}}\bigl{(}e^{2\alpha B\sum_{s=1}^{\tau}\eta_{t-s}}-1\bigr{)}\mathcal{Q}_{c}({\bm{W}}_{t}),

where (a) is by Lemma D.10, (b) is by 𝒉i1B\lVert\bm{h}_{i}\rVert_{1}\leq B for all i[n]i\in[n], (c) is by (3c) and triangle inequality, and (d)(d) is by Lemma D.1 and the definition of 𝒢(𝑾t){\mathcal{G}}({\bm{W}}_{t}). For the second term, we obtain:

2\displaystyle\clubsuit_{2} =1ni[n],yi=c|𝕊c(𝑾tτ𝒉i)𝕊c(𝑾t𝒉i)|\displaystyle=\frac{1}{n}\sum_{i\in[n],y_{i}=c}|\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})|
=1ni[n],yi=c|𝕊yi(𝑾tτ𝒉i)1+1𝕊yi(𝑾t𝒉i)|\displaystyle=\frac{1}{n}\sum_{i\in[n],y_{i}=c}|\mathbb{S}_{y_{i}}({\bm{W}}_{t-\tau}\bm{h}_{i})-1+1-\mathbb{S}_{y_{i}}({\bm{W}}_{t}\bm{h}_{i})|
=1ni[n],yi=c(1𝕊yi(𝑾t𝒉i))|𝕊yi(𝑾tτ𝒉i)11𝕊yi(𝑾t𝒉i)+1|\displaystyle=\frac{1}{n}\sum_{i\in[n],y_{i}=c}\bigl{(}1-\mathbb{S}_{y_{i}}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}|\frac{\mathbb{S}_{y_{i}}({\bm{W}}_{t-\tau}\bm{h}_{i})-1}{1-\mathbb{S}_{y_{i}}({\bm{W}}_{t}\bm{h}_{i})}+1|
=1ni[n],yi=c(1𝕊yi(𝑾t𝒉i))|1𝕊yi(𝑾tτ𝒉i)1𝕊yi(𝑾t𝒉i)1|\displaystyle=\frac{1}{n}\sum_{i\in[n],y_{i}=c}\bigl{(}1-\mathbb{S}_{y_{i}}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}|\frac{1-\mathbb{S}_{y_{i}}({\bm{W}}_{t-\tau}\bm{h}_{i})}{1-\mathbb{S}_{y_{i}}({\bm{W}}_{t}\bm{h}_{i})}-1|
(e)1ni[n],yi=c(1𝕊yi(𝑾t𝒉i))(e2(𝑾tτ𝑾t)𝒉i1)\displaystyle\stackrel{{\scriptstyle(e)}}{{\leq}}\frac{1}{n}\sum_{i\in[n],y_{i}=c}\bigl{(}1-\mathbb{S}_{y_{i}}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}(e^{2\lVert({\bm{W}}_{t-\tau}-{\bm{W}}_{t})\bm{h}_{i}\rVert_{\infty}}-1)
(f)(e2αBs=1τηts1)𝒢c(𝑾t),\displaystyle\stackrel{{\scriptstyle(f)}}{{\leq}}\bigl{(}e^{2\alpha B\sum_{s=1}^{\tau}\eta_{t-s}}-1\bigr{)}{\mathcal{G}}_{c}({\bm{W}}_{t}),

where (e) is by Lemma D.10, and (f) is by the same approach taken for 1\clubsuit_{1}. Based on the upper bounds for 1\clubsuit_{1} and 2\clubsuit_{2}, we obtain the following: 2B(e2αBs=1τηts1)(𝒢c(𝑾t)+𝒬c(𝑾t))\clubsuit\leq 2B\bigl{(}e^{2\alpha B\sum_{s=1}^{\tau}\eta_{t-s}}-1\bigr{)}({\mathcal{G}}_{c}({\bm{W}}_{t})+\mathcal{Q}_{c}({\bm{W}}_{t})). Then, we substitute this into (22) to obtain:

|𝐌t[c,j](1β1t+1)(𝑾t)[c,j]|\displaystyle|\mathbf{M}_{t}[c,j]-(1-\beta_{1}^{t+1})\nabla\mathcal{L}({\bm{W}}_{t})[c,j]| B(1β1)(𝒢c(𝑾t)+𝒬c(𝑾t))τ=0tβ1τ(e2αBs=1τηts1)\displaystyle\leq B(1-\beta_{1})({\mathcal{G}}_{c}({\bm{W}}_{t})+\mathcal{Q}_{c}({\bm{W}}_{t}))\sum_{\tau=0}^{t}\beta_{1}^{\tau}\bigl{(}e^{2\alpha B\sum_{s=1}^{\tau}\eta_{t-s}}-1\bigr{)}
(g)B(1β1)c2ηt(𝒢c(𝑾t)+𝒬c(𝑾t)),\displaystyle\stackrel{{\scriptstyle(g)}}{{\leq}}B(1-\beta_{1})c_{2}\eta_{t}({\mathcal{G}}_{c}({\bm{W}}_{t})+\mathcal{Q}_{c}({\bm{W}}_{t})),

where (g) is by the Assumption 3.4. ∎

The following Lemma bounds the first moment buffer (𝑽t{\bm{V}}_{t}) of Adam in terms of the product of ηt\eta_{t} and with 𝒢c(𝑾t){\mathcal{G}}_{c}({\bm{W}}_{t}) and 𝒬c(𝑾t)\mathcal{Q}_{c}({\bm{W}}_{t}). It is used in the proof of Lemma D.4.

Lemma D.3.

Let c[k]c\in[k]. Under the same setting as Theorem D.8, there exists a time t0t_{0} such that the following holds for all tt0t\geq t_{0}

|𝑽t[c,j](1β2t+1)|(𝑾t)[c,j]||\displaystyle\bigm{|}\sqrt{{\bm{V}}_{t}[c,j]}-\sqrt{(1-\beta_{2}^{t+1})}|\nabla\mathcal{L}({\bm{W}}_{t})[c,j]|\bigm{|} αVηt(𝒬c(𝑾t)+𝒢c(𝑾t)),\displaystyle\leq\alpha_{V}\sqrt{\eta_{t}}(\mathcal{Q}_{c}({\bm{W}}_{t})+{\mathcal{G}}_{c}({\bm{W}}_{t})),

where j[d]j\in[d], and αV\alpha_{V} is some constant that depends on BB and β2\beta_{2}.

Proof.

Consider any fixed c[k]c\in[k] and j[d]j\in[d],

|𝑽t[c,j](1β2t+1)(𝑾t)[c,j]2|\displaystyle|{\bm{V}}_{t}[c,j]-(1-\beta_{2}^{t+1})\nabla\mathcal{L}({\bm{W}}_{t})[c,j]^{2}| =|τ=0t(1β2)β2τ((𝑾tτ)[c,j]2(𝑾t)[c,j]2)|\displaystyle=|\sum_{\tau=0}^{t}(1-\beta_{2})\beta_{2}^{\tau}\bigl{(}\nabla\mathcal{L}({\bm{W}}_{t-\tau})[c,j]^{2}-\nabla\mathcal{L}({\bm{W}}_{t})[c,j]^{2}\bigr{)}|
τ=0t(1β2)β2τ|(𝑾tτ)[c,j]2(𝑾t)[c,j]2|.\displaystyle\leq\sum_{\tau=0}^{t}(1-\beta_{2})\beta_{2}^{\tau}\underbrace{|\nabla\mathcal{L}({\bm{W}}_{t-\tau})[c,j]^{2}-\nabla\mathcal{L}({\bm{W}}_{t})[c,j]^{2}|}_{\spadesuit}. (23)

For any 𝑾k×d{\bm{W}}\in\mathbb{R}^{k\times d}, recall that 1ni[n]𝒆cT(𝒆yi𝕊(𝑾𝒉i))hij-\frac{1}{n}\sum_{i\in[n]}\bm{e}_{c}^{T}\bigl{(}\bm{e}_{y_{i}}-\mathbb{S}({\bm{W}}\bm{h}_{i})\bigr{)}h_{ij}. Then, we can obtain (𝑾)[c,j]2=1n2i[n]p[n]hijhpj(δcyi𝕊c(𝑾𝒉i))(δcyp𝕊c(𝑾𝒉p))\nabla\mathcal{L}({\bm{W}})[c,j]^{2}=\frac{1}{n^{2}}\sum_{i\in[n]}\sum_{p\in[n]}h_{ij}h_{pj}(\delta_{cy_{i}}-\mathbb{S}_{c}({\bm{W}}\bm{h}_{i}))(\delta_{cy_{p}}-\mathbb{S}_{c}({\bm{W}}\bm{h}_{p})) where δcy=1\delta_{cy}=1 if and only if c=yc=y. Next, we define the function fc,i,pf_{c,i,p} to be fc,i,p(𝑾):=(δcyi𝕊c(𝑾𝒉i))(δcyp𝕊c(𝑾𝒉p))f_{c,i,p}({\bm{W}}):=(\delta_{cy_{i}}-\mathbb{S}_{c}({\bm{W}}\bm{h}_{i}))(\delta_{cy_{p}}-\mathbb{S}_{c}({\bm{W}}\bm{h}_{p})). Then, we have

|fc,i,p(𝑾tτ)fc,i,p(𝑾t)|\displaystyle|f_{c,i,p}({\bm{W}}_{t-\tau})-f_{c,i,p}({\bm{W}}_{t})| =δcyi(𝕊c(𝑾t𝒉p)𝕊c(𝑾tτ𝒉p))+δcyp(𝕊c(𝑾t𝒉i)𝕊c(𝑾tτ𝒉i))\displaystyle=\delta_{cy_{i}}\bigl{(}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})-\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{p})\bigr{)}+\delta_{cy_{p}}\bigl{(}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})-\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})\bigr{)}
+(𝕊c(𝑾tτ𝒉i)𝕊c(𝑾tτ𝒉p)𝕊c(𝑾t𝒉i)𝕊c(𝑾t𝒉p))\displaystyle\quad\quad+\bigl{(}\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{p})-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigr{)}

We can substitute this result into \spadesuit to obtain

\displaystyle\spadesuit =|1n2i[n]p[n]hijhpj(fc,i,p(𝑾tτ)fc,i,p(𝑾t))|\displaystyle=|\frac{1}{n^{2}}\sum_{i\in[n]}\sum_{p\in[n]}h_{ij}h_{pj}(f_{c,i,p}({\bm{W}}_{t-\tau})-f_{c,i,p}({\bm{W}}_{t}))|
B2n2i[n]p[n]|fc,i,p(𝑾tτ)fc,i,p(𝑾t)|\displaystyle\leq\frac{B^{2}}{n^{2}}\sum_{i\in[n]}\sum_{p\in[n]}|f_{c,i,p}({\bm{W}}_{t-\tau})-f_{c,i,p}({\bm{W}}_{t})|
=B21n2i[n],yicp[n],ypc|fc,i,p(𝑾tτ)fc,i,p(𝑾t)|1\displaystyle=B^{2}\underbrace{\frac{1}{n^{2}}\sum_{i\in[n],y_{i}\neq c}\sum_{p\in[n],y_{p}\neq c}|f_{c,i,p}({\bm{W}}_{t-\tau})-f_{c,i,p}({\bm{W}}_{t})|}_{\spadesuit_{1}}
+B21n2i[n],yicp[n],yp=c|fc,i,p(𝑾tτ)fc,i,p(𝑾t)|2\displaystyle\quad\quad+B^{2}\underbrace{\frac{1}{n^{2}}\sum_{i\in[n],y_{i}\neq c}\sum_{p\in[n],y_{p}=c}|f_{c,i,p}({\bm{W}}_{t-\tau})-f_{c,i,p}({\bm{W}}_{t})|}_{\spadesuit_{2}}
+B21n2i[n],yi=cp[n],ypc|fc,i,p(𝑾tτ)fc,i,p(𝑾t)|3\displaystyle\quad\quad+B^{2}\underbrace{\frac{1}{n^{2}}\sum_{i\in[n],y_{i}=c}\sum_{p\in[n],y_{p}\neq c}|f_{c,i,p}({\bm{W}}_{t-\tau})-f_{c,i,p}({\bm{W}}_{t})|}_{\spadesuit_{3}}
+B21n2i[n],yi=cp[n],yp=c|fc,i,p(𝑾tτ)fc,i,p(𝑾t)|4\displaystyle\quad\quad+B^{2}\underbrace{\frac{1}{n^{2}}\sum_{i\in[n],y_{i}=c}\sum_{p\in[n],y_{p}=c}|f_{c,i,p}({\bm{W}}_{t-\tau})-f_{c,i,p}({\bm{W}}_{t})|}_{\spadesuit_{4}}

We deal with the 44 terms 1,2,3\spadesuit_{1},\spadesuit_{2},\spadesuit_{3}, and 4\spadesuit_{4} separately. Starting with the first term, we have

1\displaystyle\spadesuit_{1} =1n2i[n],yicp[n],ypc|𝕊c(𝑾tτ𝒉i)𝕊c(𝑾tτ𝒉p)𝕊c(𝑾t𝒉i)𝕊c(𝑾t𝒉p)|\displaystyle=\frac{1}{n^{2}}\sum_{i\in[n],y_{i}\neq c}\sum_{p\in[n],y_{p}\neq c}|\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{p})-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})|
=1n2i[n],yicp[n],ypc𝕊c(𝑾t𝒉i)𝕊c(𝑾t𝒉p)|𝕊c(𝑾tτ𝒉i)𝕊c(𝑾tτ𝒉p)𝕊c(𝑾t𝒉i)𝕊c(𝑾t𝒉p)1|\displaystyle=\frac{1}{n^{2}}\sum_{i\in[n],y_{i}\neq c}\sum_{p\in[n],y_{p}\neq c}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})|\frac{\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{p})}{\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})}-1|
(a)1n2i[n],yicp[n],ypc𝕊c(𝑾t𝒉i)𝕊c(𝑾t𝒉p)(e2((𝑾tτ𝑾t)𝒉i+(𝑾tτ𝑾t)𝒉p)1)\displaystyle\stackrel{{\scriptstyle(a)}}{{\leq}}\frac{1}{n^{2}}\sum_{i\in[n],y_{i}\neq c}\sum_{p\in[n],y_{p}\neq c}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigl{(}e^{2\bigl{(}\lVert({\bm{W}}_{t-\tau}-{\bm{W}}_{t})\bm{h}_{i}\rVert_{\infty}+\lVert({\bm{W}}_{t-\tau}-{\bm{W}}_{t})\bm{h}_{p}\rVert_{\infty}\bigr{)}}-1\bigr{)}
(b)1n2i[n],yicp[n],ypc𝕊c(𝑾t𝒉i)𝕊c(𝑾t𝒉p)(e4B𝑾tτ𝑾tmax1)\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}\frac{1}{n^{2}}\sum_{i\in[n],y_{i}\neq c}\sum_{p\in[n],y_{p}\neq c}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigl{(}e^{4B{\left\|{\bm{W}}_{t-\tau}-{\bm{W}}_{t}\right\|_{\max}}}-1\bigr{)}
(c)(e4Bs=1τηts𝑴ts𝑽tsmax1)1n2i[n],yicp[n],ypc𝕊c(𝑾t𝒉i)𝕊c(𝑾t𝒉p)\displaystyle\stackrel{{\scriptstyle(c)}}{{\leq}}\bigl{(}e^{4B\sum_{s=1}^{\tau}\eta_{t-s}{\left\|\frac{\bm{M}_{t-s}}{\sqrt{{\bm{V}}_{t-s}}}\right\|_{\max}}}-1\bigr{)}\frac{1}{n^{2}}\sum_{i\in[n],y_{i}\neq c}\sum_{p\in[n],y_{p}\neq c}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})
(d)(e4Bαs=1τηts1)𝒬c(𝑾t)2,\displaystyle\stackrel{{\scriptstyle(d)}}{{\leq}}\bigl{(}e^{4B\alpha\sum_{s=1}^{\tau}\eta_{t-s}}-1\bigr{)}\mathcal{Q}_{c}({\bm{W}}_{t})^{2},

where (a) is by Lemma D.10, (b) is by 𝒉i1B\lVert\bm{h}_{i}\rVert_{1}\leq B for all i[n]i\in[n], (c) is by (3c) and the triangle inequality, and (d) is by Lemma D.1 and the definition of 𝒢(𝑾t){\mathcal{G}}({\bm{W}}_{t}). For the second term, we have

2\displaystyle\spadesuit_{2} =1n2i[n],yicp[n],yp=c|(𝕊c(𝑾t𝒉i)𝕊c(𝑾tτ𝒉i))\displaystyle=\frac{1}{n^{2}}\sum_{i\in[n],y_{i}\neq c}\sum_{p\in[n],y_{p}=c}|\bigl{(}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})-\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})\bigr{)}
+(𝕊c(𝑾tτ𝒉i)𝕊c(𝑾tτ𝒉p)𝕊c(𝑾t𝒉i)𝕊c(𝑾t𝒉p))|\displaystyle\quad\quad\quad\quad+\bigl{(}\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{p})-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigr{)}|
=1n2i[n],yicp[n],yp=c|𝕊c(𝑾t𝒉i)(1𝕊c(𝑾t𝒉p))(1𝕊c(𝑾tτ𝒉p))𝕊c(𝑾tτ𝒉i)|\displaystyle=\frac{1}{n^{2}}\sum_{i\in[n],y_{i}\neq c}\sum_{p\in[n],y_{p}=c}|\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigr{)}-\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{p})\bigr{)}\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})|
=1n2i[n],yicp[n],yp=c𝕊c(𝑾t𝒉i)(1𝕊c(𝑾t𝒉p))|1(1𝕊c(𝑾tτ𝒉p))𝕊c(𝑾tτ𝒉i)(1𝕊c(𝑾t𝒉p))𝕊c(𝑾t𝒉i)|\displaystyle=\frac{1}{n^{2}}\sum_{i\in[n],y_{i}\neq c}\sum_{p\in[n],y_{p}=c}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigr{)}|1-\frac{\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{p})\bigr{)}\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})}{\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigr{)}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})}|
(e4Bαs=1τηts1)𝒬c(𝑾t)𝒢c(𝑾t),\displaystyle\leq\bigl{(}e^{4B\alpha\sum_{s=1}^{\tau}\eta_{t-s}}-1\bigr{)}\mathcal{Q}_{c}({\bm{W}}_{t}){\mathcal{G}}_{c}({\bm{W}}_{t}),

where the last inequality is by Lemma D.10 and the same steps taken for 1\spadesuit_{1}. The third term can be derived similarly as the second term and we can obtain the same bound as follows:

3\displaystyle\spadesuit_{3} =1n2i[n],yi=cp[n],ypc𝕊c(𝑾t𝒉p)(1𝕊c(𝑾t𝒉i))|1(1𝕊c(𝑾tτ𝒉p))𝕊c(𝑾tτ𝒉i)(1𝕊c(𝑾t𝒉i))𝕊c(𝑾t𝒉p)|\displaystyle=\frac{1}{n^{2}}\sum_{i\in[n],y_{i}=c}\sum_{p\in[n],y_{p}\neq c}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}|1-\frac{\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{p})\bigr{)}\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})}{\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})}|
(e4Bαs=1τηts1)𝒬c(𝑾t)𝒢c(𝑾t).\displaystyle\leq\bigl{(}e^{4B\alpha\sum_{s=1}^{\tau}\eta_{t-s}}-1\bigr{)}\mathcal{Q}_{c}({\bm{W}}_{t}){\mathcal{G}}_{c}({\bm{W}}_{t}).

For the fourth term, we obtain:

4\displaystyle\spadesuit_{4} =1n2i[n],yi=cp[n],yp=c|(1𝕊c(𝑾tτ𝒉i))(1𝕊c(𝑾tτ𝒉p))(1𝕊c(𝑾t𝒉i))(1𝕊c(𝑾t𝒉p))|\displaystyle=\frac{1}{n^{2}}\sum_{i\in[n],y_{i}=c}\sum_{p\in[n],y_{p}=c}|\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})\bigr{)}\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{p})\bigr{)}-\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigr{)}|
=1n2i[n],yi=cp[n],yp=c(1𝕊c(𝑾t𝒉i))(1𝕊c(𝑾t𝒉p))|(1𝕊c(𝑾tτ𝒉i))(1𝕊c(𝑾tτ𝒉p))(1𝕊c(𝑾t𝒉i))(1𝕊c(𝑾t𝒉p))1|\displaystyle=\frac{1}{n^{2}}\sum_{i\in[n],y_{i}=c}\sum_{p\in[n],y_{p}=c}\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigr{)}|\frac{\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{i})\bigr{)}\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t-\tau}\bm{h}_{p})\bigr{)}}{\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{i})\bigr{)}\bigl{(}1-\mathbb{S}_{c}({\bm{W}}_{t}\bm{h}_{p})\bigr{)}}-1|
(e4Bαs=1τηts1)𝒢c(𝑾t)2,\displaystyle\leq\bigl{(}e^{4B\alpha\sum_{s=1}^{\tau}\eta_{t-s}}-1\bigr{)}{\mathcal{G}}_{c}({\bm{W}}_{t})^{2},

where the last inequality is by Lemma D.10 and the same steps taken for 1\spadesuit_{1}. We combine the bounds for 1\spadesuit_{1}, 2\spadesuit_{2}, 3\spadesuit_{3}, and 4\spadesuit_{4} to obtain: 4B2(e4Bαs=1τηts1)(𝒢c(𝑾t)+𝒬c(𝑾t))2\spadesuit\leq 4B^{2}\bigl{(}e^{4B\alpha\sum_{s=1}^{\tau}\eta_{t-s}}-1\bigr{)}({\mathcal{G}}_{c}({\bm{W}}_{t})+\mathcal{Q}_{c}({\bm{W}}_{t}))^{2}. Then, we substitute this into (23) to obtain:

|𝑽t[c,j](1β2t+1)(𝑾t)[c,j]2|\displaystyle|{\bm{V}}_{t}[c,j]-(1-\beta_{2}^{t+1})\nabla\mathcal{L}({\bm{W}}_{t})[c,j]^{2}| B2(1β2)(𝒬c(𝑾t)+𝒢c(𝑾t))2τ=0tβ2τ(e4αBs=1τηts1)\displaystyle\leq B^{2}(1-\beta_{2})(\mathcal{Q}_{c}({\bm{W}}_{t})+{\mathcal{G}}_{c}({\bm{W}}_{t}))^{2}\sum_{\tau=0}^{t}\beta_{2}^{\tau}\bigl{(}e^{4\alpha B\sum_{s=1}^{\tau}\eta_{t-s}}-1\bigr{)}
B2(1β2)c2ηt(𝒬c(𝑾t)+𝒢c(𝑾t))2,\displaystyle\leq B^{2}(1-\beta_{2})c_{2}\eta_{t}(\mathcal{Q}_{c}({\bm{W}}_{t})+{\mathcal{G}}_{c}({\bm{W}}_{t}))^{2},

where the last inequality is by the Assumption 3.4. The final result follows from the fact that |pq|2|p2q2||p-q|^{2}\leq|p^{2}-q^{2}| when both pp and qq are positive. ∎

The following Lemma bounds the term |(𝑾t),𝑴t𝑽t(𝑾t)|(𝑾t)||\bigm{|}\langle\nabla\mathcal{L}({\bm{W}}_{t}),\frac{\bm{M}_{t}}{\sqrt{{\bm{V}}_{t}}}-\frac{\nabla\mathcal{L}({\bm{W}}_{t})}{|\nabla\mathcal{L}({\bm{W}}_{t})|}\rangle\bigm{|} using 𝒢(𝑾t){\mathcal{G}}({\bm{W}}_{t}). It is used in Lemma D.5 to show the decrease in the risk. The proof is similar to that of Zhang et al. (2024, Lemma A.3), but here we need to carefully track the index c[k]c\in[k] using both 𝒢c(𝑾){\mathcal{G}}_{c}({\bm{W}}) and 𝒬c(𝑾)\mathcal{Q}_{c}({\bm{W}}) to avoid kk dependence. The final result crucially relies on the decomposition 𝒢(𝑾t)=c[k]𝒯c(𝑾t)=c[k]𝒬c(𝑾t){\mathcal{G}}({\bm{W}}_{t})=\sum_{c\in[k]}\mathcal{T}_{c}({\bm{W}}_{t})=\sum_{c\in[k]}\mathcal{Q}_{c}({\bm{W}}_{t}).

Lemma D.4.

Under the same setting as Theorem C.4, we have

|(𝑾t),𝑴t𝑽t(𝑾t)|(𝑾t)||\displaystyle\bigm{|}\langle\nabla\mathcal{L}({\bm{W}}_{t}),\frac{\bm{M}_{t}}{\sqrt{{\bm{V}}_{t}}}-\frac{\nabla\mathcal{L}({\bm{W}}_{t})}{|\nabla\mathcal{L}({\bm{W}}_{t})|}\rangle\bigm{|} 4β1t+11β2t+1(𝑾t)sum+\displaystyle\leq 4\sqrt{\frac{\beta_{1}^{t+1}}{1-\beta_{2}^{t+1}}}{\left\|\nabla\mathcal{L}({\bm{W}}_{t})\right\|_{\rm{sum}}}+
2d1β2(6αV1β2t+1ηt+3αMηt)𝒢(𝑾t).\displaystyle\quad\quad\frac{2d}{\sqrt{1-\beta_{2}}}\bigl{(}\frac{6\alpha_{V}}{\sqrt{1-\beta_{2}^{t+1}}}\sqrt{\eta_{t}}+3\alpha_{M}\eta_{t}\bigr{)}{\mathcal{G}}({\bm{W}}_{t}).
Proof.

For simplicity, we drop the subscripts tt. Denote 𝒯c(𝑾)𝒢c(𝑾)+𝒬c(𝑾)\mathcal{T}_{c}({\bm{W}})\coloneqq{\mathcal{G}}_{c}({\bm{W}})+\mathcal{Q}_{c}({\bm{W}}). Then, by Lemmas D.2 and D.3, we have for any c[k]c\in[k] and j[d]j\in[d]:

𝑴[c,j]\displaystyle\bm{M}[c,j] =(1β1t+1)(𝑾)[c,j]+αMηt𝒯c(𝑾)ϵm,c,j,\displaystyle=(1-\beta_{1}^{t+1})\nabla\mathcal{L}({\bm{W}})[c,j]+\alpha_{M}\eta_{t}\mathcal{T}_{c}({\bm{W}})\epsilon_{m,c,j}, (24)
𝑽[c,j]\displaystyle\sqrt{{\bm{V}}[c,j]} =1β2t+1|(𝑾)[c,j]|+αVηt𝒯c(𝑾)ϵv,c,j,\displaystyle=\sqrt{1-\beta_{2}^{t+1}}|\nabla\mathcal{L}({\bm{W}})[c,j]|+\alpha_{V}\sqrt{\eta_{t}}\mathcal{T}_{c}({\bm{W}})\epsilon_{v,c,j}, (25)

where |ϵm,c,j|1|\epsilon_{m,c,j}|\leq 1 and |ϵv,c,j|1|\epsilon_{v,c,j}|\leq 1 are some residual terms. We denote ψc,j(𝑾)[c,j](𝑴[c,j]𝑽[c,j](𝑾)[c,j]|(𝑾)[c,j]|)\psi_{c,j}\coloneqq\nabla\mathcal{L}({\bm{W}})[c,j](\frac{\bm{M}[c,j]}{\sqrt{{\bm{V}}[c,j]}}-\frac{\nabla\mathcal{L}({\bm{W}})[c,j]}{|\nabla\mathcal{L}({\bm{W}})[c,j]|}), the set of index Ec,j{j[d]|1β2t+1|(𝑾)[c,j]|2αVηt𝒯c(𝑾)|ϵv,c,j|}E_{c,j}\coloneqq\{j\in[d]\Bigm{|}\sqrt{1-\beta_{2}^{t+1}}|\nabla\mathcal{L}({\bm{W}})[c,j]|\geq 2\alpha_{V}\sqrt{\eta}_{t}\mathcal{T}_{c}({\bm{W}})|\epsilon_{v,c,j}|\}, and its complement Ec,jc=[d]\Ec,jE_{c,j}^{c}=[d]\backslash E_{c,j}. The goal is to bound |ψc,j||\psi_{c,j}| when jEc,jcj\in E_{c,j}^{c} or jEc,jj\in E_{c,j} using 𝒯c(𝑾)\mathcal{T}_{c}({\bm{W}}). We start with the indices in Ec,jcE_{c,j}^{c}:

jEc,jc|ψc,j|\displaystyle\sum_{j\in E_{c,j}^{c}}|\psi_{c,j}| jEc,jc|(𝑾)[c,j]|(|𝑴[c,j]|𝑽[c,j]+1)\displaystyle\leq\sum_{j\in E_{c,j}^{c}}|\nabla\mathcal{L}({\bm{W}})[c,j]|\bigl{(}\frac{|\bm{M}[c,j]|}{\sqrt{{\bm{V}}[c,j]}}+1\bigr{)}
(a)jEc,jc|(𝑾)[c,j]|((1β1t+1)|(𝑾)[c,j]|+αMηt𝒯c(𝑾)1β2|(𝑾)[c,j]|+1)\displaystyle\stackrel{{\scriptstyle(a)}}{{\leq}}\sum_{j\in E_{c,j}^{c}}|\nabla\mathcal{L}({\bm{W}})[c,j]|\bigl{(}\frac{(1-\beta_{1}^{t+1})|\nabla\mathcal{L}({\bm{W}})[c,j]|+\alpha_{M}\eta_{t}\mathcal{T}_{c}({\bm{W}})}{\sqrt{1-\beta_{2}}|\nabla\mathcal{L}({\bm{W}})[c,j]|}+1\bigr{)}
(b)jEc,jc(1β1t+11β2+1)2αVηt𝒯c(𝑾)1β2t+1+αMηt𝒯c(𝑾)1β2\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}\sum_{j\in E_{c,j}^{c}}(\frac{1-\beta_{1}^{t+1}}{\sqrt{1-\beta_{2}}}+1)\frac{2\alpha_{V}\sqrt{\eta_{t}}\mathcal{T}_{c}({\bm{W}})}{\sqrt{1-\beta_{2}^{t+1}}}+\frac{\alpha_{M}\eta_{t}\mathcal{T}_{c}({\bm{W}})}{\sqrt{1-\beta_{2}}}
d1β2(4αV1β2t+1ηt+αMηt)𝒯c(𝑾),\displaystyle\leq\frac{d}{\sqrt{1-\beta_{2}}}\bigl{(}\frac{4\alpha_{V}}{\sqrt{1-\beta_{2}^{t+1}}}\sqrt{\eta_{t}}+\alpha_{M}\eta_{t}\bigr{)}\mathcal{T}_{c}({\bm{W}}),

where (a) is by (24), |ϵm,c,j|1|\epsilon_{m,c,j}|\leq 1, and 𝑽[c,j](1β2)(𝑾)[c,j]2{\bm{V}}[c,j]\geq(1-\beta_{2})\nabla\mathcal{L}({\bm{W}})[c,j]^{2}; and (b) is by jEc,jcj\in E^{c}_{c,j} s.t. |(𝑾)[c,j]|2αVηt𝒯c(𝑾)1β2t+1|\nabla\mathcal{L}({\bm{W}})[c,j]|\leq\frac{2\alpha_{V}\sqrt{\eta_{t}}\mathcal{T}_{c}({\bm{W}})}{\sqrt{1-\beta_{2}^{t+1}}}. Next, we focus on the indices jEc,jj\in E_{c,j}. In this case, we have

ψc,j\displaystyle\psi_{c,j} =(𝑾)[c,j](𝑴[c,j]1β2t+1|(𝑾)[c,j]|+αVηt𝒯c(𝑾)ϵv,c,j(𝑾)[c,j]|(𝑾)[c,j]|)\displaystyle=\nabla\mathcal{L}({\bm{W}})[c,j]\bigl{(}\frac{\bm{M}[c,j]}{\sqrt{1-\beta_{2}^{t+1}}|\nabla\mathcal{L}({\bm{W}})[c,j]|+\alpha_{V}\sqrt{\eta_{t}}\mathcal{T}_{c}({\bm{W}})\epsilon_{v,c,j}}-\frac{\mathcal{L}({\bm{W}})[c,j]}{|\nabla\mathcal{L}({\bm{W}})[c,j]|}\bigr{)}
=(𝑾)[c,j]𝑴[c,j]|(𝑾)[c,j]|(1β2t+1|(𝑾)[c,j]|+αVηt𝒯c(𝑾)ϵv,c,j)(𝑾)[c,j](1β2t+1|(𝑾)[c,j]|+αVηt𝒯c(𝑾)ϵv,c,j)|(𝑾)[c,j]|12,\displaystyle=\nabla\mathcal{L}({\bm{W}})[c,j]\underbrace{\frac{\bm{M}[c,j]|\nabla\mathcal{L}({\bm{W}})[c,j]|-\bigl{(}\sqrt{1-\beta_{2}^{t+1}}|\nabla\mathcal{L}({\bm{W}})[c,j]|+\alpha_{V}\sqrt{\eta_{t}}\mathcal{T}_{c}({\bm{W}})\epsilon_{v,c,j}\bigr{)}\nabla\mathcal{L}({\bm{W}})[c,j]}{\bigl{(}\sqrt{1-\beta_{2}^{t+1}}|\nabla\mathcal{L}({\bm{W}})[c,j]|+\alpha_{V}\sqrt{\eta_{t}}\mathcal{T}_{c}({\bm{W}})\epsilon_{v,c,j}\bigr{)}|\nabla\mathcal{L}({\bm{W}})[c,j]|}}_{\frac{\spadesuit_{1}}{\spadesuit_{2}}},

where

|1|\displaystyle\Bigm{|}\spadesuit_{1}\Bigm{|} =|(1β1t+11β2t+1)(𝑾)[c,j]|(𝑾)[c,j]|+\displaystyle=\Bigm{|}\bigl{(}1-\beta_{1}^{t+1}-\sqrt{1-\beta_{2}^{t+1}}\bigr{)}\nabla\mathcal{L}({\bm{W}})[c,j]|\nabla\mathcal{L}({\bm{W}})[c,j]|+
αMηt𝒯c(𝑾)ϵm,c,j|(𝑾)[c,j]|αVηt𝒯c(𝑾)ϵv,c,j(𝑾)[c,j]|\displaystyle\quad\quad\alpha_{M}\eta_{t}\mathcal{T}_{c}({\bm{W}})\epsilon_{m,c,j}|\nabla\mathcal{L}({\bm{W}})[c,j]|-\alpha_{V}\sqrt{\eta_{t}}\mathcal{T}_{c}({\bm{W}})\epsilon_{v,c,j}\nabla\mathcal{L}({\bm{W}})[c,j]\Bigm{|}
(c)|1β1t+11β2t+1||(𝑾)[c,j]|3+(αMηt+αVηt)𝒯c(𝑾)|(𝑾)[c,j]|2,\displaystyle\stackrel{{\scriptstyle(c)}}{{\leq}}\bigm{|}1-\beta_{1}^{t+1}-\sqrt{1-\beta_{2}^{t+1}}||\nabla\mathcal{L}({\bm{W}})[c,j]|^{3}+(\alpha_{M}\eta_{t}+\alpha_{V}\sqrt{\eta_{t}})\mathcal{T}_{c}({\bm{W}})|\nabla\mathcal{L}({\bm{W}})[c,j]|^{2},

and

|2|=2(d)121β2t+1|(𝑾)[c,j]|2.\displaystyle\Bigm{|}\spadesuit_{2}\Bigm{|}=\spadesuit_{2}\stackrel{{\scriptstyle(d)}}{{\geq}}\frac{1}{2}\sqrt{1-\beta_{2}^{t+1}}|\nabla\mathcal{L}({\bm{W}})[c,j]|^{2}.

Inequality (c) is by |ϵm,c,j|1|\epsilon_{m,c,j}|\leq 1 and |ϵv,c,j|1|\epsilon_{v,c,j}|\leq 1, and (d) is by αVηt𝒯c(𝑾)ϵv,c,j121β2t+1|(𝑾)[c,j]|\alpha_{V}\sqrt{\eta_{t}}\mathcal{T}_{c}({\bm{W}})\epsilon_{v,c,j}\geq-\frac{1}{2}\sqrt{1-\beta_{2}^{t+1}}|\nabla\mathcal{L}({\bm{W}})[c,j]| for any jEc,jj\in E_{c,j}. Putting these two pieces together, we obtain

jEc,j|ψc,j|\displaystyle\sum_{j\in E_{c,j}}|\psi_{c,j}| jEc,j|1β1t+11β2t+1||(𝑾)[c,j]|3+(αMηt+αVηt)𝒯c(𝑾)|(𝑾)[c,j]|2121β2t+1|(𝑾)[c,j]|2\displaystyle\leq\sum_{j\in E_{c,j}}\frac{|1-\beta_{1}^{t+1}-\sqrt{1-\beta_{2}^{t+1}}||\nabla\mathcal{L}({\bm{W}})[c,j]|^{3}+(\alpha_{M}\eta_{t}+\alpha_{V}\sqrt{\eta_{t}})\mathcal{T}_{c}({\bm{W}})|\nabla\mathcal{L}({\bm{W}})[c,j]|^{2}}{\frac{1}{2}\sqrt{1-\beta_{2}^{t+1}}|\nabla\mathcal{L}({\bm{W}})[c,j]|^{2}}
(e)(jEc,j4β1t+11β2t+1|(𝑾)[c,j]|)+d(2αV1β2t+1ηt+2αM1β2t+1ηt)𝒯c(𝑾)\displaystyle\stackrel{{\scriptstyle(e)}}{{\leq}}\Bigl{(}\sum_{j\in E_{c,j}}4\sqrt{\frac{\beta_{1}^{t+1}}{1-\beta_{2}^{t+1}}}|\nabla\mathcal{L}({\bm{W}})[c,j]|\Bigr{)}+d\bigl{(}\frac{2\alpha_{V}}{\sqrt{1-\beta_{2}^{t+1}}}\sqrt{\eta_{t}}+\frac{2\alpha_{M}}{\sqrt{1-\beta_{2}^{t+1}}}\eta_{t}\bigr{)}\mathcal{T}_{c}({\bm{W}})
4β1t+11β2t+1(𝑾)[c,:]sum+d1β2(2αV1β2t+1ηt+2αMηt)𝒯c(𝑾),\displaystyle\leq 4\sqrt{\frac{\beta_{1}^{t+1}}{1-\beta_{2}^{t+1}}}{\left\|\nabla\mathcal{L}({\bm{W}})[c,:]\right\|_{\rm{sum}}}+\frac{d}{\sqrt{1-\beta_{2}}}\bigl{(}\frac{2\alpha_{V}}{\sqrt{1-\beta_{2}^{t+1}}}\sqrt{\eta_{t}}+2\alpha_{M}\eta_{t}\bigr{)}\mathcal{T}_{c}({\bm{W}}),

where (e) is by aab+b\sqrt{a}\leq\sqrt{a-b}+\sqrt{b} implying 11β2t+1β1t+121-\sqrt{1-\beta_{2}^{t+1}}\leq\beta_{1}^{\frac{t+1}{2}}, and (𝑾)[c,:]\nabla\mathcal{L}({\bm{W}})[c,:] denotes the ccth row of (𝑾)\nabla\mathcal{L}({\bm{W}}). Finally, we note that |(𝑾),𝑴𝑽(𝑾)|(𝑾)||=|(c,j)ψc,j|(c,j)|ψc,j||\langle\nabla\mathcal{L}({\bm{W}}),\frac{\bm{M}}{\sqrt{{\bm{V}}}}-\frac{\nabla\mathcal{L}({\bm{W}})}{|\nabla\mathcal{L}({\bm{W}})|}\rangle|=|\sum_{(c,j)}\psi_{c,j}|\leq\sum_{(c,j)}|\psi_{c,j}|. Then, we obtain

c,j|ψc,j|\displaystyle\sum_{c,j}|\psi_{c,j}| =c[k](jEc,jc|ψc,j|+jEc,j|ψc,j|)\displaystyle=\sum_{c\in[k]}\bigl{(}\sum_{j\in E^{c}_{c,j}}|\psi_{c,j}|+\sum_{j\in E_{c,j}}|\psi_{c,j}|\bigr{)}
=c[k]4β1t+11β2t+1(𝑾)[c,:]sum+c[k]d1β2(2αV1β2t+1ηt+2αMηt)𝒯c(𝑾)\displaystyle=\sum_{c\in[k]}4\sqrt{\frac{\beta_{1}^{t+1}}{1-\beta_{2}^{t+1}}}{\left\|\nabla\mathcal{L}({\bm{W}})[c,:]\right\|_{\rm{sum}}}+\sum_{c\in[k]}\frac{d}{\sqrt{1-\beta_{2}}}\bigl{(}\frac{2\alpha_{V}}{\sqrt{1-\beta_{2}^{t+1}}}\sqrt{\eta_{t}}+2\alpha_{M}\eta_{t}\bigr{)}\mathcal{T}_{c}({\bm{W}})
=(f)4β1t+11β2t+1(𝑾)sum+2d1β2(2αV1β2t+1ηt+2αMηt)𝒢(𝑾),\displaystyle\stackrel{{\scriptstyle(f)}}{{=}}4\sqrt{\frac{\beta_{1}^{t+1}}{1-\beta_{2}^{t+1}}}{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\rm{sum}}}+\frac{2d}{\sqrt{1-\beta_{2}}}\bigl{(}\frac{2\alpha_{V}}{\sqrt{1-\beta_{2}^{t+1}}}\sqrt{\eta_{t}}+2\alpha_{M}\eta_{t}\bigr{)}{\mathcal{G}}({\bm{W}}),

where (f) is by c[k]𝒯c(𝑾)=c[k]𝒬c(𝑾)+𝒢c(𝑾)=2𝒢(𝑾)\sum_{c\in[k]}\mathcal{T}_{c}({\bm{W}})=\sum_{c\in[k]}\mathcal{Q}_{c}({\bm{W}})+{\mathcal{G}}_{c}({\bm{W}})=2{\mathcal{G}}({\bm{W}}). ∎

Lemma D.5 (Adam Descent).

Under the same setting as Theorem D.8, set tA2log(1β24)log(β1)t_{A}\coloneqq\frac{2\log(\frac{\sqrt{1-\beta_{2}}}{4})}{\log(\beta_{1})}, then we have for all ttAt\geq t_{A}

(𝑾t+1)\displaystyle\mathcal{L}({\bm{W}}_{t+1}) (𝑾t)ηtγ(1αa1β1t/2αa2dηt12αa3dηt)𝒢(𝑾t),\displaystyle\leq\mathcal{L}({\bm{W}}_{t})-\eta_{t}\gamma\bigl{(}1-\alpha_{a_{1}}\beta_{1}^{t/2}-\alpha_{a_{2}}d\eta_{t}^{\frac{1}{2}}-\alpha_{a_{3}}d\eta_{t}\bigr{)}{\mathcal{G}}({\bm{W}}_{t}),

where αa1\alpha_{a_{1}}, αa2\alpha_{a_{2}}, and αa3\alpha_{a_{3}} are some constants that depend on BB, γ\gamma,β1\beta_{1}, and β2\beta_{2}.

Proof.

We follow the same notations and strategy of Lemma C.1, and recall the definitions t=(𝑾t),𝚫t\spadesuit_{t}=\langle\nabla\mathcal{L}({\bm{W}}_{t}),\bm{\Delta}_{t}\rangle and t=𝒉i𝚫t(diag(𝕊(𝑾t,t+1,γ𝒉i))𝕊(𝑾t,t+1,ζ𝒉i)𝕊(𝑾t,t+1,ζ𝒉i))𝚫t𝒉i\clubsuit_{t}=\bm{h}_{i}^{\top}\bm{\Delta}_{t}^{\top}\left(\operatorname{diag}(\mathbb{S}({\bm{W}}_{t,t+1,\gamma}\bm{h}_{i}))-\mathbb{S}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i})\mathbb{S}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i})^{\top}\right)\bm{\Delta}_{t}\,\bm{h}_{i}. In the case of Adam, we have 𝚫t=𝑴t𝑽t\bm{\Delta}_{t}=\frac{\bm{M}_{t}}{\sqrt{{\bm{V}}_{t}}}. We bound t\spadesuit_{t} and t\clubsuit_{t} separately. Starting with t\spadesuit_{t}, we have for all ttAt\geq t_{A}

t\displaystyle\spadesuit_{t} =ηt(𝑾t),𝑴t𝑽t\displaystyle=-\eta_{t}\langle\nabla\mathcal{L}({\bm{W}}_{t}),\frac{\bm{M}_{t}}{\sqrt{{\bm{V}}_{t}}}\rangle
=ηt((𝑾t),𝑴t𝑽t(𝑾t)|(𝑾t)|+(𝑾t),(𝑾t)|(𝑾t)|)\displaystyle=-\eta_{t}\bigl{(}\langle\nabla\mathcal{L}({\bm{W}}_{t}),\frac{\bm{M}_{t}}{\sqrt{{\bm{V}}_{t}}}-\frac{\nabla\mathcal{L}({\bm{W}}_{t})}{|\nabla\mathcal{L}({\bm{W}}_{t})|}\rangle+\langle\nabla\mathcal{L}({\bm{W}}_{t}),\frac{\nabla\mathcal{L}({\bm{W}}_{t})}{|\nabla\mathcal{L}({\bm{W}}_{t})|}\rangle\bigr{)}
ηt(𝑾t)sum+ηt|(𝑾t),𝑴t𝑽t(𝑾t)|(𝑾t)||\displaystyle\leq-\eta_{t}{\left\|\nabla\mathcal{L}({\bm{W}}_{t})\right\|_{\rm{sum}}}+\eta_{t}\bigm{|}\langle\nabla\mathcal{L}({\bm{W}}_{t}),\frac{\bm{M}_{t}}{\sqrt{{\bm{V}}_{t}}}-\frac{\nabla\mathcal{L}({\bm{W}}_{t})}{|\nabla\mathcal{L}({\bm{W}}_{t})|}\rangle\bigm{|}
(a)ηt(14β1t+11β2t+1)(𝑾t)sum+2d1β2(6αV1β2t+1ηt32+3αMηt2)𝒢(𝑾t)\displaystyle\stackrel{{\scriptstyle(a)}}{{\leq}}-\eta_{t}\bigl{(}1-4\sqrt{\frac{\beta_{1}^{t+1}}{1-\beta_{2}^{t+1}}}\bigr{)}{\left\|\nabla\mathcal{L}({\bm{W}}_{t})\right\|_{\rm{sum}}}+\frac{2d}{\sqrt{1-\beta_{2}}}\bigl{(}\frac{6\alpha_{V}}{\sqrt{1-\beta_{2}^{t+1}}}\eta_{t}^{\frac{3}{2}}+3\alpha_{M}\eta_{t}^{2}\bigr{)}{\mathcal{G}}({\bm{W}}_{t})
ηt(14β1t21β2)(𝑾t)sum+12αV1β2dηt3/2𝒢(𝑾t)+6αM1β2dηt2𝒢(𝑾t)\displaystyle\leq-\eta_{t}\bigl{(}1-4\frac{\beta_{1}^{\frac{t}{2}}}{\sqrt{1-\beta_{2}}}\bigr{)}{\left\|\nabla\mathcal{L}({\bm{W}}_{t})\right\|_{\rm{sum}}}+\frac{12\alpha_{V}}{1-\beta_{2}}d\eta_{t}^{3/2}{\mathcal{G}}({\bm{W}}_{t})+\frac{6\alpha_{M}}{\sqrt{1-\beta_{2}}}d\eta_{t}^{2}{\mathcal{G}}({\bm{W}}_{t})
(b)ηtγ(14β1t21β2)𝒢(𝑾t)+12αV1β2dηt3/2𝒢(𝑾t)+6αM1β2dηt2𝒢(𝑾t),\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}-\eta_{t}\gamma\bigl{(}1-4\frac{\beta_{1}^{\frac{t}{2}}}{\sqrt{1-\beta_{2}}}\bigr{)}{\mathcal{G}}({\bm{W}}_{t})+\frac{12\alpha_{V}}{1-\beta_{2}}d\eta_{t}^{3/2}{\mathcal{G}}({\bm{W}}_{t})+\frac{6\alpha_{M}}{\sqrt{1-\beta_{2}}}d\eta_{t}^{2}{\mathcal{G}}({\bm{W}}_{t}),

where (a) is by Lemma D.4, and (b) is by Lemma B.1. For t\clubsuit_{t}, we apply Lemma A.3 to obtain

t4𝚫t𝒉i2(1𝕊yi(𝑾t,t+1,ζ𝒉i))4ηt2α2B2(1𝕊yi(𝑾t,t+1,ζ𝒉i)),\displaystyle\clubsuit_{t}\leq 4\|\bm{\Delta}_{t}\bm{h}_{i}\|_{\infty}^{2}(1-\mathbb{S}_{y_{i}}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i}))\leq 4\eta_{t}^{2}\alpha^{2}B^{2}(1-\mathbb{S}_{y_{i}}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i})),

where in the second inequality we have used 𝚫t𝒉i𝚫tmax𝒉i1\|\bm{\Delta}_{t}\bm{h}_{i}\|_{\infty}\leq{\left\|\bm{\Delta}_{t}\right\|_{\max}}\|\bm{h}_{i}\|_{1}, 𝒉i1B\|\bm{h}_{i}\|_{1}\leq B, and 𝚫tmax=ηt𝑴t𝑽tmaxηtα{\left\|\bm{\Delta}_{t}\right\|_{\max}}=\eta_{t}{\left\|\frac{\bm{M}_{t}}{\sqrt{{\bm{V}}_{t}}}\right\|_{\max}}\leq\eta_{t}\alpha by Lemma D.1 given ttAt\geq t_{A} implying that 14β1t21β21\geq 4\frac{\beta_{1}^{\frac{t}{2}}}{\sqrt{1-\beta_{2}}}. Combing this with Lemma A.3, we obtain

12ni[n]𝒉i𝚫t\displaystyle\frac{1}{2n}\sum_{i\in[n]}\bm{h}_{i}^{\top}\bm{\Delta}_{t}^{\top} (diag(𝕊(𝑾t,t+1,γ𝒉i))𝕊(𝑾t,t+1,ζ𝒉i)𝕊(𝑾t,t+1,ζ𝒉i))𝚫t𝒉i\displaystyle\left(\operatorname{diag}(\mathbb{S}({\bm{W}}_{t,t+1,\gamma}\bm{h}_{i}))-\mathbb{S}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i})\mathbb{S}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i})^{\top}\right)\bm{\Delta}_{t}\,\bm{h}_{i}
12ni[n]4ηt2α2B2(1𝕊yi(𝑾t,t+1,ζ𝒉i))2α2ηt2B2e2Bη0𝒢(𝑾t),\displaystyle\leq\frac{1}{2n}\sum_{i\in[n]}4\eta_{t}^{2}\alpha^{2}B^{2}(1-\mathbb{S}_{y_{i}}({\bm{W}}_{t,t+1,\zeta^{*}}\bm{h}_{i}))\leq 2\alpha^{2}\eta_{t}^{2}B^{2}e^{2B\eta_{0}}{\mathcal{G}}({\bm{W}}_{t}),

where the derivation of the second inequality can be found in the derivation of 18. Putting everything together, we obtain

(𝑾t+1)\displaystyle\mathcal{L}({\bm{W}}_{t+1}) (𝑾t)ηtγ𝒢(𝑾t)+4β1t21β2γηt𝒢(𝑾t)+12αV1β2dηt3/2𝒢(𝑾t)+\displaystyle\leq\mathcal{L}({\bm{W}}_{t})-\eta_{t}\gamma{\mathcal{G}}({\bm{W}}_{t})+4\frac{\beta_{1}^{\frac{t}{2}}}{\sqrt{1-\beta_{2}}}\gamma\eta_{t}{\mathcal{G}}({\bm{W}}_{t})+\frac{12\alpha_{V}}{1-\beta_{2}}d\eta_{t}^{3/2}{\mathcal{G}}({\bm{W}}_{t})+
(6αM1β2+2α2B2e2Bη0)dηt2𝒢(𝑾t)\displaystyle\quad\quad\quad\bigl{(}\frac{6\alpha_{M}}{\sqrt{1-\beta_{2}}}+2\alpha^{2}B^{2}e^{2B\eta_{0}}\bigr{)}d\eta_{t}^{2}{\mathcal{G}}({\bm{W}}_{t})
=(𝑾t)ηtγ(1αa1β1t/2αa2dηt12αa3dηt)𝒢(𝑾t),\displaystyle=\mathcal{L}({\bm{W}}_{t})-\eta_{t}\gamma\bigl{(}1-\alpha_{a_{1}}\beta_{1}^{t/2}-\alpha_{a_{2}}d\eta_{t}^{\frac{1}{2}}-\alpha_{a_{3}}d\eta_{t}\bigr{)}{\mathcal{G}}({\bm{W}}_{t}),

where we have defined αa141β2\alpha_{a_{1}}\coloneqq\frac{4}{\sqrt{1-\beta_{2}}}, αa212αVγ(1β2)\alpha_{a_{2}}\coloneqq\frac{12\alpha_{V}}{\gamma(1-\beta_{2})}, and αa36αMγ1β2+2α2B2e2Bη0γ\alpha_{a_{3}}\coloneqq\frac{6\alpha_{M}}{\gamma\sqrt{1-\beta_{2}}}+\frac{2\alpha^{2}B^{2}e^{2B\eta_{0}}}{\gamma}. ∎

Built upon Lemma D.5, we can further lower bound the unnormalized margin of Adam iterates for a sufficiently large tt. The proof is similar to that of SignGD (i.e., Lemma C.2), which crucially depends on the separability condition obtained after achieving a low loss (Lemma B.4). The time t~A\tilde{t}_{A} will be specified in the proof of Theorem D.8.

Lemma D.6 (Adam Unnormalized Margin).

Under the same setting as Theorem D.8, suppose that there exist t~\tilde{t} such that (𝐖t)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} for all t>t~t>\tilde{t}, then we have for all tt~Amax{tA,t~}t\geq\tilde{t}_{A}\coloneqq\max\{t_{A},\tilde{t}\}

mini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉iγs=t~At1ηs𝒢(𝑾s)(𝑾s)αa5ds=t~At1ηs32αa6ds=t~At1ηs2αa7,\displaystyle\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}\geq\gamma\sum_{s=\tilde{t}_{A}}^{t-1}\eta_{s}\frac{{\mathcal{G}}({\bm{W}}_{s})}{\mathcal{L}({\bm{W}}_{s})}-\alpha_{a_{5}}d\sum_{s=\tilde{t}_{A}}^{t-1}\eta_{s}^{\frac{3}{2}}-\alpha_{a_{6}}d\sum_{s=\tilde{t}_{A}}^{t-1}\eta_{s}^{2}-\alpha_{a_{7}},

where tA=2log(1β24)log(β1)t_{A}=\frac{2\log(\frac{\sqrt{1-\beta_{2}}}{4})}{\log(\beta_{1})}, and αa5\alpha_{a_{5}}, αa6\alpha_{a_{6}}, and αa7\alpha_{a_{7}} are some constants that depend on BB, β1\beta_{1}, and β2\beta_{2}.

Proof.

We denote αa441β2\alpha_{a_{4}}\coloneqq\frac{4}{\sqrt{1-\beta_{2}}}, αa512αV1β2\alpha_{a_{5}}\coloneqq\frac{12\alpha_{V}}{1-\beta_{2}}, and αa66αM1β2+2α2B2e2Bη0\alpha_{a_{6}}\coloneqq\frac{6\alpha_{M}}{\sqrt{1-\beta_{2}}}+2\alpha^{2}B^{2}e^{2B\eta_{0}}. Under the assumption that (𝑾t)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} for all tt~t\geq\tilde{t}, we have for all tt~Amax{tA,t~}t\geq\tilde{t}_{A}\coloneqq\max\{t_{A},\tilde{t}\} (recall that tA=2log(1β24)log(β1)t_{A}=\frac{2\log(\frac{\sqrt{1-\beta_{2}}}{4})}{\log(\beta_{1})})

(𝑾t+1)\displaystyle\mathcal{L}({\bm{W}}_{t+1}) (a)(𝑾t)ηtγ𝒢(𝑾t)+αa4β1t2γηt𝒢(𝑾t)+αa5kdηt32𝒢(𝑾t)+αa6kdηt2𝒢(𝑾t)\displaystyle\stackrel{{\scriptstyle(a)}}{{\leq}}\mathcal{L}({\bm{W}}_{t})-\eta_{t}\gamma{\mathcal{G}}({\bm{W}}_{t})+\alpha_{a_{4}}\beta_{1}^{\frac{t}{2}}\gamma\eta_{t}{\mathcal{G}}({\bm{W}}_{t})+\alpha_{a_{5}}kd\eta_{t}^{\frac{3}{2}}{\mathcal{G}}({\bm{W}}_{t})+\alpha_{a_{6}}kd\eta_{t}^{2}{\mathcal{G}}({\bm{W}}_{t})
(b)(𝑾t)(1ηtγ𝒢(𝑾t)(𝑾t)+αa4β1t2γηt+αa5dηt32+αa6dηt2)\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}\mathcal{L}({\bm{W}}_{t})\bigl{(}1-\eta_{t}\gamma\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}+\alpha_{a_{4}}\beta_{1}^{\frac{t}{2}}\gamma\eta_{t}+\alpha_{a_{5}}d\eta_{t}^{\frac{3}{2}}+\alpha_{a_{6}}d\eta_{t}^{2}\bigr{)}
(𝑾t~A)exp(γs=t~Atηs𝒢(𝑾s)(𝑾s)+αa4γs=t~Atβ1s2ηs+αa5ds=t~Atηs32+αa6ds=t~Atηs2)\displaystyle\leq\mathcal{L}({\bm{W}}_{\tilde{t}_{A}})\exp\bigl{(}-\gamma\sum_{s=\tilde{t}_{A}}^{t}\eta_{s}\frac{{\mathcal{G}}({\bm{W}}_{s})}{\mathcal{L}({\bm{W}}_{s})}+\alpha_{a_{4}}\gamma\sum_{s=\tilde{t}_{A}}^{t}\beta_{1}^{\frac{s}{2}}\eta_{s}+\alpha_{a_{5}}d\sum_{s=\tilde{t}_{A}}^{t}\eta_{s}^{\frac{3}{2}}+\alpha_{a_{6}}d\sum_{s=\tilde{t}_{A}}^{t}\eta_{s}^{2}\bigr{)}
(c)log2nexp(γs=t~Atηs𝒢(𝑾s)(𝑾s)+αa5ds=t~Atηs32+αa6ds=t~Atηs2+αa7),\displaystyle\stackrel{{\scriptstyle(c)}}{{\leq}}\frac{\log 2}{n}\exp\bigl{(}-\gamma\sum_{s=\tilde{t}_{A}}^{t}\eta_{s}\frac{{\mathcal{G}}({\bm{W}}_{s})}{\mathcal{L}({\bm{W}}_{s})}+\alpha_{a_{5}}d\sum_{s=\tilde{t}_{A}}^{t}\eta_{s}^{\frac{3}{2}}+\alpha_{a_{6}}d\sum_{s=\tilde{t}_{A}}^{t}\eta_{s}^{2}+\alpha_{a_{7}}\bigr{)},

where (a) is by Lemma D.5, (b) is by 𝒢(𝑾t)(𝑾t)1\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}\leq 1 (shown in Lemma B.3), and (c) is by (𝑾t~A)log2n\mathcal{L}({\bm{W}}_{\tilde{t}_{A}})\leq\frac{\log 2}{n} and αa4γs=t~Atβ1s2ηsαa4γη01β112αa7\alpha_{a_{4}}\gamma\sum_{s=\tilde{t}_{A}}^{t}\beta_{1}^{\frac{s}{2}}\eta_{s}\leq\frac{\alpha_{a_{4}}\gamma\eta_{0}}{1-\beta_{1}^{\frac{1}{2}}}\eqqcolon\alpha_{a_{7}} . The rest of the proof follows the same arguments in Lemma C.2. Namely, the assumption (𝑾t)log2n\mathcal{L}({\bm{W}}_{t})\leq\frac{\log 2}{n} implies that mincyi(𝒆yi𝒆c)T𝑾t𝒉i0\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}\geq 0 for all i[n]i\in[n]. This separability condition can be used further to show that for all tt~At\geq\tilde{t}_{A}

emini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉iexp(γs=t~At1ηs𝒢(𝑾s)(𝑾s)+αa5ds=t~At1ηs32+αa6ds=t~At1ηs2+αa7).\displaystyle e^{-\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}\leq\exp\bigl{(}-\gamma\sum_{s=\tilde{t}_{A}}^{t-1}\eta_{s}\frac{{\mathcal{G}}({\bm{W}}_{s})}{\mathcal{L}({\bm{W}}_{s})}+\alpha_{a_{5}}d\sum_{s=\tilde{t}_{A}}^{t-1}\eta_{s}^{\frac{3}{2}}+\alpha_{a_{6}}d\sum_{s=\tilde{t}_{A}}^{t-1}\eta_{s}^{2}+\alpha_{a_{7}}\bigr{)}.

Taking the log\log on both sides leads to the final result. ∎

Next lemma upper bounds the max-norm of Adam iterates. It involves showing that the risk upper bounds entry-wise second moment, which will become small after the risk starts to monotonically decrease. Its proof can be found in Zhang et al. (2024, Lemma 6.4). Here, we only show the steps that are specific in our settings.

Lemma D.7 (Adam 𝑾tmax{\left\|{\bm{W}}_{t}\right\|_{\max}}).

Under the same setting as Theorem D.8, suppose that there exists t~B>log(1ω)\tilde{t}_{B}>\log(\frac{1}{\omega}) such that (𝐖t)14B2+αVη0\mathcal{L}({\bm{W}}_{t})\leq\frac{1}{\sqrt{4B^{2}+\alpha_{V}\eta_{0}}} for all tt~Bt\geq\tilde{t}_{B}, then we have

𝑾tmaxαa8s=0t~B1ηs+s=t~Bt1ηs+𝑾0max,\displaystyle{\left\|{\bm{W}}_{t}\right\|_{\max}}\leq\alpha_{a_{8}}\sum_{s=0}^{\tilde{t}_{B}-1}\eta_{s}+\sum_{s=\tilde{t}_{B}}^{t-1}\eta_{s}+{\left\|{\bm{W}}_{0}\right\|_{\max}},

where αa8\alpha_{a_{8}} is some constant that depends on BB, β1\beta_{1}, and β2\beta_{2}.

Proof.

For any c[k]c\in[k] and j[d]j\in[d], we have for all tt~Bt\geq\tilde{t}_{B}

𝑽t[c,j]\displaystyle{\bm{V}}_{t}[c,j] (a)(1β2t+1)(𝑾t)[c,j]2+αVηt𝒢(𝑾t)2\displaystyle\stackrel{{\scriptstyle(a)}}{{\leq}}(1-\beta_{2}^{t+1})\nabla\mathcal{L}({\bm{W}}_{t})[c,j]^{2}+\alpha_{V}\eta_{t}{\mathcal{G}}({\bm{W}}_{t})^{2}
(𝑾t)[c,j]2+αVηt𝒢(𝑾t)2\displaystyle\leq\nabla\mathcal{L}({\bm{W}}_{t})[c,j]^{2}+\alpha_{V}\eta_{t}{\mathcal{G}}({\bm{W}}_{t})^{2}
(b)4B2𝒢(𝑾t)2+αVη0𝒢(𝑾t)2\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}4B^{2}{\mathcal{G}}({\bm{W}}_{t})^{2}+\alpha_{V}\eta_{0}{\mathcal{G}}({\bm{W}}_{t})^{2}
(c)(4B2+αVη0)(𝑾t)2(d)1,\displaystyle\stackrel{{\scriptstyle(c)}}{{\leq}}(4B^{2}+\alpha_{V}\eta_{0})\mathcal{L}({\bm{W}}_{t})^{2}\stackrel{{\scriptstyle(d)}}{{\leq}}1,

where (a) is by Lemma D.2, (b) is by Lemma B.1, (c) is by Lemma B.3, and (d) is by the assumption. This implies that for all tt~Bt\geq\tilde{t}_{B}

0log(𝑽t[c,j])log(β2t(1β2)(𝑾t)[c,j]2)(e)tlog(β2)+log(1β2)+log(ω),\displaystyle 0\geq\log({\bm{V}}_{t}[c,j])\geq\log(\beta_{2}^{t}(1-\beta_{2})\mathcal{L}({\bm{W}}_{t})[c,j]^{2})\stackrel{{\scriptstyle(e)}}{{\geq}}t\log(\beta_{2})+\log(1-\beta_{2})+\log(\omega),

where (e) is by the Assumption 3.2. The rest proof follows the same arguments in Zhang et al. (2024, Lemma 6.4). ∎

Theorem D.8.

Suppose that Assumption 3.1, 3.2, 3.3, 3.4, and 3.5 hold, and β1β2\beta_{1}\leq\beta_{2}, then there exists ta2=ta2(n,d,γ,B,𝐖0,β1,β2,ω)t_{a_{2}}=t_{a_{2}}(n,d,\gamma,B,{\bm{W}}_{0},\beta_{1},\beta_{2},\omega) such that Adam achieves the following for all t>ta2t>t_{a_{2}}

|mini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉i𝑾tmaxγ|\displaystyle\left|\frac{\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}{{\left\|{\bm{W}}_{t}\right\|_{\max}}}-\gamma\right| 𝒪(s=ta2t1ηseγ4τ=ta2s1ητ+s=0ta21ηs+ds=ta2t1ηs3/2s=0t1ηs).\displaystyle\leq\mathcal{O}(\frac{\sum_{s=t_{a_{2}}}^{t-1}\eta_{s}e^{-\frac{\gamma}{4}\sum_{\tau=t_{a_{2}}}^{s-1}\eta_{\tau}}+\sum_{s=0}^{t_{a_{2}}-1}\eta_{s}+d\sum_{s=t_{a_{2}}}^{t-1}\eta_{s}^{3/2}}{\sum_{s=0}^{t-1}\eta_{s}}).
Proof.

Determination of ta1t_{a_{1}}. Here, we consider learning rate schedule of the form ηt=Θ(1ta)\eta_{t}=\Theta(\frac{1}{t^{a}}) where a(0,1]a\in(0,1]. We choose ta1t_{a_{1}} after (max{t0,tA,log(1ω)}\max\{t_{0},t_{A},\log(\frac{1}{\omega})\} where t0t_{0} satisfies Assumption 3.4 and tA=2log(1β24)logβ1t_{A}=\frac{2\log(\frac{\sqrt{1-\beta_{2}}}{4})}{\log\beta_{1}}) such that the following conditions are met: αa1β1t/216\alpha_{a_{1}}\beta_{1}^{t/2}\leq\frac{1}{6}, αa2dηt1/216\alpha_{a_{2}}d\eta_{t}^{1/2}\leq\frac{1}{6}, and αa3dηt16\alpha_{a_{3}}d\eta_{t}\leq\frac{1}{6}. Concretely, we can set ta1=max{2log(6αa1)logβ1,(36αa22d2)1/a,(6αa3d)1/a}=Θ(d2/a)t_{a_{1}}=\max\{\frac{-2\log(6\alpha_{a_{1}})}{\log\beta_{1}},(36\alpha_{a_{2}}^{2}d^{2})^{1/a},(6\alpha_{a_{3}}d)^{1/a}\}=\Theta(d^{2/a}). Then, we have for all tta1t\geq t_{a_{1}}

(𝑾t+1)(𝑾t)ηtγ2𝒢(𝑾t).\displaystyle\mathcal{L}({\bm{W}}_{t+1})\leq\mathcal{L}({\bm{W}}_{t})-\frac{\eta_{t}\gamma}{2}{\mathcal{G}}({\bm{W}}_{t}). (26)

Rearranging this equation and using non-negativity of the loss we obtain γs=ta1tηs𝒢(𝑾s)2(𝑾ta1)\gamma\sum_{s=t_{a_{1}}}^{t}\eta_{s}{\mathcal{G}}({\bm{W}}_{s})\leq 2\mathcal{L}({\bm{W}}_{t_{a_{1}}}).
Determination of ta2t_{a_{2}}. By Lemma B.2, we can bound (𝑾ts1)\mathcal{L}({\bm{W}}_{t_{s_{1}}}) as follows

|(𝑾ta1)(𝑾0)|2B𝑾ta1𝑾0max2Bs=0ta11ηs𝑴s𝑽smax2Bαs=0ta11ηs,\displaystyle|\mathcal{L}({\bm{W}}_{t_{a_{1}}})-\mathcal{L}({\bm{W}}_{0})|\leq 2B{\left\|{\bm{W}}_{t_{a_{1}}}-{\bm{W}}_{0}\right\|_{\max}}\leq 2B\sum_{s=0}^{t_{a_{1}}-1}\eta_{s}{\left\|\frac{\bm{M}_{s}}{\sqrt{{\bm{V}}_{s}}}\right\|_{\max}}\leq 2B\alpha\sum_{s=0}^{t_{a_{1}}-1}\eta_{s},

where the last inequality is by Lemma D.1. Combining this with the result above and letting ~=min{log2n,14B2+αVη0}\tilde{\mathcal{L}}=\min\{\frac{\log 2}{n},\frac{1}{\sqrt{4B^{2}+\alpha_{V}\eta_{0}}}\}), we obtain

𝒢(𝑾t)=mins[ta1,ta2]𝒢(𝑾s)2(𝑾0)+4Bαs=1ta11ηsγs=ta1ta2ηs~212n,\displaystyle{\mathcal{G}}({\bm{W}}_{t^{*}})=\min_{s\in[t_{a_{1}},t_{a_{2}}]}{\mathcal{G}}({\bm{W}}_{s})\leq\frac{2\mathcal{L}({\bm{W}}_{0})+4B\alpha\sum_{s=1}^{t_{a_{1}}-1}\eta_{s}}{\gamma\sum_{s=t_{a_{1}}}^{t_{a_{2}}}\eta_{s}}\leq\frac{\tilde{\mathcal{L}}}{2}\leq\frac{1}{2n},

from which we derive the sufficient condition on ta2t_{a_{2}} to be s=ta1ta2ηs4(𝑾0)+8Bαs=1ta11ηsγ~\sum_{s=t_{a_{1}}}^{t_{a_{2}}}\eta_{s}\geq\frac{4\mathcal{L}({\bm{W}}_{0})+8B\alpha\sum_{s=1}^{t_{a_{1}}-1}\eta_{s}}{\gamma\tilde{\mathcal{L}}}.
Convergence of 𝒢(Wt)(Wt)\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})} We follow the same arguments in the proof of SignGD (Theorem C.4) to conclude that

𝒢(𝑾t)(𝑾t)1eγ4s=ta2t1ηs.\displaystyle\frac{{\mathcal{G}}({\bm{W}}_{t})}{\mathcal{L}({\bm{W}}_{t})}\geq 1-e^{-\frac{\gamma}{4}\sum_{s=t_{a_{2}}}^{t-1}\eta_{s}}. (27)

We note that ta2t_{a_{2}} satisfies the assumptions in Lemma D.6 and Lemma D.7.
Margin Convergence Finally, we combine Lemma D.6, Lemma D.7, and (27) to obtain

|mini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉i𝑾tmaxγ|\displaystyle|\frac{\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}{{\left\|{\bm{W}}_{t}\right\|_{\max}}}-\gamma| 𝒪(s=ta2t1ηseγ4τ=ta2s1ητ+s=0ta21ηs+ds=ta2t1ηs3/2+ds=ta2t1ηs2s=0t1ηs)\displaystyle\leq\mathcal{O}(\frac{\sum_{s=t_{a_{2}}}^{t-1}\eta_{s}e^{-\frac{\gamma}{4}\sum_{\tau=t_{a_{2}}}^{s-1}\eta_{\tau}}+\sum_{s=0}^{t_{a_{2}}-1}\eta_{s}+d\sum_{s=t_{a_{2}}}^{t-1}\eta_{s}^{3/2}+d\sum_{s=t_{a_{2}}}^{t-1}\eta_{s}^{2}}{\sum_{s=0}^{t-1}\eta_{s}})
𝒪(s=ta2t1ηseγ4τ=ta2s1ητ+s=0ta21ηs+ds=ta2t1ηs3/2s=0t1ηs)\displaystyle\leq\mathcal{O}(\frac{\sum_{s=t_{a_{2}}}^{t-1}\eta_{s}e^{-\frac{\gamma}{4}\sum_{\tau=t_{a_{2}}}^{s-1}\eta_{\tau}}+\sum_{s=0}^{t_{a_{2}}-1}\eta_{s}+d\sum_{s=t_{a_{2}}}^{t-1}\eta_{s}^{3/2}}{\sum_{s=0}^{t-1}\eta_{s}})

Similar to the case of SignGD, we can derive the margin convergence rates for Adam.

Corollary D.9.

Consider learning rate schedule of the form ηt=Θ(1ta)\eta_{t}=\Theta(\frac{1}{t^{a}}) where a(0,1]a\in(0,1], under the same setting as Theorem D.8, then we have for Adam

|mini[n],cyi(𝒆yi𝒆c)T𝑾t𝒉i𝑾tmaxγ|={𝒪(dt13a2+nd2(1a)a+n(𝑾0)+[log(1/ω)]1at1a)ifa<23𝒪(dlog(t)+nd+n(𝑾0)+[log(1/ω)]1/3t1/3)ifa=23𝒪(d+nd2(1a)a+n(𝑾0)+[log(1/ω)]1at1a)if23<a<1𝒪(d+nlog(d)+n(𝑾0)+loglog(1/ω)logt)ifa=1|\frac{\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}_{t}\bm{h}_{i}}{{\left\|{\bm{W}}_{t}\right\|_{\max}}}-\gamma|=\left\{\begin{array}[]{ll}\mathcal{O}(\frac{dt^{1-\frac{3a}{2}}+nd^{\frac{2(1-a)}{a}}+n\mathcal{L}({\bm{W}}_{0})+[\log(1/\omega)]^{1-a}}{t^{1-a}})&\text{if}\quad a<\frac{2}{3}\\ \mathcal{O}(\frac{d\log(t)+nd+n\mathcal{L}({\bm{W}}_{0})+[\log(1/\omega)]^{1/3}}{t^{1/3}})&\text{if}\quad a=\frac{2}{3}\\ \mathcal{O}(\frac{d+nd^{\frac{2(1-a)}{a}}+n\mathcal{L}({\bm{W}}_{0})+[\log(1/\omega)]^{1-a}}{t^{1-a}})&\text{if}\quad\frac{2}{3}<a<1\\ \mathcal{O}(\frac{d+n\log(d)+n\mathcal{L}({\bm{W}}_{0})+\log\log(1/\omega)}{\log t})&\text{if}\quad a=1\end{array}\right.
Proof.

Recall that ta1=Θ(d2/a)=Ca1d2/at_{a_{1}}=\Theta(d^{2/a})=C_{a_{1}}d^{2/a}, and the condition on ta2t_{a_{2}} is 2(𝑾0)+4Bαs=1ta11ηsγs=ta1ta2ηs~2\frac{2\mathcal{L}({\bm{W}}_{0})+4B\alpha\sum_{s=1}^{t_{a_{1}}-1}\eta_{s}}{\gamma\sum_{s=t_{a_{1}}}^{t_{a_{2}}}\eta_{s}}\leq\frac{\tilde{\mathcal{L}}}{2}, where ~=min{log2n,14B2+αVη0}\tilde{\mathcal{L}}=\min\{\frac{\log 2}{n},\frac{1}{\sqrt{4B^{2}+\alpha_{V}\eta_{0}}}\}. Then, we apply integral approximations and the rest of the proof can be found in ([)Corollary 4.7 and Lemma C.1]zhang2024implicit. ∎

Lemma D.10.

For any 𝐯,𝐯,𝐪,𝐪k{\bm{v}},{\bm{v}}^{\prime},{\bm{q}},{\bm{q}}^{\prime}\in\mathbb{R}^{k} and c[k]c\in[k], the following inequalities hold:

  1. (i)

    |𝕊c(𝒗)𝕊c(𝒗)1|e2𝒗𝒗1|\frac{\mathbb{S}_{c}({\bm{v}}^{\prime})}{\mathbb{S}_{c}({\bm{v}})}-1|\leq e^{2\lVert{\bm{v}}-{\bm{v}}^{\prime}\rVert_{\infty}}-1

  2. (ii)

    |1𝕊c(𝒗)1𝕊c(𝒗)1|e2𝒗𝒗1|\frac{1-\mathbb{S}_{c}({\bm{v}}^{\prime})}{1-\mathbb{S}_{c}({\bm{v}})}-1|\leq e^{2\lVert{\bm{v}}-{\bm{v}}^{\prime}\rVert_{\infty}}-1

  3. (iii)

    |𝕊c(𝒗)𝕊c(𝒒)𝕊c(𝒗)𝕊c(𝒒)1|e2(𝒗𝒗+𝒒𝒒)1|\frac{\mathbb{S}_{c}({\bm{v}}^{\prime})\mathbb{S}_{c}({\bm{q}}^{\prime})}{\mathbb{S}_{c}({\bm{v}})\mathbb{S}_{c}({\bm{q}})}-1|\leq e^{2(\lVert{\bm{v}}^{\prime}-{\bm{v}}\rVert_{\infty}+\lVert{\bm{q}}^{\prime}-{\bm{q}}\rVert_{\infty})}-1

  4. (iv)

    |𝕊c(𝒗)(1𝕊c(𝒒))𝕊c(𝒗)(1𝕊c(𝒒))1|e2(𝒗𝒗+𝒒𝒒)1|\frac{\mathbb{S}_{c}({\bm{v}}^{\prime})(1-\mathbb{S}_{c}({\bm{q}}^{\prime}))}{\mathbb{S}_{c}({\bm{v}})(1-\mathbb{S}_{c}({\bm{q}}))}-1|\leq e^{2(\lVert{\bm{v}}^{\prime}-{\bm{v}}\rVert_{\infty}+\lVert{\bm{q}}^{\prime}-{\bm{q}}\rVert_{\infty})}-1

  5. (v)

    |(1𝕊c(𝒗))(1𝕊c(𝒒))(1𝕊c(𝒗))(1𝕊c(𝒒))1|e2(𝒗𝒗+𝒒𝒒)1|\frac{(1-\mathbb{S}_{c}({\bm{v}}^{\prime}))(1-\mathbb{S}_{c}({\bm{q}}^{\prime}))}{(1-\mathbb{S}_{c}({\bm{v}}))(1-\mathbb{S}_{c}({\bm{q}}))}-1|\leq e^{2(\lVert{\bm{v}}^{\prime}-{\bm{v}}\rVert_{\infty}+\lVert{\bm{q}}^{\prime}-{\bm{q}}\rVert_{\infty})}-1

Proof.

We prove each inequality:

(i) First, observe that

|𝕊c(𝒗)𝕊c(𝒗)1|\displaystyle|\frac{\mathbb{S}_{c}({\bm{v}}^{\prime})}{\mathbb{S}_{c}({\bm{v}})}-1| =|evcevci[k]evii[k]evi1|\displaystyle=|\frac{e^{v^{\prime}_{c}}}{e^{v_{c}}}\frac{\sum_{i\in[k]}e^{v_{i}}}{\sum_{i\in[k]}e^{v^{\prime}_{i}}}-1|
=|i[k]evc+vii[k]evc+vii[k]evc+vi|\displaystyle=|\frac{\sum_{i\in[k]}e^{v^{\prime}_{c}+v_{i}}-\sum_{i\in[k]}e^{v_{c}+v^{\prime}_{i}}}{\sum_{i\in[k]}e^{v_{c}+v^{\prime}_{i}}}|
i[k]|evc+vievc+vi|i[k]evc+vi\displaystyle\leq\frac{\sum_{i\in[k]}|e^{v^{\prime}_{c}+v_{i}}-e^{v_{c}+v^{\prime}_{i}}|}{\sum_{i\in[k]}e^{v_{c}+v^{\prime}_{i}}}

For any i[k]i\in[k], we have |evc+vievc+vi|evc+vi=|evcvc+vivi1|e|vcvc+vivi|1e2𝒗𝒗1\frac{|e^{v^{\prime}_{c}+v_{i}}-e^{v_{c}+v^{\prime}_{i}}|}{e^{v_{c}+v^{\prime}_{i}}}=|e^{v^{\prime}_{c}-v_{c}+v_{i}-v^{\prime}_{i}}-1|\leq e^{|v^{\prime}_{c}-v_{c}+v_{i}-v^{\prime}_{i}|}-1\leq e^{2\lVert{\bm{v}}-{\bm{v}}^{\prime}\rVert_{\infty}}-1. This implies i[k]|evc+vievc+vi|(e2𝒗𝒗1)i[k]evc+vi\sum_{i\in[k]}|e^{v^{\prime}_{c}+v_{i}}-e^{v_{c}+v^{\prime}_{i}}|\leq\bigl{(}e^{2\lVert{\bm{v}}-{\bm{v}}^{\prime}\rVert_{\infty}}-1\bigr{)}\sum_{i\in[k]}e^{v_{c}+v^{\prime}_{i}}, from which we obtain the desired inequality.

(ii) For the second inequality:

|1𝕊c(𝒗)1𝕊c(𝒗)1|\displaystyle|\frac{1-\mathbb{S}_{c}({\bm{v}}^{\prime})}{1-\mathbb{S}_{c}({\bm{v}})}-1| =|1evci[k]evi1evci[k]evi1|\displaystyle=|\frac{1-\frac{e^{v^{\prime}_{c}}}{\sum_{i\in[k]}e^{v^{\prime}_{i}}}}{1-\frac{e^{v_{c}}}{\sum_{i\in[k]}e^{v_{i}}}}-1|
=|(j[k],jcevj)(i[k]evi)(j[k],jcevj)(i[k]evi)1|\displaystyle=|\frac{(\sum_{j\in[k],j\neq c}e^{v^{\prime}_{j}})(\sum_{i\in[k]}e^{v_{i}})}{(\sum_{j\in[k],j\neq c}e^{v_{j}})(\sum_{i\in[k]}e^{v^{\prime}_{i}})}-1|
=|j[k],jci[k][evj+vievj+vi]j[k],jci[k]evj+vi|\displaystyle=|\frac{\sum_{j\in[k],j\neq c}\sum_{i\in[k]}\bigl{[}e^{v^{\prime}_{j}+v_{i}}-e^{v_{j}+v^{\prime}_{i}}\bigl{]}}{\sum_{j\in[k],j\neq c}\sum_{i\in[k]}e^{v_{j}+v^{\prime}_{i}}}|
j[k],jci[k]|evj+vievj+vi|j[k],jci[k]evj+vi\displaystyle\leq\frac{\sum_{j\in[k],j\neq c}\sum_{i\in[k]}|e^{v^{\prime}_{j}+v_{i}}-e^{v_{j}+v^{\prime}_{i}}|}{\sum_{j\in[k],j\neq c}\sum_{i\in[k]}e^{v_{j}+v^{\prime}_{i}}}

For any j[k]j\in[k], jcj\neq c, and i[k]i\in[k], we have |evj+vievj+vi|evj+vie2𝒗𝒗1\frac{|e^{v_{j}^{\prime}+v_{i}}-e^{v_{j}+v_{i}^{\prime}}|}{e^{v_{j}+v_{i}^{\prime}}}\leq e^{2\lVert{\bm{v}}-{\bm{v}}^{\prime}\rVert_{\infty}}-1. This implies that j[k],jci[k]|evj+vievj+vi|(e2𝒗𝒗1)j[k],jci[k]evj+vi\sum_{j\in[k],j\neq c}\sum_{i\in[k]}|e^{v_{j}^{\prime}+v_{i}}-e^{v_{j}+v_{i}^{\prime}}|\leq(e^{2\lVert{\bm{v}}-{\bm{v}}^{\prime}\rVert_{\infty}}-1)\sum_{j\in[k],j\neq c}\sum_{i\in[k]}e^{v_{j}+v_{i}^{\prime}}, from which the result follows.

(iii) For the third inequality:

|𝕊c(𝒗)𝕊c(𝒒)𝕊c(𝒗)𝕊c(𝒒)1|\displaystyle|\frac{\mathbb{S}_{c}({\bm{v}}^{\prime})\mathbb{S}_{c}({\bm{q}}^{\prime})}{\mathbb{S}_{c}({\bm{v}})\mathbb{S}_{c}({\bm{q}})}-1| =|evci[k]evieqci[k]eqievci[k]evieqci[k]eqi1|\displaystyle=|\frac{\frac{e^{v^{\prime}_{c}}}{\sum_{i\in[k]}e^{v^{\prime}_{i}}}\frac{e^{q^{\prime}_{c}}}{\sum_{i\in[k]}e^{q^{\prime}_{i}}}}{\frac{e^{v_{c}}}{\sum_{i\in[k]}e^{v_{i}}}\frac{e^{q_{c}}}{\sum_{i\in[k]}e^{q_{i}}}}-1|
=|evci[k]evieqci[k]eqievci[k]evieqci[k]eqievci[k]evieqci[k]eqievci[k]evieqci[k]eqi|\displaystyle=|\frac{\frac{e^{v^{\prime}_{c}}}{\sum_{i\in[k]}e^{v^{\prime}_{i}}}\frac{e^{q^{\prime}_{c}}}{\sum_{i\in[k]}e^{q^{\prime}_{i}}}}{\frac{e^{v_{c}}}{\sum_{i\in[k]}e^{v_{i}}}\frac{e^{q_{c}}}{\sum_{i\in[k]}e^{q^{\prime}_{i}}}}-\frac{\frac{e^{v_{c}}}{\sum_{i\in[k]}e^{v^{\prime}_{i}}}\frac{e^{q_{c}}}{\sum_{i\in[k]}e^{q^{\prime}_{i}}}}{\frac{e^{v_{c}}}{\sum_{i\in[k]}e^{v^{\prime}_{i}}}\frac{e^{q_{c}}}{\sum_{i\in[k]}e^{q^{\prime}_{i}}}}|
=|evceqci[k]evij[k]eqjevceqci[k]evij[k]eqjevceqci[k]evij[k]eqjevceqci[k]evij[k]eqj|\displaystyle=|\frac{e^{v^{\prime}_{c}}e^{q^{\prime}_{c}}\sum_{i\in[k]}e^{v_{i}}\sum_{j\in[k]}e^{q_{j}}}{e^{v_{c}}e^{q_{c}}\sum_{i\in[k]}e^{v^{\prime}_{i}}\sum_{j\in[k]}e^{q^{\prime}_{j}}}-\frac{e^{v_{c}}e^{q_{c}}\sum_{i\in[k]}e^{v^{\prime}_{i}}\sum_{j\in[k]}e^{q^{\prime}_{j}}}{e^{v_{c}}e^{q_{c}}\sum_{i\in[k]}e^{v^{\prime}_{i}}\sum_{j\in[k]}e^{q^{\prime}_{j}}}|
=|i[k]j[k][evc+vi+qc+qjevc+vi+qc+qj]i[k]j[k]evc+vi+qc+qj|\displaystyle=|\frac{\sum_{i\in[k]}\sum_{j\in[k]}\bigl{[}e^{v^{\prime}_{c}+v_{i}+q^{\prime}_{c}+q_{j}}-e^{v_{c}+v^{\prime}_{i}+q_{c}+q^{\prime}_{j}}\bigr{]}}{\sum_{i\in[k]}\sum_{j\in[k]}e^{v_{c}+v^{\prime}_{i}+q_{c}+q^{\prime}_{j}}}|
i[k]j[k]|evc+vi+qc+qjevc+vi+qc+qj|i[k]j[k]evc+vi+qc+qj\displaystyle\leq\frac{\sum_{i\in[k]}\sum_{j\in[k]}|e^{v^{\prime}_{c}+v_{i}+q^{\prime}_{c}+q_{j}}-e^{v_{c}+v^{\prime}_{i}+q_{c}+q^{\prime}_{j}}|}{\sum_{i\in[k]}\sum_{j\in[k]}e^{v_{c}+v^{\prime}_{i}+q_{c}+q^{\prime}_{j}}}

For any i[k]i\in[k] and j[k]j\in[k], |evc+vi+qc+qjevc+vi+qc+qj|evc+vi+qc+qj=|evcvc+vivi+qcqc+qjqj1|e|vcvc|+|vivi|+|qcqc|+|qjqj|1e2(𝒗𝒗+𝒒𝒒)1\frac{|e^{v^{\prime}_{c}+v_{i}+q^{\prime}_{c}+q_{j}}-e^{v_{c}+v^{\prime}_{i}+q_{c}+q^{\prime}_{j}}|}{e^{v_{c}+v^{\prime}_{i}+q_{c}+q^{\prime}_{j}}}=|e^{v^{\prime}_{c}-v_{c}+v_{i}-v^{\prime}_{i}+q^{\prime}_{c}-q_{c}+q_{j}-q^{\prime}_{j}}-1|\leq e^{|v^{\prime}_{c}-v_{c}|+|v_{i}-v^{\prime}_{i}|+|q^{\prime}_{c}-q_{c}|+|q_{j}-q^{\prime}_{j}|}-1\leq e^{2(\lVert{\bm{v}}^{\prime}-{\bm{v}}\rVert_{\infty}+\lVert{\bm{q}}^{\prime}-{\bm{q}}\rVert_{\infty})}-1. Then, rearranging and summing over ii and jj leads to the result.

(iv) For the fourth inequality:

|𝕊c(𝒗)(1𝕊c(𝒒))𝕊c(𝒗)(1𝕊c(𝒒))1|\displaystyle|\frac{\mathbb{S}_{c}({\bm{v}}^{\prime})(1-\mathbb{S}_{c}({\bm{q}}^{\prime}))}{\mathbb{S}_{c}({\bm{v}})(1-\mathbb{S}_{c}({\bm{q}}))}-1| =|evcs[k]evs(1eqct[k]eqt)evcs[k]evs(1eqct[k]eqt)1|\displaystyle=|\frac{\frac{e^{v^{\prime}_{c}}}{\sum_{s\in[k]}e^{v^{\prime}_{s}}}(1-\frac{e^{q^{\prime}_{c}}}{\sum_{t\in[k]}e^{q^{\prime}_{t}}})}{\frac{e^{v_{c}}}{\sum_{s\in[k]}e^{v_{s}}}(1-\frac{e^{q_{c}}}{\sum_{t\in[k]}e^{q_{t}}})}-1|
=|evcs[k]evsi[k],iceqit[k]eqtevcs[k]evsi[k],iceqtt[k]eqt1|\displaystyle=|\frac{\frac{e^{v^{\prime}_{c}}}{\sum_{s\in[k]}e^{v^{\prime}_{s}}}\frac{\sum_{i\in[k],i\neq c}e^{q^{\prime}_{i}}}{\sum_{t\in[k]}e^{q^{\prime}_{t}}}}{\frac{e^{v_{c}}}{\sum_{s\in[k]}e^{v_{s}}}\frac{\sum_{i\in[k],i\neq c}e^{q_{t}}}{\sum_{t\in[k]}e^{q_{t}}}}-1|
=|i[k],ict[k]s[k]evc+qi+vs+qti[k],ict[k]s[k]evc+qi+vs+qt1|\displaystyle=|\frac{\sum_{i\in[k],i\neq c}\sum_{t\in[k]}\sum_{s\in[k]}e^{v^{\prime}_{c}+q^{\prime}_{i}+v_{s}+q_{t}}}{\sum_{i\in[k],i\neq c}\sum_{t\in[k]}\sum_{s\in[k]}e^{v_{c}+q_{i}+v^{\prime}_{s}+q^{\prime}_{t}}}-1|
i[k],ict[k]s[k]|evc+qi+vs+qtevc+qi+vs+qt|i[k],ict[k]s[k]evc+qi+vs+qt\displaystyle\leq\frac{\sum_{i\in[k],i\neq c}\sum_{t\in[k]}\sum_{s\in[k]}|e^{v^{\prime}_{c}+q^{\prime}_{i}+v_{s}+q_{t}}-e^{v_{c}+q_{i}+v^{\prime}_{s}+q^{\prime}_{t}}|}{\sum_{i\in[k],i\neq c}\sum_{t\in[k]}\sum_{s\in[k]}e^{v_{c}+q_{i}+v^{\prime}_{s}+q^{\prime}_{t}}}

For each i[k],ici\in[k],i\neq c, s[k]s\in[k], and t[k]t\in[k], we obtain |evc+qi+vs+qtevc+qi+vs+qt|evc+qi+vs+qte2(𝒗𝒗+𝒒𝒒)1\frac{|e^{v^{\prime}_{c}+q^{\prime}_{i}+v_{s}+q_{t}}-e^{v_{c}+q_{i}+v^{\prime}_{s}+q^{\prime}_{t}}|}{e^{v_{c}+q_{i}+v^{\prime}_{s}+q^{\prime}_{t}}}\leq e^{2(\lVert{\bm{v}}^{\prime}-{\bm{v}}\rVert_{\infty}+\lVert{\bm{q}}^{\prime}-{\bm{q}}\rVert_{\infty})}-1. Then, rearranging and summing over ii, ss, and tt leads to the result.

(v) Finally, for the fifth inequality:

|(1𝕊c(𝒗))(1𝕊c(𝒒))(1𝕊c(𝒗))(1𝕊c(𝒒))1|\displaystyle|\frac{(1-\mathbb{S}_{c}({\bm{v}}^{\prime}))(1-\mathbb{S}_{c}({\bm{q}}^{\prime}))}{(1-\mathbb{S}_{c}({\bm{v}}))(1-\mathbb{S}_{c}({\bm{q}}))}-1| =|(1evcs[k]evs)(1eqct[k]eqt)(1evcs[k]evs)(1eqct[k]eqt)1|\displaystyle=|\frac{(1-\frac{e^{v^{\prime}_{c}}}{\sum_{s\in[k]}e^{v^{\prime}_{s}}})(1-\frac{e^{q^{\prime}_{c}}}{\sum_{t\in[k]}e^{q^{\prime}_{t}}})}{(1-\frac{e^{v_{c}}}{\sum_{s\in[k]}e^{v_{s}}})(1-\frac{e^{q_{c}}}{\sum_{t\in[k]}e^{q_{t}}})}-1|
=|j[k],jcevjs[k]evsi[k],iceqit[k]eqtj[k],jcevjs[k]evsi[k],iceqit[k]eqt1|\displaystyle=|\frac{\frac{\sum_{j\in[k],j\neq c}e^{v^{\prime}_{j}}}{\sum_{s\in[k]}e^{v^{\prime}_{s}}}\frac{\sum_{i\in[k],i\neq c}e^{q^{\prime}_{i}}}{\sum_{t\in[k]}e^{q^{\prime}_{t}}}}{\frac{\sum_{j\in[k],j\neq c}e^{v_{j}}}{\sum_{s\in[k]}e^{v_{s}}}\frac{\sum_{i\in[k],i\neq c}e^{q_{i}}}{\sum_{t\in[k]}e^{q_{t}}}}-1|
=|j[k],jci[k],ict[k]s[k]evj+qi+vs+qtj[k],jci[k],ict[k]s[k]evj+qi+vs+qt1|\displaystyle=|\frac{\sum_{j\in[k],j\neq c}\sum_{i\in[k],i\neq c}\sum_{t\in[k]}\sum_{s\in[k]}e^{v^{\prime}_{j}+q^{\prime}_{i}+v_{s}+q_{t}}}{\sum_{j\in[k],j\neq c}\sum_{i\in[k],i\neq c}\sum_{t\in[k]}\sum_{s\in[k]}e^{v_{j}+q_{i}+v^{\prime}_{s}+q^{\prime}_{t}}}-1|
j[k],jci[k],ict[k]s[k]|evj+qi+vs+qtevj+qi+vs+qt|j[k],jci[k],ict[k]s[k]evj+qi+vs+qt.\displaystyle\leq\frac{\sum_{j\in[k],j\neq c}\sum_{i\in[k],i\neq c}\sum_{t\in[k]}\sum_{s\in[k]}|e^{v^{\prime}_{j}+q^{\prime}_{i}+v_{s}+q_{t}}-e^{v_{j}+q_{i}+v^{\prime}_{s}+q^{\prime}_{t}}|}{\sum_{j\in[k],j\neq c}\sum_{i\in[k],i\neq c}\sum_{t\in[k]}\sum_{s\in[k]}e^{v_{j}+q_{i}+v^{\prime}_{s}+q^{\prime}_{t}}}.

For each j[k]j\in[k] (jcj\neq c), i[k]i\in[k] (ici\neq c), s[k]s\in[k], and t[k]t\in[k], we have

|evj+qi+vs+qtevj+qi+vs+qt|evj+qi+vs+qt\displaystyle\frac{|e^{v^{\prime}_{j}+q^{\prime}_{i}+v_{s}+q_{t}}-e^{v_{j}+q_{i}+v^{\prime}_{s}+q^{\prime}_{t}}|}{e^{v_{j}+q_{i}+v^{\prime}_{s}+q^{\prime}_{t}}} =|evjvj+qiqi+vsvs+qtqt1|\displaystyle=|e^{v^{\prime}_{j}-v_{j}+q^{\prime}_{i}-q_{i}+v_{s}-v^{\prime}_{s}+q_{t}-q^{\prime}_{t}}-1|
e|vjvj|+|qiqi|+|vsvs|+|qtqt|1\displaystyle\leq e^{|v^{\prime}_{j}-v_{j}|+|q^{\prime}_{i}-q_{i}|+|v_{s}-v^{\prime}_{s}|+|q_{t}-q^{\prime}_{t}|}-1
e2(𝒗𝒗+𝒒𝒒)1\displaystyle\leq e^{2(\lVert{\bm{v}}^{\prime}-{\bm{v}}\rVert_{\infty}+\lVert{\bm{q}}^{\prime}-{\bm{q}}\rVert_{\infty})}-1

Then, rearranging and summing over jj, ii, ss, and tt leads to the result. ∎

Appendix E Normalized pp-Nrom Steepest Descent

Let \|\cdot\| denote (entrywise) pp-norm with p1p\geq 1 and \|\cdot\|_{\star} denote its dual.

Consider normalized steepest-descent with respect to the pp-norm:

𝑾t+1=𝑾tηt𝚫t,where𝚫t:=argmax𝚫1(𝑾t),𝚫.\displaystyle{\bm{W}}_{t+1}={\bm{W}}_{t}-\eta_{t}\bm{\Delta}_{t},\qquad\text{where}~{}~{}\bm{\Delta}_{t}:=\arg\max_{\|\bm{\Delta}\|\leq 1}\langle\nabla\mathcal{L}({\bm{W}}_{t}),\bm{\Delta}\rangle\,. (28)

Lemma E.1 generalizes Lemma B.1.

Lemma E.1 (𝒢(𝑾){\mathcal{G}}({\bm{W}}) as proxy to the loss-gradient norm).

Define the margin γ\gamma with respect to pp-norm, for p1p\geq 1 and denote \|\cdot\|_{\star} its dual norm. Then, for any 𝐖{\bm{W}} it holds that

2B𝒢(𝑾)(𝑾)γ𝒢(𝑾).2B\cdot{\mathcal{G}}({\bm{W}})\geq\|{\nabla\mathcal{L}({\bm{W}})}\|_{\star}\geq\gamma\cdot{\mathcal{G}}({\bm{W}})\,.
Proof.

First, we prove the lower bound. By duality and direct application of (12)

(𝑾)\displaystyle\|{\nabla\mathcal{L}({\bm{W}})}\|_{\star} =max𝑨1𝑨,(𝑾)\displaystyle=\max_{\|{\bm{A}}\|\leq 1}\langle\bm{A},-\nabla\mathcal{L}({\bm{W}})\rangle
max𝑨11ni[n](1siyi)mincyi(𝒆yi𝒆c)T𝑨𝒉i\displaystyle\geq\max_{\|\bm{A}\|\leq 1}\frac{1}{n}\sum_{i\in[n]}(1-s_{iy_{i}})\min_{c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}\bm{A}\bm{h}_{i}
1ni[n](1siyi)max𝑨1mini[n],cyi(𝒆yi𝒆c)T𝑨𝒉i.\displaystyle\geq\frac{1}{n}\sum_{i\in[n]}(1-s_{iy_{i}})\cdot\max_{\|\bm{A}\|\leq 1}\min_{i\in[n],c\neq y_{i}}(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}\bm{A}\bm{h}_{i}.

The upper bound follows because for pp-norms with p1p\geq 1

(𝑾)(𝑾)sum\|\nabla\mathcal{L}({\bm{W}})\|_{\star}\leq{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\rm{sum}}}

and we can use the bound for (𝑾)sum{\left\|\nabla\mathcal{L}({\bm{W}})\right\|_{\rm{sum}}}. ∎

A direct consequence of Lemma E.1 is the following.

Lemma E.2.

For any 𝐖,𝐖0k×d{\bm{W}},{\bm{W}}_{0}\in\mathbb{R}^{k\times d}, suppose that (𝐖)\mathcal{L}({\bm{W}}) is convex, we have

|(𝑾)(𝑾0)|2B𝑾𝑾0.\displaystyle|\mathcal{L}({\bm{W}})-\mathcal{L}({\bm{W}}_{0})|\leq 2B\lVert{\bm{W}}-{\bm{W}}_{0}\rVert.
Proof.

We replace the term (𝑾0)sum𝑾0𝑾max{\left\|\nabla\mathcal{L}({\bm{W}}_{0})\right\|_{\rm{sum}}}{\left\|{\bm{W}}_{0}-{\bm{W}}\right\|_{\max}} in Lemma B.2 with (𝑾0)𝑾0𝑾\lVert\nabla\mathcal{L}({\bm{W}}_{0})\rVert_{*}\lVert{\bm{W}}_{0}-{\bm{W}}\rVert. The rest proof follows the same steps as Lemma B.2. ∎

Lemma E.3 generalizes Lemma A.3.

Lemma E.3.

For any 𝐬Δk1\bm{s}\in\Delta^{k-1} in the kk-dimensional simplex, any index c[k]c\in[k], and any 𝐯k{\bm{v}}\in\mathbb{R}^{k} it holds:

𝒗(diag(𝒔)𝒔𝒔)𝒗4𝒗2(1sc){\bm{v}}^{\top}\left(\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\right){\bm{v}}\leq 4\,\|{\bm{v}}\|^{2}\,(1-s_{c})
Proof.

By Cauchy-Schwartz,

𝒗(diag(𝒔)𝒔𝒔)𝒗\displaystyle{\bm{v}}^{\top}\left(\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\right){\bm{v}} =vec(diag(𝒔)𝒔𝒔)vec(𝒗𝒗)\displaystyle=\operatorname{vec}\big{(}\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\big{)}^{\top}\operatorname{vec}\big{(}{\bm{v}}{\bm{v}}^{\top}\big{)}
vec(diag(𝒔)𝒔𝒔)vec(𝒗𝒗)\displaystyle\leq\|\operatorname{vec}\big{(}\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\big{)}\|_{\star}\|\operatorname{vec}\big{(}{\bm{v}}{\bm{v}}^{\top}\big{)}\|
𝒗2vec(diag(𝒔)𝒔𝒔).\displaystyle\leq\|{\bm{v}}\|^{2}\,\|\operatorname{vec}\big{(}\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\big{)}\|_{*}\,.

But,

vec(diag(𝒔)𝒔𝒔)vec(diag(𝒔)𝒔𝒔)1\displaystyle\|\operatorname{vec}\big{(}\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\big{)}\|_{\star}\leq\|\operatorname{vec}\big{(}\operatorname{diag}(\bm{s})-\bm{s}\bm{s}^{\top}\big{)}\|_{1}

and we can use Lemma A.3. ∎

When applying the above lemma to bound the Hessian term we also need to use the following:

𝚫𝒉𝚫𝒉,\displaystyle\|\bm{\Delta}\bm{h}\|\leq\|\bm{\Delta}\|\|\bm{h}\|_{\star}, (29)

which is true because for q=p/(1p)q=p/(1-p):

𝚫𝒉p=𝚫𝒉pp=j|𝒆j𝚫𝒉|ppj𝒆j𝚫pp𝒉qp=𝒉qpij|𝚫[i,j]|p=𝒉qp𝚫pp\displaystyle\|\bm{\Delta}\bm{h}\|^{p}=\|\bm{\Delta}\bm{h}\|_{p}^{p}=\sum_{j}|\bm{e}_{j}^{\top}\bm{\Delta}\bm{h}|_{p}^{p}\leq\sum_{j}\|\bm{e}_{j}^{\top}\bm{\Delta}\|_{p}^{p}\|\bm{h}\|_{q}^{p}=\|\bm{h}\|_{q}^{p}\sum_{ij}|\bm{\Delta}[i,j]|^{p}=\|\bm{h}\|_{q}^{p}\|\bm{\Delta}\|_{p}^{p} (30)

Appendix F Other multiclass loss functions

F.1 Exponential Loss

The multiclass exponential loss is given as

exp(𝑾):=1ni[n]cyiexp((𝒆yi𝒆c)𝑾𝒉i).\mathcal{L}_{\rm{exp}}({\bm{W}}):=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}\exp\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)\,.

The gradient of exp(𝑾)\mathcal{L}_{\exp}({\bm{W}}) is

exp(𝑾)=1ni[n]cyiexp((𝒆yi𝒆c)T𝑾𝒉i)(𝒆yi𝒆c)𝒉iT.\displaystyle\nabla\mathcal{L}_{\exp}({\bm{W}})=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}-\exp(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}\bm{h}_{i})(\bm{e}_{y_{i}}-\bm{e}_{c})\bm{h}_{i}^{T}.

Thus, for any matrix 𝑨k×d\bm{A}\in\mathbb{R}^{k\times d}, we have

𝑨,exp(𝑾)=1ni[n]cyiexp((𝒆yi𝒆c)𝑾𝒉i)(𝒆yi𝒆c)𝑨𝒉i.\displaystyle\langle{\bm{A}},-\nabla\mathcal{L}_{\exp}({\bm{W}})\rangle=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}\exp\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)\cdot\left(\bm{e}_{y_{i}}-\bm{e}_{c}\right)^{\top}{\bm{A}}\bm{h}_{i}\,.

This motivates us to define 𝒢(𝑾){\mathcal{G}}({\bm{W}}) as

𝒢exp(𝑾)=1ni[n]cyiexp((𝒆yi𝒆c)𝑾𝒉i),\displaystyle{\mathcal{G}}_{\exp}({\bm{W}})=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}\exp\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right),

from which we recognize that 𝒢exp(𝑾)=exp(𝑾){\mathcal{G}}_{\exp}({\bm{W}})=\mathcal{L}_{\exp}({\bm{W}}) and the rest follows.

F.2 PairLogLoss

The PairLogLoss loss (Wang et al., 2021b) is given as

pll(𝑾):=1ni[n]cyilog(1+exp((𝒆yi𝒆c)𝑾𝒉i)).\mathcal{L}_{\rm{pll}}({\bm{W}}):=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}\log\left(1+\exp\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)\right)\,.

Note that =1ni[n]cyif((𝒆yi𝒆c)𝑾𝒉i)\mathcal{L}=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}f\left((\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right) where f(t):=log(1+et)f(t):=\log(1+e^{-t}) denotes the logistic loss. Therefore, the Taylor expansion of PLL writes:

pll(𝑾+𝚫)\displaystyle\mathcal{L}_{\rm{pll}}({\bm{W}}+\bm{\Delta}) =(𝑾)+1ni[n]cyif((𝒆yi𝒆c)𝑾𝒉i)(𝒆yi𝒆c)𝚫𝒉i\displaystyle=\mathcal{L}({\bm{W}})+\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}f^{\prime}\left((\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)\cdot(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}\bm{\Delta}\bm{h}_{i}
+1ni[n]cyif((𝒆yi𝒆c)𝑾𝒉i)𝒉i𝚫(𝒆yi𝒆c)(𝒆yi𝒆c)𝚫𝒉i+o(𝚫3).\displaystyle\qquad+\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}f^{\prime\prime}\left((\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)\cdot\bm{h}_{i}^{\top}\bm{\Delta}^{\top}(\bm{e}_{y_{i}}-\bm{e}_{c})(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}\bm{\Delta}\bm{h}_{i}+o\left(\|\bm{\Delta}\|^{3}\right)\,. (31)

From the above, the gradient of the PLL loss is:

pll(𝑾)\displaystyle\nabla\mathcal{L}_{\rm{pll}}({\bm{W}}) =1ni[n]cyif((𝒆yi𝒆c)𝑾𝒉i)(𝒆yi𝒆c)𝒉i\displaystyle=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}f^{\prime}\left((\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)\cdot\left(\bm{e}_{y_{i}}-\bm{e}_{c}\right)\bm{h}_{i}^{\top}
=1ni[n]cyiexp((𝒆yi𝒆c)𝑾𝒉i)1+exp((𝒆yi𝒆c)𝑾𝒉i)(𝒆yi𝒆c)𝒉i\displaystyle=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}\frac{-\exp\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)}{1+\exp\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)}\left(\bm{e}_{y_{i}}-\bm{e}_{c}\right)\bm{h}_{i}^{\top} (32)

Thus, for any matrix 𝑨k×d{\bm{A}}\in\mathbb{R}^{k\times d},

𝑨,pll(𝑾)=1ni[n]cyi|f((𝒆yi𝒆c)𝑾𝒉i)|(𝒆yi𝒆c)𝑨𝒉i.\displaystyle\langle{\bm{A}},-\nabla\mathcal{L}_{\rm{pll}}({\bm{W}})\rangle=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}|f^{\prime}\left((\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)|\cdot\left(\bm{e}_{y_{i}}-\bm{e}_{c}\right)^{\top}{\bm{A}}\bm{h}_{i}\,. (33)

This motivates us to define

𝒢pll(𝑾)=1ni[n]cyi|f((𝒆yi𝒆c)𝑾𝒉i)|=1ni[n]cyiexp((𝒆yi𝒆c)𝑾𝒉i)1+exp((𝒆yi𝒆c)𝑾𝒉i)\displaystyle{\mathcal{G}}_{\rm{pll}}({\bm{W}})=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}\left|f^{\prime}\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)\right|=\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}\frac{\exp\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)}{1+\exp\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)} (34)
Lemma F.1 (Analogue of Lemma B.1 for PLL).

For any 𝐖{\bm{W}}, the PairLogLoss (PLL) satisfies:

2B𝒢pll(𝑾)pll(𝑾)sumγ𝒢pll(𝑾).2B\cdot{\mathcal{G}}_{\rm{pll}}({\bm{W}})\geq{\left\|\nabla\mathcal{L}_{\rm{pll}}({\bm{W}})\right\|_{\rm{sum}}}\geq\gamma\cdot{\mathcal{G}}_{\rm{pll}}({\bm{W}})\,.
Proof.

The lower bound follows immediately from (33) and expressing pll(𝑾)sum=max𝑨max1𝑨,pll(𝑾){\left\|\nabla\mathcal{L}_{\rm{pll}}({\bm{W}})\right\|_{\rm{sum}}}=\max_{{\left\|{\bm{A}}\right\|_{\max}}\leq 1}\langle{\bm{A}},-\nabla\mathcal{L}_{\rm{pll}}({\bm{W}})\rangle. The lower bound follows from triangle inequality applied to (32):

pll(𝑾)sum1ni[n]cyi|f((𝒆yi𝒆c)𝑾𝒉i)|𝒆yi𝒆c1𝒉i12B𝒢(𝑾).{\left\|\nabla\mathcal{L}_{\rm{pll}}({\bm{W}})\right\|_{\rm{sum}}}\leq\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}\left|f^{\prime}\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)\right|\|\bm{e}_{y_{i}}-\bm{e}_{c}\|_{1}\|\bm{h}_{i}\|_{1}\leq 2B\cdot{\mathcal{G}}({\bm{W}})\,.

For bounding with 𝒢(𝑾){\mathcal{G}}({\bm{W}}) the second-order term in the Taylor expansion of PLL, note the following. First, for all i[n],cyii\in[n],c\neq y_{i}:

𝒉i𝚫(𝒆yi𝒆c)(𝒆yi𝒆c)𝚫𝒉i\displaystyle\bm{h}_{i}^{\top}\bm{\Delta}^{\top}(\bm{e}_{y_{i}}-\bm{e}_{c})(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}\bm{\Delta}\bm{h}_{i} =(𝒆yi𝒆c)(𝒆yi𝒆c),𝚫𝒉i𝒉i𝚫T(𝒆yi𝒆c)(𝒆yi𝒆c)sum𝚫𝒉i𝒉i𝚫Tmax\displaystyle=\langle(\bm{e}_{y_{i}}-\bm{e}_{c})(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top},\bm{\Delta}\bm{h}_{i}\bm{h}_{i}^{\top}\bm{\Delta}^{T}\rangle\leq{\left\|(\bm{e}_{y_{i}}-\bm{e}_{c})(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}\right\|_{\rm{sum}}}{\left\|\bm{\Delta}\bm{h}_{i}\bm{h}_{i}^{\top}\bm{\Delta}^{T}\right\|_{\max}}
𝒆yi𝒆c)12𝚫𝒉i2\displaystyle\leq\|\bm{e}_{y_{i}}-\bm{e}_{c})\|_{1}^{2}\cdot\|\bm{\Delta}\bm{h}_{i}\|_{\infty}^{2}
4(𝚫max)2𝒉i124B2(𝚫max)2.\displaystyle\leq 4\cdot\left({\left\|\bm{\Delta}\right\|_{\max}}\right)^{2}\cdot\|\bm{h}_{i}\|_{1}^{2}\leq 4B^{2}\left({\left\|\bm{\Delta}\right\|_{\max}}\right)^{2}\,.

Second, the (easy to check) property of logistic loss that f(t)|f(t)|f^{\prime\prime}(t)\leq|f^{\prime}(t)|. Putting these together:

1ni[n]cyif((𝒆yi𝒆c)𝑾𝒉i)𝒉i𝚫(𝒆yi𝒆c)(𝒆yi𝒆c)𝚫𝒉i4B2𝒢(𝑾)(𝚫max)2.\frac{1}{n}\sum_{i\in[n]}\sum_{c\neq y_{i}}f^{\prime\prime}\left((\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right)\cdot\bm{h}_{i}^{\top}\bm{\Delta}^{\top}(\bm{e}_{y_{i}}-\bm{e}_{c})(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}\bm{\Delta}\bm{h}_{i}\leq 4B^{2}\cdot{\mathcal{G}}({\bm{W}})\cdot\left({\left\|\bm{\Delta}\right\|_{\max}}\right)^{2}\,.

Finally, we verify PLL satisfies Lemma B.3.

Lemma F.2 (Analogue of Lemma B.3 for PLL).

Let 𝐖k×d{\bm{W}}\in\mathbb{R}^{k\times d}, we have

  1. (i)

    1𝒢pll(𝑾)pll(𝑾)1npll(𝑾)21\geq\frac{{\mathcal{G}}_{\rm{pll}}({\bm{W}})}{\mathcal{L}_{\rm{pll}}({\bm{W}})}\geq 1-\frac{n\mathcal{L}_{\rm{pll}}({\bm{W}})}{2}

  2. (ii)

    Suppose that 𝑾{\bm{W}} satisfies pll(𝑾)log2n\mathcal{L}_{\rm{pll}}({\bm{W}})\leq\frac{\log 2}{n} or 𝒢pll(𝑾)12n{\mathcal{G}}_{\rm{pll}}({\bm{W}})\leq\frac{1}{2n}, then pll(𝑾)2𝒢pll(𝑾).\mathcal{L}_{\rm{pll}}({\bm{W}})\leq 2{\mathcal{G}}_{\rm{pll}}({\bm{W}}).

Proof.

(i) The upper bound follows by the well-known self-boundedness property of the logistic loss, namely |f(t)|f(t)|f^{\prime}(t)|\leq f(t)

To prove the upper bound, it suffices to prove for for x>0x>0:

x1+xlog(1+x)12log2(1+x).\displaystyle\frac{x}{1+x}\geq\log(1+x)-\frac{1}{2}\log^{2}(1+x). (35)

The general case follows by summing over xic=exp((𝒆yi𝒆c)𝑾𝒉i),i[n],cyix_{ic}=\exp\left(-(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i}\right),i\in[n],c\neq y_{i} since then we have

𝒢(𝑾)=i[n]cyixic1+xic\displaystyle{\mathcal{G}}({\bm{W}})=\sum_{i\in[n]}\sum_{c\neq y_{i}}\frac{x_{ic}}{1+x_{ic}} i[n]cyilog(1+xic)12i[n]cyilog2(1+xic)\displaystyle\geq\sum_{i\in[n]}\sum_{c\neq y_{i}}\log(1+x_{ic})-\frac{1}{2}\sum_{i\in[n]}\sum_{c\neq y_{i}}\log^{2}(1+x_{ic})
i[n]cyilog(1+xic)12(i[n]cyilog(1+xic))2,\displaystyle\geq\sum_{i\in[n]}\sum_{c\neq y_{i}}\log(1+x_{ic})-\frac{1}{2}\left(\sum_{i\in[n]}\sum_{c\neq y_{i}}\log(1+x_{ic})\right)^{2}\,,

where the last line used log(1+xic)0\log(1+x_{ic})\geq 0. For (15), let a=log(1+x)>0a=\log(1+x)>0. The inequality becomes ea1a+a2/2e^{-a}\leq 1-a+a^{2}/2, which holds for a>0a>0 by the second-order Taylor expansion of eae^{-a} around 0.

(ii) Denote pll\mathcal{L}\coloneqq\mathcal{L}_{pll} and 𝒢𝒢pll{\mathcal{G}}\coloneqq{\mathcal{G}}_{pll}. Given log(2)n1n\mathcal{L}\leq\frac{\log(2)}{n}\leq\frac{1}{n}, we have 1n2121-\frac{n\mathcal{L}}{2}\geq\frac{1}{2}, then the first part follows from (i). For the second part, denote lic:=(𝒆yi𝒆c)𝑾𝒉i,i[n],cyil_{ic}:=(\bm{e}_{y_{i}}-\bm{e}_{c})^{\top}{\bm{W}}\bm{h}_{i},i\in[n],c\neq y_{i}. For 2𝒢\mathcal{L}\leq 2{\mathcal{G}} to hold, it is sufficient to show that log(1+elic)2elic1+elic\log(1+e^{-l_{ic}})\leq 2\frac{e^{-l_{ic}}}{1+e^{-l_{ic}}} for all i[n],cyii\in[n],c\neq y_{i}. This holds true when lic1.366l_{ic}\geq-1.366, which is clearly satisfied given the assumption 𝒢12n{\mathcal{G}}\leq\frac{1}{2n} implying lic0l_{ic}\geq 0. ∎

Lemma F.3 (Analogue of Lemma B.5 for PLL).

For any ψ[0,1]\psi\in[0,1], we have the following:

𝒢pll(𝑾+ψ𝑾)𝒢pll(𝑾)e2Bψ𝑾max+2\displaystyle\frac{{\mathcal{G}}_{pll}({\bm{W}}+\psi\triangle{\bm{W}})}{{\mathcal{G}}_{pll}({\bm{W}})}\leq e^{2B\psi{\left\|\triangle{\bm{W}}\right\|_{\max}}}+2
Proof.

For logistic loss f(z)=log(1+ez)f(z)=\log(1+e^{-z}), for any z1,z2z_{1},z_{2}\in\mathbb{R}, we have the following

|f(z1)f(z2)|=|1+ez21+ez1|\displaystyle\bigm{|}\frac{f^{\prime}(z_{1})}{f^{\prime}(z_{2})}\bigm{|}=\bigm{|}\frac{1+e^{z_{2}}}{1+e^{z_{1}}}\bigm{|} =|1+ez2ez1+ez11+ez1|\displaystyle=\bigm{|}\frac{1+e^{z_{2}}-e^{z_{1}}+e^{z_{1}}}{1+e^{z_{1}}}\bigm{|}
=|ez2ez11+ez1+1||ez2ez11+ez1|+1\displaystyle=\bigm{|}\frac{e^{z_{2}}-e^{z_{1}}}{1+e^{z_{1}}}+1\bigm{|}\leq\bigm{|}\frac{e^{z_{2}}-e^{z_{1}}}{1+e^{z_{1}}}\bigm{|}+1
|ez2z11|+1\displaystyle\leq\bigm{|}e^{z_{2}-z_{1}}-1\bigm{|}+1
e|z2z1|+2.\displaystyle\leq e^{|z_{2}-z_{1}|}+2.

Denote xic𝑾:=(𝒆yi𝒆c)T𝑾𝒉ix_{ic}^{{\bm{W}}}:=(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}\bm{h}_{i} and xic𝑾:=(𝒆yi𝒆c)T(𝑾+ψ𝚫𝑾)𝒉ix_{ic}^{{\bm{W}}^{\prime}}:=(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}({\bm{W}}+\psi\bm{\Delta}{\bm{W}})\bm{h}_{i}, then we have for i[n]i\in[n], cyic\neq y_{i}

f(xic𝑾)f(xic𝑾)=|f(xic𝑾)f(xic𝑾)|e|xic𝑾xic𝑾|+2\displaystyle\frac{f^{\prime}(x_{ic}^{{\bm{W}}^{\prime}})}{f^{\prime}(x_{ic}^{{\bm{W}}})}=|\frac{f^{\prime}(x_{ic}^{{\bm{W}}^{\prime}})}{f^{\prime}(x_{ic}^{{\bm{W}}})}|\leq e^{|x_{ic}^{{\bm{W}}}-x_{ic}^{{\bm{W}}^{\prime}}|}+2 =eψ|(𝒆c𝒆yi)T𝚫𝑾𝒉i|+2=eψ|𝚫𝑾,(𝒆c𝒆yi)𝒉iT|+2\displaystyle=e^{\psi|(\bm{e}_{c}-\bm{e}_{y_{i}})^{T}\bm{\Delta}{\bm{W}}\bm{h}_{i}|}+2=e^{\psi|\langle\bm{\Delta}{\bm{W}},(\bm{e}_{c}-\bm{e}_{y_{i}})\bm{h}_{i}^{T}\rangle|}+2
eψ𝚫𝑾max(𝒆c𝒆yi)𝒉iTsum+2\displaystyle\leq e^{\psi\lVert\bm{\Delta}{\bm{W}}\rVert_{\max}{\left\|(\bm{e}_{c}-\bm{e}_{y_{i}})\bm{h}_{i}^{T}\right\|_{\rm{sum}}}}+2
=eψ𝚫𝑾max𝒆c𝒆yisum𝒉isum+2\displaystyle=e^{\psi\lVert\bm{\Delta}{\bm{W}}\rVert_{\max}{\left\|\bm{e}_{c}-\bm{e}_{y_{i}}\right\|_{\rm{sum}}}{\left\|\bm{h}_{i}\right\|_{\rm{sum}}}}+2
e2Bψ𝚫𝑾max+2.\displaystyle\leq e^{2B\psi{\left\|\bm{\Delta}{\bm{W}}\right\|_{\max}}}+2.

This leads to i[n]cyif(xic𝑾)(e2Bψ𝚫𝑾max+2)i[n]cyif(xic𝑾)\sum_{i\in[n]}\sum_{c\neq y_{i}}f^{\prime}(x_{ic}^{{\bm{W}}^{\prime}})\leq(e^{2B\psi{\left\|\bm{\Delta}{\bm{W}}\right\|_{\max}}}+2)\sum_{i\in[n]}\sum_{c\neq y_{i}}f^{\prime}(x_{ic}^{{\bm{W}}}). Rearrange and using the definition of 𝒢pll(𝑾){\mathcal{G}}_{pll}({\bm{W}}), we obtain the desired. ∎

Lemma F.4 (Analogue of Lemma B.4 for PLL).

Suppose that there exists 𝐖k×d{\bm{W}}\in\mathbb{R}^{k\times d} such that pll(𝐖)log2n\mathcal{L}_{pll}({\bm{W}})\leq\frac{\log 2}{n}, then we have

(𝒆yi𝒆c)T𝑾𝒉i0,for all i[n] and for all c[k] such that cyi.\displaystyle(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}\bm{h}_{i}\geq 0,\quad\text{for all $i\in[n]$ and for all $c\in[k]$ such that $c\neq y_{i}$}. (36)
Proof.

Denote xic=(𝒆yi𝒆c)T𝑾𝒉ix_{ic}=(\bm{e}_{y_{i}}-\bm{e}_{c})^{T}{\bm{W}}\bm{h}_{i}. Then, by the assumption, we have for any i[n],cyii\in[n],c\neq y_{i}

log(1+exic)i[n]cyilog(1+exic)log(2).\displaystyle\log(1+e^{-x_{ic}})\leq\sum_{i\in[n]}\sum_{c\neq y_{i}}\log(1+e^{-x_{ic}})\leq\log(2).

This implies that xic0x_{ic}\geq 0 for all i[n],cyii\in[n],c\neq y_{i}. ∎

Lemma F.5 (Analogue of Lemma B.2 for PLL).

For any 𝐖,𝐖0k×d{\bm{W}},{\bm{W}}_{0}\in\mathbb{R}^{k\times d}, suppose that (𝐖)\mathcal{L}({\bm{W}}) is convex, we have

|pll(𝑾)pll(𝑾0)|2B𝑾𝑾0max.\displaystyle|\mathcal{L}_{pll}({\bm{W}})-\mathcal{L}_{pll}({\bm{W}}_{0})|\leq 2B{\left\|{\bm{W}}-{\bm{W}}_{0}\right\|_{\max}}.
Proof.

This lemma is a direct consequence of Lemma F.1 and be proved in the same way as Lemma B.2. ∎

Thus, we have proved all the Lemmas for 𝒢pll(𝑾){\mathcal{G}}_{pll}({\bm{W}}) and its relationships to pll(𝑾)\mathcal{L}_{pll}({\bm{W}}) in analogous to those in section B. The proof of SignGD ((4)) or NSD ((5)) with PairLogLoss follow the same steps as with cross-entropy loss given in section C.