FedNNNN: Norm-Normalized Neural Network Aggregation for Fast and Accurate
Federated Learning
Abstract
Federated learning (FL) is a distributed learning protocol in which a server needs to aggregate a set of models learned some independent clients to proceed the learning process. At present, model averaging, known as FedAvg, is one of the most widely adapted aggregation techniques. However, it is known to yield the models with degraded prediction accuracy and slow convergence. In this work, we find out that averaging models from different clients significantly diminishes the norm of the update vectors, resulting in slow learning rate and low prediction accuracy. Therefore, we propose a new aggregation method called FedNNNN. Instead of simple model averaging, we adjust the norm of the update vector and introduce momentum control techniques to improve the aggregation effectiveness of FL. As a demonstration, we evaluate FedNNNN on multiple datasets and scenarios with different neural network models, and observe up to accuracy improvement.
Keywords:
Federated Learning Distributed Learning Deep Learning1 Introduction
The use of neural networks (NN) in applications such as image classification, natural language processing and speech recognition have become one of the core infrastructures in modern lives. As the number of devices in the age of internet of things (IoT) increases, a large amount of data can easily be collected to improve the accuracy of the NN models. On the other hand, due to privacy and computation resource concerns, federated learning (FL) [10] is attracting major attentions recently. In contrast to conventional single-machine model training, the FL server aggregates models that are locally trained by clients, and can thus be considered as a type of distributed learning [9, 18, 4, 3]. As model aggregation involves much less computations and communication bandwidth on the server compared to direct training, FL is a preferred strategy for the providers of the machine learning services.
Unfortunately, most existing FL frameworks are either unrealistic in their protocol construction, or impractical in terms of their prediction accuracy. In particular, the model averaging technique proposed by federated averaging (FedAvg) [10], adopted in most FL frameworks, result in about 15% accuracy degradation over non-IID (non independent and identically distributed) datasets [19]. Notable improvements over FedAvg include [19, 14, 5, 16]. We defer a more detailed discussion of related works to Section 2.3, but point out that many existing works failed to improve the accuracy of FL across datasets. Those techniques that are successful in addressing the accuracy degradation problem generally rely on additional data sharing [19, 5, 16], which defeats the original purpose of FL in many practical applications.
In this paper, we propose FedNNNN, a norm-based neural network aggregation technique. In particular, we first propose norm-based weight divergence analysis (NWDA) to visualize and explain how FL proceeds, particularly in the presence of the FedAvg algorithm. We then introduce FedNNNN, and show that we can improve the prediction accuracy of FL up to 5%. The main contributions of this work are summarized as follows.
-
•
NWDA and The Updating Direction Divergence Problem: We point out that local weight updates in different (diverging) directions on the clients result in small update norms, and we call this the weight updating direction divergence problem (WUDD). WUDD is identified as the most important reason for the slow convergence speed and degraded accuracy performance for FedAvg and the related works that build on FedAvg.
-
•
Norm-Normalized Aggregation: We propose a normalization technique that targets on solving the WUDD problem. In the technique, we apply a simple normalizing factor during model aggregation on the server with an additional momentum term to force and accelerate the learning process over the communication rounds.
-
•
Improved Accuracy with Negligible Overheads: By conducting rigorous experiments with the proposed technique, we observe accuracy improvements across datasets compared to the-state-of-the-art FL techniques with extremely small computational overheads.
The rest of this paper is organized as follows. In Section 2, we explain FL and FedAvg in detail, and discuss some related works. In Section 3, we introduce our norm-based FL analysis method NWDA, and in Section 4, our norm-based aggregation method, FedNNNN is formulated. In Section 5, our method is evaluated with various datasets experimentally, and we conclude our work in Section 6.
2 Preliminaries and Related Works
2.1 Federated Learning
Federated Learning (FL) [10] is a distributed learning protocol that enables one to train a model with a massive amount of data obtained by IoT devices or smartphones without expensive training on a centralized server. Instead of collecting the data and training the model on a single machine, FL server only combines models locally trained by edge devices to proceed the learning process, and this combination procedure is known as model aggregation. The properties of FL is extensively studied over the past few years [10, 7, 6, 2, 1, 17, 19, 14, 12, 15, 16, 11], where we see discussions on aspects of FL such as communication efficiency [7], adaptation to heterogeneous systems [12], performance over non-standard data distributions [19, 14, 5, 15, 16, 11], security properties [6, 2, 1], and many more. In this work, we focus on improving the prediction accuracy of FL over non-IID distributions, which is one of the main problems associated with existing FL frameworks.
2.2 Federated Averaging
In this section, we outline the Federated Averaging (FedAvg) framework [10]. The overview is depicted by Figure 2.
![]() ![]() |
In FedAvg, we assume that clients are in possession of the data, and a server aggregates the models from the clients. The learning is proceeded by the following steps.
-
•
Step ➀: Let the total number of clients to be . The server first picks clients out of the total clients for some real number . In the first round, the server initialize a model and distribute the model to the selected clients. Otherwise, the server distributes the aggregated model . We consider the model distribution to be the start of the -th round of communication.
-
•
Step ➁: Upon receiving the server model, each client locally trains the their own models using as the initial model for epochs. The local model trained on the -th client in the -th communication round is referred to as .
-
•
Step ➂: Client returns its trained model to the server.
-
•
Step ➃: Upon receiving locally trained models from the clients, the server aggregates the models by simply taking the weighted average as
(1) where is the size of the dataset on the -th client, and we have that .
-
•
Step ➄: The server evaluates a unified model with test data. Additional learning steps can be carried out by repeating the steps from ➀ to ➄.
2.3 Improvements on FedAvg
As mentioned, the averaging aggregation utilized in FedAvg results in significant accuracy degradation during server model evaluation if the clients possess non-IID datasets. A line of works [19, 14, 5, 15, 16, 11] are proposed to address this problem. Here we give a brief review on the existing methods.
The work in [19] is one of the first to point out that non-IID datasets result in drastically different local models, and these models become difficult to aggregate with simple averaging. However, [19] only proposes to share auxiliary datasets to each client so that each client obtains a more IID dataset. [14] tries to adjust the loss function to improve upon [19], but our analysis shows that the improvements are not consistent across datasets. In Section 3.2, we take a deeper look at the exact reason for the degraded accuracy of FL over non-IID datasets.
Other optimization approaches include [5, 15, 16]. In [5], clients are grouped based on the label distribution of their datasets. In [15], the server aggregates the client models whenever a single weight update occurs on some particular client. These approaches clearly incur a large amount of communications between the server and the clients. FedCurv [16], which imposes complex loss function on clients, also induces communication overheads. In summary, existing works generally require additional datasets or complex communication protocols to directly solve the non-IID problem on the data level. In what follows, we introduce quantitative analyses and normalization-based techniques to mitigate the impact of non-IID datasets without relying on auxiliary datasets or complex protocol modification.
3 Aggregating Divergent Weights in Federated Averaging
We propose a norm-based weight divergence analysis (NWDA) technique to visualize how learning proceeds in FedAvg-based aggregation techniques over non-IID datasets in this Section.
3.1 Norm-based Weight Divergence Analysis
We start by defining a per-client for the -th communication round
(2) |
where is the weight vector (i.e., neural network model) distributed to each client in the start of the -th communication round, and is the locally learned weights that will be returned to the server in the -th round. Hence, we can simply interpret the weight difference as the amount of learning proceeded in a single round of communication for client .
Using Equation (2), we can re-formulate Equation (1) the aggregation procedure in FedAvg as
(3) |
In other words, since all clients share the same , we can express the weight-averaging procedure in FedAvg as the sum of the distributed model and the averaged sum of the local updates from each of the clients.
To quantitatively assess the impact of the updating vector, we define a pair of real scalars using the norm as follows
(4) | |||||
(5) |
Here, in Equation (4) is the distance that the server model moved from the -th round to the -th round. Whereas, in Equation (5) is the average of the norms of each local weight updating vector on the -th client. Consequently, we can think of as the amount of server model updates, and as the (average) amount of local updates in clients.
The following proposition expresses the relationship between and .
Proposition 1
The following inequality holds
(6) |
for all
As Proposition 1 follows trivially from a recursive application of the Pythagorean inequality, we leave a formal proof to the appendix. The main idea behind the proposition is that, while each client locally proceeds the learning process by an average of , after model aggregation, the server only learns by , which is guaranteed to be less than by Proposition 1.
The previously described learning behavior can be better illustrated through Figure 2. The important observation here is that, since each client learns locally without online communication, their updating direction diverges. As the server averages these local models in FedAvg, by the formulation of Equation (3), the local models tend to cancel each other out, resulting in a extremely small updating vector on the server (i.e., ). We refer to this phenomenon as the weight updating direction divergence (WUDD) problem.
![]() ![]() |
As a demonstration of the NWDA technique, we show the calculated and values using MNIST dataset. Figure 4 show how the sizes of and change over the communication rounds (i.e., the learning epochs). Initially, the NN tries to learn the dataset through a series of weight updates, and we observe reasonably large updating norms. Since MNIST is a small dataset, on the single-client case, the learning soon converges, and both and become extremely small due to the learning convergence (note that for the single-client case). In contrast, while is large for non-IID FedAvg, due to the WUDD problem, is relatively small throughout the communication rounds. We also emphasize on the fact that, as the learning proceeds, becomes convergent (extremely small ) while remains large. This indicates that some learning updates are still available, but cannot be learned (or aggregated) by the server through the simple averaging technique in FedAvg.
To further study the norm properties of single-client, IID and non-IID FedAvg, we plot , the integral of the amount of norm updates on the server with respect to , i.e., in Figure 4. The single-client case shows a clear trend of fast and convergent learning curve, while both IID and non-IID FedAvg do not show clear signs of convergence across the communication rounds.
3.2 Applying NWDA in Existing Works
In Section 3.1, we analyzed the learning process of FL under FedAvg as a series of aggregations of vectors representing weight updates . Based on our analysis, we can say that the average of the weights across clients (i.e., ) obtains a reasonable large norm only when all clients update in the same direction. If the weight updates diverge (i.e., the WUDD problem), learning cannot be executed effectively and efficiently in a simple FedAvg setting.
While no existing works explicitly derive a norm analysis, some techniques [19, 5, 15, 16] try to implicitly address the WUDD problem. For example, as mentioned Section 2.3, it is shown that if additional datasets (with different class labels) are distributed to the clients along to be learned along with the local datasets on the client, the accuracy improves [19]. With NWDA, we can simply interpret this method as reducing the divergence between the updating weight vectors across independent clients by introducing common datasets. Based on the observation from [19], in [14], FedProx is proposed to reduce the difference between client models. FedProx proposes to add a normalization term in the loss function on each client as
(7) |
where is the original loss, and is a hyperparameter. The basic idea of FedProx is to penalize weight updates that largely increase the distance (i.e., norm) between the locally learned model and the distributed model . Essentially, by introducing a penalty term in the loss function, FedProx only reduces the size of , and the diverging direction problem is not addressed. However, is what actually matters to the evaluation accuracy of the aggregated model. As long as is much less than the required amount of weight updates to fully perceive a dataset, the aggregated model cannot obtain a reasonable level of accuracy.
4 Norm-Normalized Neural Network Aggregation
In this section, we formally present the proposed norm-based aggregation (FedNNNN) technique that improves the prediction accuracy of FL without incurring large computational or communication overheads. We first introduce the slightly modified FL protocol adopted in FedNNNN, and then demonstrate that FedNNNN is more efficient and effective in improving the accuracy of FL, especially over non-IID datasets.
4.1 FedNNNN: The Protocol
![]() ![]() |
The main difference between the FedNNNN protocol and the conventional FL protocol is that, we separate the evaluation model and the aggregated model. As shown in Figure 6, after the clients locally produce their models, there are two aggregation steps.
-
•
➃ Average Aggregation: The same aggregation in FedAvg is adopted to produce a single model that is capable of performing highly-accurate inference based on the aggregated models.
-
•
➅ Normalized Aggregation: The proposed normalized aggregation technique is used to resolve the WUDD problem where . More details on the complete FedNNNN aggregation is provided in Section 4.2.
All other steps remain the same as in the conventional FL protocol. Lastly, we note that the proposed normalized aggregation technique is only good as a foundation of learning. In fact, due to our modification on the norm of the weights (as seen in the next section), prediction accuracy on the aggregated model from step ➅ remains extremely poor.
4.2 FedNNNN: The Aggregation
As discussed in Section 3.2, according to our NWDA, the main reason that existing FL schemes fall short on non-IID datasets is that averaging models with divergent directions significantly reduces the updating norm of the aggregated model. Hence, in FedNNNN, we propose to normalize the updating norm as
(8) |
Consequently, the norm of the normalized updating weight vector becomes
(9) |
In order to control the impact of the norm normalization, we introduce a hyperparameter . In the experiment, we empirically decide on the value of .
A conceptual illustration of our normalized aggregation is depicted in Figure 6. We assume that the weighted average gives us a reasonably correct updating direction, but with insufficient norm. We amplify the norm of the updating vector in that particular direction with the size of . The reason that works as a proper normalization factor stems from our observation in Figure 2, where we see that the size of remains large after converges in the non-IID case. As a result, the amount of unfinished learning can be expressed through the size of , and if approaches to zero, we can safely conclude that no significant client update is available anymore. Note that when is extremely small, e.g., 0 or close to 0, we do not normalize the updating vector and return as is.
While the proposed normalized aggregation technique described above is successful in advancing the learning process in a given direction, we still find that, due to its democratic nature, FL is susceptible to local extrema. To avoid the weights being stuck in a low-accuracy state, we add a momentum term to the aggregation process that is reminiscent to the momentum stochastic gradient descent (SGD) method [13]
(10) | |||||
(11) |
Here, is the same hyperparameter as in momentum SGD that expresses the amount of past weight changes to be memorized. Equation (10) defines the -th round momentum to be a scaled version of the average weight updates (note ), and these weight updates accumulate as a driving force to prevent the learning from being stuck in non-optimized states. Hence, the aggregation formula of our normalized aggregation with momentum adjustment becomes
(12) | |||||
(13) |
and we leave the complete description of the entire protocol to the appendix.
5 Experiments
In this section, we present the effectiveness of our proposed FedNNNN through the experiments on 2 image datasets, MNIST and CIFAR10 [8] and 4 different FL settings for data distributions. According to the size of each dataset, different CNN models have been implemented using the PyTorch framework. In this work, we compare our technique to FedAvg and FedProx, as they have similar computation and communication characteristics to FedNNNN.
5.1 Settings
We use the MNIST dataset that contains 10 classes of grey-scale handwritten images of numbers, and the CIFAR10 dataset which classifies RGB images into 10 classes. The parameters used in the following experiments are summarized in Table 1, where the notations are commonly used in most FL methods. Total Rounds represents the total number of communications between the server and the clients. and are learning rate and weight decay parameters, respectively. We used SGD without momentum for optimization. Due to the limited space, the exact valuations of , and which used in each FL method will be listed in the appendix. We split the dataset into training and test data, where the training data are assumed to be inherently held by the clients and test data are evaluated at server side. The hyperparameters , , , , were chosen to have the highest accuracy on the test data over the training rounds.
Here, we outline the architecture of CNN models used in this section. For MNIST, we used a CNN architecture with 2 convolutional layers followed by 2 fully connected layers. For CIFAR10, a CNN with 3 fully connected layers following 6 convolutional layers with batch normalization layer is used. Convolutional and fully connected layers of both the MNIST and the CIFAR10 architectures have bias terms. In Table 1, BN indicates whether batch normalization is used. As a separate preprocessing, we normalized input images by subtracting the mean and dividing by the standard deviation.
Parameters | Total Rounds | BN | ||||||
---|---|---|---|---|---|---|---|---|
MNIST | 100 | 1 | 100 | 50 | 5 | 0.05 | 0 | no |
CIFAR10 | 250 | 1 | 100 | 50 | 5 | 0.05 | yes |
IID - B | non-IID - B | IID - UB | non-IID - UB | |||||
---|---|---|---|---|---|---|---|---|
MNIST | CIFAR10 | MNIST | CIFAR10 | MNIST | CIFAR10 | MNIST | CIFAR10 | |
FedAvg | 99.0 | 81.3 | 98.2 | 72.6 | 99.1 | 83.5 | 93.5 | 56.6 |
FedProx | 99.0 | 81.4 | 98.1 | 74.2 | 99.1 | 83.2 | 87.9 | 55.6 |
Norm-Norm | 99.1 | 81.2 | 99.0 | 76.7 | 99.2 | 83.3 | 97.1 | 55.5 |
Momentum | 99.1 | 83.1 | 99.2 | 74.3 | 99.1 | 84.2 | 96.9 | 62.0 |
FedNNNN | 99.2 | 84.7 | 99.1 | 76.7 | 99.2 | 84.0 | 98.9 | 61.9 |
![]() ![]() ![]() |
In this work, we take data distribution into account when performing the experiments. There are four types of data distribution: i) IID, ii) non-IID, iii) balanced, and iv) unbalanced. For i) IID, clients possess images of all of the defined classes in the dataset. In contrast, for ii) non-IID, clients have images of only 2 classes (the minimum number of classes that allow meaningful learning). For the iii) balanced condition, all clients have dataset of the same size. On the other hand, iv) unbalanced condition assumes that the dataset sizes of the clients follows a power-law distribution, where a small number of clients hold most of the training images. For privacy reasons, however, we assumed that the server does not know the data size of the client and set for all .
Experiments are conducted once for each dataset. For our parameter setting, simulation time for total communication rounds took an hour for MNIST and 3.5 hours for CIFAR10, respectively, using GeForce GTX 1080 Ti.
5.2 Experiment Results
We evaluated five FL aggregation methods on two datasets with the aforementioned four different types of dataset distributions in the experiments. The accuracy comparisons between various experimental conditions are summarized in Table 2. In the table, Equation (8) is used by Norm-Norm method and Equation (10), (11) are used by Momentum. B and UB indicate conditions iii) balanced and iv) unbalanced, respectively.
Regardless of the datasets and distributions, the proposed methods achieve the best prediction accuracy. Although the momentum appears to be relatively strong, it can be seen that the accuracy can be further improved by combining norm normalization. The changes of accuracy as functions of communication rounds are shown in Figure 9 for non-IID - B CIFAR10. We can find FedNNNN to significantly improve the convergence rate, thus speeding up the learning process.
We expect the amount of server update in FedNNNN to be larger than that of FedAvg as discussed in Section 4.2. Figure 9 and 9 confirms this expectation. In these figures, is an average clients update norm and is the updated norm on the server side . We can clearly observe that the amount of update on the server in FedNNNN is larger than the in FedAvg and FedProx. In Figure 9, we see that of FedNNNN is smaller than that of FedAvg and FedProx, indicating that the weight updates of clients have become smaller because of better initial models. However, in Figure 9, of FedNNNN appears to be oscillating. This indicates that the weights of the convolutional layers are less likely to converge when compared to fully connected layers. A thorough evaluation of and of the fully connected layer and convolutional layers for all settings will appear in the appendix.
6 Conclusion
In this paper, we introduced FedNNNN for the improvement of convergence speed and prediction accuracy of FL. We first defined a norm-based analysis that expresses the amount of model updates by the norm of the sum of the update vectors, and identify that the small size of the norm causes slow convergence and accuracy degradation on clients with non-IID datasets (the WUDD problem). To solve the WUDD problem, we proposed FedNNNN, an aggregation technique that normalize the update vector according to the amount of unfinished learning. In the experiments, we observed that FedNNNN outperforms FedAvg and FedProx, the-state-of-the-art FL frameworks in terms of both convergence speed and prediction accuracy. In particular, we achieve up to 5% accuracy improvement over FedAvg and FedProx on the CIFAR10 dataset.
References
- [1] Bagdasaryan, E., Veit, A., Hua, Y., Estrin, D., Shmatikov, V.: How to backdoor federated learning. CoRR abs/1807.00459 (2018), http://arxiv.org/abs/1807.00459
- [2] Bonawitz, K., Ivanov, V., Kreuter, B., Marcedone, A., McMahan, H.B., Patel, S., Ramage, D., Segal, A., Seth, K.: Practical secure aggregation for privacy-preserving machine learning. In: Proceedings of the 2017 ACM SIGSAC Conference on Computer and Communications Security. pp. 1175–1191. CCS ’17, ACM (2017)
- [3] Chilimbi, T., Suzue, Y., Apacible, J., Kalyanaraman, K.: Project adam: Building an efficient and scalable deep learning training system. In: 11th USENIX Symposium on Operating Systems Design and Implementation (OSDI 14). pp. 571–582. USENIX Association (2014)
- [4] Dean, J., Corrado, G., Monga, R., Chen, K., Devin, M., Mao, M., aurelio Ranzato, M., Senior, A., Tucker, P., Yang, K., Le, Q.V., Ng, A.Y.: Large scale distributed deep networks. In: Pereira, F., Burges, C.J.C., Bottou, L., Weinberger, K.Q. (eds.) Advances in Neural Information Processing Systems 25, pp. 1223–1231. Curran Associates, Inc. (2012)
- [5] Duan, M.: Astraea: Self-balancing federated learning for improving classification accuracy of mobile deep learning applications. CoRR abs/1907.01132 (2019), http://arxiv.org/abs/1907.01132
- [6] Geyer, R.C., Klein, T., Nabi, M.: Differentially private federated learning: A client level perspective. CoRR abs/1712.07557 (2017), http://arxiv.org/abs/1712.07557
- [7] Konecný, J., McMahan, H.B., Yu, F.X., Richtárik, P., Suresh, A.T., Bacon, D.: Federated learning: Strategies for improving communication efficiency. CoRR abs/1610.05492 (2016), http://arxiv.org/abs/1610.05492
- [8] Krizhevsky, A.: Learning multiple layers of features from tiny images. Tech. rep. (2009)
- [9] Ma, C., Konečný, J., Jaggi, M., Smith, V., Jordan, M.I., Richtárik, P., Takáč, M.: Distributed optimization with arbitrary local solvers. Optimization Methods and Software 32(4), 813–848 (2017). https://doi.org/10.1080/10556788.2016.1278445
- [10] McMahan, H.B., Moore, E., Ramage, D., Hampson, S., y Arcas, B.A.: Communication-efficient learning of deep networks from decentralized data. In: AISTATS (2016)
- [11] Mohri, M., Sivek, G., Suresh, A.T.: Agnostic federated learning. In: Chaudhuri, K., Salakhutdinov, R. (eds.) Proceedings of the 36th International Conference on Machine Learning. Proceedings of Machine Learning Research, vol. 97, pp. 4615–4625. PMLR (2019)
- [12] Nishio, T., Yonetani, R.: Client selection for federated learning with heterogeneous resources in mobile edge. In: ICC 2019 - 2019 IEEE International Conference on Communications (ICC). pp. 1–7 (2019)
- [13] Qian, N.: On the momentum term in gradient descent learning algorithms. Neural Networks 12(1), 145 – 151 (1999)
- [14] Sahu, A.K., Li, T., Sanjabi, M., Zaheer, M., Talwalkar, A., Smith, V.: On the convergence of federated optimization in heterogeneous networks. CoRR abs/1812.06127 (2018), http://arxiv.org/abs/1812.06127
- [15] Sattler, F., Wiedemann, S., Müller, K., Samek, W.: Robust and communication-efficient federated learning from non-iid data. CoRR abs/1903.02891 (2019), http://arxiv.org/abs/1903.02891
- [16] Shoham, N., Avidor, T., Keren, A., Israel, N., Benditkis, D., Mor-Yosef, L., Zeitak, I.: Overcoming forgetting in federated learning on non-iid data (2019)
- [17] Wang, S., Tuor, T., Salonidis, T., Leung, K.K., Makaya, C., He, T., Chan, K.: When edge meets learning: Adaptive control for resource-constrained distributed machine learning. IEEE INFOCOM 2018 - IEEE Conference on Computer Communications pp. 63–71 (2018)
- [18] Zhang, Y., Lin, X.: Disco: Distributed optimization for self-concordant empirical loss. In: Bach, F., Blei, D. (eds.) Proceedings of the 32nd International Conference on Machine Learning. Proceedings of Machine Learning Research, vol. 37, pp. 362–370. PMLR (2015)
- [19] Zhao, Y., Li, M., Lai, L., Suda, N., Civin, D., Chandra, V.: Federated learning with non-iid data. CoRR abs/1806.00582 (2018), http://arxiv.org/abs/1806.00582
Appendix
Complete Protocol of FedAvg
Here, we show the complete protocol of FedAvg. The server and clients operate as shown in Algorithm 1 and Algorithm 2 in FedAvg, respectively. The server executes Algorithm 1 as follows.
-
•
Line 1: The server initializes model and distribute the model to each client.
-
•
Line 2: For each round , the following steps are repeated.
-
•
Line 3–4: The number of clients to train in this round, , is calculated, and the selected set of clients is generated.
-
•
Line 5: For each client :
-
•
Line 6: The server sends to the clients.
-
•
Line 7: After the clients train their local models, the server receives trained model for each client .
-
•
Line 9: The server aggregates models averaging them as .
On the client side, the following steps are implemented.
-
•
Input: The client receives the server model and its label .
-
•
Line 1: the client first obtains some local datasets (Note that this dataset is not from the server). The client splits the entire dataset into mini-batches . We note the set of the mini-batches as .
-
•
Line 2: For each epoch from 1 to :
-
•
Line 3: For each mini-batch :
-
•
Line 4: The client trains weight as .
-
•
Line 7: The clients send trained weight to the server.
Complete Protocol of FedNNNN
The protocol of FedNNNN is similar to that of the FedAvg protocol. First, we note that the client-side procedure remains unchanged, i.e., exactly same as Algorithm 2. The server-side protocol is also similar but different only for lines 9–12 in Algorithm 3.
-
•
Line 9–10: and are calculated.
-
•
Line 11–12: the server adds the update vector to the momentum and updates weight as .
Proof of Proposition 1
In this section, we show a formal proof of Proposition 1. Before delving into the proof of Proposition 1, we first outline an important lemma.
Lemma 1
Let be some integer. The following inequality holds
(14) |
for all and any real numbers .
Proof
We prove Lemma 1 through an induction on .
Base case: Let . Then, Equation (1) becomes
(15) |
and the equality obviously holds for any .
Inductive case: Assume that Equation (1) holds for . For , the LHS of Equation (1) becomes
(16) |
The Pythagorean inequality states that, for any two real vectors , the following inequality holds.
(17) |
Then, we can derive the following inequality
(18) |
Since as assumed, Equation (5) becomes
(19) |
Since it is also obvious that ,
(20) | |||||
Consequently, we know that
(21) |
and the lemma follows.
Proposition 2
The following inequality holds
(22) |
for all
Proof
Since and as defined, let . Then, by Lemma 1, we know that
(23) |
and the proposition follows.
Experiments
Models
Table 3 and Table 4 show the models we used for the experiments of MNIST and CIFAR10, respectively. We adopted the cross entropy loss function in all models.
Layer | Input | Output | Kernel | Stride |
---|---|---|---|---|
Input | - | - | - | |
Conv1 | 5 | 1 | ||
ReLU | - | - | - | - |
Maxpool | 2 | 2 | ||
Conv2 | 5 | 1 | ||
ReLU | - | - | - | - |
Maxpool | 2 | 2 | ||
Fc1 | - | - | ||
ReLU | - | - | - | - |
Fc2 | - | - |
Layer | Input | Output | Kernel | Stride |
---|---|---|---|---|
Input | - | - | - | |
Conv11 | 3 | 1 | ||
BN+ReLU | - | - | - | - |
Conv12 | 3 | 1 | ||
BN+ReLU | - | - | - | - |
Maxpool | 2 | 2 | ||
Conv21 | 3 | 1 | ||
BN+ReLU | - | - | - | - |
Conv22 | 3 | 1 | ||
BN+ReLU | - | - | - | - |
Maxpool | 2 | 2 | ||
Conv31 | 3 | 1 | ||
BN+ReLU | - | - | - | - |
Conv32 | 3 | 1 | ||
BN+ReLU | - | - | - | - |
Maxpool | 2 | 2 | ||
Fc1 | - | - | ||
ReLU | - | - | - | - |
Fc2 | - | - | ||
ReLU | - | - | - | - |
Fc3 | - | - |
Determination of Hyperparameters
In this section, we explain how we determined the parameters used in the experiment. The following heuristic procedure is adopted on deciding hyperparameters.
- Step 1
-
First of all, divide the training dataset into two datasets. One dataset is used to train the neural network model, while the other dataset is used to adjust the hyperparameters, referred to as the validation dataset.
- Step 2
-
Decide the learning rate . We performed several epochs (30 epochs on CIFAR10 and 50 epochs on MNIST) for the training of non-IID and B conditions with of the candidates, and the one with the best test accuracy was used in all methods and conditions. In MNIST, we choose from range [0.02,0.15] by every 0.01, and in CIFAR10 we choose from range [0.01,0.11] by every 0.02.
- Step 3
-
While fixing at the best value in Step-2, with respect to each dataset and condition, we varied , and independently and choose the one with the best test accuracy.
Table 5 lists all the parameters used for each FL method, for each condition. We choose , and as well as the learning rate . In FedNNNN, we varied and independently and selected the the combination that gave the best accuracy.
IID-B | NonIID-B | IID-UB | NonIID-UB | |||||
---|---|---|---|---|---|---|---|---|
MNIST | CIFAR10 | MNIST | CIFAR10 | MNIST | CIFAR10 | MNIST | CIFAR10 | |
FedProx | ||||||||
0.005 | 0.015 | 0.015 | 0.015 | 0.005 | 0.005 | 0.02 | 0.01 | |
Norm-Norm | ||||||||
1.1 | 0.6 | 1.0 | 0.6 | 1.0 | 0.7 | 0.9 | 0.7 | |
Momentum | ||||||||
0.8 | 0.9 | 0.9 | 0.9 | 0.7 | 0.9 | 0.8 | 0.8 | |
FedNNNN | ||||||||
0.6 | 0.7 | 0.7 | 0.6 | 0.7 | 0.8 | 0.7 | 0.7 | |
FedNNNN | ||||||||
0.7 | 0.8 | 0.8 | 0.7 | 0.7 | 0.8 | 0.8 | 0.6 |
Experimental Results
Here, we show accuracy curve and update norm on MNIST, CIFAR10. Figures 21 21 show the accuracy and update norm on MNIST. Figure 33 33 show those on CIFAR10 is shown.
![]() ![]() ![]() |
![]() ![]() ![]() |
![]() ![]() ![]() |
![]() ![]() ![]() |
![]() ![]() ![]() |
![]() ![]() ![]() |
![]() ![]() ![]() |
![]() ![]() ![]() |
We plot the following norms in the figure.
(24) | |||||
(25) |
The implications of the graphs on these norms can be summarized as follows:
- Improvement of convergence rate and accuracy
-
The convergence of the proposed method is faster than that of the existing methods. The accuracy of the final result has also been improved.
- Trend of norm
-
The norm tends to be larger than that of the existing methods, especially in the early stages of the communication round. This means that larger weight updates are achieved in each round than those of the existing methods.
- Trend of norm
-
On the other hand, the norm of the proposed method tends to be smaller than the existing methods. This is due to the fact that clients are learning gradually towards convergence. This trend is more pronounced in full connected layers, but this is not the case in the convolutional layer.
- Convolutional layer on CIFAR10
-
Observing the norm of the convolutional layer in CIFAR10, the norm of the proposed method appears to be oscillating. It is believed that the weights are harder to converge than in full connected layers.