This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

FedSC: Provable Federated Self-supervised Learning
with Spectral Contrastive Objective over Non-i.i.d. Data

Shusen Jing    Anlan Yu    Shuai Zhang    Songyang Zhang
Abstract

Recent efforts have been made to integrate self-supervised learning (SSL) with the framework of federated learning (FL). One unique challenge of federated self-supervised learning (FedSSL) is that the global objective of FedSSL usually does not equal the weighted sum of local SSL objectives. Consequently, conventional approaches, such as federated averaging (FedAvg), fail to precisely minimize the FedSSL global objective, often resulting in suboptimal performance, especially when data is non-i.i.d.. To fill this gap, we propose a provable FedSSL algorithm, named FedSC, based on the spectral contrastive objective. In FedSC, clients share correlation matrices of data representations in addition to model weights periodically, which enables inter-client contrast of data samples in addition to intra-client contrast and contraction, resulting in improved quality of data representations. Differential privacy (DP) protection is deployed to control the additional privacy leakage on local datasets when correlation matrices are shared. We also provide theoretical analysis on the convergence and extra privacy leakage. The experimental results validate the effectiveness of our proposed algorithm.

Machine Learning, ICML

1 Introduction

As a type of unsupervised learning, self-supervised learning (SSL) aims to learn a structured representation space, in which data similarity can be measured by simple metrics, such as cosine and Euclidean distances, with unlabeled data (Chen et al., 2020; Chen & He, 2021; Grill et al., 2020; He et al., 2020; Zbontar et al., 2021; Bardes et al., 2021; HaoChen et al., 2021). On top of the foundation model trained with SSL, a simple linear layer, also known as linear probe, is sufficient to perform well on a wide range of downstream tasks with minimal labeled data. Resulting from its high label efficiency, SSL has been adopted in a variety of applications, such as natural language processing (He et al., 2021; Brown et al., 2020) and computer vision (Ravi & Larochelle, 2016; Hu et al., 2021).

However, SSL algorithms are often executed on massive amounts of unlabeled data that may be dispersed across various locations. Moreover, the progressively tightening privacy-protection regulations frequently inhibit the centralization of data. Within this context, the federated learning (FL) framework is often favored, wherein a central server can learn from private data located on clients without the data being shared directly (McMahan et al., 2017; Stich, 2018; Li et al., 2019).

Despite the extensive study and theoretical guarantees (Stich, 2018; Li et al., 2019) associated with conventional FL, its generalization to incorporate with SSL is not straightforward. The fundamental challenge arises from the fact that, unlike FL within supervised learning, the global objective of FedSSL usually does not equal the weighted sum of local SSL objectives. Consequently, conventional FL approaches, e.g. federated averaging (FedAvg), can not minimize the exact global objective of FedSSL especially when data is non-independent and identically distributed (non-i.i.d.). From the perspective of contrastive learning, FedAvg only contrasts data samples within the same client (intra-client) rather than those across different clients (inter-client). Therefore, the learned representation might not be as effective at distinguishing inter-client data samples as it is with intra-client data samples.

Although recent works on FedSSL have shown great numerical success (Zhuang et al., 2021, 2022; Zhang et al., 2023; Han et al., 2022), the majority of them either overlook previously mentioned challenge or fail to offer a theoretical analysis. FedU (Zhuang et al., 2021) and FedEMA (Zhuang et al., 2021) lack the formulation of global objective and thus fail to provide theoretical analysis. FedCA (Zhang et al., 2023) notices the unique challenge and proposes to share data representations, which, however, results in significant privacy leakage and communication overhead. Unlike FedU and FedEMA, which involve sharing predictors, and FedCA, which shares data representations, our proposed FedSC results in much lower communication costs, since sharing correlation matrices requires transmitting far fewer parameters than what is needed for predictors or data representations. FedX (Han et al., 2022) does not share additional information besides encoders, but still lacks theoretical analysis. Among all these works, only our proposed FedSC deploys differential privacy (DP) protection to mitigate the extra privacy leakage from components other than encoders. Moreover, FedSC is the only provable FedSSL method to the best knowledge of the authors. Table 1 summarizes the difference between this work and state of the arts (SOTAs).

Table 1: A comparison with SOTAs: FedSC (proposed) is the only one applying DP mechanism on components other than encoder. Moreover, FedSC is the only provable method among them.
Info. shared besides encoder Privacy Protection Provable
FedU predictor ×\times ×\times
FedEMA predictor ×\times ×\times
FedX N/A ×\times ×\times
FedCA representations ×\times ×\times
FedSC correlation matrices \surd \surd

Contribution. In this work, we propose a novel FedSSL formulation based on the spectral contrastive (SC) objective (HaoChen et al., 2021). The formulation clarifies all the necessary components in FedSSL encompassing intra-client contraction, intra-client contrast and inter-client contrast. Building upon this formulation, we propose the first provable FedSSL method, namely FedSC, with the convergence guarantee to the solutions of centralized SSL. Unlike FedAvg, clients in FedSC share correlation matrices of their local data representations in addition to the weights of local models. By leveraging the aggregated correlation matrix from the server, inter-client contrast of data samples, which is overlooked in FedAvg, can be performed in addition to local contrast and contraction. To better control and quantify the extra privacy leakage, we apply DP mechanism to correlation matrices when they are shared. We made theoretical analysis of FedSC, demonstrating the convergence of the global objective and efficacy of our method. Our contributions are summarized as follows:

\bullet We propose a novel FedSSL formulation delineating all essential components of FedSSL, which encompasses intra-client contraction, intra-client contrast and inter-client contrast. This highlights the limitations of FedAvg due to its neglect of the inter-client contrast.

\bullet We propose FedSC, in which clients are able to perform inter-client contrast of data samples by leveraging the correlation matrices of data representations shared from others, resulting in improved quality of data representations.

\bullet DP protection is applied, which effectively constrains the privacy leakage resulting from sharing correlation matrices with only negligible utility degradation.

\bullet Theoretical analysis of FedSC is made, providing extra privacy leakage and convergence guarantee for the global FedSSL objective. We prove that FedSC can achieve a 𝒪(1/T)\mathcal{O}(1/\sqrt{T}) convergence rate, while FedAvg will have a constant error floor.

\bullet Through extensive experimentation involving 33 datasets across 44 SOTAs, we affirm that FedSC achieves superior or comparable performance compared with other methods.

2 Related Works

Self-supervised learning. SSL can be mainly categorized into contrastive and non-contrastive SSL. The mechanisms and explicit objective of non-contrastive SSL algorithms are still not fully understood despite a few recent attempts (Halvagal et al., 2023; Tian et al., 2021; Zhang et al., 2022). In contrast, contrastive SSL is more intuitive and explainable. Contrastive SSL explicitly penalizes the distance between positive pairs (two samples share the same semantic meaning), while encouraging distance between negative pairs (two samples share different semantic meanings). For example, SimCLR (Chen et al., 2020) objective accounts for the mutual information between positive pairs (Tschannen et al., 2019) preserved by representations. The SC objective (HaoChen et al., 2021) is equivalent to performing a spectral decomposition of the augmentation graph.

Federated Self-supervised Learning. In FedU (Zhuang et al., 2021), clients make decisions on whether the local model should be updated by the global based on the distances of two model weights when receiving global models from the server. As a follow up, FedEMA (Zhuang et al., 2022) is proposed, in which the hard decision in FedU is replaced with a weighted combination of local and global models. FedX (Han et al., 2022) designs local and global objectives using the idea of cross knowledge distillation to mitigate the effects of non-i.i.d. data. The authors of FedCA (Zhang et al., 2023) propose to share features of individual data samples in addition to local model weights for inter-client contrast, which however, results in significant privacy leakage and communication overhead.

Differential Privacy. Gaussian and Laplace mechanisms are most common DP approaches to protect a dataset from membership attack (Dwork, 2006). To better analyze DP, (Mironov, 2017) proposes Rényi differential privacy (RDP), which characterizes the operations on mechanisms, such as composition, in a more elegant way, and proves the equivalence between DP and RDP. Currently, DP has been widely deployed in FL (Wei et al., 2020; Truex et al., 2020; Hu et al., 2020; Geyer et al., 2017; Noble et al., 2022).

Refer to caption
Figure 1: Diagram of the proposed FedSC. 1) The server synchronizes local models with the global model. 2) Clients compute their local correlation matrices of dataset and send them to the server. 3) The server distributes the aggregated global correlation matrices back to the clients. 4) The clients proceed to update their local models in accordance with the local objective specified in Eq. (5). 5) The server aggregates the local models and initiates the next iteration.

3 Preliminaries: Spectral Contrastive (SC) Self-supervised Learning

Spectral contrastive (SC) SSL is proposed in (HaoChen et al., 2021) with the following objective:

SC(θ;𝒟)12𝖤x,x𝒜(|𝒟)[(z(x;θ)Tz(x;θ))2]𝖤x¯𝒟𝖤x,x+𝒜(|x¯)[z(x;θ)Tz(x+;θ)],\displaystyle\begin{aligned} &\mathcal{L}^{SC}(\theta;\mathcal{D})\triangleq\frac{1}{2}\mathsf{E}_{x,x^{-}\sim\mathcal{A}(\cdot|\mathcal{D})}\left[\left(z(x;\theta)^{T}z(x^{-};\theta)\right)^{2}\right]\\ &-\mathsf{E}_{\bar{x}\sim\mathcal{D}}\mathsf{E}_{x,x^{+}\sim\mathcal{A}(\cdot|\bar{x})}\left[z(x;\theta)^{T}z(x^{+};\theta)\right],\\ \end{aligned}

where 𝒟\mathcal{D} is the dataset; z(;θ):dHz(\cdot;\theta):\mathbb{R}^{d}\rightarrow\mathbb{R}^{H} is the representation mapping parameterized by θ\theta; 𝖤\mathsf{E} is referred to as the operator of expectation; 𝒜(|x¯)\mathcal{A}(\cdot|\bar{x}) is referred to as the augmentation kernel, which is essentially a conditional distribution, and 𝒜(|𝒟)𝖤x¯𝒟𝒜(|x¯)\mathcal{A}(\cdot|\mathcal{D})\triangleq\mathsf{E}_{\bar{x}\sim\mathcal{D}}\mathcal{A}(\cdot|\bar{x}). We use (x,x)(x,x^{-}) to denote negative pairs, where sample xx and xx^{-} have different semantic meanings, and (x,x+)(x,x^{+}) to denote positive pairs, where xx and x+x^{+} have same semantic meaning. Intuitively, minimizing the SC objective encourages the orthogonality of representations of a negative pair, and simultaneously promotes linear alignment of representations of a positive pair. It has been proved that solving this optimization problem is equivalent to doing spectral decomposition of a well-defined augmentation graph, whose nodes are augmented images, i.e, from 𝒜(|𝒟)\mathcal{A}(\cdot|\mathcal{D}), and edges describe the semantic similarity of two images determined by the kernel 𝒜(|)\mathcal{A}(\cdot|\cdot), which results in high-quality and explainable data (node) representations (HaoChen et al., 2021).

After rearranging the original SC objective, we first propose the following equivalent form not reported in (HaoChen et al., 2021).

SC(θ;𝒟)=𝖤x¯𝒟Tr{R+(x¯;θ)}+12𝖤x¯𝒟R(x¯;θ)F2,\displaystyle\begin{aligned} \mathcal{L}^{SC}(\theta;\mathcal{D})&=-\mathsf{E}_{\bar{x}\sim\mathcal{D}}Tr\{R^{+}(\bar{x};\theta)\}\\ &+\frac{1}{2}\left\lVert\mathsf{E}_{\bar{x}\sim\mathcal{D}}R(\bar{x};\theta)\right\rVert_{F}^{2},\end{aligned} (1)

where R+(x¯;θ)H×HR^{+}(\bar{x};\theta)\in\mathbb{R}^{H\times H} and R(x¯;θ)H×HR(\bar{x};\theta)\in\mathbb{R}^{H\times H} are defined respectively as

R+(x¯;θ)𝖤x,x+𝒜(|x¯)[z(x;θ)z(x+;θ)T],R(x¯;θ)𝖤x𝒜(|x¯)[z(x;θ)z(x;θ)T].\displaystyle\begin{aligned} R^{+}(\bar{x};\theta)&\triangleq\mathsf{E}_{x,x^{+}\sim\mathcal{A}(\cdot|\bar{x})}\left[z(x;\theta)z(x^{+};\theta)^{T}\right],\\ R(\bar{x};\theta)&\triangleq\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}\left[z(x;\theta)z(x;\theta)^{T}\right].\end{aligned}

The detailed derivations are given in Appendix A.

4 Problem Formulation

In an FL system consists of a server and JJ clients, the jj-th client owns a private local dataset 𝒟j\mathcal{D}_{j} disjoint with others. The goal of FedSSL is to optimize the SSL model (SC model in this work) over the union of all local datasets, i.e,

minθSC(θ;𝒟),\vspace{-1mm}\min_{\theta}\mathcal{L}^{SC}(\theta;\mathcal{D}), (2)

where 𝒟=j=1J𝒟j\mathcal{D}=\cup_{j=1}^{J}\mathcal{D}_{j}. Like the majority of SSL objectives, the global SC objective typically does not equal the weighted sum of local SC objectives, especially with non-i.i.d. data distribution. For the purpose of rigor, we make it an assumption instead of a claim in this work as follows.

SC(θ;𝒟)j=1JqjSC(θ;𝒟j),\displaystyle\vspace{-1mm}\mathcal{L}^{SC}(\theta;\mathcal{D})\neq\sum_{j=1}^{J}q_{j}\mathcal{L}^{SC}(\theta;\mathcal{D}_{j}), (3)

where {qj}\{q_{j}\} are weights depending on the amount of local data. As a result, FedAvg is not guaranteed to minimize the global objective SC(θ;𝒟)\mathcal{L}^{SC}(\theta;\mathcal{D}) when data is non-i.i.d..

In addition, we adopt SC framework for the following reasons: First, SC has solid theoretical derivations and simultaneously achieves performance comparable to SOTA SSL methods (HaoChen et al., 2021). Second, the SC objective suggests that correlation matrices of data representations are sufficient for contrasting negative-pairs. Sharing correlation matrices only results in constant negligible extra communication overheads and quantifiable privacy leakage.

5 FedSC: A Provable FedSSL Method

For the simplification of notations, we denote Rj+(θ)=𝖤x¯𝒟jR+(x¯;θ)R^{+}_{j}(\theta)=\mathsf{E}_{\bar{x}\sim\mathcal{D}_{j}}R^{+}(\bar{x};\theta) and Rj(θ)=𝖤x¯𝒟jR(x¯;θ)R_{j}(\theta)=\mathsf{E}_{\bar{x}\sim\mathcal{D}_{j}}R(\bar{x};\theta) the positive correlation matrix and correlation matrix, respectively. We start with manipulating the global objective

SC(θ;𝒟)=j=1JqjTr{Rj+(θ)}+12j=1JqjRj(θ)F2=j=1Jqj(Tr{Rj+(θ)}intra-client contraction+12qjRj(θ)F2intra-client contrast+12(1qj)Tr{Rj(θ)Rj(θ)}inter-client contrast),\displaystyle\begin{aligned} \mathcal{L}^{SC}(\theta;\mathcal{D})&=-\sum_{j=1}^{J}q_{j}Tr\{R^{+}_{j}(\theta)\}+\frac{1}{2}\left\lVert\sum_{j=1}^{J}q_{j}R_{j}(\theta)\right\rVert_{F}^{2}\\ &=\sum_{j=1}^{J}q_{j}\biggl{(}\underbrace{-Tr\{R^{+}_{j}(\theta)\}}_{\text{intra-client contraction}}+\underbrace{\frac{1}{2}q_{j}\left\lVert R_{j}(\theta)\right\rVert_{F}^{2}}_{\text{intra-client contrast}}\\ &+\underbrace{\frac{1}{2}(1-q_{j})Tr\{R_{j}(\theta)R_{-j}(\theta)\}}_{\text{inter-client contrast}}\biggr{)},\end{aligned} (4)

where Rj(θ)11qjjjqjRj(θ)R_{-j}(\theta)\triangleq\frac{1}{1-q_{j}}\sum_{j^{\prime}\neq j}q_{j^{\prime}}R_{j^{\prime}}(\theta). From Eq. (4), we notice that SC(θ;𝒟)\mathcal{L}^{SC}(\theta;\mathcal{D}) can be decomposed into a weighted sum of JJ terms corresponding to JJ clients, where each term consists of three sub-terms accounting for intra-client contraction (of positive pairs), intra-client contrast (of negative pairs), and inter-clients contrast (of negative pairs), respectively. Inspired by this decomposition, we construct the following local objective:

jSC(θ;R¯j)=Tr{Rj+(θ)}+12qjRj(θ)F2+(1qj)Tr{Rj(θ)R¯j},\displaystyle\begin{aligned} \mathcal{L}^{SC}_{j}(\theta;\bar{R}_{-j})&=-Tr\{R^{+}_{j}(\theta)\}+\frac{1}{2}q_{j}\left\lVert R_{j}(\theta)\right\rVert_{F}^{2}\\ &+(1-q_{j})Tr\{R_{j}(\theta)\bar{R}_{-j}\},\end{aligned} (5)

where R¯jH×H\bar{R}_{-j}\in\mathbb{R}^{H\times H} is an estimate of Rj(θ)R_{-j}(\theta), whose updates relying on the communication with the server. Since R¯j\bar{R}_{-j} is treated as a constant (stop gradient) in local objectives, we intentionally remove the coefficient 1/21/2 before the third term for gradient alignment between local and global objectives. That is to say, when R¯j=Rj(θ)\bar{R}_{-j}=R_{-j}(\theta), we have

SC(θ;𝒟)=j=1JqjjSC(θ,R¯j).\displaystyle\nabla\mathcal{L}^{SC}(\theta;\mathcal{D})=\sum_{j=1}^{J}q_{j}\nabla\mathcal{L}^{SC}_{j}(\theta,\bar{R}_{-j}). (6)

Note that directly applying FedAvg results in a misalignment of gradients, which is inherited from the fact that the global objective of FedSSL does not equal to the weighted sum of local objectives as suggested in Eq. (3).

The process of FedSC is similar to FedAvg, except sharing and aggregating local correlation matrices R~jt\tilde{R}_{j}^{t} besides model weights. To begin with, the server synchronizes local models with the global model. Subsequently, clients compute their local correlation matrices and send them to the server. Following this, the server distributes the aggregated global correlation matrices back to the clients. The clients then proceed to update their local models in accordance with the local objective specified in Eq. (5). Finally, the server aggregates the local models and initiates the next iteration. The process is summarized in Fig. 1.

The detailed algorithm of FedSC is shown in Algorithm 1. Here, clients use Algorithm 2 DP-CalR to calculate local correlation matrices to be shared R~jt\tilde{R}_{j}^{t} with differential privacy (DP) protection, which is detailed in Sec. 5.1. During local training, clients minimize jSC\mathcal{L}^{SC}_{j} through stochastic gradient descent (SGD), which is detailed in Sec. 5.2. It can be noticed that both clients and the server maintain the knowledge of global correlation matrix R~t\tilde{R}^{t}.

Intuitively, since the averaged local gradients align with the global gradient as shown in Eq. (6), the drift and variance of local gradients contribute 𝒪(1/T)\mathcal{O}(1/T) and 𝒪(1/T)\mathcal{O}(1/\sqrt{T}) to the convergence rate, respectively, which has been extensively studied by previous works on FedAvg. The difference is that the shared correlation matrix R~jt\tilde{R}_{j}^{t} introduces additional perturbation due to its aging (compared with instant correlation matrix Rj(θ)R_{-j}(\theta)) and DP noise. The perturbation caused by aging is proportional to the movements of weights, which is proportional to the squared learning rate η2\eta^{2}, thus contributing an additional 𝒪(1/T)\mathcal{O}(1/T) factor to the convergence rate. This is what motivates the design of FedSC.

Algorithm 1 FedSC
1:  Initialization: θ0\theta^{0} and a set of clients [J][J]
2:  for t=1Tt=1...T do
3:     Server samples a subset of clients 𝒥t[J]\mathcal{J}^{t}\subset[J]
4:     if t=1t=1 then
5:        for j[J]j\in[J] do
6:           Server sends θt1\theta^{t-1} to client jj.
7:           Client jj uploads R~jt=DP-CalR(θt1,𝒟j)\tilde{R}^{t}_{j}=\texttt{DP-CalR}(\theta^{t-1},\mathcal{D}_{j}).
8:        end for
9:        Server sends R~t=j=1JqjR~jt\tilde{R}^{t}=\sum_{j=1}^{J}q_{j}\tilde{R}^{t}_{j} to all clients.
10:     else
11:        for j𝒥tj\in\mathcal{J}^{t} do
12:           Server sends θt1\theta^{t-1} to client jj.
13:           Client jj uploads R~jt=DP-CalR(θt1,𝒟j)\tilde{R}^{t}_{j}=\texttt{DP-CalR}(\theta^{t-1},\mathcal{D}_{j}).
14:           Server updates R~t=R~t1qjR~jt1+qjR~jt\tilde{R}^{t}=\tilde{R}^{t-1}-q_{j}\tilde{R}_{j}^{t-1}+q_{j}\tilde{R}_{j}^{t}.
15:        end for
16:        Server sends R~t\tilde{R}^{t} to all clients.
17:        Server sets R~jt=R~jt1\tilde{R}_{j}^{t}=\tilde{R}_{j}^{t-1} for j𝒥tj\notin\mathcal{J}^{t}.
18:     end if
19:     for j𝒥tj\in\mathcal{J}^{t} do
20:        Client jj calculates R~jt=11qj(R~tqjR~jt)\tilde{R}^{t}_{-j}=\frac{1}{1-q_{j}}(\tilde{R}^{t}-q_{j}\tilde{R}^{t}_{j}).
21:        Client jj trains local model with procedures in Sec. 5.2, and returns updated weights θjt1\theta_{j}^{t-1}.
22:     end for
23:     Server aggregation: θt=1|𝒥t|j𝒥tθjt1\theta^{t}=\frac{1}{|\mathcal{J}^{t}|}\sum_{j\in\mathcal{J}^{t}}\theta_{j}^{t-1}.
24:  end for
25:  return: θT\theta^{T}
Algorithm 2 DP-CalR
1:  Inputs: θ\theta and local dataset 𝒟j\mathcal{D}_{j}
2:  R~jt=0\tilde{R}^{t}_{j}=0
3:  for x¯𝒟j\bar{x}\in\mathcal{D}_{j} do
4:     Sample x1,x2,,xV𝒜(|x¯)x_{1},x_{2},...,x_{V}\sim\mathcal{A}(\cdot|\bar{x})
5:     Calculate zˇv=NormClip(z(xv;θ),μ)\check{z}_{v}=\texttt{NormClip}(z(x_{v};\theta),\sqrt{\mu}), for v=1,2,,Vv=1,2,...,V.
6:     R~jt=R~jt+1|𝒟j|Vv=1VzˇvzˇvT\tilde{R}^{t}_{j}=\tilde{R}^{t}_{j}+\frac{1}{|\mathcal{D}_{j}|V}\sum_{v=1}^{V}\check{z}_{v}\check{z}_{v}^{T}
7:  end for
8:  R~jt=R~jt+𝒩(0,σ2𝐈)\tilde{R}^{t}_{j}=\tilde{R}^{t}_{j}+\mathcal{N}(0,\sigma^{2}\mathbf{I}).
9:  return: R~jt\tilde{R}^{t}_{j}

5.1 Correlation Matrices Sharing

DP protection is applied when correlation matrices are shared to mitigate additional privacy leakage on local dataset. A typical Gaussian mechanism is adopted, with parameters μ\mu and σ2\sigma^{2} controlling sensitivity and noise scale, respectively. The process is summarized in Algorithm 2.

5.2 Local Training

The local training process follows mini-batch stochastic gradient descent (SGD). At each iteration, consider a batch of BB samples X¯𝒟j\bar{X}\sim\mathcal{D}_{j}. Let X1,X2,,X2V𝒜(X¯)X_{1},X_{2},...,X_{2V}\sim\mathcal{A}(\bar{X}) be 2V2V views augmented from X¯\bar{X}. The empirical correlation matrices are calculated as follows:

R^j+({Xv}v;θj)=12BVv=1V[Z(Xv;θj)Z(Xv+V;θj)T+Z(Xv+V;θj)Z(Xv;θj)T];R^j({Xv}v;θj)=12BVv=12VZ(Xv;θj)Z(Xv;θj)T.\displaystyle\begin{aligned} \hat{R}_{j}^{+}(\{X_{v}\}_{v};\theta_{j})&\!=\!\frac{1}{2BV}\!\sum_{v=1}^{V}\biggl{[}\!Z(X_{v};\theta_{j})Z(X_{v\!+\!V};\!\theta_{j})^{T}\\ &+Z(X_{v+V};\theta_{j})Z(X_{v};\theta_{j})^{T}\biggr{]};\\ \hat{R}_{j}(\{X_{v}\}_{v};\theta_{j})&\!=\!\frac{1}{2BV}\sum_{v=1}^{2V}Z(X_{v};\theta_{j})Z(X_{v};\theta_{j})^{T}.\end{aligned}

The batch loss ^jSC(θj;{Xv}v,R¯j)\hat{\mathcal{L}}^{SC}_{j}(\theta_{j};\{X_{v}\}_{v},\bar{R}_{-j}) can be obtained by substitute Rj+(θ)R^{+}_{j}(\theta) and Rj(θ)R_{j}(\theta) in Eq. (5) with R^j+({Xv}v,θj)\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta_{j}) and R^j({Xv}v,θj)\hat{R}_{j}(\{X_{v}\}_{v},\theta_{j}), respectively. The local training follows by back-propagating the batch loss and updating the model weights iteratively.

5.3 Comparison with existing FedSSL frameworks

In this subsection, we discuss the privacy leakage and communication overhead of FedSC in comparison with other FedSSL frameworks.

Sharing correlation matrices only results in negligible communication overhead: Although FedSC shares correlation matrices in addition, it still results in less communication overhead than SOTA non-contrastive FedSSL frameworks (Zhuang et al., 2021, 2022; He et al., 2020), due to the implementation of predictor in non-contrastive SSL methods. For example, the feature dimension is H=512H=512 in our experiments, thus the correlation matrices yield H×H260,000H\times H\approx 260,000 additional parameters to be communicated. In contrast, the structure of the predictor is often a three-layer multilayer perceptron (MLP), which contains parameters that are multiples of the correlation matrices. In our case, we choose a typical size of (5121024512)(512-1024-512) resulting in 1,000,0001,000,000 parameters. The overhead of correlation matrices is negligible compared with the encoders. Therefore, even compared with contrastive SSL, which does not have a predictor, the communication overhead resulting from sharing correlation matrices is not a concern.

The extra privacy leakage is probably comparable to that caused by sharing predictors: The predictors in non-contrastive SSL also lead to potential privacy leakages. Although theoretical characterization has not been established, recent works shed lights on the operational meaning of the predictors (Halvagal et al., 2023; Tian et al., 2021), suggesting what information is probably leaked. Particularly, (Tian et al., 2021) reports that linear predictors in BYOL align with the correlation matrices R(θ)R(\theta) during training. This interesting finding suggests that predictors probably contain similar information as the correlation matrices.

6 Theoretical Analysis

In this section, we first analyze the additional privacy leakage and convergence of FedSC. Our findings are summarized as follows:

\bullet We prove that sharing correlation matrices through DP-CalR results in (Tμ22σ2+2Tμ2log1/δσ2,δ)\left(\frac{T\mu^{2}}{2\sigma^{2}}+\sqrt{\frac{2T\mu^{2}\log 1/\delta}{\sigma^{2}}},\delta\right)-DP.

\bullet We provide the convergence analysis of FedSC. Specifically, with large batch size BB, large number of views VV and small scale of DP noise σ\sigma, we can achieve a convergence rate close to 𝒪(1/T)\mathcal{O}(1/\sqrt{T}).

\bullet The analysis indicates superior performance of FedSC over FedAvg whose convergence is dominated by a 𝒪(1)\mathcal{O}(1) constant meaning error floor.

6.1 Additional Privacy Leakage

In this subsection, we analyze the Gaussian mechanism in Algorithm 2 DP-CalR. We start with definitions of variations of differential privacy (DP).

Definition 6.1 ((ϵ,δ)(\epsilon,\delta)-DP).

A mechanism :𝒳𝒴\mathcal{M}:\mathcal{X}\rightarrow\mathcal{Y} is (ϵ,δ)(\epsilon,\delta)-DP, if for any neighboring X,X𝒳X,X^{\prime}\in\mathcal{X} and 𝒮𝒴\mathcal{S}\subset\mathcal{Y}, the following inequality is satisfied.

Pr((X)𝒮)eϵPr((X)𝒮)+δ.\displaystyle Pr(\mathcal{M}(X)\in\mathcal{S})\leq e^{\epsilon}Pr(\mathcal{M}(X^{\prime})\in\mathcal{S})+\delta.

DP protects the inputs of a mechanism from membership inference attacks. For a mechanism satisfying DP, we expect that one can hardly tell whether the input contains a certain entry by only looking at the output. In our case, we do not want the server to know whether a local dataset contains a particular data point.

Definition 6.2 ((α,ϵ)(\alpha,\epsilon)-RDP (Mironov, 2017)).

A mechanism :𝒳𝒴\mathcal{M}:\mathcal{X}\rightarrow\mathcal{Y} has (α,ϵ)(\alpha,\epsilon)-Rényi differential privacy, if for any neighboring X,X𝒳X,X^{\prime}\in\mathcal{X}, Y=(X)Y=\mathcal{M}(X) and Y=(X)Y^{\prime}=\mathcal{M}(X^{\prime}), the following inequality is satisfied:

Dα(PY||PY)ϵ,\displaystyle D_{\alpha}(P_{Y}||P_{Y^{\prime}})\leq\epsilon,

where Dα(PY||PY)D_{\alpha}(P_{Y}||P_{Y^{\prime}}) is Rényi-divergence of order α>1\alpha>1

Dα(PY||PY)1α1log𝖤Y(PY(Y)PY(Y))α.\displaystyle D_{\alpha}(P_{Y}||P_{Y^{\prime}})\triangleq\frac{1}{\alpha-1}\log\mathsf{E}_{Y^{\prime}}\left(\frac{P_{Y}(Y^{\prime})}{P_{Y^{\prime}}(Y^{\prime})}\right)^{\alpha}.

RDP is a variation of DP with many good properties, which are summarized in the following Lemmas.

Lemma 6.3 (Gaussian Mechanism of RDP (Mironov, 2017)).

Let f:𝒳nf:\mathcal{X}\rightarrow\mathbb{R}^{n} be a function with l2l_{2} sensitivity WW, then the Gaussian mechanism Gf()=f()+𝒩(0,𝐈nσ2)G_{f}(\cdot)=f(\cdot)+\mathcal{N}(0,\mathbf{I}_{n}\sigma^{2}) is (α,αW22σ2)(\alpha,\frac{\alpha W^{2}}{2\sigma^{2}})-RDP.

Lemma 6.4 (Composition of RDP (Mironov, 2017) ).

Let 1:𝒳𝒴\mathcal{M}_{1}:\mathcal{X}\rightarrow\mathcal{Y} be (α,ϵ1)(\alpha,\epsilon_{1})-RDP, and 2:𝒳×𝒴𝒵\mathcal{M}_{2}:\mathcal{X}\times\mathcal{Y}\rightarrow\mathcal{Z} be (α,ϵ2)(\alpha,\epsilon_{2})-RDP. Then the mechanism 3:𝒳𝒴×𝒵,X(1(X),2(X,1(X)))\mathcal{M}_{3}:\mathcal{X}\rightarrow\mathcal{Y}\times\mathcal{Z},X\mapsto\left(\mathcal{M}_{1}(X),\mathcal{M}_{2}(X,\mathcal{M}_{1}(X))\right) is (α,ϵ1+ϵ2)(\alpha,\epsilon_{1}+\epsilon_{2})-RDP to XX.

Lemma 6.5 ((Mironov, 2017)).

If a mechanism is (α,ϵ)(\alpha,\epsilon)-RDP, then it is (ϵ+log1/δα1,δ)(\epsilon+\frac{\log 1/\delta}{\alpha-1},\delta)-DP.

With all these preparations, we use the following proposition to characterize the additional privacy leakage of FedSC.

Proposition 6.6 (Additional Privacy Leakage of FedSC).

Sharing correlation matrices for TjT_{j} times through Algorithm 2 DP-CalR results in (Tjμ22σ2+2Tjμ2log1/δσ2,δ)\left(\frac{T_{j}\mu^{2}}{2\sigma^{2}}+\sqrt{\frac{2T_{j}\mu^{2}\log 1/\delta}{\sigma^{2}}},\delta\right)-DP.

Proof.

We start with the sensitivity of DP-CalR.

R~jtF=1|𝒟j|x¯𝒟j1Vv=1Vzˇv(x¯)zˇv(x¯)TF1|𝒟j|x¯𝒟j/x¯1Vv=1Vzˇv(x¯)zˇv(x¯)TF+1|𝒟j|1Vv=1Vzˇv(x¯)zˇv(x¯)TF\displaystyle\begin{aligned} \left\lVert\tilde{R}^{t}_{j}\right\rVert_{F}&=\left\lVert\frac{1}{|\mathcal{D}_{j}|}\sum_{\bar{x}\in\mathcal{D}_{j}}\frac{1}{V}\sum_{v=1}^{V}\check{z}_{v}(\bar{x})\check{z}_{v}(\bar{x})^{T}\right\rVert_{F}\\ &\leq\left\lVert\frac{1}{|\mathcal{D}_{j}|}\sum_{\bar{x}^{\prime}\in\mathcal{D}_{j}/\bar{x}}\frac{1}{V}\sum_{v=1}^{V}\check{z}_{v}(\bar{x}^{\prime})\check{z}_{v}(\bar{x}^{\prime})^{T}\right\rVert_{F}\\ &+\frac{1}{|\mathcal{D}_{j}|}\frac{1}{V}\sum_{v=1}^{V}\left\lVert\check{z}_{v}(\bar{x})\check{z}_{v}(\bar{x})^{T}\right\rVert_{F}\end{aligned}

where zˇv(x¯)\check{z}_{v}(\bar{x}) is the representation of the vv-th view of data x¯\bar{x}. Notice that for any x¯\bar{x}

zˇv(x¯)zˇv(x¯)TF=Tr{zˇv(x¯)zˇv(x¯)Tzˇv(x¯)zˇv(x¯)T}μ\displaystyle\begin{aligned} \left\lVert\check{z}_{v}(\bar{x})\check{z}_{v}(\bar{x})^{T}\right\rVert_{F}&=\sqrt{Tr\left\{\check{z}_{v}(\bar{x})\check{z}_{v}(\bar{x})^{T}\check{z}_{v}(\bar{x})\check{z}_{v}(\bar{x})^{T}\right\}}\\ &\leq\mu\end{aligned}

The sensitivity is finally bounded by μ/|𝒟j|\mu/|\mathcal{D}_{j}|. With Lemma 6.3, we have DP-CalR is (α,αμ22σ2|𝒟j|2)\left(\alpha,\frac{\alpha\mu^{2}}{2\sigma^{2}|\mathcal{D}_{j}|^{2}}\right)-RDP. With Lemma 6.4, sharing correlation matrices for TjT_{j} times results in (α,Tjαμ22σ2|𝒟j|2)\left(\alpha,\frac{T_{j}\alpha\mu^{2}}{2\sigma^{2}|\mathcal{D}_{j}|^{2}}\right)-RDP, which is (Tjμ22σ2|𝒟j|2+2Tjμ2log1/δσ2|𝒟j|2,δ)\left(\frac{T_{j}\mu^{2}}{2\sigma^{2}|\mathcal{D}_{j}|^{2}}+\sqrt{\frac{2T_{j}\mu^{2}\log 1/\delta}{\sigma^{2}|\mathcal{D}_{j}|^{2}}},\delta\right)-DP using Lemma 6.5 with α=2σ2|𝒟j|2log1/δTjμ2+1\alpha=\sqrt{\frac{2\sigma^{2}|\mathcal{D}_{j}|^{2}\log 1/\delta}{T_{j}\mu^{2}}}+1. ∎

From the results, we can notice that for arbitrarily δ\delta, TjT_{j} and σ\sigma, the parameter ϵ=Tjμ22σ2|𝒟j|2+2Tjμ2log1/δσ2|𝒟j|2\epsilon=\frac{T_{j}\mu^{2}}{2\sigma^{2}|\mathcal{D}_{j}|^{2}}+\sqrt{\frac{2T_{j}\mu^{2}\log 1/\delta}{\sigma^{2}|\mathcal{D}_{j}|^{2}}} approaches to zero when the size of local dataset approaches to infinity, indicating no differential privacy leakage.

Table 2: Performance comparison between FedSC and SOTAs on benchmark tasks: FedSC outperforms most of the SOTAs under different settings. Here we use bold and underline to mark the highest and second highest accuracy, respectively.
SVHN CIFAR10 CIFAR100 SVHN CIFAR10 CIFAR100
Participation 5/55/5 10/1010/10 20/2020/20 2/52/5 2/102/10 4/204/20
FedAvg + BYOL 87.85±0.4987.85\pm 0.49 68.14±0.5168.14\pm 0.51 43.54±1.1243.54\pm 1.12 87.10±0.7987.10\pm 0.79 65.28±0.8365.28\pm 0.83 38.77±1.0438.77\pm 1.04
FedAvg + SC 90.52±0.4290.52\pm 0.42 77.82±0.8277.82\pm 0.82 56.24±0.1956.24\pm 0.19 89.89±0.9489.89\pm 0.94 75.36±0.3675.36\pm 0.36 42.95±0.5242.95\pm 0.52
FedU 87.92±0.3187.92\pm 0.31 68.39±0.6968.39\pm 0.69 43.81±0.9843.81\pm 0.98 87.40±0.7587.40\pm 0.75 65.52±0.5165.52\pm 0.51 39.11±0.9239.11\pm 0.92
FedEMA 91.87±0.30\mathbf{91.87}\pm 0.30 68.78±0.2568.78\pm 0.25 44.18±0.7344.18\pm 0.73 88.97±0.8288.97\pm 0.82 65.93±0.6365.93\pm 0.63 39.78±1.2039.78\pm 1.20
FedX 74.60±0.7274.60\pm 0.72 59.17±0.9359.17\pm 0.93 39.70±0.3939.70\pm 0.39 73.34±0.8873.34\pm 0.88 57.42±0.9157.42\pm 0.91 33.54±0.6733.54\pm 0.67
FedCA 89.92±0.1489.92\pm 0.14 78.22±0.2278.22\pm 0.22 52.35±0.0952.35\pm 0.09 89.28±0.4489.28\pm 0.44 77.22±0.65\mathbf{77.22}\pm 0.65 51.58±0.1851.58\pm 0.18
FedSC (Proposed) 91.78¯±0.49\underline{91.78}\pm 0.49 80.06±0.35\mathbf{80.06}\pm 0.35 58.35±0.15\mathbf{58.35}\pm 0.15 91.03±0.58\mathbf{91.03}\pm 0.58 77.12¯±0.44\underline{77.12}\pm 0.44 56.64±0.65\mathbf{56.64}\pm 0.65
Centralized SC 93.17±0.1393.17\pm 0.13 90.21±0.0890.21\pm 0.08 64.32±0.0564.32\pm 0.05 - - -

6.2 Convergence of FedSC

This subsection presents the convergence of FedSC. We begin with the following assumptions.

Assumption 6.7.

For any θ\theta and xx, NN’s output is bounded in norm z(x,θ)2A0||z(x,\theta)||_{2}\leq A_{0} for some constant A0A_{0}.

Assumption 6.8.

For any θ\theta and xx, the Jacobian matrix of NN’s output is bounded in norm θz(x,θ)FA1||\nabla_{\theta}z(x,\theta)||_{F}\leq A_{1} for some constant A1A_{1}.

Assumption 6.9.

The function represented by NN has bounded second order derivatives, i.e, for any θ\theta and xx,

m,pmpz(x;θ)22A22\displaystyle\sum_{m,p}\left\lVert\partial_{m}\partial_{p}z(x;\theta)\right\rVert_{2}^{2}\leq A_{2}^{2}

for some constant A2A_{2}, where p\partial_{p} refers to the partial derivation with respect to the pp-th entry of θ\theta.

Assumption 6.7 can be satisfied when the NN has a normalization layer at the end or uses bounded activation functions, such as sigmoid, at the output layer. Assumption 6.8 accounts for the Lipschitz continuity of z(x,θ)z(x,\theta), which is often the case when hidden layers of a NN uses activation functions, such as tanh, sigmoid and relu. Note that Assumption 6.8 is weaker than the bounded gradient norm assumption used in previous works (Li et al., 2019; Noble et al., 2022). However, in our case, it can lead to bounded gradient norm due to the structure of SC objectives, which is detailed in the appendices. Assumption 6.9 accounts for the strong smoothness of the NN, which is widely adopted in existing works (Karimireddy et al., 2020b; Li et al., 2019; Karimireddy et al., 2020a) . With these assumptions, we demonstrate the convergence of FedSC with the following theorem.

Theorem 6.10.

Let Assumption 6.7, 6.8 and 6.9 hold. Choose μ>A02\mu>A_{0}^{2}, and the local learning rate η=𝒪(1TE)\eta=\mathcal{O}\left(\frac{1}{\sqrt{TE}}\right), where TT and EE are the number of communication rounds and local updates, respectively. Then FedSC achieves

1TEt=0T1e=0E1𝖤[SC(θt,e;𝒟)2]𝒪(E2(H2σ2+C2)TE+EJ/|𝒥|1J1T+(1V+maxj|𝒟j|/B1|𝒟j|1)(H2σ2+C4)TE+1V+maxj|𝒟j|/B1|𝒟j|1+H2σ2)\displaystyle\begin{aligned} &\frac{1}{TE}\sum_{t=0}^{T-1}\sum_{e=0}^{E-1}\mathsf{E}\!\left[\left\lVert\nabla\mathcal{L}^{SC}(\theta^{t,e};\mathcal{D})\right\rVert^{2}\right]\\ &\leq\mathcal{O}\Biggl{(}\frac{E^{2}(H^{2}\sigma^{2}+C_{2})}{TE}+\frac{\sqrt{E}\sqrt{\frac{J/|\mathcal{J}|-1}{J-1}}}{\sqrt{T}}\\ &+\frac{\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\left(H^{2}\sigma^{2}+C_{4}\right)}{\sqrt{TE}}\\ &+\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}+H^{2}\sigma^{2}\Biggr{)}\end{aligned} (7)

where θt,e1|𝒥t|j𝒥tθjt,e\theta^{t,e}\triangleq\frac{1}{|\mathcal{J}^{t}|}\sum_{j\in\mathcal{J}^{t}}\theta_{j}^{t,e} is the virtual averaged weights, and θjt,e\theta_{j}^{t,e} the local weights of client jj at the ee-th step in the tt-th round; BB and VV are batch size and number of augmented views, respectively; C2C_{2} and C4C_{4} here are constants only depending on A0A_{0}, A1A_{1} and A2A_{2}.

6.2.1 Superior performance of FedSC

The convergence rate of FedSC is dominated by the following term when TT approaches infinity

γ=𝒪(1V+maxj|𝒟j|/B1|𝒟j|1+H2σ2).\displaystyle\gamma=\mathcal{O}\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}+H^{2}\sigma^{2}\right). (8)

The bias of local batch gradients results in a rate of (1V+maxj|𝒟j|/B1|𝒟j|1)\bigl{(}\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\bigr{)}, in which specifically, sampling the data set 𝒟j\mathcal{D}_{j} (without replacement) leads to the rate of |𝒟j|/B1|𝒟j|1\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}, and sampling the augmentation kernel 𝒜(|x¯)\mathcal{A}(\cdot|\bar{x}) results in 1V\frac{1}{V}. Note that this bias also exists in centralized SSL training, and does not result from federation. Sampling variance, from the augmentation kernel, in the shared correlation matrix contributes 1V\frac{1}{V} to the convergence. The DP noise contributes H2σ2H^{2}\sigma^{2}. If we set batch size B=|𝒟j|B=|\mathcal{D}_{j}|, generate infinite number of views V=V=\infty and not apply DP protection, i.e., σ2=0\sigma^{2}=0, then the error floor will disappear, which results in convergence rate similar to FedAvg in supervised FL. In comparison, if we directly use the SC objective without modification and apply FedAvg, there will be a constant error floor independent with batch size BB and number of views VV, due to the misalignment between the averaged local objectives and the global objectives.

6.2.2 Sketch of Proof

We begin with the case of full clients participation. The convergence is determined by two terms: 1) The squared norm of the bias of the averaged local gradient and 2) The variance of the averaged local gradient. The norm of bias can be factorized into three components: 1.a) The “drift” of the local weights, leading to a factor of 𝒪(1/T)\mathcal{O}(1/T). 1.b) The bias in the batch gradient of local objectives jSC(θ;R¯j)\mathcal{L}^{SC}_{j}(\theta;\bar{R}_{-j}), contributes a factor of 𝒪(1V+maxj|𝒟j|/B1|𝒟j|1)\mathcal{O}\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}j|/B-1}{|\mathcal{D}j|-1}\right). The bias is due to the fact that the SSL objective can not be written as a sum of samples losses like in the supervised learning cases. 1.c) The impact of aging, sample variance, and DP noise in the shared correlation matrices R~jt\tilde{R}_{-j}^{t}. Note that R~jt\tilde{R}_{-j}^{t} remains constant during local training. The aging (compared with Rj(θt,e)R_{-j}(\theta^{t,e})) leads to a bias proportional to the drift of local weights, resulting in a factor of 𝒪(1/T)\mathcal{O}(1/T). The sampling variance in R~jt\tilde{R}_{-j}^{t} contributes a factor of 𝒪(1/V)\mathcal{O}(1/V). DP noise contributes a factor of 𝒪(H2σ2)\mathcal{O}(H^{2}\sigma^{2}), where HH is the dimension of the representation. The variance in the averaged local gradient contributes a factor of 𝒪(1/T)\mathcal{O}(1/\sqrt{T}), considering 2.a) the variance in batch gradient sampling and 2.b) the DP noise in R~jt\tilde{R}_{-j}^{t}.For partial client participation, we need to consider the variance in aggregation and additional aging of R~jt\tilde{R}_{-j}^{t}. Given bounded gradient norm, the variance due to client sampling is 𝒪(EJ/|𝒥|1T(J1))\mathcal{O}\left(\sqrt{E}\sqrt{\frac{J/|\mathcal{J}|-1}{T(J-1)}}\right). Additional aging is proportional to the extra drift, leading to a rate of 𝒪(1/T)\mathcal{O}(1/T).

Table 3: FedSC with different levels of DP protections (δ=102\delta=10^{-2}): deploying DP leads to negligible performance degradation.
SVHN CIFAR10 SVHN CIFAR10
Participation 5/55/5 10/1010/10 2/52/5 2/102/10
FedAvg+SC 90.52±0.4290.52\pm 0.42 77.82±0.8277.82\pm 0.82 89.89±0.9489.89\pm 0.94 75.36±0.3675.36\pm 0.36
ϵ=3\epsilon=3 90.90±0.5290.90\pm 0.52 79.21±0.5979.21\pm 0.59 90.00±0.7290.00\pm 0.72 76.97±0.6476.97\pm 0.64
ϵ=6\epsilon=6 91.32±0.8791.32\pm 0.87 79.42±0.6379.42\pm 0.63 90.12±0.6190.12\pm 0.61 77.08±0.4677.08\pm 0.46
ϵ=\epsilon=\infty 91.78±0.4991.78\pm 0.49 80.06±0.3580.06\pm 0.35 91.03±0.5891.03\pm 0.58 77.12±0.4477.12\pm 0.44

7 Experiments

7.1 Experimental Setup

Datasets: Three datasets, SVHN, CIFAR10 and CIFAR100, are used for evaluation. SVHN is split into 55 disjoint local datasets, each of which contains 22 classes. CIFAR10 is split into 1010 disjoint local datasets according to the 1010 classes. CIFAR100 is split into 2020 disjoint local datasets, each of which contains 55 classes. Therefore, the size of local datasets for SVHN, CIFAR10 and CIFAR100 tasks are 10,00010,000, 5,0005,000 and 2,5002,500, respectively.

Models: For SVHN and CIFAR10, we use a modified version of ResNet20 as backbones. For CIFAR100, the backbone is a modified version of ResNet50.

Hyper-parameters: For all three tasks, the number of communication round T=200T=200, and the number of local epochs is E=5E=5. For SVHN and CIFAR10, the batch size is B=512B=512. For CIFAR100, the batch size is B=256B=256. The number of views V=2V=2 for all experiments. For correlation matrices sharing, the number of views is set as V=5V=5.

Benchmarks: Besides FedAvg+BYOL and FedAvg+SC, we also compare with the following state of the arts: FedU (Zhuang et al., 2021), FedEMA (Zhuang et al., 2022), FedX (Han et al., 2022) and FedCA (Zhang et al., 2023).

Refer to caption
(a) SVHN
Refer to caption
(b) CIFAR10
Refer to caption
(c) CIFAR100
Figure 2: Convergence of FedSC and FedAvg+SC. 1) FedAvg+SC tends to experience either a high error floor or overfitting. 2) FedSC is able to consistently enhance KNN accuracy. This observation validates our theoretical analysis in Sec. 6.

7.2 Experimental Results

Comparison with SOTA approaches: Table 2 presents the performance comparisons of various algorithms under linear evaluation, where the centralized SC serves as an ideal upper bound. We conclude the following three observations: (1) Our proposed algorithm, FedSC, demonstrates better or comparable performance across different tasks compared to other 66 methods. (2) FedBYOL, FedU, and FedEMA show good results on SVHN but underperform on CIFAR10 and CIFAR100. We believe that this disparity is caused by the larger local dataset size in SVHN, leading to increased local updates. Since these methods incorporate momentum updates in the target encoder, a larger number of updates might be necessary to effectively initiate local training. (3) FedSC and FedCA exhibit less performance degradation when switched to the partial client participation case. We believe this is because clients in both FedSC and FedCA have extra global information about representations. Additionally, predictors in FedBYOL, FedU, and FedEMA are under the effect of client sampling, hindering their global information provision.

Table 4: FedSC with different levels of DP protections (δ=102\delta=10^{-2}): deploying DP leads to negligible performance degradation.
CIFAR100
Participation 20/2020/20 4/204/20
FedAvg+SC 56.24±0.1956.24\pm 0.19 42.95±0.5242.95\pm 0.52
ϵ=6\epsilon=6 57.10±0.8257.10\pm 0.82 54.87±0.6254.87\pm 0.62
ϵ=12\epsilon=12 57.63±0.5757.63\pm 0.57 55.76±0.5855.76\pm 0.58
ϵ=\epsilon=\infty 58.35±0.1558.35\pm 0.15 56.64±0.6556.64\pm 0.65
Table 5: FedSC with different levels of DP protections (δ=104\delta=10^{-4}): deploying DP leads to negligible performance degradation.
SVHN CIFAR10 CIFAR100
Participation 2/52/5 2/102/10 4/204/20
FedAvg+SC 89.89±0.9489.89\pm 0.94 75.36±0.3675.36\pm 0.36 42.95±0.5242.95\pm 0.52
ϵ=3\epsilon=3 89.95±0.8189.95\pm 0.81 76.75±0.6276.75\pm 0.62 54.22±0.7254.22\pm 0.72
ϵ=8\epsilon=8 90.12±0.6190.12\pm 0.61 77.08±0.4677.08\pm 0.46 54.87±0.6254.87\pm 0.62
ϵ=\epsilon=\infty 91.03±0.5891.03\pm 0.58 77.12±0.4477.12\pm 0.44 56.64±0.6556.64\pm 0.65

DP Impact: Table 3, 4 and 5 illustrate the impact of the DP mechanism on FedSC’s performance. It is shown that with a reasonable degree of DP protection, there is only a modest decline in FedSC’s performance, which remains better than that of FedAvg+SC. Given that our focus is on data level DP, the extra privacy leakage shown in the tables is typically insignificant when compared to the leakage resulting from the encoders. On the other hand, according to the analysis in Sec. 6.1, a smaller dataset necessitates a higher level of DP noise to maintain the same degree of privacy protection. The local dataset sizes for SVHN, CIFAR10, and CIFAR100 tasks are 10,00010,000, 5,0005,000, and 2,5002,500, respectively. As a result, for the CIFAR100 task, we choose a slightly higher privacy budget compared to the other two tasks.

Convergence: Fig. 2 compares the convergence of proposed FedSC and FedAvg+SC, in terms of communication rounds and KNN accuracy. The figures reveal that FedAvg+SC tends to experience either a high error rate or overfitting as the number of communication rounds grows. In contrast, FedSC can consistently enhance KNN accuracy. This validates our theoretical analysis in Sec. 6.

8 Conclusion

In this paper, we proposed FedSC, a novel FedSSL framework based on spectral contrastive objectives. In FedSC, clients share correlation matrices besides local weights periodically. With shared correlation matrices, clients are able to contrast inter-client sample contrast in addition to intra-client contrast and contraction. To mitigate the extra privacy leakage on local dataset, we adopted DP mechanism on shared correlation matrices. We provided theoretical analysis on privacy leakage and convergence, demonstrating the efficacy of FedSC. To the best knowledge of the authors, this is the first provable FedSSL method.

9 Impact Statements

This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none of which we feel must be specifically highlighted here.

References

  • Bardes et al. (2021) Bardes, A., Ponce, J., and LeCun, Y. Vicreg: Variance-invariance-covariance regularization for self-supervised learning. arXiv preprint arXiv:2105.04906, 2021.
  • Brown et al. (2020) Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • Chen et al. (2020) Chen, T., Kornblith, S., Norouzi, M., and Hinton, G. A simple framework for contrastive learning of visual representations. In International conference on machine learning, pp.  1597–1607. PMLR, 2020.
  • Chen & He (2021) Chen, X. and He, K. Exploring simple siamese representation learning. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp.  15750–15758, 2021.
  • Dwork (2006) Dwork, C. Differential privacy. In International colloquium on automata, languages, and programming, pp.  1–12. Springer, 2006.
  • Geyer et al. (2017) Geyer, R. C., Klein, T., and Nabi, M. Differentially private federated learning: A client level perspective. arXiv preprint arXiv:1712.07557, 2017.
  • Grill et al. (2020) Grill, J.-B., Strub, F., Altché, F., Tallec, C., Richemond, P., Buchatskaya, E., Doersch, C., Avila Pires, B., Guo, Z., Gheshlaghi Azar, M., et al. Bootstrap your own latent-a new approach to self-supervised learning. Advances in neural information processing systems, 33:21271–21284, 2020.
  • Halvagal et al. (2023) Halvagal, M. S., Laborieux, A., and Zenke, F. Implicit variance regularization in non-contrastive ssl. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  • Han et al. (2022) Han, S., Park, S., Wu, F., Kim, S., Wu, C., Xie, X., and Cha, M. Fedx: Unsupervised federated learning with cross knowledge distillation. In European Conference on Computer Vision, pp.  691–707. Springer, 2022.
  • HaoChen et al. (2021) HaoChen, J. Z., Wei, C., Gaidon, A., and Ma, T. Provable guarantees for self-supervised deep learning with spectral contrastive loss. Advances in Neural Information Processing Systems, 34:5000–5011, 2021.
  • He et al. (2021) He, J., Zhou, C., Ma, X., Berg-Kirkpatrick, T., and Neubig, G. Towards a unified view of parameter-efficient transfer learning. In International Conference on Learning Representations, 2021.
  • He et al. (2020) He, K., Fan, H., Wu, Y., Xie, S., and Girshick, R. Momentum contrast for unsupervised visual representation learning. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp.  9729–9738, 2020.
  • Hu et al. (2021) Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L., and Chen, W. Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685, 2021.
  • Hu et al. (2020) Hu, R., Guo, Y., Li, H., Pei, Q., and Gong, Y. Personalized federated learning with differential privacy. IEEE Internet of Things Journal, 7(10):9530–9539, 2020.
  • Karimireddy et al. (2020a) Karimireddy, S. P., Jaggi, M., Kale, S., Mohri, M., Reddi, S. J., Stich, S. U., and Suresh, A. T. Mime: Mimicking centralized stochastic algorithms in federated learning. arXiv preprint arXiv:2008.03606, 2020a.
  • Karimireddy et al. (2020b) Karimireddy, S. P., Kale, S., Mohri, M., Reddi, S., Stich, S., and Suresh, A. T. Scaffold: Stochastic controlled averaging for federated learning. In International conference on machine learning, pp.  5132–5143. PMLR, 2020b.
  • Li et al. (2019) Li, X., Huang, K., Yang, W., Wang, S., and Zhang, Z. On the convergence of fedavg on non-iid data. arXiv preprint arXiv:1907.02189, 2019.
  • McMahan et al. (2017) McMahan, B., Moore, E., Ramage, D., Hampson, S., and Arcas, B. A. y. Communication-Efficient Learning of Deep Networks from Decentralized Data. In Singh, A. and Zhu, J. (eds.), Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, volume 54 of Proceedings of Machine Learning Research, pp.  1273–1282. PMLR, 20–22 Apr 2017.
  • Mironov (2017) Mironov, I. Rényi differential privacy. In 2017 IEEE 30th computer security foundations symposium (CSF), pp.  263–275. IEEE, 2017.
  • Noble et al. (2022) Noble, M., Bellet, A., and Dieuleveut, A. Differentially private federated learning on heterogeneous data. In International Conference on Artificial Intelligence and Statistics, pp.  10110–10145. PMLR, 2022.
  • Ravi & Larochelle (2016) Ravi, S. and Larochelle, H. Optimization as a model for few-shot learning. In International conference on learning representations, 2016.
  • Stich (2018) Stich, S. U. Local sgd converges fast and communicates little. arXiv preprint arXiv:1805.09767, 2018.
  • Tian et al. (2021) Tian, Y., Chen, X., and Ganguli, S. Understanding self-supervised learning dynamics without contrastive pairs. In International Conference on Machine Learning, pp.  10268–10278. PMLR, 2021.
  • Truex et al. (2020) Truex, S., Liu, L., Chow, K.-H., Gursoy, M. E., and Wei, W. Ldp-fed: Federated learning with local differential privacy. In Proceedings of the Third ACM International Workshop on Edge Systems, Analytics and Networking, pp.  61–66, 2020.
  • Tschannen et al. (2019) Tschannen, M., Djolonga, J., Rubenstein, P. K., Gelly, S., and Lucic, M. On mutual information maximization for representation learning. arXiv preprint arXiv:1907.13625, 2019.
  • Wei et al. (2020) Wei, K., Li, J., Ding, M., Ma, C., Yang, H. H., Farokhi, F., Jin, S., Quek, T. Q., and Poor, H. V. Federated learning with differential privacy: Algorithms and performance analysis. IEEE Transactions on Information Forensics and Security, 15:3454–3469, 2020.
  • Zbontar et al. (2021) Zbontar, J., Jing, L., Misra, I., LeCun, Y., and Deny, S. Barlow twins: Self-supervised learning via redundancy reduction. In International Conference on Machine Learning, pp.  12310–12320. PMLR, 2021.
  • Zhang et al. (2022) Zhang, C., Zhang, K., Zhang, C., Pham, T. X., Yoo, C. D., and Kweon, I. S. How does simsiam avoid collapse without negative samples? a unified understanding with self-supervised contrastive learning. arXiv preprint arXiv:2203.16262, 2022.
  • Zhang et al. (2023) Zhang, F., Kuang, K., Chen, L., You, Z., Shen, T., Xiao, J., Zhang, Y., Wu, C., Wu, F., Zhuang, Y., et al. Federated unsupervised representation learning. Frontiers of Information Technology & Electronic Engineering, 24(8):1181–1193, 2023.
  • Zhuang et al. (2021) Zhuang, W., Gan, X., Wen, Y., Zhang, S., and Yi, S. Collaborative unsupervised visual representation learning from decentralized data. In Proceedings of the IEEE/CVF international conference on computer vision, pp.  4912–4921, 2021.
  • Zhuang et al. (2022) Zhuang, W., Wen, Y., and Zhang, S. Divergence-aware federated self-supervised learning. arXiv preprint arXiv:2204.04385, 2022.

Appendix A Derivation of SC objective

SC(θ;𝒟)𝖤x¯𝒟𝖤x,x+𝒜(|x¯)[z(x;θ)Tz(x+;θ)]+12𝖤x,x𝒜(|𝒟)[(z(x;θ)Tz(x;θ))2]=𝖤x¯𝒟𝖤x,x+𝒜(|x¯)[Tr{z(x+;θ)z(x;θ)T}]+12𝖤x,x𝒜(|𝒟)[Tr{z(x;θ)z(x;θ)Tz(x;θ)z(x;θ)T}]=𝖤x¯𝒟[Tr{R+(θ)}]+12Tr{𝖤x𝒜(|𝒟)z(x;θ)z(x;θ)T𝖤x𝒜(|𝒟)z(x;θ)z(x;θ)T}=𝖤x¯𝒟[Tr{R+(θ)}]+12𝖤x𝒜(|𝒟)z(x;θ)z(x;θ)TF2=𝖤x¯𝒟Tr{R+(x¯;θ)}+12𝖤x¯𝒟R(x¯;θ)F2\displaystyle\begin{aligned} \mathcal{L}^{SC}(\theta;\mathcal{D})&\triangleq-\mathsf{E}_{\bar{x}\sim\mathcal{D}}\mathsf{E}_{x,x^{+}\sim\mathcal{A}(\cdot|\bar{x})}\left[z(x;\theta)^{T}z(x^{+};\theta)\right]+\frac{1}{2}\mathsf{E}_{x,x^{-}\sim\mathcal{A}(\cdot|\mathcal{D})}\left[\left(z(x;\theta)^{T}z(x^{-};\theta)\right)^{2}\right]\\ &=-\mathsf{E}_{\bar{x}\sim\mathcal{D}}\mathsf{E}_{x,x^{+}\sim\mathcal{A}(\cdot|\bar{x})}\left[Tr\left\{z(x^{+};\theta)z(x;\theta)^{T}\right\}\right]\\ &+\frac{1}{2}\mathsf{E}_{x,x^{-}\sim\mathcal{A}(\cdot|\mathcal{D})}\left[Tr\left\{z(x;\theta)z(x;\theta)^{T}z(x^{-};\theta)z(x^{-};\theta)^{T}\right\}\right]\\ &=-\mathsf{E}_{\bar{x}\sim\mathcal{D}}\left[Tr\left\{R^{+}(\theta)\right\}\right]+\frac{1}{2}Tr\left\{\mathsf{E}_{x\sim\mathcal{A}(\cdot|\mathcal{D})}z(x;\theta)z(x;\theta)^{T}\mathsf{E}_{x^{-}\sim\mathcal{A}(\cdot|\mathcal{D})}z(x^{-};\theta)z(x^{-};\theta)^{T}\right\}\\ &=-\mathsf{E}_{\bar{x}\sim\mathcal{D}}\left[Tr\left\{R^{+}(\theta)\right\}\right]+\frac{1}{2}\left\lVert\mathsf{E}_{x\sim\mathcal{A}(\cdot|\mathcal{D})}z(x;\theta)z(x;\theta)^{T}\right\rVert_{F}^{2}\\ &=-\mathsf{E}_{\bar{x}\sim\mathcal{D}}Tr\{R^{+}(\bar{x};\theta)\}+\frac{1}{2}\left\lVert\mathsf{E}_{\bar{x}\sim\mathcal{D}}R(\bar{x};\theta)\right\rVert_{F}^{2}\end{aligned} (9)

Appendix B Proof of Theorem 6.10

B.1 Additional Notations

Let θjt,e\theta_{j}^{t,e} and vjt,ev_{j}^{t,e} be the local weights and local SGD direction, respectively, at the ee-th update in the tt-th communication round. Denote θt,ejqjθjt,e\theta^{t,e}\triangleq\sum_{j}q_{j}\theta_{j}^{t,e} and vjt,ejqjvjt,ev_{j}^{t,e}\triangleq\sum_{j}q_{j}v_{j}^{t,e} the virtual averaged weights and moving direction, respectively. Since the server aggregates periodically, we have θt=θt,0\theta^{t}=\theta^{t,0}. For simplicity, we remove the up-script ”SC” in SC(θ)\mathcal{L}^{SC}(\theta) and SC(θ,R~jt)\mathcal{L}^{SC}(\theta,\tilde{R}^{t}_{-j}) without ambiguity.

B.2 Assumptions

Assumption B.1.

For any θ\theta and xx, NN’s output is bounded in norm z(x,θ)2<A0||z(x,\theta)||_{2}<A_{0}.

Assumption B.2.

For any θ\theta and xx, the Jacobin of NN’s output is bounded in norm z(x,θ)F<A1||\nabla z(x,\theta)||_{F}<A_{1}.

Assumption B.3.

The function represented by NN has bounded second order derivatives, i.e, for any θ\theta and xx

m,pmpz(x;θ)22A22\displaystyle\sum_{m,p}\left\lVert\partial_{m}\partial_{p}z(x;\theta)\right\rVert_{2}^{2}\leq A_{2}^{2} (10)

B.3 Lemmas

Lemma B.4.

For any x¯\bar{x}, {Xv}v\{X_{v}\}_{v}, θ\theta and jj, the following inequalities hold

R(x¯,θ)F2,R(θ)F2,Rj(θ)F2,R^({Xv}v,θ)F2A04\displaystyle\left\lVert R(\bar{x},\theta)\right\rVert_{F}^{2},\left\lVert R(\theta)\right\rVert_{F}^{2},\left\lVert R_{j}(\theta)\right\rVert_{F}^{2},\left\lVert\hat{R}(\{X_{v}\}_{v},\theta)\right\rVert_{F}^{2}\leq A_{0}^{4} (11)
ppR(x¯,θ)F2,ppR(θ)F2,ppRj(θ)F2,ppR^({Xv}v,θ)F24A12A02.\displaystyle\sum_{p}\left\lVert\partial_{p}R(\bar{x},\theta)\right\rVert_{F}^{2},\sum_{p}\left\lVert\partial_{p}R(\theta)\right\rVert_{F}^{2},\sum_{p}\left\lVert\partial_{p}R_{j}(\theta)\right\rVert_{F}^{2},\sum_{p}\left\lVert\partial_{p}\hat{R}(\{X_{v}\}_{v},\theta)\right\rVert_{F}^{2}\leq 4A_{1}^{2}A_{0}^{2}. (12)
Proof.
R(x¯,θ)F2=𝖤x𝒜(|x¯)z(x;θ)zT(x;θ)F2𝖤x𝒜(|x¯)z(x;θ)2z(x;θ)2A04\displaystyle\begin{aligned} \left\lVert R(\bar{x},\theta)\right\rVert_{F}^{2}&=\left\lVert\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}z(x;\theta)z^{T}(x;\theta)\right\rVert_{F}^{2}\\ &\leq\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}\left\lVert z(x;\theta)\right\rVert^{2}\left\lVert z(x;\theta)\right\rVert^{2}\\ &\leq A_{0}^{4}\end{aligned} (13)
ppR(x¯,θ)F2p𝖤x𝒜(|x¯)pz(x;θ)zT(x;θ)+z(x;θ)pzT(x;θ)F24p𝖤x𝒜(|x¯)pz(x;θ)2z(x;θ)24A12A02\displaystyle\begin{aligned} \sum_{p}\left\lVert\partial_{p}R(\bar{x},\theta)\right\rVert_{F}^{2}&\leq\sum_{p}\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}\left\lVert\partial_{p}z(x;\theta)z^{T}(x;\theta)+z(x;\theta)\partial_{p}z^{T}(x;\theta)\right\rVert_{F}^{2}\\ &\leq 4\sum_{p}\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}\left\lVert\partial_{p}z(x;\theta)\right\rVert^{2}\left\lVert z(x;\theta)\right\rVert^{2}\\ &\leq 4A_{1}^{2}A_{0}^{2}\end{aligned} (14)

The remaining results directly follows Jansen’s inequality. ∎

Lemma B.5.

The following function, whose pp-th entry is defined as

ujt,e(θ)[p]=pTr{Rj+(θ)}+Tr{pRj(θ)jqjRj(θ)}\displaystyle u^{t,e}_{j}(\theta)[p]=-\partial_{p}Tr\left\{R_{j}^{+}(\theta)\right\}+Tr\left\{\partial_{p}R_{j}(\theta)\sum_{j^{\prime}}q_{j^{\prime}}R_{j^{\prime}}(\theta)\right\} (15)

is β\beta-Lipschitz continuous with

β2=24(A04+1)(A14+A22A02)+48A14A04.\displaystyle\beta^{2}=24(A_{0}^{4}+1)(A_{1}^{4}+A_{2}^{2}A_{0}^{2})+48A_{1}^{4}A_{0}^{4}. (16)
Proof.

We start with the derivative of ujt,e(θ)[p]u^{t,e}_{j}(\theta)[p]

mujt,e(θ)[p]=mpTr{Rj+(θ)}+Tr{mpRj(θ)jqjRj(θ)}+Tr{pRj(θ)jqjmRj(θ)}\displaystyle\begin{aligned} \partial_{m}u^{t,e}_{j}(\theta)[p]&=-\partial_{m}\partial_{p}Tr\left\{R_{j}^{+}(\theta)\right\}+Tr\left\{\partial_{m}\partial_{p}R_{j}(\theta)\sum_{j^{\prime}}q_{j^{\prime}}R_{j^{\prime}}(\theta)\right\}+Tr\left\{\partial_{p}R_{j}(\theta)\sum_{j^{\prime}}q_{j^{\prime}}\partial_{m}R_{j^{\prime}}(\theta)\right\}\end{aligned} (17)

Using AM-GM, we have

(mujt,e(θ)[p])23(mpTr{Rj+(θ)})2+3Tr{mpRj(θ)jqjRj(θ)}2+3Tr{pRj(θ)jqjmRj(θ)}2\displaystyle\begin{aligned} (\partial_{m}u^{t,e}_{j}(\theta)[p])^{2}\leq 3\left(\partial_{m}\partial_{p}Tr\left\{R_{j}^{+}(\theta)\right\}\right)^{2}+3Tr\left\{\partial_{m}\partial_{p}R_{j}(\theta)\sum_{j^{\prime}}q_{j^{\prime}}R_{j^{\prime}}(\theta)\right\}^{2}+3Tr\left\{\partial_{p}R_{j}(\theta)\sum_{j^{\prime}}q_{j^{\prime}}\partial_{m}R_{j^{\prime}}(\theta)\right\}^{2}\end{aligned} (18)

For the first term, recall the definition of Rj+(θ)R_{j}^{+}(\theta), we have

mpTr{Rj+(θ)}=2𝖤x¯𝒟jp𝖤x𝒜(|x¯)zT(x;θ)m𝖤x𝒜(|x¯)z(x;θ)+2𝖤x¯𝒟jmp𝖤x𝒜(|x¯)zT(x;θ)𝖤x𝒜(|x¯)z(x;θ).\displaystyle\begin{aligned} \partial_{m}\partial_{p}Tr\{R_{j}^{+}(\theta)\}&=2\mathsf{E}_{\bar{x}\sim\mathcal{D}_{j}}\partial_{p}\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}z^{T}(x;\theta)\partial_{m}\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}z(x;\theta)\\ &+2\mathsf{E}_{\bar{x}\sim\mathcal{D}_{j}}\partial_{m}\partial_{p}\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}z^{T}(x;\theta)\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}z(x;\theta).\end{aligned} (19)

Consequently, we have

m,p(mpTr{Rj+(θ)})28(𝖤x¯𝒟jp𝖤x𝒜(|x¯)zT(x;θ)m𝖤x𝒜(|x¯)z(x;θ))2+8(𝖤x¯𝒟jmp𝖤x𝒜(|x¯)zT(x;θ)𝖤x𝒜(|x¯)z(x;θ))28m,p𝖤x¯𝒟j𝖤x1,x2𝒜(x¯)((pzT(x1;θ)mz(x2;θ))2+(mpzT(x1;θ)z(x2;θ))2)8m,p𝖤x¯𝒟j𝖤x1,x2𝒜(x¯)(pz(x1;θ)22mz(x2;θ)22+mpz(x1;θ)22z(x2;θ)22)8(A14+A22A02)\displaystyle\begin{aligned} &\sum_{m,p}(\partial_{m}\partial_{p}Tr\left\{R_{j}^{+}(\theta)\right\})^{2}\\ &\leq 8\left(\mathsf{E}_{\bar{x}\sim\mathcal{D}_{j}}\partial_{p}\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}z^{T}(x;\theta)\partial_{m}\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}z(x;\theta)\right)^{2}+8\left(\mathsf{E}_{\bar{x}\sim\mathcal{D}_{j}}\partial_{m}\partial_{p}\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}z^{T}(x;\theta)\mathsf{E}_{x\sim\mathcal{A}(\cdot|\bar{x})}z(x;\theta)\right)^{2}\\ &\leq 8\sum_{m,p}\mathsf{E}_{\bar{x}\sim\mathcal{D}_{j}}\mathsf{E}_{x_{1},x_{2}\sim\mathcal{A}(\bar{x})}\left(\left(\partial_{p}z^{T}(x_{1};\theta)\partial_{m}z(x_{2};\theta)\right)^{2}+\left(\partial_{m}\partial_{p}z^{T}(x_{1};\theta)z(x_{2};\theta)\right)^{2}\right)\\ &\leq 8\sum_{m,p}\mathsf{E}_{\bar{x}\sim\mathcal{D}_{j}}\mathsf{E}_{x_{1},x_{2}\sim\mathcal{A}(\bar{x})}\left(\left\lVert\partial_{p}z(x_{1};\theta)\right\rVert_{2}^{2}\left\lVert\partial_{m}z(x_{2};\theta)\right\rVert_{2}^{2}+\left\lVert\partial_{m}\partial_{p}z(x_{1};\theta)\right\rVert_{2}^{2}\left\lVert z(x_{2};\theta)\right\rVert_{2}^{2}\right)\\ &\leq 8(A_{1}^{4}+A_{2}^{2}A_{0}^{2})\end{aligned} (20)

where the first inequality uses AM-GM, and the second inequality uses Jensen’s inequality.

For the second term, we have

Tr{mpRj(θ)jqjRj(θ)}2mpRj(θ)F2jqjRj(θ)F2.\displaystyle\begin{aligned} Tr\left\{\partial_{m}\partial_{p}R_{j}(\theta)\sum_{j^{\prime}}q_{j^{\prime}}R_{j^{\prime}}(\theta)\right\}^{2}\leq\left\lVert\partial_{m}\partial_{p}R_{j}(\theta)\right\rVert_{F}^{2}\left\lVert\sum_{j^{\prime}}q_{j^{\prime}}R_{j^{\prime}}(\theta)\right\rVert_{F}^{2}.\end{aligned} (21)

Notice that

mpRj(θ)=𝖤x𝒜(|𝒟j)mpz(x;θ)zT(x;θ)+z(x;θ)mpzT(x;θ)+mz(x;θ)pzT(x;θ)+pz(x;θ)mzT(x;θ)\displaystyle\begin{aligned} \partial_{m}\partial_{p}R_{j}(\theta)&=\mathsf{E}_{x\sim\mathcal{A}(\cdot|\mathcal{D}_{j})}\partial_{m}\partial_{p}z(x;\theta)z^{T}(x;\theta)+z(x;\theta)\partial_{m}\partial_{p}z^{T}(x;\theta)\\ &+\partial_{m}z(x;\theta)\partial_{p}z^{T}(x;\theta)+\partial_{p}z(x;\theta)\partial_{m}z^{T}(x;\theta)\end{aligned} (22)

We have

m,pmpRj(θ)F24𝖤x𝒜(|𝒟j)mpz(x;θ)zT(x;θ)F2+4𝖤x𝒜(|𝒟j)z(x;θ)mpzT(x;θ)F2+4𝖤x𝒜(|𝒟j)mz(x;θ)pzT(x;θ)F2+4𝖤x𝒜(|𝒟j)pz(x;θ)mzT(x;θ)F24𝖤x𝒜(|𝒟j)[mpz(x;θ)zT(x;θ)F2+z(x;θ)mpzT(x;θ)F2+mz(x;θ)pzT(x;θ)F2+pz(x;θ)mzT(x;θ)F2]8m,p𝖤x𝒜(|𝒟j)(mpz(x;θ)22z(x;θ)22+mz(x;θ)22pz(x;θ)22)8(A14+A22A02)\displaystyle\begin{aligned} \sum_{m,p}\left\lVert\partial_{m}\partial_{p}R_{j}(\theta)\right\rVert_{F}^{2}&\leq 4\left\lVert\mathsf{E}_{x\sim\mathcal{A}(\cdot|\mathcal{D}_{j})}\partial_{m}\partial_{p}z(x;\theta)z^{T}(x;\theta)\right\rVert_{F}^{2}+4\left\lVert\mathsf{E}_{x\sim\mathcal{A}(\cdot|\mathcal{D}_{j})}z(x;\theta)\partial_{m}\partial_{p}z^{T}(x;\theta)\right\rVert_{F}^{2}\\ &+4\left\lVert\mathsf{E}_{x\sim\mathcal{A}(\cdot|\mathcal{D}_{j})}\partial_{m}z(x;\theta)\partial_{p}z^{T}(x;\theta)\right\rVert_{F}^{2}+4\left\lVert\mathsf{E}_{x\sim\mathcal{A}(\cdot|\mathcal{D}_{j})}\partial_{p}z(x;\theta)\partial_{m}z^{T}(x;\theta)\right\rVert_{F}^{2}\\ &\leq 4\mathsf{E}_{x\sim\mathcal{A}(\cdot|\mathcal{D}_{j})}\biggl{[}\left\lVert\partial_{m}\partial_{p}z(x;\theta)z^{T}(x;\theta)\right\rVert_{F}^{2}+\left\lVert z(x;\theta)\partial_{m}\partial_{p}z^{T}(x;\theta)\right\rVert_{F}^{2}\\ &+\left\lVert\partial_{m}z(x;\theta)\partial_{p}z^{T}(x;\theta)\right\rVert_{F}^{2}+\left\lVert\partial_{p}z(x;\theta)\partial_{m}z^{T}(x;\theta)\right\rVert_{F}^{2}\biggr{]}\\ &\leq 8\sum_{m,p}\mathsf{E}_{x\sim\mathcal{A}(\cdot|\mathcal{D}_{j})}\left(\left\lVert\partial_{m}\partial_{p}z(x;\theta)\right\rVert^{2}_{2}\left\lVert z(x;\theta)\right\rVert^{2}_{2}+\left\lVert\partial_{m}z(x;\theta)\right\rVert_{2}^{2}\left\lVert\partial_{p}z(x;\theta)\right\rVert_{2}^{2}\right)\\ &\leq 8(A_{1}^{4}+A_{2}^{2}A_{0}^{2})\end{aligned} (23)

Apply Lemma B.4, we have

jqjRj(θ)F2jqjRj(θ)F2A04\displaystyle\begin{aligned} \left\lVert\sum_{j^{\prime}}q_{j^{\prime}}R_{j^{\prime}}(\theta)\right\rVert_{F}^{2}&\leq\sum_{j^{\prime}}q_{j^{\prime}}\left\lVert R_{j^{\prime}}(\theta)\right\rVert_{F}^{2}\leq A_{0}^{4}\end{aligned} (24)

Substitute eq. (21) with eq. (23) and eq. (24), we have

Tr{mpRj(θ)jqjRj(θ)}28A04(A14+A22A02)\displaystyle\begin{aligned} Tr\left\{\partial_{m}\partial_{p}Rj(\theta)\sum_{j^{\prime}}q_{j^{\prime}}R_{j^{\prime}}(\theta)\right\}^{2}\leq 8A_{0}^{4}(A_{1}^{4}+A_{2}^{2}A_{0}^{2})\end{aligned} (25)

For the third term, we have

m,pTr{pRj(θ)jqjmRj(θ)}2m,ppRj(θ)F2jqjmRj(θ)F2m,ppRj(θ)F2jqjmRj(θ)F216A14A04\displaystyle\begin{aligned} \sum_{m,p}Tr\left\{\partial_{p}R_{j}(\theta)\sum_{j^{\prime}}q_{j^{\prime}}\partial_{m}R_{j^{\prime}}(\theta)\right\}^{2}&\leq\sum_{m,p}\left\lVert\partial_{p}R_{j}(\theta)\right\rVert_{F}^{2}\left\lVert\sum_{j^{\prime}}q_{j^{\prime}}\partial_{m}R_{j^{\prime}}(\theta)\right\rVert_{F}^{2}\\ &\leq\sum_{m,p}\left\lVert\partial_{p}R_{j}(\theta)\right\rVert_{F}^{2}\sum_{j^{\prime}}q_{j^{\prime}}\left\lVert\partial_{m}R_{j^{\prime}}(\theta)\right\rVert_{F}^{2}\\ &\leq 16A_{1}^{4}A_{0}^{4}\end{aligned} (26)

where the last inequality uses Lemma B.4.

Combine eq. (20), eq. (25) and eq. (26), we have

ujt,e(θ)F2=m,p(mujt,e(θ)[p])224(A04+1)(A14+A22A02)+48A14A04\displaystyle\left\lVert\nabla u^{t,e}_{j}(\theta)\right\rVert_{F}^{2}=\sum_{m,p}(\partial_{m}u^{t,e}_{j}(\theta)[p])^{2}\leq 24(A_{0}^{4}+1)(A_{1}^{4}+A_{2}^{2}A_{0}^{2})+48A_{1}^{4}A_{0}^{4} (27)

and thus β2=4(A04+1)(A14+A22A02)+48A14A04\beta^{2}=4(A_{0}^{4}+1)(A_{1}^{4}+A_{2}^{2}A_{0}^{2})+48A_{1}^{4}A_{0}^{4}. ∎

Corollary B.6.

The global loss (θ)\mathcal{L}(\theta) is β\beta-smooth.

Proof.

Notice that (θ)=jqjujt,e(θt,e)\nabla\mathcal{L}(\theta)=\sum_{j}q_{j}u^{t,e}_{j}(\theta^{t,e}). The result follows after applying Lemma B.5. ∎

Lemma B.7.

For any random matrix X,YX,Y, we have

𝖵𝖺𝗋[Tr{XY}]2𝖤[Y]F2𝖵𝖺𝗋[X]+2𝖤[X]F2𝖵𝖺𝗋[Y]\displaystyle\mathsf{Var}\left[Tr\left\{XY\right\}\right]\leq 2\left\lVert\mathsf{E}[Y]\right\rVert_{F}^{2}\mathsf{Var}[X]+2\left\lVert\mathsf{E}[X]\right\rVert_{F}^{2}\mathsf{Var}[Y] (28)
Proof.
𝖵𝖺𝗋[Tr{XY}]2𝖵𝖺𝗋[Tr{(X𝖤[X])Y}]+2𝖵𝖺𝗋[Tr{𝖤[X]Y}]2𝖤[Tr{(X𝖤[X])Y}2]+2𝖤[X]F2𝖵𝖺𝗋[Y]2𝖤[Y]F2𝖵𝖺𝗋[X]+2𝖤[X]F2𝖵𝖺𝗋[Y]\displaystyle\begin{aligned} \mathsf{Var}\left[Tr\left\{XY\right\}\right]&\leq 2\mathsf{Var}[Tr\{(X-\mathsf{E}[X])Y\}]+2\mathsf{Var}[Tr\{\mathsf{E}[X]Y\}]\\ &\leq 2\mathsf{E}[Tr\{(X-\mathsf{E}[X])Y\}^{2}]+2\left\lVert\mathsf{E}[X]\right\rVert_{F}^{2}\mathsf{Var}[Y]\\ &\leq 2\left\lVert\mathsf{E}[Y]\right\rVert_{F}^{2}\mathsf{Var}[X]+2\left\lVert\mathsf{E}[X]\right\rVert_{F}^{2}\mathsf{Var}[Y]\end{aligned} (29)

Lemma B.8.

The local stochastic gradient with pp-th entry defined as

vjt,e[p]=pTr{R^j+({Xv}v,θjt,e)}+qjTr{R^j({Xv}v,θjt,e)pR^j({Xv}v,θjt,e)}+(1qj)Tr{pR^j({Xv}v,θjt,e))R~jt}.\displaystyle\begin{aligned} v_{j}^{t,e}[p]&=-\partial_{p}Tr\left\{\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}+q_{j}Tr\left\{\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}\\ &+(1-q_{j})Tr\left\{\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))\tilde{R}^{t}_{-j}\right\}.\end{aligned} (30)

has bounded norm

𝖤[vjt,e2|t,0]12A12A02(H2σ2+1)+12A12A06.\displaystyle\begin{aligned} \mathsf{E}\!\left[\left\lVert v_{j}^{t,e}\right\rVert^{2}|\mathcal{F}^{t,0}\right]\leq 12A_{1}^{2}A_{0}^{2}(H^{2}\sigma^{2}+1)+12A_{1}^{2}A_{0}^{6}.\end{aligned} (31)

where t,0\mathcal{F}^{t,0} is the history before the tt-th round; σ2\sigma^{2} is the variance of the DP noise and HH is the dimension of the representation z(x,θ)z(x,\theta).

Proof.
p(vjt,e[p])22pTr{pR^j+({Xv}v,θjt,e)}2+2pTr{pR^j({Xv}v,θjt,e))(qjR^j({Xv}v,θjt,e))+(1qj)R~jt)}2.\displaystyle\begin{aligned} \sum_{p}(v_{j}^{t,e}[p])^{2}&\leq 2\sum_{p}Tr\left\{\partial_{p}\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}^{2}\\ &+2\sum_{p}Tr\left\{\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))\left(q_{j}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))+(1-q_{j})\tilde{R}^{t}_{-j}\right)\right\}^{2}.\end{aligned} (32)

For the first term,

Tr{R^j+({Xv}v,θjt,e)}=1BVb=1Bv=1VzT(xb,v,θjt,e)z(xb,v+V,θjt,e).\displaystyle Tr\left\{\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}=\frac{1}{BV}\sum_{b=1}^{B}\sum_{v=1}^{V}z^{T}(x_{b,v},\theta^{t,e}_{j})z(x_{b,v+V},\theta^{t,e}_{j}). (33)

Then we have

Tr{pR^j+({Xv}v,θjt,e)}=1BVb=1Bv=1Vpz(xb,v,θjt,e)z(xb,v+V,θjt,e)T+z(xb,v,θjt,e)Tpz(xb,v+V,θjt,e)\displaystyle Tr\left\{\partial_{p}\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}=\frac{1}{BV}\sum_{b=1}^{B}\sum_{v=1}^{V}\partial_{p}z(x_{b,v},\theta^{t,e}_{j})z(x_{b,v+V},\theta^{t,e}_{j})^{T}+z(x_{b,v},\theta^{t,e}_{j})^{T}\partial_{p}z(x_{b,v+V},\theta^{t,e}_{j}) (34)

and

pTr{pR^j+({Xv}v,θjt,e)}2p2BVb=1Bv=1Vpz(xb,v,θjt,e)2z(xb,v+V,θjt,e)2+z(xb,v,θjt,e)2pz(xb,v+V,θjt,e)24A12A02\displaystyle\begin{aligned} &\sum_{p}Tr\left\{\partial_{p}\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}^{2}\\ &\leq\sum_{p}\frac{2}{BV}\sum_{b=1}^{B}\sum_{v=1}^{V}\left\lVert\partial_{p}z(x_{b,v},\theta^{t,e}_{j})\right\rVert^{2}\left\lVert z(x_{b,v+V},\theta^{t,e}_{j})\right\rVert^{2}+\left\lVert z(x_{b,v},\theta^{t,e}_{j})\right\rVert^{2}\left\lVert\partial_{p}z(x_{b,v+V},\theta^{t,e}_{j})\right\rVert^{2}\\ &\leq 4A_{1}^{2}A_{0}^{2}\end{aligned} (35)

where the line uses Jensen’s inequality and AM-GM. For the second term,

pTr{pR^j({Xv}v,θjt,e))(qjR^j({Xv}v,θjt,e))+(1qj)R~jt)}2qjR^j({Xv}v,θjt,e))+(1qj)R~jtF2ppR^j({Xv}v,θjt,e)F2qjR^j({Xv}v,θjt,e))+(1qj)R~jtF24A12A02\displaystyle\begin{aligned} &\sum_{p}Tr\left\{\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))\left(q_{j}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))+(1-q_{j})\tilde{R}^{t}_{-j}\right)\right\}^{2}\\ &\leq\left\lVert q_{j}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))+(1-q_{j})\tilde{R}^{t}_{-j}\right\rVert_{F}^{2}\sum_{p}\left\lVert\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\rVert_{F}^{2}\\ &\leq\left\lVert q_{j}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))+(1-q_{j})\tilde{R}^{t}_{-j}\right\rVert_{F}^{2}4A_{1}^{2}A_{0}^{2}\end{aligned} (36)

where the last inequality uses Lemma B.4. Combine the above results we have

𝖤[vjt,e2|t,0]8A12A02+8A12A02𝖤[qjR^j({Xv}v,θjt,e))+(1qj)R~jtF2|t,0]=8A12A02+8A12A02(A04+jjqj2H2σ2)8A12A02+8A12A02(A04+H2σ2)\displaystyle\begin{aligned} \mathsf{E}\!\left[\left\lVert v_{j}^{t,e}\right\rVert^{2}|\mathcal{F}^{t,0}\right]&\leq 8A_{1}^{2}A_{0}^{2}+8A_{1}^{2}A_{0}^{2}\mathsf{E}\!\left[\left\lVert q_{j}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))+(1-q_{j})\tilde{R}^{t}_{-j}\right\rVert_{F}^{2}|\mathcal{F}^{t,0}\right]\\ &=8A_{1}^{2}A_{0}^{2}+8A_{1}^{2}A_{0}^{2}\left(A_{0}^{4}+\sum_{j^{\prime}\neq j}q_{j^{\prime}}^{2}H^{2}\sigma^{2}\right)\\ &\leq 8A_{1}^{2}A_{0}^{2}+8A_{1}^{2}A_{0}^{2}\left(A_{0}^{4}+H^{2}\sigma^{2}\right)\end{aligned} (37)

where we use the fact that qjR^j({Xv}v,θjt,e))+(1qj)R~tjq_{j}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))+(1-q_{j})\tilde{R}^{t}_{-j} is essentially a correlation matrix plus DP noise with scale jjqj2H2σ2\sum_{j^{\prime}\neq j}q_{j^{\prime}}^{2}H^{2}\sigma^{2}. ∎

B.4 Proof of the fully participation

From the β\beta-smoothness of (θ)\mathcal{L}(\theta) given by Corollary B.6, we have

(θt,e+1)(θt,e)η(θt,e),vt,e+βη22vt,e2.\displaystyle\mathcal{L}(\theta^{t,e+1})\leq\mathcal{L}(\theta^{t,e})-\eta\langle\nabla\mathcal{L}(\theta^{t,e}),v^{t,e}\rangle+\frac{\beta\eta^{2}}{2}\left\lVert v^{t,e}\right\rVert^{2}. (38)

Denote the history of the optimization process as t,e\mathcal{F}^{t,e}, then we have

𝖤[(θt,e+1)|t,e](θt,e)η(θt,e),𝖤[vt,e|t,e]+βη22𝖤[vt,e2|t,e].\displaystyle\mathsf{E}\!\left[\mathcal{L}(\theta^{t,e+1})|\mathcal{F}^{t,e}\right]\leq\mathcal{L}(\theta^{t,e})-\eta\langle\nabla\mathcal{L}(\theta^{t,e}),\mathsf{E}\!\left[v^{t,e}|\mathcal{F}^{t,e}\right]\rangle+\frac{\beta\eta^{2}}{2}\mathsf{E}\!\left[\left\lVert v^{t,e}\right\rVert^{2}|\mathcal{F}^{t,e}\right]. (39)

Let v¯t,e=𝖤[vt,e|t,e]\bar{v}^{t,e}=\mathsf{E}\!\left[v^{t,e}|\mathcal{F}^{t,e}\right] and v¯jt,e=𝖤[vjt,e|t,e]\bar{v}_{j}^{t,e}=\mathsf{E}\!\left[v_{j}^{t,e}|\mathcal{F}^{t,e}\right], we have

𝖤[(θt,e+1)|t,e](θt,e)+βη22𝖤[vt,e2|t,e]+η2[(θt,e)2v¯t,e2+(θt,e)v¯t,e2].\displaystyle\begin{aligned} \mathsf{E}\!\left[\mathcal{L}(\theta^{t,e+1})|\mathcal{F}^{t,e}\right]&\leq\mathcal{L}(\theta^{t,e})+\frac{\beta\eta^{2}}{2}\mathsf{E}\!\left[\left\lVert v^{t,e}\right\rVert^{2}|\mathcal{F}^{t,e}\right]\\ &+\frac{\eta}{2}\left[-\left\lVert\nabla\mathcal{L}(\theta^{t,e})\right\rVert^{2}-\left\lVert\bar{v}^{t,e}\right\rVert^{2}+\left\lVert\nabla\mathcal{L}(\theta^{t,e})-\bar{v}^{t,e}\right\rVert^{2}\right].\end{aligned} (40)

By the choice of η1/β\eta\leq 1/\beta, we have

𝖤[(θt,e+1)|t,e](θt,e)η2(θt,e)2+η2(θt,e)v¯t,e2T1+βη22𝖵𝖺𝗋[vt,e|t,e]T2\displaystyle\begin{aligned} \mathsf{E}\!\left[\mathcal{L}(\theta^{t,e+1})|\mathcal{F}^{t,e}\right]\leq\mathcal{L}(\theta^{t,e})-\frac{\eta}{2}\left\lVert\nabla\mathcal{L}(\theta^{t,e})\right\rVert^{2}+\frac{\eta}{2}\underbrace{\left\lVert\nabla\mathcal{L}(\theta^{t,e})-\bar{v}^{t,e}\right\rVert^{2}}_{T_{1}}+\frac{\beta\eta^{2}}{2}\underbrace{\mathsf{Var}\left[v^{t,e}|\mathcal{F}^{t,e}\right]}_{T_{2}}\end{aligned} (41)

B.4.1 Bounding the term T1T_{1}

Recall the definition of local batch loss

^jSC(θjt,e)=Tr{R^j+({Xv}v,θjt,e)}+12qjR^j({Xv}v,θjt,e)F2+(1qj)Tr{R^j({Xv}v,θjt,e)R~jt}\displaystyle\hat{\mathcal{L}}_{j}^{SC}(\theta^{t,e}_{j})=-Tr\{\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\}+\frac{1}{2}q_{j}\left\lVert\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\rVert_{F}^{2}+(1-q_{j})Tr\{\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\tilde{R}^{t}_{-j}\} (42)

The pp-th entry of vjt,e=^jSC(θjt,e)v_{j}^{t,e}=\nabla\hat{\mathcal{L}}_{j}^{SC}(\theta^{t,e}_{j}) is

vjt,e[p]=pTr{R^j+({Xv}v,θjt,e)}+qjTr{R^j({Xv}v,θjt,e)pR^j({Xv}v,θjt,e)}+(1qj)Tr{pR^j({Xv}v,θjt,e))R~jt}.\displaystyle\begin{aligned} v_{j}^{t,e}[p]&=-\partial_{p}Tr\left\{\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}+q_{j}Tr\left\{\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}\\ &+(1-q_{j})Tr\left\{\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))\tilde{R}^{t}_{-j}\right\}.\end{aligned} (43)

Take expectation over {Xv}v\{X_{v}\}_{v}, we have

v¯jt,e[p]=pTr{Rj+(θjt,e)}+qj𝖤{Xv}v[Tr{R^j({Xv}v,θjt,e)pR^j({Xv}v,θjt,e)}]+(1qj)Tr{pRj(θjt,e)R~jt}.\displaystyle\begin{aligned} \bar{v}_{j}^{t,e}[p]&=-\partial_{p}Tr\left\{R_{j}^{+}(\theta^{t,e}_{j})\right\}+q_{j}\mathsf{E}_{\{X_{v}\}_{v}}\left[Tr\left\{\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}\right]\\ &+(1-q_{j})Tr\left\{\partial_{p}R_{j}(\theta^{t,e}_{j})\tilde{R}^{t}_{-j}\right\}.\end{aligned} (44)

The pp-th entry of the global loss gradient is

p(θt,e)=jqjpTr{Rj+(θt,e)}+jqjTr{pRj(θt,e)jqjRj(θt,e)}=jqj(pTr{Rj+(θt,e)}+qjTr{pRj(θt,e)Rj(θt,e)}+Tr{pRj(θt,e)jjqjRj(θt,e)})\displaystyle\begin{aligned} \partial_{p}\mathcal{L}(\theta^{t,e})&=-\sum_{j}q_{j}\partial_{p}Tr\left\{R_{j}^{+}(\theta^{t,e})\right\}+\sum_{j}q_{j}Tr\left\{\partial_{p}R_{j}(\theta^{t,e})\sum_{j^{\prime}}q_{j^{\prime}}R_{j^{\prime}}(\theta^{t,e})\right\}\\ &=\sum_{j}q_{j}\left(-\partial_{p}Tr\left\{R_{j}^{+}(\theta^{t,e})\right\}+q_{j}Tr\left\{\partial_{p}R_{j}(\theta^{t,e})R_{j}(\theta^{t,e})\right\}+Tr\left\{\partial_{p}R_{j}(\theta^{t,e})\sum_{j^{\prime}\neq j}q_{j^{\prime}}R_{j^{\prime}}(\theta^{t,e})\right\}\right)\\ \end{aligned} (45)

Decompose vjt,e[p]=ujt,e(θjt,e)[p]+qjbjt,e[p]+cjt,e[p]v_{j}^{t,e}[p]=u^{t,e}_{j}(\theta^{t,e}_{j})[p]+q_{j}b_{j}^{t,e}[p]+c_{j}^{t,e}[p], where the terms are defined as follows.

ujt,e(θ)[p]=pTr{Rj+(θ)}+Tr{pRj(θ)jqjRj(θ)}bjt,e[p]=𝖤{Xv}v[Tr{R^j({Xv}v,θjt,e)pR^j({Xv}v,θjt,e)}]Tr{Rj(θjt,e)pRj(θjt,e)}cjt,e[p]=(1qj)Tr{pRj(θjt,e)R~jt}Tr{pRj(θjt,e)jjqjRj(θjt,e)}\displaystyle\begin{aligned} u^{t,e}_{j}(\theta)[p]&=-\partial_{p}Tr\left\{R_{j}^{+}(\theta)\right\}+Tr\left\{\partial_{p}R_{j}(\theta)\sum_{j^{\prime}}q_{j^{\prime}}R_{j^{\prime}}(\theta)\right\}\\ b_{j}^{t,e}[p]&=\mathsf{E}_{\{X_{v}\}_{v}}\left[Tr\left\{\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}\right]-Tr\left\{R_{j}(\theta_{j}^{t,e})\partial_{p}R_{j}(\theta_{j}^{t,e})\right\}\\ c_{j}^{t,e}[p]&=(1-q_{j})Tr\left\{\partial_{p}R_{j}(\theta_{j}^{t,e})\tilde{R}^{t}_{-j}\right\}-Tr\left\{\partial_{p}R_{j}(\theta_{j}^{t,e})\sum_{j^{\prime}\neq j}q_{j^{\prime}}R_{j^{\prime}}(\theta_{j}^{t,e})\right\}\end{aligned} (46)

Then we have

v¯t,e[p]p(θt,e)=jqj(ujt,e(θjt,e)[p]ujt,e(θt,e)[p]+qjbjt,e[p]+cjt,e[p]).\displaystyle\begin{aligned} \bar{v}^{t,e}[p]-\partial_{p}\mathcal{L}(\theta^{t,e})&=\sum_{j}q_{j}\left(u^{t,e}_{j}(\theta^{t,e}_{j})[p]-u^{t,e}_{j}(\theta^{t,e})[p]+q_{j}b_{j}^{t,e}[p]+c_{j}^{t,e}[p]\right).\end{aligned} (47)

The term T1T_{1} can be written as follows

T1=p(v¯t,e[p]p(θt,e))2=p(jqj(ujt,e(θjt,e)[p]ujt,e(θt,e)[p]+qjbjt,e[p]+cjt,e[p]))2jqjp(ujt,e(θjt,e)[p]ujt,e(θt,e)[p]+qjbjt,e[p]+cjt,e[p])23jqj[p(ujt,e(θjt,e)[p]ujt,e(θt,e)[p])2T3+qj2p(bjt,e[p])2T4+p(cjt,e[p])2T5]\displaystyle\begin{aligned} T_{1}&=\sum_{p}\left(\bar{v}^{t,e}[p]-\partial_{p}\mathcal{L}(\theta^{t,e})\right)^{2}\\ &=\sum_{p}\left(\sum_{j}q_{j}\left(u^{t,e}_{j}(\theta^{t,e}_{j})[p]-u^{t,e}_{j}(\theta^{t,e})[p]+q_{j}b_{j}^{t,e}[p]+c_{j}^{t,e}[p]\right)\right)^{2}\\ &\leq\sum_{j}q_{j}\sum_{p}\left(u^{t,e}_{j}(\theta^{t,e}_{j})[p]-u^{t,e}_{j}(\theta^{t,e})[p]+q_{j}b_{j}^{t,e}[p]+c_{j}^{t,e}[p]\right)^{2}\\ &\leq 3\sum_{j}q_{j}\left[\underbrace{\sum_{p}(u^{t,e}_{j}(\theta^{t,e}_{j})[p]-u^{t,e}_{j}(\theta^{t,e})[p])^{2}}_{T_{3}}+q^{2}_{j}\underbrace{\sum_{p}(b_{j}^{t,e}[p])^{2}}_{T_{4}}+\underbrace{\sum_{p}(c_{j}^{t,e}[p])^{2}}_{T_{5}}\right]\\ \end{aligned} (48)

where the third line uses Jensen’s inequality and the last line uses AM-GM. By Lemma , we have

T3β2θt,eθjt,e2.\displaystyle T_{3}\leq\beta^{2}\left\lVert\theta^{t,e}-\theta_{j}^{t,e}\right\rVert^{2}. (49)

For the term T4T_{4}, we have

T4=pTr{𝖤{Xv}v[R^j({Xv}v,θjt,e)Rj(θjt,e)][pR^j({Xv}v,θjt,e)pRj(θjt,e)]}2p𝖤{Xv}v[R^j({Xv}v,θjt,e)Rj(θjt,e)F2pR^j({Xv}v,θjt,e)pRj(θjt,e)F2]2𝖤{Xv}v[(R^j({Xv}v,θjt,e)F2+Rj(θjt,e)F2)ppR^j({Xv}v,θjt,e)pRj(θjt,e)F2]4A04𝖤{Xv}vppR^j({Xv}v,θjt,e)pRj(θjt,e)F2=4A04𝖤X¯𝒟j𝖤{Xv}v|X¯ppR^j({Xv}v,θjt,e)pR(X¯jt,e;θjt,e)+pR(X¯jt,e;θjt,e)pRj(θjt,e)F2\displaystyle\begin{aligned} T_{4}&=\sum_{p}Tr\left\{\mathsf{E}_{\{X_{v}\}_{v}}\left[\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})-R_{j}(\theta_{j}^{t,e})\right]\left[\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})-\partial_{p}R_{j}(\theta_{j}^{t,e})\right]\right\}^{2}\\ &\leq\sum_{p}\mathsf{E}_{\{X_{v}\}_{v}}\left[\left\lVert\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})-R_{j}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\left\lVert\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})-\partial_{p}R_{j}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\right]\\ &\leq 2\mathsf{E}_{\{X_{v}\}_{v}}\left[\left(\left\lVert\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\rVert_{F}^{2}+\left\lVert R_{j}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\right)\sum_{p}\left\lVert\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})-\partial_{p}R_{j}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\right]\\ &\leq 4A_{0}^{4}\mathsf{E}_{\{X_{v}\}_{v}}\sum_{p}\left\lVert\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})-\partial_{p}R_{j}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\\ &=4A_{0}^{4}\mathsf{E}_{\bar{X}\sim\mathcal{D}_{j}}\mathsf{E}_{\{X_{v}\}_{v}|\bar{X}}\sum_{p}\left\lVert\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})-\partial_{p}R(\bar{X}_{j}^{t,e};\theta_{j}^{t,e})+\partial_{p}R(\bar{X}_{j}^{t,e};\theta_{j}^{t,e})-\partial_{p}R_{j}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\end{aligned} (50)

where the third inequality uses Lemma B.4 X¯\bar{X} is a batch of samples drawn from 𝒟j\mathcal{D}_{j}, and {Xv}v\{X_{v}\}_{v} are augmented views of X¯\bar{X}. Notice that

𝖤{Xv}v|X¯pR^j({Xv}v,θjt,e)pR(X¯jt,e;θjt,e),pR(X¯jt,e;θjt,e)pRj(θjt,e)=0\displaystyle\mathsf{E}_{\{X_{v}\}_{v}|\bar{X}}\langle\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})-\partial_{p}R(\bar{X}_{j}^{t,e};\theta_{j}^{t,e}),\partial_{p}R(\bar{X}_{j}^{t,e};\theta_{j}^{t,e})-\partial_{p}R_{j}(\theta_{j}^{t,e})\rangle=0 (51)

we have

T44A04p𝖤X¯𝒟j𝖤{Xv}v|X¯(pR^j({Xv}v,θjt,e)pR(X¯jt,e;θjt,e)F2+pR(X¯jt,e;θjt,e)pRj(θjt,e)F2)=4A04p(𝖤X¯𝒟j𝖵𝖺𝗋{Xv}v|X¯[pR^j({Xv}v,θjt,e)]+𝖵𝖺𝗋X¯𝒟j[pR(X¯;θjt,e)])16A12A06(12V+|𝒟j|/B1|𝒟j|1)\displaystyle\begin{aligned} T_{4}&\leq 4A_{0}^{4}\sum_{p}\mathsf{E}_{\bar{X}\sim\mathcal{D}_{j}}\mathsf{E}_{\{X_{v}\}_{v}|\bar{X}}\biggl{(}\left\lVert\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})-\partial_{p}R(\bar{X}_{j}^{t,e};\theta_{j}^{t,e})\right\rVert_{F}^{2}+\left\lVert\partial_{p}R(\bar{X}_{j}^{t,e};\theta_{j}^{t,e})-\partial_{p}R_{j}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\biggr{)}\\ &=4A_{0}^{4}\sum_{p}\left(\mathsf{E}_{\bar{X}\sim\mathcal{D}_{j}}\mathsf{Var}_{\{X_{v}\}_{v}|\bar{X}}\left[\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right]+\mathsf{Var}_{\bar{X}\sim\mathcal{D}_{j}}\left[\partial_{p}R(\bar{X};\theta_{j}^{t,e})\right]\right)\\ &\leq 16A_{1}^{2}A_{0}^{6}\left(\frac{1}{2V}+\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\end{aligned} (52)

where the last inequality uses Lemma B.4, the fact 𝖵𝖺𝗋[X]𝖤X2\mathsf{Var}[X]\leq\mathsf{E}\left\lVert X\right\rVert^{2} and sampling with and without replacement.

For the term T5T_{5}, we have

T5=Tr{pRj(θjt,e)((1qj)R~jtjjqjRj(θjt,e))}2=pTr{pRj(θjt,e)jjqj(R~jtRj(θjt,e))}2ppRj(θjt,e)F2jjqj(R~jtRj(θjt,e))F24A12A02jjqj(R~jtRj(θjt,e))F24(1qj)A12A02jjqjR~jtRj(θjt,e)F2=4(1qj)A12A02jjqjR~jtRj(θt)+Rj(θt)Rj(θjt,e)F28(1qj)A12A02jjqj(R~jtRj(θt)F2+Rj(θt)Rj(θjt,e)F2)\displaystyle\begin{aligned} T_{5}&=Tr\left\{\partial_{p}R_{j}(\theta_{j}^{t,e})\left((1-q_{j})\tilde{R}^{t}_{-j}-\sum_{j^{\prime}\neq j}q_{j^{\prime}}R_{j^{\prime}}(\theta_{j}^{t,e})\right)\right\}^{2}\\ &=\sum_{p}Tr\left\{\partial_{p}R_{j}(\theta_{j}^{t,e})\sum_{j^{\prime}\neq j}q_{j^{\prime}}\left(\tilde{R}^{t}_{j^{\prime}}-R_{j^{\prime}}(\theta_{j}^{t,e})\right)\right\}^{2}\\ &\leq\sum_{p}\left\lVert\partial_{p}R_{j}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\left\lVert\sum_{j^{\prime}\neq j}q_{j^{\prime}}\left(\tilde{R}^{t}_{j^{\prime}}-R_{j^{\prime}}(\theta_{j}^{t,e})\right)\right\rVert_{F}^{2}\\ &\leq 4A_{1}^{2}A_{0}^{2}\left\lVert\sum_{j^{\prime}\neq j}q_{j^{\prime}}\left(\tilde{R}^{t}_{j^{\prime}}-R_{j^{\prime}}(\theta_{j}^{t,e})\right)\right\rVert_{F}^{2}\\ &\leq 4(1-q_{j})A_{1}^{2}A_{0}^{2}\sum_{j^{\prime}\neq j}q_{j^{\prime}}\left\lVert\tilde{R}^{t}_{j^{\prime}}-R_{j^{\prime}}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\\ &=4(1-q_{j})A_{1}^{2}A_{0}^{2}\sum_{j^{\prime}\neq j}q_{j^{\prime}}\left\lVert\tilde{R}^{t}_{j^{\prime}}-R_{j^{\prime}}(\theta^{t})+R_{j^{\prime}}(\theta^{t})-R_{j^{\prime}}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\\ &\leq 8(1-q_{j})A_{1}^{2}A_{0}^{2}\sum_{j^{\prime}\neq j}q_{j^{\prime}}\left(\left\lVert\tilde{R}^{t}_{j^{\prime}}-R_{j^{\prime}}(\theta^{t})\right\rVert_{F}^{2}+\left\lVert R_{j^{\prime}}(\theta^{t})-R_{j^{\prime}}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\right)\\ \end{aligned} (53)

where the second inequality uses Lemma B.4, and the third inequality uses Jensen’s inequality.

Use Lemma B.4 and mean-value theorem, we have

Rj(θt)Rj(θjt,e)F24A12A02θtθjt,e22.\displaystyle\left\lVert R_{j^{\prime}}(\theta^{t})-R_{j^{\prime}}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\leq 4A_{1}^{2}A_{0}^{2}\left\lVert\theta^{t}-\theta_{j}^{t,e}\right\rVert_{2}^{2}. (54)

Denote Djt=11qjjjqjR~jtRj(θt)F2D_{j}^{t}=\frac{1}{1-q_{j}}\sum_{j^{\prime}\neq j}q_{j^{\prime}}\left\lVert\tilde{R}^{t}_{j^{\prime}}-R_{j^{\prime}}(\theta^{t})\right\rVert_{F}^{2}, we have

T58(1qj)2A12A02(4A12A02θtθjt,e2+Djt)\displaystyle\begin{aligned} T_{5}&\leq 8(1-q_{j})^{2}A_{1}^{2}A_{0}^{2}\left(4A_{1}^{2}A_{0}^{2}\left\lVert\theta^{t}-\theta_{j}^{t,e}\right\rVert^{2}+D_{j}^{t}\right)\\ \end{aligned} (55)

Notice that

jqjθt,eθjt,e2=jqjθt,eθt+θtθjt,e2jqjθtθjt,e2\displaystyle\begin{aligned} \sum_{j}q_{j}\left\lVert\theta^{t,e}-\theta_{j}^{t,e}\right\rVert^{2}&=\sum_{j}q_{j}\left\lVert\theta^{t,e}-\theta^{t}+\theta^{t}-\theta_{j}^{t,e}\right\rVert^{2}\\ &\leq\sum_{j}q_{j}\left\lVert\theta^{t}-\theta_{j}^{t,e}\right\rVert^{2}\end{aligned} (56)

then, substitute eq. (48) with eq. (49), eq. (52) and eq. (55), we have

T13(β2jqjθt,eθjt,e2+16A12A06jqj3(12V+|𝒟j|/B1|𝒟j|1)+8A12A02jqj(1qj)2(4A12A02θtθjt,e2+Djt))3((β2+32A16A06)jqjθtθjt,e2+16A12A06(12V+maxj|𝒟j|/B1|𝒟j|1)+8A12A02jqjDjt)\displaystyle\begin{aligned} T_{1}&\leq 3\Biggl{(}\beta^{2}\sum_{j}q_{j}\left\lVert\theta^{t,e}-\theta_{j}^{t,e}\right\rVert^{2}+16A_{1}^{2}A_{0}^{6}\sum_{j}q_{j}^{3}\left(\frac{1}{2V}+\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\\ &+8A_{1}^{2}A_{0}^{2}\sum_{j}q_{j}(1-q_{j})^{2}\left(4A_{1}^{2}A_{0}^{2}\left\lVert\theta^{t}-\theta_{j}^{t,e}\right\rVert^{2}+D_{j}^{t}\right)\Biggr{)}\\ &\leq 3\Biggl{(}\left(\beta^{2}+32A_{1}^{6}A_{0}^{6}\right)\sum_{j}q_{j}\left\lVert\theta^{t}-\theta_{j}^{t,e}\right\rVert^{2}+16A_{1}^{2}A_{0}^{6}\left(\frac{1}{2V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)+8A_{1}^{2}A_{0}^{2}\sum_{j}q_{j}D_{j}^{t}\Biggr{)}\\ \end{aligned} (57)

B.4.2 Bounding the term T2T_{2}

T2=𝖵𝖺𝗋[vt,e|t,e]=𝖵𝖺𝗋[jqjvjt,e|t,e]=jqj2𝖵𝖺𝗋[vjt,e|t,e]\displaystyle\begin{aligned} T_{2}=\mathsf{Var}\left[v^{t,e}|\mathcal{F}^{t,e}\right]=\mathsf{Var}\left[\sum_{j}q_{j}v_{j}^{t,e}|\mathcal{F}^{t,e}\right]=\sum_{j}q_{j}^{2}\mathsf{Var}\left[v_{j}^{t,e}|\mathcal{F}^{t,e}\right]\end{aligned} (58)

Compare eq. (43) and eq. (44), we have

vjt,e[p]v¯jt,e[p]=pTr{R^j+({Xv}v,θjt,e)}+pTr{Rj+(θjt,e)}+qjTr{R^j({Xv}v,θjt,e)pR^j({Xv}v,θjt,e)}qj𝖤{Xv}v[Tr{R^j({Xv}v,θjt,e)pR^j({Xv}v,θjt,e)}]+(1qj)Tr{pR^j({Xv}v,θjt,e))R~jt}(1qj)Tr{pRj(θjt,e)R~jt}\displaystyle\begin{aligned} v_{j}^{t,e}[p]-\bar{v}_{j}^{t,e}[p]&=-\partial_{p}Tr\left\{\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}+\partial_{p}Tr\left\{R_{j}^{+}(\theta^{t,e}_{j})\right\}\\ &+q_{j}Tr\left\{\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}-q_{j}\mathsf{E}_{\{X_{v}\}_{v}}\left[Tr\left\{\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}\right]\\ &+(1-q_{j})Tr\left\{\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))\tilde{R}^{t}_{-j}\right\}-(1-q_{j})Tr\left\{\partial_{p}R_{j}(\theta^{t,e}_{j})\tilde{R}^{t}_{-j}\right\}\end{aligned} (59)
𝖵𝖺𝗋[vjt,e|t,e]=𝖤{Xv}v[vjt,ev¯jt,e2]3𝖤{Xv}vp(pTr{R^j+({Xv}v,θjt,e)}pTr{Rj+(θjt,e)})2T6+3qj2𝖤{Xv}vp(Tr{R^j({Xv}v,θjt,e)pR^j({Xv}v,θjt,e)}𝖤{Xv}v[Tr{R^j({Xv}v,θjt,e)pR^j({Xv}v,θjt,e)}])2T7+3(1qj)2𝖤{Xv}vp(Tr{pR^j({Xv}v,θjt,e))R~jt}Tr{pRj(θjt,e)R~jt})2T8\displaystyle\begin{aligned} &\mathsf{Var}\left[v_{j}^{t,e}|\mathcal{F}^{t,e}\right]=\mathsf{E}_{\{X_{v}\}_{v}}\left[\left\lVert v_{j}^{t,e}-\bar{v}_{j}^{t,e}\right\rVert^{2}\right]\\ &\leq 3\underbrace{\mathsf{E}_{\{X_{v}\}_{v}}\sum_{p}\left(\partial_{p}Tr\left\{\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}-\partial_{p}Tr\left\{R_{j}^{+}(\theta^{t,e}_{j})\right\}\right)^{2}}_{T_{6}}\\ &+3q_{j}^{2}\underbrace{\mathsf{E}_{\{X_{v}\}_{v}}\sum_{p}\left(Tr\left\{\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}-\mathsf{E}_{\{X_{v}\}_{v}}\left[Tr\left\{\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}\right]\right)^{2}}_{T_{7}}\\ &+3(1-q_{j})^{2}\underbrace{\mathsf{E}_{\{X_{v}\}_{v}}\sum_{p}\left(Tr\left\{\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j}))\tilde{R}^{t}_{-j}\right\}-Tr\left\{\partial_{p}R_{j}(\theta^{t,e}_{j})\tilde{R}^{t}_{-j}\right\}\right)^{2}}_{T_{8}}\end{aligned} (60)

For term T6T_{6}, we have

T6𝖤X¯𝒟j𝖤{Xv}v|X¯p(pTr{R^j+({Xv}v,θjt,e)}pTr{R+(X¯,θjt,e)}+pTr{R+(X¯,θjt,e)}pTr{Rj+(θjt,e)})2=𝖤X¯𝒟j𝖤{Xv}v|X¯p(pTr{R^j+({Xv}v,θjt,e)}pTr{R+(X¯,θjt,e)})2+𝖤X¯𝒟j(pTr{R+(X¯,θjt,e)}pTr{Rj+(θjt,e)})2=p(𝖤X¯𝒟j𝖵𝖺𝗋{Xv}v|X¯[pTr{R^j+({Xv}v,θjt,e)}]+𝖵𝖺𝗋X¯𝒟j[pTr{R+(X¯,θjt,e)}])4A12A02(1V+|𝒟j|/B1|𝒟j|1)\displaystyle\begin{aligned} T_{6}&\leq\mathsf{E}_{\bar{X}\sim\mathcal{D}_{j}}\mathsf{E}_{\{X_{v}\}_{v}|\bar{X}}\sum_{p}\Bigl{(}\partial_{p}Tr\left\{\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}-\partial_{p}Tr\left\{R^{+}(\bar{X},\theta_{j}^{t,e})\right\}\\ &+\partial_{p}Tr\left\{R^{+}(\bar{X},\theta_{j}^{t,e})\right\}-\partial_{p}Tr\left\{R_{j}^{+}(\theta_{j}^{t,e})\right\}\Bigr{)}^{2}\\ &=\mathsf{E}_{\bar{X}\sim\mathcal{D}_{j}}\mathsf{E}_{\{X_{v}\}_{v}|\bar{X}}\sum_{p}\left(\partial_{p}Tr\left\{\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}-\partial_{p}Tr\left\{R^{+}(\bar{X},\theta_{j}^{t,e})\right\}\right)^{2}\\ &+\mathsf{E}_{\bar{X}\sim\mathcal{D}_{j}}\left(\partial_{p}Tr\left\{R^{+}(\bar{X},\theta_{j}^{t,e})\right\}-\partial_{p}Tr\left\{R_{j}^{+}(\theta_{j}^{t,e})\right\}\right)^{2}\\ &=\sum_{p}\left(\mathsf{E}_{\bar{X}\sim\mathcal{D}_{j}}\mathsf{Var}_{\{X_{v}\}_{v}|\bar{X}}\left[\partial_{p}Tr\left\{\hat{R}_{j}^{+}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}\right]+\mathsf{Var}_{\bar{X}\sim\mathcal{D}_{j}}\left[\partial_{p}Tr\left\{R^{+}(\bar{X},\theta_{j}^{t,e})\right\}\right]\right)\\ &\leq 4A_{1}^{2}A_{0}^{2}\left(\frac{1}{V}+\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\end{aligned} (61)

where the first equality uses 𝖤X𝖤[X]2𝖤X2\mathsf{E}\left\lVert X-\mathsf{E}[X]\right\rVert^{2}\leq\mathsf{E}\left\lVert X\right\rVert^{2}; the last inequality uses the variance under sampling without replacement.

For the term T7T_{7}, we have

T7=p𝖵𝖺𝗋{Xv}v[Tr{R^j({Xv}v,θjt,e)pR^j({Xv}v,θjt,e)}]2Rj(θjt,e)F2p𝖵𝖺𝗋{Xv}v[pR^j({Xv}v,θjt,e)]+2ppRj(θjt,e)F2𝖵𝖺𝗋{Xv}v[R^j({Xv}v,θjt,e)]2A04p𝖵𝖺𝗋{Xv}v[pR^j({Xv}v,θjt,e)]+8A12A02𝖵𝖺𝗋{Xv}v[R^j({Xv}v,θjt,e)]\displaystyle\begin{aligned} T_{7}&=\sum_{p}\mathsf{Var}_{\{X_{v}\}_{v}}\left[Tr\left\{\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right\}\right]\\ &\leq 2\left\lVert R_{j}(\theta_{j}^{t,e})\right\rVert^{2}_{F}\sum_{p}\mathsf{Var}_{\{X_{v}\}_{v}}\left[\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right]+2\sum_{p}\left\lVert\partial_{p}R_{j}(\theta_{j}^{t,e})\right\rVert_{F}^{2}\mathsf{Var}_{\{X_{v}\}_{v}}\left[\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right]\\ &\leq 2A_{0}^{4}\sum_{p}\mathsf{Var}_{\{X_{v}\}_{v}}\left[\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right]+8A_{1}^{2}A_{0}^{2}\mathsf{Var}_{\{X_{v}\}_{v}}\left[\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right]\end{aligned} (62)

Notice that

p𝖵𝖺𝗋{Xv}v[pR^j({Xv}v,θjt,e)]=p𝖤X¯𝒟j𝖵𝖺𝗋{Xv}v|X¯[pR^j({Xv}v,θjt,e)]+p𝖵𝖺𝗋X¯𝒟j[pR(X¯;θjt,e)]4A12A02(12V+|𝒟j|/B1|𝒟j|1)\displaystyle\begin{aligned} &\sum_{p}\mathsf{Var}_{\{X_{v}\}_{v}}\left[\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right]\\ &=\sum_{p}\mathsf{E}_{\bar{X}\sim\mathcal{D}_{j}}\mathsf{Var}_{\{X_{v}\}_{v}|\bar{X}}\left[\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right]+\sum_{p}\mathsf{Var}_{\bar{X}\sim\mathcal{D}_{j}}\left[\partial_{p}R(\bar{X};\theta_{j}^{t,e})\right]\\ &\leq 4A_{1}^{2}A_{0}^{2}\left(\frac{1}{2V}+\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\end{aligned} (63)

where the second line uses 𝖵𝖺𝗋[Y]=𝖤X𝖵𝖺𝗋[Y|X]+𝖵𝖺𝗋[𝖤[Y|X]]\mathsf{Var}[Y]=\mathsf{E}_{X}\mathsf{Var}[Y|X]+\mathsf{Var}[\mathsf{E}[Y|X]]; the third line uses 𝖵𝖺𝗋[X]𝖤XF2\mathsf{Var}[X]\leq\mathsf{E}\left\lVert X\right\rVert_{F}^{2}, Lemma B.4, and sampling with and without replacement. Similarly, we have

𝖵𝖺𝗋{Xv}v[R^j({Xv}v,θjt,e)]=𝖤X¯𝒟j𝖵𝖺𝗋{Xv}v|X¯[R^j({Xv}v,θjt,e)]+𝖵𝖺𝗋X¯𝒟j[R(X¯;θjt,e)]A04(12V+|𝒟j|/B1|𝒟j|1)\displaystyle\begin{aligned} &\mathsf{Var}_{\{X_{v}\}_{v}}\left[\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right]\\ &=\mathsf{E}_{\bar{X}\sim\mathcal{D}_{j}}\mathsf{Var}_{\{X_{v}\}_{v}|\bar{X}}\left[\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right]+\mathsf{Var}_{\bar{X}\sim\mathcal{D}_{j}}\left[R(\bar{X};\theta_{j}^{t,e})\right]\\ &\leq A_{0}^{4}\left(\frac{1}{2V}+\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\end{aligned} (64)

Plug eq. (63) and eq. (64) into eq. (62), we have

T78A12A06(12V+|𝒟j|/B1|𝒟j|1)\displaystyle T_{7}\leq 8A_{1}^{2}A_{0}^{6}\left(\frac{1}{2V}+\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right) (65)

For the term T8T_{8}, using eq. (63) we have

T8p𝖵𝖺𝗋{Xv}v[pR^j({Xv}v,θjt,e)]R~jtF24A12A02(12V+|𝒟j|/B1|𝒟j|1)R~jtF2\displaystyle\begin{aligned} T_{8}&\leq\sum_{p}\mathsf{Var}_{\{X_{v}\}_{v}}\left[\partial_{p}\hat{R}_{j}(\{X_{v}\}_{v},\theta^{t,e}_{j})\right]\left\lVert\tilde{R}^{t}_{-j}\right\rVert^{2}_{F}\\ &\leq 4A_{1}^{2}A_{0}^{2}\left(\frac{1}{2V}+\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\left\lVert\tilde{R}^{t}_{-j}\right\rVert^{2}_{F}\end{aligned} (66)

Substitute eq. (58) with eq. (60), eq. (61), eq. (65) and eq. (66), we have

T212A12A02jqj2(1V+|𝒟j|/B1|𝒟j|1)+24A12A06jqj4(1V+|𝒟j|/B1|𝒟j|1)+12A12A02jqj2(1qj)2(1V+|𝒟j|/B1|𝒟j|1)R~jtF2(1V+maxj|𝒟j|/B1|𝒟j|1)(12A12A02+24A12A06+12A12A02jqj2(1qj)2R~jtF2)\displaystyle\begin{aligned} T_{2}&\leq 12A_{1}^{2}A_{0}^{2}\sum_{j}q_{j}^{2}\left(\frac{1}{V}+\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)+24A_{1}^{2}A_{0}^{6}\sum_{j}q_{j}^{4}\left(\frac{1}{V}+\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\\ &+12A_{1}^{2}A_{0}^{2}\sum_{j}q_{j}^{2}(1-q_{j})^{2}\left(\frac{1}{V}+\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\left\lVert\tilde{R}^{t}_{-j}\right\rVert^{2}_{F}\\ &\leq\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\left(12A_{1}^{2}A_{0}^{2}+24A_{1}^{2}A_{0}^{6}+12A_{1}^{2}A_{0}^{2}\sum_{j}q^{2}_{j}(1-q_{j})^{2}\left\lVert\tilde{R}^{t}_{-j}\right\rVert^{2}_{F}\right)\end{aligned} (67)

B.4.3 Combine the results

Take expectation on both sides of eq. (41) conditioned on t,0\mathcal{F}^{t,0}, we have

𝖤[(θt,e+1)|t,0]𝖤[(θt,e)|t,0]η2𝖤[(θt,e)2|t,0]+η2𝖤[T1|t,0]+𝖤[T2|t,0].\displaystyle\begin{aligned} \mathsf{E}\!\left[\mathcal{L}(\theta^{t,e+1})|\mathcal{F}^{t,0}\right]\leq\mathsf{E}\!\left[\mathcal{L}(\theta^{t,e})|\mathcal{F}^{t,0}\right]-\frac{\eta}{2}\mathsf{E}\!\left[\left\lVert\nabla\mathcal{L}(\theta^{t,e})\right\rVert^{2}|\mathcal{F}^{t,0}\right]+\frac{\eta}{2}\mathsf{E}\!\left[T_{1}|\mathcal{F}^{t,0}\right]+\mathsf{E}\!\left[T_{2}|\mathcal{F}^{t,0}\right].\end{aligned} (68)

Notice that

𝖤[θtθjt,e2|t,0]=𝖤[e=0e1ηvjt,e22|t,0]Ee=0e1η2𝖤[vjt,e22|t,0]E2η2(8A12A02+8A12A02(A04+H2σ2))\displaystyle\begin{aligned} \mathsf{E}\!\left[\left\lVert\theta^{t}-\theta_{j}^{t,e}\right\rVert^{2}|\mathcal{F}^{t,0}\right]&=\mathsf{E}\!\left[\left\lVert\sum_{e^{\prime}=0}^{e-1}\eta v_{j}^{t,e^{\prime}}\right\rVert^{2}_{2}|\mathcal{F}^{t,0}\right]\\ &\leq E\sum_{e^{\prime}=0}^{e-1}\eta^{2}\mathsf{E}\!\left[\left\lVert v_{j}^{t,e}\right\rVert_{2}^{2}|\mathcal{F}^{t,0}\right]\\ &\leq E^{2}\eta^{2}(8A_{1}^{2}A_{0}^{2}+8A_{1}^{2}A_{0}^{2}(A_{0}^{4}+H^{2}\sigma^{2}))\end{aligned} (69)

where the last line uses Lemma B.8. Also notice that

𝖤[Djt|t,0]=11qjjjqj𝖤[R~jtRj(θt)F2|t,0]=11qjjjqj[𝖤[R~jtRˇj(θt)F2|t,0]+𝖤[RˇjRj(θt)F2|t,0]]=H2σ2+12VA04\displaystyle\begin{aligned} \mathsf{E}\!\left[D_{j}^{t}|\mathcal{F}^{t,0}\right]&=\frac{1}{1-q_{j}}\sum_{j^{\prime}\neq j}q_{j^{\prime}}\mathsf{E}\!\left[\left\lVert\tilde{R}^{t}_{j^{\prime}}-R_{j^{\prime}}(\theta^{t})\right\rVert_{F}^{2}|\mathcal{F}^{t,0}\right]\\ &=\frac{1}{1-q_{j}}\sum_{j^{\prime}\neq j}q_{j^{\prime}}\left[\mathsf{E}\!\left[\left\lVert\tilde{R}^{t}_{j^{\prime}}-\check{R}_{j^{\prime}}(\theta^{t})\right\rVert_{F}^{2}|\mathcal{F}^{t,0}\right]+\mathsf{E}\!\left[\left\lVert\check{R}_{j^{\prime}}-R_{j^{\prime}}(\theta^{t})\right\rVert_{F}^{2}|\mathcal{F}^{t,0}\right]\right]\\ &=H^{2}\sigma^{2}+\frac{1}{2V}A_{0}^{4}\end{aligned} (70)

where Rˇj\check{R}_{j^{\prime}} is the empirical correlation matrix before applying DP noise. Then we have

𝖤[T1|t,0]3((β2+32A16A06)E2η2(8A12A02+8A12A02(A04+H2σ2))+16A12A06(1V+maxj|𝒟j|/B1|𝒟j|1)+8A12A02(H2σ2+12VA04)).\displaystyle\begin{aligned} \mathsf{E}\!\left[T_{1}|\mathcal{F}^{t,0}\right]&\leq 3\Biggl{(}\left(\beta^{2}+32A_{1}^{6}A_{0}^{6}\right)E^{2}\eta^{2}\left(8A_{1}^{2}A_{0}^{2}+8A_{1}^{2}A_{0}^{2}\left(A_{0}^{4}+H^{2}\sigma^{2}\right)\right)\\ &+16A_{1}^{2}A_{0}^{6}\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)+8A_{1}^{2}A_{0}^{2}\left(H^{2}\sigma^{2}+\frac{1}{2V}A_{0}^{4}\right)\Biggr{)}.\end{aligned} (71)

For the term 𝖤[T2|t,0]\mathsf{E}\!\left[T_{2}|\mathcal{F}^{t,0}\right], we have

𝖤[R~jtF2|t,0]=𝖤[R¯jtF2|t,0]+𝖤[njtF2|t,0]=A04+1(1qj)2jjqj2H2σ2\displaystyle\begin{aligned} \mathsf{E}\!\left[\left\lVert\tilde{R}^{t}_{-j}\right\rVert^{2}_{F}|\mathcal{F}^{t,0}\right]&=\mathsf{E}\!\left[\left\lVert\bar{R}^{t}_{-j}\right\rVert^{2}_{F}|\mathcal{F}^{t,0}\right]+\mathsf{E}\!\left[\left\lVert n^{t}_{-j}\right\rVert^{2}_{F}|\mathcal{F}^{t,0}\right]\\ &=A_{0}^{4}+\frac{1}{(1-q_{j})^{2}}\sum_{j^{\prime}\neq j}q_{j}^{\prime 2}H^{2}\sigma^{2}\end{aligned} (72)
𝖤[T2|t,0](1V+maxj|𝒟j|/B1|𝒟j|1)(12A12A02+26A12A06+12A12A02H2σ2)\displaystyle\begin{aligned} \mathsf{E}\!\left[T_{2}|\mathcal{F}^{t,0}\right]\leq\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\left(12A_{1}^{2}A_{0}^{2}+26A_{1}^{2}A_{0}^{6}+12A_{1}^{2}A_{0}^{2}H^{2}\sigma^{2}\right)\end{aligned} (73)

where we use fact qj(1qj)2427q_{j}(1-q_{j})^{2}\leq\frac{4}{27}. Combine eq. (68), eq. (71) and eq. (73), we have

𝖤[(θt,e+1)|t,0]𝖤[(θt,e)|t,0]η2𝖤[(θt,e)2|t,0]+η32C1E2(H2σ2+C2)+η22C3(1V+maxj|𝒟j|/B1|𝒟j|1)(H2σ2+C4)+η2C5(1V+maxj|𝒟j|/B1|𝒟j|1)+η2C6(H2σ2+1VC7)\displaystyle\begin{aligned} &\mathsf{E}\!\left[\mathcal{L}(\theta^{t,e+1})|\mathcal{F}^{t,0}\right]-\mathsf{E}\!\left[\mathcal{L}(\theta^{t,e})|\mathcal{F}^{t,0}\right]\\ &\leq-\frac{\eta}{2}\mathsf{E}\!\left[\left\lVert\nabla\mathcal{L}(\theta^{t,e})\right\rVert^{2}|\mathcal{F}^{t,0}\right]+\frac{\eta^{3}}{2}C_{1}E^{2}(H^{2}\sigma^{2}+C_{2})+\frac{\eta^{2}}{2}C_{3}\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\left(H^{2}\sigma^{2}+C_{4}\right)\\ &+\frac{\eta}{2}C_{5}\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)+\frac{\eta}{2}C_{6}\left(H^{2}\sigma^{2}+\frac{1}{V}C_{7}\right)\end{aligned} (74)

where C1,C2,,C7C_{1},C_{2},...,C_{7} are constant depending A0A_{0}, A1A_{1} and A2A_{2}. Telescope eq. (74) and take expectation, we have

1TEt=0T1e=0E1𝖤[(θt,e)2]2ηTE(𝖤[(θt,0)]𝖤[(θt,E)])+η2C1E2(H2σ2+C2)+ηC3(1V+maxj|𝒟j|/B1|𝒟j|1)(H2σ2+C4)+C5(1V+maxj|𝒟j|/B1|𝒟j|1)+C6(H2σ2+1VC7).\displaystyle\begin{aligned} &\frac{1}{TE}\sum_{t=0}^{T-1}\sum_{e=0}^{E-1}\mathsf{E}\!\left[\left\lVert\nabla\mathcal{L}(\theta^{t,e})\right\rVert^{2}\right]\\ &\leq\frac{2}{\eta TE}\left(\mathsf{E}\!\left[\mathcal{L}(\theta^{t,0})\right]-\mathsf{E}\!\left[\mathcal{L}(\theta^{t,E})\right]\right)+\eta^{2}C_{1}E^{2}(H^{2}\sigma^{2}+C_{2})+\eta C_{3}\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\left(H^{2}\sigma^{2}+C_{4}\right)\\ &+C_{5}\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)+C_{6}\left(H^{2}\sigma^{2}+\frac{1}{V}C_{7}\right).\end{aligned} (75)

Recall the choice η=𝒪(1TE)\eta=\mathcal{O}\left(\frac{1}{\sqrt{TE}}\right), we have

1TEt=0T1e=0E1𝖤[(θt,e)2]𝒪(E2(H2σ2+C2)TE+(θ0)+(1V+maxj|𝒟j|/B1|𝒟j|1)(H2σ2+C4)TE+1V+maxj|𝒟j|/B1|𝒟j|1+H2σ2)\displaystyle\begin{aligned} &\frac{1}{TE}\sum_{t=0}^{T-1}\sum_{e=0}^{E-1}\mathsf{E}\!\left[\left\lVert\nabla\mathcal{L}(\theta^{t,e})\right\rVert^{2}\right]\\ &\leq\mathcal{O}\left(\frac{E^{2}(H^{2}\sigma^{2}+C_{2})}{TE}+\frac{\mathcal{L}(\theta^{0})+\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\left(H^{2}\sigma^{2}+C_{4}\right)}{\sqrt{TE}}+\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}+H^{2}\sigma^{2}\right)\end{aligned} (76)

B.5 Partial Participation Case

Partial participation results in perturbation in aggregation.

𝖤[(θ𝒥tt,E)2(θt,E)2|𝒥t]=2𝖤𝒥t[T(θt,E+λh)2(θt,E+λh,𝒟)h]2Wβ𝖤𝒥t[θt,Eθ𝒥tt,E]2Wβ𝖤𝒥t[θt,Eθ𝒥tt,E22]2ηβEW2J/|𝒥t|1J1\displaystyle\begin{aligned} &\mathsf{E}\left[\left\lVert\nabla\mathcal{L}(\theta^{t,E}_{\mathcal{J}^{t}})\right\rVert^{2}-\left\lVert\nabla\mathcal{L}(\theta^{t,E})\right\rVert^{2}|\mathcal{J}_{t}\right]\\ &=2\mathsf{E}_{\mathcal{J}_{t}}\left[\nabla\mathcal{L}^{T}(\theta^{t,E}+\lambda h)\nabla^{2}\mathcal{L}(\theta^{t,E}+\lambda h,\mathcal{D})h\right]\\ &\leq 2W\beta\mathsf{E}_{\mathcal{J}_{t}}\left[\left\lVert\theta^{t,E}-\theta^{t,E}_{\mathcal{J}^{t}}\right\rVert\right]\\ &\leq 2W\beta\sqrt{\mathsf{E}_{\mathcal{J}^{t}}\left[\left\lVert\theta^{t,E}-\theta^{t,E}_{\mathcal{J}^{t}}\right\rVert_{2}^{2}\right]}\\ &\leq 2\eta\beta EW^{2}\sqrt{\frac{J/|\mathcal{J}^{t}|-1}{J-1}}\end{aligned} (77)

where one can easily verify that W=8A12A02+8A12A06W=8A_{1}^{2}A_{0}^{2}+8A_{1}^{2}A_{0}^{6} from Lemma B.8 servers a bound for gradient norm; the second line uses mean-value theorem. Another aspect is that R~j\tilde{R}_{j} is less frequently updated on sever. Therefore the term 𝖤[Djt|t,0]\mathsf{E}\!\left[D_{j}^{t}|\mathcal{F}^{t,0}\right] should involve an additional term accounting to aging of correlation matrix,

𝖤[Djt|t,0]H2σ2+12VA04+C8η2.\displaystyle\begin{aligned} \mathsf{E}\!\left[D_{j}^{t}|\mathcal{F}^{t,0}\right]\leq H^{2}\sigma^{2}+\frac{1}{2V}A_{0}^{4}+C_{8}\eta^{2}.\end{aligned} (78)

The reason is that the difference between the current and old correlation matrix is proportional to the distance between the current and old variables (shown in eq. (54)), which is proportional to E2η2E^{2}\eta^{2} (shown in eq. (69)). Thus we finally have

1TEt=0T1e=0E1𝖤[(θt,e)2]𝒪(E2(H2σ2+C2)TE+(θ0)+(1V+maxj|𝒟j|/B1|𝒟j|1)(H2σ2+C4)+EJ/|𝒥t|1J1TE+1V+maxj|𝒟j|/B1|𝒟j|1+H2σ2)\displaystyle\begin{aligned} \frac{1}{TE}\sum_{t=0}^{T-1}\sum_{e=0}^{E-1}\mathsf{E}\!\left[\left\lVert\nabla\mathcal{L}(\theta^{t,e})\right\rVert^{2}\right]&\leq\mathcal{O}\Biggl{(}\frac{E^{2}(H^{2}\sigma^{2}+C_{2})}{TE}\\ &+\frac{\mathcal{L}(\theta^{0})+\left(\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}\right)\left(H^{2}\sigma^{2}+C_{4}\right)+E\sqrt{\frac{J/|\mathcal{J}^{t}|-1}{J-1}}}{\sqrt{TE}}\\ &+\frac{1}{V}+\max_{j}\frac{|\mathcal{D}_{j}|/B-1}{|\mathcal{D}_{j}|-1}+H^{2}\sigma^{2}\Biggr{)}\end{aligned} (79)
Table 6: Implementation details of FedSC with DP protection: full client participation
μ\mu σ\sigma round indices local dataset size
SVHN (ϵ=3,δ=102)(\epsilon=3,\delta=10^{-2}) 22 0.00340.0034 t>100t>100 10,00010,000
SVHN (ϵ=6,δ=102)(\epsilon=6,\delta=10^{-2}) 22 0.00180.0018 t>100t>100 10,00010,000
SVHN (ϵ=3,δ=104)(\epsilon=3,\delta=10^{-4}) 22 0.00480.0048 t>100t>100 10,00010,000
SVHN (ϵ=8,δ=104)(\epsilon=8,\delta=10^{-4}) 22 0.00180.0018 t>100t>100 10,00010,000
CIFAR10 (ϵ=3,δ=102)(\epsilon=3,\delta=10^{-2}) 44 0.010.01 t>150t>150 5,0005,000
CIFAR10 (ϵ=6,δ=102)(\epsilon=6,\delta=10^{-2}) 44 0.00520.0052 t>100,t%2=0t>100,t\%2=0 5,0005,000
CIFAR10 (ϵ=3,δ=104)(\epsilon=3,\delta=10^{-4}) 44 0.0120.012 t>150t>150 5,0005,000
CIFAR10 (ϵ=8,δ=104)(\epsilon=8,\delta=10^{-4}) 44 0.00510.0051 t>100,t%2=0t>100,t\%2=0 5,0005,000
CIFAR100 (ϵ=6,δ=102)(\epsilon=6,\delta=10^{-2}) 55 0.0130.013 t>100,t%2=0t>100,t\%2=0 2,5002,500
CIFAR100(ϵ=12,δ=102)(\epsilon=12,\delta=10^{-2}) 55 0.00750.0075 t>100,t%2=0t>100,t\%2=0 2,5002,500
CIFAR100 (ϵ=3,δ=104)(\epsilon=3,\delta=10^{-4}) 55 0.040.04 t>100,t%2=0t>100,t\%2=0 2,5002,500
CIFAR100(ϵ=8,δ=104)(\epsilon=8,\delta=10^{-4}) 55 0.0130.013 t>100,t%2=0t>100,t\%2=0 2,5002,500

Appendix C Detailed Implementation of FedSC

Recall the local objective

jSC(θ;R¯j)=Tr{Rj+(θ)}+12αjRj(θ)F2+(1αj)Tr{Rj(θ)R¯j}\displaystyle\begin{aligned} \mathcal{L}^{SC}_{j}(\theta;\bar{R}_{-j})&=-Tr\{R^{+}_{j}(\theta)\}+\frac{1}{2}\alpha_{j}\left\lVert R_{j}(\theta)\right\rVert_{F}^{2}+(1-\alpha_{j})Tr\{R_{j}(\theta)\bar{R}_{-j}\}\end{aligned} (80)

here we replace qjq_{j} with a general coefficient αj\alpha_{j}, and decay it linearly from 11 to 0.20.2 along with communication round indices. The behind motivation is as follows. At the beginning of the training, moving direction from the global objective and the average local objective tend to align closely. Moreover, the correlation matrices of clients are not yet stable at this stage, making it less critical to at early stages. Therefore, we choose large αj\alpha_{j} for quicker start. Conversely, correlation matrices converges and becomes stable at the end of training, thus we give the inter-client contrast larger weights, i.e., smaller αj\alpha_{j}.

We also make modifications when DP protection is applied. Based on the above analysis, we start sharing at the middle or late stages of the training to save privacy budgets. Following are the detailed implementation details. For partial client participation, we only change σ\sigma according to the ratio of participation.