HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: stackrel

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: CC BY-NC-SA 4.0
arXiv:2307.13716v4 [cs.LG] 19 Mar 2024

FedDRL: A Trustworthy Federated Learning Model Fusion Method Based on Staged Reinforcement Learning

Leiming Chen Corresponding author: [email protected] China University of Petroleum (East China), China weishan Zhang China University of Petroleum (East China), China Cihao Dong China University of Petroleum (East China), China Sibo Qiao China University of Petroleum (East China), China Ziling Huang China University of Petroleum (East China), China Yuming Nie China University of Petroleum (East China), China Zhaoxiang Hou Digital Research Institute, ENN Group, China Chee Wei Tan Corresponding author: [email protected] Nanyang Technological University, Singapore
Abstract

Federated learning facilitates collaborative data analysis among multiple participants while preserving user privacy. However, conventional federated learning approaches, typically employing weighted average techniques for model fusion, confront two significant challenges: (1) The inclusion of malicious models in the fusion process can drastically undermine the accuracy of the aggregated global model. (2) Due to the heterogeneity problem of devices and data, the number of client samples does not determine the weight value of the model. To solve those challenge, we propose a trustworthy model fusion method based on reinforcement learning (FedDRL), which includes two stages. In the first stage, we propose a reliable client selection mechanism to exclude malicious models from the fusion process. In the second stage, we propose an adaptive model fusion method that dynamically assigns weights based on model quality to aggregate the best global models. Finally, We validate our approach against five distinct model fusion scenarios, demonstrating that our algorithm significantly enhances reliability without compromising accuracy.

1 Introduction

With the advent of deep learning technologies, various industries have been integrating these technologies into their sectors, promoting the development of intelligent transportation, smart logistics, and healthcare systems. These technologies are crucial in reducing production and management costs, enhancing operational efficiency, and accelerating industry digitization. However, supervised learning remains the primary method for training deep learning models, where the volume and diversity of samples are essential for creating high-quality models. Consequently, acquiring extensive and varied data samples has emerged as the initial step in training deep learning models. This approach has led to sample sources expanding from single industries to collaborations across multiple sectors to develop large-scale datasets. To achieve multi-party joint data analysis under the condition of protecting data security and privacy, Google has proposed federated learning technology for the first time. Although federated learning solves the problem of user privacy protection, the traditional federated learning algorithm assumes that all participants are trustworthy. On the contrary, in the actual scenario, if participants exhibit malicious behavior and intentionally contribute harmful models to the fusion process, it can significantly disrupt the global model’s convergence. Thus, creating adaptive defenses for federated learning systems becomes increasingly crucial [9]. Identifying methods to remove malicious models in federated learning model fusion has become a critical issue. Simultaneously, when a client submits low-quality models for fusion, determining how to adaptively adjust each model’s fusion weights based on their quality is also an urgent problem needing resolution. Some studies have applied reinforcement learning techniques to address these weighting issues. For instance, the Favor [10] method uses the DDPG to assign weights to participant models. Additional research has applied reinforcement learning to address device selection [11] [12], resource optimization [13] [14], and communication optimization in IoT federated learning contexts.

Reinforcement learning (RL) employs a trial-and-error strategy. The essence of this approach is training an intelligent agent that interacts with the external environment through varied actions. The environment then provides feedback in the form of rewards and penalties based on the agent’s actions, guiding the agent toward optimal action selection by maximizing reward value. However, employing reinforcement learning presents certain challenges. Firstly, continuous training is required for sample collection through environmental interaction. When the cost of such interactions is prohibitive or unacceptable (for example, in our scenario, where the server must frequently calculate the global model’s parameters), the efficiency of sample collection significantly impacts the reinforcement learning training duration. Secondly, when the agent’s action space is vast and continuous, it leads to prolonged sampling periods. These issues mean traditional single-agent reinforcement learning training approaches can be exceedingly time-consuming. Applying reinforcement learning in federated learning requires addressing these problems, as increasing participant numbers escalates agent training time. Therefore, optimizing the action space for reinforcement learning to expedite the agent training process is an essential challenge to address.

Why opt for phased reinforcement learning? We take an example to explain this problem. Consider a robot learning to cook through reinforcement learning, with the process divided into washing, chop**, and cooking stages. The robot must master each stage to prepare a successful dish. Traditional reinforcement learning aims to identify the optimal action across all stages simultaneously; however, mastering the initial stage is essential before progressing. By adopting a phased learning approach, the robot sequentially masters each stage, streamlining the learning process and leading to more effective outcomes. Similarly, if malicious models are not initially filtered out, the agent’s trial-and-error costs in weight assignment for these models will increase. To resolve these issues, we propose a staged reinforcement learning algorithm (FedDRL). The contributions of the paper are as follows.

  • We design a federated learning framework that employs reinforcement learning for model fusion, designed to select trustworthy clients and optimally assign model weights.

  • We propose an adaptive client selection strategy based on the A2C algorithm, dynamically identifying and selecting trustworthy clients while excluding malicious ones from the model fusion process based on situational analysis.

  • We propose an adaptive weight assignment method that adaptively adjusts the weights according to the quality of their uploaded models.

  • We propose an adaptive weight assignment method that adaptively adjusts the model fusion weights according to the quality of their uploaded models.

  • We present five types of model fusion scenarios to validate the performance of each algorithm. We also compare the performance of our algorithm with the baseline algorithm on three public datasets.

2 Related Work

2.1 Federated Learning

Research in federated learning primarily aims to address two challenges: enhancing the generalization of the global model on the server side and personalizing the model on the client side. Consequently, federated learning algorithms are bifurcated into server-side and client-side optimization strategies. Google initially introduced the FedAvg algorithm [2] to address the problem of server-side global model fusion. To improve global model convergence, Karimireddy et al. developed the Scaffold method [3], which mitigates client-side drift by integrating a control variable. Similarly, Li et al. introduced FedProx [4], applying a regularization function to client models to correct deviations. Additionally, Wang et al. unveiled FedNova [5], addressing global model convergence issues by normalizing parameters on both client and server ends. Furthermore, Li et al. have introduced the MOON [6] technique, leveraging model comparison learning to enhance global model convergence. Chen et al. [29]. also proposed a client identification method based on model parameter features to achieve trustworthy federated learning.

While those approaches enhance the global model’s convergence speed, practical federated learning situations reveal variances in the quality of models trained by individual participants. These discrepancies stem from the diversity in computational resources and the calibre of data samples available to each participant. Additionally, variations arise due to the quantity and type of samples possessed by each participant, a phenomenon known as Non-IID (Non-Independent and Identically Distributed). Consequently, these factors complicate the attainment of optimal global model aggregation in the Non-IID environments.

2.2 Challenges of Non-IID Data Distribution

The Non-IID data issue significantly impacts federated learning models’ convergence. Zhao et al. explored various federated learning methods’ performance on non-IID datasets, demonstrating significant accuracy challenges [15]. Accordingly, several studies have addressed the non-IID dilemma in federated learning. For instance, Zhang et al. proposed the FedPD approach [16], optimizing models and communication for non-convex objective functions. Moreover, Gong et al. introduced AutoCFL [17], utilizing a weighted voting client clustering strategy to mitigate non-IID and imbalanced data effects. Huang et al. developed FedAMP [18], which addresses Non-IID data-induced client-side model personalization issues through personalized model updates. Li et al. devised Fedbn [19], incorporating a batch normalization layer into local models to address feature shift challenges due to data heterogeneity. Briggs et al. suggested a hierarchical clustering method (FL+HC) [1], improving Non-IID dataset model performance by grou** clients for independent model training. Additionally, Gao et al. offered the Feddc approach [20], bridging client and global model parameter disparities through a control variable. Lastly, Mu et al. introduced Fedproc [21], directing client model training by integrating a comparative loss between client and global models. Chen [7] et al. proposed a federated learning method based on adaptive knowledge distillation to improve the accuracy of heterogeneous model scenarios.

Although these methodologies advance Non-IID issue mitigation in federated learning, they typically assign uniform fusion weights to all clients, failing to exclude malicious or low-quality model contributions. Consequently, dynamically selecting clients for fusion and adaptively calculating each model’s weight remains critical for successful global model integration.

2.3 Federated Reinforcement Learning

Given the adaptive learning potential of reinforcement learning, its application within federated learning contexts has garnered interest. Some research has concentrated on leveraging reinforcement learning to boost global model performance. For instance, Wang et al. introduced the Favor method [10], which adaptively selects clients for model fusion. Sun et al. developed the PG-FFL framework [22], addressing the challenge of client weight computation during model fusion. Additional studies have applied reinforcement learning for device optimization within federated IoT frameworks. For example, Zhang et al. utilized the DDPG algorithm [11] for optimal device selection. Zhang also formulated the FedMarl strategy [23], employing multi-agent reinforcement learning for node selection. Similarly, Yang et al. proposed a digital twin architecture (DTEI) [12], applying reinforcement learning for device selection issues. Other investigations have addressed resource optimization and scheduling challenges within IoT contexts, such as Zhang et al.’s RoF methodology [13], which leverages multi-intelligent reinforcement learning for optimal resource scheduling. Additionally, Rjoub et al. have developed trusted device selection techniques [24] and the DDQN-Trust method [14], utilizing Q-learning to assess devices’ credit scores for optimal scheduling. To ameliorate federated learning communication issues, Yang et al. introduced a reinforcement learning-based model evaluation method [25], selecting optimal devices for training and fusion. Nevertheless, while these efforts predominantly focus on IoT environment applications—such as device selection, resource optimization, and communication enhancement—they seldom address federated learning’s model weight calculation challenges. Therefore, Zhang et al. proposed the R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPTFed framework [27], employing the DDPG reinforcement learning method for adaptive client weight calculation. Chen et al. [28] designed a task platform for implementing trustworthy federation learning.

Although current research addresses the issue of weight allocation in federated learning, it often neglects the training efficiency of the agents. Therefore, optimizing the training efficiency of agents is a significant challenge that needs attention.

3 Method

3.1 Problem Definition

In this section, we scrutinize the prevailing challenges of the current federated learning approach and subsequently propose a solution. In federated learning, the objective is to get the global model by amalgamating local models from all clients through server-side aggregation. We define n clients as involved in model fusion, and the client is denoted as Cisubscript𝐶𝑖C_{i}italic_C start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT where Ci{C1,C2,C3Cn}subscript𝐶𝑖subscript𝐶1subscript𝐶2subscript𝐶3subscript𝐶𝑛C_{i}\in\left\{C_{1},C_{2},C_{3}\ldots C_{n}\right\}italic_C start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT … italic_C start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT }. Each client has a network model Misubscript𝑀𝑖M_{i}italic_M start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, where Mi{M1,M2,M3Mn}subscript𝑀𝑖subscript𝑀1subscript𝑀2subscript𝑀3subscript𝑀𝑛M_{i}\in\left\{M_{1},M_{2},M_{3}\ldots M_{n}\right\}italic_M start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT … italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT }. Each client has its private data Disubscript𝐷𝑖D_{i}italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, where Di{D1,D2,D3Dn}subscript𝐷𝑖subscript𝐷1subscript𝐷2subscript𝐷3subscript𝐷𝑛D_{i}\in\left\{D_{1},D_{2},D_{3}\ldots D_{n}\right\}italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_D start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT … italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT }. The number of samples in each dataset is Sisubscript𝑆𝑖S_{i}italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, where Si{S1,S2,S3Sn}subscript𝑆𝑖subscript𝑆1subscript𝑆2subscript𝑆3subscript𝑆𝑛S_{i}\in\left\{S_{1},S_{2},S_{3}\ldots S_{n}\right\}italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT … italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT }. The total number of samples is i=1NSisuperscriptsubscript𝑖1𝑁subscript𝑆𝑖\sum_{i=1}^{N}S_{i}∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. We define the θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as a model parameter of Misubscript𝑀𝑖M_{i}italic_M start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. where θi{θ1,θ2,θ3θn}subscript𝜃𝑖subscript𝜃1subscript𝜃2subscript𝜃3subscript𝜃𝑛\theta_{i}\in\left\{\theta_{1},\theta_{2},\theta_{3}\ldots\theta_{n}\right\}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT … italic_θ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT }.

Additionally, the server-side model aggregation process per round is defined as shown in equation 1:

θglobal =i=1Nwiθi, where wi=Sii=1NSi,wi0,i=1Nwi=1formulae-sequencesubscript𝜃global superscriptsubscript𝑖1𝑁subscript𝑤𝑖subscript𝜃𝑖formulae-sequence where subscript𝑤𝑖subscript𝑆𝑖superscriptsubscript𝑖1𝑁subscript𝑆𝑖formulae-sequencesubscript𝑤𝑖0superscriptsubscript𝑖1𝑁subscript𝑤𝑖1\theta_{\text{global }}=\sum_{i=1}^{N}w_{i}\theta_{i},\text{ where }w_{i}=% \frac{S_{i}}{\sum_{i=1}^{N}S_{i}},\quad w_{i}\geq 0,\sum_{i=1}^{N}w_{i}=1italic_θ start_POSTSUBSCRIPT global end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , where italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG , italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 0 , ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 (1)

The wisubscript𝑤𝑖w_{i}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the fusion weight of each model parameter.

Traditional Federated Learning typically employs a weighted average approach for computing model fusion weights, with each model’s weight determined by its corresponding client’s data sample size relative to the total. Thus, clients contributing more data exert a greater influence on the aggregated model. However, this method fails to consider the quality of each client’s model and the potential inclusion of malicious models in real-world scenarios. We illustrate the deficiencies of the traditional federated fusion algorithm through two scenarios:

Scenario 1: A client’s data represents 20% of the total, yet its model’s accuracy is merely 53%. Employing the conventional federated fusion algorithm in this case would detrimentally impact the global model’s accuracy.

Scenario 2: A client engaged in model fusion launches malicious attacks, intentionally skewing its model’s output to reflect a mere 10% accuracy. If such malicious models are incorporated through the standard fusion process, the accuracy of the global model would be severely compromised.

Addressing these challenges necessitates an adaptive weight calculation strategy capable of nullifying malicious models by assigning them a weight of zero, thus excluding them from the fusion process. Concurrently, this approach should dynamically adjust the weights of each client’s model, prioritizing those of higher quality to enhance the global model’s overall accuracy.

Adopting a single-agent reinforcement learning strategy to tackle these issues introduces new challenges. As the number of clients increases, so too does the agent’s action space, prolonging the training duration. Additionally, a single-agent framework is limited to interacting with just one environment, further extending the sampling period. We propose a bifurcated solution inspired by hierarchical reinforcement learning to mitigate these concerns, thereby streamlining the lengthy reinforcement learning training process. This solution comprises two primary stages: the selection of trustworthy clients and the assignment of optimal weights.

Refer to caption
Figure 1: The Process of FedDRL framework

Stage 1: During this phase, the objective is to identify K trustworthy models from a pool of N for inclusion in the global model fusion. Identifying clients who have uploaded malicious models is challenging. We address this by employing reinforcement learning to dynamically select and autonomously screen client models, as delineated in equation 2.

{Ma,Mb,,Mk}SelectTrustworthyModel({M1,M2,,Mn})subscript𝑀𝑎subscript𝑀bsubscript𝑀kSelectTrustworthyModelsubscript𝑀1subscript𝑀2subscript𝑀n\{M_{a},M_{\mathrm{b}},\ldots,M_{\mathrm{k}}\}\leftarrow\text{% SelectTrustworthyModel}\left(\{M_{1},M_{2},\ldots,M_{\mathrm{n}}\}\right){ italic_M start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT roman_b end_POSTSUBSCRIPT , … , italic_M start_POSTSUBSCRIPT roman_k end_POSTSUBSCRIPT } ← SelectTrustworthyModel ( { italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_M start_POSTSUBSCRIPT roman_n end_POSTSUBSCRIPT } ) (2)

Stage 2: Building on the first step, we then allocate optimal weights to the verified models to bolster the global model’s accuracy, formalized in equation 3.

{W1,W2,,Wn}AdaptCalculateWeight({Ma,Mb,,Mk})subscript𝑊1subscript𝑊2subscript𝑊nAdaptCalculateWeightsubscript𝑀𝑎subscript𝑀𝑏subscript𝑀k\{W_{1},W_{2},\ldots,W_{\mathrm{n}}\}\leftarrow\operatorname{% AdaptCalculateWeight}\left(\{M_{a},M_{b},\ldots,M_{\mathrm{k}}\}\right){ italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_W start_POSTSUBSCRIPT roman_n end_POSTSUBSCRIPT } ← roman_AdaptCalculateWeight ( { italic_M start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , … , italic_M start_POSTSUBSCRIPT roman_k end_POSTSUBSCRIPT } ) (3)

Here, AdaptCalculateWeight(.)AdaptCalculateWeight(.)italic_A italic_d italic_a italic_p italic_t italic_C italic_a italic_l italic_c italic_u italic_l italic_a italic_t italic_e italic_W italic_e italic_i italic_g italic_h italic_t ( . ) signifies a method for adaptive weight computation, and Wisubscript𝑊𝑖W_{i}italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT represents the optimal computational weight assigned to each client’s output.

3.2 A Trustworthy Federated Learning Approach Based on Staged Reinforcement Learning

To address these challenges, we introduce a trusted federated learning framework anchored in staged reinforcement learning (FedDRL). This framework unfolds across two distinct phases. In the first phase, we propose an adaptive client selection strategy aimed at identifying and selecting trustworthy clients for participation in model fusion. Subsequently, in the second phase, we formulate a model weight assignment algorithm designed to dynamically allocate fusion weight values to models based on the prevailing fusion environment. The process is depicted in Figure 1.

3.2.1 Adaptive client selection method

Once we have defined the base elements of reinforcement learning, We use a distributed A2C approach to train the agent; A2C is an improved method-based A3C algorithm [26]. Figure 1 shows the A2C architecture, which consists of a central node and K𝐾Kitalic_K workers. Each worker contains an Actor and a Critic network, where the actor network generates action, and the Critic network evaluates the action and gives the corresponding reward. Meanwhile, each worker independently interacts with the related environment to achieve sampling and training of the Actor and Critic networks. In addition, the Actor and Critic networks of the central node are used to synchronize the network information of each worker and to achieve the fusion and sharing of network parameters of multiple workers.

Therefore, our main objective is to train Actor and Critic networks. We define the Actor-network parameters as π(θ)𝜋𝜃\pi(\theta)italic_π ( italic_θ ) and the Critic network parameters as V(w)𝑉𝑤V(w)italic_V ( italic_w ). The process of the worker and the central node is as follows.

Step 1: Each worker initializes the local network by pulling the global network model parameters from the centre node. Then, each worker trains the Actor and Critic networks by interacting with the environment independently. Finally, the two networks are uploaded to the central node.

Step 2: After the central node collects the network parameters uploaded by all workers, it updates the global model by the weighted averaging method. Then, the server sends the two networks to each worker.

Steps 1 and 2 are repeated according to the total number of times to obtain the final global model.

The training process for the step 1 neutralization network is as follows: The gradient of the primary communication algorithm of the policy network is calculated as equation 4.

θJ(θ)=θlogπ(atst;θ)A(st,at;w)subscript𝜃𝐽𝜃subscript𝜃𝜋conditionalsubscript𝑎𝑡subscript𝑠𝑡𝜃𝐴subscript𝑠𝑡subscript𝑎𝑡w\nabla_{\theta}J(\theta)=\nabla_{\theta}\log\pi\left(a_{t}\mid s_{t};\theta% \right)A\left(s_{t},a_{t};\mathrm{w}\right)∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_J ( italic_θ ) = ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT roman_log italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ ) italic_A ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; roman_w ) (4)

Where A(st,at;θv)𝐴subscript𝑠𝑡subscript𝑎𝑡subscript𝜃𝑣A\left(s_{t},a_{t};\theta_{v}\right)italic_A ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) is the advantage function. The k-step sampling strategy is used in the A2C algorithm to calculate the advantage function, so the definition is expressed as equation 5.

A(st,at;θv)=i=0k1γirt+i+γkV(st+k;w)V(st;w)𝐴subscript𝑠𝑡subscript𝑎𝑡subscript𝜃𝑣superscriptsubscript𝑖0𝑘1superscript𝛾𝑖subscript𝑟𝑡𝑖superscript𝛾𝑘𝑉subscript𝑠𝑡𝑘w𝑉subscript𝑠𝑡wA\left(s_{t},a_{t};\theta_{v}\right)=\sum_{i=0}^{k-1}\gamma^{i}r_{t+i}+\gamma^% {k}V\left(s_{t+k};\mathrm{w}\right)-V\left(s_{t};\mathrm{w}\right)italic_A ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k - 1 end_POSTSUPERSCRIPT italic_γ start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_r start_POSTSUBSCRIPT italic_t + italic_i end_POSTSUBSCRIPT + italic_γ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_V ( italic_s start_POSTSUBSCRIPT italic_t + italic_k end_POSTSUBSCRIPT ; roman_w ) - italic_V ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; roman_w ) (5)

The Loss function of the actor network is calculated as in equation 6, and The Critic network is calculated as in equation 7.

θJ(θ)=θlogπ(atst;θ)(i=0k1γirt+i+γkV(st+k;w)V(st;w))subscript𝜃𝐽𝜃subscript𝜃𝜋conditionalsubscript𝑎𝑡subscript𝑠𝑡𝜃superscriptsubscript𝑖0𝑘1superscript𝛾𝑖subscript𝑟𝑡𝑖superscript𝛾𝑘𝑉subscript𝑠𝑡𝑘w𝑉subscript𝑠𝑡w\nabla_{\theta}J(\theta)=\nabla_{\theta}\log\pi\left(a_{t}\mid s_{t};\theta% \right)\left(\sum_{i=0}^{k-1}\gamma^{i}r_{t+i}+\gamma^{k}V\left(s_{t+k};% \mathrm{w}\right)-V\left(s_{t};\mathrm{w}\right)\right)∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_J ( italic_θ ) = ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT roman_log italic_π ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ ) ( ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k - 1 end_POSTSUPERSCRIPT italic_γ start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_r start_POSTSUBSCRIPT italic_t + italic_i end_POSTSUBSCRIPT + italic_γ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_V ( italic_s start_POSTSUBSCRIPT italic_t + italic_k end_POSTSUBSCRIPT ; roman_w ) - italic_V ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; roman_w ) ) (6)
wJ(w)=w(i=0k1γirt+i+γkV(st+k;w)V(st;w))2\nabla_{w}J(w)=\nabla_{w}\left(\sum_{i=0}^{k-1}\gamma^{i}r_{t+i}+\gamma^{k}V% \left(s_{t+k};\mathrm{w}\right)-V\left(s_{t};w\right)\right)^{2}∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_J ( italic_w ) = ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k - 1 end_POSTSUPERSCRIPT italic_γ start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_r start_POSTSUBSCRIPT italic_t + italic_i end_POSTSUBSCRIPT + italic_γ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_V ( italic_s start_POSTSUBSCRIPT italic_t + italic_k end_POSTSUBSCRIPT ; roman_w ) - italic_V ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_w ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (7)

We update the Actor and Critic network parameters using the derivative formula as equation 8.

ww+wJ(w),θθ+θJ(θ)formulae-sequence𝑤𝑤subscript𝑤𝐽𝑤𝜃𝜃subscript𝜃𝐽𝜃w\leftarrow w+\nabla_{w}J(w),\quad\theta\leftarrow\theta+\nabla_{\theta}J(\theta)italic_w ← italic_w + ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_J ( italic_w ) , italic_θ ← italic_θ + ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_J ( italic_θ ) (8)

Finally, each worker uploads the Actor and Critic network to the server. Then, the network parameters of the server are calculated using the weighted average method. The process is equation 9.

wglobal=1n1nwi,θglobal=1n1nθi,i[1,n]formulae-sequencesubscript𝑤𝑔𝑙𝑜𝑏𝑎𝑙1𝑛superscriptsubscript1𝑛subscript𝑤𝑖formulae-sequencesubscript𝜃𝑔𝑙𝑜𝑏𝑎𝑙1𝑛superscriptsubscript1𝑛subscript𝜃𝑖𝑖1𝑛w_{global}=\frac{1}{n}{\textstyle\sum_{1}^{n}}w_{i},\theta_{global}=\frac{1}{n% }{\textstyle\sum_{1}^{n}}\theta_{i},i\in[1,n]italic_w start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_i ∈ [ 1 , italic_n ] (9)

When the parameters of the Actor and Critic networks in the central stage are updated, the central node sends down these two networks to all workers, and each worker uses the updated networks to continue interacting with the external environment. The process is repeated for the specified number of rounds until the agent at the central node can obtain a stable reward value. The process is shown in algorithm 1.

Algorithm 1 The process of trustworthy client selection
1: Client Models {m1t,m2t,m3t,mnt}superscriptsubscript𝑚1𝑡superscriptsubscript𝑚2𝑡superscriptsubscript𝑚3𝑡superscriptsubscript𝑚n𝑡\left\{m_{1}^{t},m_{2}^{t},m_{3}^{t},\ldots m_{\mathrm{n}}^{t}\right\}{ italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_m start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , … italic_m start_POSTSUBSCRIPT roman_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT }, Round T, Worker Number K, Sampling Step Length S
2: Chosen Credible Client Model List M = {m2t,m3t,mkt}superscriptsubscript𝑚2𝑡superscriptsubscript𝑚3𝑡superscriptsubscript𝑚k𝑡\left\{m_{2}^{t},m_{3}^{t},\ldots m_{\mathrm{k}}^{t}\right\}{ italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_m start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , … italic_m start_POSTSUBSCRIPT roman_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT }
3:/* Each Worker Training Step */
4:worker (θ,w)GetGlobalParamter(θglobal ,wglobal )𝜃𝑤𝐺𝑒𝑡𝐺𝑙𝑜𝑏𝑎𝑙𝑃𝑎𝑟𝑎𝑚𝑡𝑒𝑟subscript𝜃global subscript𝑤global (\theta,w)\leftarrow GetGlobalParamter\left(\theta_{\text{global }},w_{\text{% global }}\right)( italic_θ , italic_w ) ← italic_G italic_e italic_t italic_G italic_l italic_o italic_b italic_a italic_l italic_P italic_a italic_r italic_a italic_m italic_t italic_e italic_r ( italic_θ start_POSTSUBSCRIPT global end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT global end_POSTSUBSCRIPT )
5:the Client Upload Current Epoch Model, Turn to State s0subscripts0\mathrm{s}_{0}roman_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, tstart =t=1subscript𝑡start 𝑡1t_{\text{start }}=t=1italic_t start_POSTSUBSCRIPT start end_POSTSUBSCRIPT = italic_t = 1
6:for e𝑒eitalic_e from 1 to S𝑆Sitalic_S do
7:     According to Current State s0subscripts0\mathrm{s}_{0}roman_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT Randomly Choose Action stsubscripts𝑡\mathrm{s}_{t}roman_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
8:     st,at,r,st+1Step(at)subscript𝑠𝑡subscript𝑎𝑡𝑟subscript𝑠𝑡1Stepsubscripta𝑡{s_{t},a_{t},r,s_{t+1}}\leftarrow\operatorname{Step}\left(\mathrm{a}_{t}\right)italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r , italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← roman_Step ( roman_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) // Execute Action atsubscripta𝑡\mathrm{a}_{t}roman_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to Acquire Reward rr\mathrm{r}roman_r and Next State st+1subscripts𝑡1\mathrm{s}_{t+1}roman_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT
9:     tstartsubscripttstart\mathrm{t_{start}}roman_t start_POSTSUBSCRIPT roman_start end_POSTSUBSCRIPT = tstartsubscripttstart\mathrm{t_{start}}roman_t start_POSTSUBSCRIPT roman_start end_POSTSUBSCRIPT + 1
10:     if stsubscriptst\mathrm{s_{t}}roman_s start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT != terminal: R \leftarrow V(st;w)subscriptst𝑤(\mathrm{s_{t}};w)( roman_s start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT ; italic_w ) else: R = 0
11:for  i{t1,,tstart }i𝑡1subscripttstart \mathrm{i}\in\left\{t-1,\ldots,\mathrm{t}_{\text{start }}\right\}roman_i ∈ { italic_t - 1 , … , roman_t start_POSTSUBSCRIPT start end_POSTSUBSCRIPT } do
12:     R ri+γabsentsubscriptri𝛾\leftarrow{\mathrm{r_{i}}+\gamma}← roman_r start_POSTSUBSCRIPT roman_i end_POSTSUBSCRIPT + italic_γR // Compute Target TD
13:     θJ(θ)=θlogπθ(atst)(Rv(si;w))subscript𝜃𝐽𝜃subscript𝜃subscript𝜋𝜃conditionalsubscript𝑎𝑡subscript𝑠𝑡𝑅𝑣subscript𝑠𝑖𝑤\nabla_{\theta}J(\theta)=\nabla_{\theta}\log\pi_{\theta}\left(a_{t}\mid s_{t}% \right)\left(R-v\left(s_{i};w\right)\right)∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_J ( italic_θ ) = ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT roman_log italic_π start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( italic_R - italic_v ( italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w ) ) // Compute Strategy Gradient
14:     wJ(w)=w(Rv(si;w))2\nabla_{w}J(w)=\nabla_{w}\left(R-v\left(s_{i};w\right)\right)^{2}∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_J ( italic_w ) = ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_R - italic_v ( italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT // Compute Critic Network Gradient
15:     Update Actor Network Parameters: θθ+θJ(θ)𝜃𝜃subscript𝜃𝐽𝜃\theta\leftarrow{\theta+\nabla_{\theta}J(\theta)}italic_θ ← italic_θ + ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_J ( italic_θ )
16:     Update Critic Network Parameters: ww+wJ(w)𝑤𝑤subscript𝑤𝐽𝑤w\leftarrow w+\nabla_{w}J(w)italic_w ← italic_w + ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_J ( italic_w )
17:/* Center Node Process */
18:for round𝑟𝑜𝑢𝑛𝑑rounditalic_r italic_o italic_u italic_n italic_d from 1 to T𝑇Titalic_T do
19:     for workerisubscriptworker𝑖\mathrm{worker}_{i}roman_worker start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT from 1 to K𝐾Kitalic_K do
20:         Receive Each Worker Parameters(θ,w)𝜃𝑤(\theta,w)( italic_θ , italic_w )
21:         Global(θglobal ,wglobal )𝐀𝐠𝐠({(θ1,w1),(θ2,w2),})Globalsubscript𝜃global subscript𝑤global 𝐀𝐠𝐠subscript𝜃1subscript𝑤1subscript𝜃2subscript𝑤2\operatorname{Global}\left(\theta_{\text{global }},w_{\text{global }}\right)% \leftarrow\mathbf{Agg}\left(\left\{\left(\theta_{1},w_{1}\right),\left(\theta_% {2},w_{2}\right),\ldots\right\}\right)roman_Global ( italic_θ start_POSTSUBSCRIPT global end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT global end_POSTSUBSCRIPT ) ← bold_Agg ( { ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) , … } ) // Aggregate Parameters
22:         workeriSendGlobal(θglobal ,wglobal )subscriptworker𝑖𝑆𝑒𝑛𝑑𝐺𝑙𝑜𝑏𝑎𝑙subscript𝜃global subscript𝑤global \mathrm{worker}_{i}\leftarrow{SendGlobal\left(\theta_{\text{global }},w_{\text% {global }}\right)}roman_worker start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← italic_S italic_e italic_n italic_d italic_G italic_l italic_o italic_b italic_a italic_l ( italic_θ start_POSTSUBSCRIPT global end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT global end_POSTSUBSCRIPT ) // Send New Parameters to Worker      
23:/* Results Process */
24:Output Trusted Client Model List M = {m1t,m2tmkt}superscriptsubscript𝑚1𝑡superscriptsubscript𝑚2𝑡superscriptsubscript𝑚𝑘𝑡\left\{m_{1}^{t},m_{2}^{t}\ldots m_{k}^{t}\right\}{ italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT … italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT }

3.2.2 Adaptive model weight calculation method

In this phase, our main objective is to achieve the optimal weight assignment for each model. For each communication round, we assume that K𝐾Kitalic_K trustworthy client models were selected. We need to train the agents in each communication round and use the weight output of the agent to achieve the global model fusion. We first describe the process of global model fusion for agent-based actions. We define θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as the i𝑖iitalic_i-th client model, and the all client models as {θ1,θ2,,θk}subscript𝜃1subscript𝜃2subscript𝜃𝑘\{\theta_{1},\theta_{2},\cdots,\theta_{k}\}{ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ , italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }. We also define sisubscript𝑠𝑖s_{i}italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as the number of samples of i𝑖iitalic_i-th client. The process is as follows:

(1) In this step, the agent needs to output the weight values for each model. We define the t𝑡titalic_t-th time, the action adopted by the agent as equation 10.

Wt={w1t,w2t,w3twkt}superscriptWtsuperscriptsubscript𝑤1𝑡superscriptsubscript𝑤2𝑡superscriptsubscript𝑤3𝑡superscriptsubscript𝑤𝑘𝑡\mathrm{W^{t}}=\left\{w_{1}^{t},w_{2}^{t},w_{3}^{t}\ldots w_{k}^{t}\right\}roman_W start_POSTSUPERSCRIPT roman_t end_POSTSUPERSCRIPT = { italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_w start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT … italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } (10)

wisubscript𝑤𝑖w_{i}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the i𝑖iitalic_i-th weight value output by agent for i𝑖iitalic_i-th model.

(2) We aggregate the global models based on the model weights assigned by the agent, and the process is expressed as equation 11.

θglobalk=t=1Twitθisuperscriptsubscript𝜃𝑔𝑙𝑜𝑏𝑎𝑙𝑘superscriptsubscript𝑡1𝑇superscriptsubscript𝑤𝑖𝑡subscript𝜃𝑖\theta_{global}^{k}={\textstyle\sum_{t=1}^{T}}w_{i}^{t}\theta_{i}italic_θ start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (11)

We aim to train the agent so that it can output the optimal fusion weight values based on the quality of each model. To accomplish this goal, we first describe the basic elements of reinforcement learning as follows:

Environment: The external environment is the server-side global model fusion module, which fuses the global model based on the actions output by the agent and then verifies the accuracy of the global model on the reserved dataset on the server side. Finally, the server side feeds back to the agent the corresponding reward and punishment values based on the accuracy of the global model.

State: We define the agent’s state information to include the number of samples corresponding to each client, the accuracy of each client’s model, and the accuracy of the global model fused using the weights output by the agent. as shown in 12.

St={s1t,s2t,s3tskt,acc1t,acc2t,acc3tacckt,accglobalt}superscriptStsuperscriptsubscript𝑠1𝑡superscriptsubscript𝑠2𝑡superscriptsubscript𝑠3𝑡superscriptsubscript𝑠𝑘𝑡𝑎𝑐superscriptsubscript𝑐1𝑡𝑎𝑐superscriptsubscript𝑐2𝑡𝑎𝑐superscriptsubscript𝑐3𝑡𝑎𝑐superscriptsubscript𝑐𝑘𝑡𝑎𝑐superscriptsubscript𝑐𝑔𝑙𝑜𝑏𝑎𝑙𝑡\mathrm{S^{t}}=\left\{s_{1}^{t},s_{2}^{t},s_{3}^{t}\ldots s_{k}^{t},acc_{1}^{t% },acc_{2}^{t},acc_{3}^{t}\ldots acc_{k}^{t},acc_{global}^{t}\right\}roman_S start_POSTSUPERSCRIPT roman_t end_POSTSUPERSCRIPT = { italic_s start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_s start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_s start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT … italic_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_a italic_c italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_a italic_c italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_a italic_c italic_c start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT … italic_a italic_c italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_a italic_c italic_c start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } (12)

The accglobalt𝑎𝑐superscriptsubscript𝑐𝑔𝑙𝑜𝑏𝑎𝑙𝑡acc_{global}^{t}italic_a italic_c italic_c start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is the accuracy of the global model fused using the weights output by the agent.

Action: In each stage, the agent needs to assign each model’s weights based on the model’s quality. The action space is shown as 13. aitsuperscriptsubscript𝑎𝑖𝑡a_{i}^{t}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT denotes the weight value assigned to the i𝑖iitalic_i-th client in the state t, while the sum of the corresponding weight values of all clients is 1.

At={a1t,a2t,akt},1kait=1,ait(0,1)formulae-sequencesuperscriptAtsuperscriptsubscript𝑎1𝑡superscriptsubscript𝑎2𝑡superscriptsubscript𝑎𝑘𝑡formulae-sequencesuperscriptsubscript1𝑘superscriptsubscript𝑎𝑖𝑡1superscriptsubscript𝑎𝑖𝑡01\mathrm{A^{t}}=\left\{a_{1}^{t},a_{2}^{t},...a_{k}^{t}\right\},{\textstyle\sum% _{1}^{k}}a_{i}^{t}=1,a_{i}^{t}\in(0,1)roman_A start_POSTSUPERSCRIPT roman_t end_POSTSUPERSCRIPT = { italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , … italic_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } , ∑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = 1 , italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ∈ ( 0 , 1 ) (13)

Reward: We define the model accuracy aggregated using the average method as Accall𝐴𝑐subscript𝑐𝑎𝑙𝑙Acc_{all}italic_A italic_c italic_c start_POSTSUBSCRIPT italic_a italic_l italic_l end_POSTSUBSCRIPT, where each model weight is 1N1𝑁\frac{1}{N}divide start_ARG 1 end_ARG start_ARG italic_N end_ARG. At the m𝑚mitalic_m-th time, we define the weight set output by the agent as W𝑊Witalic_W. Then, we use the weight set to fusion the global model, and we define the accuracy of the global model as Accm𝐴𝑐subscript𝑐𝑚Acc_{m}italic_A italic_c italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT. We calculate the reward value by subtracting the difference of Accm𝐴𝑐subscript𝑐𝑚Acc_{m}italic_A italic_c italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT from Accall𝐴𝑐subscript𝑐𝑎𝑙𝑙Acc_{all}italic_A italic_c italic_c start_POSTSUBSCRIPT italic_a italic_l italic_l end_POSTSUBSCRIPT. If the calculated result is greater than zero, this indicates that the weights assigned by the agent improve the accuracy of the global model, and we give a positive reward. Conversely, we give a penalty reward. φ𝜑\varphiitalic_φ, ϕitalic-ϕ\phiitalic_ϕ denotes the reward and penalty factors, respectively. So, the reward is defined as equation 14.

Reward={φ(AccmAccall),Accm>Accallϕ(AccmAccall),AccmAccall𝑅𝑒𝑤𝑎𝑟𝑑cases𝜑𝐴𝑐subscript𝑐𝑚𝐴𝑐subscript𝑐𝑎𝑙𝑙𝐴𝑐subscript𝑐𝑚𝐴𝑐subscript𝑐𝑎𝑙𝑙𝑜𝑡ℎ𝑒𝑟𝑤𝑖𝑠𝑒italic-ϕ𝐴𝑐subscript𝑐𝑚𝐴𝑐subscript𝑐𝑎𝑙𝑙𝐴𝑐subscript𝑐𝑚𝐴𝑐subscript𝑐𝑎𝑙𝑙𝑜𝑡ℎ𝑒𝑟𝑤𝑖𝑠𝑒Reward=\begin{cases}\varphi\cdot(Acc_{m}-Acc_{all}),Acc_{m}>Acc_{all}\\ \phi\cdot(Acc_{m}-Acc_{all}),Acc_{m}\leq Acc_{all}\end{cases}italic_R italic_e italic_w italic_a italic_r italic_d = { start_ROW start_CELL italic_φ ⋅ ( italic_A italic_c italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - italic_A italic_c italic_c start_POSTSUBSCRIPT italic_a italic_l italic_l end_POSTSUBSCRIPT ) , italic_A italic_c italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT > italic_A italic_c italic_c start_POSTSUBSCRIPT italic_a italic_l italic_l end_POSTSUBSCRIPT end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_ϕ ⋅ ( italic_A italic_c italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - italic_A italic_c italic_c start_POSTSUBSCRIPT italic_a italic_l italic_l end_POSTSUBSCRIPT ) , italic_A italic_c italic_c start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ≤ italic_A italic_c italic_c start_POSTSUBSCRIPT italic_a italic_l italic_l end_POSTSUBSCRIPT end_CELL start_CELL end_CELL end_ROW (14)

When we have finished defining the basic elements, We implement a distributed reinforcement learning approach based on TD3 [8] to train the agent. The training process is shown in figure 1. This stage includes a central Learner and multiple Worker nodes. Each worker corresponds to a parallel environment. The workflow of each worker is as follows: first, each worker performs global model fusion based on the assigned weights; then verifies the accuracy of the global model by interacting with the parallel environment; and finally receives the reward values from the parallel environment feedback. Finally, each worker stores the corresponding ones in the sampling buffer pool. Multiple workers interact with each environment independently, thus achieving parallel sampling to improve the sampling efficiency. After each worker collects a certain batch of samples, the Learner trains the agent by taking a certain amount of sample data from the experience pool.

The TD3 algorithm consists of six network models, including an Actor network P(w)𝑃𝑤P(w)italic_P ( italic_w ), two Critic networks Q1(θ1)subscript𝑄1subscript𝜃1Q_{1}\left(\theta_{1}\right)italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ),Q2(θ2)subscript𝑄2subscript𝜃2Q_{2}\left(\theta_{2}\right)italic_Q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ), and a target Actor-network P(w)superscript𝑃𝑤P^{\prime}(w)italic_P start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ), two target Critic networks Q1(θ1),Q2(θ2)subscriptsuperscript𝑄1subscript𝜃1subscriptsuperscript𝑄2subscript𝜃2Q^{\prime}_{1}\left(\theta_{1}\right),Q^{\prime}_{2}\left(\theta_{2}\right)italic_Q start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_Q start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). Each network is shown in figure 1. The Learner randomly draws N𝑁Nitalic_N batches of sample data from the buffer pool every certain round to train the model. The training processes are as follows.

(1) First, select the action at+1subscript𝑎𝑡1a_{t+1}italic_a start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT based on the target Actor-network P(st+1)superscriptPsubscripts𝑡1\mathrm{P}^{\prime}\left(\mathrm{\leavevmode\nobreak\ s}_{t+1}\right)roman_P start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( roman_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ). The state st+1subscripts𝑡1\mathrm{s}_{t+1}roman_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT and action at+1subscript𝑎𝑡1a_{t+1}italic_a start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT are input to the target Critic networkQ1(θ1)superscriptsubscript𝑄1superscriptsubscript𝜃1Q_{1}^{\prime}\left(\theta_{1}^{\prime}\right)italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) and Q2(θ2)superscriptsubscript𝑄2superscriptsubscript𝜃2Q_{2}^{\prime}\left(\theta_{2}^{\prime}\right)italic_Q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ), respectively. The two target Critic networks will calculate the predicted reward q1subscript𝑞1q_{1}italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and q2subscript𝑞2q_{2}italic_q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT.

(2) The TD target value is calculated using equation 15, where Min(q1,q2)Minsubscript𝑞1subscript𝑞2\operatorname{Min}\left(q_{1},q_{2}\right)roman_Min ( italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) takes the minimum value of both.

ytr+γMin(q1,q2)subscript𝑦𝑡𝑟𝛾Min𝑞1𝑞2y_{t}\leftarrow r+\gamma\operatorname{Min}(q1,q2)italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_r + italic_γ roman_Min ( italic_q 1 , italic_q 2 ) (15)

(3) Select the action based on the actor network, input the state and action into the critical network separately, and let these two networks output the corresponding prediction reward sum.

(4) Calculate the TD error. The calculation formula is as equation 16.

δ1,t=q1,tyt,δ2,t=q2,tytformulae-sequencesubscript𝛿1𝑡subscript𝑞1𝑡subscript𝑦𝑡subscript𝛿2𝑡subscript𝑞2𝑡subscript𝑦𝑡\delta_{1,t}=q_{1,t}-y_{t},\quad\delta_{2,t}=q_{2,t}-y_{t}italic_δ start_POSTSUBSCRIPT 1 , italic_t end_POSTSUBSCRIPT = italic_q start_POSTSUBSCRIPT 1 , italic_t end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_δ start_POSTSUBSCRIPT 2 , italic_t end_POSTSUBSCRIPT = italic_q start_POSTSUBSCRIPT 2 , italic_t end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (16)

(5) Update the Critic network as equation 17.

θ1θ1αδ1,twQ1(st,at;θ1)θ2θ2αδ2,twQ2(st,at;θ2)subscript𝜃1subscript𝜃1𝛼subscript𝛿1𝑡subscript𝑤subscript𝑄1subscript𝑠𝑡subscript𝑎𝑡subscript𝜃1subscript𝜃2subscript𝜃2𝛼subscript𝛿2𝑡subscript𝑤subscript𝑄2subscript𝑠𝑡subscript𝑎𝑡subscript𝜃2\begin{array}[]{l}\theta_{1}\leftarrow\theta_{1}-\alpha\cdot\delta_{1,t}\cdot% \nabla_{w}Q_{1}\left(s_{t},a_{t};\theta_{1}\right)\\ \theta_{2}\leftarrow\theta_{2}-\alpha\cdot\delta_{2,t}\cdot\nabla_{w}Q_{2}% \left(s_{t},a_{t};\theta_{2}\right)\end{array}start_ARRAY start_ROW start_CELL italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_α ⋅ italic_δ start_POSTSUBSCRIPT 1 , italic_t end_POSTSUBSCRIPT ⋅ ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ← italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_α ⋅ italic_δ start_POSTSUBSCRIPT 2 , italic_t end_POSTSUBSCRIPT ⋅ ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARRAY (17)

(6) Update the strategy network every d rounds through the Actor-network output action as equation 18.

ww+βwP(st;w)wQ1(st,at;θ1)ww𝛽subscript𝑤𝑃subscript𝑠𝑡wsubscript𝑤subscript𝑄1subscript𝑠𝑡subscript𝑎𝑡subscript𝜃1\mathrm{w}\leftarrow\mathrm{w}+\beta\cdot\nabla_{w}P\left(s_{t};\mathrm{w}% \right)\cdot\nabla_{w}Q_{1}\left(s_{t},a_{t};\theta_{1}\right)roman_w ← roman_w + italic_β ⋅ ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_P ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; roman_w ) ⋅ ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) (18)

(7) Update the target Actor and Critic network parameters every d rounds as equation 19.

wτw+(1τ)wθ1τθ1+(1τ)θ1θ2τθ2+(1τ)θ2superscript𝑤𝜏𝑤1𝜏superscript𝑤superscriptsubscript𝜃1𝜏subscript𝜃11𝜏superscriptsubscript𝜃1superscriptsubscript𝜃2𝜏subscript𝜃21𝜏superscriptsubscript𝜃2\begin{array}[]{l}w^{\prime}\leftarrow\tau w+(1-\tau)w^{\prime}\\ \theta_{1}^{\prime}\leftarrow\tau\theta_{1}+(1-\tau)\theta_{1}^{\prime}\\ \theta_{2}^{\prime}\leftarrow\tau\theta_{2}+(1-\tau)\theta_{2}^{\prime}\end{array}start_ARRAY start_ROW start_CELL italic_w start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_τ italic_w + ( 1 - italic_τ ) italic_w start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_τ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + ( 1 - italic_τ ) italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_τ italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + ( 1 - italic_τ ) italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARRAY (19)

Repeating the above steps for the specified number of rounds, we will get the trained agent. Finally, we output the optimal value of each model through the agent. The process is shown in algorithm 2

Algorithm 2 The process of model weight calculation
1: Client Models {θ1t,θ2t,θ3t,θnt}superscriptsubscript𝜃1𝑡superscriptsubscript𝜃2𝑡superscriptsubscript𝜃3𝑡superscriptsubscript𝜃n𝑡\left\{\theta_{1}^{t},\theta_{2}^{t},\theta_{3}^{t},\ldots\theta_{\mathrm{n}}^% {t}\right\}{ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , … italic_θ start_POSTSUBSCRIPT roman_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT }, Round R, Worker Number N, Buffer Memory Pool M Initialize Learner Parameters: Actor Parameter P(w)𝑤\left(w\right)( italic_w ), Critic Network Q1(θ1),Q2(θ2)subscript𝑄1subscript𝜃1subscript𝑄2subscript𝜃2Q_{1}\left(\theta_{1}\right),Q_{2}\left(\theta_{2}\right)italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_Q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) Target Actor Parameter P(w)superscript𝑃superscript𝑤P^{\prime}\left(w^{\prime}\right)italic_P start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ), Target Critic Network Q1(θ1),Q2(θ2)superscriptsubscript𝑄1superscriptsubscript𝜃1superscriptsubscript𝑄2superscriptsubscript𝜃2Q_{1}^{\prime}\left(\theta_{1}^{\prime}\right),Q_{2}^{\prime}\left(\theta_{2}^% {\prime}\right)italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , italic_Q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ww,θ1θ1,θ2θ2formulae-sequencesuperscript𝑤𝑤formulae-sequencesuperscriptsubscript𝜃1subscript𝜃1superscriptsubscript𝜃2subscript𝜃2w^{\prime}\leftarrow w,\theta_{1}^{\prime}\leftarrow\theta_{1},\theta_{2}^{% \prime}\leftarrow\theta_{2}italic_w start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_w , italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
2: Optimized Client Model Weight W={w1t,w2t,wkt}𝑊superscriptsubscript𝑤1𝑡superscriptsubscript𝑤2𝑡superscriptsubscript𝑤k𝑡W=\left\{w_{1}^{t},w_{2}^{t},\ldots w_{\mathrm{k}}^{t}\right\}italic_W = { italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , … italic_w start_POSTSUBSCRIPT roman_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT }
3:/* Each Worker Sampling Step */
4:for worker𝑤𝑜𝑟𝑘𝑒𝑟workeritalic_w italic_o italic_r italic_k italic_e italic_r from 1 to N𝑁Nitalic_N do
5:     atP(st,w)subscript𝑎𝑡𝑃subscript𝑠𝑡𝑤a_{t}\leftarrow P\left(s_{t},w\right)italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_P ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_w ) // Randomly Choose an Act from P(st,w)𝑃subscript𝑠𝑡𝑤P\left(s_{t},w\right)italic_P ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_w )
6:     {w1t,w2t,w3twkt}Step(at)superscriptsubscript𝑤1𝑡superscriptsubscript𝑤2𝑡superscriptsubscript𝑤3𝑡superscriptsubscript𝑤𝑘𝑡𝑆𝑡𝑒𝑝subscript𝑎𝑡\left\{w_{1}^{t},w_{2}^{t},w_{3}^{t}\ldots w_{k}^{t}\right\}\leftarrow Step% \left(a_{t}\right){ italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_w start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT … italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } ← italic_S italic_t italic_e italic_p ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
7:     θglobaltAgg(i=1kwitθi)superscriptsubscript𝜃𝑔𝑙𝑜𝑏𝑎𝑙𝑡𝐴𝑔𝑔superscriptsubscript𝑖1𝑘superscriptsubscript𝑤𝑖𝑡subscript𝜃𝑖\theta_{global}^{t}\leftarrow Agg\left(\sum_{i=1}^{k}w_{i}^{t}\theta_{i}\right)italic_θ start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ← italic_A italic_g italic_g ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
8:     RtCaculateReward(ACCtACCavg)subscript𝑅𝑡𝐶𝑎𝑐𝑢𝑙𝑎𝑡𝑒𝑅𝑒𝑤𝑎𝑟𝑑𝐴𝐶subscript𝐶𝑡𝐴𝐶subscript𝐶𝑎𝑣𝑔R_{t}\leftarrow CaculateReward\left(ACC_{t}-ACC_{avg}\right)italic_R start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_C italic_a italic_c italic_u italic_l italic_a italic_t italic_e italic_R italic_e italic_w italic_a italic_r italic_d ( italic_A italic_C italic_C start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_A italic_C italic_C start_POSTSUBSCRIPT italic_a italic_v italic_g end_POSTSUBSCRIPT )
9:     MStore(<St,At,Rt,St+1>)M\leftarrow Store\left(<S_{t},A_{t},R_{t},S_{t+1}>\right)italic_M ← italic_S italic_t italic_o italic_r italic_e ( < italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_R start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT > )
10:/* Center Learner Training Step */
11:for r𝑟ritalic_r from 1 to R do
12:     Randomly Sampling N Batches of Data from M
13:     at+1P(st+1)superscriptsubscript𝑎𝑡1superscript𝑃subscript𝑠𝑡1a_{t+1}^{\prime}\leftarrow P^{\prime}\left(s_{t+1}\right)italic_a start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_P start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT )
14:     yr+γMin(Q1(st+1,at),Q2(st+1,at))𝑦𝑟𝛾𝑀𝑖𝑛superscriptsubscript𝑄1subscript𝑠𝑡1superscriptsubscript𝑎𝑡superscriptsubscript𝑄2subscript𝑠𝑡1superscriptsubscript𝑎𝑡y\leftarrow r+\gamma Min\left(Q_{1}^{\prime}\left(s_{t+1},a_{t}^{\prime}\right% ),Q_{2}^{\prime}\left(s_{t+1},a_{t}^{\prime}\right)\right)italic_y ← italic_r + italic_γ italic_M italic_i italic_n ( italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , italic_Q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) )
15:     Update Critic Network θ1argminθ11N(yQθ1(s,a))2subscript𝜃1𝑎𝑟𝑔𝑚𝑖subscript𝑛subscript𝜃11𝑁superscript𝑦subscript𝑄subscript𝜃1𝑠𝑎2\theta_{1}\leftarrow argmin_{\theta_{1}}\frac{1}{N}\sum\left(y-Q_{\theta_{1}}% \left(s,a\right)\right)^{2}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ← italic_a italic_r italic_g italic_m italic_i italic_n start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ ( italic_y - italic_Q start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s , italic_a ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
16:     Update Critic Network θ2argminθ21N(yQθ2(s,a))2subscript𝜃2𝑎𝑟𝑔𝑚𝑖subscript𝑛subscript𝜃21𝑁superscript𝑦subscript𝑄subscript𝜃2𝑠𝑎2\theta_{2}\leftarrow argmin_{\theta_{2}}\frac{1}{N}\sum\left(y-Q_{\theta_{2}}% \left(s,a\right)\right)^{2}italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ← italic_a italic_r italic_g italic_m italic_i italic_n start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ ( italic_y - italic_Q start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s , italic_a ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
17:     Every d Rounds:
18:     Update Actor-Network: wJ(w)=N1wQθ1(s,a)a=P(s)wP(s)subscript𝑤𝐽𝑤evaluated-atsuperscript𝑁1subscript𝑤subscript𝑄subscript𝜃1𝑠𝑎𝑎𝑃𝑠subscript𝑤𝑃𝑠\nabla_{w}J\left(w\right)=N^{-1}\sum\nabla_{w}Q_{\theta_{1}}\left(s,a\right)% \mid_{a=P\left({s}\right)}\nabla_{w}P\left({s}\right)∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_J ( italic_w ) = italic_N start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∑ ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s , italic_a ) ∣ start_POSTSUBSCRIPT italic_a = italic_P ( italic_s ) end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_P ( italic_s )
19:     Update Target Critic Network: θ1τθ1+(1τ)θ1,θ2τθ2+(1τ)θ2formulae-sequencesuperscriptsubscript𝜃1𝜏subscript𝜃11𝜏superscriptsubscript𝜃1superscriptsubscript𝜃2𝜏subscript𝜃21𝜏superscriptsubscript𝜃2\theta_{1}^{\prime}\leftarrow\tau\theta_{1}+\left(1-\tau\right)\theta_{1}^{% \prime},\theta_{2}^{\prime}\leftarrow\tau\theta_{2}+\left(1-\tau\right)\theta_% {2}^{\prime}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_τ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + ( 1 - italic_τ ) italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_τ italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + ( 1 - italic_τ ) italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT
20:     Update Target Actor-Network: wτw+(1τ)wsuperscript𝑤𝜏𝑤1𝜏superscript𝑤w^{\prime}\leftarrow\tau w+\left(1-\tau\right)w^{\prime}italic_w start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_τ italic_w + ( 1 - italic_τ ) italic_w start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT
21:After R Rounds, Save Trained Model
22:Output Optimized Model Weight W = {w1t,w2twkt}superscriptsubscript𝑤1𝑡superscriptsubscript𝑤2𝑡superscriptsubscript𝑤𝑘𝑡\left\{w_{1}^{t},w_{2}^{t}\ldots w_{k}^{t}\right\}{ italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT … italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT }

4 SYSTEM DESIGN

To establish a reliable federated learning process, we developed a framework for trustworthy federated learning (FedDRL). The framework employs a staged reinforcement learning approach to achieve trustworthy federated learning. In the first stage, we train agents to accomplish the selection of trustworthy clients to participate in global model fusion. Then, in the second stage, we also use the trained agent to dynamically adjust the fusion weights of each model and finally realize the optimal global model fusion. The framework workflow consists of six steps, as shown in Figure 2.

Step 1 (Local Model Training): Each client downloads the global model, initializes its parameters accordingly, and conducts model training using local private data.

Step 2 (Upload Model): After local model training, each client uploads its model parameters to the server.

Step 3 (Select Trustworthy Clients): Upon receiving client model parameters, the server employs the SelectTrustClient(.)SelectTrustClient(.)italic_S italic_e italic_l italic_e italic_c italic_t italic_T italic_r italic_u italic_s italic_t italic_C italic_l italic_i italic_e italic_n italic_t ( . ) algorithm to train an agent. Subsequently, the trained agent selects trustworthy clients.

Refer to caption
Figure 2: The system architecture of FedDRL

Step 4 (Assigning Model Weights): The server Utilizes models from trustworthy clients and performs global model fusion. It then employs the AdaptCalculateWeight(.)AdaptCalculateWeight(.)italic_A italic_d italic_a italic_p italic_t italic_C italic_a italic_l italic_c italic_u italic_l italic_a italic_t italic_e italic_W italic_e italic_i italic_g italic_h italic_t ( . ) algorithm to train an agent, which optimizes weight assignments for each client model.

Step 5 (Fusing Global Model): The server fuses the global model using the calculated weights from the previous step.

Step 6 (Distribute Global Model): The server disseminates the global model to all clients, initiating the subsequent federation task.

The federation task is set to execute a specified number of communication rounds until the final global model is obtained. This process is shown in Algorithm 3.

Algorithm 3 The FedDRL framework
1: Private Dataset {D1,D2,Dn}subscript𝐷1subscript𝐷2subscript𝐷n\left\{D_{1},D_{2},\ldots D_{\mathrm{n}}\right\}{ italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … italic_D start_POSTSUBSCRIPT roman_n end_POSTSUBSCRIPT }, communication round E
2: The Global model {Mglobal}subscript𝑀𝑔𝑙𝑜𝑏𝑎𝑙\left\{M_{global}\right\}{ italic_M start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT }
3:/* Client Process */
4:for Cisubscript𝐶𝑖C_{i}italic_C start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT from 1 to N𝑁Nitalic_N do
5:     MiGetGlobalModel(round=i)subscript𝑀𝑖GetGlobalModel𝑟𝑜𝑢𝑛𝑑𝑖M_{i}\leftarrow\operatorname{GetGlobalModel}(round=i)italic_M start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← roman_GetGlobalModel ( italic_r italic_o italic_u italic_n italic_d = italic_i ) // Get the global model and init client model
6:     MiTrainLocalModel(Di)subscript𝑀𝑖TrainLocalModelsubscript𝐷𝑖M_{i}\leftarrow\operatorname{TrainLocalModel}(D_{i})italic_M start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← roman_TrainLocalModel ( italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) // Train model Misubscript𝑀𝑖M_{i}italic_M start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT based Dataset {Di}subscript𝐷𝑖\left\{D_{i}\right\}{ italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }
7:      Server Send(Mi) Server Sendsubscript𝑀𝑖\text{ Server }\leftarrow\operatorname{Send}(M_{i})Server ← roman_Send ( italic_M start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
8:/* Server Process */
9:for e𝑒eitalic_e from 1 to E𝐸Eitalic_E do
10:      Store ({M1,M2,Mn})Receive(Mi)subscript𝑀1subscript𝑀2subscript𝑀𝑛Receivesubscript𝑀𝑖(\left\{M_{1},M_{2},\ldots M_{n}\right\})\leftarrow\operatorname{Receive}\left% (M_{i}\right)( { italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } ) ← roman_Receive ( italic_M start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) // Receive Client Model /* FedDRL Algorithm Process */
11:     Train the Stage 1 Agent
12:     Update the SelectTrustClient(.)\operatorname{SelectTrustClient}\left(.\right)roman_SelectTrustClient ( . ) Algorithm parameters // According to Algorithm 1
13:     {Ma,Mb,Mk}SelectTrustClient({M1,M2,Mn})subscript𝑀𝑎subscript𝑀𝑏subscript𝑀kSelectTrustClientsubscript𝑀1subscript𝑀2subscript𝑀𝑛\left\{M_{a},M_{b},\ldots M_{\mathrm{k}}\right\}\leftarrow\operatorname{% SelectTrustClient}\left(\left\{M_{1},M_{2},\ldots M_{n}\right\}\right){ italic_M start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , … italic_M start_POSTSUBSCRIPT roman_k end_POSTSUBSCRIPT } ← roman_SelectTrustClient ( { italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } )
14:     Train the Stage 2 Agent
15:     Update the AdaptCalculateWeight(.)\operatorname{AdaptCalculateWeight}\left(.\right)roman_AdaptCalculateWeight ( . ) Algorithm parameters // According to Algorithm 2
16:     {W1,W2,Wn}AdaptCalculateWeight({Ma,Mb,Mk})subscript𝑊1subscript𝑊2subscript𝑊nAdaptCalculateWeightsubscript𝑀𝑎subscript𝑀𝑏subscript𝑀𝑘\left\{W_{1},W_{2},\ldots W_{\mathrm{n}}\right\}\leftarrow\operatorname{% AdaptCalculateWeight}\left(\left\{M_{a},M_{b},\ldots M_{k}\right\}\right){ italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … italic_W start_POSTSUBSCRIPT roman_n end_POSTSUBSCRIPT } ← roman_AdaptCalculateWeight ( { italic_M start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , … italic_M start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } )
17:     MglobalFusionGlobalModel({W1,W2,Wn})subscript𝑀𝑔𝑙𝑜𝑏𝑎𝑙FusionGlobalModelsubscript𝑊1subscript𝑊2subscript𝑊𝑛M_{global}\leftarrow\operatorname{FusionGlobalModel}\left(\left\{W_{1},W_{2},% \ldots W_{n}\right\}\right)italic_M start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT ← roman_FusionGlobalModel ( { italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … italic_W start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } )
18:     CiSendGlobalModel(Mglobal)subscript𝐶𝑖SendGlobalModelsubscript𝑀𝑔𝑙𝑜𝑏𝑎𝑙C_{i}\leftarrow\operatorname{SendGlobalModel}\left(M_{global}\right)italic_C start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← roman_SendGlobalModel ( italic_M start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT )

5 EXPERIMENT

5.1 Experiment setup

5.1.1 Experiment datasets

We evaluated the FedDRL framework using three distinct image classification datasets:

Fashion-MNIST: This dataset includes 60,000 training samples and 10,000 test samples, each a 28x28 grayscale image, classified into one of 10 categories.

CIFAR-10: The CIFAR-10 dataset comprises 60,000 32x32 colour images, evenly distributed across ten classes, each containing 6,000 images.

CIFAR-100: Similar in size to CIFAR-10 but with a broader spectrum, CIFAR-100 features 100 classes with 600 images each, totalling 60,000 colour images.

Data Set Partitioning: For simulating non-IID data distribution among clients. We utilized the Dirichlet function to segregate data across various clients in the open-source dataset. This method can partition the data for each client by adjusting the alpha parameter. As the alpha parameter approaches zero, clients’ data distributions are skewed towards specific classes within the dataset. Conversely, as alpha increases towards infinity. Using the CIFAR-10 dataset as a case study, we set alpha to 1, thereby dividing the three datasets among ten clients. In the figure, Different categories are represented by distinct colours, and the length of each segment within the graphs reflects the sample count within that category. The resulting data distribution is illustrated in Figure 3.

Refer to caption (a) Fashion-MNIST
Refer to caption (b) CIFAR-10
Refer to caption (c) CIFAR-100
Figure 3: The non-iid distribution of 10 clients(alpha=1).

5.1.2 Comparison of Methods

We contrasted the FedDRL algorithm with two established federated learning approaches.

FedAvg[2]: Serving as the foundational benchmark in federated learning, the FedAvg method determines the weight of each client model based on the proportion of samples contributed by the client relative to the aggregate sample size.

FedProx[4]: Enhancing the FedAvg approach, FedProx incorporates a regularization term within the client model, thereby refining federated learning performance.

5.1.3 Experimental Metrics

We employed accuracy as the metric to gauge the performance of the global model in multi-classification tasks and across individual clients. Assuming n clients engage in model fusion with m communication rounds, the accuracy of the global model in the t-th round is denoted as Aglobal(t)superscriptsubscript𝐴global𝑡A_{\text{global}}^{(t)}italic_A start_POSTSUBSCRIPT global end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT. The collective global model accuracies across all rounds are represented as follows:

Aglobal={Aglobal1,Aglobal2,,Aglobalm}subscript𝐴globalsuperscriptsubscript𝐴global1superscriptsubscript𝐴global2superscriptsubscript𝐴global𝑚A_{\text{global}}=\{A_{\text{global}}^{1},A_{\text{global}}^{2},\ldots,A_{% \text{global}}^{m}\}italic_A start_POSTSUBSCRIPT global end_POSTSUBSCRIPT = { italic_A start_POSTSUBSCRIPT global end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_A start_POSTSUBSCRIPT global end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , … , italic_A start_POSTSUBSCRIPT global end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT } (20)

We denote the accuracy of the c𝑐citalic_c-th client’s model as Acsubscript𝐴𝑐A_{c}italic_A start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT. Additionally, we document each client model’s accuracy per round, compiling these as follows:

Ac={Ac1,Ac2,,Acm},c[1,2,,n]formulae-sequencesubscript𝐴csuperscriptsubscript𝐴c1superscriptsubscript𝐴c2superscriptsubscript𝐴c𝑚𝑐12𝑛A_{\text{c}}=\{A_{\text{c}}^{1},A_{\text{c}}^{2},\ldots,A_{\text{c}}^{m}\},c% \in[1,2,\ldots,n]italic_A start_POSTSUBSCRIPT c end_POSTSUBSCRIPT = { italic_A start_POSTSUBSCRIPT c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_A start_POSTSUBSCRIPT c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , … , italic_A start_POSTSUBSCRIPT c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT } , italic_c ∈ [ 1 , 2 , … , italic_n ] (21)

5.1.4 Experimental Configuration

Hardware Configuration: The experiments were conducted on a workstation equipped with an Intel i9-12900k CPU, 64GB RAM, and an NVIDIA RTX3090 GPU.

Software Configuration: We utilized two distinct frameworks for the Federated Learning and Reinforcement Learning experiments. Federated Learning trials were carried out using FedBolt, our custom-built framework, enabling simulation of varied client numbers and data distributions. For reinforcement learning model training, we employed the stablebaseline3 framework, designing two distinct algorithms for trusted client selection and model weight assignment.

Network Setup: We implemented different network architectures tailored to each dataset. For CIFAR-10 and CIFAR-100, a 6-layer CNN was utilized for model training. Conversely, a 4-layer MLP was developed for the Fashion-MNIST dataset.

Agent Network Setup: Implementing a staged reinforcement learning strategy necessitated the training of two distinct agents. The initial phase, adhering to Section 3.2.1, utilizes the A2C algorithm, with each worker and the central node comprising a 6-layer MLP actor and critic network. For the second phase, the TD3 algorithm outlined in Section 3.2.2 was employed for agent training, where each module within the TD3 setup incorporates a 6-layer MLP, with further details available in Section 3.2.2.

5.2 Experimental Results

We evaluate the FedDRL framework through four experiments: client attack scenarios, low-quality model fusion, hybrid scenarios, and multi-agent training efficiency. The client attack experiments assess the efficacy of the trustworthy client selection algorithm (stage 1). The low-quality model fusion experiment examines the adaptive weight calculation method (stage 2). The hybrid experiments, combining client attacks and low-quality model elements, validate the comprehensive performance of FedDRL. The final experiment focuses on the training efficiency of multi-agents.

5.2.1 Malicious Client Attack Experiment

In this experiment, we define three types of client-side attacks in federated learning to evaluate our FedDRL framework under adversarial conditions. The experiment spans different client numbers and attack types across three datasets, detailed in Table 1.

Type 1: The client directly uploads the initialized model or makes the model accuracy less than 10% by modifying the model’s hyperparameters.

Type 2: We use falsified data to perform the attack. We use a certain percentage of forged data to participate in model training (e.g., mix the CIFAR-10 dataset with 80%percent8080\%80 % of CIFAR-100 data and generate these CIFAR-100 data labels as CIFAR-10 corresponding label types). We conduct the attack by faking sample data to train the client’s local model, thus reducing the client model’s accuracy.

Type 3: We select some clients to simulate the attack and divide the training process of these clients into standard and attack rounds. In the standard round, each client does not perform the attack behavior. Instead, each client deliberately uploads the prepared malicious model in the attack round. We also set that these clients alternately initiate the attack behavior.

Table 1: Experimental setup for malicious attack scenarios
Number Attack Type Malicious ID Number of samples Accuracy of models (\leq)
Clients=5 Type 1 Client1 7750 A10%𝐴percent10A\leq 10\%italic_A ≤ 10 %
Type 2 Client1 7750 10%A20%percent10𝐴percent2010\%\leq A\leq 20\%10 % ≤ italic_A ≤ 20 %
Type 3 Client1 7750 Attack round A10%𝐴percent10A\leq 10\%italic_A ≤ 10 %
Clients=10 Type 1 Client1,Client6 4222, 4938 A10%𝐴percent10A\leq 10\%italic_A ≤ 10 %
Type 2 Client1,Client6 4222, 4938 10%A20%percent10𝐴percent2010\%\leq A\leq 20\%10 % ≤ italic_A ≤ 20 %
Type 3 Client1,Client6 4222, 4938 Attack round A10%𝐴percent10A\leq 10\%italic_A ≤ 10 %
Clients=15 Type 1 Client1,Client6,Client11 3670, 3314, 4454 A10%𝐴percent10A\leq 10\%italic_A ≤ 10 %
Type 2 Client1,Client6,Client11 3670, 3314, 4453 10%A20%percent10𝐴percent2010\%\leq A\leq 20\%10 % ≤ italic_A ≤ 20 %
Type 3 Client1,Client6,Client11 3670, 3314, 4453 Attack round A15%𝐴percent15A\leq 15\%italic_A ≤ 15 %

According to the experimental setup, we compared the FedDRL algorithm with the FedAvg and FedProx. In the attack experiments, we set the total number of communication rounds to 100 rounds, and each client performs local model training with one epoch. To show the attack behavior of each client and the accuracy of different algorithms more detail, we counted the accuracy of each client’s local model and the accuracy of the server-side global model in each communication round. The specific experimental results are shown in table 2.

Table 2: Accuracy of each algorithm under different malicious attack scenarios
DataSet Method Clients=5 Clients=10 Clients=15
Type 1 Type 2 Type 3 Type 1 Type 2 Type 3 Type 1 Type 2 Type 3
Fashion -MNIST FedAvg 0.862 0.875 0.863 0.792 0.881 0.821 0.776 0.878 0.812
FedProx 0.873 0.884 0.864 0.791 0.882 0.824 0.764 0.879 0.811
Ours 0.885 0.878 0.883 0.877 0.886 0.887 0.881 0.886 0.882
Cifar10 FedAvg 0.596 0.691 0.751 0.335 0.681 0.314 0.139 0.664 0.197
FedProx 0.586 0.732 0.775 0.331 0.716 0.363 0.122 0.719 0.202
Ours 0.731 0.701 0.747 0.694 0.711 0.727 0.679 0.702 0.689
Cifar100 FedAvg 0.376 0.412 0.298 0.273 0.416 0.287 0.162 0.398 0.176
FedProx 0.398 0.442 0.321 0.223 0.436 0.208 0.172 0.426 0.183
Ours 0.412 0.438 0.431 0.426 0.432 0.421 0.423 0.412 0.422

To show the effect of the FedDRL algorithm on global model fusion at each communication round, we conducted experiments using the CIFAR10 dataset on 5, 10, and 15 clients. We compared FedDRL with the FedAvg and FedProx algorithms for global model accuracy.

We analyze the experimental results for different numbers of client models and different client data. In attack type 1, In malicious data attack type 2, our algorithm outperforms the FedAvg algorithm and slightly underperforms the FedProx algorithm alone. In the attack type 3 scenario, our algorithm outperforms the comparison algorithm in most cases, especially when multiple malicious clients are involved in model fusion.

Refer to caption (a) 10 clients FedAvg
Refer to caption (b) 10 clients FedProx
Refer to caption (c) 10 clients FedDRL
Refer to caption (d) 15 clients FedAvg
Refer to caption (e) 15 clients FedProx
Refer to caption (f) 15 clients FedDRL
Figure 4: The accuracy of the global model for different number of client in attack type 1

To show the relationship between the global and client model’s accuracy in each attack scenario. We conducted more detailed experiments on the Cifar10 dataset.

In attack type 1, the global model accuracy plummets with increasing malicious clients under FedAvg and FedProx, drop** below 40% and 20% in 10 and 15 client setups, respectively. Conversely, FedDRL’s dynamic client selection maintains higher reliability. However, Our trained agent can dynamically select trusted clients for model fusion and eliminate malicious models from participating, so our algorithm has higher reliability. The experimental results are shown in Figure 4.

In attack type 2, our algorithm is better than FedAvg but lower than FedProx. The FedProx algorithm uses control parameters to force the models of each client to converge to the global model, which will improve the global model’s accuracy by improving the malicious model’s accuracy to some extent. Our trained agent will filter out low-accuracy models to participate in the fusion after several communication rounds. The experimental results are shown in Figure 5.

Refer to caption (a) 10 clients FedAvg
Refer to caption (b) 10 clients FedProx
Refer to caption (c) 10 clients FedDRL
Refer to caption (d) 15 clients FedAvg
Refer to caption (e) 15 clients FedProx
Refer to caption (f) 15 clients FedDRL
Figure 5: The accuracy of the global model for different number of client in attack type 2
Refer to caption (a) 10 clients FedAvg
Refer to caption (b) 10 clients FedProx
Refer to caption (c) 10 clients FedDRL
Refer to caption (d) 15 clients FedAvg
Refer to caption (e) 15 clients FedProx
Refer to caption (f) 15 clients FedDRL
Figure 6: The accuracy of the global model for different number of client in attack type 3

In attack type 3 scenarios, the FedAvg and FedProx algorithms experience significant fluctuations in global model accuracy due to alternating attack behaviors by malicious clients. Conversely, the agent within the FedDRL framework adaptively selects trusted clients, effectively excluding malicious entities from participating in model fusion, thereby enabling the FedDRL algorithm to operate with stability. The experimental results are shown in Figure 6.

5.2.2 Low-quality Model Fusion Experiments

In evaluating our FedDRL framework, we undertook validation using the Fashion-MNIST, CIFAR-10, and CIFAR-100 datasets. Given their open-source nature, these datasets are of high quality, leading to minimal variance in model accuracy among clients utilizing them directly. Thus, to simulate real-world conditions, we incorporated low-quality models into the global fusion process. We established a model accuracy threshold, ensuring that models uploaded by low-quality clients did not exceed this threshold in any communication round.

Experiments were carried out on the three datasets, with client groups of varying sizes—5, 10, and 15—participating in the global model fusion. We applied a Dirichlet distribution with parameter alpha=1 to achieve dataset segmentation among clients. We set some clients to upload low-quality models; after several communication rounds, we controlled these client models’ accuracy in global fusion, ensuring it remained within the 40% to 55% range.

Details of these low-quality model experiment configurations are specified in Table 3. The FedDRL algorithm was compared against the FedAvg and FedProx methods across 100 communication rounds, with each client executing one epoch of local model training. Results are summarized in Table 4.

Table 3: Experimental settings for low-quality model experiments
Number Dataset Low-quality Model ID Number of samples Accuracy of models (\leq)
Clients=5 Fashion-MINST Client1 9061 53%
CIFAR-10 Client1 7750 52%
CIFAR-100 Client1 9278 22%
Clients=10 Fashion-MINST Client1,Client5 5071, 7245 51%, 52%
CIFAR-10 Client1,Client5 4222, 6039 50%, 54%
CIFAR-100 Client1,Client5 4191, 5491 22%
Clients=15 Fashion-MINST Client1,Client5,Client10 4405, 3752, 1809 52%, 51%, 53%
CIFAR-10 Client1,Client5,Client10 3670, 3128, 1509 49%, 52%, 55%
CIFAR-100 Client1,Client5,Client10 3073, 3494, 2910 22% 19% 23%

The FedDRL algorithm was compared against the FedAvg and FedProx methods across 100 communication rounds, with each client executing one epoch of local model training. Results are summarized in Table 4. Details of these low-quality model experiment configurations are specified in Table 3.

Table 4: Accuracy of each algorithm for low-quality modeling experiments
Method Fashion-MINST CIFAR-10 CIFAR-100
C=5 C=10 C=15 C=5 C=10 C=15 C=5 C=10 C=15
FedAvg 0.857 0.858 0.841 0.705 0.664 0.602 0.386 0.373 0.365
FedProx 0.865 0.861 0.829 0.714 0.652 0.607 0.402 0.391 0.386
Ours 0.885 0.887 0.884 0.725 0.706 0.698 0.422 0.418 0.407

Employing the CIFAR-10 dataset for illustrative purposes, we performed comparative analyses for setups with 10 and 15 clients, respectively; the findings are depicted in Figure 7. The experiments indicate that the accuracy of the FedAvg and FedProx methods deteriorates as the prevalence of low-quality models increases. This decline can be attributed to these algorithms’ reliance on sample count for determining the fusion weight values of the models, where the inclusion of low-quality models adversely impacts the global model’s accuracy. Conversely, FedDRL surpasses both methodologies in terms of global model convergence speed and accuracy. This is because FedDRL adaptively recalibrates the weights assigned to each client’s model based on quality, thereby diminishing the adverse effects of low-quality models on the global model’s accuracy and consequently hastening the global model’s convergence rate.

Refer to caption (a) 10 clients FedAvg
Refer to caption (b) 10 clients FedProx
Refer to caption (c) 10 clients FedDRL
Refer to caption (d) 15 clients FedAvg
Refer to caption (e) 15 clients FedProx
Refer to caption (f) 15 clients FedDRL
Figure 7: The accuracy of a global model for different numbers of client in Low-quality scenario.

5.2.3 Hybrid experiment

In this section, we establish a hybrid scenario incorporating two types of attacking clients (type 1 and type 3) alongside clients submitting low-quality models. We assess the effectiveness of the FedDRL algorithm within this mixed scenario and benchmark it against the FedAvg and FedProx approaches.

Employing the CIFAR-10 dataset, we set different numbers of clients (10,15) participating in global model fusion, respectively. Client 1 persistently uploads merely the initial model at each round. Client 6 emulates the submission of low-quality models for fusion, and Client 10 or 11 engages in attack behaviour during odd communication rounds but normally participates during even rounds. The remainder of the nodes contribute routinely to each cycle of the federated learning tasks. The experimental setup specifics are delineated in Table 5.

Table 5: Experimental settings for hybrid scenarios
Number Client ID Type Number of samples Model Accuracy
Clients=10 Client1 Attack Type 1 4222 A10%𝐴percent10A\leq 10\%italic_A ≤ 10 %
Client6 Low-quality Model 4938 45%A50%percent45𝐴percent5045\%\leq A\leq 50\%45 % ≤ italic_A ≤ 50 %
Client10 Attack Type 3 3560 Attack round A15%𝐴percent15A\leq 15\%italic_A ≤ 15 %
Clients=15 Client1 Attack Type 1 3670 A10%𝐴percent10A\leq 10\%italic_A ≤ 10 %
Client6 Low-quality Model 3314 45%A50%percent45𝐴percent5045\%\leq A\leq 50\%45 % ≤ italic_A ≤ 50 %
Client11 Attack Type 3 4453 Attack round A15%𝐴percent15A\leq 15\%italic_A ≤ 15 %
Refer to caption (a) 10 clients FedAvg
Refer to caption (b) 10 clients FedProx
Refer to caption (c) 10 clients FedDRL
Refer to caption (d) 15 clients FedAvg
Refer to caption (e) 15 clients FedProx
Refer to caption (f) 15 clients FedDRL
Figure 8: Comparison of global model accuracy between different algorithms.

After completing 100 communication rounds, we present the global model accuracy for each algorithm in Table 6. The comparative global model accuracies and individual client model accuracies per communication round, as determined by these three algorithms, are depicted in Figure 8. The experimental outcomes from the hybrid scenario reveal that the FedAvg and FedProx algorithms falter in properly conducting global model fusion due to the adversarial behavior of certain clients. Incorporating malicious models under traditional algorithmic frameworks significantly degrades the global model’s accuracy.

Table 6: Accuracy of each algorithm for hybrid scenarios experiments
Method Fashion-MINST CIFAR-10 CIFAR-100
Clients=10 Clients=15 Clients=10 Clients=15 Clients=10 Clients=15
FedAvg 0.835 0.823 0.368 0.348 0.223 0.238
FedProx 0.821 0.846 0.308 0.341 0.241 0.266
Ours 0.876 0.883 0.701 0.698 0.426 0.418

The experimental outcomes show that FedAvg and FedProx’s global model accuracies suffer from malicious attacks due to their weighted average-based fusion, which doesn’t block harmful participants. Conversely, the FedDRL algorithm, through its two-stage approach, initially filters out malicious models from fusion and subsequently applies an adaptive weight strategy to diminish the impact of substandard models. Consequently, our algorithm maintains operational integrity even within this complex scenario.

5.2.4 Agent Training Efficiency in the FedDRL Framework

In this segment, our primary objective is to assess the training efficiency of agents within the FedDRL framework. To expedite the training process, we have implemented optimizations in two key areas. Initially, we adopted a distributed reinforcement learning methodology, enabling multi-agents to interact concurrently with the external environment. Concurrently, we introduced a memory cache module designed to prevent redundant sampling by multiple agents.

Experimental Scenarios: Our investigation encompasses varied attack scenarios across two distinct datasets: Fashion-MNIST and Cifar-10. In each scenario, we involve a total of 10 and 15 clients in the federated task, including 2 and 3 malicious clients accordingly.

Comparison Experiments: To ascertain the efficacy of the FedDRL framework, we initiated experiments featuring 1, 5, 10, and 20 agents. To guarantee the stability of the reward values acquired by the final agents, we designated the number of iterations for each experimental group to be 10,000, 15,000, 20,000, and 25,000, correspondingly.

Experimental metrics: Our evaluation involves counting the iterations necessary for reinforcement learning to reach stable rewards across different agent counts. We employ a sliding window approach to compute the average reward, depicting the progression of rewards attained by the agents. We define rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT the agent’s reward obtains in the t𝑡titalic_t-th interaction and the sliding window as W𝑊Witalic_W. The formula for calculating the average reward is represented as equation 22:

R¯=1Wt=1Wrt¯𝑅1𝑊superscriptsubscript𝑡1𝑊subscript𝑟𝑡\bar{R}=\frac{1}{W}{\textstyle\sum_{t=1}^{W}}r_{t}over¯ start_ARG italic_R end_ARG = divide start_ARG 1 end_ARG start_ARG italic_W end_ARG ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W end_POSTSUPERSCRIPT italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (22)

Reward Parameter Setting: Our reward function comprises two components: the global model accuracy reward and the reward for the number of credible nodes. For this experiment, these parameters are set to 100 and 10, respectively.

Table 7: The iterations of obtaining stable rewards for different numbers of agents
Dataset Attack Type The number of agents
N=1 N=5 N=10 N=20
Fashion-MNIST Type-1 25000 20000 16000 8000
Type-2 25000 21000 15000 9000
CIFAR-10 Type-1 25000 20000 12000 10000
Type-2 25000 20000 14000 9000
Refer to caption (a) Fashion-MNIST (Attack Type 1)
Refer to caption (b) Fashion-MNIST (Attack Type 2)
Refer to caption (c) Cifar-10 (Attack Type 1)
Refer to caption (d) Cifar-10 (Attack Type 2)
Figure 9: Global model accuracy in 15 clients attack type 1

Experimental Results: In accordance with our experimental setup, we recorded the reward values for each iteration of the agents, as detailed in Figure 9. We have systematically arranged this information into Table 7 for enhanced clarity regarding the actual iterations across different experiments.

The data reveals a notable trend: The single agent does not get the optimal reward in some attack scenarios, because the single agent is easy to fall into the local optimal solution. Meanwile, an increase in agents correlates with reducing the iterations required to achieve a stable reward. However, this relationship is not strictly proportional because the multi-agent independently train their respective Actor and Critic networks. Each agent necessitates a distinct number of iterations to ensure the stability of its individual networks. Nevertheless, the simultaneous interaction of multiple agents with the environment markedly decreases the sampling time, demonstrating a clear trade-off between computational resources and time. This strategy underlines the significant computational resources required, highlighting a deliberate exchange of increased computational demand for reduced computational time.

6 CONCLUSION

To realize trustworthy federated learning, We propose a trusted reinforcement learning framework (FedDRL) based on staged reinforcement learning. The framework comprises two phases: selecting trusted clients and adaptive weight assignment. In the first phase, we design a reward strategy to train the agent, which allows the trained agent to exclude malicious client models from participating in model fusion based on the environment, and it also adaptively selects trustworthy clients for model fusion. In the second phase, we design a dynamic model weight calculation method, which can adaptively calculate the corresponding weights based on the model quality of each client. In addition, we propose a distributed reinforcement learning method to accelerate agent training. Finally, we design five model fusion scenarios to validate our approach, and the experiments show that our proposed algorithm can work reliably in various model fusion scenarios while maintaining global model accuracy.

Although a multi-agent distributed reinforcement learning approach can accelerate the agent training process, it sacrifices computational resources for the computational time. In our future work, we will continue to explore more lightweight and trustworthy federated learning methods. We will also investigate more efficient reinforcement learning methods for credible federated learning.

References

  • [1] Briggs C, Fan Z, Andras P. Federated learning with hierarchical clustering of local updates to improve training on non-IID data[C]//2020 International Joint Conference on Neural Networks (IJCNN). IEEE, 2020: 1-9.
  • [2] McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.
  • [3] Karimireddy S P, Kale S, Mohri M, et al. Scaffold: Stochastic controlled averaging for federated learning[C]//International conference on machine learning. PMLR, 2020: 5132-5143.
  • [4] Li T, Sahu A K, Zaheer M, et al. Federated optimization in heterogeneous networks[J]. Proceedings of Machine learning and systems, 2020, 2: 429-450.
  • [5] Wang J, Liu Q, Liang H, et al. Tackling the objective inconsistency problem in heterogeneous federated optimization[J]. Advances in neural information processing systems, 2020, 33: 7611-7623.
  • [6] Li Q, He B, Song D. Model-contrastive federated learning[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021: 10713-10722.
  • [7] Chen L, Zhang W, Dong C, et al. FedTKD: A Trustworthy Heterogeneous Federated Learning Based on Adaptive Knowledge Distillation[J]. Entropy, 2024, 26(1): 96.
  • [8] Fujimoto S, Hoof H, Meger D. Addressing function approximation error in actor-critic methods[C]//International conference on machine learning. PMLR, 2018: 1587-1596.
  • [9] Li H, Sun X, Zheng Z. Learning to attack federated learning: A model-based reinforcement learning attack framework[J]. Advances in Neural Information Processing Systems, 2022, 35: 35007-35020.
  • [10] Wang H, Kaplan Z, Niu D, et al. Optimizing federated learning on non-iid data with reinforcement learning[C]//IEEE INFOCOM 2020-IEEE conference on computer communications. IEEE, 2020: 1698-1707.
  • [11] Zhang P, Wang C, Jiang C, et al. Deep reinforcement learning assisted federated learning algorithm for data management of IIoT[J]. IEEE Transactions on Industrial Informatics, 2021, 17(12): 8475-8484.
  • [12] Yang W, Xiang W, Yang Y, et al. Optimizing federated learning with deep reinforcement learning for digital twin empowered industrial IoT[J]. IEEE Transactions on Industrial Informatics, 2022, 19(2): 1884-1893.
  • [13] Zhang W, Yang D, Wu W, et al. Optimizing federated learning in distributed industrial IoT: A multi-agent approach[J]. IEEE Journal on Selected Areas in Communications, 2021, 39(12): 3688-3703.
  • [14] Rjoub G, Wahab O A, Bentahar J, et al. Trust-driven reinforcement selection strategy for federated learning on IoT devices[J]. Computing, 2022: 1-23.
  • [15] Zhao Y, Li M, Lai L, et al. Federated learning with non-iid data[J]. arXiv preprint arXiv:1806.00582, 2018.
  • [16] Zhang X, Hong M, Dhople S, et al. Fedpd: A federated learning framework with adaptivity to non-iid data[J]. IEEE Transactions on Signal Processing, 2021, 69: 6055-6070.
  • [17] Gong B, Xing T, Liu Z, et al. Adaptive client clustering for efficient federated learning over non-iid and imbalanced data[J]. IEEE Transactions on Big Data, 2022.
  • [18] Huang Y, Chu L, Zhou Z, et al. Personalized cross-silo federated learning on non-iid data[C]//Proceedings of the AAAI conference on artificial intelligence. 2021, 35(9): 7865-7873.
  • [19] Li X, Jiang M, Zhang X, et al. Fedbn: Federated learning on non-iid features via local batch normalization[J]. arXiv preprint arXiv:2102.07623, 2021.
  • [20] Gao L, Fu H, Li L, et al. Feddc: Federated learning with non-iid data via local drift decoupling and correction[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022: 10112-10121.
  • [21] Mu X, Shen Y, Cheng K, et al. Fedproc: Prototypical contrastive federated learning on non-iid data[J]. Future Generation Computer Systems, 2023, 143: 93-104.
  • [22] Sun Y, Si S, Wang J, et al. A fair federated learning framework with reinforcement learning[C]//2022 International Joint Conference on Neural Networks (IJCNN). IEEE, 2022: 1-8.
  • [23] Zhang S Q, Lin J, Zhang Q. A multi-agent reinforcement learning approach for efficient client selection in federated learning[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2022, 36(8): 9091-9099.
  • [24] Rjoub G, Wahab O A, Bentahar J, et al. Trust-augmented deep reinforcement learning for federated learning client selection[J]. Information Systems Frontiers, 2022: 1-18.
  • [25] Yang N, Wang S, Chen M, et al. Model-based reinforcement learning for quantized federated learning performance optimization[C]//GLOBECOM 2022-2022 IEEE Global Communications Conference. IEEE, 2022: 5063-5068.
  • [26] Mnih V, Badia A P, Mirza M, et al. Asynchronous methods for deep reinforcement learning[C]//International conference on machine learning. PMLR, 2016: 1928-1937.
  • [27] Zhang W, Yu F, Wang X, et al. R2superscript𝑅2R^{2}italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPTFed: Resilient Reinforcement Federated Learning for Industrial Applications[J]. IEEE Transactions on Industrial Informatics, 2022.
  • [28] Chen L, Zhang W, Xu L, et al. A Federated Parallel Data Platform for Trustworthy AI[C]//2021 IEEE 1st International Conference on Digital Twins and Parallel Intelligence (DTPI). IEEE, 2021: 344-347.
  • [29] Chen L, Zhao D, Tao L, et al. A Credible and Fair Federated Learning Framework Based on Blockchain[J]. IEEE Transactions on Artificial Intelligence, 2024.