HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: mwe

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: CC BY-SA 4.0
arXiv:2403.11375v1 [cs.CV] 18 Mar 2024

Path-GPTOmic: A Balanced Multi-modal Learning Framework
for Survival Outcome Prediction

Abstract

For predicting cancer survival outcomes, standard approaches in clinical research are often based on two main modalities: pathology images for observing cell morphology features, and genomic (e.g., bulk RNA-seq) for quantifying gene expressions. However, existing pathology-genomic multi-modal algorithms face significant challenges: (1) Valuable biological insights regarding genes and gene-gene interactions are frequently overlooked; (2) one modality often dominates the optimization process, causing inadequate training for the other modality. In this paper, we introduce a new multi-modal “Path-GPTOmic” framework for cancer survival outcome prediction. First, to extract valuable biological insights, we regulate the embedding space of a foundation model, scGPT, initially trained on single-cell RNA-seq data, making it adaptable for bulk RNA-seq data. Second, to address the imbalance-between-modalities problem, we propose a gradient modulation mechanism tailored to the Cox partial likelihood loss for survival prediction. The contributions of the modalities are dynamically monitored and adjusted during the training process, encouraging that both modalities are sufficiently trained. Evaluated on two TCGA(The Cancer Genome Atlas) datasets, our model achieves substantially improved survival prediction accuracy.

Index Terms—  Multi-modal, Survival outcome prediction, Pathology images, Genomics.

1 Introduction

Pathology images and genomic assays are two main data sources for predicting the survival outcome of cancer patients. Pathology images contain information on cell mitosis, cell morphology, and the micro-environment. Genomic data from RNA-seq measure the abundance of RNA transcripts, providing critical biological insights into cell identity, cellular activity, stage of development and differentiation, as well as cell functionality [1]. In clinical practice, bulk RNA-seq (average global gene expression among cells) is more cost-effective, and is widely employed for analyzing cancer initiation and progression, as well as predicting survival outcomes [2].

To jointly utilize these two complementary sources for automatic cancer survival prediction, researchers have developed a series of multi-modality deep learning algorithms. For example, in [3, 4], genomic data were processed with self-normalizing networks (SNN) and fused with pathology image embeddings with Kronecker product. Transformer models were employed to capture genotype-phenotype interactions through an attention mechanism [5]. Different modalities were projected into the same latent space to enclose distances between multi-modal embeddings of the same patients [6, 7].

However, these methods still face two major challenges. First, the biological insight of genomic data cannot be fully explored by SNN or Transformer models with limited training data, and external knowledge on human cells is not utilized to compute more accurate genomic embeddings. Second, researchers found that the dominant modality (the one with better performance) may suppress the training process of the other modality [8, 9]. Consequently, the other modality may not generalize well to test data.

In this paper, we propose a new multi-modal Path-GPTOmic framework, which combines pathology images and genomic data from patient specimen for predicting cancer survival outcomes. To address the first challenge, we seek to learn a smooth latent space for bulk RNA-seq embeddings by incorporating a foundation model, scGPT [1]. This model was originally trained on single-cell RNA-seq data from large human cell atlases [10], and has demonstrated superior performance in cell functionality analysis (e.g., cell type annotation and genetic perturbation prediction). However, in our prior experiments, we observed that directly applying scGPT to bulk RNA-seq improved model performance only marginally for downstream tasks. One possible reason for this is that the direct application maps bulk RNA to an unfair latent space, where distances between embedding vectors do not accurately reflect RNA similarities. To address this issue, we adopt generative model practices [11, 12, 13] to smooth the latent space with mix-up regulation. Specifically, we append a smoothing module (i.e., a multi-layer-proceptron (MLP) network) after scGPT. Then we train it by simulating bulk RNA-seq with individual single-cell RNA expression [2]. This approach enables us to achieve an interpolatable latent space for gene expression, enhancing generalizability for bulk RNA data.

To tackle the second challenge, we combine the genomic and image branches of the model and closely monitor the contribution of each modality in training process. We find that genomic branch outperforms and contributes less to the Cox partial likelihood loss [14]. As a result, the genomic branch dominates the training process, causing under-optimization of the image branch. To alleviate this issue in survival outcome prediction, we propose to control the loss optimization process by dynamically adjusting the gradient. Specifically, we assess the contributions of the two branches and appropriately modulate the gradient of the under-optimized image modality.

Our contributions can be summarized in three key aspects. (1) In clinical research, bulk RNA-seq data are more cost-effective and easier to acquire. To our best knowledge, we are the first to extend the single-cell foundation model scGPT for processing bulk RNA-seq data of patient samples. (2) We take pioneering steps to address the training imbalance problem in multi-modal pathology-genomics fusion tasks. (3) We evaluate our new method using two TCGA datasets, and obtain performance improvement compared to the baselines.

2 Method

Fig. 1 gives an overview of our method. In Section 2.1, we show how to smooth the scGPT-derived single-cell RNA-seq embedding space, enabling scGPT to be adapted for bulk RNA-seq embeddings. In Section 2.2, we describe how to balance the contributions of the genomics branch and image branch in training Cox partial likelihood loss, encouraging better optimization for both modalities.

Refer to caption

Fig. 1: Illustrating our Path-GPTOmic pipeline. First, we train MLP-A for regulating the latent space for bulk RNA-seq embeddings, freezing the scGPT parameters. Second, we train SNN and MLP-B for genomics embedding and Image Encoder (T2T-ViT [15]) for embedding pathology image features, freezing both the scGPT parameters and MLP-A parameters.

2.1 Regulating the Genomics Embedding Space

Cancer survival outcome predictions often require profiling transcript abundance through RNA-seq, including both single-cell RNA-seq (scRNAseq) and bulk RNA-seq (average global gene expression). On one hand, scRNAseq can provide fine-grained individual cell-level evidence for cancer. By training a Transformer model on large datasets, scGPT has uncovered gene-gene interactions and achieved superior performance on various downstream tasks. On the other hand, bulk RNA is more cost-effective and widely used in clinical research. Hence, it is desirable to extend scGPT to the analysis of bulk RNA-seq data for survival outcome prediction.

Our preliminary results show that directly applying scGPT to bulk RNA data is not significantly beneficial for downstream tasks (see Section 3). One possible reason is that without sufficient average single-cell RNA-seq input for training, the intermediate interpolation area of the latent space is abrupt and does not generalize well for global average data (i.e., bulk RNA-seq). A similar phenomenon has been observed in generation tasks. For example, as shown in [16, 17], a model f()𝑓f(\cdot)italic_f ( ⋅ ) performing well on inputs x1,x2subscript𝑥1subscript𝑥2x_{1},x_{2}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT may collapse and yield artifact results for interpolated input (λx1+(1λ)x2)𝜆subscript𝑥11𝜆subscript𝑥2(\lambda x_{1}+(1-\lambda)x_{2})( italic_λ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + ( 1 - italic_λ ) italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ), with λ[0,1]𝜆01\lambda\in[0,1]italic_λ ∈ [ 0 , 1 ].

By applying scGPT to bulk RNA-seq, our goal is to regulate the latent space using Mixup-based regulation [11, 18], encouraging “average” RNA-seq input to yield reasonable embeddings that are beneficial for downstream tasks. We choose cell type annotation [1] prediction as the probing task, as it is closely related to survival outcomes. Our pipeline is outlined in the top part of Fig. 1. Two randomly selected scRNAseq samples are interpolated by a mix layer, processed by a pre-trained scGPT module, a multi-layer perceptron (MLP) network (called as MLP-A), and a linear layer as a classifier. By encouraging the model output to converge to average cell types, we expect the embedding space (i.e., the output of MLP-A) to be regulated for representing average cell information.

Our model is pre-trained as follows. First, in each optimization step, we randomly select two scRNAseq samples from the scGPT assembled human cell scRNAseq dataset. Formally, we denote the input scRNAseq as xi,xjsubscript𝑥𝑖subscript𝑥𝑗x_{i},x_{j}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, represented as one-hot cell type label encodings 𝐭𝐢,𝐭𝐣subscript𝐭𝐢subscript𝐭𝐣\bf{t_{i}},\bf{t_{j}}bold_t start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT , bold_t start_POSTSUBSCRIPT bold_j end_POSTSUBSCRIPT, respectively. Second, we simulate the bulk RNAseq by interpolating xi,xjsubscript𝑥𝑖subscript𝑥𝑗x_{i},x_{j}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT with (λxi+(1λ)xj)𝜆subscript𝑥𝑖1𝜆subscript𝑥𝑗(\lambda x_{i}+(1-\lambda)x_{j})( italic_λ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + ( 1 - italic_λ ) italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ), where λ𝜆\lambdaitalic_λ is a scalar value uniformly and randomly chosen from [0,1]01[0,1][ 0 , 1 ]. Third, to avoid excessively perturbing the well-trained scGPT model, we fix the scGPT parameters and append a three-layer MLP-A to it to regulate the scGPT output. Then, we optimize the output with a regression target value (λ𝐭𝐢+(𝟏λ)𝐭𝐣)𝜆subscript𝐭𝐢1𝜆subscript𝐭𝐣(\lambda\bf{t_{i}}+(1-\lambda)\bf{t_{j}})( italic_λ bold_t start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT + ( bold_1 - italic_λ ) bold_t start_POSTSUBSCRIPT bold_j end_POSTSUBSCRIPT ). In this way, the distance between the latent spaces can represent the weighted contributions of the RNAseq from the two input samples.

2.2 Balanced Multi-modal Training

In the optimization process of multi-modal deep learning models, researchers have observed that the dominant model, which displays better performance, tends to suppress the optimization of the others [8, 9]. While the issue of imbalanced training has been addressed in optimizing cross-entropy loss, it has not been thoroughly investigated in the context of pathology-genomics multi-modal fusion for optimizing Cox partial likelihood loss in survival prediction.

Framework. To better discuss the multi-modal imbalance training problem, we use a common prototype model, as shown in Fig. 1. In this model, the Image Encoder submodule can be flexibly replaced with alternatives such as ResNet [3], Transformer [6], or T2T-ViT [15]. We define the training set as 𝒟={(gk,pk)}k=1K𝒟superscriptsubscriptsubscript𝑔𝑘subscript𝑝𝑘𝑘1𝐾\mathcal{D}=\{(g_{k},p_{k})\}_{k=1}^{K}caligraphic_D = { ( italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT, where K𝐾Kitalic_K is the number of paired genomic data gksubscript𝑔𝑘g_{k}italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and pathology images pksubscript𝑝𝑘p_{k}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. The genomic data consist of copy number variations (CNV), mutations, and RNA-seq. We process the CNV and mutations with SNN in the same structure as in [3], yielding a vector 𝐆𝐤(𝟏)subscriptsuperscript𝐆1𝐤\bf{G^{(1)}_{k}}bold_G start_POSTSUPERSCRIPT ( bold_1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT. The RNA-seq is processed by scGPT and MLP-A sequentially, with the network parameters pre-trained and fixed as described in Section 2.1, yielding a vector 𝐆𝐤(𝟐)subscriptsuperscript𝐆2𝐤\bf{G^{(2)}_{k}}bold_G start_POSTSUPERSCRIPT ( bold_2 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT. These are then concatenated, and we train a new MLP-B to produce a genomic feature 𝐆𝐤=𝐌𝐋𝐏𝐁([𝐆𝐤(𝟏)||𝐆𝐤(𝟐)])\bf{G_{k}}=MLP_{B}([\bf{G^{(1)}_{k}}||\bf{G^{(2)}_{k}}])bold_G start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT = bold_MLP start_POSTSUBSCRIPT bold_B end_POSTSUBSCRIPT ( [ bold_G start_POSTSUPERSCRIPT ( bold_1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT | | bold_G start_POSTSUPERSCRIPT ( bold_2 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT ] ). In parallel, the pathology image pksubscript𝑝𝑘p_{k}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is processed by the Image Encoder to obtain an image feature 𝐏𝐤subscript𝐏𝐤\bf{P_{k}}bold_P start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT. We then concatenate 𝐆𝐤subscript𝐆𝐤\bf{G_{k}}bold_G start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT and 𝐏𝐤subscript𝐏𝐤\bf{P_{k}}bold_P start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT as [𝐆𝐤||𝐏𝐤][\bf{G_{k}}||\bf{P_{k}}][ bold_G start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT | | bold_P start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT ]. The log hazard ratio θksubscript𝜃𝑘\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT for patient k𝑘kitalic_k is derived by using a linear classifier with trainable weights W𝑊Witalic_W and bias b𝑏bitalic_b, as θk=W([𝐆𝐤||𝐏𝐤])+𝐛\theta_{k}=W([\bf{G_{k}}||\bf{P_{k}}])+bitalic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_W ( [ bold_G start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT | | bold_P start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT ] ) + bold_b.

Following [8], we represent W𝑊Witalic_W as a combination of two blocks WGsuperscript𝑊𝐺\it{W^{G}}italic_W start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT and WPsuperscript𝑊𝑃\it{W^{P}}italic_W start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT, and rewrite the equation as:

θk=WG𝐆𝐤+WP𝐏𝐤+𝐛.subscript𝜃𝑘superscript𝑊𝐺subscript𝐆𝐤superscript𝑊𝑃subscript𝐏𝐤𝐛\theta_{k}=\it{W^{G}}\cdot\bf{G_{k}}+\it{W^{P}}\cdot\bf{P_{k}}+b.italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_W start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT ⋅ bold_G start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT + italic_W start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT ⋅ bold_P start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT + bold_b . (1)

To predict survival outcome, Cox partial log-likelihood [14, 3] is used as cost function:

LCox=C(k)=1(θklogR(tk)exp(θk)),subscript𝐿𝐶𝑜𝑥subscript𝐶𝑘1subscript𝜃𝑘subscript𝑅subscript𝑡𝑘subscript𝜃𝑘L_{Cox}=\sum_{C(k)=1}(\theta_{k}-\log\sum_{R(t_{k})}\exp({\theta_{k}})),italic_L start_POSTSUBSCRIPT italic_C italic_o italic_x end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_C ( italic_k ) = 1 end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - roman_log ∑ start_POSTSUBSCRIPT italic_R ( italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT roman_exp ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) , (2)

where C(k)=1𝐶𝑘1C(k)=1italic_C ( italic_k ) = 1 represents the uncensored events, and R(tk)𝑅subscript𝑡𝑘R(t_{k})italic_R ( italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) is the risk set at time tksubscript𝑡𝑘t_{k}italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. The gradient is:

LCoxθksubscript𝐿𝐶𝑜𝑥subscript𝜃𝑘\displaystyle\frac{\partial L_{Cox}}{\partial\theta_{k}}divide start_ARG ∂ italic_L start_POSTSUBSCRIPT italic_C italic_o italic_x end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG =C(k)=1(1exp(θk)R(tk)exp(θk))absentsubscript𝐶𝑘11subscript𝜃𝑘subscript𝑅subscript𝑡𝑘subscript𝜃𝑘\displaystyle=\sum_{C(k)=1}(1-\frac{\exp({\theta_{k}})}{\sum_{R(t_{k})}\exp({% \theta_{k}})})= ∑ start_POSTSUBSCRIPT italic_C ( italic_k ) = 1 end_POSTSUBSCRIPT ( 1 - divide start_ARG roman_exp ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_R ( italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT roman_exp ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG ) (3)
=C(k)=1(1e(WGGk+WPPk+b)R(tk)e(WGGk+WPPk+b)).absentsubscript𝐶𝑘11superscript𝑒superscript𝑊𝐺subscript𝐺𝑘superscript𝑊𝑃subscript𝑃𝑘𝑏subscript𝑅subscript𝑡𝑘superscript𝑒superscript𝑊𝐺subscript𝐺𝑘superscript𝑊𝑃subscript𝑃𝑘𝑏\displaystyle=\sum_{C(k)=1}(1-\frac{e^{({W^{G}G_{k}+W^{P}P_{k}+b})}}{\sum_{R(t% _{k})}e^{({W^{G}G_{k}+W^{P}P_{k}+b})}}).= ∑ start_POSTSUBSCRIPT italic_C ( italic_k ) = 1 end_POSTSUBSCRIPT ( 1 - divide start_ARG italic_e start_POSTSUPERSCRIPT ( italic_W start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_W start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_b ) end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_R ( italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT ( italic_W start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_W start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_b ) end_POSTSUPERSCRIPT end_ARG ) . (4)

One can observe that in the above equation, when one modality (say, the genomics modality) performs well, it will dominate the gradient loss in Equation (4) via WG𝐆𝐤superscript𝑊𝐺subscript𝐆𝐤\it{W^{G}}\cdot\bf{G_{k}}italic_W start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT ⋅ bold_G start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT. The global loss will be close to zero. Thus, limited training optimization will be applied to the image modality. Even when the model converges, the image modality possibly remains inadequately trained.

Balanced Training. To alleviate the imbalance training problem, we evaluate the contribution discrepancy ratio ρG,ρPsuperscript𝜌𝐺superscript𝜌𝑃\rho^{G},\rho^{P}italic_ρ start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT , italic_ρ start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT for Cox partial log-likelihood loss, as:

ρGsuperscript𝜌𝐺\displaystyle\rho^{G}italic_ρ start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT =((WG𝐆𝐤+𝐛𝟐)R(tk)e(WG𝐆𝐤+𝐛𝟐))/((WP𝐏𝐤+𝐛𝟐)R(tk)e(WP𝐏𝐤+𝐛𝟐)),absentsuperscript𝑊𝐺subscript𝐆𝐤𝐛2subscript𝑅subscript𝑡𝑘superscript𝑒superscript𝑊𝐺subscript𝐆𝐤𝐛2superscript𝑊𝑃subscript𝐏𝐤𝐛2subscript𝑅subscript𝑡𝑘superscript𝑒superscript𝑊𝑃subscript𝐏𝐤𝐛2\displaystyle=(\frac{({W^{G}\bf{G_{k}}+\frac{b}{2}})}{\sum_{R(t_{k})}e^{({W^{G% }\bf{G_{k}}+\frac{b}{2}})}})/(\frac{({W^{P}\bf{P_{k}}+\frac{b}{2}})}{\sum_{R(t% _{k})}e^{({W^{P}\bf{P_{k}}+\frac{b}{2}})}}),= ( divide start_ARG ( italic_W start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT bold_G start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT + divide start_ARG bold_b end_ARG start_ARG bold_2 end_ARG ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_R ( italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT ( italic_W start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT bold_G start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT + divide start_ARG bold_b end_ARG start_ARG bold_2 end_ARG ) end_POSTSUPERSCRIPT end_ARG ) / ( divide start_ARG ( italic_W start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT bold_P start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT + divide start_ARG bold_b end_ARG start_ARG bold_2 end_ARG ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_R ( italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT ( italic_W start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT bold_P start_POSTSUBSCRIPT bold_k end_POSTSUBSCRIPT + divide start_ARG bold_b end_ARG start_ARG bold_2 end_ARG ) end_POSTSUPERSCRIPT end_ARG ) , (5)
ρPsuperscript𝜌𝑃\displaystyle\rho^{P}italic_ρ start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT =1/ρG.absent1superscript𝜌𝐺\displaystyle=1/\rho^{G}.= 1 / italic_ρ start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT . (6)

Formally, inspired by [8], we modulate the gradient by estimating the contributions of the two modalities. We update the model parameters in each iteration t𝑡titalic_t, as:

ϕt+1Gsubscriptsuperscriptitalic-ϕ𝐺𝑡1\displaystyle\phi^{G}_{t+1}italic_ϕ start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT =ϕtGηmin(1tanh(ρG1),1)g(ϕtG),absentsubscriptsuperscriptitalic-ϕ𝐺𝑡𝜂1tanhsuperscript𝜌𝐺11𝑔subscriptsuperscriptitalic-ϕ𝐺𝑡\displaystyle=\phi^{G}_{t}-\eta\cdot\min(1-\text{tanh}(\rho^{G}-1),1)\cdot g(% \phi^{G}_{t}),= italic_ϕ start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_η ⋅ roman_min ( 1 - tanh ( italic_ρ start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT - 1 ) , 1 ) ⋅ italic_g ( italic_ϕ start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , (7)
ϕt+1Psubscriptsuperscriptitalic-ϕ𝑃𝑡1\displaystyle\phi^{P}_{t+1}italic_ϕ start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT =ϕtPηmin(1tanh(ρP1),1)g(ϕtP),absentsubscriptsuperscriptitalic-ϕ𝑃𝑡𝜂1tanhsuperscript𝜌𝑃11𝑔subscriptsuperscriptitalic-ϕ𝑃𝑡\displaystyle=\phi^{P}_{t}-\eta\cdot\min(1-\text{tanh}(\rho^{P}-1),1)\cdot g(% \phi^{P}_{t}),= italic_ϕ start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_η ⋅ roman_min ( 1 - tanh ( italic_ρ start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT - 1 ) , 1 ) ⋅ italic_g ( italic_ϕ start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , (8)

where ϕt+1Gsubscriptsuperscriptitalic-ϕ𝐺𝑡1\phi^{G}_{t+1}italic_ϕ start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT and g(ϕtG)𝑔subscriptsuperscriptitalic-ϕ𝐺𝑡g(\phi^{G}_{t})italic_g ( italic_ϕ start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) represent the parameters and gradient for the MLP in the genomics network, ϕt+1Psubscriptsuperscriptitalic-ϕ𝑃𝑡1\phi^{P}_{t+1}italic_ϕ start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT and g(ϕtP)𝑔subscriptsuperscriptitalic-ϕ𝑃𝑡g(\phi^{P}_{t})italic_g ( italic_ϕ start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) represent the parameters and gradient for the pathology image network, and η𝜂\etaitalic_η is the learning rate. In this way, the learning rate of the modality with higher ρ𝜌\rhoitalic_ρ is suppressed.

3 Experiments

3.1 Datasets and Implementation Details

Our experiments utilize two datasets [3] that consist of hematoxylin and eosin (H&E)-stained pathology images, corresponding genomic features (mutations, copy number variations (CNV), RNA-seq), and patient survival outcomes. Specifically, the TCGA-GBMLGG dataset contains 1505 gliomas (brain and spinal cord tumors) samples, and the TCGA-KIRC dataset contains 1251 clear cell renal cell carcinoma samples. We apply the same experimental protocol as [3] by evaluating model performance with 15-fold cross-validation.

In our pipeline, the CNV and mutation information is handled by SNN with the same architecture as [3]. The bulk RNA-seq data are processed with scGPT [1]. Both MLP-A and MLP-B are three-layer perceptron networks with hidden layer dimension 128. The classifier is a linear layer map** the MLP-A generated feature vector to 17 cell type categories. The batch size is 32. Our model is implemented with PyTorch and is trained on an NVIDIA A10 GPU.

3.2 Results

Main Results. Table 1 and Table 2 show our main results. Our model is compared with the baselines of SCNN, SGCNN [19], Pathomic Fusion [3], and the supervised multi-modal setting in [6]. For fair comparison, we also replace the CNN in Pathomic Fusion [3] with the T2T-ViT [15] backbone to extract features from pathology images. We use C-Index to measure performance. First, our model outperforms all the baselines on both datasets. Second, similar to the setting “Pathomic Fusion [3] (T2T-ViT [15] + SNN)”, we also use the backbone T2T-ViT and SNN to process the image features and CNV features. Our model gains around 2%percent22\%2 % improvement on both datasets. This indicates the effectiveness of our design for regulating scGPT embeddings and gradient modulation.

Table 1: Comparison of C-Index performance on the TCGA-GBMLGG dataset (p<0.05𝑝0.05p<0.05italic_p < 0.05).
Method C-Index
SCNN (Histology Only) [19] 0.754
Histology CNN [3] 0.792 ± 0.014
Histology T2T-ViT [15] 0.803 ± 0.016
Genomic SNN [3] 0.808 ± 0.014
GSCNN (Histology + Genomic) [19] 0.781
Pathomic Fusion [3] (CNN + SNN) 0.820 ± 0.009
Pathomic Fusion [3] (T2T-ViT [15] + SNN) 0.826 ± 0.010
PathOmics [4] 0.833 ± 0.012
Ours 0.848 ± 0.014
Table 2: Comparison of C-Index performance on the TCGA-KIRC dataset (p<0.05𝑝0.05p<0.05italic_p < 0.05).
Method C-Index
Histology CNN [3] 0.671 ± 0.023
Histology T2T-ViT [15] 0.683 ± 0.023
Genomic SNN [3] 0.684 ± 0.025
Pathomic Fusion [3] (CNN + SNN) 0.719 ± 0.031
Pathomic Fusion [3] (T2T-ViT [15] + SNN) 0.727 ± 0.033
PathOmics [4] 0.736 ± 0.024
Ours 0.754 ± 0.030

Ablation Study. We evaluate the impact of each of our key model components in Table 3. First, we find that directly applying scGPT is a sub-optimal solution as the performance is improved only by 0.05%percent0.050.05\%0.05 % (see Exp. 1 and Exp. 2). Second, by incorporating our mix-up regulation module to smooth the latent space, the scGPT’s ability to process bulk RNA genomics data is realized (see Exp. 1 and Exp. 4). Third, our gradient modulation (GradMod.) can effectively help improve the performance, in comparison with direct concatenation and the performance with the Kronecker product-based fusion in [3] (see the comparison between Exp. 2 and Exp. 3, and the comparison among Exp. 4, Exp. 5, and Exp. 6).

Table 3: Ablation study on the TCGA-GBMLGG dataset.
Exp. Encoding RNA-seq Fusion C-Index
1 SNN Concat 0.826 ± 0.010
2 scGPT w/o smooth Concat 0.831 ± 0.012
3 scGPT w/o smooth GradMod. 0.840 ± 0.013
4 scGPT w/ smooth Concat 0.838 ± 0.013
5 scGPT w/ smooth Kro. Prod. [3] 0.840 ± 0.013
6 scGPT w/ smooth GradMod. 0.848 ± 0.014

4 Conclusions

In this paper, we addressed two main challenges in pathology image and genomics data fusion. First, we showed how to effectively use the advanced foundation model scGPT, originally designed for single cell RNA-seq, for processing bulk RNA-seq data. Second, we tackled the imbalance training problem between image modality and genomics modality by proposing gradient modulation for the Cox partial likelihood loss. Evaluated on two TCGA datasets, our model effectively improved the performance compared to the baseline models.

References

  • [1] Haotian Cui, Chloe Wang, Hassaan Maan, Kuan Pang, Fengning Luo, and Bo Wang, “scGPT: Towards building a foundation model for single-cell multi-omics using generative AI,” bioRxiv, pp. 2023–04, 2023.
  • [2] Xinmin Li and Cun-Yu Wang, “From bulk, single-cell to spatial RNA sequencing,” International Journal of Oral Science, vol. 13, no. 1, pp. 36, 2021.
  • [3] Richard J Chen, Ming Y Lu, **gwen Wang, Drew FK Williamson, Scott J Rodig, Neal I Lindeman, and Faisal Mahmood, “Pathomic fusion: An integrated framework for fusing histopathology and genomic features for cancer diagnosis and prognosis,” IEEE Transactions on Medical Imaging, vol. 41, no. 4, pp. 757–770, 2020.
  • [4] Richard J Chen, Ming Y Lu, Drew FK Williamson, Tiffany Y Chen, Jana Lipkova, Zahra Noor, Muhammad Shaban, Maha Shady, Mane Williams, Bum** Joo, et al., “Pan-cancer integrative histology-genomic analysis via multimodal deep learning,” Cancer Cell, vol. 40, no. 8, pp. 865–878, 2022.
  • [5] Richard J Chen, Ming Y Lu, Wei-Hung Weng, Tiffany Y Chen, Drew FK Williamson, Trevor Manz, Maha Shady, and Faisal Mahmood, “Multimodal co-attention Transformer for survival prediction in gigapixel whole slide images,” in Proceedings of the IEEE/CVF International Conference on Computer Vision, 2021, pp. 4015–4025.
  • [6] Kexin Ding, Mu Zhou, Dimitris N Metaxas, and Shaoting Zhang, “Pathology-and-genomics multimodal Transformer for survival outcome prediction,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2023, pp. 622–631.
  • [7] Anika Cheerla and Olivier Gevaert, “Deep learning with multimodal representation for pancancer prognosis prediction,” Bioinformatics, vol. 35, no. 14, pp. i446–i454, 2019.
  • [8] Xiaokang Peng, Yake Wei, Andong Deng, Dong Wang, and Di Hu, “Balanced multimodal learning via on-the-fly gradient modulation,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2022, pp. 8238–8247.
  • [9] Ruize Xu, Ruoxuan Feng, Shi-Xiong Zhang, and Di Hu, “MMCosine: Multi-modal cosine loss towards balanced audio-visual fine-grained learning,” in ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2023, pp. 1–5.
  • [10] Aviv Regev, Sarah A Teichmann, Eric S Lander, Ido Amit, Christophe Benoist, Ewan Birney, Bernd Bodenmiller, Peter J Campbell, Piero Carninci, Menna Clatworthy, et al., “Science forum: The human cell atlas,” eLife, vol. 6, pp. e27041, 2017.
  • [11] Yahui Liu, Enver Sangineto, Ya**g Chen, Linchao Bao, Haoxian Zhang, Nicu Sebe, Bruno Lepri, and Marco De Nadai, “Smooth image-to-image translations with latent space interpolations,” arXiv preprint arXiv:2210.00841, 2022.
  • [12] Hongxiao Wang, Hao Zheng, Jianxu Chen, Lin Yang, Yizhe Zhang, and Danny Z Chen, “Unlabeled data guided semi-supervised histopathology image segmentation,” in 2020 IEEE International Conference on Bioinformatics and Biomedicine (BIBM). IEEE, 2020, pp. 815–820.
  • [13] Christopher Beckham, Sina Honari, Vikas Verma, Alex M Lamb, Farnoosh Ghadiri, R Devon Hjelm, Yoshua Bengio, and Chris Pal, “On adversarial mixup resynthesis,” Advances in Neural Information Processing Systems, vol. 32, 2019.
  • [14] Travers Ching, Xun Zhu, and Lana X Garmire, “Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data,” PLoS Computational Biology, vol. 14, no. 4, pp. e1006076, 2018.
  • [15] Li Yuan, Yunpeng Chen, Tao Wang, Weihao Yu, Yujun Shi, Zi-Hang Jiang, Francis EH Tay, Jiashi Feng, and Shuicheng Yan, “Tokens-to-Token ViT: Training Vision Transformers from scratch on ImageNet,” in Proceedings of the IEEE/CVF International Conference on Computer Vision, 2021, pp. 558–567.
  • [16] Yahui Liu, Enver Sangineto, Ya**g Chen, Linchao Bao, Haoxian Zhang, Nicu Sebe, Bruno Lepri, Wei Wang, and Marco De Nadai, “Smoothing the disentangled latent style space for unsupervised image-to-image translation,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2021, pp. 10785–10794.
  • [17] Momchil Peychev, Anian Ruoss, Mislav Balunović, Maximilian Baader, and Martin Vechev, “Latent space smoothing for individually fair representations,” in European Conference on Computer Vision. Springer, 2022, pp. 535–554.
  • [18] David Berthelot, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, and Colin A Raffel, “MixMatch: A holistic approach to semi-supervised learning,” Advances in Neural Information Processing Systems, vol. 32, 2019.
  • [19] Pooya Mobadersany, Safoora Yousefi, Mohamed Amgad, David A Gutman, Jill S Barnholtz-Sloan, José E Velázquez Vega, Daniel J Brat, and Lee AD Cooper, “Predicting cancer outcomes from histology and genomics using convolutional networks,” Proceedings of the National Academy of Sciences, vol. 115, no. 13, pp. E2970–E2979, 2018.