11institutetext: Bei**g Key Laboratory of Mobile Computing and Pervasive Device, Institute of Computing Technology, Chinese Academy of Sciences, Bei**g 100190, China. 11email: [email protected]
22institutetext: University of Chinese Academy of Sciences, Bei**g 100086, China.

SCMIL: Sparse Context-aware Multiple Instance Learning for Predicting Cancer Survival Probability Distribution in Whole Slide Images

Zekang Yang 1122    Hong Liu (🖂)🖂{}^{(\textrm{\Letter})}start_FLOATSUPERSCRIPT ( 🖂 ) end_FLOATSUPERSCRIPT 11    Xiangdong Wang 11
Abstract

Cancer survival prediction is a challenging task that involves analyzing of the tumor microenvironment within Whole Slide Image (WSI). Previous methods cannot effectively capture the intricate interaction features among instances within the local area of WSI. Moreover, existing methods for cancer survival prediction based on WSI often fail to provide better clinically meaningful predictions. To overcome these challenges, we propose a Sparse Context-aware Multiple Instance Learning (SCMIL) framework for predicting cancer survival probability distributions. SCMIL innovatively segments patches into various clusters based on their morphological features and spatial location information, subsequently leveraging sparse self-attention to discern the relationships between these patches with a context-aware perspective. Considering many patches are irrelevant to the task, we introduce a learnable patch filtering module called SoftFilter, which ensures that only interactions between task-relevant patches are considered. To enhance the clinical relevance of our prediction, we propose a register-based mixture density network to forecast the survival probability distribution for individual patients. We evaluate SCMIL on two public WSI datasets from the The Cancer Genome Atlas (TCGA) specifically focusing on lung adenocarcinom (LUAD) and kidney renal clear cell carcinoma (KIRC). Our experimental results indicate that SCMIL outperforms current state-of-the-art methods for survival prediction, offering more clinically meaningful and interpretable outcomes. Our code is accessible at https://github.com/yang-ze-kang/SCMIL.

Keywords:
Whole slide image Survival prediction Context interaction Sparse attention.

1 Introduction

Using Whole Slide Image (WSI) to predict patient’s cancer survival risk is crucial for health monitoring and personalized treatment in clinical settings. Pathologists typically examine WSIs manually to identify relevant biological features for diagnosis. However, the high resolution of WSI demands considering time and effort to complete the analysis. Automatic diagnosis using deep learning technology has the potential to significantly reduce the workload of pathologists, and many studies have been conducted on this subject [3, 16, 24]. Obtaining fine-grained annotations for high-resolution WSI is challenging, and it is often treated as a weakly supervised learning task. In recent years, researchers have developed various methods to address this challenge, achieving commendable results in cancer diagnosis. Unlike cancer diagnosis, survival risk prediction involves not only extracting biomorphological features but also delving into the interactions between cells and tissues within the tumor microenvironment. Furthermore, providing predictions with enhanced clinical relevance posed an additional challenge in the task of survival prediction [6].

Due to the high resolution of WSIs, it is common practice to segment them into patches with a fixed size. Then a feature extractor, such as ImageNet pretrained ResNet50 [9], is used to extract features from all patches, followed by multiple instance learning [11] for predictive analysis. Methods like AMIL [11], CLAM [16], and DSMIL [15] make predictions by identifying key patches. However, these methods neglect the interaction among patches, which is insufficient for survival prediction tasks. Approaches such as WSISA [25], and DeepAttnMISL [23] use clustering to divide patches into various phenotypes and then extract the features of each phenotype respectively. While these methods consider the morphological relationship between patches, they disregard the spatial connections. Methods like PatchGCN [2], and HGT [10] treat WSIs as point clouds with each patch represented as a node. Graph Convolutional Networks (GCNs) [7, 14, 22] are used to explore the relationships among patches. In these methods, each patch pays attention to the information from neighboring patches, requiring deeper layers to cover a wider area. However, an increase in layer depth leads to a significant rise in computational demands and GPU memory usage. And the mining of the relationship among patches also depends on the selection of aggregation function. TransMIL [18] employs a self-attention mechanism along with the PPEG module to investigate inter-patch relationships. However, to mitigate GPU memory constraints, the author uses linear approximation for self-attention, resulting in a coarse-grained attention between patches.

To address the aforementioned challenges, we propose a Sparse Context-aware Multiple Instance Learning (SCMIL) framework for the prediction of patient survival probability distributions. Our primary contributions are as follows: (1) We design a patch filtering module called SoftFilter to identify task-relevant patches and can be trained through backpropagation. (2) We propose the Sparse Context-aware Self-Attention (SCSA), which uses sparse self-attention to learn the interactions among local patches, while concurrently incorporating both spatial and morphological information to guide the learning of patch interactions in specific areas. (3) We present the Register-based Mixture Density Network (RegisterMDN), which can learn the parameters for each component of a Gaussian Mixture Model from data of cancer patient cohort and utilizes individual patient’s data to forecast the weights of these components. This approach enables the prediction of a tailored survival probability curve for each patient and enhances the interpretability and clinical significance of the model’s predictions.

2 Methodology

Figure 1 depicts the pipeline of our proposed Sparse Context-aware Multiple Instance Learning (SCMIL) framework. WSIs are segmented into fixed-size patches with 256×256 pixels, and irrelevant patches are filtered out. Subsequently, we use the feature extractor ViT [5] (F(x)𝐹𝑥F(x)italic_F ( italic_x ) in Figure 1), which has been pre-trained on a large-scale collection of WSIs using self-supervised learning [12], to extract the features Featn×d𝐹𝑒𝑎𝑡superscript𝑛𝑑Feat\in\mathbb{R}^{n\times d}italic_F italic_e italic_a italic_t ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_d end_POSTSUPERSCRIPT for all patches. The fundamental principle of our SCMIL approach is to identify regions within high-resolution WSI that are most informative for predicting patient survival risk. In these significant areas, we identify biomarkers that are associated with survival risk. By integrating the survival information from the cancer patient cohort, we can subsequently generate a survival probability distribution for the patient. SCMIL framework is mainly composed of three components: SoftFilter, Sparse Context-aware Self-Attention (SCSA), and the Register-based Mixture Density Network (RegisterMDN). SoftFilter help SCSA focus on task-specific areas, and RegisterMDN predicts the survival probability distribution based on the wsi-level feature.

Refer to caption
Figure 1: Overview of the proposed Sparse Context-aware Multiple Instance Learning (SCMIL) framework for predicting cancer survival probability distribution.

2.1 SoftFilter

Within each WSI, there exist numerous patches that are irrelevant to the immediate task. To address this problem, we design a learnable patch filtering module termed SoftFilter. SoftFilter inputs the features of patches into a Multilayer Perceptron (MLP) followed by a Sigmoid activation function to predict the patches’ importance scores ISn×1𝐼𝑆superscript𝑛1IS\in\mathbb{R}^{n\times 1}italic_I italic_S ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × 1 end_POSTSUPERSCRIPT:

IS=Sigmoid(MLP(Feat))𝐼𝑆𝑆𝑖𝑔𝑚𝑜𝑖𝑑𝑀𝐿𝑃𝐹𝑒𝑎𝑡IS=Sigmoid(MLP(Feat))italic_I italic_S = italic_S italic_i italic_g italic_m italic_o italic_i italic_d ( italic_M italic_L italic_P ( italic_F italic_e italic_a italic_t ) ) (1)

Subsequently, the features of each patch are element-wise multiplied by their corresponding importance score to derive the new features Hn×d𝐻superscript𝑛𝑑H\in\mathbb{R}^{n\times d}italic_H ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_d end_POSTSUPERSCRIPT. This process enables the SoftFilter module learnable without requiring patch-level supervision. H𝐻Hitalic_H are then partitioned into task-relevant features Hhighsubscript𝐻𝑖𝑔H_{high}italic_H start_POSTSUBSCRIPT italic_h italic_i italic_g italic_h end_POSTSUBSCRIPT and task-irrelevant features Hlowsubscript𝐻𝑙𝑜𝑤H_{low}italic_H start_POSTSUBSCRIPT italic_l italic_o italic_w end_POSTSUBSCRIPT according to the IS threshold Thre𝑇𝑟𝑒Threitalic_T italic_h italic_r italic_e. The task-relevant features are propagated to the SCSA module for learning the interactions among patches, while the task-irrelevant features bypass this stage.

2.2 Sparse Context-aware Self-Attention (SCSA)

After obtaining the task-relevant features, we devise a Sparse Context-aware Self-Attention (SCSA) module to explore the interactions among patches. The SCSA first cluster the potentially interacting patches into the C𝐶Citalic_C clusters {L1,L2,,LC}subscript𝐿1subscript𝐿2subscript𝐿𝐶\{L_{1},\allowbreak L_{2},...,L_{C}\}{ italic_L start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_L start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT } based on the morphological features and spatial positions of the patches. Specifically, we employ the K-Means clustering algorithm to divide the task-relevant patches and the similarity between patches is obtained by a weighted sum of the cosine similarity of morphological features and the normalized Euclidean distance of spatial positions, with the weights being w1subscript𝑤1w_{1}italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and w2subscript𝑤2w_{2}italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT respectively. To accommodate WSIs of varying sizes, we fix the size of the clusters and derive the number of clusters from the size of the clusters. Then we utilize the Multi-Head Self-Attention mechanism (MHSA) [19] to learn the relationships within each cluster and obtain refined features Lisuperscriptsubscript𝐿𝑖L_{i}^{\prime}italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT

Li=MHSA(Li)+Li,i=1,2,,Cformulae-sequencesuperscriptsubscript𝐿𝑖𝑀𝐻𝑆𝐴subscript𝐿𝑖subscript𝐿𝑖𝑖12𝐶L_{i}^{\prime}=MHSA(L_{i})+L_{i},i=1,2,...,Citalic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_M italic_H italic_S italic_A ( italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_i = 1 , 2 , … , italic_C (2)

Compared with linear self-attentions methods [18, 21], our sparse self-attention approach enables a more fine-grained attention to the relationships among patches. Subsequently, the features from all clusters, along with the task-irrelevant features, are concatenated. The WSI-level features Feat𝐹𝑒𝑎superscript𝑡Feat^{\prime}italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is obtained through an attention-weighted process [11]:

H=Concat(L1,L2,,LC,Hlow)superscript𝐻𝐶𝑜𝑛𝑐𝑎𝑡superscriptsubscript𝐿1superscriptsubscript𝐿2superscriptsubscript𝐿𝐶subscript𝐻𝑙𝑜𝑤H^{\prime}=Concat(L_{1}^{\prime},L_{2}^{\prime},...,L_{C}^{\prime},H_{low})italic_H start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_C italic_o italic_n italic_c italic_a italic_t ( italic_L start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , … , italic_L start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_H start_POSTSUBSCRIPT italic_l italic_o italic_w end_POSTSUBSCRIPT ) (3)
αi=exp(aT(tanh(VHiT)σ(UHiT)))k=1nexp(wT(tanh(VHkT)σ(UHkT)))subscript𝛼𝑖𝑒𝑥𝑝superscript𝑎𝑇direct-product𝑡𝑎𝑛𝑉superscriptsubscript𝐻𝑖𝑇𝜎𝑈superscriptsubscript𝐻𝑖𝑇superscriptsubscript𝑘1𝑛𝑒𝑥𝑝superscript𝑤𝑇direct-product𝑡𝑎𝑛𝑉superscriptsubscript𝐻𝑘𝑇𝜎𝑈superscriptsubscript𝐻𝑘𝑇\alpha_{i}=\frac{exp(a^{T}(tanh(VH_{i}^{\prime T})\odot\sigma(UH_{i}^{\prime T% })))}{\sum_{k=1}^{n}exp(w^{T}(tanh(VH_{k}^{\prime T})\odot\sigma(UH_{k}^{% \prime T})))}italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG italic_e italic_x italic_p ( italic_a start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_t italic_a italic_n italic_h ( italic_V italic_H start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ italic_T end_POSTSUPERSCRIPT ) ⊙ italic_σ ( italic_U italic_H start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ italic_T end_POSTSUPERSCRIPT ) ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e italic_x italic_p ( italic_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_t italic_a italic_n italic_h ( italic_V italic_H start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ italic_T end_POSTSUPERSCRIPT ) ⊙ italic_σ ( italic_U italic_H start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ italic_T end_POSTSUPERSCRIPT ) ) ) end_ARG (4)
Feat=i=1nαiHi𝐹𝑒𝑎superscript𝑡superscriptsubscript𝑖1𝑛subscript𝛼𝑖superscriptsubscript𝐻𝑖Feat^{\prime}=\sum_{i=1}^{n}\alpha_{i}H_{i}^{\prime}italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_H start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT (5)

where U𝑈Uitalic_U, V𝑉Vitalic_V, and a𝑎aitalic_a are learnable parameters, n𝑛nitalic_n is the number of patches within the WSI, direct-product\odot denotes element-wise multiplication, and tanh()𝑡𝑎𝑛tanh()italic_t italic_a italic_n italic_h ( ) is the hyperbolic tangent function. The features Feat𝐹𝑒𝑎superscript𝑡Feat^{\prime}italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT now contain biomarkers relevant to patient survival risk and are instrumental in subsequent survival prediction tasks.

2.3 RegisterMDN

Previous studies [1, 2, 10, 23] for predicting survival risk based on WSIs mainly focus on predicting a time-independent risk value. This approach is of limited utility when considering only the risk value of an individual patient. A more comprehensive prognosis of a patient’s survival risk should take into account the risk values and survival times of other patients within the cancer patient cohort. Moreover, looking at the risk value for a single patient does not provide useful information. To provide more clinically meaningful predictions, we design the Register-based Mixture Density Network (RegisterMDN) inspired by SurvivlMDN [8] to predict the survival probability distribution for an individual patient.

The Mixed Density Network (MDN) translates the input to a probability distribution. We adopt Gaussian distributions as the components of the MDN, assuming that the number of components is K𝐾Kitalic_K. We utilize the WSI-level features Feat𝐹𝑒𝑎superscript𝑡Feat^{\prime}italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, the mean vector Pmsubscript𝑃𝑚P_{m}italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT, and the standard deviation vector Pvsubscript𝑃𝑣P_{v}italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT as the input of our RegisterMDN. Both Pmsubscript𝑃𝑚P_{m}italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT and Pvsubscript𝑃𝑣P_{v}italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT are learnable parameters and learn the survival risk characteristics of the specific cancer during the training phase. Feat𝐹𝑒𝑎superscript𝑡Feat^{\prime}italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, Pmsubscript𝑃𝑚P_{m}italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT, and Pvarsubscript𝑃𝑣𝑎𝑟P_{var}italic_P start_POSTSUBSCRIPT italic_v italic_a italic_r end_POSTSUBSCRIPT through the neural networks to produce the weights λi(Feat)subscript𝜆𝑖𝐹𝑒𝑎superscript𝑡\lambda_{i}(Feat^{\prime})italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ), means μi(Pm)subscript𝜇𝑖subscript𝑃𝑚\mu_{i}(P_{m})italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ), and variances σi(Pv)subscript𝜎𝑖subscript𝑃𝑣\sigma_{i}(P_{v})italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) of the mixture model. Consequently, we can get the Probability Density Function (PDF):

PDF(y|Feat,Pm,Pv)=i=1Kλi(Feat)𝒩(y|μi(Pm),σi(Pv))𝑃𝐷𝐹conditional𝑦𝐹𝑒𝑎superscript𝑡subscript𝑃𝑚subscript𝑃𝑣superscriptsubscript𝑖1𝐾subscript𝜆𝑖𝐹𝑒𝑎superscript𝑡𝒩conditional𝑦subscript𝜇𝑖subscript𝑃𝑚subscript𝜎𝑖subscript𝑃𝑣PDF(y|Feat^{\prime},P_{m},P_{v})=\sum_{i=1}^{K}\lambda_{i}(Feat^{\prime})% \mathcal{N}(y|\mu_{i}(P_{m}),\sigma_{i}(P_{v}))italic_P italic_D italic_F ( italic_y | italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) caligraphic_N ( italic_y | italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) , italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) ) (6)

The patient’s survival time is a positive number, so we define the survival time t=g(x)=log(1+exp(y))𝑡𝑔𝑥𝑙𝑜𝑔1𝑒𝑥𝑝𝑦t=g(x)=log(1+exp(y))italic_t = italic_g ( italic_x ) = italic_l italic_o italic_g ( 1 + italic_e italic_x italic_p ( italic_y ) ). This transformation enables us to formulate the patient’s Death Probability Density Function (DPDF) and Death Cumulative Density Function (DCDF):

DPDF(t|Feat,Pm,Pv)=|dg1dt|i=1]Kλi(Feat)𝒩(g1(t)|μi(Pm),σi(Pv))DPDF(t|Feat^{\prime},P_{m},P_{v})=|\frac{\mathrm{d}g^{-1}}{\mathrm{d}t}|\sum_{% i=1]}^{K}\lambda_{i}(Feat^{\prime})\mathcal{N}(g^{-1}(t)|\mu_{i}(P_{m}),\sigma% _{i}(P_{v}))italic_D italic_P italic_D italic_F ( italic_t | italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) = | divide start_ARG roman_d italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_t end_ARG | ∑ start_POSTSUBSCRIPT italic_i = 1 ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) caligraphic_N ( italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_t ) | italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) , italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) ) (7)
DCDF(t|Feat,Pm,Pv)=iKλi(Feat)erf(g1(t)μi(x)σi(x))𝐷𝐶𝐷𝐹conditional𝑡𝐹𝑒𝑎superscript𝑡subscript𝑃𝑚subscript𝑃𝑣superscriptsubscript𝑖𝐾subscript𝜆𝑖𝐹𝑒𝑎superscript𝑡erfsuperscript𝑔1𝑡subscript𝜇𝑖𝑥subscript𝜎𝑖𝑥DCDF(t|Feat^{\prime},P_{m},P_{v})=\sum_{i}^{K}\lambda_{i}(Feat^{\prime})% \mathrm{erf}(\frac{g^{-1}(t)-\mu_{i}(x)}{\sigma_{i}(x)})italic_D italic_C italic_D italic_F ( italic_t | italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) roman_erf ( divide start_ARG italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_t ) - italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) end_ARG ) (8)

where erf(\cdot) is the Gaussian error function. The patient’s Survival Cumulative Distribution Function SCDF(t|Feat,Pm,Pv)=1DCDF(t|Feat,Pm,Pv)𝑆𝐶𝐷𝐹conditional𝑡𝐹𝑒𝑎superscript𝑡subscript𝑃𝑚subscript𝑃𝑣1𝐷𝐶𝐷𝐹conditional𝑡𝐹𝑒𝑎superscript𝑡subscript𝑃𝑚subscript𝑃𝑣SCDF(t|Feat^{\prime},P_{m},P_{v})=1-DCDF(t|Feat^{\prime},P_{m},P_{v})italic_S italic_C italic_D italic_F ( italic_t | italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) = 1 - italic_D italic_C italic_D italic_F ( italic_t | italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) is the final predicted patient survival probability distribution.

Assuming the patient’s right uncensorship status is c (1 for uncensored data and 0 for censored data), the duration from diagnosis to death is d𝑑ditalic_d, and the time from diagnosis to the last follow up is o𝑜oitalic_o. td𝑡𝑑tditalic_t italic_d is either equal to d𝑑ditalic_d (c=1𝑐1c=1italic_c = 1) or o𝑜oitalic_o (c=0𝑐0c=0italic_c = 0). Then we can define the loss function of RegisterMDN with the help of maximum likelihood estimation:

loss𝑙𝑜𝑠𝑠\displaystyle lossitalic_l italic_o italic_s italic_s =clog(DPDF(td|Feat,Pm,Pv))absent𝑐log𝐷𝑃𝐷𝐹conditional𝑡𝑑𝐹𝑒𝑎superscript𝑡subscript𝑃𝑚subscript𝑃𝑣\displaystyle=-c\cdot\mathrm{log}(DPDF(td|Feat^{\prime},P_{m},P_{v}))= - italic_c ⋅ roman_log ( italic_D italic_P italic_D italic_F ( italic_t italic_d | italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) ) (9)
(1c)log(SCDF(td|Feat,Pm,Pv))1𝑐log𝑆𝐶𝐷𝐹conditional𝑡𝑑𝐹𝑒𝑎superscript𝑡subscript𝑃𝑚subscript𝑃𝑣\displaystyle-(1-c)\cdot\mathrm{log}(SCDF(td|Feat^{\prime},P_{m},P_{v}))- ( 1 - italic_c ) ⋅ roman_log ( italic_S italic_C italic_D italic_F ( italic_t italic_d | italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_P start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) )

3 Experiments

3.1 Experimental Settings

3.1.1 Dataset.

We evaluate the effectiveness of our method on The Caner Genome Atlas (TCGA) lung adenocarcinom (LUAD) with 452 cases and kidney renal clear cell carcinoma (KIRC) with 512 cases. All WSIs are analyzed at 20x magnification and cropped into 256 × 256 patches. The average number of patches per WSI is 12,097 for TCGA-LUAD, and 14,249 for TCGA-KIRC, with the largest number of patches is 84,365 from a TCGA-KIRC sample.

3.1.2 Implementation Details.

In our implementation, we set the cluster size C𝐶Citalic_C in SCSA to be 64, the threshold Thres𝑇𝑟𝑒𝑠Thresitalic_T italic_h italic_r italic_e italic_s in SoftFilter to be 0.5, and the number of components K𝐾Kitalic_K in RegisterMDN to be 100. We use cuML [17] to accelerate the execution of the K-Means algorithm on the GPU. For all comparison experiments and ablation experiments, we maintain a consistent hyperparameter setting: the learning rate of 2e-4 with a weight decay of 1e-3, the Adam optimizer is used to update the model weights, a dropout rate of 0.1, a batch size of 1, and training for 20 epochs. The 5-fold cross-validation are used on all datasets and models.

3.1.3 Evaluation Metric.

The conventional concordance index (C-Index) [20] is limted to provide a more comprehensive comparison between different methods. We introduce enhanced evaluation metrics. We use a time-dependent version of the concordance estimator (TDC) within a pre-specified time span [0,τ]0𝜏[0,\tau][ 0 , italic_τ ]. TDC measures the proportion of patients pairs for which the survival risks is correctly ranked at multiple time points in [0,τ]0𝜏[0,\tau][ 0 , italic_τ ]. The Brier score (BS) calculates the mean square error between the ground-truth and the predicted probability. It mainly measures the calibration performance. To consider all times, we use an integrated BS (IBS) over time interval [0,τ]0𝜏[0,\tau][ 0 , italic_τ ]. Models with larger TDC and lower IBS demonstrate superior performance. The result of mean ± std is reported.

Table 1: Evaluation of all models on TCGA-KIRC and TCGA-LUAD with time dependent concordance index (TDC) and integrated Brier Score (IBS). Best results are marked in bold.
Method KIRC LUAD
TDC \uparrow IBS \downarrow TDC \uparrow IBS \downarrow
AMIL [11] 0.627 ±plus-or-minus\pm± 0.063 0.287 ±plus-or-minus\pm± 0.014 0.612 ±plus-or-minus\pm± 0.042 0.305 ±plus-or-minus\pm± 0.045
CLAM [16] 0.664 ±plus-or-minus\pm± 0.037 0.289 ±plus-or-minus\pm± 0.031 0.592 ±plus-or-minus\pm± 0.070 0.308 ±plus-or-minus\pm± 0.044
DSMIL [15] 0.642 ±plus-or-minus\pm± 0.045 0.289 ±plus-or-minus\pm± 0.015 0.581 ±plus-or-minus\pm± 0.075 0.322 ±plus-or-minus\pm± 0.044
PatchGCN [2] 0.674 ±plus-or-minus\pm± 0.049 0.279 ±plus-or-minus\pm± 0.026 0.582 ±plus-or-minus\pm± 0.055 0.307 ±plus-or-minus\pm± 0.045
TransMIL [18] 0.629 ±plus-or-minus\pm± 0.041 0.290 ±plus-or-minus\pm± 0.017 0.512 ±plus-or-minus\pm± 0.082 0.319 ±plus-or-minus\pm± 0.033
HIPT [1] 0.635 ±plus-or-minus\pm± 0.041 0.270 ±plus-or-minus\pm± 0.021 0.540 ±plus-or-minus\pm± 0.025 0.289 ±plus-or-minus\pm± 0.068
HGT [10] 0.634 ±plus-or-minus\pm± 0.058 0.269 ±plus-or-minus\pm± 0.033 0.601 ±plus-or-minus\pm± 0.042 0.289 ±plus-or-minus\pm± 0.052
\hdashlineSCMIL w/o SoftFilter 0.659 ±plus-or-minus\pm± 0.038 0.278 ±plus-or-minus\pm± 0.015 0.546 ±plus-or-minus\pm± 0.046 0.318 ±plus-or-minus\pm± 0.043
SCMIL w/o SCSA 0.651 ±plus-or-minus\pm± 0.020 0.274 ±plus-or-minus\pm± 0.015 0.589 ±plus-or-minus\pm± 0.042 0.318 ±plus-or-minus\pm± 0.028
SCMIL 0.688 ±plus-or-minus\pm± 0.037 0.268 ±plus-or-minus\pm± 0.021 0.622 ±plus-or-minus\pm± 0.015 0.288 ±plus-or-minus\pm± 0.060

3.2 Experiments and Results

3.2.1 Comparison with State-of-the-Art Methods.

To compare the ability of our proposed SCMIL in learning cancer survival risk-related features with existing methods, we select several state-of-the-art methods, including AMIL [11], CLAM [16], DSMIL [15], PatchGCN [2], TransMIL [18], HIPT [1], and HGT [10]. We add the RegisterMDN module into these methods to predict the patient’s survival probability distribution, ensuring a fair comparison with our method. SCMIL demonstrates its ability to learn interactions between related patches, which is an advancement over methods based on key patches [11, 16, 15]. Compared to GCNs-based methods [2, 10] that focuses on adjacent patches, SCMIL offer a more adaptable attention scope. SCMIL also outperforms Transformer-based methods [1, 18] that emphasize global patches by focusing more effectively on local regions of interest. The experimental results are presented in Table 1. Our proposed SCMIL has achieved the best performance in both TDC and IBS metrics on two WSI datasets, proving its superior ability to learn features associated with cancer survival risk from WSIs compared to previous methods.

3.2.2 Ablation Analysis.

Table 1 presents the experimental results on SCMIL with the removal of the SoftFilter module and the SCSA module, respectively. The omission of either module lead to a decline in performance, underscoring the essential role of both modules. Notably, the model’s performance on the LUAD dataset is significantly decreased without the SoftFilter module, which suggests that many patches in this dataset may be irrelevant to the task.

Refer to caption
Figure 2: Comparison of different clustering methods.
Table 2: Comparison of different probability distribution prediction methods. Best results are marked in bold, second best results are underlined.
Method TDC \uparrow IBS \downarrow
Predicted Vector [8] 0.653 ±plus-or-minus\pm± 0.100 0.255 ±plus-or-minus\pm± 0.017
Fixed Vector 0.683 ±plus-or-minus\pm± 0.018 0.280 ±plus-or-minus\pm± 0.023
Learnable Vector 0.688 ±plus-or-minus\pm± 0.037 0.268 ±plus-or-minus\pm± 0.021

Further experiments on the KIRC dataset are conducted to assess the impact of varying morphological similarity weight w1subscript𝑤1w_{1}italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and spatial location similarity weight (1w1)1subscript𝑤1(1-w_{1})( 1 - italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) on model performance during clustering. Figure 2 illustrates these experimental results, with the blue dotted line indicating the experimental results from random clustering. It is evident that an 8:2 weighted ratio of morphological similarity to spatial location similarity yields the best model performance. Conversely, models that rely solely on morphological information or spatial location information for clustering exhibit inferior performance. We further evaluate various approaches for predicting the survival probability distribution: (1) Predicted Vector, which forecasts the parameters of each MDN component via Feat𝐹𝑒𝑎superscript𝑡Feat^{\prime}italic_F italic_e italic_a italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT; (2) Fixed Vector, which predefines the parameters of each component in advance; (3) Learnable Vector, a method we designed that allows for learning parameters. The experimental results, as shown in Table 3.2.2, indicating that our proposed Learnable Vector method offers superior discriminative power and improved calibration.

Refer to caption
Figure 3: Interpretability of the SCMIL.
Refer to caption
Figure 4: Survival probability distribution prediction and actual survival time.

3.3 Interpretability of the Proposed Method

We conduct an interpretability analysis for each module of SCMIL, and the visualization results are presented in Figure 3. The original image is located in the top left, the heatmap of IS is in the top right. The cluster distribution image is in the bottom left, and a zoomed-in view is in the bottom right. In the IS heat map, the color spectrum from red to yellow to blue represents a decrease in IS𝐼𝑆ISitalic_I italic_S value. Areas closer to red are considered more valuable for the task. In the cluster distribution image, task-relevant patches are divided into different clusters by the model, with each color representing a different cluster. To determine which areas the model primarily focuses on for patch interactions, we calculate the average IS𝐼𝑆ISitalic_I italic_S value for patches within clusters. The image in the bottom right of Figure 3 is an enlarged view of the region containing the cluster with the highest average IS𝐼𝑆ISitalic_I italic_S value. The figure reveals that the model pays more attention to the perivascular area. Concurrently, clinical studies have identified angiogenesis and blood vessel invasion as significant factors in predicting cancer risk [4, 13]. The knowledge acquired by our model coincides with clinical findings. Figure 4 illustrates the actual survival time of two patients and the survival probability distribution predicted by our model. The blue curve is the Kaplan-Meier curve of the patient cohort. Our model can estimate the survival probability of patients at any given time and accurately distinguish between patients with varying survival risks.

4 Conclusion

In this paper, we propose SCMIL, a method designed to effectively identify instances related to survival risks from numerous instances and to discern the interactions among instances within the regions of interest. Moreover, our method synthesizes the information from cancer patient cohort to predict a more clinically meaningful survival probability distribution for individual patient. Experimental results on two public WSI datasets demonstrate that our method achieves superior performance and richer interpretability compared to existing methods. In the future, we will extend our model for tasks such as predicting cancer recurrence and enhance the efficiency of our model.

{credits}

4.0.1 Acknowledgements

This work is supported by the National Natural Science Foundation of China (62276250), the National Key R&D Program of China (2022YFF1203303).

4.0.2 \discintname

We have no competing interests to declare.

References

  • [1] Chen, R.J., Chen, C., Li, Y., Chen, T.Y., Trister, A.D., Krishnan, R.G., Mahmood, F.: Scaling vision transformers to gigapixel images via hierarchical self-supervised learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 16144–16155 (2022)
  • [2] Chen, R.J., Lu, M.Y., Shaban, M., Chen, C., Chen, T.Y., Williamson, D.F., Mahmood, F.: Whole slide images are 2d point clouds: Context-aware survival prediction using patch-based graph convolutional networks. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part VIII 24. pp. 339–349. Springer (2021)
  • [3] Chen, R.J., Lu, M.Y., Williamson, D.F., Chen, T.Y., Lipkova, J., Noor, Z., Shaban, M., Shady, M., Williams, M., Joo, B., et al.: Pan-cancer integrative histology-genomic analysis via multimodal deep learning. Cancer Cell 40(8), 865–878 (2022)
  • [4] D’Aniello, C., Berretta, M., Cavaliere, C., Rossetti, S., Facchini, B.A., Iovane, G., Mollo, G., Capasso, M., Pepa, C.D., Pesce, L., et al.: Biomarkers of prognosis and efficacy of anti-angiogenic therapy in metastatic clear cell renal cancer. Frontiers in oncology 9,  1400 (2019)
  • [5] Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., et al.: An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929 (2020)
  • [6] Haider, H., Hoehn, B., Davis, S., Greiner, R.: Effective ways to build and evaluate individual survival distributions. The Journal of Machine Learning Research 21(1), 3289–3351 (2020)
  • [7] Hamilton, W., Ying, Z., Leskovec, J.: Inductive representation learning on large graphs. Advances in neural information processing systems 30 (2017)
  • [8] Han, X., Goldstein, M., Ranganath, R.: Survival mixture density networks. In: Machine Learning for Healthcare Conference. pp. 224–248. PMLR (2022)
  • [9] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE conference on computer vision and pattern recognition. pp. 770–778 (2016)
  • [10] Hou, W., He, Y., Yao, B., Yu, L., Yu, R., Gao, F., Wang, L.: Multi-scope analysis driven hierarchical graph transformer for whole slide image based cancer survival prediction. In: International Conference on Medical Image Computing and Computer-Assisted Intervention. pp. 745–754. Springer (2023)
  • [11] Ilse, M., Tomczak, J., Welling, M.: Attention-based deep multiple instance learning. In: International conference on machine learning. pp. 2127–2136. PMLR (2018)
  • [12] Kang, M., Song, H., Park, S., Yoo, D., Pereira, S.: Benchmarking self-supervised learning on diverse pathology datasets. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 3344–3354 (2023)
  • [13] Kato, T., Kameoka, S., Kimura, T., Nishikawa, T., Kobayashi, M.: The combination of angiogenesis and blood vessel invasion as a prognostic indicator in primary breast cancer. British journal of cancer 88(12), 1900–1908 (2003)
  • [14] Kipf, T.N., Welling, M.: Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907 (2016)
  • [15] Li, B., Li, Y., Eliceiri, K.W.: Dual-stream multiple instance learning network for whole slide image classification with self-supervised contrastive learning. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. pp. 14318–14328 (2021)
  • [16] Lu, M.Y., Williamson, D.F., Chen, T.Y., Chen, R.J., Barbieri, M., Mahmood, F.: Data-efficient and weakly supervised computational pathology on whole-slide images. Nature biomedical engineering 5(6), 555–570 (2021)
  • [17] Raschka, S., Patterson, J., Nolet, C.: Machine learning in python: Main developments and technology trends in data science, machine learning, and artificial intelligence. arXiv preprint arXiv:2002.04803 (2020)
  • [18] Shao, Z., Bian, H., Chen, Y., Wang, Y., Zhang, J., Ji, X., et al.: Transmil: Transformer based correlated multiple instance learning for whole slide image classification. Advances in neural information processing systems 34, 2136–2147 (2021)
  • [19] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, Ł., Polosukhin, I.: Attention is all you need. Advances in neural information processing systems 30 (2017)
  • [20] Wang, P., Li, Y., Reddy, C.K.: Machine learning for survival analysis: A survey. ACM Computing Surveys (CSUR) 51(6), 1–36 (2019)
  • [21] Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V.: Nyströmformer: A nyström-based algorithm for approximating self-attention. In: Proceedings of the AAAI Conference on Artificial Intelligence. vol. 35, pp. 14138–14148 (2021)
  • [22] Xu, K., Hu, W., Leskovec, J., Jegelka, S.: How powerful are graph neural networks? arXiv preprint arXiv:1810.00826 (2018)
  • [23] Yao, J., Zhu, X., Huang, J.: Deep multi-instance learning for survival prediction from whole slide images. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part I 22. pp. 496–504. Springer (2019)
  • [24] Yao, J., Zhu, X., Jonnagaddala, J., Hawkins, N., Huang, J.: Whole slide images based cancer survival prediction using attention guided deep multiple instance learning networks. Medical Image Analysis 65, 101789 (2020)
  • [25] Zhu, X., Yao, J., Zhu, F., Huang, J.: Wsisa: Making survival prediction from whole slide histopathological images. In: Proceedings of the IEEE conference on computer vision and pattern recognition. pp. 7234–7242 (2017)