[1]\fnmShengwei \surGuo

[1]\orgdivCollege of Electronic Engineering, \orgnameHeilongjiang University, \orgaddress\streetHaxi Street, \cityHarbin, \postcode150080, \stateHeilongjiang, \countryChina

An Interpretable and Efficient Sleep Staging Algorithm: DetectsleepNet

Abstract

Sleep quality directly impacts human health and quality of life, so accurate sleep staging is essential for assessing sleep quality. However, most traditional methods are inefficient and time-consuming due to segmenting different sleep cycles by manual labeling. In contrast, automated sleep staging technology not only directly assesses sleep quality but also helps sleep specialists analyze sleep status, significantly improving efficiency and reducing the cost of sleep monitoring, especially for continuous sleep monitoring. Most of the existing models, however, are deficient in computational efficiency, lightweight design, and model interpretability. In this paper, we propose a neural network architecture based on the prior knowledge of sleep experts. Specifically, 1) Propose an end-to-end model named DetectsleepNet that uses single-channel EEG signals without additional data processing, which has achieved an impressive 80.9% accuracy on the SHHS dataset and an outstanding 88.0% accuracy on the Physio2018 dataset. 2) Constructure an efficient lightweight sleep staging model named DetectsleepNet-tiny based on DetectsleepNet, which has just 6% of the parameter numbers of existing models, but its accuracy exceeds 99% of state-of-the-art models, 3) Introducing a specific inference header to assess the attention given to a specific EEG segment in each sleep frame, enhancing the transparency in the decisions of models. Our model comprises fewer parameters compared to existing ones and ulteriorly explores the interpretability of the model to facilitate its application in healthcare. The code is available at https://github.com/komdec/DetectSleepNet.git.

keywords:
sleep staging, EEG signals, interpretability of neural networks, deep learning

1 Introduction

Sleep consists of non-REM and REM sleep, which are then classified into different sleep stages by sleep experts stage sleep based on electroencephalography (EEG), electrooculography (EOG), and sub-chin (chin) electromyography (chin EMG)[1]. Since 1968, sleep staging studies have followed the guidelines developed by Rechtschaffen and Kales in the Handbook of Standardized Terminology, Techniques, and Scoring Systems for Sleep Staging in Human Subjects (R&K Handbook)[2]. It divides NREM sleep into stages I, II, III, and IV and refers to REM sleep as stage REM. Following this, NREM stage 3 and stage 4 were merged into a new stage called N3 by the American Academy of Sleep Medicine Manual for the Scoring of Sleep and Associated Events (AASM)[3]. Sleep staging and assessment currently depend primarily on the analysis and interpretation of polysomnogram (PSG) data by sleep experts, who usually divide the data into 30-second segments (referred to as one sleep frame). This manual process is known as sleep stage scoring or sleep stage classification. However, for one thing, this method is not only time-consuming and prone to human errors, but it cannot be applied on a large scale due to the limitations of experts. For another thing, if we use the manual method mentioned above, patients with sleep disorders would need to be placed in a specialized environment with expensive equipment and supervised by specialists to monitor their health conditions regularly. This process is both tiring and financially burdensome. Hence, automated sleep staging is essential.

Researchers have paid significant attention to the advantages of profound learning technology breakthroughs in various fields. These include the ability to automatically classify sleep stages without manual feature extraction and its powerful representation and adaptability to different types of data. Recent research indicates an increasing interest in using deep learning for sleep staging. A. Supratak et al. proposed the DeepSleepNet, which utilizes convolutional neural network (CNN) to learn local features first and then the recurrent neural networks (RNN) to learn temporal transformation rules for these features[4]. A. Supratak further proposed the TinySleepNet, applying simpler single-pipelined CNN and unidirectional Long Short-Term Memory Networks (LSTM) to enhance classification accuracy while reducing parameters[5]. H. Seo et al. proposed the IITNet model, which divides each 30-second EEG feature segment into overlap** subsegments and encodes them into corresponding feature vectors. The model utilizes a modified ResNet-50 for feature extraction and applies Bi-LSTM for sleep stage classification[6]. U-Time utilizes a fully convolutional encoder-decoder network to analyze raw EEG signals of any length for sleep stage classification[7]. In the SeqSleepNet, researchers begin by applying a short-time Fourier transform (STFT) to PSG signals to create spectrograms[8]. Then, they use a Hamming window and fast Fourier transform (FFT) for logarithmic scaling and filter out unimportant subbands using a filter module. Finally, the RNN and an attention module are employed to classify multiple sleep stages. XSleepNet uses raw EEG signals and power spectra as inputs for sleep staging with CNN and RNN[9]. SleepTransformer uses power spectra as inputs for sleep staging and analyzes the confidence level of each element throughout the sleep night for interpretability[10]. SleePyCo first obtains a pre-trained model through comparison learning and then fine-tunes by the transformer to learn sequence information[11].

Although deep learning has facilitated significant advances in automatic sleep stage classification and has continuously improved the classification performance by applying cellular neural network, RNN, Transformer, etc., these methods still face many challenges and difficulties in practical applications:

(1) Complexity and computational requirements, such as using the Window Fourier Transform (WFT) to convert one-dimensional physiological signals into two-dimensional images for classification. This process increases the computational resource requirements for model training and inference, making it difficult to embed models directly into small devices with limited computational power.

(2) Model Interpretability: Deep learning models are usually referred to as "black boxes," especially in medicine and healthcare, where the decision-making process requires transparency to both professionals and patients in order to build trust and enhance usefulness.

Hence, research should prioritize develo** lightweight models with enhanced interpretability. This involves simplifying deep learning model structures, reducing parameters for small devices, and exploring new interpretability techniques. These improvements can make automated sleep stage classification more effective in clinical settings and provide more accurate tools for advancing sleep medicine.

This research aims to use deep learning to identify key signal features and transformation rules during sleep and propose a highly accurate and lightweight algorithm for automatic sleep stage classification. The algorithm will substitute traditional manual sleep staging methods, allowing their widespread use in human life and contributing to the realization of AI-driven healthcare services. The main contributions of this work are as follows:

(1) We develop a concise end-to-end neural network model named DetectsleepNet, employing only single-channel EEG signals as inputs and achieving excellent classification accuracy on two large, publicly available sleep staging datasets.

(2) We develop a lightweight model named DetectsleepNet-tiny to address the problem of a large number of parameters and the difficulty of deployment on mobile devices. The model reduces its size to 6% of existing sleep stage models, maintaining over 99% accuracy compared to state-of-the-art (SOTA) models in the meantime. This significantly improves efficiency and deployability.

(3) To facilitate visualization of the decision-making process, specific structural inference heads have been introduced to assess the attention of EEG data to specific data segments. This structure enhances transparency and persuasiveness of the "black box" in real-world applications and assistants sleep experts perform sleep staging quickly and effectively.

2 DetectSleepNet

During the process of sleep staging, EEG feature plays a crucial role. As shown in Figure 1, a frame of the EEG signal from the O2-M1 channel demonstrates that during the awake state with eyes open, the EEG predominantly exhibits β𝛽\betaitalic_β waves and α𝛼\alphaitalic_α waves (8-13Hz, typically most pronounced in the occipital leads). If α𝛼\alphaitalic_α waves are present when the eyes are closed, and the α𝛼\alphaitalic_α wave activity in the O2-M1 channel EEG signal exceeds 50%, it is marked as the awake (W) stage. Figure 1 shows that if the α𝛼\alphaitalic_α wave activity in a frame is less than 50% and is replaced by low-amplitude mixed-frequency waves (LAMF), it is marked as the N1 stage.

Refer to caption
(a) O2-M1 EEG signals during stage W
Refer to caption
(b) O2-M1 EEG signals during stage N1
Figure 1: The analysis of single channel EEG raw signal

2.1 Model Architecture and Mathematical Derivation

The model architecture depicted in the Figure 2 is known as DetectSleepNet. DetectSleepNet commences by employing multi-channel convolutional layers to process the one-dimensional raw EEG data (time-series signals), and introduces an additional dimension during the feature extraction process. Subsequently, in the representation learning phase, the model utilizes global average pooling to condense the two-dimensional EEG feature matrix of each time segment into a one-dimensional feature vector, ensuring precise description of the features of each time segment by a single vector.

Refer to caption
Figure 2: The architecture of DetectSleepNet

DetectSleepNet further utilizes the global average pooling in the sequence learning part to merge the feature vectors of multiple time segments into a single feature vector representing the entire sleep frame. The process are as follows:

{xit|i{1,2,,N}}=split(Xt)conditional-setsuperscriptsubscript𝑥𝑖𝑡𝑖12𝑁𝑠𝑝𝑙𝑖𝑡superscript𝑋𝑡\left\{x_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}=split\left({{X}^{t}}\right){ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } = italic_s italic_p italic_l italic_i italic_t ( italic_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (1)
f_repit=Avg1(Rep(xit))𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡𝐴𝑣𝑔1𝑅𝑒𝑝superscriptsubscript𝑥𝑖𝑡f\_rep_{i}^{t}=Avg1\left(Rep\left(x_{i}^{t}\right)\right)italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_A italic_v italic_g 1 ( italic_R italic_e italic_p ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) ) (2)
{f_gsit}=Global_Seq({f_repit|i{1,2,,N},t{1,2,,M}})𝑓_𝑔superscriptsubscript𝑠𝑖𝑡𝐺𝑙𝑜𝑏𝑎𝑙_𝑆𝑒𝑞conditional-set𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡formulae-sequence𝑖12𝑁𝑡12𝑀\left\{f\_gs_{i}^{t}\right\}=Global\_Seq\left(\left\{f\_rep_{i}^{t}|i\in\left% \{1,2,\ldots,N\right\},t\in\left\{1,2,\ldots,M\right\}\right\}\right){ italic_f _ italic_g italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } = italic_G italic_l italic_o italic_b italic_a italic_l _ italic_S italic_e italic_q ( { italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } , italic_t ∈ { 1 , 2 , … , italic_M } } ) (3)
{f_lsit}=Local_Seq({f_gsit|i{1,2,,N}})𝑓_𝑙superscriptsubscript𝑠𝑖𝑡𝐿𝑜𝑐𝑎𝑙_𝑆𝑒𝑞conditional-set𝑓_𝑔superscriptsubscript𝑠𝑖𝑡𝑖12𝑁\left\{f\_ls_{i}^{t}\right\}=Local\_Seq\left(\left\{f\_gs_{i}^{t}|i\in\left\{1% ,2,\ldots,N\right\}\right\}\right){ italic_f _ italic_l italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } = italic_L italic_o italic_c italic_a italic_l _ italic_S italic_e italic_q ( { italic_f _ italic_g italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } ) (4)
f_seqt=Avg2({f_lsit|i{1,2,,N}})𝑓_𝑠𝑒superscript𝑞𝑡𝐴𝑣𝑔2conditional-set𝑓_𝑙superscriptsubscript𝑠𝑖𝑡𝑖12𝑁f\_se{{q}^{t}}=Avg2\left(\left\{f\_ls_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}% \right\}\right)italic_f _ italic_s italic_e italic_q start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_A italic_v italic_g 2 ( { italic_f _ italic_l italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } ) (5)
{ct}={classifier (f_seqt)|t{1,2,,M}}superscript𝑐𝑡conditional-set𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟 𝑓_𝑠𝑒superscript𝑞𝑡𝑡12𝑀\left\{{{c}^{t}}\right\}=\left\{classifier\text{ }\left(f\_se{{q}^{t}}\right)|% t\in\left\{1,2,\ldots,M\right\}\right\}{ italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } = { italic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r ( italic_f _ italic_s italic_e italic_q start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) | italic_t ∈ { 1 , 2 , … , italic_M } } (6)

Where, Xtsuperscript𝑋𝑡{{X}^{t}}italic_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT denotes the single-channel EEG data at time step t𝑡titalic_t. Through the function split𝑠𝑝𝑙𝑖𝑡splititalic_s italic_p italic_l italic_i italic_t, it is split into N𝑁Nitalic_N subsets xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. For each subset xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT, the representation learning structure Rep𝑅𝑒𝑝Repitalic_R italic_e italic_p extracts the feature matrix. Following this, global average pooling Avg1𝐴𝑣𝑔1Avg1italic_A italic_v italic_g 1 is applied to condense this into a one-dimensional feature vector f_repit𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡f\_rep_{i}^{t}italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT.

The feature vectors are processed by the global sequence learning structure Global_Seq𝐺𝑙𝑜𝑏𝑎𝑙_𝑆𝑒𝑞Global\_Seqitalic_G italic_l italic_o italic_b italic_a italic_l _ italic_S italic_e italic_q and the local sequence learning structure Local_Seq𝐿𝑜𝑐𝑎𝑙_𝑆𝑒𝑞Local\_Seqitalic_L italic_o italic_c italic_a italic_l _ italic_S italic_e italic_q to derive the vector f_lsit𝑓_𝑙superscriptsubscript𝑠𝑖𝑡f\_ls_{i}^{t}italic_f _ italic_l italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT corresponding to each subset xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. Subsequently, the global average pooling Avg2𝐴𝑣𝑔2Avg2italic_A italic_v italic_g 2 is employed to obtain the feature vector f_seqt𝑓_𝑠𝑒superscript𝑞𝑡f\_seq^{t}italic_f _ italic_s italic_e italic_q start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT for each sleep frame. Finally, the classifier classifier𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟classifieritalic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r transforms the feature vector of each sleep frame into the category vector ctsuperscript𝑐𝑡{{c}^{t}}italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT, enabling precise analysis and identification of various sleep stages. The collection of category vectors for M𝑀Mitalic_M sleep frames is represented as {ct|t{1,2,,M}}conditional-setsuperscript𝑐𝑡𝑡12𝑀\left\{{{c}^{t}}|t\in\left\{1,2,\ldots,M\right\}\right\}{ italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_t ∈ { 1 , 2 , … , italic_M } }.

To effectively extract these features, this paper developed a parameter-shared local feature extractor as the fundamental structure for representation learning. Each local time series is processed through this feature extractor to generate feature vectors that capture the waveform characteristics of the local time. These features are then integrated using a sequence learning architecture to produce the final sleep stage categories. As illustrated in Figure 3, this simplified model architecture outlines the entire process from feature extraction to classification.

Refer to caption
Figure 3: The architecture of DetectSleepNet

Given a series of M𝑀Mitalic_M continuous sleep frames, wherein each frame is subdivided into N𝑁Nitalic_N time segments, the forward computation steps are determined as follows:

{xit|i{1,2,,N}}=split(Xt).conditional-setsuperscriptsubscript𝑥𝑖𝑡𝑖12𝑁𝑠𝑝𝑙𝑖𝑡superscript𝑋𝑡\left\{x_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}=split\left({{X}^{t}}% \right).{ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } = italic_s italic_p italic_l italic_i italic_t ( italic_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) . (7)
f_repit=Rep(xit).𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡𝑅𝑒𝑝superscriptsubscript𝑥𝑖𝑡f\_rep_{i}^{t}=Rep\left(x_{i}^{t}\right).italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_R italic_e italic_p ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) . (8)
{f_seqt,t{1,2,,M}}=Seq({f_repit|i{1,2,,N},t{1,2,,M}}).𝑓_𝑠𝑒superscript𝑞𝑡𝑡12𝑀𝑆𝑒𝑞conditional-set𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡formulae-sequence𝑖12𝑁𝑡12𝑀\displaystyle\left\{f\_se{{q}^{t}},t\in\left\{1,2,\ldots,M\right\}\right\}=Seq% (\left\{f\_rep_{i}^{t}|i\in\left\{1,2,\ldots,N\right\},t\in\left\{1,2,\ldots,M% \right\}\right\}).{ italic_f _ italic_s italic_e italic_q start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , italic_t ∈ { 1 , 2 , … , italic_M } } = italic_S italic_e italic_q ( { italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } , italic_t ∈ { 1 , 2 , … , italic_M } } ) . (9)
ct=classifier (f_seqt).superscript𝑐𝑡𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟 𝑓_𝑠𝑒superscript𝑞𝑡{{c}^{t}}=classifier\text{ }\left(f\_se{{q}^{t}}\right).italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r ( italic_f _ italic_s italic_e italic_q start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) . (10)

where i{1,2,,N}𝑖12𝑁i\in\left\{1,2,\ldots,N\right\}italic_i ∈ { 1 , 2 , … , italic_N } and t{1,2,,M}𝑡12𝑀t\in\left\{1,2,\ldots,M\right\}italic_t ∈ { 1 , 2 , … , italic_M } represent the indices of the time segments and sleep frames, respectively. The variable Xtsuperscript𝑋𝑡{{X}^{t}}italic_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the single-channel EEG data of the t𝑡titalic_tth sleep frame, and xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the EEG data of the i𝑖iitalic_ith time segment of the i𝑖iitalic_ith sleep frame subsequent to segmentation. Rep𝑅𝑒𝑝Repitalic_R italic_e italic_p is the representation learning structure that converts the data of each time segment xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT into the corresponding feature vector f_repit𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡f\_rep_{i}^{t}italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. Seq𝑆𝑒𝑞Seqitalic_S italic_e italic_q is the sequence learning structure that integrates the feature vectors of continuous sleep frames to generate the composite feature vector f_seqt𝑓_𝑠𝑒superscript𝑞𝑡f\_se{{q}^{t}}italic_f _ italic_s italic_e italic_q start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT for each sleep frame. The classifier classifier𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟classifieritalic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r converts the feature vector of each sleep frame into the corresponding category vector ctsuperscript𝑐𝑡{{c}^{t}}italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT.

This process involves segmentation, feature extraction, serialization, and classification, ensuring continuity and efficiency.

2.2 Representation Learning

To effectively capture the unique frequency features of different EEG rhythms, we have introduced a multi-scale feature extraction module (MCFEM). The module, as illustrated in Figure 4, is designed with multiple convolutional layers and feature extraction paths to capture and integrate features from different receptive fields.

Refer to caption
Figure 4: The module of multi scale feature extraction

In the Multi-Channel Feature Extraction Module (MCFEM), the first convolutional layer (Conv Layer 1) is set with a kernel size of 3 to effectively capture features from smaller receptive fields. Subsequently, two dilated convolutional layers (Dilated Conv 1 and Dilated Conv 2) leverage dilation rates of 3 and 5, respectively, to encompass broader receptive fields and capture a more comprehensive hierarchical feature representation. Furthermore, Conv Layer 2 employs a kernel size of 1 primarily to regulate the number of data channels and feature dimensions, thereby ensuring the efficient fusion of outputs from diverse pathways. Each feature extraction pathway comprises two convolutional layers, thereby enhancing the module’s fitting capacity.

The module incorporates batch normalization and nonlinear activation functions to enhance the model’s training efficiency and overall effectiveness. Furthermore, it seamlessly integrates features from diverse receptive fields through summation operations, deftly avoiding an increase in feature dimension. The final pooling layer diminishes feature dimensions while fortifying the model’s robustness and its capacity to generalize amidst variations in input data. These design elements work together to make the extracted features more comprehensive and suitable for complex EEG signal analysis tasks.

2.3 Sequence Learning

This paper utilizes a six-layer bidirectional gated recurrent unit (Bi-GRU) structure for sequence learning. Among them, five layers are responsible for capturing global sequence information and analyzing the connections between consecutive sleep frames, while the remaining layer focuses on processing local sequence information and examining the detailed relationships between time segments within a single sleep frame. The specific computation process is as follows:

{f_gsit}=Global_Seq({f_repit|i{1,2,,N},t{1,2,,M}}).𝑓_𝑔superscriptsubscript𝑠𝑖𝑡𝐺𝑙𝑜𝑏𝑎𝑙_𝑆𝑒𝑞conditional-set𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡formulae-sequence𝑖12𝑁𝑡12𝑀\left\{f\_gs_{i}^{t}\right\}=Global\_Seq\left(\left\{f\_rep_{i}^{t}|i\in\left% \{1,2,\ldots,N\right\},t\in\left\{1,2,\ldots,M\right\}\right\}\right).{ italic_f _ italic_g italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } = italic_G italic_l italic_o italic_b italic_a italic_l _ italic_S italic_e italic_q ( { italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } , italic_t ∈ { 1 , 2 , … , italic_M } } ) . (11)
{f_lsit}=Local_Seq({f_gsit|i{1,2,,N}}).𝑓_𝑙superscriptsubscript𝑠𝑖𝑡𝐿𝑜𝑐𝑎𝑙_𝑆𝑒𝑞conditional-set𝑓_𝑔superscriptsubscript𝑠𝑖𝑡𝑖12𝑁\left\{f\_ls_{i}^{t}\right\}=Local\_Seq\left(\left\{f\_gs_{i}^{t}|i\in\left\{1% ,2,\ldots,N\right\}\right\}\right).{ italic_f _ italic_l italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } = italic_L italic_o italic_c italic_a italic_l _ italic_S italic_e italic_q ( { italic_f _ italic_g italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } ) . (12)

the set {f_repit|i{1,2,,N},t{1,2,,M}}conditional-set𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡formulae-sequence𝑖12𝑁𝑡12𝑀\left\{f\_rep_{i}^{t}|i\in\left\{1,2,\ldots,N\right\},t\in\left\{1,2,\ldots,M% \right\}\right\}{ italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } , italic_t ∈ { 1 , 2 , … , italic_M } } denotes the features extracted from an N x M matrix of data points, each containing single-channel EEG data within a specific time segment. The global sequence learning structure, Global_Seq𝐺𝑙𝑜𝑏𝑎𝑙_𝑆𝑒𝑞Global\_Seqitalic_G italic_l italic_o italic_b italic_a italic_l _ italic_S italic_e italic_q, integrates the sequence features of multiple sleep frames from a global perspective, resulting in global sequence features {f_gsit}𝑓_𝑔superscriptsubscript𝑠𝑖𝑡\left\{f\_gs_{i}^{t}\right\}{ italic_f _ italic_g italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT }. The local sequence learning structure, Local_Seq𝐿𝑜𝑐𝑎𝑙_𝑆𝑒𝑞Local\_Seqitalic_L italic_o italic_c italic_a italic_l _ italic_S italic_e italic_q, further processes these global features to refine the sequence dynamics within each time segment of a single sleep frame, producing local sequence features {f_lsit}𝑓_𝑙superscriptsubscript𝑠𝑖𝑡\left\{f\_ls_{i}^{t}\right\}{ italic_f _ italic_l italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT }. This design allows the model to capture global dynamics and local variations, thereby enhancing the accuracy and reliability of sleep stage determination.

3 Experiment

3.1 Datasets

In this paper, we have opted to work with the SHHS and Physio2018 datasets[12, 13, 14]. These datasets offer the advantage of large sample sizes, enabling comprehensive testing and validation of model performance and stability. Moreover, the standardization of the above datasets allows for easy comparison with other studies, thus enhancing the reliability and reproducibility of research. The specific dataset evaluation scheme is exhibited in Table 1.

Table 1: Number of samples
Dataset Scoring manual Sample Channel
Physio2018 AASM 994 C3-A2subscriptC3subscript-A2{{\text{C}}_{\text{3}}}\text{-}{{\text{A}}_{\text{2}}}C start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT - roman_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
SHHS R&K 5793 C4-M1subscriptC4subscript-M1{{\text{C}}_{\text{4}}}\text{-}{{\text{M}}_{\text{1}}}C start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT - roman_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT
\botrule

The processing of both datasets followed the methods of existing research. Specifically, the SHHS dataset was refined by combining sleep stage 3 and 4 into a single N3 stage and removing movement periods and unknown stages to maintain data consistency and precision. The comprehensive evaluation framework of the datasets and the categorization distribution will be explicated in subsequent sections. The distribution of labeled categories for sleep epochs is presented in Table 2.

Table 2: Distribution Information of Physio2018 and SHHS Datasets
Dataset Sleep frame category distribution
W N1 N2 N3 R Total
Physio2018 157,945 157,945 377,870 102,592 116,877 892,262
(17.7%) (15.4%) (42.3%) (11.5%) (13.1%) (100%)
SHHS 1,691,288 217,583 2,397,460 739,403 817,473 5,863,207
(28.8%) (3.7%) (40.9%) 12.6(%) (13.9%) (100%)
\botrule

3.2 Experimental Setting

In this paper, the SHHS dataset is preprocessed to adhere to the AASM scoring manual. The sleep staging model was trained using 20 consecutive single-channel EEG sleep frames, incorporating data from both the Physio2018 and SHHS datasets. The ReLU6 activation function was chosen for its ability to remain linear for positive values and truncate for negative values, effectively mitigating the vanishing gradient problem and improving training efficiency. The model utilizes the cross-entropy loss function, as illustrated in the formula:

CELoss(y,y^)=[ylogy^+(1y)log(1y^)]𝐶subscript𝐸𝐿𝑜𝑠𝑠𝑦^𝑦delimited-[]𝑦^𝑦1𝑦1^𝑦C{{E}_{Loss}}\left(y,\hat{y}\right)=-[y\log\hat{y}+(1-y)\log(1-\hat{y})]italic_C italic_E start_POSTSUBSCRIPT italic_L italic_o italic_s italic_s end_POSTSUBSCRIPT ( italic_y , over^ start_ARG italic_y end_ARG ) = - [ italic_y roman_log over^ start_ARG italic_y end_ARG + ( 1 - italic_y ) roman_log ( 1 - over^ start_ARG italic_y end_ARG ) ] (13)

This paper employed the AdamW optimizer, which incorporates a weight decay mechanism in addition to the original Adam optimizer[15, 16]. The approach addressed overfitting issues and enhanced the generalization ability of model. AdamW combines adaptive learning rates and momentum optimization, thereby improving the stability and efficiency of the training process. Furthermore, the paper implemented the CyclicLR learning rate scheduling method, which intermittently adjusts the learning rate to facilitate the avoidance of local optima and expedite convergence. Such a strategy effectively safeguards the model against becoming ensnared in local minima by modifying the learning rate during training[17].

Overall, we combined utilization of AdamW and CyclicLR optimizes the learning process and improves the model’s capacity to handle complex sleep data through a meticulously designed loss function and activation function. This enhances prediction accuracy and model generalization, resulting in more precise and stable performance in sleep staging tasks.

3.3 Performance Comparison and Analysis

DetectSleepNet has demonstrated outstanding performance on the Physio2018 and SHHS datasets. As shown in Table 3, is a comparison with the top algorithms in single-channel sleep staging.

Table 3: Comparison between DetectsleepNet and other state-of-the-art methods
Dataset Method Overall F1 score
OA MF1 κ𝜅\kappaitalic_κ W N1 N2 N3 R
Physio2018 DetectSleepNet 80.980.9\mathbf{80.9}bold_80.9 79.079.0\mathbf{79.0}bold_79.0 0.7390.739\mathbf{0.739}bold_0.739 84.684.6\mathbf{84.6}bold_84.6 59.0¯¯59.0\underline{59.0}under¯ start_ARG 59.0 end_ARG 85.1¯¯85.1\underline{85.1}under¯ start_ARG 85.1 end_ARG 80.280.2\mathbf{80.2}bold_80.2 86.386.3\mathbf{86.3}bold_86.3
SleePyCo[11] 80.980.9\mathbf{80.9}bold_80.9 78.9¯¯78.9\underline{78.9}under¯ start_ARG 78.9 end_ARG 0.737¯¯0.737\underline{0.737}under¯ start_ARG 0.737 end_ARG 84.4¯¯84.4\underline{84.4}under¯ start_ARG 84.4 end_ARG 59.359.3\mathbf{59.3}bold_59.3 85.385.3\mathbf{85.3}bold_85.3 79.4¯¯79.4\underline{79.4}under¯ start_ARG 79.4 end_ARG 86.386.3\mathbf{86.3}bold_86.3
XSleepNet[9] 80.3¯¯80.3\underline{80.3}under¯ start_ARG 80.3 end_ARG 78.6 0.732 - - - - -
SeqSleepNet[8] 79.4 77.6 0.719 - - - - -
U-time[7] 78.8 77.4 0.714 82.5 59.0¯¯59.0\underline{59.0}under¯ start_ARG 59.0 end_ARG 83.1 79.0 83.5¯¯83.5\underline{83.5}under¯ start_ARG 83.5 end_ARG
SHHS DetectSleepNet 88.088.0\mathbf{88.0}bold_88.0 80.780.7\mathbf{80.7}bold_80.7 0.8310.831\mathbf{0.831}bold_0.831 92.992.9\mathbf{92.9}bold_92.9 48.5 88.588.5\mathbf{88.5}bold_88.5 84.8 88.788.7\mathbf{88.7}bold_88.7
SleePyCo 87.987.9\mathbf{87.9}bold_87.9 80.780.7\mathbf{80.7}bold_80.7 0.830¯¯0.830\underline{0.830}under¯ start_ARG 0.830 end_ARG 92.6¯¯92.6\underline{92.6}under¯ start_ARG 92.6 end_ARG 49.2¯¯49.2\underline{49.2}under¯ start_ARG 49.2 end_ARG 88.588.5\mathbf{88.5}bold_88.5 84.5 88.6¯¯88.6\underline{88.6}under¯ start_ARG 88.6 end_ARG
SleepTransformer[10] 87.7 80.1¯¯80.1\underline{80.1}under¯ start_ARG 80.1 end_ARG 0.828 92.2 46.1 88.3 85.285.2\mathbf{85.2}bold_85.2 88.6¯¯88.6\underline{88.6}under¯ start_ARG 88.6 end_ARG
XSleepNet 87.6 80.780.7\mathbf{80.7}bold_80.7 0.826 92.0 49.949.9\mathbf{49.9}bold_49.9 88.3 85.0¯¯85.0\underline{85.0}under¯ start_ARG 85.0 end_ARG 88.2
IITNet[6] 86.7 79.8 0.812 90.1 48.1 88.4¯¯88.4\underline{88.4}under¯ start_ARG 88.4 end_ARG 85.285.2\mathbf{85.2}bold_85.2 87.2
SeqSleepNet 86.5 78.5 0.81 - - - - -
\botrule

In Table 3, the bold indicates the best performance, and the underlined indicates the second-best. Notably, DetectSleepNet achieved superior results across multiple classification tasks. Conversely, SleePyCo, XSleepNet, and SleepTransformer employ more intricate technical methodologies. SleePyCo uses a contrastive learning strategy for pre-training rather than a direct end-to-end architecture. XSleepNet and SleepTransformer transform raw EEG signals into time-frequency graphs, extracting features through Fourier transforms. DetectSleepNet employs an end-to-end structure, directly processing raw EEG signals and effectively utilizing waveform information. This approach not only reduces the complexity of preprocessing steps but also improves the accuracy and generalization capability of the model in sleep staging tasks, demonstrating its potential and advantages in the field of sleep staging.

Despite the overall effectiveness of current automatic sleep staging methods, there are specific challenges in accurately classifying the N1 stage. This can be attributed to several factors:

1. Annotation Bias: Sleep experts rely on the proportion of alpha rhythms in the occipital leads (O1-M2 or O2-M1) to determine sleep stages. The AASM scoring manual dictates that more than 50% alpha rhythm is classified as the W stage, while less than 50% is marked as the N1 stage. However, subtle variations around the 50% threshold can result in manual annotation errors.

2. Signal Modality: About 10% of the population does not produce alpha waves and requires alternative signals (e.g., EMG signals) to determine the N1 stage. Interpretation of these signals may vary among sleep experts based on personal experience and preference.

3. EEG Lead Selection: This paper primarily utilizes central leads (C4-M2 or C3-M1) for single-channel sleep staging. While central leads can capture EEG characteristics from multiple regions, the proportion of alpha rhythms may differ from the occipital leads, potentially impacting classification accuracy.

The confusion matrix of the DetectSleepNet method is shown in the Figure 5, which illustrates the matching between the actual predictions and the true labels of the DetectSleepNet model on the test set. This visualization comprehensively showcases the model’s prediction accuracy across different sleep stages and its misclassification scenarios, encompassing true positives, false positives, true negatives, and false negatives. The confusion matrix serves as a crucial instrument for gaining insights into the real-world performance of models, offering an intuitive portrayal of its strengths and weaknesses within each classification label.

Refer to caption
Figure 5: The module of multi scale feature extraction

With the proposed multi-scale feature extraction module, our DetectSleepNet model can process one-dimensional EEG signals of varying frequencies. The comparison experiments are displayed in the Figure 6. In these experiments, TinySleepNet adjusts the sizes of its convolutional kernels based on the input EEG signal frequencies, resulting in changes to the model parameters. In contrast, the architecture of DetectSleepNet remains consistent.

Refer to caption
Figure 6: The accuracy comparison of TinySleepNet and DetectSleepNet

4 Lightweight Processing

Previous experiments have shown that the number of model parameters in the sequence learning part far exceeds that in the representation learning part. This phenomenon is reminiscent of visual-text multimodal tasks, where the parameter count of the visual feature extractor surpasses that of the text feature extractor. To tackle this issue, a transfer learning strategy was implemented for the DetectSleepNet model. Specifically, the strategy entailed the fixation of the weights within the representation learning component while retraining the sequence learning component. A detailed exposition of this methodology and the model’s architecture is provided in Figure 7.

Refer to caption
Figure 7: The architecture of DetectSleepNet-tiny

The Figure 7 shows the architecture of the DetectSleepNet-tiny model, a lightweight version of DetectSleepNet. Notably, the original 6-layer bidirectional gated recurrent unit (Bi-GRU) has been simplified to a single-layer Bi-GRU, resulting in a substantial reduction in model size and concurrent enhancements in training efficiency. The sequence of computation steps is as follows:

{xit|i{1,2,,N}}=split(Xt)conditional-setsuperscriptsubscript𝑥𝑖𝑡𝑖12𝑁𝑠𝑝𝑙𝑖𝑡superscript𝑋𝑡\left\{x_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}=split\left({{X}^{t}}\right){ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } = italic_s italic_p italic_l italic_i italic_t ( italic_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (14)
f_repit=Avg1(Rep(xit))𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡𝐴𝑣𝑔1𝑅𝑒𝑝superscriptsubscript𝑥𝑖𝑡f\_rep_{i}^{t}=Avg1\left(Rep\left(x_{i}^{t}\right)\right)italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_A italic_v italic_g 1 ( italic_R italic_e italic_p ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) ) (15)
{f_tsit}=Tiny_Seq({f_repit|i{1,2,,N},t{1,2,,M}})𝑓_𝑡superscriptsubscript𝑠𝑖𝑡𝑇𝑖𝑛𝑦_𝑆𝑒𝑞conditional-set𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡formulae-sequence𝑖12𝑁𝑡12𝑀\left\{f\_ts_{i}^{t}\right\}=Tiny\_Seq\left(\left\{f\_rep_{i}^{t}|i\in\left\{1% ,2,\ldots,N\right\},t\in\left\{1,2,\ldots,M\right\}\right\}\right){ italic_f _ italic_t italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } = italic_T italic_i italic_n italic_y _ italic_S italic_e italic_q ( { italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } , italic_t ∈ { 1 , 2 , … , italic_M } } ) (16)
f_avgt=Avg2({f_tsit|i{1,2,,N}})𝑓_𝑎𝑣superscript𝑔𝑡𝐴𝑣𝑔2conditional-set𝑓_𝑡superscriptsubscript𝑠𝑖𝑡𝑖12𝑁f\_av{{g}^{t}}=Avg2\left(\left\{f\_ts_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}% \right\}\right)italic_f _ italic_a italic_v italic_g start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_A italic_v italic_g 2 ( { italic_f _ italic_t italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } ) (17)
{ct}={classifier (f_avgt)|t{1,2,,M}}superscript𝑐𝑡conditional-set𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟 𝑓_𝑎𝑣superscript𝑔𝑡𝑡12𝑀\left\{{{c}^{t}}\right\}=\left\{classifier\text{ }\left(f\_av{{g}^{t}}\right)|% t\in\left\{1,2,\ldots,M\right\}\right\}{ italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } = { italic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r ( italic_f _ italic_a italic_v italic_g start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) | italic_t ∈ { 1 , 2 , … , italic_M } } (18)
{y^t}={argmax(ct)}superscript^𝑦𝑡𝑎𝑟𝑔𝑚𝑎𝑥superscript𝑐𝑡\left\{{{{\hat{y}}}^{t}}\right\}=\left\{argmax({{c}^{t}})\right\}{ over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } = { italic_a italic_r italic_g italic_m italic_a italic_x ( italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) } (19)
L=1Mt=1MCELoss({y^t},{yt})𝐿1𝑀superscriptsubscript𝑡1𝑀𝐶𝐸𝐿𝑜𝑠𝑠superscript^𝑦𝑡superscript𝑦𝑡L=\frac{1}{M}\sum\limits_{t=1}^{M}{CELoss\left(\left\{{{{\hat{y}}}^{t}}\right% \},\left\{{{y}^{t}}\right\}\right)}italic_L = divide start_ARG 1 end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_C italic_E italic_L italic_o italic_s italic_s ( { over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } , { italic_y start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } ) (20)
θTiny_SeqθTiny_SeqηLθTiny_Seqsubscript𝜃𝑇𝑖𝑛𝑦_𝑆𝑒𝑞subscript𝜃𝑇𝑖𝑛𝑦_𝑆𝑒𝑞𝜂𝐿subscript𝜃𝑇𝑖𝑛𝑦_𝑆𝑒𝑞{{\theta}_{Tiny\_Seq}}\leftarrow{{\theta}_{Tiny\_Seq}}-\eta\cdot\frac{\partial L% }{\partial{{\theta}_{Tiny\_Seq}}}italic_θ start_POSTSUBSCRIPT italic_T italic_i italic_n italic_y _ italic_S italic_e italic_q end_POSTSUBSCRIPT ← italic_θ start_POSTSUBSCRIPT italic_T italic_i italic_n italic_y _ italic_S italic_e italic_q end_POSTSUBSCRIPT - italic_η ⋅ divide start_ARG ∂ italic_L end_ARG start_ARG ∂ italic_θ start_POSTSUBSCRIPT italic_T italic_i italic_n italic_y _ italic_S italic_e italic_q end_POSTSUBSCRIPT end_ARG (21)

where Tiny_Seq𝑇𝑖𝑛𝑦_𝑆𝑒𝑞Tiny\_Seqitalic_T italic_i italic_n italic_y _ italic_S italic_e italic_q represents the new sequence learning structure introduced in this chapter. f_tsit𝑓_𝑡superscriptsubscript𝑠𝑖𝑡f\_ts_{i}^{t}italic_f _ italic_t italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the feature vector corresponding to each subset xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT after being processed by the Tiny_Seq𝑇𝑖𝑛𝑦_𝑆𝑒𝑞Tiny\_Seqitalic_T italic_i italic_n italic_y _ italic_S italic_e italic_q structure. f_avgt𝑓_𝑎𝑣superscript𝑔𝑡f\_av{{g}^{t}}italic_f _ italic_a italic_v italic_g start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the feature vector for each sleep frame. The function argmax𝑎𝑟𝑔𝑚𝑎𝑥argmaxitalic_a italic_r italic_g italic_m italic_a italic_x is used to obtain the category label with the highest probability. {y^t}superscript^𝑦𝑡\left\{{{{\hat{y}}}^{t}}\right\}{ over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } is the set of predicted labels for multiple sleep frames, and {yt}superscript𝑦𝑡\left\{{{y}^{t}}\right\}{ italic_y start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } is the set of true labels for multiple sleep frames. CELoss𝐶𝐸𝐿𝑜𝑠𝑠CELossitalic_C italic_E italic_L italic_o italic_s italic_s is the cross-entropy loss function, L𝐿Litalic_L represents the mean of the loss function, and η𝜂\etaitalic_η represents the learning rate. The model parameters Tiny_Seq𝑇𝑖𝑛𝑦_𝑆𝑒𝑞Tiny\_Seqitalic_T italic_i italic_n italic_y _ italic_S italic_e italic_q of θTiny_Seqsubscript𝜃𝑇𝑖𝑛𝑦_𝑆𝑒𝑞{{\theta}_{Tiny\_Seq}}italic_θ start_POSTSUBSCRIPT italic_T italic_i italic_n italic_y _ italic_S italic_e italic_q end_POSTSUBSCRIPT are iteratively updated through gradient descent to optimize performance.

The comparison results for accuracy between DetectSleepNet and DetectSleepNet-tiny are illustrated in Table 4. It is evident from the table that, despite a slight decrease in performance on the Physio2018 and SHHS1 datasets, DetectSleepNet-tiny still maintains high accuracy at 99.5% and 99.3% of DetectSleepNet’s accuracy, respectively. This indicates that even with a significant reduction in the number of parameters, DetectSleepNet-tiny remains effective in performing the sleep staging task. (In the table, 80.5% and 80.9% represent the accuracy of DetectSleepNet-tiny and DetectSleepNet on the Physio2018 dataset, respectively.)

Table 4: Number of samples
Method Parameter
DetectsleepNet-tiny 0.049M
DetectsleepNet 0.43M
SalientSleepNet[18] 0.9M
U-time[7] 1.1M
TinySleepNet[5] 1.3M
SleepEEGNet[19] 2.1M
DeepSleepNet[4] 21M
\botrule

To further evaluate the lightweight design of DetectSleepNet-tiny, this paper also compares its model parameters with those of other models published in recent years, as outlined in Table 5. The results reveal that the model parameters of DetectSleepNet-tiny are only 0.049M, significantly lower than those of the other comparison models. This validates that while maintaining high accuracy, DetectSleepNet-tiny effectively enhances the model’s efficiency and adaptability, especially in scenarios with limited resources or the necessity for mobile deployment.

Table 5: Number of samples
Dataset Method Accuracy
Physio2018 DetectsleepNet-tiny 80.5(99.5%)
DetectsleepNet 80.9(100%)
SHHS DetectsleepNet-tiny 87.4(99.3%)
DetectsleepNet 88.0(100%)
\botrule

The comparison results demonstrate that DetectSleepNet-tiny achieves model lightweight while maintaining excellent performance, highlighting its potential applications in sleep staging. Through the optimization of the sequence learning component’s design, DetectSleepNet-tiny mitigates the computational load, rendering the model better suited for deployment in resource-constrained environments, such as mobile devices or remote medical monitoring systems. In summary, the successful development of DetectSleepNet-tiny exemplifies a harmonious blend of performance and resource efficiency, offering a feasible and effective solution for forthcoming sleep monitoring technologies.

5 Interpretability Analysis

The interpretability of neural networks is gaining increasing attention in the field of deep learning, especially in the medical field, where decision transparency and verifiability are of utmost importance. While researchers have developed various techniques to elucidate the decision-making processes of convolutional neural networks in image classification tasks, the application of these methods in the specialized field of sleep staging still requires further exploration[20, 21].

In this section, we aim to enhance the interpretability of the DetectSleepNet model introduced in section 2. We will retain its representation learning component and optimize the decision reasoning process by implementing various reasoning heads to augment the explanatory capability of model. These reasoning heads are specifically designed to offer clear decision logic, thereby rendering the model’s prediction process more transparent.

5.1 Voting based decision model

In the vote-based decision model, each individual sleep frame Xtsuperscript𝑋𝑡{{X}^{t}}italic_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is subdivided into N𝑁Nitalic_N sample blocks xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. Each sample block xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT independently undergoes the representation learning module to extract the feature vector f_repit𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡f\_rep_{i}^{t}italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. These feature vectors are then separately fed into the classifier, resulting in an independent classification for each time segment. To determine the final sleep frame category, a voting mechanism is implemented by applying global average pooling to these classification results. The forward process is illustrated in Figure 8.

Refer to caption
Figure 8: The forward process of the voting based decision model

The detailed calculation process of this model follows:

{xit|i{1,2,,N}}=split(Xt)conditional-setsuperscriptsubscript𝑥𝑖𝑡𝑖12𝑁𝑠𝑝𝑙𝑖𝑡superscript𝑋𝑡\left\{x_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}=split\left({{X}^{t}}\right){ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } = italic_s italic_p italic_l italic_i italic_t ( italic_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (22)
f_repit=Rep(xit)𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡𝑅𝑒𝑝superscriptsubscript𝑥𝑖𝑡f\_rep_{i}^{t}=Rep\left(x_{i}^{t}\right)italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_R italic_e italic_p ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (23)
cit=classifier(f_repit)superscriptsubscript𝑐𝑖𝑡𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡c_{i}^{t}=classifier\left(f\_rep_{i}^{t}\right)italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r ( italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (24)
ct=1Ni=1Ncitsuperscript𝑐𝑡1𝑁superscriptsubscript𝑖1𝑁superscriptsubscript𝑐𝑖𝑡{{c}^{t}}=\frac{1}{N}\sum\limits_{i=1}^{N}{c_{i}^{t}}italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT (25)
predt=argmax(ct)𝑝𝑟𝑒superscript𝑑𝑡𝑎𝑟𝑔𝑚𝑎𝑥superscript𝑐𝑡pre{{d}^{t}}=argmax({{c}^{t}})italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_a italic_r italic_g italic_m italic_a italic_x ( italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (26)

where f_repit𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡f\_rep_{i}^{t}italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the feature vector of the sample block xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT that is extracted by the representation learning module Rep𝑅𝑒𝑝Repitalic_R italic_e italic_p. citsuperscriptsubscript𝑐𝑖𝑡c_{i}^{t}italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT denotes the classification vector (i.e., the classification result) for each time segment within the sleep frame after the feature vector f_repit𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡f\_rep_{i}^{t}italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is input into the classifier. The final category result for the sleep frame ctsuperscript𝑐𝑡{{c}^{t}}italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is obtained by averaging (i.e., voting) the classification results citsuperscriptsubscript𝑐𝑖𝑡c_{i}^{t}italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT across the N𝑁Nitalic_N time segments. The final classification vector predt𝑝𝑟𝑒superscript𝑑𝑡pre{{d}^{t}}italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is obtained from the final category result cisubscript𝑐𝑖{{c}_{i}}italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT of the sleep frame using the maximum value index function argmax𝑎𝑟𝑔𝑚𝑎𝑥argmaxitalic_a italic_r italic_g italic_m italic_a italic_x. predt𝑝𝑟𝑒superscript𝑑𝑡pre{{d}^{t}}italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is a scalar representing the final classification result.

The process enables the decision-making of the model to be transparent, allowing for the analysis and verification of each sample block’s specific contribution to the final decision. The feature vector f_repit𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡f\_rep_{i}^{t}italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT of each sample block can be considered independent of other blocks, providing a unique opportunity to track and understand the model’s decision vectors throughout the entire sleep frame. The specific decision vector expression is as follows:

Attxt=Attct={cit[predt]|i{1,2,,N}}𝐴𝑡superscriptsubscript𝑡𝑥𝑡𝐴𝑡superscriptsubscript𝑡𝑐𝑡conditional-setsuperscriptsubscript𝑐𝑖𝑡delimited-[]𝑝𝑟𝑒superscript𝑑𝑡𝑖12𝑁Att_{x}^{t}=Att_{c}^{t}=\left\{c_{i}^{t}\left[pre{{d}^{t}}\right]|i\in\left\{1% ,2,\ldots,N\right\}\right\}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_A italic_t italic_t start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = { italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT [ italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ] | italic_i ∈ { 1 , 2 , … , italic_N } } (27)

where ci[predt]subscript𝑐𝑖delimited-[]𝑝𝑟𝑒superscript𝑑𝑡{{c}_{i}}\left[pre{{d}^{t}}\right]italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ] represents the scalar of the final category, which is also the maximum value of the classification vector cisubscript𝑐𝑖{{c}_{i}}italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Attct𝐴𝑡superscriptsubscript𝑡𝑐𝑡Att_{c}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the decision vector of the category vector group {cit|i{1,2,,N}}conditional-setsuperscriptsubscript𝑐𝑖𝑡𝑖12𝑁\left\{c_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}{ italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } }, and Attxt𝐴𝑡superscriptsubscript𝑡𝑥𝑡Att_{x}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the decision vector of the sample set {xit|i{1,2,,N}}conditional-setsuperscriptsubscript𝑥𝑖𝑡𝑖12𝑁\left\{x_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}{ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } }. Similarly, the value of predt𝑝𝑟𝑒superscript𝑑𝑡pre{{d}^{t}}italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT can be changed to observe the process of the model classifying the slee** frame into other categories.

Specifically, researchers can adjust the observed metrics in order to uncover alternative paths and inference processes in decision-making. By analyzing the features and weights that the model utilizes to classify sleep frames into different categories, researchers can gain a deeper understanding of the decision-making of the model approach and the factors involved in various scenarios. This analysis is essential for revealing the decision behavior of the model, as well as the perception and interpretation of different categories.

Additionally, to facilitate the visualization of the decision vector Attxt𝐴𝑡superscriptsubscript𝑡𝑥𝑡Att_{x}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT in the form of a heatmap, it is transformed during display as follows:

Attshowt=interp(relu(Attxtmax(Attxt)))𝐴𝑡superscriptsubscript𝑡𝑠𝑜𝑤𝑡𝑖𝑛𝑡𝑒𝑟𝑝𝑟𝑒𝑙𝑢𝐴𝑡superscriptsubscript𝑡𝑥𝑡𝐴𝑡superscriptsubscript𝑡𝑥𝑡Att_{show}^{t}=interp\left(relu\left(\frac{Att_{x}^{t}}{\max\left(Att_{x}^{t}% \right)}\right)\right)italic_A italic_t italic_t start_POSTSUBSCRIPT italic_s italic_h italic_o italic_w end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_i italic_n italic_t italic_e italic_r italic_p ( italic_r italic_e italic_l italic_u ( divide start_ARG italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG start_ARG roman_max ( italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) end_ARG ) ) (28)

where the max\maxroman_max function is used to obtain the maximum value in the decision vector Attxt𝐴𝑡superscriptsubscript𝑡𝑥𝑡Att_{x}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT (the maximum scalar in a one-dimensional matrix). The relu𝑟𝑒𝑙𝑢reluitalic_r italic_e italic_l italic_u function sets negative numbers in the one-dimensional vector to zero. The expression relu(Attxtmax(Attxt))𝑟𝑒𝑙𝑢𝐴𝑡superscriptsubscript𝑡𝑥𝑡𝐴𝑡superscriptsubscript𝑡𝑥𝑡relu\left(\frac{Att_{x}^{t}}{\max\left(Att_{x}^{t}\right)}\right)italic_r italic_e italic_l italic_u ( divide start_ARG italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG start_ARG roman_max ( italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) end_ARG ) maps the decision vector Attxt𝐴𝑡superscriptsubscript𝑡𝑥𝑡Att_{x}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT to the [0,1] interval. The interp𝑖𝑛𝑡𝑒𝑟𝑝interpitalic_i italic_n italic_t italic_e italic_r italic_p function is an interpolation function used to adjust the vector’s size to improve the resolution of the heatmap. Attshowt𝐴𝑡superscriptsubscript𝑡𝑠𝑜𝑤𝑡Att_{show}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_s italic_h italic_o italic_w end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is the final one-dimensional decision vector used for plotting the heatmap for observation.

5.2 Sleep Staging Decision Model Based on Feature Vectors

Expanding on section 5.1, this section develops a feature vector-based decision model aimed at addressing the limitations of the vote-based model’s voting (global average pooling) structure. The vote-based model utilizes shorter feature vectors (equivalent in length to the category vector) in each sample block for global average pooling, leading to reduced information during decision-making. Therefore, the feature vector-based decision model employs a simple architecture to integrate the feature vectors of all sample blocks, resulting in a more informative and comprehensive global feature representation.

5.2.1 Calculating Decision Vectors Based on Forward Propagation

The feature vector-based decision model involves dividing the data of each single sleep frame Xtsuperscript𝑋𝑡{{X}^{t}}italic_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT into N𝑁Nitalic_N sample blocks xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. Each sample block is then transformed into the corresponding feature vector f_repit𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡f\_rep_{i}^{t}italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT through the representation learning module. Unlike the previous model, rather than classifying each sample block individually, all feature vectors are merged into a global feature vector using global average pooling. The final classification result is then produced through a fully connected layer. The process is depicted in Figure 9.

Refer to caption
Figure 9: The forward process

The detailed forward process of the model architecture is as follows:

fglobalt=1Ni=1Nf_repitsuperscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡1𝑁superscriptsubscript𝑖1𝑁𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡f_{global}^{t}=\frac{1}{N}\sum\limits_{i=1}^{N}{f\_rep_{i}^{t}}italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT (29)
ct=classifier(fglobalt)superscript𝑐𝑡𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟superscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡{{c}^{t}}=classifier(f_{global}^{t})italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r ( italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (30)
predt=argmax(ct)𝑝𝑟𝑒superscript𝑑𝑡𝑎𝑟𝑔𝑚𝑎𝑥superscript𝑐𝑡pre{{d}^{t}}=argmax({{c}^{t}})italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_a italic_r italic_g italic_m italic_a italic_x ( italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (31)

where the feature vectors f_repit𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡f\_rep_{i}^{t}italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT corresponding to all sample blocks xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT are averaged to obtain fglobaltsuperscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡f_{global}^{t}italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. fglobaltsuperscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡f_{global}^{t}italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the global feature vector for this sleep frame. The category vector ctsuperscript𝑐𝑡{{c}^{t}}italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is obtained by passing the global feature vector fglobaltsuperscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡f_{global}^{t}italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT through a single-layer fully connected classifier classifier𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟classifieritalic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r. The final classification result predt𝑝𝑟𝑒superscript𝑑𝑡pre{{d}^{t}}italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is determined by identifying the index of the maximum value in the classification vector ctsuperscript𝑐𝑡{{c}^{t}}italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. The computation process of the fully connected classifier classifier𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟classifieritalic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r is as follows:

classifier(fglobalt)=fglobalt𝐖𝐟𝐜𝐓+𝐛𝐟𝐜𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟superscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡superscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡superscriptsubscript𝐖𝐟𝐜𝐓subscript𝐛𝐟𝐜classifier(f_{global}^{t})=f_{global}^{t}{{\mathbf{W}}_{\mathbf{fc}}}^{\mathbf% {T}}+{{\mathbf{b}}_{\mathbf{fc}}}italic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r ( italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) = italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_T end_POSTSUPERSCRIPT + bold_b start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT (32)

where 𝐖𝐟𝐜subscript𝐖𝐟𝐜{{\mathbf{W}}_{\mathbf{fc}}}bold_W start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT and 𝐛𝐟𝐜subscript𝐛𝐟𝐜{{\mathbf{b}}_{\mathbf{fc}}}bold_b start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT, the symbols W𝑊Witalic_W and b𝑏bitalic_b denote the weight and bias parameters of the fully connected layer of the classifier, respectively. If L𝐿Litalic_L represents the length of fglobaltsuperscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡f_{global}^{t}italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT and C𝐶Citalic_C signifies the length of ctsuperscript𝑐𝑡{{c}^{t}}italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT, then 𝐖𝐟𝐜subscript𝐖𝐟𝐜{{\mathbf{W}}_{\mathbf{fc}}}bold_W start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT is a two-dimensional matrix with dimensions C×L𝐶𝐿C\times Litalic_C × italic_L, and 𝐛𝐟𝐜subscript𝐛𝐟𝐜{{\mathbf{b}}_{\mathbf{fc}}}bold_b start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT is a one-dimensional matrix with a length of C𝐶Citalic_C.

Attf_globalt=fglobalt𝐖𝐟𝐜[predt]T+𝐛𝐟𝐜[predt]𝐴𝑡superscriptsubscript𝑡𝑓_𝑔𝑙𝑜𝑏𝑎𝑙𝑡superscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡subscript𝐖𝐟𝐜superscriptdelimited-[]𝑝𝑟𝑒superscript𝑑𝑡Tsubscript𝐛𝐟𝐜delimited-[]𝑝𝑟𝑒superscript𝑑𝑡Att_{f\_global}^{t}=f_{global}^{t}\cdot{{\mathbf{W}}_{\mathbf{fc}}}{{[pre{{d}^% {t}}]}^{\text{T}}}+{{\mathbf{b}}_{\mathbf{fc}}}[pre{{d}^{t}}]italic_A italic_t italic_t start_POSTSUBSCRIPT italic_f _ italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ⋅ bold_W start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT [ italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT + bold_b start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT [ italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ] (33)
fglobalt𝐖𝐟𝐜[predt]Tsimilar-toabsentsuperscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡subscript𝐖𝐟𝐜superscriptdelimited-[]𝑝𝑟𝑒superscript𝑑𝑡T\sim f_{global}^{t}\cdot{{\mathbf{W}}_{\mathbf{fc}}}{{[pre{{d}^{t}}]}^{\text{T% }}}∼ italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ⋅ bold_W start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT [ italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT (34)
Attxt=Attft=1Ni=1N{fit/fglobaltAttf_globalt}𝐴𝑡superscriptsubscript𝑡𝑥𝑡𝐴𝑡superscriptsubscript𝑡𝑓𝑡1𝑁superscriptsubscript𝑖1𝑁superscriptsubscript𝑓𝑖𝑡superscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡𝐴𝑡superscriptsubscript𝑡𝑓_𝑔𝑙𝑜𝑏𝑎𝑙𝑡Att_{x}^{t}=Att_{f}^{t}=\frac{1}{N}\sum\limits_{i=1}^{N}{\left\{f_{i}^{t}/f_{% global}^{t}\cdot Att_{f\_global}^{t}\right\}}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_A italic_t italic_t start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT { italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT / italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ⋅ italic_A italic_t italic_t start_POSTSUBSCRIPT italic_f _ italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT } (35)
1Ni=1N{fit𝐖𝐟𝐜[predt]T}similar-toabsent1𝑁superscriptsubscript𝑖1𝑁superscriptsubscript𝑓𝑖𝑡subscript𝐖𝐟𝐜superscriptdelimited-[]𝑝𝑟𝑒superscript𝑑𝑡T\sim\frac{1}{N}\sum\limits_{i=1}^{N}{\left\{f_{i}^{t}\cdot{{\mathbf{W}}_{% \mathbf{fc}}}{{[pre{{d}^{t}}]}^{\text{T}}}\right\}}∼ divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT { italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ⋅ bold_W start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT [ italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT } (36)

where Attf_globalt𝐴𝑡superscriptsubscript𝑡𝑓_𝑔𝑙𝑜𝑏𝑎𝑙𝑡Att_{f\_global}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_f _ italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the decision vector of the global vector fglobaltsuperscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡f_{global}^{t}italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT for the classification result pred𝑝𝑟𝑒𝑑preditalic_p italic_r italic_e italic_d, with a length of L𝐿Litalic_L. 𝐖𝐟𝐜[predt]subscript𝐖𝐟𝐜delimited-[]𝑝𝑟𝑒superscript𝑑𝑡{{\mathbf{W}}_{\mathbf{fc}}}[pre{{d}^{t}}]bold_W start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT [ italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ] is the predt𝑝𝑟𝑒superscript𝑑𝑡pre{{d}^{t}}italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT-th row of the weight matrix of the fully connected layer, which is a one-dimensional vector of length L𝐿Litalic_L. 𝐛𝐟𝐜[predt]subscript𝐛𝐟𝐜delimited-[]𝑝𝑟𝑒superscript𝑑𝑡{{\mathbf{b}}_{\mathbf{fc}}}[pre{{d}^{t}}]bold_b start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT [ italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ] is the predt𝑝𝑟𝑒superscript𝑑𝑡pre{{d}^{t}}italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT-th row of the bias matrix of the fully connected layer, which is a scalar (a constant during model inference). Attft𝐴𝑡superscriptsubscript𝑡𝑓𝑡Att_{f}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the decision vector of the feature vector group {fit|i{1,2,,N}}conditional-setsuperscriptsubscript𝑓𝑖𝑡𝑖12𝑁\left\{f_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}{ italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } }, and Attxt𝐴𝑡superscriptsubscript𝑡𝑥𝑡Att_{x}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT represents the decision vector of the sample set {xit|i{1,2,,N}}conditional-setsuperscriptsubscript𝑥𝑖𝑡𝑖12𝑁\left\{x_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}{ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } }.

However, we have omitted the bias term 𝐛𝐟𝐜[predt]subscript𝐛𝐟𝐜delimited-[]𝑝𝑟𝑒superscript𝑑𝑡{{\mathbf{b}}_{\mathbf{fc}}}[pre{{d}^{t}}]bold_b start_POSTSUBSCRIPT bold_fc end_POSTSUBSCRIPT [ italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ] for simplicity. The decision vector for the feature vector of each sample block can be calculated using the above method and then visualized to illustrate the model’s decision basis and focus at different sleep stages. Heatmaps are generated using the same method as previously described to visually display the contribution of each sample block to the final decision.

5.2.2 Calculating Decision Vectors Based on Back Propagation

When dealing with intricate neural network models such as those featuring multi-layer fully connected structures or recurrent neural network architectures directly deriving decision vectors through forward propagation is often unfeasible. These models exhibit high non-linearity and complex parameter structures, making the decision-making process challenging to parse.

Neural networks typically employ the backpropagation algorithm to optimize model parameters. This method guides parameter adjustments by computing the gradient of the loss function with respect to the model parameters to minimize the loss. This is valuable during the model training phase, offering insights into the model’s decision process. The backpropagation algorithm can also calculate the gradients of the loss function with respect to not only the model parameters, but also the input features, quantifying the influence of each input feature on the model’s prediction. The paradigm of the back propagation process is as follows:

f_repit=encoder(xit)𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡𝑒𝑛𝑐𝑜𝑑𝑒𝑟superscriptsubscript𝑥𝑖𝑡f\_rep_{i}^{t}=encoder(x_{i}^{t})italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_e italic_n italic_c italic_o italic_d italic_e italic_r ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (37)
ct=decoder({f_repit|i{1,2,,N}})superscript𝑐𝑡𝑑𝑒𝑐𝑜𝑑𝑒𝑟conditional-set𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡𝑖12𝑁{{c}^{t}}=decoder(\left\{f\_rep_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\})italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_d italic_e italic_c italic_o italic_d italic_e italic_r ( { italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } ) (38)
predt=argmax(ct)𝑝𝑟𝑒superscript𝑑𝑡𝑎𝑟𝑔𝑚𝑎𝑥superscript𝑐𝑡pre{{d}^{t}}=argmax({{c}^{t}})italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_a italic_r italic_g italic_m italic_a italic_x ( italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (39)

where each sample block xitsuperscriptsubscript𝑥𝑖𝑡x_{i}^{t}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is processed by the feature extractor encoder𝑒𝑛𝑐𝑜𝑑𝑒𝑟encoderitalic_e italic_n italic_c italic_o italic_d italic_e italic_r to obtain the corresponding feature vector f_repit𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡f\_rep_{i}^{t}italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. The feature vectors of all sample blocks in a single sleep frame {xit|i{1,2,,N}}conditional-setsuperscriptsubscript𝑥𝑖𝑡𝑖12𝑁\left\{x_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}{ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } are then processed by the classifier decoder𝑑𝑒𝑐𝑜𝑑𝑒𝑟decoderitalic_d italic_e italic_c italic_o italic_d italic_e italic_r to obtain the classification vector ctsuperscript𝑐𝑡{{c}^{t}}italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT for that sleep frame. The final classification result for the sleep frame is predt𝑝𝑟𝑒superscript𝑑𝑡pre{{d}^{t}}italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. The decision vector of the model is derived according to the following paradigm:

Attfset t𝐴𝑡superscriptsubscript𝑡subscript𝑓set 𝑡\displaystyle Att_{f_{-}\text{set }}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT - end_POSTSUBSCRIPT set end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT =ct[pred]fsettabsentsuperscript𝑐𝑡delimited-[]𝑝𝑟𝑒𝑑subscript𝑓𝑠𝑒superscript𝑡𝑡\displaystyle=-\frac{\partial c^{t}[pred]}{\partial f_{-}set^{t}}= - divide start_ARG ∂ italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT [ italic_p italic_r italic_e italic_d ] end_ARG start_ARG ∂ italic_f start_POSTSUBSCRIPT - end_POSTSUBSCRIPT italic_s italic_e italic_t start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG (40)
Attxt𝐴𝑡superscriptsubscript𝑡𝑥𝑡\displaystyle Att_{x}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT =AvgL(f_settAttfsett)absent𝐴𝑣subscript𝑔𝐿subscript𝑓_𝑠𝑒superscript𝑡𝑡𝐴𝑡superscriptsubscript𝑡subscript𝑓𝑠𝑒𝑡𝑡\displaystyle=Avg_{L}\left(f_{\_}set^{t}\cdot Att_{f_{-}set}^{t}\right)= italic_A italic_v italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT _ end_POSTSUBSCRIPT italic_s italic_e italic_t start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ⋅ italic_A italic_t italic_t start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT - end_POSTSUBSCRIPT italic_s italic_e italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT )

where f_sett={f_repit|i{1,2,,N}}𝑓_𝑠𝑒superscript𝑡𝑡conditional-set𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡𝑖12𝑁f\_set^{t}=\left\{f\_rep_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}italic_f _ italic_s italic_e italic_t start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = { italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } represents the feature vectors of all sample blocks. If L𝐿Litalic_L is defined as the length of fglobaltsuperscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡f_{global}^{t}italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT, then f_sett𝑓_𝑠𝑒superscript𝑡𝑡f\_set^{t}italic_f _ italic_s italic_e italic_t start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT can also be considered a feature matrix of size N×L𝑁𝐿N\times Litalic_N × italic_L. Attft𝐴𝑡superscriptsubscript𝑡𝑓𝑡Att_{f}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is the partial derivative of f_sett𝑓_𝑠𝑒superscript𝑡𝑡f\_set^{t}italic_f _ italic_s italic_e italic_t start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT with respect to ct[pred]superscript𝑐𝑡delimited-[]𝑝𝑟𝑒𝑑{{c}^{t}}[pred]italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT [ italic_p italic_r italic_e italic_d ], representing the sensitivity of the classifier decoder𝑑𝑒𝑐𝑜𝑑𝑒𝑟decoderitalic_d italic_e italic_c italic_o italic_d italic_e italic_r to f_sett𝑓_𝑠𝑒superscript𝑡𝑡f\_set^{t}italic_f _ italic_s italic_e italic_t start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT, and is of the same size as f_sett𝑓_𝑠𝑒superscript𝑡𝑡f\_set^{t}italic_f _ italic_s italic_e italic_t start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT. AvgL𝐴𝑣subscript𝑔𝐿Av{{g}_{L}}italic_A italic_v italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT is the average of the decision matrix f_settAttft𝑓_𝑠𝑒superscript𝑡𝑡𝐴𝑡superscriptsubscript𝑡𝑓𝑡f\_set^{t}\cdot Att_{f}^{t}italic_f _ italic_s italic_e italic_t start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ⋅ italic_A italic_t italic_t start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT along the L dimension, yielding the decision vector Attxt𝐴𝑡superscriptsubscript𝑡𝑥𝑡Att_{x}^{t}italic_A italic_t italic_t start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT of length N.

5.3 Sleep Staging Decision Model Based on Temporal Information

The sleep staging task involves capturing the temporal information of adjacent sleep frames. When dealing with data that has temporal characteristics, such as time series data or sequential data, the decision models from section 2 are unable to capture the temporal correlation and dynamic changes in adjacent sleep frames. Therefore, this section introduces a decision model that is temporally correlated, specifically optimized and designed for the characteristics of sequential data.

The temporally correlated decision model leverages temporal information in sequential data to more accurately infer and predict events or states within the sequence. Unlike traditional feature aggregation and global decision-making, temporal models take into account the time order in the data, enabling them to adapt more effectively to the dynamics and trends of time series data. This section will delve into the introduction of a temporally correlated decision model, as illustrated in Figure 10.

Refer to caption
Figure 10: Decision models with temporal correlation

The forward propagation process is as follows:

fglobalt=Seq_local({f_repit|i{1,2,,N}})superscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡𝑆𝑒𝑞_𝑙𝑜𝑐𝑎𝑙conditional-set𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡𝑖12𝑁f_{global}^{t}=Seq\_local\left(\left\{f\_rep_{i}^{t}|i\in\left\{1,2,\ldots,N% \right\}\right\}\right)italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_S italic_e italic_q _ italic_l italic_o italic_c italic_a italic_l ( { italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } } ) (41)
fseqt=Seq_global(fseq(t1),fglobalt,fseq(t+1))superscriptsubscript𝑓𝑠𝑒𝑞𝑡𝑆𝑒𝑞_𝑔𝑙𝑜𝑏𝑎𝑙superscriptsubscript𝑓𝑠𝑒𝑞𝑡1superscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡superscriptsubscript𝑓𝑠𝑒𝑞𝑡1f_{seq}^{t}=Seq\_global(\overrightarrow{f_{seq}^{(t-1)}},f_{global}^{t},% \overleftarrow{f_{seq}^{(t+1)}})italic_f start_POSTSUBSCRIPT italic_s italic_e italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_S italic_e italic_q _ italic_g italic_l italic_o italic_b italic_a italic_l ( over→ start_ARG italic_f start_POSTSUBSCRIPT italic_s italic_e italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT end_ARG , italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , over← start_ARG italic_f start_POSTSUBSCRIPT italic_s italic_e italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT end_ARG ) (42)
ct=classifier(fseqt)superscript𝑐𝑡𝑐𝑙𝑎𝑠𝑠𝑖𝑓𝑖𝑒𝑟superscriptsubscript𝑓𝑠𝑒𝑞𝑡{{c}^{t}}=classifier(f_{seq}^{t})italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_c italic_l italic_a italic_s italic_s italic_i italic_f italic_i italic_e italic_r ( italic_f start_POSTSUBSCRIPT italic_s italic_e italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (43)
predt=argmax(ct)𝑝𝑟𝑒superscript𝑑𝑡𝑎𝑟𝑔𝑚𝑎𝑥superscript𝑐𝑡pre{{d}^{t}}=argmax({{c}^{t}})italic_p italic_r italic_e italic_d start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = italic_a italic_r italic_g italic_m italic_a italic_x ( italic_c start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (44)

where the local temporal learning structure Seq_local𝑆𝑒𝑞_𝑙𝑜𝑐𝑎𝑙Seq\_localitalic_S italic_e italic_q _ italic_l italic_o italic_c italic_a italic_l first integrates the feature vectors of the sample blocks within a single sleep frame Xtsuperscript𝑋𝑡{{X}^{t}}italic_X start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT, represented as {f_repit|i{1,2,,N}}conditional-set𝑓_𝑟𝑒superscriptsubscript𝑝𝑖𝑡𝑖12𝑁\left\{f\_rep_{i}^{t}|i\in\left\{1,2,\ldots,N\right\}\right\}{ italic_f _ italic_r italic_e italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT | italic_i ∈ { 1 , 2 , … , italic_N } }, to obtain a global feature vector fglobaltsuperscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡f_{global}^{t}italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT that contains temporal information. Then, the global temporal learning structure Seq_global𝑆𝑒𝑞_𝑔𝑙𝑜𝑏𝑎𝑙Seq\_globalitalic_S italic_e italic_q _ italic_g italic_l italic_o italic_b italic_a italic_l processes the global feature vector fglobaltsuperscriptsubscript𝑓𝑔𝑙𝑜𝑏𝑎𝑙𝑡f_{global}^{t}italic_f start_POSTSUBSCRIPT italic_g italic_l italic_o italic_b italic_a italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT in conjunction with the forward sequence feature of the previous sleep frame fseq(t1)superscriptsubscript𝑓𝑠𝑒𝑞𝑡1\overrightarrow{f_{seq}^{(t-1)}}over→ start_ARG italic_f start_POSTSUBSCRIPT italic_s italic_e italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT end_ARG and the backward sequence feature of the next sleep frame fseq(t+1)superscriptsubscript𝑓𝑠𝑒𝑞𝑡1\overleftarrow{f_{seq}^{(t+1)}}over← start_ARG italic_f start_POSTSUBSCRIPT italic_s italic_e italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT end_ARG, resulting in a sequence feature that comprehensively incorporates the temporal dependencies.

5.4 Analysis of Interpretability Model Results

In this section, we will conduct a comprehensive analysis and performance comparison of the previously introduced models, namely the vote-based decision model, the feature vector-based decision model, and the temporally correlated decision model. These models underwent rigorous testing using five-fold cross-validation on the initial fold of the Physio2018 dataset. The specific accuracy results are meticulously presented in Table 6 for detailed evaluation and comparison.

Table 6: Number of samples
Method Accuracy
Decision model based on voting 73.2
Decision model based on feature vectors 73.2
Decision model related to time series 81.5
\botrule

In the experiment, it was noted that decision models predicated on voting and feature vectors demonstrated diminished accuracy. This phenomenon primarily emanated from their inadequate capacity to effectively capture both localized and global temporal features. Notwithstanding their disparate structural configurations, both models evinced limitations in assimilating input features during forward propagation, culminating in commensurate accuracy outcomes.

The decision vectors abstracted from a segment of the Physio2018 dataset are delineated in Figure 11. These depictions illustrate the decision vectors corresponding to the 6th, 8th, 70th, and 76th sleep frames in sample tr03-0005 from the Physio2018 dataset. (The choice to feature these sleep frames was predicated on their manifestation of prototypical waveform characteristics, rendering them more amenable to scrutiny.)

The sleep stages are designated as W (wakefulness), N1 (light sleep), N2 (intermediate sleep), and N3 (deep sleep). A single-channel EEG data from the central lead C4-M1 is employed as the data channel. Heatmaps labelled a, b, c, and d correspond to varying models: the vote-based model, feature vector-based forward propagation model, feature vector-based backpropagation model, and temporally correlated decision model. Within the heatmaps, regions with greater darkness signify the features upon which the models concentrate, in accordance with the sleep staging criteria articulated in the AASM scoring manual. The results depicted in Figure 11 illustrate a notable similarity in outcomes between models employing forward propagation (heatmap b) and backpropagation (heatmap c), affirming the effectiveness and accuracy of backpropagation in elucidating the model’s decision-making process.

In Figure 11(a), the decision vectors a, b, c, and d are indicative of the alpha rhythm, as emphasized in the heatmap. This finding aligns with the methodology employed by sleep experts for identifying the W stage as per the AASM scoring manual.

Figure 11(b) displays decision vectors a, b, c, and d pointing towards the low-amplitude mixed-frequency (LAMF) waves. Nevertheless, decision vector d demonstrates limited coverage of the LAMF waves, indicative of a more focused range.

Refer to caption
(a) Comparison of interpretable methods for stage W
Refer to caption
(b) Comparison of interpretable methods for stage N1
Refer to caption
(c) Comparison of interpretable methods for stage N2
Refer to caption
(d) Comparison of interpretable methods for stage N3
Figure 11: The analysis of single channel EEG raw signal

In Figures 11(c) and 11(d), all models successfully identify slow waves. However, the temporally correlated model only partially encompasses all regions of slow waves, potentially attributed to excessive reliance on temporal information.

Despite achieving higher accuracy, the temporally correlated decision model integrates an excessive amount of temporal information from adjacent frames, consequently leading to decreased consistency with the decisions outlined in the AASM scoring manual. Conversely, non-temporally correlated decision models demonstrate more robust alignment with the AASM scoring manual.

6 Conclusion

Sleep is a crucial physiological phenomenon for humans, and the quality of sleep affects various aspects of daily life. Good sleep quality can promote metabolism, maintain cardiovascular health, and improve attention and reaction capabilities. Sleep monitoring is an important means of analyzing sleep quality, with sleep stage classification (sleep staging) being a key component. Therefore, an efficient automatic sleep staging algorithm is of great significance for health monitoring.

This paper proposes a neural network model architecture based on the prior knowledge of sleep experts and deeply analyzes the model’s lightweight nature and interpretability. Additionally, a set of interpretable sleep staging systems has been developed to achieve low-cost, efficient automatic sleep stage classification and describe the sleep cycle, aiming to improve the prevalence and efficiency of sleep monitoring in both home and clinical environments. By adopting the latest machine learning technologies, this paper not only enhances the level of automation in sleep staging but also increases the system’s interpretability, enabling medical professionals to better understand and trust the model’s decision-making process. The main contributions of this paper include:

1. High-performance sleep staging model: The DetectSleepNet model developed in this study utilizes single-channel EEG signals and achieves competitive accuracy on publicly available large sleep staging datasets without additional processing. This result showcases that the proposed method streamlines operations while maintaining high accuracy, offering a robust tool for sleep research and clinical applications.

2. Lightweight model DetectSleepNet-tiny: To better cater to the requirements of home and mobile devices, this study introduces a lightweight model with a minimal parameter count. This model significantly reduces computational resource requirements while retaining high accuracy.

3. Decision visualization: By integrating a model inference head capable of evaluating the attention level of each sleep frame to specific EEG segments, this system enhances result interpretability and provides medical professionals with an intuitive decision support tool.

Through these efforts, we look forward to extending automatic sleep staging technology from clinical laboratories to homes and mobile devices, more broadly serving public health and the development of sleep science.

\bmhead

Acknowledgements Extend the sincerest gratitude to Associate Professor Sun Guobing of Heilongjiang University for the invaluable guidance provided.

Declarations

References

  • \bibcommenthead
  • Berry et al. [2012] Berry, R.B., Brooks, R., Gamaldo, C.E., Harding, S.M., Marcus, C., Vaughn, B.V., et al.: The aasm manual for the scoring of sleep and associated events. Rules, Terminology and Technical Specifications, Darien, Illinois, American Academy of Sleep Medicine 176(2012), 7 (2012)
  • Wolpert [1969] Wolpert, E.A.: A manual of standardized terminology, techniques and scoring system for sleep stages of human subjects. Archives of General Psychiatry 20(2), 246–247 (1969)
  • Iber [2007] Iber, C.: The aasm manual for the scoring of sleep and associated events: rules, terminology, and technical specification. (No Title) (2007)
  • Supratak et al. [2017] Supratak, A., Dong, H., Wu, C., Guo, Y.: Deepsleepnet: A model for automatic sleep stage scoring based on raw single-channel eeg. IEEE Transactions on Neural Systems and Rehabilitation Engineering 25(11), 1998–2008 (2017)
  • Supratak and Guo [2020] Supratak, A., Guo, Y.: Tinysleepnet: An efficient deep learning model for sleep stage scoring based on raw single-channel eeg. In: 2020 42nd Annual International Conference of the IEEE Engineering in Medicine & Biology Society (EMBC), pp. 641–644 (2020). IEEE
  • Seo et al. [2020] Seo, H., Back, S., Lee, S., Park, D., Kim, T., Lee, K.: Intra-and inter-epoch temporal context network (iitnet) using sub-epoch features for automatic sleep scoring on raw single-channel eeg. Biomedical signal processing and control 61, 102037 (2020)
  • Perslev et al. [2019] Perslev, M., Jensen, M., Darkner, S., Jennum, P.J., Igel, C.: U-time: A fully convolutional network for time series segmentation applied to sleep staging. Advances in Neural Information Processing Systems 32 (2019)
  • Phan et al. [2019] Phan, H., Andreotti, F., Cooray, N., Chén, O.Y., De Vos, M.: Seqsleepnet: end-to-end hierarchical recurrent neural network for sequence-to-sequence automatic sleep staging. IEEE Transactions on Neural Systems and Rehabilitation Engineering 27(3), 400–410 (2019)
  • Liu et al. [2021] Liu, F., Huang, X., Chen, Y., Suykens, J.A.: Random features for kernel approximation: A survey on algorithms, theory, and beyond. IEEE Transactions on Pattern Analysis and Machine Intelligence 44(10), 7128–7148 (2021)
  • Phan et al. [2022] Phan, H., Mikkelsen, K., Chén, O.Y., Koch, P., Mertins, A., De Vos, M.: Sleeptransformer: Automatic sleep staging with interpretability and uncertainty quantification. IEEE Transactions on Biomedical Engineering 69(8), 2456–2467 (2022)
  • Lee et al. [2024] Lee, S., Yu, Y., Back, S., Seo, H., Lee, K.: Sleepyco: Automatic sleep scoring with feature pyramid and contrastive learning. Expert Systems with Applications 240, 122551 (2024)
  • Zhang et al. [2018] Zhang, G.-Q., Cui, L., Mueller, R., Tao, S., Kim, M., Rueschman, M., Mariani, S., Mobley, D., Redline, S.: The national sleep research resource: towards a sleep data commons. Journal of the American Medical Informatics Association 25(10), 1351–1358 (2018)
  • Quan et al. [1997] Quan, S.F., Howard, B.V., Iber, C., Kiley, J.P., Nieto, F.J., O’Connor, G.T., Rapoport, D.M., Redline, S., Robbins, J., Samet, J.M., et al.: The sleep heart health study: design, rationale, and methods. Sleep 20(12), 1077–1085 (1997)
  • Ghassemi et al. [2018] Ghassemi, M.M., Moody, B.E., Lehman, L.-W.H., Song, C., Li, Q., Sun, H., Mark, R.G., Westover, M.B., Clifford, G.D.: You snooze, you win: the physionet/computing in cardiology challenge 2018. In: 2018 Computing in Cardiology Conference (CinC), vol. 45, pp. 1–4 (2018). IEEE
  • Kingma and Ba [2014] Kingma, D.P., Ba, J.: Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014)
  • Loshchilov and Hutter [2018] Loshchilov, I., Hutter, F.: Fixing weight decay regularization in adam (2018)
  • Smith [2017] Smith, L.N.: Cyclical learning rates for training neural networks. In: 2017 IEEE Winter Conference on Applications of Computer Vision (WACV), pp. 464–472 (2017). IEEE
  • Jia et al. [2021] Jia, Z., Lin, Y., Wang, J., Wang, X., Xie, P., Zhang, Y.: Salientsleepnet: Multimodal salient wave detection network for sleep staging. arXiv preprint arXiv:2105.13864 (2021)
  • Mousavi et al. [2019] Mousavi, S., Afghah, F., Acharya, U.R.: Sleepeegnet: Automated sleep stage scoring with sequence to sequence deep learning approach. PloS one 14(5), 0216456 (2019)
  • Selvaraju et al. [2017] Selvaraju, R.R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., Batra, D.: Grad-cam: Visual explanations from deep networks via gradient-based localization. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 618–626 (2017)
  • Zhou et al. [2016] Zhou, B., Khosla, A., Lapedriza, A., Oliva, A., Torralba, A.: Learning deep features for discriminative localization. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2921–2929 (2016)