HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: datetime2

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: arXiv.org perpetual non-exclusive license
arXiv:2403.10581v2 [q-bio.QM] 22 Mar 2024
\useunder

\ul \useunder\ul

Large Language Model-informed ECG Dual Attention Network for Heart Failure Risk Prediction

Chen Chen, Lei Li, Marcel Beetz, Abhirup Banerjee, Ramneek Gupta, Vicente Grau Manuscript received March 22, 2024. This work is supported by the Novo Nordisk collaborative research fund.Chen Chen, Lei Li, Marcel Beetz, Abhirup Banerjee, and Vicente Grau are with the Institute of Biomedical Engineering, Department of Engineering Science, University of Oxford, Oxford, United Kingdom. Chen Chen is also an Honorary Research Associate at Imperial College London; Lecturer in Computer Vision at University of Sheffield. This work is mainly done during her time at Oxford. Ramneek Gupta is with the Novo Nordisk Research Centre Oxford (NNRCO). (Corresponding author: Chen Chen, email: [email protected])A. Banerjee is a Royal Society University Research Fellow and is supported by the Royal Society Grant No. URF\R1\221314. The works of A. Banerjee and V. Grau were partially supported by the British Heart Foundation Project under Grant PG/20/21/35082 and by the CompBioMed 2 Centre of Excellence in Computational Biomedicine (European Commission Horizon 2020 research and innovation programme, grant agreement No. 823712).
Abstract

Heart failure (HF) poses a significant public health challenge, with a rising global mortality rate. Early detection and prevention of HF could significantly reduce its impact. We introduce a novel methodology for predicting HF risk using 12-lead electrocardiograms (ECGs). We present a novel, lightweight dual-attention ECG network designed to capture complex ECG features essential for early HF risk prediction, despite the notable imbalance between low and high-risk groups. This network incorporates a cross-lead attention module and twelve lead-specific temporal attention modules, focusing on cross-lead interactions and each lead’s local dynamics. To further alleviate model overfitting, we leverage a large language model (LLM) with a public ECG-Report dataset for pretraining on an ECG-report alignment task. The network is then fine-tuned for HF risk prediction using two specific cohorts from the UK Biobank study, focusing on patients with hypertension (UKB-HYP) and those who have had a myocardial infarction (UKB-MI).The results reveal that LLM-informed pre-training substantially enhances HF risk prediction in these cohorts. The dual-attention design not only improves interpretability but also predictive accuracy, outperforming existing competitive methods with C-index scores of 0.6349 for UKB-HYP and 0.5805 for UKB-MI. This demonstrates our method’s potential in advancing HF risk assessment with clinical complex ECG data.

Index Terms:
Large language model, multi-modal learning, heart failure, risk prediction, interpretable artificial intelligence, electrocardiogram.

I Introduction

Heart failure (HF) is a complex cardiovascular syndrome, where the heart fails to pump sufficient blood to meet the body’s demands. Common causes of HF are cardiac structural and/or functional abnormalities, including heart attack, cardiomyopathy, and high blood pressure. HF is a chronic and progressive disease. In England, admissions due to HF have notably escalated, witnessing an increment from 65,025 in the year 2013/14 to 86,474 in 2018/19, representing a 33% surge, as reported by the British Heart Foundation [1]. It has been found that around 50%percent5050\%50 % of deaths in HF patients presented with a sudden and unexpected pattern [2, 3], which leaves a tremendous burden on patients with HF, their families, and healthcare systems worldwide.

Preventing HF early is crucial to reduce its health and economic impacts, yet HF diagnosis often occurs late when patients have already developed serious symptoms [4, 5]. risk [6, 7, 8, 9, 10]. A promising strategy for improving HF management is the development of risk prediction models for future HF events. These models generate a risk score for patients over a specific timeframe, taking into consideration of the patient’s specific characteristics. With such a personalized risk assessment, more tailored HF management strategies and/or treatment recommendations can be provided. In this context, the low-cost 12-lead electrocardiogram (ECG), a medical test commonly used in clinical practice, serves as a valuable resource for evaluating a patient’s cardiovascular health and uncovering the risk. Recent studies have already found that several markers detected from clinically acquired 12‐lead ECG are associated with future HF events, such as prolonged QRS duration [6, 7, 8, 9, 10], conduction disorders (left, right bundle-branch blocks) [11, 12]. A significant limitation of these previous research lies in its reliance on a limited set of biomarkers (e.g., QRS duration), which are identified through predetermined rules based on ECG data. Moreover, much of this research has employed simple linear models for modeling the risk ratio associated with HF. While linear models offer a straightforward and interpretable framework for risk assessment, they may fail to capture the intricate, complex subtleties and nuances embedded within the ECG for the early-stage risk prediction of HF.

In recent years, deep learning-based approaches with neural networks have shown great capacities to automatically extract features from raw data and utilize them to perform a wide range of tasks. In the field of ECG analysis, deep neural networks have shown competitive performance on a wide range of tasks, including disease classification, waveform prediction, rhythm detection, mortality prediction on and automated report generation, with higher accuracy compared to traditional approaches with handcrafted features [13, 14, 15]. Yet, two prevalent limitations associated with deep learning-based methods are their over-fitting issues as well as poor interpretability. Overfitting occurs when a model directly memorizes the training data instead of learning task relevant features, reducing its ability to perform well on new, unseen data. The large number of parameters and layers in deep neural networks can further exacerbate the risk of overfitting, especially when the available training data is limited. Additionally, the complex architecture of these networks makes it difficult to trace and understand the decision-making process behind their predictions, leading to challenges in interpretability.

In this work, we aim to develop a deep learning-based HF risk prediction model with improved feature learning, higher data efficiency, and explainability. To this end, we design a novel, lightweight ECG dual attention network. This network is capable of capturing intricate cross-lead interactions and local temporal dynamics within each lead. The dual attention mechanism also enables the visualization of lead-wise attention maps and temporal activation across each lead for improved explainability of neural network behaviors. To further alleviate the over-fitting with improved data efficiency and more importantly, to learn clinically useful representations from ECG data for higher precision, we adopt a two-stage training scheme. We first employ a large language model (LLM) to pre-train the network using a large public ECG-Report dataset covering a wide spectrum of diverse diseases and then finetune the network for the HF risk prediction task on two specific cohorts with specific risk factors collected from the UK Biobank study. Specifically, we employ ClinicalBERT [16] to extract text embeddings from ECG report and force the extracted ECG features to be aligned with corresponding text features at the pretraining stage. We hypothesize that such an ECG-Text feature alignment learning paradigm can better facilitate the deep neural network to capture clinically useful patterns, which provide a more holistic picture of patients and potentially yield more accurate risk predictions. To the best of our knowledge, this is the first work that applies LLM-informed pre-training to benefit the training of the downstream ECG-based HF risk prediction model.

To summarize, our contribution is two-fold:

  1. 1.

    A novel deep neural network architecture to enhance both the representation learning of ECG features and the model’s interpretability: The proposed network not only yields a quantitative risk score but also offers qualitative, interpretable insights into the neural network’s reasoning through a dual attention mechanism. This unique feature acts as a transparent medium, enabling both clinicians and readers to observe and understand the intricate relationships between different ECG leads as well as the dynamic temporal patterns within each lead, highlighting the particular leads and time segments that hold the greatest importance for reliable risk prediction.

  2. 2.

    An effective model pretraining strategy with LLM: We design an LLM-informed multi-modal pretraining task so that clinical knowledge can be transferred to the downstream risk prediction task. Training a deep risk prediction network is challenging due to a lack of sufficient event data. In this work, we advocate for the strategic use of large language models, coupled with structured ECG reports with confidence scores, to guide the pretraining process for the benefits of more data-efficient, accurate risk prediction models.

II Related work

II-A Risk prediction

Risk prediction aims to estimate the chance h(t)𝑡h(t)italic_h ( italic_t ) that a patient will have a certain event in an infinitesimal time interval D[t,t+Δt)𝐷𝑡𝑡Δ𝑡D\in[t,t+\Delta t)italic_D ∈ [ italic_t , italic_t + roman_Δ italic_t ): h(t)=limΔt0p(D[t,t+Δt)Dt)Δt𝑡subscriptΔ𝑡0𝑝𝐷conditional𝑡𝑡Δ𝑡𝐷𝑡Δ𝑡h(t)=\lim_{\Delta t\rightarrow 0}\frac{p(D\in[t,t+\Delta t)\mid D\geq t)}{% \Delta t}italic_h ( italic_t ) = roman_lim start_POSTSUBSCRIPT roman_Δ italic_t → 0 end_POSTSUBSCRIPT divide start_ARG italic_p ( italic_D ∈ [ italic_t , italic_t + roman_Δ italic_t ) ∣ italic_D ≥ italic_t ) end_ARG start_ARG roman_Δ italic_t end_ARG, given that the event has not occurred before. A common approach is the Cox proportional hazards formulation (CoxPH) [17]. The CoxPH model is a statistical model for predicting the time-to-event outcomes based on the assumption that the hazard function h(t)𝑡h(t)italic_h ( italic_t ) is proportional to a set of features or variables associated with the subject being studied. Mathematically, it can be defined as: h(t|x)=h0(t)exp(gθ(x))conditional𝑡𝑥subscript0𝑡subscript𝑔𝜃𝑥h(t|x)=h_{0}(t)\exp(g_{\theta}(x))italic_h ( italic_t | italic_x ) = italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_t ) roman_exp ( italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) ), where h(t|x)conditional𝑡𝑥h(t|x)italic_h ( italic_t | italic_x ) models the hazard rate at which events occur at a time t𝑡titalic_t taking the subject information x𝑥xitalic_x into account, h0(t)subscript0𝑡h_{0}(t)italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_t ) is the baseline hazard function shared by all observations, depending only on the time t𝑡titalic_t, and exp(gθ(x))subscript𝑔𝜃𝑥\exp(g_{\theta}(x))roman_exp ( italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) ) is the risk score for observation x𝑥xitalic_x. Conventional CoxPH models assume that the input observation is a set of covariates (e.g., sex, age, medical history, smoking) and constrain the function g()𝑔g(\cdot)italic_g ( ⋅ ) to the linear form: gθ(x)=θx=θ1×x1+θ2×x2++θp×xpsubscript𝑔𝜃𝑥superscript𝜃𝑥subscript𝜃1subscript𝑥1subscript𝜃2subscript𝑥2subscript𝜃𝑝subscript𝑥𝑝g_{\theta}(x)=\theta^{\intercal}x=\theta_{1}\times x_{1}+\theta_{2}\times x_{2% }+\cdots+\theta_{p}\times x_{p}italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = italic_θ start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT italic_x = italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT × italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + ⋯ + italic_θ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT × italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT with θ𝜃\thetaitalic_θ being the weights to the set of p𝑝pitalic_p covariates in x:x1,x2,,xp:𝑥subscript𝑥1subscript𝑥2subscript𝑥𝑝x:{x_{1},x_{2},...,x_{p}}italic_x : italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT [17]. In the existing body of literature pertaining to HF risk prediction, the predominant methodologies employ a predefined set of variables to construct a linear model, subsequently isolating variables that exhibit a high correlation for the purpose of risk prediction. This approach includes parameters such as prolonged QRS duration [6, 7, 8, 9, 10], various conduction disorders (specifically left and right bundle-branch blocks) [11, 12], elongated QT intervals [7], abnormalities in QRS/T angles and T wave patterns [18, 19, 20, 21, 7], and ST-segment depression in the V5 precordial lead [7, 22].

More recently, deep learning approaches such as DeepSurv [23] consider replacing the linear function g()𝑔g(\cdot)italic_g ( ⋅ ) with a neural network f𝑓fitalic_f, which is a deep architecture parameterized by the weights of the network θ𝜃\thetaitalic_θ. This relaxes the risk model to a non-linear form; also, the input can be high-dimensional without needing to be linearly independent, as the neural network is capable of extracting hierarchical features for risk score estimation in a non-linear fashion. Such an approach has been found to have superior performance against traditional approaches [23] on different risk prediction or survival prediction tasks. For example, researchers have successfully applied neural networks to automatically discover latent features from high-dimensional data for risk/survival prediction, such as whole-slide pathology images [24, 25], 4D cardiac shape plus motion information [26], etc. In this work, we adhere to a similar philosophy and concentrate on creating neural networks designed to autonomously extract features from intricate 12-lead ECG waves, bypassing the reliance on a predefined set of ECG parameters.

II-B Large language model for healthcare

Large language models (LLMs) are a class of artificial intelligence (AI) algorithms to understand human language. They can answer questions, provide summaries or translations, and create stories or poems [27]. Recent studies have found that LLMs can be effective in guiding representation learning on image data [28], enabling the knowledge transfer to several downstream vision tasks such as image classification, object detection, and segmentation. In the medical domain, such kind of multi-modal pretraining has been exploited for better understanding of imaging data, such as chest X-rays and magnetic resonance images, to benefit the downstream disease classification and medical image segmentation tasks [29, 30, 31, 32, 33]. Apart from image data, most recently, concurrent works have been made on exploring the connection between natural language and signals (ECG, EEG) for better disease classification [34, 35, 36]. To the best of our knowledge, there is no existing risk prediction work that explores the benefit of combining ECG reports with ECG waves for better representation learning. Compared to disease classification, the task of risk prediction is more challenging due to the lack of event records for training, which amplifies the overfitting issue associated with deep neural networks. Our work provides a promising transfer learning approach to alleviate the need for a large number of event data, by utilizing LLMs and additional public large ECG-report datasets to conduct multi-modal pretraining.

Refer to caption
Figure 1: (a) Overview of the ECG dual attention encoder-based risk prediction network (ECG-DAN). A 12-lead ECG recording 𝐱𝐱\mathbf{x}bold_x is sent to the ECG dual attention encoder, which is capable of simultaneously extracting both cross-lead relationships as well as temporal dynamic patterns within each lead, for better feature aggregation. Then, features from two routes are added and then sent to a max-average pooling layer, producing a flattened feature vector 𝐳ecgsubscript𝐳𝑒𝑐𝑔\mathbf{z}_{ecg}bold_z start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT. Finally, we employ a multi-layer perceptron (MLP) module to map from a high-dimensional feature space into a risk score (scalar) r𝑟ritalic_r for heart failure. (b) Overview of the core attention module used in lead attention and temporal attention modules. See Sec. IIIB for more details.

III Methods

III-A Overview

Assume we have a dataset of N𝑁Nitalic_N triplets {𝐱i,δi,ti}i=1Nsuperscriptsubscriptsubscript𝐱𝑖subscript𝛿𝑖subscript𝑡𝑖𝑖1𝑁\{\mathbf{x}_{i},\delta_{i},t_{i}\}_{i=1}^{N}{ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_δ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT to record the HF events in a population. Here, 𝐱i12×Tsubscript𝐱𝑖superscript12𝑇\mathbf{x}_{i}\in\mathbb{R}^{12\times T}bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 12 × italic_T end_POSTSUPERSCRIPT is a 12-lead ECG signal (I, II, III, AVL, AVR, AVF, V1-6) with a recording length of T𝑇Titalic_T; δisubscript𝛿𝑖\delta_{i}italic_δ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT indicates whether there is known date of HF; tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the number of month to the censoring time if there is no reported HF event during the follow‐up period (δi=0subscript𝛿𝑖0\delta_{i}=0italic_δ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0, right censored) or the number of months until the patient was diagnosed with HF during the follow-up (δi=1subscript𝛿𝑖1\delta_{i}=1italic_δ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1, uncensored). The objective is to have a risk prediction model f𝑓fitalic_f parameterized by θ𝜃\thetaitalic_θ so that it can predict a patient’s risk of HF given the patient’s ECG data. To this end, we design an ECG dual attention network (ECG-DAN), as shown in Fig. 1(a), where the input is a 12-lead ECG median wave recording 𝐱𝐱\mathbf{x}bold_x and the output of the network is a single node r𝑟ritalic_r, which estimates the risk score r=f(𝐱;θ)𝑟𝑓𝐱𝜃r=f(\mathbf{x};\theta)italic_r = italic_f ( bold_x ; italic_θ ).

III-B ECG-DAN network

A signature of ECG-DAN is a dual attention module, which is designed to extract morphological and spatial changes and relationships across different leads as well as the temporal dynamics inside each lead for a more comprehensive understanding of the heart’s electrical activity. We first process each lead signal via a group of 1D convolution layers for noise filtering and feature extraction, which gives Cinsubscript𝐶𝑖𝑛C_{in}italic_C start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT-channel features for each lead at each time point. We then employ a set of K𝐾Kitalic_K residual blocks with K𝐾Kitalic_K 2×2\times2 × down-sampling layers along the temporal dimension to extract features at different scales. The output for each lead consists of feature maps with Coutsubscript𝐶𝑜𝑢𝑡C_{out}italic_C start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT channels and a reduced time dimension of T2K𝑇superscript2𝐾\frac{T}{2^{K}}divide start_ARG italic_T end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT end_ARG: 𝐡iCout×T2Ksubscript𝐡𝑖superscriptsubscript𝐶𝑜𝑢𝑡𝑇superscript2𝐾\mathbf{h}_{i}\in\mathbb{R}^{C_{out}\times\frac{T}{2^{K}}}bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT × divide start_ARG italic_T end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT end_ARG end_POSTSUPERSCRIPT. Those features from all 12 leads 𝐡:{𝐡1,𝐡2,,𝐡12}:𝐡subscript𝐡1subscript𝐡2subscript𝐡12\mathbf{h}:\{\mathbf{h}_{1},\mathbf{h}_{2},...,\mathbf{h}_{12}\}bold_h : { bold_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_h start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_h start_POSTSUBSCRIPT 12 end_POSTSUBSCRIPT } are then chained and sent to the Lead Attention (LA) block, facilitating the learning of cross-lead interactions to enhance feature aggregation globally. Concurrently, a 12-lead Temporal Attention (TA) component is employed to capture crucial temporal patterns within each lead 𝐡isubscript𝐡𝑖\mathbf{h}_{i}bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT along the time dimension. Here, for each lead 𝐡isubscript𝐡𝑖\mathbf{h}_{i}bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we apply an individual temporal attention module TAi𝑖{}_{i}start_FLOATSUBSCRIPT italic_i end_FLOATSUBSCRIPT across the time domain separately, given the fact that different leads correspond to different directions of cardiac activation within three-dimensional space. The outcomes of the two attention modules are added and then summed with a concat-pooling module, producing concatenated flattened features from max-pooling and average pooling operations following common practice from previous ECG analysis work [13]:

𝐡superscript𝐡\displaystyle\small\mathbf{h}^{\prime}bold_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT =𝐡LA+𝐡TAabsentsuperscript𝐡𝐿𝐴superscript𝐡𝑇𝐴\displaystyle=\mathbf{h}^{LA}+\mathbf{h}^{TA}= bold_h start_POSTSUPERSCRIPT italic_L italic_A end_POSTSUPERSCRIPT + bold_h start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT (1)
𝐳ecgsubscript𝐳𝑒𝑐𝑔\displaystyle\mathbf{z}_{ecg}bold_z start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT =Concat(Maxpooling(𝐡);Avg-pooling(𝐡)).absentConcatMaxpoolingsuperscript𝐡Avg-poolingsuperscript𝐡\displaystyle=\textrm{Concat}(\textrm{Maxpooling}(\mathbf{h}^{\prime});\textrm% {Avg-pooling}(\mathbf{h}^{\prime})).= Concat ( Maxpooling ( bold_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ; Avg-pooling ( bold_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) .

With this latent feature 𝐳ecgsubscript𝐳𝑒𝑐𝑔\mathbf{z}_{ecg}bold_z start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT, we then send it to a multi-layer perceptron network (MLP), which consists of three linear layers, reducing the feature by half and then projecting the feature into a 3-dimensional feature space, finally, regressing it to a risk score r𝑟ritalic_r.

Next, we will explain the two key modules for improved feature learning and model explainability: the lead attention module and the temporal attention, in more detail.

III-B1 Lead attention

We apply the multi-head attention-based encoder module introduced in transformers [37] to capture the cross-lead interactions. Specifically, we first reshape the feature matrix (𝐡12×CoutT2K)\mathbf{h}\in\mathbb{R}^{12\times\frac{C_{out}T}{2^{K}}})bold_h ∈ blackboard_R start_POSTSUPERSCRIPT 12 × divide start_ARG italic_C start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT italic_T end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT end_ARG end_POSTSUPERSCRIPT ) (12 is number of leads) where each lead feature is treated as a token. As shown in Fig. 1(b), a sinusoid positional encoding PEPE\operatorname{PE}roman_PE [37] is added to the input feature 𝐡LayerNorm(𝐡+PE(𝐡))𝐡LayerNorm𝐡PE𝐡\mathbf{h}\leftarrow\text{LayerNorm}(\mathbf{h}+\operatorname{PE}(\mathbf{h}))bold_h ← LayerNorm ( bold_h + roman_PE ( bold_h ) ) to encode the contextual information with layer normalization [38], followed by a multi-head self-attention module to capture cross-lead interactions. At a high level, the input to the attention are: key matrix (𝐊12×CoutT2K𝐊superscript12subscript𝐶𝑜𝑢𝑡𝑇superscript2𝐾\mathbf{K}\in\mathbb{R}^{12\times\frac{C_{out}T}{2^{K}}}bold_K ∈ blackboard_R start_POSTSUPERSCRIPT 12 × divide start_ARG italic_C start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT italic_T end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT end_ARG end_POSTSUPERSCRIPT), query matrix (𝐐12×CoutT2K𝐐superscript12subscript𝐶𝑜𝑢𝑡𝑇superscript2𝐾\mathbf{Q}\in\mathbb{R}^{12\times\frac{C_{out}T}{2^{K}}}bold_Q ∈ blackboard_R start_POSTSUPERSCRIPT 12 × divide start_ARG italic_C start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT italic_T end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT end_ARG end_POSTSUPERSCRIPT), and value matrix (𝐕12×CoutT2K𝐕superscript12subscript𝐶𝑜𝑢𝑡𝑇superscript2𝐾\mathbf{V}\in\mathbb{R}^{12\times\frac{C_{out}T}{2^{K}}}bold_V ∈ blackboard_R start_POSTSUPERSCRIPT 12 × divide start_ARG italic_C start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT italic_T end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT end_ARG end_POSTSUPERSCRIPT) where key and query matrices are used to compute the weights to re-weight the value matrix. Mathematically, it can be expressed as:

𝐡LA=softmax(𝐐𝐊Dk)𝐕=𝐀LA𝐕superscript𝐡𝐿𝐴softmaxsuperscript𝐐𝐊topsubscript𝐷𝑘𝐕superscript𝐀𝐿𝐴𝐕\small\mathbf{h}^{LA}=\operatorname{softmax}\left(\frac{\mathbf{QK}^{\top}}{% \sqrt{D_{k}}}\right)\mathbf{V}=\mathbf{A}^{LA}\mathbf{V}bold_h start_POSTSUPERSCRIPT italic_L italic_A end_POSTSUPERSCRIPT = roman_softmax ( divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG end_ARG ) bold_V = bold_A start_POSTSUPERSCRIPT italic_L italic_A end_POSTSUPERSCRIPT bold_V (2)

where 𝐀LA12×12superscript𝐀𝐿𝐴superscript1212\mathbf{A}^{LA}\in\mathbb{R}^{12\times 12}bold_A start_POSTSUPERSCRIPT italic_L italic_A end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 12 × 12 end_POSTSUPERSCRIPT is the attention weight matrix and Dksubscript𝐷𝑘D_{k}italic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is the feature dimension (CoutT2Ksubscript𝐶𝑜𝑢𝑡𝑇superscript2𝐾\frac{C_{out}T}{2^{K}}divide start_ARG italic_C start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT italic_T end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT end_ARG). We use the same input 𝐡𝐡\mathbf{h}bold_h for computing 𝐊,𝐐,𝐕𝐊𝐐𝐕\mathbf{K},\mathbf{Q},\mathbf{V}bold_K , bold_Q , bold_V 111In practice, for computational efficiency, following [37], we apply the multi-head attention trick, which first linearly projects the queries, keys, and values hhitalic_h times with different learnable linear projection matrices to lower dimensions (Dk/hsubscript𝐷𝑘D_{k}/hitalic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT / italic_h) respectively to compute the attention matrix and re-weight the projected value matrix, then concatenate all the output hhitalic_h heads to recover the feature dimension..

In addition to the self-attention, a fully connected feed-forward network (FFN) is applied to refine the re-weighted features with residual connection:

𝐡LAsuperscript𝐡𝐿𝐴\displaystyle\small\mathbf{h}^{LA}bold_h start_POSTSUPERSCRIPT italic_L italic_A end_POSTSUPERSCRIPT =LayerNorm(𝐡+𝐡LA)absentLayerNorm𝐡superscript𝐡𝐿𝐴\displaystyle=\text{LayerNorm}(\mathbf{h}+\mathbf{h}^{LA})= LayerNorm ( bold_h + bold_h start_POSTSUPERSCRIPT italic_L italic_A end_POSTSUPERSCRIPT ) (3)
𝐡LAsuperscript𝐡𝐿𝐴\displaystyle\mathbf{h}^{LA}bold_h start_POSTSUPERSCRIPT italic_L italic_A end_POSTSUPERSCRIPT =FFN(𝐡LA)=𝐡LA+max(0,(𝐡LAW1+b1)W2+b2)absentFFNsuperscript𝐡𝐿𝐴superscript𝐡𝐿𝐴0superscript𝐡𝐿𝐴subscript𝑊1subscript𝑏1subscript𝑊2subscript𝑏2\displaystyle=\operatorname{FFN}(\mathbf{h}^{LA})=\mathbf{h}^{LA}+\max(0,(% \mathbf{h}^{LA}W_{1}+b_{1})W_{2}+b_{2})= roman_FFN ( bold_h start_POSTSUPERSCRIPT italic_L italic_A end_POSTSUPERSCRIPT ) = bold_h start_POSTSUPERSCRIPT italic_L italic_A end_POSTSUPERSCRIPT + roman_max ( 0 , ( bold_h start_POSTSUPERSCRIPT italic_L italic_A end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )

where W1,W2subscript𝑊1subscript𝑊2W_{1},W_{2}italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and b1,b2subscript𝑏1subscript𝑏2b_{1},b_{2}italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are weights and biases of two linear layers in FFN; LayerNorm is a layer normalization layer [38] for feature normalization. In this way, the output feature is adjusted with all the information from other leads into account.

III-B2 12-lead temporal attention (TA)

The structure of the temporal attention module is very similar to the lead attention. The only difference is that now we have 12121212 separate attention modules for different lead features so that each module is adapted to a specific lead to extract temporal dynamic features locally. Specifically, we split the feature h along the lead dimension, and for each lead feature 𝐡isubscript𝐡𝑖\textbf{h}_{i}h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we first add sinusoid positional encoding PE𝑃𝐸PEitalic_P italic_E information along the time domain and then process each lead with a separate temporal attention module where each time point is treated as a token. In other words, the input becomes a sequence of Coutsubscript𝐶𝑜𝑢𝑡C_{out}italic_C start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT-dimensional features with a sequence length of l=T2K𝑙𝑇superscript2𝐾l=\frac{T}{2^{K}}italic_l = divide start_ARG italic_T end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT end_ARG and the corresponding temporal attention matrix for lead i𝑖iitalic_i becomes: 𝐀iTAl×lsubscriptsuperscript𝐀𝑇𝐴𝑖superscript𝑙𝑙\mathbf{A}^{TA}_{i}\in\mathbb{R}^{l\times l}bold_A start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_l × italic_l end_POSTSUPERSCRIPT. Similar to the lead attention module process, we reweight and normalize the lead feature 𝐡isubscript𝐡𝑖\mathbf{h}_{i}bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with the generated temporal attention matrix and then employ an FFN module for feature refinement for each lead. The output of the temporal attention is a concatenation of output from the 12-lead temporal attention modules. The whole process can be defined as follows:

𝐡isubscript𝐡𝑖\displaystyle\mathbf{h}_{i}bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT LayerNorm(𝐡i+PE(𝐡i))absentLayerNormsubscript𝐡𝑖PEsubscript𝐡𝑖\displaystyle\leftarrow\text{LayerNorm}(\mathbf{h}_{i}+\operatorname{PE}(% \mathbf{h}_{i}))← LayerNorm ( bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + roman_PE ( bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) (4)
𝐡iTAsuperscriptsubscript𝐡𝑖𝑇𝐴\displaystyle\mathbf{h}_{i}^{TA}bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT =LayerNorm(𝐡i+𝐀iTA𝐡i)absentLayerNormsubscript𝐡𝑖subscriptsuperscript𝐀𝑇𝐴𝑖superscriptsubscript𝐡𝑖top\displaystyle=\text{LayerNorm}(\mathbf{h}_{i}+\mathbf{A}^{TA}_{i}\mathbf{h}_{i% }^{\top})= LayerNorm ( bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + bold_A start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT )
𝐡iTAsuperscriptsubscript𝐡𝑖𝑇𝐴\displaystyle\mathbf{h}_{i}^{TA}bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT =FFNi(𝐡iTA)absentsubscriptFFN𝑖superscriptsubscript𝐡𝑖𝑇𝐴\displaystyle=\operatorname{FFN}_{i}(\mathbf{h}_{i}^{TA})= roman_FFN start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT )
𝐡TAsuperscript𝐡𝑇𝐴\displaystyle\mathbf{h}^{TA}bold_h start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT =Concat(𝐡1TA,𝐡2TA,,𝐡12TA).\displaystyle=\operatorname{Concat}({\mathbf{h}_{1}^{TA},\mathbf{h}_{2}^{TA},.% ..,\mathbf{h}_{12}^{TA}})^{\top}.= roman_Concat ( bold_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT , bold_h start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT , … , bold_h start_POSTSUBSCRIPT 12 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT .

III-C Training

Refer to caption
Figure 2: Training overview. Our model is a) first pretrained on the ECG-Report alignment task and the signal reconstruction task on a large-scale public dataset (PTB-XL [39, 40]), and then b) finetuned on the heart failure risk prediction task with two specific cohorts from the UK Biobank where the future HF event data is available. Here, in PTB-XL dataset, each report has been abstracted to a set of SCP codes with SCP-ECG statement description and confidence score (annotated by human experts). We construct a structured report based on SCP-ECG protocol [39] and then send it to a frozen LLM to extract clinical knowledge for better representation learning guidance. As one ECG may have multiple SCP-code relavant statements, we extract text features separately and then use confidence-based reweighting to aggregate features for feature summation. See below texts for more details.

Deep learning-based risk prediction often struggles with small datasets, particularly when events are rare, as in our case where HF events are below 5%percent55\%5 %. To overcome this and prevent overfitting, we adopt a two-stage training approach, see Fig. 2. Initially, we train our network on a large, diverse public ECG dataset. Of note, this dataset does not contain any HF records for risk prediction. After that, we initialize the model with the pretrained parameters and then conduct fine-tuning, focusing on risk prediction for particular populations with documented HF from the UK Biobank study [41]. Subsequently, we fine-tune the pretrained model on the UK Biobank data, specifically targeting HF risk prediction. The pretraining incorporates human-verified ECG reports to align features with clinical knowledge, aiming to improve the model’s ability to discern pathological ECG patterns relevant to HF risk.

III-C1 Large language model informed model pre-training

For pretraining, we employed the PTB-XL dataset [39, 40], a large, expert-verified collection of 21,7992179921,79921 , 799 clinical 12-lead ECGs with accompanying text reports. It features annotations by cardiologists according to the SCP-ECG standard and classifies waveforms into five categories: Normal (NORM), Myocardial Infarction (MI), ST/T Change (STTC), Conduction Disturbance (CD), and Hypertrophy, with possible overlap due to concurrent conditions. We followed the dataset creators’ protocol, using folds 1-8 for training and folds 9-10 for validation and testing during pretraining [39].
Extracting latent text code 𝐳textsubscript𝐳𝑡𝑒𝑥𝑡\mathbf{z}_{text}bold_z start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT from reports using large language model: We employ an LLM to extract the knowledge embedded in ECG reports. Specifically, we use the medical domain language model: BioClinical BERT [42] 222https://huggingface.co/emilyalsentzer/Bio_ClinicalBERT, which has been trained on a large number of electronic health records from MIMIC III [43]. We consider two ways of extracting text embeddings 𝐳textsubscript𝐳𝑡𝑒𝑥𝑡\mathbf{z}_{text}bold_z start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT:

  • Latent text code from raw ECG report (raw): For a piece of ECG report y𝑦yitalic_y (in English) 555Since the original reports are written in a mixture of German and English, we used the open-source machine translation tool: EasyNMT 333https://github.com/UKPLab/EasyNMT for batch translation following [34] and then used ChatGPT444https://chat.openai.com/ to further refine those failed cases with a prompt: translate the ECG report into English:#textnormal-#𝑡𝑒𝑥𝑡\#text# italic_t italic_e italic_x italic_t., we simply feed it to the LLM to get 𝐳text=LLM(y)subscript𝐳𝑡𝑒𝑥𝑡LLM𝑦\mathbf{z}_{text}=\operatorname{LLM}(y)bold_z start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT = roman_LLM ( italic_y ), following [35, 32].

  • Latent text code from structured ECG report weighted by confidence (structured with confidence): The original PTB-XL dataset also provides ECG-SCP codes generated from ECG reports. Specifically, each report has been abstracted to a set of SCP codes with SCP-ECG statement description and confidence score (annotated by human experts). In this case, we build a structured sentence with linked SCP statement category information and SCP description for each SCP code: 𝐲(SCP)𝐲𝑆𝐶𝑃\mathbf{y}(SCP)bold_y ( italic_S italic_C italic_P ) = ‘{#\{\#{ #Statement Category(SCP)}:{#\{\#{ #SCP-ECG Statement Description(SCP)}’666Database containing SCP code-statement map**s can be found at https://physionet.org/content/ptb-xl/1.0.1/scp_statements.csv and https://physionet.org/content/ptb-xl/1.0.1/ptbxl_database.csv. For example, as shown in Fig 2(a), given an SCP code: LNGQT, the input becomes: ‘other ST-T descriptive statements: long QT-interval’. We send this type of structured input to the LLM, which gives an embedding 𝐳𝐒𝐂𝐏=LLM(𝐲(SCP))superscript𝐳𝐒𝐂𝐏LLM𝐲𝑆𝐶𝑃\mathbf{z^{SCP}}=\operatorname{LLM}(\mathbf{y}({SCP}))bold_z start_POSTSUPERSCRIPT bold_SCP end_POSTSUPERSCRIPT = roman_LLM ( bold_y ( italic_S italic_C italic_P ) ) for each SCP code. It is important to note that a single recording may encompass multiple SCP codes. In that case, multiple SCP embeddings are initially derived. These embeddings are then aggregated with weights based on corresponding confidence scores c𝑐citalic_c to get the text embedding for the corresponding ECG: 𝐳text=jcj𝐳jSCPcsumsubscript𝐳𝑡𝑒𝑥𝑡subscript𝑗subscript𝑐𝑗subscriptsuperscript𝐳𝑆𝐶𝑃𝑗superscript𝑐𝑠𝑢𝑚\mathbf{z}_{text}=\sum_{j}\frac{c_{j}\mathbf{z}^{SCP}_{j}}{c^{sum}}bold_z start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT divide start_ARG italic_c start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT bold_z start_POSTSUPERSCRIPT italic_S italic_C italic_P end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUPERSCRIPT italic_s italic_u italic_m end_POSTSUPERSCRIPT end_ARG, where csumsuperscript𝑐𝑠𝑢𝑚c^{sum}italic_c start_POSTSUPERSCRIPT italic_s italic_u italic_m end_POSTSUPERSCRIPT is the sum of all confidence scores for all corresponding statements for the input ECG signal.

III-C2 ECG-report alignment loss

To align the ECG to report, similar to [28], we first project the latent ECG code 𝐳ecgsubscript𝐳𝑒𝑐𝑔\mathbf{z}_{ecg}bold_z start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT and the latent text code 𝐳textsubscript𝐳𝑡𝑒𝑥𝑡\mathbf{z}_{text}bold_z start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT to 𝐞ecgsubscript𝐞𝑒𝑐𝑔\mathbf{e}_{ecg}bold_e start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT, 𝐞textsubscript𝐞𝑡𝑒𝑥𝑡\mathbf{e}_{text}bold_e start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT with two learnable projection functions pxsubscript𝑝𝑥{p}_{x}italic_p start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT, pysubscript𝑝𝑦{p}_{y}italic_p start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT respectively, so that the two embeddings 𝐞ecg=px(𝐳ecg)subscript𝐞𝑒𝑐𝑔subscript𝑝𝑥subscript𝐳𝑒𝑐𝑔\mathbf{e}_{ecg}={p}_{x}(\mathbf{z}_{ecg})bold_e start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT = italic_p start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT ), 𝐞text=py(𝐳text)subscript𝐞𝑡𝑒𝑥𝑡subscript𝑝𝑦subscript𝐳𝑡𝑒𝑥𝑡\mathbf{e}_{text}={p}_{y}(\mathbf{z}_{text})bold_e start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT = italic_p start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT ) are of the same dimension, as shown in Fig. 2 (a). We use a distance loss 𝒟𝒟\mathcal{D}caligraphic_D to quantify the dissimilarity between the two:

align=1ni=1n[𝒟((𝐞ecg,𝐞text)i)],subscript𝑎𝑙𝑖𝑔𝑛1𝑛superscriptsubscript𝑖1𝑛delimited-[]𝒟subscriptsubscript𝐞𝑒𝑐𝑔subscript𝐞𝑡𝑒𝑥𝑡𝑖\mathcal{L}_{align}=\frac{1}{n}\sum_{i=1}^{n}[{\mathcal{D}((\mathbf{e}_{ecg},% \mathbf{e}_{text})_{i})}],caligraphic_L start_POSTSUBSCRIPT italic_a italic_l italic_i italic_g italic_n end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT [ caligraphic_D ( ( bold_e start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] , (5)

where 𝒟𝒟\mathcal{D}caligraphic_D is a cosine embedding metric function: 𝒟(𝐞ecg,𝐞text)=1cos(𝐞ecg,𝐞text)𝒟subscript𝐞𝑒𝑐𝑔subscript𝐞𝑡𝑒𝑥𝑡1cossubscript𝐞𝑒𝑐𝑔subscript𝐞𝑡𝑒𝑥𝑡\mathcal{D}(\mathbf{e}_{ecg},\mathbf{e}_{text})=1-\operatorname{cos}(\mathbf{e% }_{ecg},\mathbf{e}_{text})caligraphic_D ( bold_e start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT ) = 1 - roman_cos ( bold_e start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT ), a metric that is commonly used for measuring the distance between two embeddings. The loss value is an average value over a batch of N𝑁Nitalic_N paired embeddings.

III-C3 Pretraining loss

The total loss for pretraining is a combination of the ECG-report alignment loss and a signal reconstruction loss, defined as:

pretraining=recon+alignrecon=1ni=1n(𝐱i𝐱^i2)subscript𝑝𝑟𝑒𝑡𝑟𝑎𝑖𝑛𝑖𝑛𝑔subscript𝑟𝑒𝑐𝑜𝑛subscript𝑎𝑙𝑖𝑔𝑛subscript𝑟𝑒𝑐𝑜𝑛1𝑛superscriptsubscript𝑖1𝑛superscriptdelimited-∥∥subscript𝐱𝑖subscript^𝐱𝑖2\begin{split}\small\mathcal{L}_{pretraining}=\mathcal{L}_{recon}+\mathcal{L}_{% align}\\ \mathcal{L}_{recon}=\frac{1}{n}\sum_{i=1}^{n}(\|\mathbf{x}_{i}-\hat{\mathbf{x}% }_{i}\|^{2})\end{split}start_ROW start_CELL caligraphic_L start_POSTSUBSCRIPT italic_p italic_r italic_e italic_t italic_r italic_a italic_i italic_n italic_i italic_n italic_g end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_r italic_e italic_c italic_o italic_n end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT italic_a italic_l italic_i italic_g italic_n end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL caligraphic_L start_POSTSUBSCRIPT italic_r italic_e italic_c italic_o italic_n end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( ∥ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - over^ start_ARG bold_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_CELL end_ROW (6)

where we compute the mean-squared-error between input 𝐱𝐢subscript𝐱𝐢\mathbf{x_{i}}bold_x start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT and reconstructed signals 𝐱^isubscript^𝐱𝑖\hat{\mathbf{x}}_{i}over^ start_ARG bold_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for every input signal 𝐱𝐢subscript𝐱𝐢\mathbf{x_{i}}bold_x start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT in a batch and average them. Adding a signal reconstruction loss is necessary as it can help uncover latent generic features in the ECG signal and serve as a regularization term to avoid latent space collapse problems. During pre-training, we only update the parameters in the ECG encoder, ECG decoder, and two projectors, while kee** the parameters of the language model frozen for training stability and efficiency, as suggested by prior work [29]. Detailed network structures can be found in the Appendix.

III-C4 Finetuning loss

After pre-training, we copy the model weights to initialize the risk prediction model, see Fig. 2(b) and then finetune the risk prediction network. The finetuning loss is also a multi-task loss, including the self-supervised signal reconstruction loss, and a risk loss [23], which aims to minimize the average negative log partial likelihood of the set of uncensored patients (δi=1subscript𝛿𝑖1\delta_{i}=1italic_δ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1: developed into HF during the follow-up). The risk loss is defined as:

risk=1nδ=1i:δi=1[fθ(𝒙i)logj:tjtiexp(fθ(𝒙j))]subscript𝑟𝑖𝑠𝑘1subscript𝑛𝛿1subscript:𝑖subscript𝛿𝑖1delimited-[]subscript𝑓𝜃subscript𝒙𝑖subscript:𝑗subscript𝑡𝑗subscript𝑡𝑖subscript𝑓𝜃subscript𝒙𝑗\small\mathcal{L}_{risk}=-\frac{1}{n_{\delta=1}}\sum_{i:\delta_{i}=1}\left[f_{% \theta}(\boldsymbol{x}_{i})-\log\sum_{j:t_{j}\geq t_{i}}\exp(f_{\theta}(% \boldsymbol{x}_{j}))\right]caligraphic_L start_POSTSUBSCRIPT italic_r italic_i italic_s italic_k end_POSTSUBSCRIPT = - divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUBSCRIPT italic_δ = 1 end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i : italic_δ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT [ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_log ∑ start_POSTSUBSCRIPT italic_j : italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ≥ italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_exp ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ) ] (7)

where nδ=1subscript𝑛𝛿1{n_{\delta=1}}italic_n start_POSTSUBSCRIPT italic_δ = 1 end_POSTSUBSCRIPT is the number of uncensored subjects in a batch, and fθ(𝒙)subscript𝑓𝜃𝒙f_{\theta}(\boldsymbol{x})italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x ) is a predicted risk score.

The total loss is then defined as:

finetuning=αrecon+(1α)risksubscript𝑓𝑖𝑛𝑒𝑡𝑢𝑛𝑖𝑛𝑔𝛼subscript𝑟𝑒𝑐𝑜𝑛1𝛼subscript𝑟𝑖𝑠𝑘\small\mathcal{L}_{finetuning}=\alpha\mathcal{L}_{recon}+(1-\alpha)\mathcal{L}% _{risk}caligraphic_L start_POSTSUBSCRIPT italic_f italic_i italic_n italic_e italic_t italic_u italic_n italic_i italic_n italic_g end_POSTSUBSCRIPT = italic_α caligraphic_L start_POSTSUBSCRIPT italic_r italic_e italic_c italic_o italic_n end_POSTSUBSCRIPT + ( 1 - italic_α ) caligraphic_L start_POSTSUBSCRIPT italic_r italic_i italic_s italic_k end_POSTSUBSCRIPT (8)

where α𝛼\alphaitalic_α is a trade-off parameter to balance the contribution of two losses.

During model optimization, an issue with risksubscript𝑟𝑖𝑠𝑘\mathcal{L}_{risk}caligraphic_L start_POSTSUBSCRIPT italic_r italic_i italic_s italic_k end_POSTSUBSCRIPT is that it can be very sensitive to the number of uncensored subjects nδ=1subscript𝑛𝛿1{n_{\delta=1}}italic_n start_POSTSUBSCRIPT italic_δ = 1 end_POSTSUBSCRIPT (number of subjects who develop HF in the follow-up) in the training batch. This causes training instability problems if the number becomes zero or jumps from large to small and vice versa between batches. In real-world datasets, this issue is amplified due to a high-class imbalance between censored and uncensored subjects. To address this problem, we, therefore, perform a modified version of stochastic gradient descent for model optimization, where we select a batch of n𝑛nitalic_n observations with stratified random sampling to ensure every batch maintains a comparable ratio of censored to uncensored observations, mirroring the ratio found in the entire training population.

IV Experiments

IV-A Study population

In this study, we focus on two populations: patients with hypertension (HYP) and MI, which are highly related to disease progression to HF. Subjects are selected from the UK Biobank dataset (UKB), which is a large-scale biomedical database containing genetic, demographic, and disease information and is regularly updated with comprehensive follow-up studies, from approx. 500,000 subjects. The UKB dataset consists of a large portion of healthy subjects as well as those with a range of cardiovascular and other diseases. To assess the future risk of HF, our analysis is confined to individuals who have electrocardiogram (ECG) together with imaging data available and have not been diagnosed with HF either prior to or at the time of the ECG evaluation.

IV-A1 UKB-HYP

HF-free subjects with prevalent HYP (had HYP before or during the ECG examination) at baseline time from the UKB dataset are studied [41]. We identified 11,581 HF-free HYP subjects. Follow‐up time was defined as the time from the baseline ECG measurement until a diagnosis of HF or death or the end of follow‐up (January 5, 2023). Most ECG recordings together with images were taken between 2014 and 2021. Records with less than two years of follow-up time were excluded. Among 11581 participants, 162 (1.2%percent1.21.2\%1.2 %) participants developed HF. The median follow-up time is 56 months (4.7 years), and the maximum follow-up time is 87 months (7.3 years).

IV-A2 UKB-MI

HF-free subjects at baseline with prevalent MI records are studied. Similar to the above selection procedure, we identified 800 subjects. Among them, 32 subjects (4%percent44\%4 %) developed HF during the follow-up period. The median follow-up time is 53 months (4.4 years), and the maximum follow-up time is 83 months (6.9 years).

TABLE I: Statistics of studied population(s).
characteristics total had HF during follow-up?
Yes No
UKB-HYP Age at examination##{}^{\#}start_FLOATSUPERSCRIPT # end_FLOATSUPERSCRIPT 66.31 (45-82) 69.57 (49-81) 66.27 (45-82)
Sex (men/women) 6730/4851 109/53 (1.6%/1.1%) 6621/4798
UKB-MI Age at examination##{}^{\#}start_FLOATSUPERSCRIPT # end_FLOATSUPERSCRIPT 68.21 (49-81) 67.28 (50-79) 68.25 (49-81)
Sex (men/women) 650/150 27/5 (4.3%/3.4%) 623/145
##{}^{\#}start_FLOATSUPERSCRIPT # end_FLOATSUPERSCRIPTValues represent mean (minimum-maximum).

Code lists used for the retrieval of disease information can be found in the Appendix.

IV-B Implementation Details

For ECG signals, we used 12-lead median waveforms (covering a single beat) with a frequency of 500 Hz as input. Each lead is preprocessed with z𝑧zitalic_z-score normalization and then zero-padded to a length of 1024102410241024. We employ 5 residual blocks (K=5𝐾5K=5italic_K = 5, Cin=4subscript𝐶𝑖𝑛4C_{in}=4italic_C start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT = 4, Cout=16subscript𝐶𝑜𝑢𝑡16C_{out}=16italic_C start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT = 16) to obtain multi-scale features. The dimension of latent ECG feature 𝐳ecgsubscript𝐳𝑒𝑐𝑔\mathbf{z}_{ecg}bold_z start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT is 512, and the dimension of projected features 𝐞ecgsubscript𝐞𝑒𝑐𝑔\mathbf{e}_{ecg}bold_e start_POSTSUBSCRIPT italic_e italic_c italic_g end_POSTSUBSCRIPT, 𝐞textsubscript𝐞𝑡𝑒𝑥𝑡\mathbf{e}_{text}bold_e start_POSTSUBSCRIPT italic_t italic_e italic_x italic_t end_POSTSUBSCRIPT is 128. The default convolutional kernel size is 5. Further network details are provided in the Appendix. We used a batch size of 128 (n=128𝑛128n=128italic_n = 128) for model updates at pre-training, with random lead masking for data augmentation [44]. AdamW optimizer [45] with stratified batches was used for training. For fine-tuning in post-MI risk prediction (UKB-MI), we maintained a batch size of 128. Given the low HF incidence (<2%absentpercent2<2\%< 2 %) in UKB-HYP, we increased the batch size to 1024 to include enough uncensored subjects for calculating risksubscript𝑟𝑖𝑠𝑘\mathcal{L}_{risk}caligraphic_L start_POSTSUBSCRIPT italic_r italic_i italic_s italic_k end_POSTSUBSCRIPT. Batch-wise dropout was used to stabilize fine-tuning [24]as well as avoid overfitting. The loss function’s weighting parameter (α𝛼\alphaitalic_α) is set at 0.5. Pre-training and fine-tuning were conducted over 300 and 100 epochs, respectively, to ensure convergence.

IV-C Evaluation metrics and evaluation method

We report the concordance index (C-index) [46] as the primary evaluation metric. This metric measures the accuracy of the ranking of predicted time based on the radio of concordant pairs: C-index = # concordant pairs  # concordant pairs + # discordant pairs  C-index  # concordant pairs  # concordant pairs  # discordant pairs \text{ C-index }=\frac{\text{ \# concordant pairs }}{\text{ \# concordant % pairs }+\text{ \# discordant pairs }}C-index = divide start_ARG # concordant pairs end_ARG start_ARG # concordant pairs + # discordant pairs end_ARG. A concordant pair refers to when the predicted time ranking of two subjects aligns with their actual ranking, while a discordant pair means the opposite. A pair (i,j)𝑖𝑗(i,j)( italic_i , italic_j ) is concordant if ti<tjsubscript𝑡𝑖subscript𝑡𝑗t_{i}<t_{j}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT < italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and risk scores ri>rjsubscript𝑟𝑖subscript𝑟𝑗r_{i}>r_{j}italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT > italic_r start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, or vice versa. A value of 1 denotes perfect prediction, while a value of 0.5 indicates a prediction quality equivalent to random guessing.

Robust evaluation with multiple two-fold stratified cross-valiation. Due to the scarcity of uncensored subjects (with HF events), using conventional deep learning data splits (e.g, five-fold cross validation or 7:1:2 for training, validation and testing) would result in too few HF events for reliable C-index evaluation. Hence, we opted for a 1:1 training-to-testing ratio with stratified cross-validation, aligned with patient HF status during follow-up. This approach constituted a two-fold stratified cross-validation, which also helps to maintain the same proportion of each class as in the original dataset in each split. To further avoid over-fitting, 20%percent2020\%20 % of the training data was randomly allocated as a validation set for hyper-parameter search (e.g. choice of the learning rate) and model selection. Specifically, we split the dataset into two folds, one for training and validation, the rest for testing. This process is then repeated once more, with the roles of the two parts reversed. To ensure that the result is not biased by the splitting, we perform the above procedure five times, each with a different split across two datasets (UKB-HYP and UKB-MI). The final model performance is reported as the average (with standard deviation) of these 10 trials.

V Results

V-A LLM-informed pre-training improves the accuracy of the downstream risk prediction

TABLE II: Comparison of risk prediction performance using different pre-training tasks. All experiments were performed using the same proposed network architecture. Reported values are average C-index over 5 cross-validations using different random splits.
Pretraining Tasks Study population
Hypertension Myocardial Infarction
(UKB-HYP) (UKB-MI)
- 0.6122 (0.0190) 0.5065 (0.0776)
SR 0.6327 (0.0165) 0.5069 (0.0770)
SR + Classification 0.6370 (0.0216) 0.5220 (0.0475)
SR + ECG-R Alignment (raw) 0.6088 (0.0189) 0.5796 (0.0570)
SR + ECG-R Alignment 0.6349 (0.0156) 0.5805 (0.0580)
(structured with confidence)
SR: Signal Reconstruction; Classification: ECG Disease Classification;
ECG-R: ECG-Report.

We first compared our proposed pretraining strategy against various pretraining tasks:

  • No pretraining (represented by a hyphen), serving as a baseline, with networks trained for 400 epochs (equivalent to the sum of pretraining and finetuning epochs);

  • Pretraining on signal reconstruction only;

  • Pretraining on signal reconstruction combined with multi-label disease classification (NORM, MI, STTC, CD, hypertrophy), using a cross-entropy loss for classification;

  • Pretraining on signal reconstruction and ECG-Report alignment using raw text reports;

  • Our proposed method, pretraining on signal reconstruction and ECG-Report alignment with structured report and confidence information, as detailed in Sec. III-C1.

Results are shown in Table II. It is clear that the models pretrained using the proposed strategy (signal reconstruction + ECG2Text alignment (structured w/ confidence)) consistently obtain the highest accuracy among the two tasks, obtaining the average C-index: 0.6349 (0.0156) on UKB-HYP and 0.5805 (0.0580) on UKB-MI.

V-B Comparison study on risk prediction using different network architectures

TABLE III: Comparison of risk prediction performance using different types of deep neural networks. By default, all network has been initialized with pre-trained on the proposed ECG-Report alignment and reconstruction tasks (300 epochs) and then finetuned on the risk prediction task (100 epochs). For fair comparison, models without any pre-training were trained for 400 epochs (300+100). To avoid overfitting, models with the highest C-index score on the validation set were selected as the final model. Reported values are average C-index over 5 cross-validations using different random splits (10 trials in total).
Network architectures ##\## parameters LLM-informed Pretraining Study population
Hypertension (UKB-HYP) Myocardial Infarction (UKB-MI)
CNN-VAE ([47, 48, 49, 50]) 7.0M 0.6215 (0.0237) 0.5675 (0.0353)
CNN-VAE ([47, 48, 49, 50]) 7.0M 0.6346 (0.0218) 0.5484 (0.0269)
XResNet1D ( [13]) 1.9M 0.5601 (0.0273) 0.5357 (0.0485)
XResNet1D ([13]) 1.9M 0.6234 (0.0143) 0.5431 (0.0387)
ECG dual attention (The proposed) 1.4M 0.6122 (0.0190) 0.5065 (0.0776)
ECG dual attention (The proposed) 1.4M 0.6349 (0.0156) 0.5805 (0.0580)

We further compare model performance using different encoder architectures, including 1) the encoder used in a variational autoencoder (VAE)-like network architecture, which has been found effective for ECG signal feature representation learning in previous works [47, 48, 49, 50]; 2) XResNet1D [13], which is the top performing network architecture on a wide range of different ECG analysis tasks such as ECG disease classification, age regression, and form/rhythm prediction on the public PTB-XL benchmark dataset [39] and ICBEB2018 dataset [51]. We report their performance trained w/o and w/ the proposed language model informed pre-training strategy with structured SCP report in Table III. The proposed ECG attention network has the fewest parameters. Yet, models with this type of network architecture and the proposed pre-training strategy obtain the highest C-index scores on both two tasks.

V-C Comparison study between traditional ECG parameter-based risk prediction vs ECG dual attention

We further compare our method to the traditional method with well-established ECG parameters. Specifically, we collect a set of ECG parameters from the ECG signals, which have been identified in previous studies for incident HF prediction and mortality estimation [6, 8, 9, 11, 12, 18, 19, 20, 21, 7, 22]. All of these parameters were automatically extracted by the ECG device (using the GE CardioSoft V6 777https://biobank.ndph.ox.ac.uk/ukb/ukb/docs/CardiosoftFormatECG.pdf) with supporting evidence from previous work: 1) Ventricular rate; 2) Left-axis deviation; 3) Right-axis deviation; 4) Prolonged P-wave duration (>120𝑚𝑠absent120𝑚𝑠>120~{}\textit{ms}> 120 ms[22]; 5) Prolonged PR interval (>200𝑚𝑠absent200𝑚𝑠>200~{}\textit{ms}> 200 ms[22]; 6) Prolonged QRS duration (>100𝑚𝑠absent100𝑚𝑠>100~{}\textit{ms}> 100 ms[22]; 7) Prolonged QT interval (460(women)/450(men)𝑚𝑠absent460(women)450(men)𝑚𝑠\geq 460~{}\text{(women)}/450~{}\text{(men)}~{}\textit{ms}≥ 460 (women) / 450 (men) ms using the Framingham formula) [52]; 8) Delayed intrinsicoid deflection (DID time) (the maximum value in leads V5 and V6 >50𝑚𝑠absent50𝑚𝑠>50~{}\textit{ms}> 50 ms[53]; 9) Abnormal P-wave axis (values outside the range of 0superscript00^{\circ}0 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT and 75superscript7575^{\circ}75 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT[54]; 10) Left ventricular hypertrophy [55]; 11) Abnormal QRS-T angle (>77(women)/88(men)absentsuperscript77(women)superscript88(men)>77^{\circ}~{}\text{(women)}/88^{\circ}~{}\text{(men)}> 77 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT (women) / 88 start_POSTSUPERSCRIPT ∘ end_POSTSUPERSCRIPT (men)[22]; 12) Low QRS voltage [56]; 13) ST-T abnormality [56]; 14) Right bundle‐branch block [56]; 15) Left bundle‐branch block [56].

TABLE IV: Comparison of risk prediction performance using traditional ECG parameter-based and our deep learning model (ECG dual attention). Reported values are average C-index over five times of 2-fold cross-validations using different random splits (10 trials in total).
Method Input Model Type Study population
Hypertension
(UKB-HYP)
Myocardial Infarction
(UKB-MI)
Tradiational
15
biomarkers
Linear 0.6149 (0.0125) 0.5398 (0.0200)
ECG dual attention
(the proposed)
12-lead
Median waves
Nonlinear 0.6349 (0.0156) 0.5805 (0.0580)
Refer to caption
Figure 3: Kaplan-Meier risk curve plots for a) conventional parameter model using a composite of ECG measurements, and b) the proposed deep learning-based risk prediction model with ECG dual attention blocks where the model has been pretrained on the ECG-report alignment task. For both models, patients were divided into low- and high-risk groups with a cutoff value referenced from the top 98th𝑡{}^{th}start_FLOATSUPERSCRIPT italic_t italic_h end_FLOATSUPERSCRIPT percentile (for UKB-HYP) or top 96th𝑡{}^{th}start_FLOATSUPERSCRIPT italic_t italic_h end_FLOATSUPERSCRIPT percentile (UKB-MI) risk scores predicted by the model, reflecting the statistics of the datasets in Table I.
Refer to caption
Figure 4: 3D visualization of the last 3-dim hidden feature learned in the risk prediction subnetwork along with the visualization of input ECG waves with lowest predicted risk score (dark purple) and highest predicted risk score (light yellow) on the a) UKB-HYP and b) UKB-MI datasets.

Following [22], we fit these ECG parameters into the conventional CoxPH linear regression model [17] and use it as a strong competing model for HF risk prediction. Table IV shows the performances using our approach and the traditional ECG parameter-based approach, and Fig. 3 shows Kaplan-Meier plots which depict the survival probability estimates over time, stratified by risk groups defined by each model’s predictions. We further plot the features learned in the last hidden layer (3absentsuperscript3\in\mathbb{R}^{3}∈ blackboard_R start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT) of ECGs from the risk prediction branch in the proposed network and representative ECG waves with the lowest/highest risk score in Fig. 4.

V-D Dual attention mechanism improves model stability and interpretability

V-D1 Quantitative ablation study

TABLE V: Ablation Experiments. Numbers in bold represent the highest values, while numbers with underlines denote the second highest.
Method Study population
Hypertension
(UKB-HYP)
Myocardial Infarction
(UKB-MI)
w/o time attention module 0.6038 (0.0243) 0.5855 (0.0353)
w/o lead attention module 0.6018 (0.0418) 0.4920 (0.0884)
w/o lead+time attention module \ul0.6167 (0.0105) 0.5195 (0.0539)
The proposed 0.6349 (0.0156) \ul0.5805 (0.0580)

We also evaluate the impact of the dual attention modules on whether they can help to enhance the accuracy of risk prediction. We conduct the ablation study experiments using the same training strategy with the same network but either the lead and/or the temporal attention module removed. Results are shown in Table V. It can be observed that the propose model containing both attention modules consistently produces the most stable performance, with the best performance on the UKB-HYP and the second best performance on the UKB-MI, yielding the best average performance across the two datasets.

V-D2 Lead attention matrix visualization

Refer to caption
Figure 5: Visualization of lead attention patterns and differences between low-risk and high-risk groups across two, different populations: UKB-HYP and UKB-MI.

We further visualize the average lead attention matrix at the population level, averaging matrices in different risk groups. For both datasets, patients are divided into low- and high-risk groups with a cutoff value referenced from the top 98th𝑡{}^{th}start_FLOATSUPERSCRIPT italic_t italic_h end_FLOATSUPERSCRIPT percentile (for UKB-HYP) or top 96th𝑡{}^{th}start_FLOATSUPERSCRIPT italic_t italic_h end_FLOATSUPERSCRIPT percentile (UKB-MI) risk scores predicted by the model, where the threshold is chosen based on the statistics of the datasets. Of note, all risk scores and attention matrices are obtained using the same cross-validation strategy, where the predicted risk is computed by averaging predictions from models trained with data excluding that subject. Figure 5 illustrates the contribution of each lead for feature re-weighting across all 12-lead features for different populations. We can find similar patterns though the underlying study cohorts are different.

V-D3 Temporal attention activation map visualization

Refer to caption
Figure 6: Visualization of cross-lead (a,b) and 12-lead temporal attention maps (d,e) obtained from a high HF risk ECG with HYP (a,d) and a high HF risk ECGs with MI (b,e). (c) is a schematic standard ECG for illustrative purpose. Source: Wikimedia Commons.

To visualize the temporal attention matrix 𝐀iTAsuperscriptsubscript𝐀𝑖𝑇𝐴\mathbf{A}_{i}^{TA}bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT in the original ECG time length T𝑇Titalic_T, we adapt GradCAM for ECG leads to highlight the model’s focus. We condense the attention matrix by summing column values for attention scores across time, and then map them to the ECG input features, creating Grad-CAM maps. These maps, weighted by their attention scores, reveal the network’s focus areas on the ECG. More details of implementation can be found in the Appendix. The visualization is presented in Fig. 6.

VI Discussion

Refer to caption
Figure 7: U-map visualization of latent code embeddings 𝐳SCPsuperscript𝐳𝑆𝐶𝑃\mathbf{z}^{SCP}bold_z start_POSTSUPERSCRIPT italic_S italic_C italic_P end_POSTSUPERSCRIPT from the large language model using different structured SCP statements. Different colors represent the categorization of statements with disease labels.

Importance of language informed pretraining: In the experiments, we first studied the impact of different pretraining tasks for downstreaming risk predictions and highlighted the value of LLM-informed pretraining for risk prediction in Table II. In general, pre-training tasks enhance the risk prediction performance compared to those without pretraining, especially on the smaller dataset (UKB-MI). It is interesting to see that the performance of the deep risk prediction model (average C-index: 0.5065) can be inferior to the traditional approach (average C-index 0.5398) if without proper pretraining, see Table II and Table IV. This indicates that pre-training is important to alleviate model over-fitting. Our study further highlights the critical role of targeted model pretraining in identifying both generic and pathological features for downstream HF risk prediction. Pretraining on structured ECG reports outperforms methods using simple disease labels or unstructured text inputs. We credit this improved performance to the integration of detailed context from structured reports and supplementary confidence information. Figure 7 shows the UMAP [57] visualization of latent text codes of structured reports 𝐳𝐒𝐂𝐏superscript𝐳𝐒𝐂𝐏\mathbf{z^{SCP}}bold_z start_POSTSUPERSCRIPT bold_SCP end_POSTSUPERSCRIPT extracted by LLM, which suggests that LLM distinguishes between disease-specific embeddings (e.g., separated clusters) and captures the interrelations among various diseases. For instance, the embedding for the normal or healthy category (red dots) is nearer to ST/T wave change embeddings (blue dots) and notably distant from those for MI (green dots), which is aligned with the fact that ST/T variations can be non-pathological, unlike the distinct disease pattern of MI.

ECG dual attention network exhibits high effectiveness of risk stratification: By comparing the results in Table III, it is clear that the suggested approach surpasses leading deep learning approaches in terms of both computational costs and accuracy. Interestingly, we observe that language-informed pretraining, while continuously boosting the performance on the ResNet-based structure and our attention model, does not enhance the CNN-VAE-based model’s performance. This phenomenon probably attributed to the over-parameterization in the CNN-VAE networks (7.0M). In that case, the prioritization of strong regularization to shape the latent representation to like a Gaussian distribution to ensure feature independence [58] becomes crucial. Such regularization might conflict with the language-informed training process.

Results in Table IV also show that our approach surpasses the performance of a traditional methodology, which relies on a predefined set of ECG parameters with a simple linear model, as well as current leading deep learning-based models in terms of both computational costs and performance. The high overlap** region between high-risk and low-risk curves of MI populations (left bottom) when using the traditional approach in Fig. 3 reveals the challenge when the underlying population is limited. By contrast, our method can consistently stratify the patients into different risks with a clear gap. Figure 4 demonstrates that our network effectively discriminates between low-risk (upper right, purple dots) and high-risk ECGs (bottom left, yellow dots) in the latent space, correlating high-risk ECGs with clinical markers like prolonged QRS duration and longer QT intervals, in alignment with established studies [7, 6, 8, 9, 10].

ECG dual attention network exhibits high interpret-ability: Furthermore, the proposed two dual attention modules offer model interpretability, a feature highly valued in the clinical setting. The integration of a dual attention mechanism serves as a window into the network’s thought process, allowing clinicians and readers alike to visualize and comprehend the complex interactions between different ECG leads, as shown in Fig. 5. First, we found the three augmented leads (aVR, aVL, aVF) contributed the least (see blue regions). We believe that this may be because aVR, aVL, and aVF can be derived from Lead I, II, and III using Goldberger’s equations [59], thus containing redundant information for feature aggregation and risk prediction. Second, we found that the network typically pays more attention to chest lead I and precordial leads (V2-V6) (see red regions). In clinical practice, comparing the morphological changes across the precordial leads (V1-V6) can help to identify ST/T wave abnormality, which could be an indicator of future HF [18, 19, 20, 21, 7] and sudden cardiac death [60]. On the other hand, upon comparing the attention patterns derived from high-risk groups to those from low-risk groups, as illustrated in the column titled ‘Difference‘ in Fig. 5, we observed elevated activation values within the high-risk groups. This observation underscores the network’s heightened sensitivity in identifying abnormal features.

For better understanding, we visualizes the cross-lead and temporal activation maps of two high-risk cases in Figure 6. For temporal ones, we apply GradCAM [61] to map** the low-resolution temporal attention maps back to the original input signal level (see appendix for more details). It can be observed that cross-lead module provides an overview to identify the uniqueness or exploit synergy or cross-lead interactions among various leads, whereas temporal lead attention focuses on identifying local areas in each lead important to the prediction. For example, Leads III and II stand out in the cross-lead attention maps (a) and (b) respectively, highlighting unique pathological pattern (prolonged negative QRS in lead III, see (d) or strong abnormal noises in lead II, see original signals in (e), in contrast to other leads. Additionally, the orange diagonal clusters in panel (b) showcases good R wave progression from V2 to V6, see (e). On the other hand, the temporal attention maps in (d) and (e) reveal more intricate details, showing the neural network’s capability to focus on clinically significant features such as P, Q, S, T waves, and R peaks, even with strong noise presented, and highlight pathological abnormalities like abnormal R peaks, T wave iregularities and prolonged QRS and QT intervals. This implies that the dual attention network can autonomously discover clinically relevant biomarkers from ECG data without explicit being taught during its training phase. Future research will aim at collaborating with medical professionals to verify potential novel biomarkers using these visual tools.

Limitations: One of the limitations of the current work is that it only considers information from ECG signals. In the future, we will consider extending our approach to a broader spectrum of data. The input features can be blood test results, demographic (age, sex, ethnicity) [47], smoking, chronic disease condition such as diabetes, imaging-derived features [50, 49, 62] as well as genetic information to create a more holistic and accurate characterization of the patient [8, 9, 63, 64]. Moreover, it is interesting to increase the inclusivity and diversity of our study by considering a broader population base.

Broader Impact:We believe the proposed LLM-informed pretraining scheme is not limited to the risk prediction task. It also holds the potential to inspire new approaches and applications for ECG signal analysis, such as disease diagnosis and ECG segmentation. The LLM model, utilizing textual reports, acts as an additional source of contextual supervision, which guides the ECG network to better understand and extract complex patterns. We would also like to highlight that the utilized public pre-training dataset which comprises a diverse and comprehensive pathologies is crucial for model generalization. In the future, it is also interesting to further enhance the model interpretability through the incorporation of the generative capabilities of LLM, providing structured report for further explanation.

VII Conclusion

This paper presents a study with a novel ECG dual attention network for ECG-based HF risk prediction. This network distinguishes itself through its unique blend of being both lightweight and efficient, outperforming existing models in the field. Its standout feature is the ability to generate not just a quantitative risk score but also to provide a qualitative interpretation of its internal processes. This is achieved through the generation of attention visualization maps, which span across and delve within individual leads, offering a granular view of the network focus and decision-making process. We hope to contribute to the development of more transparent and interpretable AI-assisted systems, fostering trust and facilitating broader adoption in clinical settings. Additionally, the study presents a multi-modal pretraining approach for risk prediction models, which leverages external public ECG reports and confidence scores from diverse populations, combined with advanced large language models, to address the challenges posed by limited future event data in risk prediction tasks.

Acknowledgment

This research has been conducted using the UK Biobank Resource under Application Number ‘40161’. The authors express no conflict of interest.

References

  • [1] British Heart Foundation “Heart failure hospital admissions rise by a third in five years” (Accessed on Oct, 2023) In British Heart Foundation News Archive, 2019 URL: https://www.bhf.org.uk/what-we-do/news-from-the-bhf/news-archive/2019/november/heart-failure-hospital-admissions-rise-by-a-third-in-five-years
  • [2] Gordon F Tomaselli and Douglas P Zipes “What causes sudden death in heart failure?” In Circulation Research 95.8 Am Heart Assoc, 2004, pp. 754–763
  • [3] Rebecca E Lane, Martin R Cowie and Anthony WC Chow “Prediction and prevention of sudden cardiac death in heart failure” In Heart 91.5 BMJ Publishing Group Ltd, 2005, pp. 674–680
  • [4] Clyde W Yancy et al. “Clinical presentation, management, and in-hospital outcomes of patients admitted with acute decompensated heart failure with preserved systolic function: a report from the Acute Decompensated Heart Failure National Registry (ADHERE) Database” In Journal of the American College of Cardiology 47.1, 2006, pp. 76–84
  • [5] Brenda S Thompson and Clyde W Yancy “Immediate vs delayed diagnosis of heart failure: is there a difference in outcomes? results of a harris interactive® Patient Survey” In Journal of Cardiac Failure 10.4 Elsevier, 2004, pp. S125
  • [6] Rose Mary Ferreira Lisboa Silva, Nadya Mendes Kazzaz, Rosália Morais Torres and Maria da Consolação Vieira Moreira “P-wave dispersion and left atrial volume index as predictors in heart failure” In Arquivos Brasileiros de Cardiologia 100.1, 2013, pp. 67–74
  • [7] Beth Triola et al. “Electrocardiographic predictors of cardiovascular outcome in women: the National Heart, Lung, and Blood Institute-sponsored Women’s Ischemia Syndrome Evaluation (WISE) study” In Journal of the American College of Cardiology 46 Elsevier, 2005, pp. 51–56
  • [8] Sadiya S. Khan et al. “10-Year Risk Equations for Incident Heart Failure in the General Population” In Journal of the American College of Cardiology 73.19, 2019, pp. 2388–2397 DOI: 10.1016/j.jacc.2019.02.057
  • [9] Sadiya S Khan et al. “Development and Validation of a Long-Term Incident Heart Failure Risk Model” In Circulation Research 130.2, 2022, pp. 200–209
  • [10] Leonard Ilkhanoff et al. “Association of QRS duration with left ventricular structure and function and risk of heart failure in middle-aged and older adults: the Multi-Ethnic Study of Atherosclerosis (MESA)” In European Journal of Heart Failure 14.11, 2012, pp. 1285–1292
  • [11] Zhu-Ming Zhang et al. “Ventricular conduction defects and the risk of incident heart failure in the Atherosclerosis Risk in Communities (ARIC) Study” In Journal of Cardiac Failure 21.4, 2015, pp. 307–312
  • [12] Zhe-Ming Zhang et al. “Different patterns of bundle-branch blocks and the risk of incident heart failure in the Women’s Health Initiative (WHI) study” In Circulation: Heart Failure 6 Lippincott Williams & Wilkins, 2013, pp. 655–661
  • [13] Nils Strodthoff, Patrick Wagner, Tobias Schaeffter and Wojciech Samek “Deep Learning for ECG Analysis: Benchmarks and Insights from PTB-XL” In IEEE Journal of Biomedical and Health Informatics 25.5, 2021, pp. 1519–1528
  • [14] Venkata Anuhya Ardeti, Venkata Ratnam Kolluru, George Tom Varghese and Rajesh Kumar Patjoshi “An overview on state-of-the-art electrocardiogram signal processing methods: Traditional to AI-based approaches” In Expert Systems with Applications 217, 2023, pp. 119561
  • [15] J Weston Hughes et al. “A deep learning-based electrocardiogram risk score for long term cardiovascular death and disease” In NPJ digital medicine 6.1, 2023, pp. 169
  • [16] Emily Alsentzer et al. “Publicly Available Clinical BERT Embeddings” In Proceedings of the 2nd Clinical Natural Language Processing Workshop Minneapolis, Minnesota, USA: Association for Computational Linguistics, 2019, pp. 72–78 DOI: 10.18653/v1/W19-1909
  • [17] David R Cox “Regression models and life-tables” In Journal of the Royal Statistical Society. Series B (Methodological) 34.2 Wiley Online Library, 1972, pp. 187–220
  • [18] Pentti M Rautaharju et al. “Electrocardiographic predictors of incident heart failure in men and women free from manifest cardiovascular disease (from the Atherosclerosis Risk in Communities [ARIC] Study)” In The American Journal of Cardiology 112 Elsevier, 2013, pp. 843–849
  • [19] Pentti M Rautaharju, Charles Kooperberg, Jennifer C Larson and Andrea LaCroix “Electrocardiographic predictors of incident congestive heart failure and all-cause mortality in postmenopausal women: the Women’s Health Initiative” In Circulation 113 Lippincott Williams & Wilkins, 2006, pp. 481–489
  • [20] Pentti M Rautaharju et al. “Electrocardiographic predictors of new-onset heart failure in men and in women free of coronary heart disease (from the Atherosclerosis in Communities [ARIC] Study)” In The American Journal of Cardiology 100 Elsevier, 2007, pp. 1437–1441
  • [21] Zhu-Ming Zhang et al. “Comparison of the prognostic significance of the electrocardiographic QRS/T angles in predicting incident coronary heart disease and total mortality (from the atherosclerosis risk in communities study)” In The American journal of cardiology 100.5, 2007, pp. 844–849
  • [22] Wesley T O’Neal et al. “Electrocardiographic Predictors of Heart Failure With Reduced Versus Preserved Ejection Fraction: The Multi-Ethnic Study of Atherosclerosis” In Journal of the American Heart Association 6.6, 2017
  • [23] Jared Katzman et al. “DeepSurv: Personalized Treatment Recommender System Using A Cox Proportional Hazards Deep Neural Network” In BMC Medical Research Methodology, 2018
  • [24] Shuai Jiang, Arief A Suriawinata and Saeed Hassanpour “MHAttnSurv: Multi-Head Attention for Survival Prediction Using Whole-Slide Pathology Images” In Computers in Biology and Medicine 158, 2023
  • [25] Zhe Li et al. “Survival Prediction via Hierarchical Multimodal Co-Attention Transformer: A Computational Histology-Radiology Solution” In IEEE Transactions on Medical Imaging PP, 2023
  • [26] Ghalib A Bello et al. “Deep learning cardiac motion analysis for human survival prediction” In Nature Machine Intelligence 1, 2019, pp. 95–104 DOI: 10.1038/s42256-019-0019-2
  • [27] Karan Singhal et al. “Large Language Models Encode Clinical Knowledge” In Nature 620(7972), 2022, pp. 172–180
  • [28] Alec Radford et al. “Learning Transferable Visual Models From Natural Language Supervision”, 2021 arXiv:2103.00020 [cs.CV]
  • [29] Che Liu et al. “M-FLAG: Medical Vision-Language Pre-training with Frozen Language Models and Latent Space Geometry Optimization” In Medical Image Computing and Computer Assisted Intervention, 2023 eprint: 2307.08347
  • [30] Xiaoman Zhang et al. “Knowledge-enhanced visual-language pre-training on chest radiology images” In Nature Communications 14.1, 2023, pp. 4542
  • [31] Jie Liu et al. “CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection” In International Conference on Computer Vision, 2023
  • [32] Jielin Qiu et al. “Multimodal Representation Learning of Cardiovascular Magnetic Resonance Imaging”, 2023 arXiv:2304.07675 [cs.CV]
  • [33] Turgut, Ö. and Müller, P. and Hager, P. and Shit, S. and Starck, S. and Menten, M. J. and Martens, E. and Rueckert, D. “Unlocking the Diagnostic Potential of ECG through Knowledge Transfer from Cardiac MRI” In ArXiv, 2023 URL: https://arxiv.longhoe.net/abs/2308.05764
  • [34] Jielin Qiu et al. “Transfer Knowledge from Natural Language to Electrocardiography: Can We Detect Cardiovascular Disease Through Language Models?” In Findings of the Association for Computational Linguistics: EACL 2023 Dubrovnik, Croatia: Association for Computational Linguistics, 2023, pp. 442–453
  • [35] Che Liu et al. “ETP: Learning Transferable ECG Representations via ECG-Text Pre-training”, 2023 arXiv:2309.07145 [eess.SP]
  • [36] William Han et al. “Can Brain Signals Reveal Inner Alignment with Human Languages?”, 2023 arXiv:2208.06348 [q-bio.NC]
  • [37] Ashish Vaswani et al. “Attention is All you Need” In Conference on Neural Information Processing Systems Curran Associates, Inc., 2017, pp. 5998–6008
  • [38] Jimmy Lei Ba, Jamie Ryan Kiros and Geoffrey E Hinton “Layer Normalization”, 2016 arXiv:1607.06450 [stat.ML]
  • [39] Patrick Wagner et al. “PTB-XL, a large publicly available electrocardiography dataset” In Scientific Data 7.1, 2020, pp. 154
  • [40] Nils Strodthoff et al. “PTB-XL+, a comprehensive electrocardiographic feature dataset” In Scientific Data 10.1, 2023, pp. 279 DOI: 10.1038/s41597-023-02153-8
  • [41] Cathie Sudlow et al. “UK biobank: an open access resource for identifying the causes of a wide range of complex diseases of middle and old age” In PLoS Medicine 12.3, 2015, pp. e1001779
  • [42] Emily Alsentzer et al. “Publicly Available Clinical BERT Embeddings” In Proceedings of the 2nd Clinical Natural Language Processing Workshop Minneapolis, Minnesota, USA: Association for Computational Linguistics, 2019, pp. 72–78 DOI: 10.18653/v1/W19-1909
  • [43] Alistair E W Johnson et al. “MIMIC-III, a freely accessible critical care database” In Scientific Data 3, 2016, pp. 160035
  • [44] Petr Nejedly et al. “Classification of ECG Using Ensemble of Residual CNNs with Attention Mechanism” In 2021 Computing in Cardiology (CinC) 48, 2021, pp. 1–4 DOI: 10.23919/CinC53138.2021.9662723
  • [45] Ilya Loshchilov and Frank Hutter “Decoupled Weight Decay Regularization” In ICLR, 2019
  • [46] F E Harrell et al. “Evaluating the yield of medical tests” In JAMA: the Journal of the American Medical Association 247.18, 1982, pp. 2543–2546
  • [47] Yuling Sang, Marcel Beetz and Vicente Grau “Generation of 12-Lead Electrocardiogram with Subject-Specific, Image-Derived Characteristics Using a Conditional Variational Autoencoder” In International Symposium on Biomedical Imaging IEEE, 2022, pp. 1–5 DOI: 10.1109/ISBI52829.2022.9761431
  • [48] Marcel Beetz, Abhirup Banerjee, Yuling Sang and Vicente Grau “Combined Generation of Electrocardiogram and Cardiac Anatomy Models Using Multi-Modal Variational Autoencoders” In International Symposium on Biomedical Imaging IEEE, 2022, pp. 1–4 DOI: 10.1109/ISBI52829.2022.9761590
  • [49] Marcel Beetz, Abhirup Banerjee and Vicente Grau “Multi-Domain Variational Autoencoders for Combined Modeling of MRI-Based Biventricular Anatomy and ECG-Based Cardiac Electrophysiology” In Frontiers in Physiology 13, 2022 DOI: 10.3389/fphys.2022.886723
  • [50] Lei Li et al. “Deep Computational Model for the Inference of Ventricular Activation Properties” In International Workshop on Statistical Atlases and Computational Models of the Heart 13593 Springer, 2022, pp. 369–380 DOI: 10.1007/978-3-031-23443-9“˙34
  • [51] Feifei Liu et al. “An Open Access Database for Evaluating the Algorithms of Electrocardiogram Rhythm and Morphology Abnormality Detection” In Journal of Medical Imaging and Health Informatics 8.7, 2018, pp. 1368–1373
  • [52] A Sagie et al. “An improved method for adjusting the QT interval for heart rate (the Framingham Heart Study)” In The American Journal of Cardiology 70.7, 1992, pp. 797–801
  • [53] Wesley T O’Neal et al. “Electrocardiographic Time to Intrinsicoid Deflection and Heart Failure: The Multi-Ethnic Study of Atherosclerosis” In Clinical Cardiology 39.9, 2016, pp. 531–536
  • [54] Yabing Li, Amit J Shah and Elsayed Z Soliman “Effect of electrocardiographic P-wave axis on mortality” In The American Journal of Cardiology 113.2, 2014, pp. 372–376
  • [55] R B Devereux et al. “Electrocardiographic detection of left ventricular hypertrophy using echocardiographic determination of left ventricular mass as the reference standard. Comparison of standard criteria, computer diagnosis and physician interpretation” In Journal of the American College of Cardiology 3.1, 1984, pp. 82–87
  • [56] Ronald J Prineas, Richard S Crow and Zhu-Ming Zhang “The Minnesota Code Manual of Electrocardiographic Findings” Springer London, 2009
  • [57] Leland McInnes, John Healy and James Melville “UMAP: Uniform Manifold Approximation and Projection for Dimension Reduction”, 2018 arXiv:1802.03426 [stat.ML]
  • [58] Diederik P Kingma and Max Welling “Auto-Encoding Variational Bayes” In International Conference on Learning Representations, 2014, pp. 1–14
  • [59] “EKG/ECG Leads, Electrodes, Limb Leads, Chest (Precordial) Leads” Accessed on: 2023-07-01, 2023 ECG Waves URL: https://ecgwaves.com/topic/ekg-ecg-leads-electrodes-systems-limb-chest-precordial/
  • [60] Jani T Tikkanen et al. “Electrocardiographic T Wave Abnormalities and the Risk of Sudden Cardiac Death: The Finnish Perspective” In Annals of Noninvasive Electrocardiology: the Official Journal of the International Society for Holter and Noninvasive Electrocardiology, Inc 20.6, 2015, pp. 526–533
  • [61] Ramprasaath R. Selvaraju et al. “Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization” In IEEE International Conference on Computer Vision, ICCV 2017, Venice, Italy, October 22-29, 2017 IEEE Computer Society, 2017, pp. 618–626 DOI: 10.1109/ICCV.2017.74
  • [62] Ciaran Grafton-Clarke et al. “Cardiac magnetic resonance left ventricular filling pressure is linked to symptoms, signs and prognosis in heart failure” In ESC heart failure 10.5, 2023, pp. 3067–3076
  • [63] Maja Cikes et al. “Machine learning-based phenogrou** in heart failure to identify responders to cardiac resynchronization therapy” In European Journal of Heart Failure 21.1, 2019, pp. 74–85
  • [64] Shany Biton et al. “Atrial fibrillation risk prediction from the 12-lead electrocardiogram using digital biomarkers and deep representation learning” In European Heart Journal. Digital health 2.4, 2021, pp. 576–585

-A Implementation of temporal activation maps

To visualize the temporal attention matrix 𝐀iTAl×l(l=T2K<T)superscriptsubscript𝐀𝑖𝑇𝐴superscript𝑙𝑙𝑙𝑇superscript2𝐾𝑇\mathbf{A}_{i}^{TA}\in\mathbb{R}^{l\times l}(l=\frac{T}{2^{K}}<T)bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T italic_A end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_l × italic_l end_POSTSUPERSCRIPT ( italic_l = divide start_ARG italic_T end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT end_ARG < italic_T ) in the original input space of the ECG with a time length of T𝑇Titalic_T, we employ a modified version of GradCAM (Gradient-weighted Class Activation Map**) [61]. This technique allows us to generate visualizations for each ECG lead covering the entire duration of the signal. Specifically, we start by adding up the values in each column of the attention matrix to get a single attention score for every moment along this reduced time dimension. Then, for each of these time points, we find the related input feature from the ECG and use it to create a gradient activation map, map** back to the original input space. We repeat this process for every point in time and then combine all the resulting GradCAM maps. Each map is weighted according to its respective attention score, ensuring that moments with higher attention scores have a greater influence on the final visualization.

TABLE A1: The code lists used for the retrieval of different disease records in UK Biobank database.
Population Category Codes
Heart Failure (HF) algorithmically defined HF Field ID: 131354
self-report Field ID: 20002, code: 1076
ICD 9 Field ID: 41271, codes: 4280, 4281, 4289; Date field: 41281
ICD10 Field ID: 41270, codes: I500, I501, I509, I110, I130, I132; Date field: 41280
Myocardial Infarction (MI) algorithmically defined MI Field ID: 131298
self-report Field ID: 20002, code: 1075
ICD 9 Field ID: 41271, codes: 410, 411, 412, 436; Date field: 41281
ICD10 Field ID: 41270, codes: I21, I22, I23, I24.1, I25.2; Date field: 41280
Hypertension (HYP) algorithmically defined HYP Field ID: 131286, 131288, 131289, 131292, 131293, 131294
self-report Field ID: 20002, code: 1065
ICD 9 Field ID: 41271, codes: 4010, 4011, 4019; Date field: 41281
ICD 10
Field ID: 41270, codes: I110, I13.0, I13.1, I13.2, I13.9, I15.1, I15.2,
I15.8, I15.9; Date field: 41280
TABLE A2: Detailed configurations of the ECG dual attention-based encoder.
ECG signal dual attention encoder
Layer Name Input size Output size PyTorch Like structure Description
Conv1D 12×\times×1 ×\times× 1024 12×\times×4 ×\times× 1024
Conv1d( in_ch= 12, out_ch= 12×\times×4, ks=5, stride=1, p=2,groups=12),
GroupNorm(12,12×\times×4),
nn.GELU(),
Conv1d( in_ch= 12×\times×4, out_ch= 12×\times×4, ks=5, stride=1, p=2,groups=12),
GroupNorm(12,12×\times×4),
nn.GELU(),
using group conv and group normal so that each lead
has a separate set of filters for signal preprocessing
and feature extraction
Reshape 12×\times×4×\times×1024 4×\times× 12×\times× 1024 Reshape reshape it to a matrix (n_feature, n_leads, n_length)
4×\times× 12×\times× 1024 8×\times× 12×\times× 512 ResBlock(in_ch=4, out_ch=8, ks=5, stride=1, p=2, downsample= 2)
8×\times× 12×\times× 512 16×\times× 12×\times× 256 ResBlock(in_ch=8, out_ch=16, ks=5, stride=1, p=2,downsample= 2)
16×\times×12×\times× 256 16×\times×12×\times×128 ResBlock(in_ch=16, out_ch=16, ks=5, stride=1, p=2,downsample= 2)
16×\times×12×\times×128 16×\times×12×\times×64 ResBlock(in_ch=16, out_ch=16, ks=5, stride=1, p=2,downsample= 2)
ResConv1D (K=5; downsample) 16×\times×12×\times× 64 16×\times×12×\times× 32 ResBlock(in_ch=16, out_ch=16, ks=5, stride=1, p=2, downsample= 2) peform feature extraction along the time dimension (the last) while kee** the lead dimension unchanged
DualAttention 16×\times×12×\times× 32 16×\times×12×\times× 32 LeadAttention(input = 12×\times×[16×\times×32]) +Concat(12-lead TimeAttention (input = 32×\times×16)) dual attention module
Max-pooling [16x12]×\times×32 256
Dropout(p=0.2);
Conv1d(in_ch= 16×\times×12, out_ch= 256, ks=1, stride=1, p=0);
AdaptiveMaxPool(output_size=1):
Flatten()
feature dimension reduction
AVG-pooling [16×\times×12]×\times×32 256
Dropout(p=0.2);
Conv1d(in_ch= 16×\times×12, out_ch= 256, ks=1, stride=1, p=0);
AdaptiveAVGPool(output_size=1);
Flatten()
feature dimension reduction
Concat 256; 256 512 Concat Concat features from Max-and AVG pooling
TABLE A3: Detailed configurations of the ECG decoder for signal reconstruction.
Signal decoder
Layer Name Input size Output size PyTorch like structure Description
Linear 512 4×\times×12×\times× 32 Linear (512, 12×\times×4×\times×32) feature refine and expansion
ResConv1D (K=5; upsample) 4×\times×12×\times×32 4×\times×12×\times×64 ResBlock(in_ch=4, out_ch=4, ks=5, stride=1, p=2, upsample= 2) feature expansion along the time dimension
4×\times×12×\times×64 4×\times×12×\times×128 ResBlock(in_ch=4, out_ch=4, ks=5, stride=1, p=2, upsample= 2)
4×\times×12×\times× 128 4×\times×12×\times× 256 ResBlock(in_ch=4, out_ch=4, ks=5, stride=1, p=2, upsample= 2)
4×\times×12×\times× 256 4×\times×12×\times× 512 ResBlock(in_ch=4, out_ch=4, ks=5, stride=1, p=2, upsample= 2)
4×\times×12×\times× 512 1×\times×12×\times× 1024 ResBlock(in_ch=4, out_ch=1, ks=5, stride=1, p=2, upsample= 2)
Conv1D 12×\times×1×\times×1024 12×\times×1×\times×1024
Conv1d( in_ch= 1, out_ch=1, ks=1, stride=1, p=0,groups=12)
InstanceNorm1d(12)
smooth and signal normalization
TABLE A4: Detailed configurations of the risk prediction branch.
Risk prediction branch
Layer Name Input size Output size PyTorch Like structure Description
Linear 512 3
BatchNorm (512);
BatchwiseDropout (0.25);
Linear(512,3);
ReLU;
dimension reduction
Risk prediction 3 1
BatchNorm (3);
Linear(3,1)
regression to risk score
TABLE A5: Detailed configurations of the two feature projectors.
ECG projector pxsubscript𝑝𝑥p_{x}italic_p start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT
Layer Name Input size Output size PyTorch Like structure Description
Projection 512 128
Linear(512, 256);
ReLU;
Linear (256, 128)
feature reduction for feature alignment
Text projector pysubscript𝑝𝑦p_{y}italic_p start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT
Layer Name Input size Output size PyTorch like structure Description
Projection 768 128
Linear(768, 256);
ReLU;
Linear (256, 128)
feature reduction for feature alignment