This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Bayesian Deep Learning Via Expectation Maximization and Turbo Deep Approximate Message Passing

Wei Xu, An Liu, Senior Member, IEEE, Yiting Zhang and Vincent Lau, Fellow, IEEE Wei Xu, An Liu and Yiting Zhang are with the College of Information Science and Electronic Engineering, Zhejiang University, Hangzhou 310027, China (email: anliu@zju.edu.cn). Vincent Lau is with the Department of ECE, The Hong Kong University of Science and Technology (email: eeknlau@ust.hk).
Abstract

Efficient learning and model compression algorithm for deep neural network (DNN) is a key workhorse behind the rise of deep learning (DL). In this work, we propose a message passing based Bayesian deep learning algorithm called EM-TDAMP to avoid the drawbacks of traditional stochastic gradient descent (SGD) based learning algorithms and regularization-based model compression methods. Specifically, we formulate the problem of DNN learning and compression as a sparse Bayesian inference problem, in which group sparse prior is employed to achieve structured model compression. Then, we propose an expectation maximization (EM) framework to estimate posterior distributions for parameters (E-step) and update hyperparameters (M-step), where the E-step is realized by a newly proposed turbo deep approximate message passing (TDAMP) algorithm. We further extend the EM-TDAMP and propose a novel Bayesian federated learning framework, in which and the clients perform TDAMP to efficiently calculate the local posterior distributions based on the local data, and the central server first aggregates the local posterior distributions to update the global posterior distributions and then update hyperparameters based on EM to accelerate convergence. We detail the application of EM-TDAMP to Boston housing price prediction and handwriting recognition, and present extensive numerical results to demonstrate the advantages of EM-TDAMP.

Index Terms:
Bayesian deep learning, DNN model compression, expectation maximization, turbo deep approximate message passing, Bayesian federated learning

I Introduction

Deep learning (DL) has become increasingly important in various artificial intelligence (AI) applications. In DL, a deep neural network (DNN), which is a type of neural network modeled as a multilayer perceptron (MLP), is trained with algorithms to learn representations from data sets without any manual design of feature extractors. It is well known that the training algorithm is one of the pillars behind the success of DL. Traditional deep learning methods first construct a loss function (e.g. mean square error (MSE), cross-entropy) and then iteratively update parameters through back propagation (BP) and stochastic gradient descent (SGD). Furthermore, to mitigate the computational load in DNN inference for large models, researchers have proposed several model compression techniques. Early regularization methods generate networks with random sparse connectivity, requiring high-dimensional matrices. Group sparse regularization has been introduced to eliminate redundant neurons, features and filters [13, 14]. A recent work has addressed neuron-wise, feature-wise and filter-wise groupings within a single sparse regularization term [15].

However, the traditional deep learning and model compression methods have several drawbacks. For example, for regularization-based pruning methods, it is difficult to achieve the exact compression ratio after training. Another drawback is their tendency to be overconfident in their predictions, which can be problematic in applications such as autonomous driving and medical diagnostics [38, 47, 33], where silent failure can lead to dramatic outcomes. To overcome the problems, Bayesian deep learning has been proposed, allowing for uncertainty quantification [19]. Bayesian deep learning formulates DNN training as a Bayesian inference problem, where the DNN parameters with a prior distribution serve as hypotheses, and the training set 𝑫\boldsymbol{D} consists of features 𝑫x\boldsymbol{D}_{x} and labels 𝑫y\boldsymbol{D}_{y}. Calculating the exact Bayesian posterior distribution p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) for a DNN is extremely challenging, and a widely used method is variational Bayesian inference (VBI), where a variational distribution qφ(𝜽)q_{\varphi}\left(\boldsymbol{\theta}\right) with parameters φ\varphi is proposed to approximate the exact posterior p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) [36]. However, most VBI algorithms still rely on SGD for optimizing variational distribution parameters φ\varphi, where loss function is often defined as the Kullback-Leibler divergence between qφ(𝜽)q_{\varphi}\left(\boldsymbol{\theta}\right) and p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) [36, 42].

The abovementioned training methods are all based on SGD, and thus have several limitations, including vanishing and exploding gradients [9, 10], the risk of getting stuck in suboptimal solutions [11], and slow convergence. Although there have been attempts to address these issues through advanced optimization techniques like Adam [7], the overall convergence remains slow for high training accuracy requirements, which limits the application scenarios of SGD-based DL. To avoid drawbacks of SGD, [27] utilizes message-passing algorithms, e.g. Belief Propagation (BP), BP-Inspired (BPI) message passing, mean-field (MF), and approximate message passing (AMP) for training. The experiments show that those message-passing based algorithms have similar performance, and are slightly better than SGD based baseline in some cases. However, the existing message-passing algorithms in [27] cannot achieve efficient model compression and may have numerical stability issues.

In recent years, federated learning is becoming a main scenario in deep learning applications with the development of computation power of edge devices. Federated learning (FL) is a machine learning paradigm where the clients train models with decentralized data and the central server handles aggregation and scheduling. Modern federated learning methods typically perform the following three steps iteratively [37].

1. Broadcast: The central server sends current model parameters and a training program to clients.

2. Local training: Each client locally computes an update to the model by executing the training program, which might for example run SGD on the local data.

3. Aggregation: The server aggregates local results and update the global model using e.g., the federated averaging (FedAvg) [8] or its variations [39, 35, 48].

However, most existing federated learning algorithms inherent abovementioned drawbacks because the local training still relies on the traditional deep learning methods.

To overcome the drawbacks of existing deep learning, model compression and federated learning methods, we propose a novel message passing based Bayesian deep learning algorithm called Expectation Maximization and Turbo Deep Approximate Message Passing (EM-TDAMP). The main contributions are summarized as follows.

  • We propose a novel Bayesian deep learning algorithm EM-TDAMP to enable efficient learning and structured compression for DNN: Firstly, we formulate the DNN learning problem as Bayesian inference of the DNN parameters. Then we propose a group sparse prior distribution to achieve efficient neuron-level pruning during training. We further incorporate zero-mean Gaussian noise in the likelihood function to control the learning rate through noise variance. The proposed Bayesian deep learning algorithm EM-TDAMP is based on expectation maximization (EM) framework, where E-step estimates the posterior distribution, and M-step adaptively updates hyperparameters. In E-step, we cannot directly apply the standard sum-product rule due to the existence of many loops in the DNN factor graph and the high computational complexity. Although various approximate message passing methods have been proposed to reduce the complexity of message passing in the compressed sensing literature [51, 52], to the best of our knowledge, there is no efficient message passing algorithm available for training the DNN with both multiple layers and structured sparse parameters. Therefore, we propose a new TDAMP algorithm to realize the E-step, which iterates between two Modules: Module B performs message passing over the group sparse prior distribution, and Module A performs deep approximate message passing (DAMP) over the DNN using independent prior distribution from Module B. The proposed EM-TDAMP overcomes the aforementioned drawbacks of SGD-based training algorithms, showing faster convergence and superior inference performance in simulations. It also improves the AMP based training methods in [27] in several aspects: we introduce group sparse prior and utilize turbo framework to enable structured model compression; we propose zero-mean Gaussian noise at output and construct a soft likelihood function to ensure numerical stability; we update prior parameters and noise variance via EM to accelerate convergence.

  • We propose a Bayesian federated learning framework by extending the EM-TDAMP algorithm to federated learning scenarios: The proposed framework also contains the above mentioned three steps (Broadcast, Local training, Aggregation). In step 1 (Broadcast), the central server broadcasts hyperparameters in prior distribution and likelihood function to clients. In step 2 (Local training), each client performs TDAMP to compute local posterior distribution. In step 3 (Aggregation), the central server aggregates local posterior parameters and updates hyperparameters via EM. Compared to the conventional FedAvg [8], the proposed Bayesian federated learning framework achieves more structured sparsity and reduces communication rounds as shown in simulations.

The rest of the paper is organized as follows. Section II presents the problem formulation for Bayesian deep learning with structured model compression. Section III derives the EM-TDAMP algorithm and discusses various implementation issues. Section IV extends the proposed EM-TDAMP to federated learning scenarios. Section V details the application of EM-TDAMP to Boston housing price prediction and handwriting recognition. Finally, the conclusion is given in Section VI.

II Problem Formulation for Bayesian Deep Learning

II-A DNN Model and Standard Training Procedure

A general DNN consists of one input layer, multiple hidden layers, and one output layer. In this paper, we focus on feedforward DNNs for easy illustration. Let 𝒛L=ϕ(𝒖0;𝜽)\boldsymbol{z}_{L}=\phi\left(\boldsymbol{u}_{0};\boldsymbol{\theta}\right) be a DNN with LL layers that maps the input vector 𝒖0=𝒙N0\boldsymbol{u}_{0}=\boldsymbol{x}\in\mathbb{R}^{N_{0}} to the output vector 𝒛LNL\boldsymbol{z}_{L}\in\mathbb{R}^{N_{L}} with a set of parameters 𝜽\boldsymbol{\theta}. The input and output of each layer, denoted as 𝒖l1Nl1\boldsymbol{u}_{l-1}\in\mathbb{R}^{N_{l-1}} and 𝒛lNl\boldsymbol{z}_{l}\in\mathbb{R}^{N_{l}} respectively, can be expressed as follows:

𝒛l=𝑾l𝒖l1+𝒃l,l=1,,L,\boldsymbol{z}_{l}=\boldsymbol{W}_{l}\boldsymbol{u}_{l-1}+\boldsymbol{b}_{l},l=1,\ldots,L,
𝒖l=ζl(𝒛l),l=1,,L1,\boldsymbol{u}_{l}=\zeta_{l}\left(\boldsymbol{z}_{l}\right),l=1,\ldots,L-1,

where 𝑾lNl×Nl1\boldsymbol{W}_{l}\in\mathbb{R}^{N_{l}\times N_{l-1}}, 𝒃lNl\boldsymbol{b}_{l}\in\mathbb{R}^{N_{l}} and ζl()\zeta_{l}\left(\cdot\right) account for the weight matrix, the bias vector and the activation function in layer ll, respectively. As is widely used, we set ζl()\zeta_{l}\left(\cdot\right) as rectified linear units (ReLU) defined as:

ζl(z)={zz>00z0.\zeta_{l}\left(z\right)=\begin{cases}z&z>0\\ 0&z\leq 0\end{cases}. (1)

For classification model, the output 𝒛L\boldsymbol{z}_{L} is converted into a predicted class uLu_{L} from the set of possible labels/classes {1,,NL}\left\{1,\ldots,N_{L}\right\} using the argmax layer:

uL=ζL(𝒛L)=argmaxmzL,m,u_{L}=\zeta_{L}\left(\boldsymbol{z}_{L}\right)=\mathop{\arg\max}\limits_{m}z_{L,m},

where zL,mz_{L,m} represents the output related to the mm-th label. However, the derivative of argmax activation function is discontinuous, which may lead to numerical instability. As a result, it is usually replaced with softmax when using SGD-based algorithms to train the DNN. In the proposed framework, to facilitate message passing algorithm design, we add zero-mean Gaussian noise on 𝒛L\boldsymbol{z}_{L}, which will be further discussed in Subsection II-B2.

The set of parameters 𝜽\boldsymbol{\theta} is defined as 𝜽{𝑾l,𝒃l|l=1,,L}\boldsymbol{\theta}\triangleq\left\{\boldsymbol{W}_{l},\boldsymbol{b}_{l}|l=1,\ldots,L\right\}. In practice, the DNN parameters 𝜽\boldsymbol{\theta} are usually obtained through a deep learning/training algorithm, which is the process of regressing the parameters 𝜽\boldsymbol{\theta} on some training data 𝑫{(𝒙i,𝒚i)|i=1,,I}\boldsymbol{D}\triangleq\left\{\left(\boldsymbol{x}^{i},\boldsymbol{y}^{i}\right)|i=1,\ldots,I\right\}, usually a series of inputs 𝑫x{𝒙i|i=1,,I}\boldsymbol{D}_{x}\triangleq\left\{\boldsymbol{x}^{i}|i=1,\ldots,I\right\} and their corresponding labels 𝑫y{𝒚i|i=1,,I}\boldsymbol{D}_{y}\triangleq\left\{\boldsymbol{y}^{i}|i=1,\ldots,I\right\}. The standard approach is minimizing a loss function L(𝜽)L\left(\boldsymbol{\theta}\right) to find a point estimate of 𝜽\boldsymbol{\theta} using the SGD-based algorithms. In regression models, the loss function is often defined as mean square error (MSE) on the training set as (2), and sometimes with a regularization term to penalize parametrizations or compress the DNN model as (3) if we choose an l1l_{1}-norm regularization function to prune the DNN weights. It is also possible to use more complicated sparse regularization functions to remove redundant neurons, features and filters [13, 14]. However, the standard training procedure above has several drawbacks as discussed in the introduction. Therefore, in this paper, we propose a Bayesian learning formulation to overcome those drawbacks.

LMSE(𝜽,𝑫)={𝒙i,𝒚i}𝑫𝒚iϕ(𝒙i;𝜽)2.L_{MSE}\left(\boldsymbol{\theta},\boldsymbol{D}\right)=\sum_{\left\{\boldsymbol{x}^{i},\boldsymbol{y}^{i}\right\}\in\boldsymbol{D}}\left\|\boldsymbol{y}^{i}-\phi\left(\boldsymbol{x}^{i};\boldsymbol{\theta}\right)\right\|^{2}. (2)
LMSE,l1(𝜽,𝑫)=LMSE(𝜽,𝑫)+λ𝜽1.L_{MSE,l_{1}}\left(\boldsymbol{\theta},\boldsymbol{D}\right)=L_{MSE}\left(\boldsymbol{\theta},\boldsymbol{D}\right)+\lambda\left\|\boldsymbol{\theta}\right\|_{1}. (3)

II-B Problem Formulation for Bayesian Deep Learning with Structured Model Compression

In the proposed Bayesian deep learning algorithm, the parameters 𝜽\boldsymbol{\theta} are treated as random variables. The goal of the proposed framework is to obtain the Bayesian posterior distribution p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right), which can be used to predict the output distribution (i.e., both point estimation and uncertainty for the output) on test data through forward propagation similar to that in training process. The joint posterior distribution p(𝜽,𝒛L|𝑫)p\left(\boldsymbol{\theta},\boldsymbol{z}_{L}|\boldsymbol{D}\right) can be factorized as (4):

p(𝜽,𝒛L|𝑫)\displaystyle p\left(\boldsymbol{\theta},\boldsymbol{z}_{L}|\boldsymbol{D}\right) p(𝜽,𝒛L,𝑫y|𝑫x)\displaystyle\propto p\left(\boldsymbol{\theta},\boldsymbol{z}_{L},\boldsymbol{D}_{y}|\boldsymbol{D}_{x}\right)
=p(𝜽)p(𝒛L|𝑫x,𝜽)p(𝑫y|𝒛L).\displaystyle=p\left(\boldsymbol{\theta}\right)p\left(\boldsymbol{z}_{L}|\boldsymbol{D}_{x},\boldsymbol{\theta}\right)p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right). (4)

The prior distribution p(𝜽)p\left(\boldsymbol{\theta}\right) is set as group sparse to achieve model compression as will be detailed in Subsection II-B1. The likelihood function p(𝑫y|𝒛L)p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right) is chosen as Gaussian/Probit-product to prevent numerical instability, as will be detailed in Subsection II-B2.

II-B1 Group Sparse Prior Distribution for DNN Parameters

Different applications often have varying requirements regarding the structure of DNN parameters. In the following, we shall introduce a group sparse prior distribution to capture structured sparsity that may arise in practical scenarios. Specifically, the joint prior distribution p(𝜽)p\left(\boldsymbol{\theta}\right) is given by

p(𝜽)\displaystyle p\left(\boldsymbol{\theta}\right) =i=1Q(ρij𝒩igj(θj)+(1ρi)j𝒩iδ(θj)),\displaystyle=\prod_{i=1}^{Q}\left(\rho_{i}\prod_{j\in\mathcal{N}_{i}}g_{j}\left(\theta_{j}\right)+\left(1-\rho_{i}\right)\prod_{j\in\mathcal{N}_{i}}\delta\left(\theta_{j}\right)\right), (5)

where QQ represents the number of groups, ρi\rho_{i} represents the active probability for the ii-th group, 𝒩i\mathcal{N}_{i} represents the set consisting of indexes of θ\theta in the ii-th group and gj(θj)g_{j}\left(\theta_{j}\right) represents the probability density function (PDF) of θj,j𝒩i\theta_{j},j\in\mathcal{N}_{i} when active, which is chosen as a Gaussian distribution with expectation μj\mu_{j} and variance vjv_{j} denoted as N(θj;μj,vj)N\left(\theta_{j};\mu_{j},v_{j}\right) in this paper. Here we shall focus on the following group sparse prior distribution to enable structured model compression.

Independent Sparse Prior for Bias Pruning

To impose simple sparse structure on the bias parameters for random dropout, we assume the elements bm,m=1,,Qbl=1LNlb_{m},m=1,\ldots,Q_{b}\triangleq\sum_{l=1}^{L}N_{l} have independent prior distributions:

p(𝒃)\displaystyle p\left(\boldsymbol{b}\right) =m=1Qbp(bm),\displaystyle=\prod_{m=1}^{Q_{b}}p\left(b_{m}\right),

where

p(bm)=ρmbN(bm;μmb,vmb)+(1ρmb)δ(bm),p\left(b_{m}\right)=\rho_{m}^{b}N\left(b_{m};\mu_{m}^{b},v_{m}^{b}\right)+\left(1-\rho_{m}^{b}\right)\delta\left(b_{m}\right),

ρmb\rho_{m}^{b} represents the active probability, and μmb\mu_{m}^{b} and vmbv_{m}^{b} represent the expectation and variance when active.

Group Sparse Prior for Neuron Pruning

In most DNNs, a weight group is often defined as the outgoing weights of a neuron to promote neuron-level sparsity. Note that there are a total number of l=1LNl1\sum_{l=1}^{L}N_{l-1} input neurons and hidden neurons in the DNN. In order to force all outgoing connections from a single neuron (corresponding to a group) to be either simultaneously zero or not, we divide the weight parameters into QWl=1LNl1Q_{W}\triangleq\sum_{l=1}^{L}N_{l-1} groups, such that the ii-th group for i=1,,QWi=1,\ldots,Q_{W} corresponds to the weights associated with the ii-th neuron. Specifically, for the ii-th weight group 𝑾i\boldsymbol{W}_{i}, we denote the active probability as ρiW\rho_{i}^{W}, and the expectation and variance related to the nn-th element Wi,n,n𝒩iWW_{i,n},n\in\mathcal{N}_{i}^{W} as μi,nW\mu_{i,n}^{W} and vi,nWv_{i,n}^{W}. The joint prior distribution can be decomposed as:

p(𝑾)\displaystyle p\left(\boldsymbol{W}\right) =i=1QWp(𝑾i),\displaystyle=\prod_{i=1}^{Q_{W}}p\left(\boldsymbol{W}_{i}\right),

where

p(𝑾i)\displaystyle p\left(\boldsymbol{W}_{i}\right) =ρiWn𝒩iWN(Wi,n;μi,nW,vi,nW)\displaystyle=\rho_{i}^{W}\prod_{n\in\mathcal{N}_{i}^{W}}N\left(W_{i,n};\mu_{i,n}^{W},v_{i,n}^{W}\right)
+(1ρiW)n𝒩iWδ(Wi,n).\displaystyle+\left(1-\rho_{i}^{W}\right)\prod_{n\in\mathcal{N}_{i}^{W}}\delta\left(W_{i,n}\right).

Note that a parameter θj\theta_{j} corresponds to either a bias parameter bmb_{m} or a weight parameter Wi,nW_{i,n}, and thus we have Q=Qb+QWQ=Q_{b}+Q_{W}. For convenience, we define 𝝍\boldsymbol{\psi} as a set consisting of ρmb,μmb,vmb\rho_{m}^{b},\mu_{m}^{b},v_{m}^{b} for m=1,,Qbm=1,\cdots,Q_{b} and ρiW,μi,nW,vi,nW\rho_{i}^{W},\mu_{i,n}^{W},v_{i,n}^{W} for i=1,,QW,n𝒩ii=1,\cdots,Q_{W},n\in\mathcal{N}_{i}, which will be updated to accelerate convergence as will be further discussed later. Please refer to Fig. 1 for an illustration of group sparsity. It is also possible to design other sparse priors to achieve more structured model compression, such as burst sparse prior, which is widely used in the literature on sparse channel estimation [45, 34]. Specifically, the burst sparse prior introduces a Markov distributed sparse support vector to drive the active neurons in each layer to concentrate on a few clusters [45, 34]. The detailed derivation with burst sparse prior is omitted due to limited space.

Refer to caption
Figure 1: Illustration for group sparsity, where we show elements in the ll-th layer. The gray elements are preserved, while white elements are set to zeros. In the figure, the 2-nd input neuron and 6-th input neuron are deactivated because the related weight columns are set to zeros.

II-B2 Likelihood Function for the Last Layer

In the Bayesian inference problem, the observation can be represented as a likelihood function p(𝒚i|𝒛Li)p\left(\boldsymbol{y}^{i}|\boldsymbol{z}_{L}^{i}\right) for i=1,,Ii=1,\ldots,I, where we define 𝒛Li=ϕ(𝒙i;𝜽)\boldsymbol{z}_{L}^{i}=\phi\left(\boldsymbol{x}^{i};\boldsymbol{\theta}\right). Directly assume p(𝒚i|𝒛Li)=δ(𝒚iζL(𝒛Li))p\left(\boldsymbol{y}^{i}|\boldsymbol{z}_{L}^{i}\right)=\delta\left(\boldsymbol{y}^{i}-\zeta_{L}\left(\boldsymbol{z}_{L}^{i}\right)\right) may lead to numerical instability. To avoid this problem, we add zero-mean Gaussian noise with variance vv on the output 𝒛Li\boldsymbol{z}_{L}^{i}. The noise variance vv is treated as a hyperparameter that is adaptively updated to control the learning rate. In the following, we take regression model and classification model as examples to illustrate the modified likelihood function.

Gaussian Likelihood Function for Regression Model

For regression model, after adding Gaussian noise at output, the likelihood function becomes joint Gaussian:

p(𝒚i|𝒛Li)\displaystyle p\left(\boldsymbol{y}^{i}|\boldsymbol{z}_{L}^{i}\right) =m=1NLN(ymi;𝒛L,mi,v),\displaystyle=\prod_{m=1}^{N_{L}}N\left(y_{m}^{i};\boldsymbol{z}_{L,mi},v\right), (6)

where v,ymiv,y_{m}^{i} and 𝒛L,mi\boldsymbol{z}_{L,mi} represent the noise variance, the mm-th element in 𝒚i\boldsymbol{y}^{i} and 𝒛Li\boldsymbol{z}_{L}^{i}, respectively.

Probit-product Likelihood Function for Classification Model

For classification model, we consider one-hot labels, where yiy^{i} refers to the label for the ii-th training sample. Instead of directly using argmax layer [27], to prevent message vanishing and booming, we add Gaussian noise on zL,miz_{L,mi} for m=1,,NLm=1,\ldots,N_{L} and obtain the following likelihood function which is product of probit function mentioned in [28]:

p(yi|𝒛Li)\displaystyle p\left(y^{i}|\boldsymbol{z}_{L}^{i}\right) uLi=1NLδ(yiuLi)muLip(zL,mi<zL,uLii)\displaystyle\approx\sum_{u_{L}^{i}=1}^{N_{L}}\delta\left(y^{i}-u_{L}^{i}\right)\prod_{m\neq u_{L}^{i}}p\left(z_{L,mi}<z_{L,u_{L}^{i}i}\right)
=myiQ(zL,mizL,yiiv),\displaystyle=\prod_{m\neq y^{i}}Q\left(\frac{z_{L,mi}-z_{L,y^{i}i}}{\sqrt{v}}\right), (7)

where we approximate zL,mizL,uLii,i=1,,I,muLiz_{L,mi}-z_{L,u_{L}^{i}i},i=1,\ldots,I,m\neq u_{L}^{i} as independent to simplify the message passing as will be detailed in Appendix -A2. Extensive simulations verify that such an approximation can achieve a good classification performance. Besides, we define Q()=1F()Q\left(\cdot\right)=1-F\left(\cdot\right), where F()F\left(\cdot\right) represents the cumulative distribution function of the standardized normal random variable.

III EM-TDAMP Algorithm Derivation

III-A Bayesian deep learning framework based on EM

To accelerate convergence, we update hyperparameters in the prior distribution and the likelihood function based on EM algorithm [29], where the expectation step (E-step) computes the posterior distribution (4) by performing turbo deep approximate message passing (TDAMP) as will be detailed in Subsection III-B, while the maximization step (M-step) updates hyperparameters 𝝍\boldsymbol{\psi} and vv by maximizing the expectation (8) taken w.r.t. the posterior distributions p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) and p(𝒛L|𝑫)p\left(\boldsymbol{z}_{L}|\boldsymbol{D}\right) as will be detailed in Subsection III-C.

{𝝍,v}\displaystyle\left\{\boldsymbol{\psi},v\right\} =argmax𝝍,vE(logp(𝜽,𝒛L,𝑫))\displaystyle=\mathop{\arg\max}\limits_{\boldsymbol{\psi},v}E\left(\log p\left(\boldsymbol{\theta},\boldsymbol{z}_{L},\boldsymbol{D}\right)\right)
=argmax𝝍E(logp(𝜽))\displaystyle=\mathop{\arg\max}\limits_{\boldsymbol{\psi}}E\left(\log p\left(\boldsymbol{\theta}\right)\right)
+argmaxvE(logp(𝑫y|𝒛L)),\displaystyle+\mathop{\arg\max}\limits_{v}E\left(\log p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right)\right), (8)

III-B E-step (TDAMP Algorithm)

To compute the expectation in (8), the E-step performs TDAMP to compute the global posterior distribution p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) and p(𝒛L|𝑫)p\left(\boldsymbol{z}_{L}|\boldsymbol{D}\right) with prior distribution p(𝜽)p\left(\boldsymbol{\theta}\right) and the likelihood function p(𝑫|𝜽)p\left(\boldsymbol{D}|\boldsymbol{\theta}\right). In order to accelerate convergence for large datasets 𝑫\boldsymbol{D}, we divide 𝑫\boldsymbol{D} into RR minibatches, and for r=1,,Rr=1,\ldots,R, we define 𝑫r{(𝒙i,𝒚i)|ir}\boldsymbol{D}^{r}\triangleq\left\{\left(\boldsymbol{x}^{i},\boldsymbol{y}^{i}\right)|i\in\mathcal{I}_{r}\right\} with r=1Rr=\cup_{r=1}^{R}\mathcal{I}_{r}=\mathcal{I}. In the following, we first elaborate the TDAMP algorithm to compute the posterior distributions for each minibatch 𝑫r\boldsymbol{D}^{r}. Then we present the PasP rule to update the prior distribution p(𝜽)p\left(\boldsymbol{\theta}\right).

III-B1 Top-Level Factor Graph

The joint PDF associated with minibatch 𝑫r\boldsymbol{D}^{r} can be factorized as follows:

p(𝜽,{𝒖l1r,𝒛lr|l=1,,L},𝑫yr|𝑫xr)\displaystyle p\left(\boldsymbol{\theta},\left\{\boldsymbol{u}_{l-1}^{r},\boldsymbol{z}_{l}^{r}|l=1,\ldots,L\right\},\boldsymbol{D}_{y}^{r}|\boldsymbol{D}_{x}^{r}\right)
=p(𝒖0r|𝑫xr)×l=1L(p(𝜽l)p(𝒛lr|𝜽l,𝒖l1r))\displaystyle=p\left(\boldsymbol{u}_{0}^{r}|\boldsymbol{D}_{x}^{r}\right)\times\prod_{l=1}^{L}\left(p\left(\boldsymbol{\theta}_{l}\right)p\left(\boldsymbol{z}_{l}^{r}|\boldsymbol{\theta}_{l},\boldsymbol{u}_{l-1}^{r}\right)\right)
×l=1L1p(𝒖lr|𝒛lr)p(𝑫yr|𝒛Lr),\displaystyle\times\prod_{l=1}^{L-1}p\left(\boldsymbol{u}_{l}^{r}|\boldsymbol{z}_{l}^{r}\right)p\left(\boldsymbol{D}_{y}^{r}|\boldsymbol{z}_{L}^{r}\right), (9)

where for l=1,,Ll=1,\ldots,L, we denote by 𝒛lr={𝒛li|ir}Nl×|r|,𝒖lr={𝒖li|ir}\boldsymbol{z}_{l}^{r}=\left\{\boldsymbol{z}_{l}^{i}|i\in\mathcal{I}_{r}\right\}\in\mathbb{R}^{N_{l}\times|\mathcal{I}_{r}|},\boldsymbol{u}_{l}^{r}=\left\{\boldsymbol{u}_{l}^{i}|i\in\mathcal{I}_{r}\right\}, and thus:

p(𝒖0r|𝑫xr)=irδ(𝒖0i𝒙i),p\left(\boldsymbol{u}_{0}^{r}|\boldsymbol{D}_{x}^{r}\right)=\prod_{i\in\mathcal{I}_{r}}\delta\left(\boldsymbol{u}_{0}^{i}-\boldsymbol{x}^{i}\right),
p(𝑫yr|𝒛Lr)=irp(𝒚i|𝒛Li),p\left(\boldsymbol{D}_{y}^{r}|\boldsymbol{z}_{L}^{r}\right)=\prod_{i\in\mathcal{I}_{r}}p\left(\boldsymbol{y}^{i}|\boldsymbol{z}_{L}^{i}\right),
p(𝜽l)\displaystyle p\left(\boldsymbol{\theta}_{l}\right) =p(𝑾l)p(𝒃l),\displaystyle=p\left(\boldsymbol{W}_{l}\right)p\left(\boldsymbol{b}_{l}\right),
p(𝒛Lr|𝜽l,𝒖l1r)\displaystyle p\left(\boldsymbol{z}_{L}^{r}|\boldsymbol{\theta}_{l},\boldsymbol{u}_{l-1}^{r}\right) =irδ(𝒛li𝑾l𝒖l1i𝒃l),\displaystyle=\prod_{i\in\mathcal{I}_{r}}\delta\left(\boldsymbol{z}_{l}^{i}-\boldsymbol{W}_{l}\boldsymbol{u}_{l-1}^{i}-\boldsymbol{b}_{l}\right),
p(𝒖lr|𝒛lr)\displaystyle p\left(\boldsymbol{u}_{l}^{r}|\boldsymbol{z}_{l}^{r}\right) =irδ(𝒖liζl(𝒛li)).\displaystyle=\prod_{i\in\mathcal{I}_{r}}\delta\left(\boldsymbol{u}_{l}^{i}-\zeta_{l}\left(\boldsymbol{z}_{l}^{i}\right)\right).

Based on (9), the detailed structure of 𝒢r\mathcal{G}_{r} is illustrated in Fig. 2, where the superscript/subscript rr is omitted for conciseness because there is no ambiguity.

Refer to caption
Figure 2: The structure of 𝒢r\mathcal{G}_{r} (r=1,,Rr=1,\ldots,R). The specific expression of factor nodes are summarized in Table I.
Factor Distribution Functional form
h0h_{0} p(𝒖0r|𝒙r)p\left(\boldsymbol{u}_{0}^{r}|\boldsymbol{x}^{r}\right) irn=1N0δ(u0,nixni)\prod_{i\in\mathcal{I}_{r}}\prod_{n=1}^{N_{0}}\delta\left(u_{0,ni}-x_{n}^{i}\right)
hlh_{l} p(𝒖lr|𝒛lr)p\left(\boldsymbol{u}_{l}^{r}|\boldsymbol{z}_{l}^{r}\right) irm=1Nlδ(ul,miζl(zl,mi))\prod_{i\in\mathcal{I}_{r}}\prod_{m=1}^{N_{l}}\delta\left(u_{l,mi}-\zeta_{l}\left(z_{l,mi}\right)\right)
hLh_{L} p(𝒚r|𝒛lr)p\left(\boldsymbol{y}^{r}|\boldsymbol{z}_{l}^{r}\right) {irm=1NLN(ymi;zL,mi,v)RegressionirmyiQ(zL,mizL,yiiv)Classification\begin{cases}\prod_{i\in\mathcal{I}_{r}}\prod_{m=1}^{N_{L}}N\left(y_{m}^{i};z_{L,mi},v\right)&Regression\\ \prod_{i\in\mathcal{I}_{r}}\prod_{m\neq y^{i}}Q\left(\frac{z_{L,mi}-z_{L,y^{i}i}}{\sqrt{v}}\right)&Classification\end{cases}
hlWh_{l}^{W} p(𝑾l)p\left(\boldsymbol{W}_{l}\right) n=1Nl1(ρl,nm=1NlN(Wl,mn;μl,mn,vl,mn)+(1ρl,n)m=1Nlδ(Wl,mn))\prod_{n=1}^{N_{l-1}}\left(\rho_{l,n}\prod_{m=1}^{N_{l}}N\left(W_{l,mn};\mu_{l,mn},v_{l,mn}\right)+\left(1-\rho_{l,n}\right)\prod_{m=1}^{N_{l}}\delta\left(W_{l,mn}\right)\right)
hlbh_{l}^{b} p(𝒃l)p\left(\boldsymbol{b}_{l}\right) m=1Nl(ρl,mbN(bl,m;μl,mb,vl,mb)+(1ρl,mb)δ(bl,m))\prod_{m=1}^{N_{l}}\left(\rho_{l,m}^{b}N\left(b_{l,m};\mu_{l,m}^{b},v_{l,m}^{b}\right)+\left(1-\rho_{l,m}^{b}\right)\delta\left(b_{l,m}\right)\right)
flf_{l} p(𝒛lr|𝑾l,𝒖l1r,𝒃l)p\left(\boldsymbol{z}_{l}^{r}|\boldsymbol{W}_{l},\boldsymbol{u}_{l-1}^{r},\boldsymbol{b}_{l}\right) irm=1Nlδ(𝒛l,mi(n=1Nl1Wl,mnul1,ni+bl,m))\prod_{i\in\mathcal{I}_{r}}\prod_{m=1}^{N_{l}}\delta\left(\boldsymbol{z}_{l,mi}-\left(\sum_{n=1}^{N_{l-1}}W_{l,mn}u_{l-1,ni}+b_{l,m}\right)\right)
TABLE I: Factors, distributions and functional forms in Fig. 2.

Each iteration of the message passing procedure on the factor graph 𝒢r\mathcal{G}_{r} in Fig. 2 consists of a forward message passing from the first layer to the last layer, followed by a backward message passing from the last layer to the first layer. However, the standard sum-product rule is infeasible on the DNN factor graph due to the high complexity. We propose DAMP to reduce complexity as will be detailed in Subsection III-B3. DAMP requires the prior distribution to be independent, so we follow turbo approach [25] to decouple the factor graph into Module AA and Module BB to compute messages with independent prior distribution and deal with group sparse prior separately. Notice that turbo framework we utilize is the same as EP [44, 43, 50] in most inference problems as illustrated in [49]. However, in this article, the two frameworks are not equivalent because EP needs to project the posterior distribution and extrinsic messages as Gaussian, while we apply standard sum-product rule in Module BB without projection. As such, the turbo framework can achieve slightly better performance than EP for the problem considered in this paper.

III-B2 Turbo Framework to Deal with Group Sparse Prior

To achieve neuron-level pruning, each weight group is a column in weight matrix as discussed in Subsection II-B1. Specifically, we denote the nn-th column in 𝑾l\boldsymbol{W}_{l} by 𝑾l,n\boldsymbol{W}_{l,n}, where l=1,,L,n=1,,Nl1l=1,\ldots,L,n=1,\ldots,N_{l-1}, and the corresponding factor graph is shown in Fig. 3. The TDAMP algorithm iterates between two Modules AA and BB. Module AA consists of factor nodes fli,irkf_{l}^{i},i\in\mathcal{I}_{r_{k}} that connect the weight parameters with the observation model, weight parameters Wl,mn,m=1,,NlW_{l,mn},m=1,\ldots,N_{l}, and factor nodes hl,mn,m=1,,Nlh_{l,mn},m=1,\ldots,N_{l} that represent the extrinsic messages from Module BB denoted as l,mnBA\triangle_{l,mn}^{B\rightarrow A}. Module BB consists of factor node hl,nh_{l,n} that represents the group sparse prior distribution, parameters Wl,mn,m=1,,NlW_{l,mn},m=1,\ldots,N_{l}, and factor nodes hl,mn,m=1,,Nlh_{l,mn},m=1,\ldots,N_{l} that represent the extrinsic messages from Module AA denoted as l,mnAB\triangle_{l,mn}^{A\rightarrow B}. Module AA updates the messages by performing DAMP algorithm with observations and independent prior distribution from Module BB. Module BB updates the independent prior distributions for Module AA by performing sum-product message passing (SPMP) algorithm over the group sparse prior. In the following, we elaborate Module AA and Module BB.

Refer to caption
Figure 3: Turbo framework factor graph related to 𝑾l,n\boldsymbol{W}_{l,n}.

III-B3 DAMP in Module AA

We compute the approximated marginal posterior distributions by performing DAMP. Based on turbo approach, in Module AA, for l,m,n\forall l,m,n, the prior factor nodes for weight matrices represent messages extracted from Module BB:

hl,mnWl,mnBA.h_{l,mn}^{W}\triangleq\triangle_{l,mn}^{B\rightarrow A}.

The factor graph for the ll-th layer in 𝒢k\mathcal{G}_{k} is shown in Fig. 4, where ul1,niu_{l-1,ni} and zl,miz_{l,mi} represent the nn-th element in 𝒖l1i\boldsymbol{u}_{l-1}^{i} and mm-th element in 𝒛li\boldsymbol{z}_{l}^{i}, respectively.

Refer to caption
Figure 4: Detailed structure of the ll-th layer related to the ii-th sample, where we set Nl=2,Nl1=3N_{l}=2,N_{l-1}=3. The specific expressions of factor nodes are summarized in Table II.
Factor Distribution Functional form
hl1,nih_{l-1,ni} {p(u0,ni|xni)l=1p(ul1,ni|zl1,ni)l=2,,L\begin{cases}p\left(u_{0,ni}|x_{n}^{i}\right)&l=1\\ p\left(u_{l-1,ni}|z_{l-1,ni}\right)&l=2,\dots,L\end{cases} {δ(u0,nixni)l=1δ(ul1,niζl1(zl1,ni))l=2,,L\begin{cases}\delta\left(u_{0,ni}-x_{n}^{i}\right)&l=1\\ \delta\left(u_{l-1,ni}-\zeta_{l-1}\left(z_{l-1,ni}\right)\right)&l=2,\dots,L\end{cases}
fl,mif_{l,mi} p(zl,mi|𝑾l,n,𝒖l1i,bl,m)p\left(z_{l,mi}|\boldsymbol{W}_{l,n},\boldsymbol{u}_{l-1}^{i},b_{l,m}\right) δ(zl,mi(n=1Nl1Wl,mnul1,ni+bl,m))\delta\left(z_{l,mi}-\left(\sum_{n=1}^{N_{l-1}}W_{l,mn}u_{l-1,ni}+b_{l,m}\right)\right)
hl,mbh_{l,m}^{b} p(bl,m)p\left(b_{l,m}\right) ρl,mbN(bl,m;μl,mb,vl,mb)+(1ρl,mb)δ(bl,m)\rho_{l,m}^{b}N\left(b_{l,m};\mu_{l,m}^{b},v_{l,m}^{b}\right)+\left(1-\rho_{l,m}^{b}\right)\delta\left(b_{l,m}\right)
hl,mnWh_{l,mn}^{W} exp(l,mnBA)\exp\left(\triangle_{l,mn}^{B\rightarrow A}\right) ρl,mnBAN(Wl,mn;μl,mnBA,vl,mnBA)+(1ρl,mnBA)δ(Wl,mn)\rho_{l,mn}^{B\rightarrow A}N\left(W_{l,mn};\mu_{l,mn}^{B\rightarrow A},v_{l,mn}^{B\rightarrow A}\right)+\left(1-\rho_{l,mn}^{B\rightarrow A}\right)\delta\left(W_{l,mn}\right)
TABLE II: Factors, distributions and functional forms in Fig. 4.

In the proposed DAMP, the messages between layers are updated in turn. For convenience, in the following, we denote by ab\triangle_{a\rightarrow b} the message from node aa to bb, and by c\triangle_{c} the marginal log-posterior computed at variable node cc.

In forward message passing, layer l=1,,Ll=1,\dots,L output messages fl,mizl,mi\triangle_{f_{l,mi}\rightarrow z_{l,mi}} with input messages hl1,niul1,ni\triangle_{h_{l-1,ni}\rightarrow u_{l-1,ni}}:

h0,niu0,ni=δ(u0,nixni),\triangle_{h_{0,ni}\rightarrow u_{0,ni}}=\delta\left(u_{0,ni}-x_{n}^{i}\right),

and for l=1,,L1l=1,\dots,L-1,

hl,niul,ni\displaystyle\triangle_{h_{l,ni}\rightarrow u_{l,ni}} =logzl,niexp(fl,nizl,ni)\displaystyle=\log\int_{z_{l,ni}}\exp\left(\triangle_{f_{l,ni}\rightarrow z_{l,ni}}\right)
×δ(ul,niζl(zl,ni)).\displaystyle\times\delta\left(u_{l,ni}-\zeta_{l}\left(z_{l,ni}\right)\right).

In backward message passing, layer l=L,,1l=L,\dots,1 output messages ul1,nihl1,ni\triangle_{u_{l-1,ni}\rightarrow h_{l-1,ni}} with input messages zl,mifl,mi\triangle_{z_{l,mi}\rightarrow f_{l,mi}}:

zL,mifL,mi\displaystyle\triangle_{z_{L,mi}\rightarrow f_{L,mi}} =logzl,mi,mmexp(mmfl,mizl,mi)\displaystyle=\log\int_{z_{l,m^{\prime}i},m^{\prime}\neq m}\exp\left(\sum_{m^{\prime}\neq m}\triangle_{f_{l,m^{\prime}i}\rightarrow z_{l,m^{\prime}i}}\right)
×p(𝒚i|𝒛Li),\displaystyle\times p\left(\boldsymbol{y}^{i}|\boldsymbol{z}_{L}^{i}\right),

and for l=L1,,1l=L-1,\dots,1,

zl,mifl,mi\displaystyle\triangle_{z_{l,mi}\rightarrow f_{l,mi}} =logzl,miexp(ul,mihl,mi)\displaystyle=\log\int_{z_{l,mi}}\exp\left(\triangle_{u_{l,mi}\rightarrow h_{l,mi}}\right)
×δ(ul,miζl(zl,mi)).\displaystyle\times\delta\left(u_{l,mi}-\zeta_{l}\left(z_{l,mi}\right)\right).

Notice that the factor graph of a layer as illustrated in Fig. 4 has a similar structure to the bilinear model discussed in [31]. Therefore, we follow the general idea of the BiG-AMP framework in [31] to approximate the messages within each layer. The detailed derivation is presented in the supplementary file of this paper, and the schedule of approximated messages is summarized in Algorithm 1. In particular, the messages ul1,ni,zl,mi\triangle_{u_{l-1,ni}},\triangle_{z_{l,mi}} are related to nonlinear steps, which will be detailed in Appendix -A.

III-B4 SPMP in Module BB

Module BB further exploits the structured sparsity to achieve structured model compression by performing the SPMP algorithm. Note that Module BB has a tree structure, and thus the SPMP is exact. For l,m,n\forall l,m,n, the input factor nodes for Module BB are defined as output messages in Module AA:

hl,mn\displaystyle h_{l,mn} l,mnAB=Wl,mnhl,mnW.\displaystyle\triangleq\triangle_{l,mn}^{A\rightarrow B}=\triangle_{W_{l,mn}\rightarrow h_{l,mn}^{W}}.

Based on SPMP, we give the updating rule (10) for the output message as follows:

exp(l,mnBA)\displaystyle\exp\left(\triangle_{l,mn}^{B\rightarrow A}\right) Wl,mn,mmp(Wl,n)exp(mml,mnAB)\displaystyle\propto\int_{W_{l,m^{\prime}n},m^{\prime}\neq m}p\left(W_{l,n}\right)\exp\left(\sum_{m^{\prime}\neq m}\triangle_{l,mn}^{A\rightarrow B}\right)
ρl,mnBAN(Wl,mn;μl,mnBA,vl,mnBA)\displaystyle\propto\rho_{l,mn}^{B\rightarrow A}N\left(W_{l,mn};\mu_{l,mn}^{B\rightarrow A},v_{l,mn}^{B\rightarrow A}\right)
+(1ρl,mnBA)δ(Wl,mn),\displaystyle+\left(1-\rho_{l,mn}^{B\rightarrow A}\right)\delta\left(W_{l,mn}\right), (10)

where

μl,mnBA=μl,mn,vl,mnBA=vl,mn,\mu_{l,mn}^{B\rightarrow A}=\mu_{l,mn},v_{l,mn}^{B\rightarrow A}=v_{l,mn},
ρl,mnBA=ρl,nρl,n+(1ρl,n)mmηl,mn,\rho_{l,mn}^{B\rightarrow A}=\frac{\rho_{l,n}}{\rho_{l,n}+\left(1-\rho_{l,n}\right)\prod_{m^{\prime}\neq m}\eta_{l,mn}},
ηl,mn=N(μl,mnAB,vl,mnAB)N(μl,mnABμl,mn,vl,mnAB+vl,mn).\eta_{l,mn}=\frac{N\left(\mu_{l,mn}^{A\rightarrow B},v_{l,mn}^{A\rightarrow B}\right)}{N\left(\mu_{l,mn}^{A\rightarrow B}-\mu_{l,mn},v_{l,mn}^{A\rightarrow B}+v_{l,mn}\right)}.

The posterior distribution for 𝑾l,n\boldsymbol{W}_{l,n} is given by (11), which will be used in Subsection III-B5 to update the prior distribution.

p(𝑾l,n|𝑫rk)\displaystyle p\left(\boldsymbol{W}_{l,n}|\boldsymbol{D}^{r_{k}}\right) p(𝑾l,n)×exp(mWl,mnhl,mnW)\displaystyle\propto p\left(\boldsymbol{W}_{l,n}\right)\times\exp\left(\sum_{m}\triangle_{W_{l,mn}\rightarrow h_{l,mn}^{W}}\right)
ρl,npostm=1NlN(Wl,mn;μl,mnpost,vl,mnpost)\displaystyle\propto\rho_{l,n}^{post}\prod_{m=1}^{N_{l}}N\left(W_{l,mn};\mu_{l,mn}^{post},v_{l,mn}^{post}\right)
+(1ρl,npost)m=1Nlδ(Wl,mn),\displaystyle+\left(1-\rho_{l,n}^{post}\right)\prod_{m=1}^{N_{l}}\delta\left(W_{l,mn}\right), (11)

where

ρl,npost=ρl,nρl,n+(1ρl,n)mηl,mn,\rho_{l,n}^{post}=\frac{\rho_{l,n}}{\rho_{l,n}+\left(1-\rho_{l,n}\right)\prod_{m}\eta_{l,mn}},
μl,mnpost=μl,mnvl,mn+μl,mnBAvl,mnBA1vl,mn+1vl,mnBA,vl,mnpost=11vl,mn+1vl,mnBA.\mu_{l,mn}^{post}=\frac{\frac{\mu_{l,mn}}{v_{l,mn}}+\frac{\mu_{l,mn}^{B\rightarrow A}}{v_{l,mn}^{B\rightarrow A}}}{\frac{1}{v_{l,mn}}+\frac{1}{v_{l,mn}^{B\rightarrow A}}},v_{l,mn}^{post}=\frac{1}{\frac{1}{v_{l,mn}}+\frac{1}{v_{l,mn}^{B\rightarrow A}}}.

III-B5 PasP Rule to Update Prior Distribution p(𝜽)p\left(\boldsymbol{\theta}\right)

To accelerate convergence and fuse the information among minibatches, we update the joint prior distribution p(𝜽)p\left(\boldsymbol{\theta}\right) after processing each minibatch. Specifically, after updating the joint posterior distribution based on the rr-th batch, we set the prior distribution as the posterior distribution. The mechanism is called PasP (12) mentioned in [27]:

p(𝜽)=(p(𝜽|𝑫r))λ,p\left(\boldsymbol{\theta}\right)=\left(p\left(\boldsymbol{\theta}|\boldsymbol{D}^{r}\right)\right)^{\lambda}, (12)

where the posterior distributions for biases are computed through DAMP in Module AA, while the posterior distributions for weights are computed through (11) in Module BB. By doing so, the information from the all the previous minibatches are incorporated in the updated prior distribution. In practice, λ\lambda plays a role similar to the learning rate in SGD and is typically set close to 1 [27]. For convenience, we fix λ=1\lambda=1 in simulations.

III-C M-step

In the M-step, we update hyperparameters 𝝍\boldsymbol{\psi} and vv in the prior distribution and the likelihood function by maximizing E(logp(𝜽))E\left(\log p\left(\boldsymbol{\theta}\right)\right) and E(logp(𝑫y|𝒛L))E\left(\log p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right)\right) respectively, where the expectation is computed based on the results of the E-step as discussed above.

Updating rules for prior hyperparameter 𝝍\boldsymbol{\psi}

We observe that the posterior distribution p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) computed through TDAMP can be factorized in the same form as p(𝜽)p\left(\boldsymbol{\theta}\right), thus maximizing E(logp(𝜽))E\left(\log p\left(\boldsymbol{\theta}\right)\right) is equivalent to update 𝝍\boldsymbol{\psi} as the corresponding parameters in p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right). However, directly updating the prior sparsity parameters ρiW\rho_{i}^{W}s based on EM cannot achieve neuron-level pruning with the target sparsity ρ\rho. It is also not a good practice to fix ρiW=ρ\rho_{i}^{W}=\rho throughout the iterations because this usually slows down the convergence speed as observed in the simulations. In order to control the network sparsity and prune the network during training without affecting the convergence, we introduce the following modified updating rules for ρiW\rho_{i}^{W}s. Specifically, after each M-step, we calculate SS, which represents the number of weight groups that are highly likely to be active, i.e.,

S=i=1QW1(ρiW>ρth),S=\sum_{i=1}^{Q_{W}}1\left(\rho_{i}^{W}>\rho_{th}\right),

where ρth\rho_{th} is certain threshold that is set close to 1. If SS exceeds the target number of neurons ρQW\rho Q_{W}, we reset ρiW\rho_{i}^{W}s as follows:

ρiW={ρ0,ρiWρth0,ρiW<ρth,\rho_{i}^{W}=\begin{cases}\rho_{0},&\rho_{i}^{W}\geq\rho_{th}\\ 0,&\rho_{i}^{W}<\rho_{th}\end{cases},

where ρ0\rho_{0} is the initial sparsity. Extensive simulations have shown that this method works well.

Updating rules for noise variance vv

We take regression model and classification model as examples to derive the updating rule for vv.

For regression model (6), by setting the derivative for E(logp(𝑫y|𝒛L))E\left(\log p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right)\right) w.r.t. vv equal to zero, we obtain:

v=i=1Im=1NL(ymiμzL,mi)2+vzL,miNLI.v^{*}=\sum_{i=1}^{I}\sum_{m=1}^{N_{L}}\frac{\left(y_{m}^{i}-\mu_{z_{L,mi}}\right)^{2}+v_{z_{L,mi}}}{N_{L}I}. (13)

For classification model (7), we define:

i=1,,I,myi:ξmi=zL,mizL,yii,\forall i=1,\ldots,I,m\neq y^{i}:\xi_{mi}=z_{L,mi}-z_{L,y^{i}i},

with expectation and variance given by

μξmi=μzL,miμzL,yii,\displaystyle\mu_{\xi_{mi}}=\mu_{z_{L,mi}}-\mu_{z_{L,y^{i}i}}, vξmi=vzL,mi+vzL,yii.\displaystyle v_{\xi_{mi}}=v_{z_{L,mi}}+v_{z_{L,y^{i}i}}. (14)

Then E(logp(𝑫y|𝒛L))E\left(\log p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right)\right) can be approximated as follows:

E(logp(𝑫y|𝒛L))=\displaystyle E\left(\log p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right)\right)= ξmilogQ(ξmiv)\displaystyle\int_{\xi_{mi}}\log Q\left(\frac{\xi_{mi}}{\sqrt{v}}\right)
×i=1ImyiN(ξmi;μξmi,vξmi)\displaystyle\times\sum_{i=1}^{I}\sum_{m\neq y^{i}}N\left(\xi_{mi};\mu_{\xi_{mi}},v_{\xi_{mi}}\right)
\displaystyle\approx ξG(ξ;αξ,βξ)logQ(ξv),\displaystyle\int_{\xi}G\left(\xi;\alpha_{\xi},\beta_{\xi}\right)\log Q\left(\frac{\xi}{\sqrt{v}}\right), (15)

where we approximate ξi=1ImyiN(ξ;μξmi,vξmi)\xi\sim\sum_{i=1}^{I}\sum_{m\neq y^{i}}N\left(\xi;\mu_{\xi_{mi}},v_{\xi_{mi}}\right) as a Gumbel distribution G(ξ;αξ,βξ)G\left(\xi;\alpha_{\xi},\beta_{\xi}\right) with location parameter αξ\alpha_{\xi} and scale parameter βξ\beta_{\xi}. Based on moment matching, we estimate αξ\alpha_{\xi} and βξ\beta_{\xi} as follows:

βξ=6πEμ2,αξ=μ+γβξ,\beta_{\xi}=\frac{\sqrt{6}}{\pi}\sqrt{E-\mu^{2}},\alpha_{\xi}=\mu+\gamma\beta_{\xi},

where γ0.5772\gamma\approx 0.5772 is Euler’s constant, and we define μ,E\mu,E using (14) as follows:

μ\displaystyle\mu i=1Imyiμξmi(NL1)I,Ei=1Imyi(μξmi2+vξmi)(NL1)I.\displaystyle\triangleq\frac{\sum_{i=1}^{I}\sum_{m\neq y^{i}}\mu_{\xi_{mi}}}{\left(N_{L}-1\right)I},E\triangleq\frac{\sum_{i=1}^{I}\sum_{m\neq y^{i}}\left(\mu_{\xi_{mi}}^{2}+v_{\xi_{mi}}\right)}{\left(N_{L}-1\right)I}. (16)

The effectiveness of this approximation will be justified in Fig. 8 and Fig. 9 in the simulation section. To solve the optimal vv based on (15), we define a special function

F(μ)=argmaxvξG(ξ;μ,1)logQ(ξv),F\left(\mu\right)=\mathop{\arg\max}\limits_{v}\int_{\xi}G\left(\xi;\mu,1\right)\log Q\left(\frac{\xi}{\sqrt{v}}\right),

which can be calculated numerically and stored in a table for practical implementation. Then the optimal vv is given by

v0\displaystyle v_{0}^{*} =βξ2F(αξβξ).\displaystyle=\beta_{\xi}^{2}F\left(\frac{\alpha_{\xi}}{\beta_{\xi}}\right).

Considering the error introduced by the above approximation, we use the damping technique [31] with damping factor 0.5 to smooth the update of vv in experiments:

v=0.5v0+0.5v.v^{*}=0.5v_{0}^{*}+0.5v. (17)

Compared to numerical solution for argmaxvE(logp(𝑫y|𝒛L))\mathop{\arg\max}\limits_{v}E\left(\log p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right)\right), the proposed method greatly reduces complexity. Experiments show that the method is stable as will be detailed in Subsection V-B.

III-D Summary of the EM-TDAMP Algorithm

To sum up, the proposed EM-TDAMP algorithm is implemented as Algorithm 1, where τmax\tau_{max} represents maximum iteration number.

Algorithm 1 EM-TDAMP algorithm

Input: dataset 𝑫\boldsymbol{D}.

Output: p(θ|D),p(zL|D)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right),p\left(\boldsymbol{z}_{L}|\boldsymbol{D}\right)

Initialization: Hyperparameters 𝝍,v\boldsymbol{\psi},v, l,m,n:sl,mi=0\forall l,m,n:\triangle_{s_{l,mi}}=0,

n,i:ul1,nihl1,ni={logδ(u0,nixni)l=10l>1\forall n,i:\triangle_{u_{l-1,ni}\rightarrow h_{l-1,ni}}=\begin{cases}\log\delta\left(u_{0,ni}-x_{n}^{i}\right)&l=1\\ 0&l>1\end{cases},

l,m:bl,mhl,mb=0\forall l,m:\triangle_{b_{l,m}\rightarrow h_{l,m}^{b}}=0, l,m,n:Wl,mnhl,mnW=0\forall l,m,n:\triangle_{W_{l,mn}\rightarrow h_{l,mn}^{W}}=0.

1:  for τ=1,,τmax\tau=1,\ldots,\tau_{\max} do
2:     \bullet E-step:
3:     Set prior distribution p(𝜽)p\left(\boldsymbol{\theta}\right) and likelihood function p(𝑫|𝒛L)p\left(\boldsymbol{D}|\boldsymbol{z}_{L}\right) based on 𝝍\boldsymbol{\psi} and vv.
4:     for r=1,,Rr=1,\ldots,R do
5:        Module BB (SPMP)
6:        Update output messages for Module AA l,m,n:l,mnBA\forall l,m,n:\triangle_{l,mn}^{B\rightarrow A} as (10) and posterior distribution for weight groups l,n:p(𝑾l,n|𝑫r)\forall l,n:p\left(\boldsymbol{W}_{l,n}|\boldsymbol{D}^{r}\right) as (11).
7:        Module AA (DAMP)
8:        Set prior distributions l,m,n:hl,mnWWl,mn=l,mnBA,l,m:hl,mbbl,m=p(bl,m)\forall l,m,n:\triangle_{h_{l,mn}^{W}\rightarrow W_{l,mn}}=\triangle_{l,mn}^{B\rightarrow A},\forall l,m:\triangle_{h_{l,m}^{b}\rightarrow b_{l,m}}=p\left(b_{l,m}\right) and likelihood function p(𝑫yr|𝒛Lr)p\left(\boldsymbol{D}_{y}^{r}|\boldsymbol{z}_{L}^{r}\right) with noise variance vv.
9:        %Forward message passing
10:        for l=1,,Ll=1,\ldots,L do
11:           Update input messages n,i:hl1,niul1,ni\forall n,i:\triangle_{h_{l-1,ni}\rightarrow u_{l-1,ni}}.
12:           Update posterior messages m,n,i:bl,m,Wl,mn,ul1,ni\forall m,n,i:\triangle_{b_{l,m}},\triangle_{W_{l,mn}},\triangle_{u_{l-1,ni}}.
13:           Update forward messages m,i:fl,mizl,mi\forall m,i:\triangle_{f_{l,mi}\rightarrow z_{l,mi}}.
14:        end for
15:        %Backward message passing
16:        for l=L,,1l=L,\ldots,1 do
17:           Update input messages m,i:zl,mifl,mi\forall m,i:\triangle_{z_{l,mi}\rightarrow f_{l,mi}}.
18:           Update posterior messages m,i:zl,mi\forall m,i:\triangle_{z_{l,mi}}.
19:           Update aggregated backward messages m,n,i:bl,mhl,mb,Wl,mnhl,mnW,ul1,nihl1,ni\forall m,n,i:\triangle_{b_{l,m}\rightarrow h_{l,m}^{b}},\triangle_{W_{l,mn}\rightarrow h_{l,mn}^{W}},\triangle_{u_{l-1,ni}\rightarrow h_{l-1,ni}}.
20:        end for
21:        PasP
22:        Update prior distribution p(𝜽)p\left(\boldsymbol{\theta}\right) as (12).
23:     end for
24:     \bullet M-step:
25:     Update 𝝍\boldsymbol{\psi} as the corresponding parameters in p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right).
26:     Update noise variance vv through (25)/(17) in regression/classification model.
27:  end for
28:  Output p(𝜽|𝑫k)=p(𝜽)p\left(\boldsymbol{\theta}|\boldsymbol{D}^{k}\right)=p\left(\boldsymbol{\theta}\right) and p(𝒛Lk|𝑫k)exp(m=1NLi=1I𝒛L,mi)p\left(\boldsymbol{z}_{L}^{k}|\boldsymbol{D}^{k}\right)\propto\exp\left(\sum_{m=1}^{N_{L}}\sum_{i=1}^{I}\triangle_{\boldsymbol{z}_{L,mi}}\right).

IV Extension of EM-TDAMP to Federated Learning Scenarios

IV-A Outline of Bayesian Federated Learning (BFL) Framework

In this Section, we consider a general federated/distributed learning scenario, which includes centralized learning as a special case. There is a central server and KK clients, where each client k=1,,Kk=1,\ldots,K possesses a subset of data (local data sets) indexed by k\mathcal{I}_{k}: 𝑫k{𝑫xk,𝑫yk}\boldsymbol{D}^{k}\triangleq\left\{\boldsymbol{D}_{x}^{k},\boldsymbol{D}_{y}^{k}\right\} with 𝑫xk{𝒙i|ik}\boldsymbol{D}_{x}^{k}\triangleq\left\{\boldsymbol{x}^{i}|i\in\mathcal{I}_{k}\right\}, 𝑫yk{𝒚i|ik}\boldsymbol{D}_{y}^{k}\triangleq\left\{\boldsymbol{y}^{i}|i\in\mathcal{I}_{k}\right\} and k=1Kk={1,2,,I}\cup_{k=1}^{K}\mathcal{I}_{k}=\left\{1,2,\ldots,I\right\}. The process of the proposed BFL framework contains three steps as illustrated in Fig. 5. Firstly, the central server sends the prior hyperparameters 𝝍\boldsymbol{\psi} and the likelihood hyperparameter vv (i.e., noise variance) to clients to initialize local prior distribution p(𝜽)p\left(\boldsymbol{\theta}\right) and likelihood function p(𝑫yk|𝒛Lk)p\left(\boldsymbol{D}_{y}^{k}|\boldsymbol{z}_{L}^{k}\right), where 𝒛LkNL×Ik\boldsymbol{z}_{L}^{k}\in\mathbb{R}^{N_{L}\times I_{k}} represents the output corresponding to the local data 𝑫xk\boldsymbol{D}_{x}^{k}. Afterwards, the clients parallelly compute local posterior distributions p(𝜽|𝑫k)p\left(\boldsymbol{\theta}|\boldsymbol{D}^{k}\right) and p(𝒛Lk|𝑫k)p\left(\boldsymbol{z}_{L}^{k}|\boldsymbol{D}^{k}\right) by performing turbo deep approximate message passing (TDAMP) as detailed in Subsection III-B and extract local posterior parameters 𝝋k\boldsymbol{\varphi}^{k} and σk\sigma^{k} for uplink communication. Lastly, the central server aggregates local posterior parameters to update hyperparameters 𝝍\boldsymbol{\psi} and vv by maximizing the expectation in (8) as will be detailed in Subsection IV-B, where we define local posterior parameters 𝝋k,σk\boldsymbol{\varphi}^{k},\sigma^{k} and approximate 𝝍,v\boldsymbol{\psi},v as function of 𝝋k,σk,k=1,,K\boldsymbol{\varphi}^{k},\sigma^{k},k=1,\ldots,K.

Refer to caption
Figure 5: Illustration for federated learning framework, where fk(𝜽),gk(𝜽)f_{k}\left(\boldsymbol{\theta}\right),g_{k}\left(\boldsymbol{\theta}\right) represents p(𝒛Lk|𝑫xk,𝜽)p\left(\boldsymbol{z}_{L}^{k}|\boldsymbol{D}_{x}^{k},\boldsymbol{\theta}\right) and p(𝑫yk|𝒛Lk)p\left(\boldsymbol{D}_{y}^{k}|\boldsymbol{z}_{L}^{k}\right), respectively for k=1,,Kk=1,\ldots,K.

IV-B Updating Rules At the Central Server

In the proposed EM-based BFL framework, the central server computes the global posterior distributions p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) and p(𝒛L|𝑫)p\left(\boldsymbol{z}_{L}|\boldsymbol{D}\right) by aggregating the local posterior distributions in the E-step, and update the hyperparameters 𝝍,v\boldsymbol{\psi},v by maximizing the objective function (8) in the M-step. The specific aggregation mechanism and updating rules are elaborated as follows:

IV-B1 Aggregation Mechanism for p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) and p(𝒛L|𝑫)p\left(\boldsymbol{z}_{L}|\boldsymbol{D}\right)

Aggregation Mechanism for p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right)

We approximate p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) as the weighted geometric average of local posterior distributions p(𝜽|𝑫k),k=1,,Kp\left(\boldsymbol{\theta}|\boldsymbol{D}^{k}\right),k=1,\ldots,K [40]:

p(𝜽|𝑫)k=1K(p(𝜽|𝑫k))IkI,p\left(\boldsymbol{\theta}|\boldsymbol{D}\right)\approx\prod_{k=1}^{K}\left(p\left(\boldsymbol{\theta}|\boldsymbol{D}^{k}\right)\right)^{\frac{I_{k}}{I}}, (18)

where Ik=|k|I_{k}=\left|\mathcal{I}_{k}\right|.

The proposed weighted geometric average of p(𝜽|𝑫k),k=1,,Kp\left(\boldsymbol{\theta}|\boldsymbol{D}^{k}\right),k=1,\ldots,K in (18) is more likely to approach the global optimal posterior distribution compared to the widely used weighted algebraic average (19):

pAA(𝜽|𝑫)=k=1KIkIp(𝜽|𝑫k).p_{AA}\left(\boldsymbol{\theta}|\boldsymbol{D}\right)=\sum_{k=1}^{K}\frac{I_{k}}{I}p\left(\boldsymbol{\theta}|\boldsymbol{D}^{k}\right). (19)

For easy illustration of this point, we consider a special case when all the local posterior distributions are Gaussian (note that Gaussian is a special case of the Bernoulli-Gaussian). The posterior distribution aggregated through weighted geometric average (WGA) (18) is still Gaussian, whose expectation μθ,WGA\mu_{\theta,WGA} is the average of local posterior expectations μθ,k\mu_{\theta,k} weighted by the corresponding variances vθ,kv_{\theta,k} as in (20), while the posterior distribution aggregated through weighted algebraic average (WAA) (19) is Gaussian mixture, whose expectation is simple average of local posterior expectations as in (21). Therefore, WGA is more reliable compared with WAA because it utilizes the local variances for posterior expectation aggregation, which is consistent with the experiment results in [40].

μθ,WGA=(k=1KIkI1vθ,k)1(k=1KIkIμθ,kvθ,k)\mu_{\theta,WGA}=\left(\sum_{k=1}^{K}\frac{I_{k}}{I}\frac{1}{v_{\theta,k}}\right)^{-1}\left(\sum_{k=1}^{K}\frac{I_{k}}{I}\frac{\mu_{\theta,k}}{v_{\theta,k}}\right) (20)
μθ,WAA=k=1KIkIμθ,k.\mu_{\theta,WAA}=\sum_{k=1}^{K}\frac{I_{k}}{I}\mu_{\theta,k}. (21)

The WGA based aggregation in (18) can also be explained from a loss function perspective. In Bayesian learning, we estimate parameters based on MAP, which can also be interpreted as minimizing a loss function LNLP(𝜽,𝑫)L_{NLP}\left(\boldsymbol{\theta},\boldsymbol{D}\right) if we define LNLP(𝜽,𝑫)L_{NLP}\left(\boldsymbol{\theta},\boldsymbol{D}\right) as negative log-posterior log(p(𝜽|𝑫))-\log\left(p\left(\boldsymbol{\theta}|\boldsymbol{D}\right)\right):

𝜽^\displaystyle\hat{\boldsymbol{\theta}} =argmax𝜽p(𝜽|𝑫)\displaystyle=\text{argmax}_{\boldsymbol{\theta}}p\left(\boldsymbol{\theta}|\boldsymbol{D}\right)
=argmin𝜽log(p(𝜽|𝑫))\displaystyle=\text{argmin}_{\boldsymbol{\theta}}-\log\left(p\left(\boldsymbol{\theta}|\boldsymbol{D}\right)\right)
=argmin𝜽log(p(𝑫|𝜽))log(p(𝜽))\displaystyle=\text{argmin}_{\boldsymbol{\theta}}-\log\left(p\left(\boldsymbol{D}|\boldsymbol{\theta}\right)\right)-\log\left(p\left(\boldsymbol{\theta}\right)\right)
=argmin𝜽12v{𝒙i,𝒚i}𝑫𝒚iϕ(𝒙i;𝜽)2logp(𝜽).\displaystyle=\text{argmin}_{\boldsymbol{\theta}}\frac{1}{2v}\sum_{\left\{\boldsymbol{x}^{i},\boldsymbol{y}^{i}\right\}\in\boldsymbol{D}}\left\|\boldsymbol{y}^{i}-\phi\left(\boldsymbol{x}^{i};\boldsymbol{\theta}\right)\right\|^{2}-\log p\left(\boldsymbol{\theta}\right).

Note that LMSE(𝜽,𝑫)L_{MSE}\left(\boldsymbol{\theta},\boldsymbol{D}\right) in (2) and LMSE,l1(𝜽,𝑫)L_{MSE,l_{1}}\left(\boldsymbol{\theta},\boldsymbol{D}\right) in (3) are special cases of LNLP(𝜽,𝑫)L_{NLP}\left(\boldsymbol{\theta},\boldsymbol{D}\right). Specifically, after setting vv as 12\frac{1}{2}, if we set p(𝜽)p\left(\boldsymbol{\theta}\right) as uniform distribution (i.e., logp(𝜽)-\log p\left(\boldsymbol{\theta}\right) is constant), LNLP(𝜽,𝑫)L_{NLP}\left(\boldsymbol{\theta},\boldsymbol{D}\right) becomes LMSE(𝜽,𝑫)L_{MSE}\left(\boldsymbol{\theta},\boldsymbol{D}\right), while if we set p(𝜽)p\left(\boldsymbol{\theta}\right) as Laplace distribution (i.e., p(𝜽)=12bexp(|xa|b)p\left(\boldsymbol{\theta}\right)=\frac{1}{2b}\exp\left(-\frac{|x-a|}{b}\right)) with a=0,b=1λa=0,b=\frac{1}{\lambda}, LNLP(𝜽,𝑫)L_{NLP}\left(\boldsymbol{\theta},\boldsymbol{D}\right) becomes LMSE,l1(𝜽,𝑫)L_{MSE,l_{1}}\left(\boldsymbol{\theta},\boldsymbol{D}\right). Therefore, LNLP(𝜽,𝑫)L_{NLP}\left(\boldsymbol{\theta},\boldsymbol{D}\right) can be seen as a loss function in Bayesian learning algorithms.

In federated learning algorithms, the global loss function is normally formulated as weighted sum of loss functions at clients, i.e. (22) as used in [40]:

LNLP(𝜽,𝑫)=k=1KIkILNLP(𝜽,𝑫k),L_{NLP}\left(\boldsymbol{\theta},\boldsymbol{D}\right)=\sum_{k=1}^{K}\frac{I_{k}}{I}L_{NLP}\left(\boldsymbol{\theta},\boldsymbol{D}_{k}\right), (22)

where negative log-posterior loss function is used in Bayesian framework. The loss function aggregation in (22) is equivalent to the weighted geometric average aggregation mechanism in (18), which provides another justification for WGA.

The specific derivation for parameters in p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) according to (18) is detailed in the supplementary file.

Aggregation Mechanism for p(𝒛L|𝑫)p\left(\boldsymbol{z}_{L}|\boldsymbol{D}\right)

We assume 𝒛Lk,k=1,,K\boldsymbol{z}_{L}^{k},k=1,\ldots,K are independent and approximate p(𝒛L|𝑫)p\left(\boldsymbol{z}_{L}|\boldsymbol{D}\right) as

p(𝒛L|𝑫)=k=1Kp(𝒛Lk|𝑫k),p\left(\boldsymbol{z}_{L}|\boldsymbol{D}\right)=\prod_{k=1}^{K}p\left(\boldsymbol{z}_{L}^{k}|\boldsymbol{D}^{k}\right), (23)

which is reasonable since 𝒛Lk\boldsymbol{z}_{L}^{k} is mainly determined by the kk-th local data set 𝑫k\boldsymbol{D}^{k} that is independent of the other local data sets 𝑫k,kk\boldsymbol{D}^{k^{{}^{\prime}}},k^{{}^{\prime}}\neq k. The local posterior distribution p(𝒛Lk|𝑫k)p\left(\boldsymbol{z}_{L}^{k}|\boldsymbol{D}^{k}\right) is the output of DAMP, which is approximated as the product of Gaussian marginal posterior distributions (24):

p(𝒛Lk|𝑫k)\displaystyle p\left(\boldsymbol{z}_{L}^{k}|\boldsymbol{D}^{k}\right) ikm=1NLN(zL,mi;μzL,mi,vzL,mi),\displaystyle\approx\prod_{i\in\mathcal{I}_{k}}\prod_{m=1}^{N_{L}}N\left(z_{L,mi};\mu_{z_{L,mi}},v_{z_{L,mi}}\right), (24)

as detailed in the supplementary file. By plugging (24) into (23), we achieve the global posterior distribution for 𝒛L\boldsymbol{z}_{L}.

IV-B2 Updating Rules for 𝝍\boldsymbol{\psi} and vv

Updating Rules for 𝝍\boldsymbol{\psi}

As detailed in the supplementary file, the aggregated global posterior distribution p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) can be factorized in the same form as p(𝜽)p\left(\boldsymbol{\theta}\right), thus maximizing E(logp(𝜽))E\left(\log p\left(\boldsymbol{\theta}\right)\right) is equivalent to update 𝝍\boldsymbol{\psi} as the corresponding parameters in p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right). In the supplementary file, we define 𝝋k,k=1,,K\boldsymbol{\varphi}^{k},k=1,\cdots,K as local posterior parameters for uplink communication and give the function of 𝝍\boldsymbol{\psi} w.r.t. 𝝋k,k=1,,K\boldsymbol{\varphi}^{k},k=1,\cdots,K.

Updating Rules for vv

In federated learning, based on (23), the expectation E(logp(𝑫y|𝒛L))E\left(\log p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right)\right) can be written as:

E(logp(𝑫y|𝒛L))\displaystyle E\left(\log p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right)\right) =k=1KE(logp(𝑫yk|𝒛Lk)),\displaystyle=\sum_{k=1}^{K}E\left(\log p\left(\boldsymbol{D}_{y}^{k}|\boldsymbol{z}_{L}^{k}\right)\right),

where the expectation is w.r.t. (23), which can be computed based on local posterior distributions (24). In practice, for E(logp(𝑫y|𝒛L))E\left(\log p\left(\boldsymbol{D}_{y}|\boldsymbol{z}_{L}\right)\right), the maximum point w.r.t. vv can be expressed as a function of local posterior parameters. This means that the clients only need to send a few posterior parameters denoted as σk,k=1,,K\sigma^{k},k=1,\ldots,K instead of the posterior distributions p(𝒛Lk|𝑫k),k=1,,Kp\left(\boldsymbol{z}_{L}^{k}|\boldsymbol{D}^{k}\right),k=1,\ldots,K to the central server. In the following, we take regression model and classification model as examples to derive the updating rule for vv and define parameters σk\sigma^{k} at client kk to compress parameters in uplink communication.

For regression model (6), vv^{*} in (13) is equivalent to weighted sum of σk\sigma_{k}:

v=k=1KIkIσk,v^{*}=\sum_{k=1}^{K}\frac{I_{k}}{I}\sigma_{k}, (25)

where we define

σkikm=1NL(ymiμzL,mi)2+vzL,miNLIk.\sigma_{k}\triangleq\sum_{i\in\mathcal{I}_{k}}\sum_{m=1}^{N_{L}}\frac{\left(y_{m}^{i}-\mu_{z_{L,mi}}\right)^{2}+v_{z_{L,mi}}}{N_{L}I^{k}}. (26)

For classification model (7), μ,E\mu,E in (16) is equivalent to weighted sum of μk,Ek\mu_{k},E_{k}:

μ\displaystyle\mu =k=1KIkIμk,E=k=1KIkIEk,\displaystyle=\sum_{k=1}^{K}\frac{I_{k}}{I}\mu_{k},E=\sum_{k=1}^{K}\frac{I_{k}}{I}E_{k}, (27)

where we define

μk\displaystyle\mu_{k} ikmyiμξmi(NL1)Ik,Ekikmyi(μξmi2+vξmi)(NL1)Ik.\displaystyle\triangleq\frac{\sum_{i\in\mathcal{I}_{k}}\sum_{m\neq y^{i}}\mu_{\xi_{mi}}}{\left(N_{L}-1\right)I_{k}},E_{k}\triangleq\frac{\sum_{i\in\mathcal{I}_{k}}\sum_{m\neq y^{i}}\left(\mu_{\xi_{mi}}^{2}+v_{\xi_{mi}}\right)}{\left(N_{L}-1\right)I_{k}}. (28)

Based on aggregated μ,E\mu,E, vv^{*} can be updated as (17).

IV-C Summary of the Entire Bayesian Federated Learning Algorithm

The entire EM-TDAMP Bayesian federated learning algorithm is summarized in Algorithm 2, where TmaxT_{max} represents maximum communication rounds.

Algorithm 2 EM-TDAMP Bayesian Federated Learning Algorithm

Input: Training set 𝑫k,k=1,,K\boldsymbol{D}^{k},k=1,\ldots,K.

Output: p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right).

1:  Initialization: Hyperparameters 𝝍,v\boldsymbol{\psi},v.
2:  for t=1Tmaxt=1\ldots T_{max} do
3:     \bullet Step1 (Broadcast)
4:     The central server sends 𝝍\boldsymbol{\psi} and vv to the clients.
5:     \bullet Step2 (Local training)
6:     for each clinet k=1Kk=1\cdots K do
7:        Update local posterior distribution p(𝜽|𝑫k)p\left(\boldsymbol{\theta}|\boldsymbol{D}^{k}\right) and p(𝒛Lk|𝑫k)p\left(\boldsymbol{z}_{L}^{k}|\boldsymbol{D}^{k}\right) by performing TDAMP as in Algorithm 1 with input hyperparameters 𝝍,v\boldsymbol{\psi},v.
8:        Extract 𝝋k\boldsymbol{\varphi}^{k} from p(𝜽|𝑫k)p\left(\boldsymbol{\theta}|\boldsymbol{D}^{k}\right).
9:        Compute σk\sigma^{k} through (26)/(28) in regression/classification model.
10:        Send 𝝋k,σk\boldsymbol{\varphi}^{k},\sigma^{k} to the central server.
11:     end for
12:     \bullet Step3 (Aggregation)
13:     The central server compute posterior distribution p(𝜽|𝑫)p\left(\boldsymbol{\theta}|\boldsymbol{D}\right) through aggregation (18) and extract hyperparameters as 𝝍\boldsymbol{\psi}.
14:     The central server compute noise variance vv based on σk\sigma^{k}s through (25)/(17) in regression/classification model.
15:  end for

V Performance Evaluation

In this section, we evaluate the performance of the proposed EM-TDAMP through simulations. We consider two commonly used application scenarios with datasets available online: the Boston house price prediction and handwriting recognition, which were selected to evaluate the performance of our algorithm in dealing with regression and classification problems, respectively.

We consider group sparse prior and compare to three baseline algorithms: AMP in [27] (Due to the message passing algorithms in [27] showing similar performance, we only add the AMP based training algorithm into comparison in the experiments), standard SGD, SGD with group sparse regularizer [13] and group SNIP (for fair comparison, we extend SNIP [32] to prune neurons). For convenience, we use Adam optimizer [7] for SGD-based baseline algorithms. We set damping factor α=0.8\alpha=0.8, and utilize a random Boolean mask for pruning as in [27].

Two cases are considered in the simulations. Firstly, we consider centralized learning case to compare the EM-TDAMP with AMP in [27] and SGD-based algorithms. Furthermore, we consider federated learning case to prove the superiority of the aggregation mechanism mentioned in Subsection IV-B1, compared to SGD-based baseline algorithms with widely-used FedAvg algorithm [8] for aggregation.

Before presenting the simulation results, we briefly compare the complexity. Here we neglect element-wise operations and only consider multiplications in matrix multiplications, which occupy the main running time in both SGD-based algorithms and the proposed EM-TDAMP algorithm. It can be shown that both SGD and EM-TDAMP require O(Il=1LNl1Nl)O\left(I\sum_{l=1}^{L}N_{l-1}N_{l}\right) multiplications per iteration, and thus they have similar complexity orders.

In the following simulations, we will focus on comparing the convergence speed and converged performance of the algorithms. Specifically, we will show loss on the test data during training process to evaluate the convergence speed (group SNIP becomes standard SGD when ρ=1\rho=1, and thus we set ρ=0.5\rho=0.5 to show the training process) and also show the converged performance under varying sparsity/pruning ratios to compare the proposed algorithm with baseline algorithms comprehensively. To achieve the target group sparsity for baseline algorithms, we need to manually prune the parameter groups based on energy after training [7]. When calculating the loss (NMSE for regression model and error for classification model) on test data for the proposed EM-TDAMP, we fix the parameters as posterior expectations (i.e., we use MMSE point estimate for the parameters). Each result is averaged on 10 experiments.

V-A Description of Models

V-A1 Boston Housing Price Prediction

For regression model, we train a DNN based on Boston housing price dataset. The training set consists of 404 past housing price, each associated with 13 relative indexes for prediction. For convenience, we set the batchsize as 101 in the following simulations. The test dataset contains 102 data. We set the architecture as follows: the network comprises three layers, including two hidden layers, each with 64 output neurons and ReLU activation, and an output layer with one output neuron. Before training, we normalize the data for stability. We evaluate the prediction performance using the normalized mean square error (NMSE) as the criterion.

V-A2 Handwriting Recognition

For classification model, we train a DNN based on MNIST dataset, which is widely used in machine learning for handwriting digit recognition. The training set consists of 60,000 individual handwritten digits collected from postal codes, with each digit labeled from 0 to 9. The images are grayscale and represented as 28×2828\times 28 pixels. In our experiments, we set the batch size as 100. The test set consists of 10,000 digits. Before training, each digit is converted into a column vector and divided by the maximum value of 255. We use a two-layer network, where the first layer has 128 output neurons and a ReLU activation function, while the second layer has 10 output neurons. After that, there is a softmax activation function for the baseline algorithms and Probit-product likelihood function for the proposed algorithm. We will use the error on test data to evaluate the performance.

V-B Simulation Results

We start by evaluating the performance of EM-TDAMP in a centralized learning scenario. The training curves and test loss-sparsity curves for both regression and classification models are depicted in Fig. 6 and Fig. 7. Fig. 6 shows the training curve of the proposed EM-TDAMP and baselines, where we set ρ=1\rho=1, i.e. Gaussian prior (Bernoulli-Gaussian prior in (5) becomes Gaussian when ρ=1\rho=1) for EM-TDAMP. The results show the proposed EM-TDAMP achieves faster training speed and also the best performance after enough rounds compared to Adam and AMP in [27]. There are two main reasons. First, compared to Adam, message passing procedure updates variance of the parameters during iterations, which makes inference more accurate after same rounds, leading to faster convergence. Second, the noise variance can be automatically learned based on EM algorithm, which can adaptively control the learning rate and avoid manually tuning of parameters like Adam. AMP in [27] does not design flexible updating rules for noise variance during iterations and sets a fixed damping factor α\alpha to control the learning rate, leading to numerical instability and slow convergence in experiments. Fig. 7 shows the test loss of the algorithms at different sparsity, where sparsity refers to the ratio of neurons remain. From the results at ρ=1\rho=1 (on the right edge of the figures) we can see EM-TDAMP with Gaussian prior performs better than AMP in [27] and Adam after convergence when pruning is not considered, which is consistent with the training curve in Fig. 6. Then, as the compression ratio becomes higher (from right points to left points), the performance gap between EM-TDAMP and baselines becomes larger, because the proposed EM-TDAMP prunes the groups based on sparsity during training, which is more efficient than baseline methods that prune based on energy or gradients.

Refer to caption
(a) Test NMSE in Boston housing price prediction.
Refer to caption
(b) Test error in handwriting recognition.
Figure 6: In centralized learning case, training curves of the proposed EM-TDAMP compared to baselines.
Refer to caption
(a) Test NMSE in Boston housing price prediction.
Refer to caption
(b) Test error in handwriting recognition.
Figure 7: In centralized learning case, converged performance of proposed EM-TDAMP compared to baselines at different sparsity.

Next, we verify the efficiency of the updating rule for noise variance in classification model discussed in Subsection III-C. The Gumbel approximation (15) is illustrated in Fig. 8, where we compare the distributions when t=1t=1 and t=30t=30. Since scaling will not affect the solution for noise variance vv (15), we scale the distributions to set the maximum as 1 and only compare the shapes. We observe that both distributions have similar skewed shapes.

Refer to caption
Figure 8: PDF of mixed Gaussian distribution and the approximated Gumbel distribution in (15), where MGMG represents mixed Gaussian distribution and GG represents Gumbel distribution.
Refer to caption
(a) Noise variance curve.
Refer to caption
(b) Test error curve.
Figure 9: Comparison of different noise variance updating methods during training.

Furthermore, in Fig. 9, we compare the training performance achieved by different noise variance updating methods, where we set a large initialization to enhance the comparison during iterations. From Fig. 9a, we observe that the proposed updating rule is stable and can update noise variance similar to the numerical solution. Fig. 9b shows the proposed method achieves comparable training speed to the numerical solution, and both outperform the fixed noise variance case.

In the subsequent experiments, we consider federated learning cases to evaluate the aggregation mechanism. For convenience, we allocate an equal amount of data to each client, i.e., Ik=IKI_{k}=\frac{I}{K} for k=1,,Kk=1,\cdots,K. In Boston housing price prediction and handwriting recognition tasks, we set K=4K=4 and K=10K=10, respectively. To reduce communication rounds, we set τmax=10\tau_{max}=10 in both cases (τmax\tau_{max} refers to the number of TDAMP inner iterations with fixed hyperparameters at each client in each round). The training curves and test loss-sparsity curves are shown in Fig. 10 and Fig. 11, respectively. Similar to the previous results, EM-TDAMP performs best among the algorithms, which proves the efficiency of the proposed aggregation method.

Refer to caption
(a) Test NMSE in Boston housing price prediction.
Refer to caption
(b) Test error in handwriting recognition.
Figure 10: In federated learning case, training curves of the proposed EM-TDAMP compared to baselines.
Refer to caption
(a) Test NMSE in Boston housing price prediction.
Refer to caption
(b) Test error in handwriting recognition.
Figure 11: In federated learning case, converged performance of proposed EM-TDAMP compared to baselines at different sparsity.

VI Conclusions

In this work, we propose an EM-TDAMP algorithm to achieve efficient Bayesian deep learning and compression, and extend EM-TDAMP to federated learning scenarios. In problem formulation, we propose a group sparse prior to promote neuron-level compression and introduce Gaussian noise at output to prevent numerical instability. Then, we propose a novel Bayesian deep learning framework based on EM and approximate message passing. In the E-step, we compute the posterior distribution by performing TDAMP, which consists of a Module BB to deal with group sparse prior distribution, a Module AA to enable efficient approximate message passing over DNN, and a PasP method to automatically tune the local prior distribution. In the M-step, we update hyperparameters to accelerate convergence. Moreover, we extend the proposed EM-TDAMP to federated learning scenarios and propose a novel Bayesian federated learning framework, where the clients compute the local posterior distributions via TDAMP, while the central server computes the global posterior distribution through aggregation and updates hyperparameters via EM. Simulations show that the proposed EM-TDAMP can achieve faster convergence speed and better training performance compared to well-known structured pruning methods with Adam optimizer and the existing multilayer AMP algorithms in [27], especially when the compression ratio is high. Besides, the proposed EM-TDAMP can greatly reduce communication rounds in federated learning scenarios, making it attractive to practical applications. In the future, we will apply the proposed EM-TDAMP framework to design better training algorithms for more general DNNs, such as those with convolutional layers.

-A Nonlinear Steps

In this section, we mainly discuss the updating rules of ul1,ni,zl,mi\triangle_{u_{l-1,ni}},\triangle_{z_{l,mi}} when related to nonlinear factors. Here we only provide the derivation for the messages, while the specific updating rules for expectation and variance will be detailed in the supplementary file.

-A1 ReLU Activation Function

ReLU is an element-wise function defined as (1). In this part, we give the updating rules for posterior messages of ul,miu_{l,mi} and zl,miz_{l,mi} for m,i\forall m,i when ζl()\zeta_{l}\left(\cdot\right) is ReLU. Based on sum-product rule, we obtain:

exp(u)\displaystyle\exp\left(\triangle_{u}\right) δ(u)Q(μfzvfz)N(μuh,vuh)\displaystyle\propto\delta\left(u\right)Q\left(\frac{\mu_{f\rightarrow z}}{\sqrt{v_{f\rightarrow z}}}\right)N\left(\mu_{u\rightarrow h},v_{u\rightarrow h}\right)
+U(u)N(μfzμuh,vfz+vuh)\displaystyle+U\left(u\right)N\left(\mu_{f\rightarrow z}-\mu_{u\rightarrow h},v_{f\rightarrow z}+v_{u\rightarrow h}\right)
×N(u;μuhvfz+μfzvuhvfz+vuh,vfzvuhvfz+vuh),\displaystyle\times N\left(u;\frac{\mu_{u\rightarrow h}v_{f\rightarrow z}+\mu_{f\rightarrow z}v_{u\rightarrow h}}{v_{f\rightarrow z}+v_{u\rightarrow h}},\frac{v_{f\rightarrow z}v_{u\rightarrow h}}{v_{f\rightarrow z}+v_{u\rightarrow h}}\right),
exp(z)\displaystyle\exp\left(\triangle_{z}\right) U(z)N(μuh,vuh)N(z;μfz,vfz)\displaystyle\propto U\left(-z\right)N\left(\mu_{u\rightarrow h},v_{u\rightarrow h}\right)N\left(z;\mu_{f\rightarrow z},v_{f\rightarrow z}\right)
+U(z)N(μuhμfz,vfz+vuh)\displaystyle+U\left(z\right)N\left(\mu_{u\rightarrow h}-\mu_{f\rightarrow z},v_{f\rightarrow z}+v_{u\rightarrow h}\right)
×N(z;μuhvfz+μfzvuhvfz+vuh,vfzvuhvfz+vuh),\displaystyle\times N\left(z;\frac{\mu_{u\rightarrow h}v_{f\rightarrow z}+\mu_{f\rightarrow z}v_{u\rightarrow h}}{v_{f\rightarrow z}+v_{u\rightarrow h}},\frac{v_{f\rightarrow z}v_{u\rightarrow h}}{v_{f\rightarrow z}+v_{u\rightarrow h}}\right),

where for convenience, we omit the subscript l,mil,mi and define U()U\left(\cdot\right) as step function.

-A2 Probit-product Likelihood Function

Here, we briefly introduce the message passing related to output zL,miz_{L,mi} for m,i\forall m,i in classification model. The factor graph of Probit-product likelihood function (7) is given in Fig. 12, where we omit L,iL,i for simplicity.

Refer to caption
Figure 12: The factor graph of Probit-product likelihood function, where hm=Q(zmzyv)h_{m}=Q\left(\frac{z_{m}-z_{y}}{\sqrt{v}}\right).

Firstly, to deal with hmzy\triangle_{h_{m}\rightarrow z_{y}} for my\forall m\neq y, we define

zy,m\displaystyle\triangle_{z_{y,m}} hmzy+fyzy\displaystyle\triangleq\triangle_{h_{m}\rightarrow z_{y}}+\triangle_{f_{y}\rightarrow z_{y}}
=logzmexp(fmzm+fyzy)Q(zmzyv)\displaystyle=\log\int_{z_{m}}\exp\left(\triangle_{f_{m}\rightarrow z_{m}}+\triangle_{f_{y}\rightarrow z_{y}}\right)Q\left(\frac{z_{m}-z_{y}}{\sqrt{v}}\right)
=log(exp(fyzy)Q(μfmzmzyv+vfmzm)),\displaystyle=\log\left(\exp\left(\triangle_{f_{y}\rightarrow z_{y}}\right)Q\left(\frac{\mu_{f_{m}\rightarrow z_{m}}-z_{y}}{\sqrt{v+v_{f_{m}\rightarrow z_{m}}}}\right)\right),

where exp(zy,m)\exp\left(\triangle_{z_{y,m}}\right) is a skew-normal distribution, and will be approximated as Gaussian based on moment matching. Then,

hmzy=zy,mfyzy\triangle_{h_{m}\rightarrow z_{y}}=\triangle_{z_{y,m}}-\triangle_{f_{y}\rightarrow z_{y}}

is also approximated as logarithm of Gaussian. Next, based on sum-product rule, we obtain:

zy=fyzy+myhmzy,\triangle_{z_{y}}=\triangle_{f_{y}\rightarrow z_{y}}+\sum_{m\neq y}\triangle_{h_{m}\rightarrow z_{y}},
zyhm=zyhmzy.\triangle_{z_{y}\rightarrow h_{m}}=\triangle_{z_{y}}-\triangle_{h_{m}\rightarrow z_{y}}.

At last, for mym\neq y, we approximate

zm\displaystyle\triangle_{z_{m}} =hmzm+fmzm\displaystyle=\triangle_{h_{m}\rightarrow z_{m}}+\triangle_{f_{m}\rightarrow z_{m}}
=logzyexp(zyhm+fmzm)Q(zmzyv)\displaystyle=\log\int_{z_{y}}\exp\left(\triangle_{z_{y}\rightarrow h_{m}}+\triangle_{f_{m}\rightarrow z_{m}}\right)Q\left(\frac{z_{m}-z_{y}}{\sqrt{v}}\right)
=log(exp(fmzm)Q(μzyhmzyv+vzyhm))\displaystyle=\log\left(\exp\left(\triangle_{f_{m}\rightarrow z_{m}}\right)Q\left(\frac{\mu_{z_{y}\rightarrow h_{m}}-z_{y}}{\sqrt{v+v_{z_{y}\rightarrow h_{m}}}}\right)\right)

as logarithm of Gaussian based on moment matching again.

References

  • [1] T. Li, A. K. Sahu, A. Talwalkar, and V. Smith, “Federated learning: Challenges, methods, and future directions,” IEEE Signal Processing Magazine, vol. 37, no. 3, pp. 50–60, 2020.
  • [2] W. Oh and G. N. Nadkarni, “Federated learning in health care using structured medical data,” Advances in Kidney Disease and Health, vol. 30, no. 1, pp. 4–16.
  • [3] A. Nguyen, T. Do, M. Tran, B. X. Nguyen, C. Duong, T. Phan, E. Tjiputra, and Q. D. Tran, “Deep federated learning for autonomous driving,” in 2022 IEEE Intelligent Vehicles Symposium (IV), pp. 1824–1830.
  • [4] G. Ananthanarayanan, P. Bahl, P. BodÃk, K. Chintalapudi, M. Philipose, L. Ravindranath, and S. Sinha, “Real-time video analytics: The killer app for edge computing,” Computer, vol. 50, no. 10, pp. 58–67.
  • [5] W. Y. B. Lim, N. C. Luong, D. T. Hoang, Y. Jiao, Y.-C. Liang, Q. Yang, D. Niyato, and C. Miao, “Federated learning in mobile edge networks: A comprehensive survey,” IEEE Communications Surveys & Tutorials, vol. 22, no. 3, pp. 2031–2063.
  • [6] S. Ruder, “An overview of gradient descent optimization algorithms,” CoRR, vol. abs/1609.04747, 2016.
  • [7] D. P. Kingma and J. Ba, “Adam: A method for stochastic optimization,” in 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings, 2015.
  • [8] B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas, “Communication-efficient learning of deep networks from decentralized data,” in Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, AISTATS 2017, 20-22 April 2017, Fort Lauderdale, FL, USA, ser. Proceedings of Machine Learning Research, vol. 54.   PMLR, 2017, pp. 1273–1282.
  • [9] B. Hanin, “Which neural net architectures give rise to exploding and vanishing gradients?” in Advances in Neural Information Processing Systems, vol. 31, 2018.
  • [10] S. Hochreiter, “The vanishing gradient problem during learning recurrent neural nets and problem solutions,” Int. J. Uncertain. Fuzziness Knowl. Based Syst., vol. 6, no. 2, pp. 107–116, 1998.
  • [11] A. Choromanska, M. Henaff, M. Mathieu, G. B. Arous, and Y. LeCun, “The loss surfaces of multilayer networks,” in Proceedings of the Eighteenth International Conference on Artificial Intelligence and Statistics, AISTATS 2015, San Diego, California, USA, May 9-12, 2015, vol. 38.
  • [12] M. Yuan and Y. Lin, “Model selection and estimation in regression with grouped variables,” Journal of the Royal Statistical Society: Series B (Statistical Methodology), vol. 68, no. 1, pp. 49–67, Feb. 2006.
  • [13] S. Scardapane, D. Comminiello, A. Hussain, and A. Uncini, “Group sparse regularization for deep neural networks,” Neurocomputing, vol. 241, pp. 81–89, Jun. 2017.
  • [14] S. Kim and E. P. Xing, “Tree-guided group lasso for multi-response regression with structured sparsity, with an application to eqtl mapping,” The Annals of Applied Statistics, vol. 6, no. 3, pp. 1095–1117.
  • [15] K. Mitsuno, J. Miyao, and T. Kurita, “Hierarchical group sparse regularization for deep convolutional neural networks,” 2020 International Joint Conference on Neural Networks (IJCNN), pp. 1–8, 2020.
  • [16] G. Litjens, T. Kooi, B. E. Bejnordi, A. A. A. Setio, F. Ciompi, M. Ghafoorian, J. A. W. M. van der Laak, B. van Ginneken, and C. I. Sánchez, “A survey on deep learning in medical image analysis,” Medical Image Analysis, vol. 42, pp. 60–88, Dec. 2017.
  • [17] J. Kocic, N. S. Jovicic, and V. Drndarevic, “An end-to-end deep neural network for autonomous driving designed for embedded automotive platforms,” Sensors, vol. 19, no. 9, p. 2064, 2019.
  • [18] X. Jiang, M. Osl, J. Kim, and L. Ohno-Machado, “Calibrating predictive model estimates to support personalized medicine,” Journal of the American Medical Informatics Association: JAMIA, vol. 19, no. 2, pp. 263–274, 2012.
  • [19] H. Wang and D. Yeung, “A survey on bayesian deep learning,” ACM Comput. Surv., vol. 53, no. 5, pp. 108:1–108:37, 2021.
  • [20] J. L. Puga, M. Krzywinski, and N. Altman, “Bayesian networks,” Nature Methods, vol. 12, no. 9, pp. 799–800, Sep. 2015.
  • [21] J. T. Springenberg, A. Klein, S. Falkner, and F. Hutter, “Bayesian optimization with robust bayesian neural networks,” in Advances in Neural Information Processing Systems, vol. 29, 2016.
  • [22] T. M. Fragoso and F. L. Neto, “Bayesian model averaging: A systematic review and conceptual classification,” International Statistical Review, vol. 86, no. 1, pp. 1–28, Apr. 2018.
  • [23] M. Rani, S. B. Dhok, and R. B. Deshmukh, “A systematic review of compressive sensing: Concepts, implementations and applications,” IEEE Access, vol. 6, pp. 4875–4894, 2018.
  • [24] A. Montanari, Graphical models concepts in compressed sensing.   Cambridge University Press, 2012, pp. 394–438.
  • [25] J. Ma, X. Yuan, and L. Ping, “Turbo compressed sensing with partial dft sensing matrix,” IEEE Signal Processing Letters, vol. 22, no. 2, pp. 158–161, 2015.
  • [26] P. Simard, D. Steinkraus, and J. Platt, “Best practices for convolutional neural networks applied to visual document analysis,” in Seventh International Conference on Document Analysis and Recognition, 2003. Proceedings., Aug. 2003, pp. 958–963.
  • [27] C. Lucibello, F. Pittorino, G. Perugini, and R. Zecchina, “Deep learning via message passing algorithms based on belief propagation,” Machine Learning: Science and Technology, vol. 3, no. 3, p. 035005, Sep. 2022.
  • [28] P. McCullagh, “Generalized linear models,” European Journal of Operational Research, vol. 16, no. 3, pp. 285–292, Jun. 1984.
  • [29] T. Moon, “The expectation-maximization algorithm,” IEEE Signal Processing Magazine, vol. 13, no. 6, pp. 47–60, 1996.
  • [30] L. Liu and F. Zheng, “A bayesian federated learning framework with multivariate gaussian product,” CoRR, vol. abs/2102.01936, 2021.
  • [31] J. T. Parker, P. Schniter, and V. Cevher, “Bilinear generalized approximate message passing—part i: Derivation,” IEEE Transactions on Signal Processing, vol. 62, no. 22, pp. 5839–5853, Nov. 2014.
  • [32] N. Lee, T. Ajanthan, and P. H. S. Torr, “Snip: single-shot network pruning based on connection sensitivity,” in 7th International Conference on Learning Representations, ICLR 2019, New Orleans, LA, USA, May 6-9, 2019, 2019.
  • [33] A. A. Abdullah, M. M. Hassan, and Y. T. Mustafa, “A review on bayesian deep learning in healthcare: Applications and challenges,” IEEE Access, vol. 10, pp. 36 538–36 562, 2022.
  • [34] X. Bai and Q. Peng, “A probabilistic model based-tracking method for mmwave massive MIMO channel estimation,” IEEE Trans. Veh. Technol., vol. 72, no. 12, pp. 16 777–16 782, Dec. 2023.
  • [35] C. Briggs, Z. Fan, and P. Andras, “Federated learning with hierarchical clustering of local updates to improve training on non-iid data,” in 2020 International Joint Conference on Neural Networks (IJCNN), Jul. 2020, pp. 1–9.
  • [36] L. V. Jospin, H. Laga, F. Boussaid, W. Buntine, and M. Bennamoun, “Hands-on bayesian neural networks—a tutorial for deep learning users,” IEEE Comput. Intell. Mag., vol. 17, no. 2, pp. 29–48, May 2022.
  • [37] P. Kairouz, H. B. McMahan, B. Avent et al., “Advances and open problems in federated learning,” Found. Trends Mach. Learn., vol. 14, no. 1-2, pp. 1–210, 2021.
  • [38] S. Kuutti, R. Bowden, Y. Jin, P. Barber, and S. Fallah, “A survey of deep learning applications to autonomous vehicle control,” IEEE Trans. Intell. Transp. Syst., vol. 22, no. 2, pp. 712–733, Feb. 2021.
  • [39] T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar, and V. Smith, “Federated optimization in heterogeneous networks,” Proceedings of Machine Learning and Systems, vol. 2, pp. 429–450, Mar. 2020.
  • [40] L. Liu, X. Jiang, F. Zheng, H. Chen, G.-J. Qi, H. Huang, and L. Shao, “A bayesian federated learning framework with online laplace approximation,” IEEE Trans. Pattern Anal. Mach. Intell., vol. 46, no. 1, pp. 1–16, Jan. 2024.
  • [41] ——, “A bayesian federated learning framework with online laplace approximation,” IEEE Trans. Pattern Anal. Mach. Intell., vol. 46, no. 1, pp. 1–16, Jan. 2024.
  • [42] M. Magris and A. Iosifidis, “Bayesian learning for neural networks: An algorithmic survey,” Artif. Intell. Rev., vol. 56, no. 10, pp. 11 773–11 823, Oct. 2023.
  • [43] X. Meng and J. Zhu, “Bilinear adaptive generalized vector approximate message passing,” IEEE Access, vol. 7, pp. 4807–4815, 2019.
  • [44] X. Meng, S. Wu, and J. Zhu, “A unified bayesian inference framework for generalized linear models,” IEEE Signal Process. Lett., vol. 25, no. 3, pp. 398–402, Mar. 2018.
  • [45] M. Rashid and M. Naraghi-Pour, “Clustered sparse channel estimation for massive MIMO systems by expectation maximization-propagation (EM-EP),” IEEE Trans. Veh. Technol., vol. 72, no. 7, pp. 9145–9159, Jul. 2023.
  • [46] S. Ray and B. Lindsay, “The topography of multivariate normal mixtures,” The Annals of Statistics, vol. 33, Mar. 2006.
  • [47] D. Wang, F. Weiping, Q. Song, and J. Zhou, “Potential risk assessment for safe driving of autonomous vehicles under occluded vision,” Scientific Reports, vol. 12, Mar. 2022.
  • [48] C. Zheng, S. Liu, Y. Huang, W. Zhang, and L. Yang, “Unsupervised recurrent federated learning for edge popularity prediction in privacy-preserving mobile-edge computing networks,” IEEE Internet Things J., vol. 9, no. 23, pp. 24 328–24 345, Dec. 2022.
  • [49] J. Zhu, “A comment on the "a unified bayesian inference framework for generalized linear models",” Apr. 2019.
  • [50] J. Zhu, C.-K. Wen, J. Tong, C. Xu, and S. Jin, “Grid-less variational bayesian channel estimation for antenna array systems with low resolution adcs,” IEEE Trans. Wirel. Commun., vol. 19, no. 3, pp. 1549–1562, Mar. 2020.
  • [51] J. Ziniel, P. Schniter, and P. Sederberg, “Binary linear classification and feature selection via generalized approximate message passing,” IEEE Trans. Signal Process., vol. 63, no. 8, pp. 2020–2032, Apr. 2015.
  • [52] Q. Zou, H. Zhang, and H. Yang, “Multi-layer bilinear generalized approximate message passing,” IEEE Trans. Signal Process., vol. 69, pp. 4529–4543, 2021.