Personalized Interpretation on Federated Learning: A Virtual Concepts approach

Peng Yan, Guodong Long, **g Jiang, Michael Blumenstein
Abstract

Tackling non-IID data is an open challenge in federated learning research. Existing FL methods, including robust FL and personalized FL, are designed to improve model performance without consideration of interpreting non-IID across clients. This paper aims to design a novel FL method to robust and interpret the non-IID data across clients. Specifically, we interpret each client’s dataset as a mixture of conceptual vectors that each one represents an interpretable concept to end-users. These conceptual vectors could be pre-defined or refined in a human-in-the-loop process or be learnt via the optimization procedure of the federated learning system. In addition to the interpretability, the clarity of client-specific personalization could also be applied to enhance the robustness of the training process on FL system. The effectiveness of the proposed method have been validated on benchmark datasets.

Motivation

A critical challenge in PerFL is the absence of well-defined concepts of personalisation. Client preferences and personalised properties are implied in training data and enclosed on each client. They could be a client’s favour towards specific classes or a specific noise mixed up with input features. The only tangible information is the shift in data distribution across clients.

Meanwhile, most machine learning models, e.g., DNNs, are trained in an end-to-end paradigm. They are optimised by back-propagating supervised information, e.g., classification loss, from the output layer to the input layer. Personalisation is performed indirectly when a model is tuned for tasks like classification. This learning schema is less efficient in PerFL. The on-device training tends to overfit a client’s local data due to limited and unbalanced training samples. The aggregation step on the server, in turn, will neutralise personalised information when synthesising the global model, e.g., by averaging local updates.

However, it is worth noting that though there is no supervised information of significantly defined client properties, a feature distinguishing model personalisation from unsupervised tasks is that data in PerFL are explicitly partitioned. Samples from the same client will demonstrate a client-specific bias toward certain properties. Then, one may assume that there were invisible labels of clients inducing the on-device training to progress toward a client’s preferences, i.e., personalisation. The client-based data partition essentially supervises PerFL’s training process, so this research calls the learning paradigm Client-Supervised Learning.

Based on the thought above, this research introduces Virtual Concepts (VC) to explicate client-supervised information. The VCs are representations of potential structure information extracted from training data. They can be learned independently of downstream classification tasks by a novel FedVC algorithm, which facilitates understanding client properties and boosts model personalisation.

Specifically, FedVC assumes that there is a set of vectors (virtual concepts), each describing a type of client property. A client’s preferences are then represented by a combination of VCs, which will be utilised as supervised information to guide the training progress of the global model. Figure 1 gives an illustration to the propose FedVC.

Refer to caption
Figure 1: Illustration to FedVC. (a) data distribution in an FL system; (b) virtual concepts (pentagon, plus and triangle) are vectors indicating underlying cluster structures of data, e.g., cluster centres; (c) a client’s preference (star) is represented by a combination of virtual concepts; (d) client-supervised loss requires sample representations on the same client (data points within the circle) to be close to each other as they share the identical client preference.

To learn the VCs, FedVC evaluates the underlying distribution structure in data by formulating the learning task into a Gaussian Mixture Model (GMM) that can be solved by most unsupervised learning methods, e.g., Expectation-Maximisation algorithm (EM).

Experiments on real-world datasets show that the VCs can work as supervised information to train a robust global model to the changing distributions. Further study demonstrates that the VCs are useful in exploring meaningful client properties by discovering distribution structures implied in training data.

The main contributions are summarised as follows:

  • The research proposes virtual concepts describing client preferences. The VCs are representations of distribution structure extracted from training data. They provide us with a way to explore meaningful client properties relevant to model personalisation.

  • The research proposes a novel client-supervised PerFL framework that utilises virtual concept vectors as supervised information to train the global model. The VCs will allow an FL algorithm to simultaneously learn class and client knowledge so that the learned global model can achieve on-deployment personalisation, where the global model will not require an extra fine-tuning process at the test stage.

  • The research formulates the learning task of VCs into a Gaussian Mixture Model that most unsupervised learning methods can solve. The proposed FedVC framework is compatible with most FL methods, where they can be integrated as an add-on to improve personalisation performance and model interpretability.

  • Contrast with baseline methods shows that FL models trained with VCs can simultaneously learn class and client knowledge. It achieves competitive personalisation performance without requiring extra fine-tuning steps or personal parameters.

  • Empirical studies show that VCs can discover meaningful distribution structures implied in training, facilitating the uncovering of client properties related to model personalisation.

Methodology

Client-supervised PerFL

Let 𝒞={c1,,cM}𝒞subscript𝑐1subscript𝑐𝑀\mathcal{C}=\{c_{1},...,c_{M}\}caligraphic_C = { italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_c start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT } denote m𝑚mitalic_m virtual concept vectors, a client’s preference is then represented by p(k)=m=1Mυm(k)cmsuperscript𝑝𝑘superscriptsubscript𝑚1𝑀subscriptsuperscript𝜐𝑘𝑚subscript𝑐𝑚p^{(k)}=\sum_{m=1}^{M}\upsilon^{(k)}_{m}c_{m}italic_p start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_υ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT, where k𝑘kitalic_k is the client index, and υmsubscript𝜐𝑚\upsilon_{m}italic_υ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT is a factor measuring the degree the client relevant to cmsubscript𝑐𝑚c_{m}italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT, i.e., how typical the client has the property of cmsubscript𝑐𝑚c_{m}italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT. FedVC aims to utilise p(k)superscript𝑝𝑘p^{(k)}italic_p start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT as supervised information to guide FL’s learning process so that the global model can learn client knowledge explicitly.

Refer to caption
Figure 2: Projection head

Specifically, FedVC adds a projection head to FL’s global model to extract a representation z^i(k)superscriptsubscript^𝑧𝑖𝑘\hat{z}_{i}^{(k)}over^ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT of potential client properties (see Figure 2). One can evaluate a sample’s relevance to each concept by a similarity function, e.g., Equation 1, and derive an estimated client preference p^i(k)=m=1Ms^i,m(k)cmsuperscriptsubscript^𝑝𝑖𝑘superscriptsubscript𝑚1𝑀superscriptsubscript^𝑠𝑖𝑚𝑘subscript𝑐𝑚\hat{p}_{i}^{(k)}=\sum_{m=1}^{M}\hat{s}_{i,m}^{(k)}c_{m}over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT over^ start_ARG italic_s end_ARG start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT, where i𝑖iitalic_i is the sample index and ι𝜄\iotaitalic_ι is a hyperparameter.

si,m(k)=υm(k)exp(ιz^i(k)cm2)m=1Mυm(k)exp(ιz^i(k)cm2)superscriptsubscript𝑠𝑖𝑚𝑘superscriptsubscript𝜐𝑚𝑘exp𝜄superscriptnormsuperscriptsubscript^𝑧𝑖𝑘subscript𝑐𝑚2superscriptsubscript𝑚1𝑀superscriptsubscript𝜐𝑚𝑘exp𝜄superscriptnormsuperscriptsubscript^𝑧𝑖𝑘subscript𝑐𝑚2s_{i,m}^{(k)}=\frac{\upsilon_{m}^{(k)}\text{exp}(-\iota\|\hat{z}_{i}^{(k)}-c_{% m}\|^{2})}{\sum_{m=1}^{M}\upsilon_{m}^{(k)}\text{exp}(-\iota\|\hat{z}_{i}^{(k)% }-c_{m}\|^{2})}italic_s start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = divide start_ARG italic_υ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT exp ( - italic_ι ∥ over^ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT - italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_υ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT exp ( - italic_ι ∥ over^ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT - italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG (1)

Then, there will be a supervised loss regarding client preferences, i.e., lp(p^(k),pi(k))=p^(k)pi(k)2subscript𝑙𝑝superscript^𝑝𝑘superscriptsubscript𝑝𝑖𝑘superscriptnormsuperscript^𝑝𝑘superscriptsubscript𝑝𝑖𝑘2l_{p}(\hat{p}^{(k)},p_{i}^{(k)})=\|\hat{p}^{(k)}-p_{i}^{(k)}\|^{2}italic_l start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) = ∥ over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT - italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. It can be integrated into any FL framework and solved by gradient-based methods. Details of the learning algorithm are in Algorithm 1.

Virtual Concepts

As virtual concepts correspond to client properties, a sample is then assumed to be generated by some random process involving a mixture of multiple client properties. FedVC formulates the assumption into a Gaussian Mixture Model (GMM). For any sample 𝐳(k)superscript𝐳𝑘\mathbf{z}^{(k)}bold_z start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT on the k𝑘kitalic_k-th client, there is

𝐳(k)𝒫(k)(𝐳)=m=1Mυm(k)𝒩(𝐳;cm,Σm)similar-tosuperscript𝐳𝑘superscript𝒫𝑘𝐳superscriptsubscript𝑚1𝑀subscriptsuperscript𝜐𝑘𝑚𝒩𝐳subscript𝑐𝑚subscriptΣ𝑚\mathbf{z}^{(k)}\sim\mathcal{P}^{(k)}(\mathbf{z})=\sum_{m=1}^{M}\upsilon^{(k)}% _{m}\mathcal{N}(\mathbf{z};c_{m},\Sigma_{m})bold_z start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ∼ caligraphic_P start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( bold_z ) = ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_υ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT caligraphic_N ( bold_z ; italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) (2)

where the covariance ΣmsubscriptΣ𝑚\Sigma_{m}roman_Σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT is set to be the identity matrix I𝐼Iitalic_I for simplicity.

Let 𝒞={c1,,cM}𝒞subscript𝑐1subscript𝑐𝑀\mathcal{C}=\{c_{1},...,c_{M}\}caligraphic_C = { italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_c start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT } denotes the set of VCs and Υ={{υm(1)}m=1M,,{υm(K)}m=1M}Υsuperscriptsubscriptsubscriptsuperscript𝜐1𝑚𝑚1𝑀superscriptsubscriptsubscriptsuperscript𝜐𝐾𝑚𝑚1𝑀\Upsilon=\{\{\upsilon^{(1)}_{m}\}_{m=1}^{M},...,\{\upsilon^{(K)}_{m}\}_{m=1}^{% M}\}roman_Υ = { { italic_υ start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT , … , { italic_υ start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT } denotes the set of client preferences, the collaborative learning task for 𝒞𝒞\mathcal{C}caligraphic_C and ΥΥ\Upsilonroman_Υ is formulated as

𝒞,Υ=argmax𝒞,Υk=1Ki=1Nklog𝒫(k)(zi(k))superscript𝒞superscriptΥsubscript𝒞Υsuperscriptsubscript𝑘1𝐾superscriptsubscript𝑖1subscript𝑁𝑘superscript𝒫𝑘superscriptsubscript𝑧𝑖𝑘\mathcal{C}^{*},\Upsilon^{*}=\arg\max_{\mathcal{C},\Upsilon}\sum_{k=1}^{K}\sum% _{i=1}^{N_{k}}\log\mathcal{P}^{(k)}(z_{i}^{(k)})caligraphic_C start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , roman_Υ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = roman_arg roman_max start_POSTSUBSCRIPT caligraphic_C , roman_Υ end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log caligraphic_P start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) (3)

FedVC solves it by the EM framework bellow:

  • E-step: Given 𝒞𝒞\mathcal{C}caligraphic_C and ΥΥ\Upsilonroman_Υ, clients estimate local samples’ si,m(k)superscriptsubscript𝑠𝑖𝑚𝑘s_{i,m}^{(k)}italic_s start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT by Equation 1

  • M-step: Clients update 𝒞𝒞\mathcal{C}caligraphic_C and ΥΥ\Upsilonroman_Υ collaboratively by Equation 4 and Equation 5

    υm(k)=1Nki=1Nksi,m(k)superscriptsubscript𝜐𝑚𝑘1subscript𝑁𝑘superscriptsubscript𝑖1subscript𝑁𝑘superscriptsubscript𝑠𝑖𝑚𝑘\upsilon_{m}^{(k)}=\frac{1}{N_{k}}\sum_{i=1}^{N_{k}}s_{i,m}^{(k)}italic_υ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_s start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT (4)
    cm=k=1Ki=1Nksi,m(k)z^i(k)k=1Ki=1Nksi,m(k)subscript𝑐𝑚superscriptsubscript𝑘1𝐾superscriptsubscript𝑖1subscript𝑁𝑘superscriptsubscript𝑠𝑖𝑚𝑘superscriptsubscript^𝑧𝑖𝑘superscriptsubscript𝑘1𝐾superscriptsubscript𝑖1subscript𝑁𝑘superscriptsubscript𝑠𝑖𝑚𝑘c_{m}=\frac{\sum_{k=1}^{K}\sum_{i=1}^{N_{k}}s_{i,m}^{(k)}\hat{z}_{i}^{(k)}}{% \sum_{k=1}^{K}\sum_{i=1}^{N_{k}}s_{i,m}^{(k)}}italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = divide start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_s start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT over^ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_s start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT end_ARG (5)

However, Equation 4 and Equation 5 cannot be applied directly when working with minibatches in FL settings. FedVC uses exponential moving averages as an alternative:

Sm(k)=Sm(k)κ+i𝔹si,m(k)(1κ)S_{m}^{{}^{\prime}(k)}=S_{m}^{(k)}*\kappa+\sum_{i\in\mathbb{B}}s_{i,m}^{(k)}*(% 1-\kappa)italic_S start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = italic_S start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ∗ italic_κ + ∑ start_POSTSUBSCRIPT italic_i ∈ blackboard_B end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ∗ ( 1 - italic_κ ) (6)
Cm(k)=Cm(k)κ+i𝔹si,m(k)z^i(k)(1κ)C_{m}^{{}^{\prime}(k)}=C_{m}^{(k)}*\kappa+\sum_{i\in\mathbb{B}}s_{i,m}^{(k)}% \hat{z}_{i}^{(k)}*(1-\kappa)italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ∗ italic_κ + ∑ start_POSTSUBSCRIPT italic_i ∈ blackboard_B end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT over^ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ∗ ( 1 - italic_κ ) (7)
Nk=Nkκ+|𝔹|(1κ)subscriptsuperscript𝑁𝑘subscript𝑁𝑘𝜅𝔹1𝜅N^{\prime}_{k}=N_{k}*\kappa+|\mathbb{B}|*(1-\kappa)italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∗ italic_κ + | blackboard_B | ∗ ( 1 - italic_κ ) (8)

where 𝔹𝔹\mathbb{B}blackboard_B denotes a minibatch of samples, |𝔹|𝔹|\mathbb{B}|| blackboard_B | denotes the batch size, and κ𝜅\kappaitalic_κ is a smoothing hyperparameter between 0 and 1. Then,

υm(k)=Sm(k)Nk\upsilon_{m}^{(k)}=\frac{S_{m}^{{}^{\prime}(k)}}{N^{\prime}_{k}}italic_υ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = divide start_ARG italic_S start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG (9)
cm=k=1KCm(k)k=1KSm(k)c_{m}=\frac{\sum_{k=1}^{K}C_{m}^{{}^{\prime}(k)}}{\sum_{k=1}^{K}S_{m}^{{}^{% \prime}(k)}}italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = divide start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT end_ARG (10)

The overall learning algorithm is described in Algorithm 1.

Unified Learning Process

It is worth noting that the client preference p(k)=m=1Mυm(k)cmsuperscript𝑝𝑘superscriptsubscript𝑚1𝑀subscriptsuperscript𝜐𝑘𝑚subscript𝑐𝑚p^{(k)}=\sum_{m=1}^{M}\upsilon^{(k)}_{m}c_{m}italic_p start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_υ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT can be viewed a function of virtual concepts 𝒞𝒞\mathcal{C}caligraphic_C, so does the loss lp(p^,p)subscript𝑙𝑝^𝑝𝑝l_{p}(\hat{p},p)italic_l start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( over^ start_ARG italic_p end_ARG , italic_p ). Then, the learning processes for 𝒞𝒞\mathcal{C}caligraphic_C and the global ω𝜔\omegaitalic_ω can be formulated into a unified optimisation task that can be solved in an end-to-end manner, rather than in an alternate way as EM-based methods.

Refer to caption
Figure 3: FedVC architecture.

Concretely, as described in Figure 3, lp(p^,p)subscript𝑙𝑝^𝑝𝑝l_{p}(\hat{p},p)italic_l start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( over^ start_ARG italic_p end_ARG , italic_p ) will simultaneously provide supervised information for optimising virtual concepts and the model. The unified learning object is formulated as

ω,𝒞=argminω,𝒞k=1Kαkk(ω,𝒞)superscript𝜔superscript𝒞subscript𝜔𝒞superscriptsubscript𝑘1𝐾subscript𝛼𝑘subscript𝑘𝜔𝒞\displaystyle\omega^{*},\mathcal{C}^{*}=\arg\min_{\omega,\mathcal{C}}\sum_{k=1% }^{K}\alpha_{k}\mathcal{L}_{k}(\omega,\mathcal{C})italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , caligraphic_C start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = roman_arg roman_min start_POSTSUBSCRIPT italic_ω , caligraphic_C end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_ω , caligraphic_C ) (11)

where

k(ω)=subscript𝑘𝜔absent\displaystyle\mathcal{L}_{k}(\omega)=caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_ω ) = (1/Nk)i=1Nklcls(y^i(k),yi(k))1subscript𝑁𝑘superscriptsubscript𝑖1subscript𝑁𝑘subscript𝑙𝑐𝑙𝑠superscriptsubscript^𝑦𝑖𝑘superscriptsubscript𝑦𝑖𝑘\displaystyle(1/N_{k})\sum_{i=1}^{N_{k}}l_{cls}(\hat{y}_{i}^{(k)},y_{i}^{(k)})( 1 / italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_l start_POSTSUBSCRIPT italic_c italic_l italic_s end_POSTSUBSCRIPT ( over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) (12)
+lp(p^i(k),sg[p(k)])+γlp(sg[p^i(k)],p(k))subscript𝑙𝑝superscriptsubscript^𝑝𝑖𝑘sgdelimited-[]superscript𝑝𝑘𝛾subscript𝑙𝑝sgdelimited-[]superscriptsubscript^𝑝𝑖𝑘superscript𝑝𝑘\displaystyle+l_{p}(\hat{p}_{i}^{(k)},\text{sg}[p^{(k)}])+\gamma l_{p}(\text{% sg}[\hat{p}_{i}^{(k)}],p^{(k)})+ italic_l start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , sg [ italic_p start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ] ) + italic_γ italic_l start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( sg [ over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ] , italic_p start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT )

The sg[]sgdelimited-[]\text{sg}[\cdot]sg [ ⋅ ] is the stopgradient operator (Van Den Oord, Vinyals et al. 2017), where the operand will feed forward as normal but have zero partial derivatives, being a non-updated constant. γ𝛾\gammaitalic_γ is a hyperparameter balancing the two loss. The corresponding learning process is summarised in Algorithm 2.

Experiments

This section empirical studies the advantages of FedVC in learning from clients with non-I.I.D. data. The FedVC can learn a robust FL global model for the changing data distributions of unseen/test clients. The FedVC’s global model can be directly deployed to the test clients while achieving comparable performance to other personalised FL methods that require model adaptation.

Non-I.I.D settings

Target Shift: MNIST is applied as a benchmark to simulate the non-I.I.D. environments. The experiment allocates samples of each class individually according to a posterior of the Dirichlet distribution(Hsu, Qi, and Brown 2019), which divides clients into five groups with different class distributions. Three groups of clients will participate in the collaborative training process, and the rest will be held for testing. An illustration of client settings is in Figure 4. Class distributions are shown in Figure 5.

Refer to caption
Figure 4: Clients in the target shift setting. Each bar denotes a client. Each colour indicates one type of distribution. Samples on each client are split into a training set and a test set.
Refer to caption
Figure 5: Class distributions on clients. Each bar denotes the class distribution on a client. Each colour corresponds to a class and the length indicates its proportion on the client.

Feature Shift: The research utilises the Digit-5 dataset to evaluate FedVC’s performance on feature-shift data. The Digit-5 consists of digits from five different domains (MNIST, MNIST-M, SVHN, USPS and Synth Digits). The experiment assigns samples of each domain to six clients, where five clients will participate in training the global model and one will be held aside for the test.Classes are evenly distributed on each client. In addition, it randomly draws samples from all domains to compose five mixed datasets for the rest clients for the test. An illustration of client settings is in Figure 6.

Refer to caption
Figure 6: Clients in the feature shift setting. Each bar denotes a client. Each colour indicates a domain. Samples of each client are split into a training set and a test set.

Models and Hyperparameters

The research applies convolution neural networks (CNN) as fundamental models and supervises the training process by virtual concepts. By default, in each communication round, ten clients are sampled to update the global model and virtual concepts, and subsequently, the global model is synchronised to all clients to evaluate its performance. The learning rate of a client’s local training step is initialised as 0.005 and it will decay at the rate of 0.8 every 10 communication rounds. During each communication round, a client will tune the global model on its local data for two epochs with a batch size of 10.

For the FedVC, the default number of virtual concepts is set to be 10 and the dimension of each virtual concept is 10. The similarity parameter ι𝜄\iotaitalic_ι is 0.1, and the smoothing parameter κ𝜅\kappaitalic_κ is 0.05.

Baseline Methods

Several PerFL strategies are compared as baselines, including:

  • Local Only: models those trained on each client locally

  • FedAvg + FT: personalisation by fine-tuning the global model on local data (Cheng, Chadha, and Duchi 2021; Collins et al. 2022)

  • FedBN: a global model with private BatchNormalisation layers (Li et al. 2021b)

  • FedProx: leverages a global to regularise the local training process (Li et al. 2018)

  • Ditto: leverages a global to regularise the local training process while learning a local model for each client (Li et al. 2021a)

  • FedRep: personalisation by training local classification heads (Collins et al. 2021)

  • FedDual: personalisation by training a global and a local feature extractors (Pillutla et al. 2022)

Performance

This section first demonstrates averaged model performance on all clients, which shows that a global model learned with FedVC will achieve comparable performance to other personalised FL methods that require model adaptation. Then, it looks inside the group-wised metrics to evaluate a model’s performance on different distributions. Results show that the global model learned with FedVC is more robust to the changing distributions. The learned global model can be directly deployed on test clients without extra adaptations.

Target Shift Settings

For target shift settings, the averaged accuracy, weighted AUC score and weighted F1 score are applied to evaluate model performance111https://scikit-learn.org/stable/index.html. Table 1 and Table 2 respectively report the averaged performance over the training clients and the test clients. Figure 9 shows the group-wise performance.

Overall performance
Table 1
demonstrates models’ performance on the MNIST dataset on the training clients (tr-clients). It can be found that a global model trained by FedVC achieves the best performance under the target shift setting. It outperforms those locally fine-tuned global models (FedAvg+FT) and models with client-specific parameters (FedBN, FedProx, FedRep and FedDual). Table 2 demonstrates models’ performance on the test clients (ts-clients). All baseline methods are fine-tuned on the test clients to adapt to the client’s local distribution. It can be found that the model learned by FedVC generalised well to the unseen clients, even though they are not fine-tuned. Note that locally trained models (Local Only and Ditto) can not be generalised to unseen clients.

avg. Acc (%) \uparrow w. AUC (%) \uparrow w. F1 (%) \uparrow
Local Only 95.79 (1.00) 99.69 (0.11) 93.21 (0.96)
FedAvg+FT 97.92 (0.98) 99.90 (0.04) 95.49 (1.07)
FedBN 98.43 (0.86) 99.90 (0.04) 95.71 (0.94)
FedProx 98.07 (0.90) 99.89 (0.04) 95.48 (0.89)
Ditto 95.86 (1.19) 99.71 (0.09) 93.25 (1.17)
FedRep 93.61 (2.55) 99.53 (0.18) 91.05 (2.59)
FedDual 96.87 (0.99) 99.84 (0.06) 94.14 (1.25)
FedVC 98.56 (0.56) 99.90 (0.05) 95.83 (0.96)
FedVC-sg 98.51 (0.62) 99.90 (0.03) 95.84 (1.10)
Table 1: Overall performance on the MNIST dataset on the training clients. The standard deviation of each metric is reported in parentheses. avg. is the abbreviation of ’averaged’ and w. denotes the ’weighted’. The \uparrow denotes that the higher the metric is, the better performance a model achieved, and the best performance is highlighted in bold.
avg. Acc (%) \uparrow w. AUC (%) \uparrow w. F1 (%) \uparrow
FedAvg+FT 98.42 (0.84) 99.91 (0.04) 95.71 (0.74)
FedBN 98.48 (0.84) 99.91 (0.04) 95.64 (0.84)
FedProx 98.19 (0.98) 99.90 (0.04) 95.53 (0.94)
FedRep 88.98 (1.37) 99.00 (0.24) 86.11 (1.85)
FedDual 97.80 (0.51) 99.88 (0.03) 95.08 (0.57)
FedVC 98.79 (0.62) 99.91 (0.03) 95.97 (0.87)
FedVC-sg 98.76 (0.67) 99.91 (0.04) 95.92 (0.99)
Table 2: Overall performance on the MNIST dataset on the test clients. The standard deviation of each metric is reported in parentheses. avg. is the abbreviation of ’averaged’ and w. denotes the ’weighted’. The \uparrow denotes that the higher the metric is, the better performance a model achieved, and the best performance is highlighted in bold.

Group-wise performance
Figure.9
shows the averaged accuracy of clients within different groups, i.e., data distributions. It shows that the global model trained by FedVC is more robust among different distributions, and it generalises well to unseen distributions (client groups 4-5). Fluctuation in the learning curves indicates that the fine-tuned models (FedAvg+FT) and models with personalised parameters (FedBN, FedProx, FedDual) are slightly unstable. Locally trained models (Local Only and Ditto) and FedRep have significant performance gaps among clients.

Feature Shift Settings

This section demonstrates evaluations in feature shift data. Table 3 shows that FedVC achieves the best accuracy, AUC and F1 score under this setting. Other models are less robust than FedVC and their performances vary significantly among clients (higher standard deviations). Group-wised performance in Figure 10 shows that FedVC has a smaller performance gap between different domains and it is more robust for that there is less fluctuation in the learning curves.

avg. Acc (%) \uparrow w. AUC (%) \uparrow w. F1 (%)\uparrow
Local Only 74.43 (13.60) 93.98 (4.53) 71.34 (13.00)
FedAvg+FT 80.92 (9.37) 96.75 (2.32) 77.40 (8.87)
FedBN 84.34 (10.61) 97.49 (2.19) 80.99 (10.02)
FedProx 80.50 (11.32) 96.46 (2.77) 76.93 (10.82)
Ditto 67.30 (18.50) 92.10 (6.63) 63.95 (18.54)
FedRep 55.44 (20.40) 85.55 (11.28) 51.86 (20.74)
FedDual 70.36 (15.59) 93.20 (5.34) 66.99 (15.57)
FedVC 85.42 (8.95) 97.55 (1.81) 81.88(8.53)
FedVC-sg 85.82 (8.47) 97.59(1.88) 82.27(7.99)
Table 3: Overall performance on the Digit-5 dataset on the training clients. The standard deviation of each metric is reported in parentheses. avg. is the abbreviation of ’averaged’ and w. denotes the ’weighted’. The \uparrow denotes that the higher the metric is, the better performance a model achieved, and the best performance is highlighted in bold.
avg. Acc (%) \uparrow w. AUC (%) \uparrow w. F1 (%)\uparrow
FedAvg+FT 77.85 (7.71) 96.26 (1.99) 74.57 (7.42)
FedBN 83.30 (7.05) 97.45 (1.32) 79.69 (6.67)
FedProx 76.90 (7.92) 96.19 (2.02) 73.55 (7.31)
FedRep 34.85 (18.75) 73.08 (11.66) 30.24 (18.39)
FedDual 67.15 (11.46) 92.74 (4.66) 63.85 (11.62)
FedVC 86.20 (5.62) 97.61 (1.38) 82.92 (5.15)
FedVC-sg 85.10 (5.92) 97.68(1.39) 81.61 (5.70)
Table 4: Overall performance on the Digit-5 dataset on the test clients. The standard deviation of each metric is reported in parentheses. avg. is the abbreviation of ’averaged’ and w. denotes the ’weighted’. The \uparrow denotes that the higher the metric is, the better performance a model achieved, and the best performance is highlighted in bold.

Ablation Study

This section evaluates the effectiveness of FedVC through experiments on the Digit-5 dataset. The section first validates virtual concepts’ capability as supervised information for personalisation by visualising the distribution of estimated client preferences (p^^𝑝\hat{p}over^ start_ARG italic_p end_ARG). Then, it analyses the behaviours of hyperparameters by ablation experiments.

Interpreting Personalisation

Figure 7 compares the latent representations learned by FedAvg and the FedVC. It can be found that FedVC succeeds in supervising the learning process with client preferences so that the distribution of the estimated client preferences p^^𝑝\hat{p}over^ start_ARG italic_p end_ARG are consistent with the group truth knowledge, i.e., samples from the same group (colours) are closer to each other.

ι𝜄\iotaitalic_ι in Equation 1 is a hyperparameter that weights the importance of the difference |z^c|^𝑧𝑐|\hat{z}-c|| over^ start_ARG italic_z end_ARG - italic_c | when estimating the client preference p^^𝑝\hat{p}over^ start_ARG italic_p end_ARG. Figure 8(a) shows that client preferences (colours) are unrecognisable with a model learned with a small ι𝜄\iotaitalic_ι, i.e., ι=0.001𝜄0.001\iota=0.001italic_ι = 0.001. With the increasing of ι𝜄\iotaitalic_ι, the estimated p^^𝑝\hat{p}over^ start_ARG italic_p end_ARG demonstrates structure consistent with their client preferences (Figure 8(b-d)). It validates the effectiveness of the supervision of virtual concepts c𝑐citalic_c. The superior performance of FedVC denotes such supervision does improve the performance of a global model, and virtual concepts are indicators that can be utilised to interpret personalisation.

Hyperparameters

The experiments study a hyperparameter’s behaviours by evaluating model performance under different values of the selected hyperparameter while holding the others with default values. According to Table 5 and Table 6, model performance will be improved along with the increasing of the number and the dimension of virtual concepts. Table 7 shows that a larger weight for the similarity between z^^𝑧\hat{z}over^ start_ARG italic_z end_ARG and c𝑐citalic_c will increase model performance, which validates the effectiveness of the supervision from virtual concepts. In addition, Table 8 indicates that the newly estimated S𝑆Sitalic_S, C𝐶Citalic_C and N𝑁Nitalic_N will outperform the older one when using the moving average strategy. Table 9 suggests that γ𝛾\gammaitalic_γ needs to be carefully selected when balancing updating the global model and the virtual concepts.

# of VCs avg. Acc (%)\uparrow on tr avg. Acc (%)\uparrow on ts
3 85.22(9.47) 85.05(6.79)
6 85.24(9.10) 85.15(6.06)
10 85.42(8.95) 86.20(5.62)
Table 5: Performance with different number of virtual concepts
d𝑑ditalic_d-VC avg. Acc (%)\uparrow on tr avg. Acc (%)\uparrow on ts
3 83.46(9.89) 83.45(6.95)
6 84.36(9.75) 83.80(7.07)
10 85.42(8.95) 86.20(5.62)
Table 6: Performance with different dimensions of virtual concepts
ι𝜄\iotaitalic_ι avg. Acc (%)\uparrow on tr avg. Acc (%)\uparrow on ts
0.001 84.40(9.30) 85.00(6.34)
0.005 84.56(9.55) 85.35(6.44)
0.01 85.74(9.12) 85.75(6.25)
0.1 85.42(8.95) 86.20(5.62)
Table 7: Performance with different similarity parameter ι𝜄\iotaitalic_ι. The larger the ι𝜄\iotaitalic_ι is, the more weight the difference |z^c|^𝑧𝑐|\hat{z}-c|| over^ start_ARG italic_z end_ARG - italic_c | when estimating the client preference p^^𝑝\hat{p}over^ start_ARG italic_p end_ARG
κ𝜅\kappaitalic_κ avg. Acc (%)\uparrow on tr avg. Acc (%)\uparrow on ts
0.01 85.66(8.51) 85.40(6.12)
0.05 85.42(8.95) 86.20(5.62)
0.1 84.96(9.23) 85.45(6.18)
0.5 84.44(9.16) 84.65(6.25)
0.95 84.02(9.51) 83.85(7.29)
Table 8: Performance with different smoothing parameter κ𝜅\kappaitalic_κ. The larger the κ𝜅\kappaitalic_κ is, the more weight the previous estimation of S𝑆Sitalic_S, C𝐶Citalic_C and N𝑁Nitalic_N.
γ𝛾\gammaitalic_γ avg. Acc (%)\uparrow on tr avg. Acc (%)\uparrow on ts
0.01 83.14(10.47) 83.40(7.27)
0.1 85.24(8.97) 85.30(6.86)
0.5 83.48(10.63) 82.35(8.03)
0.95 85.46(8.71) 85.20(6.02)
Table 9: Performance with different balancing parameter γ𝛾\gammaitalic_γ. The larger the γ𝛾\gammaitalic_γ is, the more important the loss lpsubscript𝑙𝑝l_{p}italic_l start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT to optimising the virtual concepts c𝑐citalic_c.
Refer to caption
Figure 7: Distribution of estimated client preferences. Colours indicate the client group, i.e., the domain, samples belong to. (a) The aggregation process by vanilla FedAvg will eliminate the information on client preferences so that sample representations are mixed regarding their domains. (b) Virtual concepts succeed in supervising the learning process with client preferences so that the distribution of the estimated client preferences p^^𝑝\hat{p}over^ start_ARG italic_p end_ARG are consistent with their domain knowledge, i.e., samples from the same domain will be closer to each other.
Refer to caption
Figure 8: Distribution of estimated client preferences with different ι𝜄\iotaitalic_ι. The smaller the ι𝜄\iotaitalic_ι is, the less weight the difference |z^c|^𝑧𝑐|\hat{z}-c|| over^ start_ARG italic_z end_ARG - italic_c | when estimating the client preference p^^𝑝\hat{p}over^ start_ARG italic_p end_ARG.
Algorithm 1 FedVC

Input: communication rounds R𝑅Ritalic_R, epochs in each round E𝐸Eitalic_E, learning rate λ𝜆\lambdaitalic_λ, batch size B𝐵Bitalic_B, hyperparameters ι𝜄\iotaitalic_ι, κ𝜅\kappaitalic_κ and γ𝛾\gammaitalic_γ
Output: optimal parameters ωsuperscript𝜔\omega^{*}italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, virtual concepts 𝒞superscript𝒞\mathcal{C}^{*}caligraphic_C start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT

1:server initialises parameters ω𝜔\omegaitalic_ω and virtual concepts 𝒞𝒞\mathcal{C}caligraphic_C
2:for r𝑟ritalic_r from 00 to R𝑅Ritalic_R do\triangleright communication rounds
3:     server selects a set of clients \mathbb{C}blackboard_C
4:     for k𝑘k\in\mathbb{C}italic_k ∈ blackboard_C parallel do
5:         client k𝑘kitalic_k synchronises ω𝜔\omegaitalic_ω and 𝒞𝒞\mathcal{C}caligraphic_C from the server\triangleright network traffic
6:         ωk,Sm(k),Cm(k)\omega_{k},S_{m}^{{}^{\prime}(k)},C_{m}^{{}^{\prime}(k)}\leftarrowitalic_ω start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ←ClientUpdate(ω𝜔\omegaitalic_ω)
7:     end for
8:     server collects local updates ωksubscript𝜔𝑘\omega_{k}italic_ω start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, Sm(k)S_{m}^{{}^{\prime}(k)}italic_S start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT and Cm(k)C_{m}^{{}^{\prime}(k)}italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT k𝑘k\in\mathbb{C}italic_k ∈ blackboard_C\triangleright network traffic
9:     ωkαkωk𝜔subscript𝑘subscript𝛼𝑘subscript𝜔𝑘\omega\leftarrow\sum_{k\in\mathbb{C}}\alpha_{k}\omega_{k}italic_ω ← ∑ start_POSTSUBSCRIPT italic_k ∈ blackboard_C end_POSTSUBSCRIPT italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_ω start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT
10:     update cm𝒞subscript𝑐𝑚𝒞c_{m}\in\mathcal{C}italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∈ caligraphic_C by Equation 10
11:end for
12:return ω𝜔\omegaitalic_ω, 𝒞𝒞\mathcal{C}caligraphic_C

ClientUpdate(ω𝜔\omegaitalic_ω, 𝒞𝒞\mathcal{C}caligraphic_C)

1:for any sample on the clients do \triangleright Update client preferencesp(k)superscript𝑝𝑘p^{(k)}italic_p start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT
2:     get model outputs by y^,z^=f(x;ω)^𝑦^𝑧𝑓𝑥𝜔\hat{y},\hat{z}=f(x;\omega)over^ start_ARG italic_y end_ARG , over^ start_ARG italic_z end_ARG = italic_f ( italic_x ; italic_ω )
3:     calculate si,m(k)superscriptsubscript𝑠𝑖𝑚𝑘s_{i,m}^{(k)}italic_s start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT by Equation 1
4:     update vm(k)superscriptsubscript𝑣𝑚𝑘v_{m}^{(k)}italic_v start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT by Equation 9
5:     update client preference p(k)m=1Mvm(k)cmsuperscript𝑝𝑘superscriptsubscript𝑚1𝑀superscriptsubscript𝑣𝑚𝑘subscript𝑐𝑚p^{(k)}\leftarrow\sum_{m=1}^{M}v_{m}^{(k)}c_{m}italic_p start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ← ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_v start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT
6:end for
7:for e𝑒eitalic_e from 00 to E𝐸Eitalic_E do
8:     for b𝑏bitalic_b from 00 to Nk/Bsubscript𝑁𝑘𝐵N_{k}/Bitalic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT / italic_B do
9:         sample a batch of data 𝔹𝔹\mathbb{B}blackboard_B
10:         ωωω(lp+lcls)𝜔𝜔subscript𝜔subscript𝑙𝑝subscript𝑙𝑐𝑙𝑠\omega\leftarrow\omega-\nabla_{\omega}(l_{p}+l_{cls})italic_ω ← italic_ω - ∇ start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT ( italic_l start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT + italic_l start_POSTSUBSCRIPT italic_c italic_l italic_s end_POSTSUBSCRIPT ) \triangleright Update model
11:         update S𝑆Sitalic_S, C𝐶Citalic_C and N𝑁Nitalic_N by Equation 6, 7 and 8 respectively
12:     end for
13:end for
14:return ω𝜔\omegaitalic_ω, S𝑆Sitalic_S and C𝐶Citalic_C
Algorithm 2 FedVC-unified

Input: communication rounds R𝑅Ritalic_R, epochs in each round E𝐸Eitalic_E, learning rate λ𝜆\lambdaitalic_λ, batch size B𝐵Bitalic_B, hyperparameters ι𝜄\iotaitalic_ι, κ𝜅\kappaitalic_κ and γ𝛾\gammaitalic_γ
Output: optimal parameters ωsuperscript𝜔\omega^{*}italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, virtual concepts 𝒞superscript𝒞\mathcal{C}^{*}caligraphic_C start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT

1:server initialises parameters ω𝜔\omegaitalic_ω and virtual concepts 𝒞𝒞\mathcal{C}caligraphic_C
2:for r𝑟ritalic_r from 00 to R𝑅Ritalic_R do\triangleright communication rounds
3:     server selects a set of clients \mathbb{C}blackboard_C
4:     for k𝑘k\in\mathbb{C}italic_k ∈ blackboard_C parallel do
5:         client k𝑘kitalic_k synchronises ω𝜔\omegaitalic_ω and 𝒞𝒞\mathcal{C}caligraphic_C from the server\triangleright network traffic
6:         ωksubscript𝜔𝑘\omega_{k}italic_ω start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, 𝒞ksubscript𝒞𝑘\mathcal{C}_{k}caligraphic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT \leftarrowClientUpdate(ω,𝒞)\omega,\mathcal{C})italic_ω , caligraphic_C )
7:     end for
8:     server collects local updatesωk,𝒞ksubscript𝜔𝑘subscript𝒞𝑘\omega_{k},\mathcal{C}_{k}italic_ω start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , caligraphic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT k𝑘k\in\mathbb{C}italic_k ∈ blackboard_C\triangleright network traffic
9:     ωkαkωk𝜔subscript𝑘subscript𝛼𝑘subscript𝜔𝑘\omega\leftarrow\sum_{k\in\mathbb{C}}\alpha_{k}\omega_{k}italic_ω ← ∑ start_POSTSUBSCRIPT italic_k ∈ blackboard_C end_POSTSUBSCRIPT italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_ω start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT
10:     cmkαkcm(k)subscript𝑐𝑚subscript𝑘subscript𝛼𝑘superscriptsubscript𝑐𝑚𝑘c_{m}\leftarrow\sum_{k\in\mathbb{C}}\alpha_{k}c_{m}^{(k)}italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ← ∑ start_POSTSUBSCRIPT italic_k ∈ blackboard_C end_POSTSUBSCRIPT italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT, cm(k)𝒞ksuperscriptsubscript𝑐𝑚𝑘subscript𝒞𝑘c_{m}^{(k)}\in\mathcal{C}_{k}italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ∈ caligraphic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT
11:end for
12:return ω𝜔\omegaitalic_ω, 𝒞𝒞\mathcal{C}caligraphic_C

ClientUpdate(ω𝜔\omegaitalic_ω, 𝒞𝒞\mathcal{C}caligraphic_C)

1:for any sample on the clients do \triangleright Update client preferencesp(k)superscript𝑝𝑘p^{(k)}italic_p start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT
2:     get model outputs by y^,z^=f(x;ω)^𝑦^𝑧𝑓𝑥𝜔\hat{y},\hat{z}=f(x;\omega)over^ start_ARG italic_y end_ARG , over^ start_ARG italic_z end_ARG = italic_f ( italic_x ; italic_ω )
3:     calculate si,m(k)superscriptsubscript𝑠𝑖𝑚𝑘s_{i,m}^{(k)}italic_s start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT by Equation 1
4:     update vm(k)superscriptsubscript𝑣𝑚𝑘v_{m}^{(k)}italic_v start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT by Equation 9
5:     update client preference p(k)m=1Mvm(k)cmsuperscript𝑝𝑘superscriptsubscript𝑚1𝑀superscriptsubscript𝑣𝑚𝑘subscript𝑐𝑚p^{(k)}\leftarrow\sum_{m=1}^{M}v_{m}^{(k)}c_{m}italic_p start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ← ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_v start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT
6:end for
7:for e𝑒eitalic_e from 00 to E𝐸Eitalic_E do
8:     for b𝑏bitalic_b from 00 to Nk/Bsubscript𝑁𝑘𝐵N_{k}/Bitalic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT / italic_B do
9:         sample a batch of data 𝔹𝔹\mathbb{B}blackboard_B
10:         ωωωk𝜔𝜔subscript𝜔subscript𝑘\omega\leftarrow\omega-\nabla_{\omega}\mathcal{L}_{k}italic_ω ← italic_ω - ∇ start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT
11:         cmcmcksubscript𝑐𝑚subscript𝑐𝑚subscript𝑐subscript𝑘c_{m}\leftarrow c_{m}-\nabla_{c}\mathcal{L}_{k}italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ← italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - ∇ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, cm𝒞subscript𝑐𝑚𝒞c_{m}\in\mathcal{C}italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∈ caligraphic_C
12:     end for
13:end for
14:return ω𝜔\omegaitalic_ω, 𝒞𝒞\mathcal{C}caligraphic_C
Refer to caption
Figure 9: Grouped-wise accuracy on MNIST. The horizontal axis denotes communication rounds and the vertical axis denotes the accuracy. Each colour corresponds to a client group, i.e., data distribution. Shade indicates the standard deviation of accuracy among clients in the group.
Refer to caption
Figure 10: Grouped-wise accuracy on Digit-5. The horizontal axis denotes communication rounds and the vertical axis denotes the accuracy. Each colour corresponds to a client group, i.e., data distribution. Shade indicates the standard deviation of accuracy among clients in the group.

Conclusions

The research proposes to utilise virtual concepts as client supervision information to learn a robust global model and to interpret the non-IID data across clients. Specifically, the proposed FedVC interprets each client’s preferences as a mixture of conceptual vectors each one represents an interpretable concept to end-users. These conceptual vectors could be learnt via the optimisation procedure of the federated learning system. In addition to the interpretability, the clarity of client-specific personalisation could also be applied to enhance the robustness of the training process on the FL system. The effectiveness of the proposed methods has been validated on benchmark datasets.

References

  • Cheng, Chadha, and Duchi (2021) Cheng, G.; Chadha, K.; and Duchi, J. 2021. Fine-tuning is Fine in Federated Learning. arXiv preprint arXiv:2108.07313.
  • Collins et al. (2021) Collins, L.; Hassani, H.; Mokhtari, A.; and Shakkottai, S. 2021. Exploiting Shared Representations for Personalized Federated Learning. In Meila, M.; and Zhang, T., eds., Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, 2089–2099. PMLR.
  • Collins et al. (2022) Collins, L.; Hassani, H.; Mokhtari, A.; and Shakkottai, S. 2022. Fedavg with fine tuning: Local updates lead to representation learning. Advances in Neural Information Processing Systems, 35: 10572–10586.
  • Hsu, Qi, and Brown (2019) Hsu, T.-M. H.; Qi, H.; and Brown, M. 2019. Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification. arXiv:1909.06335.
  • Li et al. (2021a) Li, T.; Hu, S.; Beirami, A.; and Smith, V. 2021a. Ditto: Fair and robust federated learning through personalization. In International Conference on Machine Learning, 6357–6368. PMLR.
  • Li et al. (2018) Li, T.; Sahu, A. K.; Zaheer, M.; Sanjabi, M.; Talwalkar, A.; and Smith, V. 2018. Federated optimization in heterogeneous networks. arXiv preprint arXiv:1812.06127.
  • Li et al. (2021b) Li, X.; Jiang, M.; Zhang, X.; Kamp, M.; and Dou, Q. 2021b. Fedbn: Federated learning on non-iid features via local batch normalization. arXiv preprint arXiv:2102.07623.
  • Pillutla et al. (2022) Pillutla, K.; Malik, K.; Mohamed, A.; Rabbat, M.; Sanjabi, M.; and Xiao, L. 2022. Federated Learning with Partial Model Personalization. arXiv preprint arXiv:2204.03809.
  • Van Den Oord, Vinyals et al. (2017) Van Den Oord, A.; Vinyals, O.; et al. 2017. Neural discrete representation learning. Advances in neural information processing systems, 30.