Personalized Interpretation on Federated Learning: A Virtual Concepts approach
Abstract
Tackling non-IID data is an open challenge in federated learning research. Existing FL methods, including robust FL and personalized FL, are designed to improve model performance without consideration of interpreting non-IID across clients. This paper aims to design a novel FL method to robust and interpret the non-IID data across clients. Specifically, we interpret each client’s dataset as a mixture of conceptual vectors that each one represents an interpretable concept to end-users. These conceptual vectors could be pre-defined or refined in a human-in-the-loop process or be learnt via the optimization procedure of the federated learning system. In addition to the interpretability, the clarity of client-specific personalization could also be applied to enhance the robustness of the training process on FL system. The effectiveness of the proposed method have been validated on benchmark datasets.
Motivation
A critical challenge in PerFL is the absence of well-defined concepts of personalisation. Client preferences and personalised properties are implied in training data and enclosed on each client. They could be a client’s favour towards specific classes or a specific noise mixed up with input features. The only tangible information is the shift in data distribution across clients.
Meanwhile, most machine learning models, e.g., DNNs, are trained in an end-to-end paradigm. They are optimised by back-propagating supervised information, e.g., classification loss, from the output layer to the input layer. Personalisation is performed indirectly when a model is tuned for tasks like classification. This learning schema is less efficient in PerFL. The on-device training tends to overfit a client’s local data due to limited and unbalanced training samples. The aggregation step on the server, in turn, will neutralise personalised information when synthesising the global model, e.g., by averaging local updates.
However, it is worth noting that though there is no supervised information of significantly defined client properties, a feature distinguishing model personalisation from unsupervised tasks is that data in PerFL are explicitly partitioned. Samples from the same client will demonstrate a client-specific bias toward certain properties. Then, one may assume that there were invisible labels of clients inducing the on-device training to progress toward a client’s preferences, i.e., personalisation. The client-based data partition essentially supervises PerFL’s training process, so this research calls the learning paradigm Client-Supervised Learning.
Based on the thought above, this research introduces Virtual Concepts (VC) to explicate client-supervised information. The VCs are representations of potential structure information extracted from training data. They can be learned independently of downstream classification tasks by a novel FedVC algorithm, which facilitates understanding client properties and boosts model personalisation.
Specifically, FedVC assumes that there is a set of vectors (virtual concepts), each describing a type of client property. A client’s preferences are then represented by a combination of VCs, which will be utilised as supervised information to guide the training progress of the global model. Figure 1 gives an illustration to the propose FedVC.
![Refer to caption](x1.png)
To learn the VCs, FedVC evaluates the underlying distribution structure in data by formulating the learning task into a Gaussian Mixture Model (GMM) that can be solved by most unsupervised learning methods, e.g., Expectation-Maximisation algorithm (EM).
Experiments on real-world datasets show that the VCs can work as supervised information to train a robust global model to the changing distributions. Further study demonstrates that the VCs are useful in exploring meaningful client properties by discovering distribution structures implied in training data.
The main contributions are summarised as follows:
-
•
The research proposes virtual concepts describing client preferences. The VCs are representations of distribution structure extracted from training data. They provide us with a way to explore meaningful client properties relevant to model personalisation.
-
•
The research proposes a novel client-supervised PerFL framework that utilises virtual concept vectors as supervised information to train the global model. The VCs will allow an FL algorithm to simultaneously learn class and client knowledge so that the learned global model can achieve on-deployment personalisation, where the global model will not require an extra fine-tuning process at the test stage.
-
•
The research formulates the learning task of VCs into a Gaussian Mixture Model that most unsupervised learning methods can solve. The proposed FedVC framework is compatible with most FL methods, where they can be integrated as an add-on to improve personalisation performance and model interpretability.
-
•
Contrast with baseline methods shows that FL models trained with VCs can simultaneously learn class and client knowledge. It achieves competitive personalisation performance without requiring extra fine-tuning steps or personal parameters.
-
•
Empirical studies show that VCs can discover meaningful distribution structures implied in training, facilitating the uncovering of client properties related to model personalisation.
Methodology
Client-supervised PerFL
Let denote virtual concept vectors, a client’s preference is then represented by , where is the client index, and is a factor measuring the degree the client relevant to , i.e., how typical the client has the property of . FedVC aims to utilise as supervised information to guide FL’s learning process so that the global model can learn client knowledge explicitly.
![Refer to caption](x2.png)
Specifically, FedVC adds a projection head to FL’s global model to extract a representation of potential client properties (see Figure 2). One can evaluate a sample’s relevance to each concept by a similarity function, e.g., Equation 1, and derive an estimated client preference , where is the sample index and is a hyperparameter.
(1) |
Then, there will be a supervised loss regarding client preferences, i.e., . It can be integrated into any FL framework and solved by gradient-based methods. Details of the learning algorithm are in Algorithm 1.
Virtual Concepts
As virtual concepts correspond to client properties, a sample is then assumed to be generated by some random process involving a mixture of multiple client properties. FedVC formulates the assumption into a Gaussian Mixture Model (GMM). For any sample on the -th client, there is
(2) |
where the covariance is set to be the identity matrix for simplicity.
Let denotes the set of VCs and denotes the set of client preferences, the collaborative learning task for and is formulated as
(3) |
FedVC solves it by the EM framework bellow:
-
•
E-step: Given and , clients estimate local samples’ by Equation 1
- •
However, Equation 4 and Equation 5 cannot be applied directly when working with minibatches in FL settings. FedVC uses exponential moving averages as an alternative:
(6) |
(7) |
(8) |
where denotes a minibatch of samples, denotes the batch size, and is a smoothing hyperparameter between 0 and 1. Then,
(9) |
(10) |
The overall learning algorithm is described in Algorithm 1.
Unified Learning Process
It is worth noting that the client preference can be viewed a function of virtual concepts , so does the loss . Then, the learning processes for and the global can be formulated into a unified optimisation task that can be solved in an end-to-end manner, rather than in an alternate way as EM-based methods.
![Refer to caption](x3.png)
Concretely, as described in Figure 3, will simultaneously provide supervised information for optimising virtual concepts and the model. The unified learning object is formulated as
(11) |
where
(12) | ||||
The is the stopgradient operator (Van Den Oord, Vinyals et al. 2017), where the operand will feed forward as normal but have zero partial derivatives, being a non-updated constant. is a hyperparameter balancing the two loss. The corresponding learning process is summarised in Algorithm 2.
Experiments
This section empirical studies the advantages of FedVC in learning from clients with non-I.I.D. data. The FedVC can learn a robust FL global model for the changing data distributions of unseen/test clients. The FedVC’s global model can be directly deployed to the test clients while achieving comparable performance to other personalised FL methods that require model adaptation.
Non-I.I.D settings
Target Shift: MNIST is applied as a benchmark to simulate the non-I.I.D. environments. The experiment allocates samples of each class individually according to a posterior of the Dirichlet distribution(Hsu, Qi, and Brown 2019), which divides clients into five groups with different class distributions. Three groups of clients will participate in the collaborative training process, and the rest will be held for testing. An illustration of client settings is in Figure 4. Class distributions are shown in Figure 5.
![Refer to caption](x4.png)
![Refer to caption](extracted/5697192/images/chapter6-class-proportion-mnist.png)
Feature Shift: The research utilises the Digit-5 dataset to evaluate FedVC’s performance on feature-shift data. The Digit-5 consists of digits from five different domains (MNIST, MNIST-M, SVHN, USPS and Synth Digits). The experiment assigns samples of each domain to six clients, where five clients will participate in training the global model and one will be held aside for the test.Classes are evenly distributed on each client. In addition, it randomly draws samples from all domains to compose five mixed datasets for the rest clients for the test. An illustration of client settings is in Figure 6.
![Refer to caption](x5.png)
Models and Hyperparameters
The research applies convolution neural networks (CNN) as fundamental models and supervises the training process by virtual concepts. By default, in each communication round, ten clients are sampled to update the global model and virtual concepts, and subsequently, the global model is synchronised to all clients to evaluate its performance. The learning rate of a client’s local training step is initialised as 0.005 and it will decay at the rate of 0.8 every 10 communication rounds. During each communication round, a client will tune the global model on its local data for two epochs with a batch size of 10.
For the FedVC, the default number of virtual concepts is set to be 10 and the dimension of each virtual concept is 10. The similarity parameter is 0.1, and the smoothing parameter is 0.05.
Baseline Methods
Several PerFL strategies are compared as baselines, including:
-
•
Local Only: models those trained on each client locally
- •
-
•
FedBN: a global model with private BatchNormalisation layers (Li et al. 2021b)
-
•
FedProx: leverages a global to regularise the local training process (Li et al. 2018)
-
•
Ditto: leverages a global to regularise the local training process while learning a local model for each client (Li et al. 2021a)
-
•
FedRep: personalisation by training local classification heads (Collins et al. 2021)
-
•
FedDual: personalisation by training a global and a local feature extractors (Pillutla et al. 2022)
Performance
This section first demonstrates averaged model performance on all clients, which shows that a global model learned with FedVC will achieve comparable performance to other personalised FL methods that require model adaptation. Then, it looks inside the group-wised metrics to evaluate a model’s performance on different distributions. Results show that the global model learned with FedVC is more robust to the changing distributions. The learned global model can be directly deployed on test clients without extra adaptations.
Target Shift Settings
For target shift settings, the averaged accuracy, weighted AUC score and weighted F1 score are applied to evaluate model performance111https://scikit-learn.org/stable/index.html. Table 1 and Table 2 respectively report the averaged performance over the training clients and the test clients. Figure 9 shows the group-wise performance.
Overall performance
Table 1 demonstrates models’ performance on the MNIST dataset on the training clients (tr-clients). It can be found that a global model trained by FedVC achieves the best performance under the target shift setting. It outperforms those locally fine-tuned global models (FedAvg+FT) and models with client-specific parameters (FedBN, FedProx, FedRep and FedDual). Table 2 demonstrates models’ performance on the test clients (ts-clients). All baseline methods are fine-tuned on the test clients to adapt to the client’s local distribution. It can be found that the model learned by FedVC generalised well to the unseen clients, even though they are not fine-tuned. Note that locally trained models (Local Only and Ditto) can not be generalised to unseen clients.
avg. Acc (%) | w. AUC (%) | w. F1 (%) | |
Local Only | 95.79 (1.00) | 99.69 (0.11) | 93.21 (0.96) |
FedAvg+FT | 97.92 (0.98) | 99.90 (0.04) | 95.49 (1.07) |
FedBN | 98.43 (0.86) | 99.90 (0.04) | 95.71 (0.94) |
FedProx | 98.07 (0.90) | 99.89 (0.04) | 95.48 (0.89) |
Ditto | 95.86 (1.19) | 99.71 (0.09) | 93.25 (1.17) |
FedRep | 93.61 (2.55) | 99.53 (0.18) | 91.05 (2.59) |
FedDual | 96.87 (0.99) | 99.84 (0.06) | 94.14 (1.25) |
FedVC | 98.56 (0.56) | 99.90 (0.05) | 95.83 (0.96) |
FedVC-sg | 98.51 (0.62) | 99.90 (0.03) | 95.84 (1.10) |
avg. Acc (%) | w. AUC (%) | w. F1 (%) | |
FedAvg+FT | 98.42 (0.84) | 99.91 (0.04) | 95.71 (0.74) |
FedBN | 98.48 (0.84) | 99.91 (0.04) | 95.64 (0.84) |
FedProx | 98.19 (0.98) | 99.90 (0.04) | 95.53 (0.94) |
FedRep | 88.98 (1.37) | 99.00 (0.24) | 86.11 (1.85) |
FedDual | 97.80 (0.51) | 99.88 (0.03) | 95.08 (0.57) |
FedVC | 98.79 (0.62) | 99.91 (0.03) | 95.97 (0.87) |
FedVC-sg | 98.76 (0.67) | 99.91 (0.04) | 95.92 (0.99) |
Group-wise performance
Figure.9 shows the averaged accuracy of clients within different groups, i.e., data distributions. It shows that the global model trained by FedVC is more robust among different distributions, and it generalises well to unseen distributions (client groups 4-5). Fluctuation in the learning curves indicates that the fine-tuned models (FedAvg+FT) and models with personalised parameters (FedBN, FedProx, FedDual) are slightly unstable. Locally trained models (Local Only and Ditto) and FedRep have significant performance gaps among clients.
Feature Shift Settings
This section demonstrates evaluations in feature shift data. Table 3 shows that FedVC achieves the best accuracy, AUC and F1 score under this setting. Other models are less robust than FedVC and their performances vary significantly among clients (higher standard deviations). Group-wised performance in Figure 10 shows that FedVC has a smaller performance gap between different domains and it is more robust for that there is less fluctuation in the learning curves.
avg. Acc (%) | w. AUC (%) | w. F1 (%) | |
Local Only | 74.43 (13.60) | 93.98 (4.53) | 71.34 (13.00) |
FedAvg+FT | 80.92 (9.37) | 96.75 (2.32) | 77.40 (8.87) |
FedBN | 84.34 (10.61) | 97.49 (2.19) | 80.99 (10.02) |
FedProx | 80.50 (11.32) | 96.46 (2.77) | 76.93 (10.82) |
Ditto | 67.30 (18.50) | 92.10 (6.63) | 63.95 (18.54) |
FedRep | 55.44 (20.40) | 85.55 (11.28) | 51.86 (20.74) |
FedDual | 70.36 (15.59) | 93.20 (5.34) | 66.99 (15.57) |
FedVC | 85.42 (8.95) | 97.55 (1.81) | 81.88(8.53) |
FedVC-sg | 85.82 (8.47) | 97.59(1.88) | 82.27(7.99) |
avg. Acc (%) | w. AUC (%) | w. F1 (%) | |
FedAvg+FT | 77.85 (7.71) | 96.26 (1.99) | 74.57 (7.42) |
FedBN | 83.30 (7.05) | 97.45 (1.32) | 79.69 (6.67) |
FedProx | 76.90 (7.92) | 96.19 (2.02) | 73.55 (7.31) |
FedRep | 34.85 (18.75) | 73.08 (11.66) | 30.24 (18.39) |
FedDual | 67.15 (11.46) | 92.74 (4.66) | 63.85 (11.62) |
FedVC | 86.20 (5.62) | 97.61 (1.38) | 82.92 (5.15) |
FedVC-sg | 85.10 (5.92) | 97.68(1.39) | 81.61 (5.70) |
Ablation Study
This section evaluates the effectiveness of FedVC through experiments on the Digit-5 dataset. The section first validates virtual concepts’ capability as supervised information for personalisation by visualising the distribution of estimated client preferences (). Then, it analyses the behaviours of hyperparameters by ablation experiments.
Interpreting Personalisation
Figure 7 compares the latent representations learned by FedAvg and the FedVC. It can be found that FedVC succeeds in supervising the learning process with client preferences so that the distribution of the estimated client preferences are consistent with the group truth knowledge, i.e., samples from the same group (colours) are closer to each other.
in Equation 1 is a hyperparameter that weights the importance of the difference when estimating the client preference . Figure 8(a) shows that client preferences (colours) are unrecognisable with a model learned with a small , i.e., . With the increasing of , the estimated demonstrates structure consistent with their client preferences (Figure 8(b-d)). It validates the effectiveness of the supervision of virtual concepts . The superior performance of FedVC denotes such supervision does improve the performance of a global model, and virtual concepts are indicators that can be utilised to interpret personalisation.
Hyperparameters
The experiments study a hyperparameter’s behaviours by evaluating model performance under different values of the selected hyperparameter while holding the others with default values. According to Table 5 and Table 6, model performance will be improved along with the increasing of the number and the dimension of virtual concepts. Table 7 shows that a larger weight for the similarity between and will increase model performance, which validates the effectiveness of the supervision from virtual concepts. In addition, Table 8 indicates that the newly estimated , and will outperform the older one when using the moving average strategy. Table 9 suggests that needs to be carefully selected when balancing updating the global model and the virtual concepts.
# of VCs | avg. Acc (%) on tr | avg. Acc (%) on ts |
---|---|---|
3 | 85.22(9.47) | 85.05(6.79) |
6 | 85.24(9.10) | 85.15(6.06) |
10 | 85.42(8.95) | 86.20(5.62) |
-VC | avg. Acc (%) on tr | avg. Acc (%) on ts |
---|---|---|
3 | 83.46(9.89) | 83.45(6.95) |
6 | 84.36(9.75) | 83.80(7.07) |
10 | 85.42(8.95) | 86.20(5.62) |
avg. Acc (%) on tr | avg. Acc (%) on ts | |
---|---|---|
0.001 | 84.40(9.30) | 85.00(6.34) |
0.005 | 84.56(9.55) | 85.35(6.44) |
0.01 | 85.74(9.12) | 85.75(6.25) |
0.1 | 85.42(8.95) | 86.20(5.62) |
avg. Acc (%) on tr | avg. Acc (%) on ts | |
---|---|---|
0.01 | 85.66(8.51) | 85.40(6.12) |
0.05 | 85.42(8.95) | 86.20(5.62) |
0.1 | 84.96(9.23) | 85.45(6.18) |
0.5 | 84.44(9.16) | 84.65(6.25) |
0.95 | 84.02(9.51) | 83.85(7.29) |
avg. Acc (%) on tr | avg. Acc (%) on ts | |
---|---|---|
0.01 | 83.14(10.47) | 83.40(7.27) |
0.1 | 85.24(8.97) | 85.30(6.86) |
0.5 | 83.48(10.63) | 82.35(8.03) |
0.95 | 85.46(8.71) | 85.20(6.02) |
![Refer to caption](x6.png)
![Refer to caption](x7.png)
Input: communication rounds , epochs in each round , learning rate , batch size , hyperparameters , and
Output: optimal parameters , virtual concepts
ClientUpdate(, )
Input: communication rounds , epochs in each round , learning rate , batch size , hyperparameters , and
Output: optimal parameters , virtual concepts
ClientUpdate(, )
![Refer to caption](x8.png)
![Refer to caption](x9.png)
Conclusions
The research proposes to utilise virtual concepts as client supervision information to learn a robust global model and to interpret the non-IID data across clients. Specifically, the proposed FedVC interprets each client’s preferences as a mixture of conceptual vectors each one represents an interpretable concept to end-users. These conceptual vectors could be learnt via the optimisation procedure of the federated learning system. In addition to the interpretability, the clarity of client-specific personalisation could also be applied to enhance the robustness of the training process on the FL system. The effectiveness of the proposed methods has been validated on benchmark datasets.
References
- Cheng, Chadha, and Duchi (2021) Cheng, G.; Chadha, K.; and Duchi, J. 2021. Fine-tuning is Fine in Federated Learning. arXiv preprint arXiv:2108.07313.
- Collins et al. (2021) Collins, L.; Hassani, H.; Mokhtari, A.; and Shakkottai, S. 2021. Exploiting Shared Representations for Personalized Federated Learning. In Meila, M.; and Zhang, T., eds., Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, 2089–2099. PMLR.
- Collins et al. (2022) Collins, L.; Hassani, H.; Mokhtari, A.; and Shakkottai, S. 2022. Fedavg with fine tuning: Local updates lead to representation learning. Advances in Neural Information Processing Systems, 35: 10572–10586.
- Hsu, Qi, and Brown (2019) Hsu, T.-M. H.; Qi, H.; and Brown, M. 2019. Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification. arXiv:1909.06335.
- Li et al. (2021a) Li, T.; Hu, S.; Beirami, A.; and Smith, V. 2021a. Ditto: Fair and robust federated learning through personalization. In International Conference on Machine Learning, 6357–6368. PMLR.
- Li et al. (2018) Li, T.; Sahu, A. K.; Zaheer, M.; Sanjabi, M.; Talwalkar, A.; and Smith, V. 2018. Federated optimization in heterogeneous networks. arXiv preprint arXiv:1812.06127.
- Li et al. (2021b) Li, X.; Jiang, M.; Zhang, X.; Kamp, M.; and Dou, Q. 2021b. Fedbn: Federated learning on non-iid features via local batch normalization. arXiv preprint arXiv:2102.07623.
- Pillutla et al. (2022) Pillutla, K.; Malik, K.; Mohamed, A.; Rabbat, M.; Sanjabi, M.; and Xiao, L. 2022. Federated Learning with Partial Model Personalization. arXiv preprint arXiv:2204.03809.
- Van Den Oord, Vinyals et al. (2017) Van Den Oord, A.; Vinyals, O.; et al. 2017. Neural discrete representation learning. Advances in neural information processing systems, 30.