Cross-modality Attention-based Multimodal Fusion for Non-small Cell Lung Cancer (NSCLC) Patient Survival Prediction
Abstract
Cancer prognosis and survival outcome predictions are crucial for therapeutic response estimation and for stratifying patients into various treatment groups. Medical domains concerned with cancer prognosis are abundant with multiple modalities, including pathological image data and non-image data such as genomic information. To date, multimodal learning has shown potential to enhance clinical prediction model performance by extracting and aggregating information from different modalities of the same subject. This approach could outperform single modality learning, thus improving computer-aided diagnosis and prognosis in numerous medical applications. In this work, we propose a cross-modality attention-based multimodal fusion pipeline designed to integrate modality-specific knowledge for patient survival prediction in non-small cell lung cancer (NSCLC). Instead of merely concatenating or summing up the features from different modalities, our method gauges the importance of each modality for feature fusion with cross-modality relationship when infusing the multimodal features. Compared with single modality, which achieved c-index of 0.5772 and 0.5885 using solely tissue image data or RNA-seq data, respectively, the proposed fusion approach achieved c-index 0.6587 in our experiment, showcasing the capability of assimilating modality-specific knowledge from varied modalities.
keywords:
Multimodal Learning, Multiple-instance Learning, Survival Prediction, Attention Mechanism1 Description of purpose
Cancer prognosis and survival outcome predictions are crucial for therapeutic response forecasts and stratifying patients into distinct treatment groups. Previous works [1, 2, 3] have demonstrated that incorporating various data modalities into survival prediction models can bolster their predictive capabilities, thereby benefiting both clinical research and practice. However, these methods simply concatenate or sum up the features from different modalities, neglecting a deeper understanding of inter-modality interactions during the fusion process. The latest attention-based methods [4, 5] show promising fusion performance by discerning the relationships between different modalities. In this research, our aim was to predict the survival outcomes of patients with NSCLC using a blend of histopathology and genomics. We employed a cross-modality attention-based multimodal fusion (CM-MMF) approach, integrating image and RNA-seq modalities to achieve superior patient survival predictions. This showcases the potential of assimilating modality-specific knowledge from varied sources. The attention scores derived from the fusion layer highlight the significance of each modality during fusion for clinical diagnosis.
2 Method
2.1 Unimodal Embedding
The Attention Multiple Instance Learning (AMIL) module from the PORPOISE [6] pipeline is utilized to transform each 1024-channel latent feature vector of image tile from a whole slide image (WSI) obtained through ImageNet pretrained ResNet-50 model, into a 128-channel feature vector. The AMIL [4] module calculates an attention score for each tile based on its perceived relevance to patient-level prognostic prediction, enabling it to select pivotal tiles when aggregating the patient-level image feature representation. The Self-normalizing neural networks [7] (SNN) was employed to transform RNA-seq information into a 128-channel feature vector as omic feature representation. SNN was chosen due to its demonstrated superior performance in the unimodal setting with sequencing data.
2.2 Cross-modality Attention-based Multimodal Fusion
Upon receiving feature representations, which contain modality-specific knowledge from two modalities, normalization is undertaken as a preprocessing step prior to fusion. Inspired from one attention-based architecture for multi-scale disease classification [8], Cross-modality Attention-based Multimodal Fusion (CM-MMF) is introduced to weigh the significance of each modality for survival prediction. The CM-MMF comprises two fully convolutional layers with a kernel size of 11, accompanied by a Tanh activation function. The kernel weights are shareable, promoting holistic learning of the importance of modality-specific knowledge through cross-modality relationships. The cross-modality attention () can be expressed as:
(1) |
Here, and are trainable parameters in the CM-MMF, with representing the size of the unimodal embedding output . Additionally, is the output channel of the first layer of CM-MMF, denotes the tangent element-wise non-linear activation function, and signifies the number of modalities in the dataset.
2.3 Multimodal Survival Prediction
The cross-modality attention scores () are subsequently multiplied with corresponding modality features to yield a unified cross-modality representation, as illustrated in Equation 2:
(2) |
For final prediction, a one-layer classifier adapted from PORPOISE [6] is implemented to facilitate patient-wise survival prediction using cross-modality embedding ().
3 Data & Experiments
3.1 Data
In this study, we used data from patients which received atezolizumab plus carboplatin plus paclitaxel (also termed as ARM-A) for the first-line treatment of metastatic nonsquamous NSCLC from the IMpower 150 study [9]. This was part of a phase 3 clinical trial that evaluated the efficacy of adding targeted treatment to PD-L1 versus the current standard of care in NSCLC. To develop our multimodal framework, we worked with anonymized histopathology images (from 270 patients) alongside bulk RNA-seq data.
Image Data H&E-stained WSI data were scanned at 20 (0.5 micron/pixel) by an Aperio scanner. For the purpose of tile filtering, we first ran our in-house pretrained model with U-Net architecture to classify regions into tumor and stroma. 512 512 pixel tiles were then captured from those classified regions of WSIs, and then embedded into a 1024-channel feature vector using CLAM [10] feature extraction pipeline which uses ImageNet pretrained ResNet-50 model.
RNA-seq Data RNA-seq data, containing gene expression values along with Ensemble Gene IDs, were obtained from CID CIT Data MART. Considering the relatively small patients number, out of the 19K available genes, 154 genes were pre-selected through Elastic Net Cox model [11] fitting with 10 fold cross validation on the same data set.
3.2 Loss Function
Survival loss function [6] is deployed for optimizing the outcome from the fusion architecture. The continuous timescale of overall patient survival time in days is partitioned into four non-overlap** bins. The negative log-likelihood (NLL) survival loss is used to supervise the training, using both censorship status and 4-bin interval labels as a classification task.
3.3 Evaluation Metrics
We assess the survival prediction results using the concordance index (c-index), where higher values indicate better performance. The c-index measures the proportion of all possible pairs of observations in which the model’s predictions correctly order the actual survival times. All results from the baseline methods and our proposed method represent mean values of the c-index, calculated using 5-fold cross-validation on consistent data splits.
3.4 Experiment Details
To improve the training robustness, Gaussian noise was added to image features and RNA-seq features before loaded into the model. All of the models were trained over 55 epochs with a learning rate of 0.01 and a batch size of 1 using the ADAM optimizer. For each patient, all tiles from stroma and tumor regions were used for survival loss. Standardization was implemented for the RNA-seq modality, and normalization was deployed to rearrange the feature vectors between 0 and 1 for all modalities before implementing the fusion architecture. The remaining setting followed PORPOISE official pipeline 111https://github.com/mahmoodlab/PORPOISE.
4 Results
AMIL [6] and an custom-AMIL with adding more batch-normalization layers and ReLU activation were implemented as image unimodal encoders, while two deep learning networks, SNN [7] for RNA-seq data was deployed as RNA-seq unimodal encoders. AMIL and SNN were selected as the backbones for all fusion designs according to the better performance in unimodal training. 6 existing fusion approaches [1, 2, 3, 4, 5] ranging from simple concatenation to attention-based architectures were deployed to compare the capabilities of the multimodal fusion with the proposed method.
4.1 Multimodal-fusion Results
In table 1, most of the fusion designs with multimodal learning achieved superior performance than unimodal learning, demonstrating the capability that infuse the modality-specific knowledge from different modalities. The proposed CM-MMF achieved better fusion performance supervised by survival loss, showcasing the functionality of multimodal fusion by considering cross-modality relationships. Directly utilizing the RNA-seq data in the fusion part (Raw-concatenation) yields stably superior performance, which can be attributed to the primary contribution from the RNA-seq modality.
4.2 Ablation Study
Inspired by [12] and [4], we explored various attention mechanism designs with different activation functions and evaluate those designs on the NSCLC dataset. We formed the CM-MMF into two strategies, differentiated by whether they shared the kernel weights while learning the embedding features from multiple modalities. As shown by the survival prediction performance in Table 2, sharing the kernel weight in the CM-MMF with Tanh activation function achieved better performances with a higher mean value of c-index.
Strategy | Modality Attention Layer | Activation Function | C-index |
1 | Non-sharing | ReLU | 0.6416 0.0181 |
2 | Sharing | ReLU | 0.5735 0.0114 |
3 | Non-sharing | Tanh | 0.6329 0.0356 |
4* | Sharing | Tanh | 0.6587 0.0266 |
*: The proposed design which achieved better survival prediction in NSCLC.
4.3 Limitations
In this study, we relied solely on one in-house dataset to gauge performance and observations. More data samples should be evaluated to gain a more generalized perspective on fusion performance. Meanwhile, other loss functions (e.g., cox loss, etc.) and performance metrics are expected to evaluate the supervisory prowess across different modalities more comprehensively. We primarily deployed only a few backbones for unimodal data, which might not be the optimal choices for unimodal representation. Implementing more embedding backbones could provide further insight into the capability of each unimodal learning method.
5 New or Breakthrough Work to be Presented
In this study, we propose a cross-modality attention-based multimodal fusion architecture (CM-MMF) to integrate knowledge from WSI and RNA-seq for enhanced lung cancer survival prediction. Within the fusion designs, our method assesses the importance of each modality for feature fusion, considering cross-modality relationships when amalgamating the multimodal features. This approach achieved the highest c-index 0.6587 in our experiment.
6 Conclusion
The proposed cross-modality attention-based multimodal fusion (CM-MMF) method outperformed other fusion designs and unimodal learning methods in this study. This underscores its capability to integrate modality-specific knowledge from various sources and highlights the functionality of multimodal fusion that takes cross-modality relationships into account. The attention scores from the fusion layer enable us to illustrate the significance of each modality for diagnosis. Meanwhile, the instance attention from AMIL can be used to indicate the contribution of each image tile. We will pursue detailed investigation in future work.
7 ACKNOWLEDGMENTS
The work has not been submitted for publication or presentation elsewhere. We thank Jennifer Giltnane and Raghavan Venugopa for their support in providing the tissue image segmentation algorithm to facilitate tile extraction. We also thank the following individuals for their expertise and assistance throughout all aspects of this study and for their help leading to this manuscript: Ravi Kamble, Kamalakar Kodali, Qiangqiang Gu, Auranuch (Ney) Lorsakul, Xingwei Wang, and Marghoob Mohiyuddin.
References
- [1] Mobadersany, P., Yousefi, S., Amgad, M., Gutman, D. A., Barnholtz-Sloan, J. S., Velázquez Vega, J. E., Brat, D. J., and Cooper, L. A., “Predicting cancer outcomes from histology and genomics using convolutional networks,” Proceedings of the National Academy of Sciences 115(13), E2970–E2979 (2018).
- [2] Cheerla, A. and Gevaert, O., “Deep learning with multimodal representation for pancancer prognosis prediction,” Bioinformatics 35(14), i446–i454 (2019).
- [3] Chen, R. J., Lu, M. Y., Weng, W.-H., Chen, T. Y., Williamson, D. F., Manz, T., Shady, M., and Mahmood, F., “Multimodal co-attention transformer for survival prediction in gigapixel whole slide images,” in [Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) ], 4015–4025 (October 2021).
- [4] Ilse, M., Tomczak, J., and Welling, M., “Attention-based deep multiple instance learning,” in [International conference on machine learning ], 2127–2136, PMLR (2018).
- [5] Jaume, G., Vaidya, A., Chen, R., Williamson, D., Liang, P., and Mahmood, F., “Modeling dense multimodal interactions between biological pathways and histology for survival prediction,” arXiv preprint arXiv:2304.06819 (2023).
- [6] Chen, R. J., Lu, M. Y., Williamson, D. F., Chen, T. Y., Lipkova, J., Noor, Z., Shaban, M., Shady, M., Williams, M., Joo, B., et al., “Pan-cancer integrative histology-genomic analysis via multimodal deep learning,” Cancer Cell 40(8), 865–878 (2022).
- [7] Klambauer, G., Unterthiner, T., Mayr, A., and Hochreiter, S., “Self-normalizing neural networks,” Advances in neural information processing systems 30 (2017).
- [8] Deng, R., Cui, C., Remedios, L. W., Bao, S., Womick, R. M., Chiron, S., Li, J., Roland, J. T., Lau, K. S., Liu, Q., et al., “Cross-scale attention guided multi-instance learning for crohn’s disease diagnosis with pathological images,” in [International Workshop on Multiscale Multimodal Medical Imaging ], 24–33, Springer (2022).
- [9] Socinski, M. A., Jotte, R. M., Cappuzzo, F., Orlandi, F., Stroyakovskiy, D., Nogami, N., Rodríguez-Abreu, D., Moro-Sibilot, D., Thomas, C. A., Barlesi, F., et al., “Atezolizumab for first-line treatment of metastatic nonsquamous nsclc,” New England Journal of Medicine 378(24), 2288–2301 (2018).
- [10] Lu, M. Y., Williamson, D. F., Chen, T. Y., Chen, R. J., Barbieri, M., and Mahmood, F., “Data-efficient and weakly supervised computational pathology on whole-slide images,” Nature Biomedical Engineering 5(6), 555–570 (2021).
- [11] Wu, Y., “Elastic net for cox’s proportional hazards model with a solution path algorithm,” Statistica Sinica 22, 27 (2012).
- [12] Yao, J., Zhu, X., Jonnagaddala, J., Hawkins, N., and Huang, J., “Whole slide images based cancer survival prediction using attention guided deep multiple instance learning networks,” Medical Image Analysis 65, 101789 (2020).