Multimodal Prototy** for cancer survival prediction

Andrew H. Song    Richard J. Chen    Guillaume Jaume    Anurag Vaidya    Alexander S. Baras    Faisal Mahmood
Abstract

Multimodal survival methods combining gigapixel histology whole-slide images (WSIs) and transcriptomic profiles are particularly promising for patient prognostication and stratification. Current approaches involve tokenizing the WSIs into smaller patches (>104absentsuperscript104>10^{4}> 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT patches) and transcriptomics into gene groups, which are then integrated using a Transformer for predicting outcomes. However, this process generates many tokens, which leads to high memory requirements for computing attention and complicates post-hoc interpretability analyses. Instead, we hypothesize that we can: (1) effectively summarize the morphological content of a WSI by condensing its constituting tokens using morphological prototypes, achieving more than 300×300\times300 × compression; and (2) accurately characterize cellular functions by encoding the transcriptomic profile with biological pathway prototypes, all in an unsupervised fashion. The resulting multimodal tokens are then processed by a fusion network, either with a Transformer or an optimal transport cross-alignment, which now operates with a small and fixed number of tokens without approximations. Extensive evaluation on six cancer types shows that our framework outperforms state-of-the-art methods with much less computation while unlocking new interpretability analyses. The code is available at https://github.com/mahmoodlab/MMP.

Machine Learning, ICML

1 Introduction

Patient prognostication – the task of predicting the progression of a disease – is a cornerstone of clinical research and can help identify novel biomarkers indicative of disease progression (Song et al., 2023, 2024b). Prognostication is often cast as predicting survival based on a series of assays describing the patient’s medical state. Due to the complexity and diverse aspects of prognostication, multimodal approaches that combine histology and omics data, such as transcriptomics (Acosta et al., 2022), are particularly promising. Histology is represented through whole-slide images (WSIs), which offer a detailed spatial depiction of the tissue, such as a tumor, with resolutions that can exceed 105×105superscript105superscript10510^{5}\times 10^{5}10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT × 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT pixels. Differently, transcriptomics is often delineated through bulk RNA sequencing, which provides insights into gene expression. The complementary information in both modalities was shown to be predictive of survival and can be used to inform disease progression (Chen et al., 2022; Lipkova et al., 2022; Steyaert et al., 2023). However, the distinct characteristics of each modality present challenges in effectively integrating them together.

WSI modeling is typically done with multiple instance learning (MIL) (Ilse et al., 2018; Campanella et al., 2019; Lu et al., 2021). This method involves (1) extracting the set of patches that constitute the WSI (>104absentsuperscript104>10^{4}> 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT per WSI), (2) feeding them through a pre-trained patch encoder to generate patch embeddings, and (3) aggregating the patch embeddings with a pooling network. In contrast, transcriptomics modeling can be done using a feed-forward neural network, treating each gene expression as a tabular data entry, or by grou** them into coarse gene families (Chen et al., 2021; Zhou & Chen, 2023; Xu & Chen, 2023) or biological pathways (Elmarakeby et al., 2021; Jaume et al., 2024). The set of patch embeddings and gene groups can then be seen as tokens, which can be fed to a Transformer to derive a multimodal representation used for outcome prediction.

However, fusing large numbers of tokens with a Transformer is computationally expensive, and most approaches resort to attention approximation (Shao et al., 2021; Jaume et al., 2024), cross-attention (Chen et al., 2021; Zhou & Chen, 2023; Xu & Chen, 2023), or token subsampling (Wulczyn et al., 2020; Xu & Chen, 2023). Even when employing alternatives to Transformers, such as optimal transport (OT) cross-modal alignment (Duan et al., 2022; Pramanick et al., 2022), addressing a set of tokens remains challenging. The small size of multimodal cohorts, often just a few hundred samples, intensifies this issue, resulting in a Large-p (large input dimensionality), Small-n (small sample size) problem. Moreover, interpreting how thousands of tokens interact and contribute to patient-level prediction, which is crucial for clinical insights, presents a significant challenge.

Instead, we hypothesize that we can summarize the patch embeddings using morphological prototypes. Indeed, due to inherent morphological redundancy in human tissue, the histology patches that constitute the WSI can be assumed as variations of key morphologies, e.g., clear cell tumor, necrosis, benign stroma, etc, which we can extract and encode. In molecular pathology, decades of research have identified biological pathways that encode specific cellular functions (Liberzon et al., 2015; Elmarakeby et al., 2021), which we can leverage to define pathway prototypes. This drastically reduces the number of tokens before multimodal fusion, thereby opening up possibilities for seamless integration of diverse fusion strategies, with interpretability greatly simplified. The challenge then revolves around the extraction and encoding of meaningful multimodal prototypes.

Here, we introduce a MultiModal Prototy** framework for patient prognostication (MMP). Inspired by prototype-based aggregation (Mialon et al., 2021; Kim, 2022; Song et al., 2024a), we construct an unsupervised and compact WSI representation with a Gaussian mixture model, where the mixture parameters define the slide summary, each map** to a morphological prototype (16 to 32 prototypes). Following existing work (Jaume et al., 2024), we transform transcriptomics into a set of 50 Cancer Hallmark pathway prototypes (Liberzon et al., 2015). With significantly fewer tokens, we show that multimodal Transformers can be readily applied to the joint set of histology and pathway tokens without relying on approximations. In addition, we establish a connection between Optimal Transport (OT) cross-alignment, a popular alternative for cross-modal alignment, and the Transformer cross-attention, thereby unifying both under a single framework. On six cancer cohorts from The Cancer Genome Atlas (TCGA), MMP outperforms nearly all uni- and multimodal baselines with a much smaller number of operations, demonstrating the predictive performance and efficiency of prototype-based approaches. Finally, the tractable number of tokens allows visualization of bi-directional interactions between the morphological and pathway prototypes, different from previous multimodal frameworks relying on uni-directional interpretation.

To summarize, our contributions are (1) a method for summarizing slides using morphological prototypes and summarizing transcriptomic profiles using established biological pathway prototypes; (2) a unified and memory-efficient multimodal fusion framework; (3) extensive evaluation and ablation experiments on six cancer cohorts highlighting the predictive power of the MMP; (4) a novel multimodal patient representation that enables novel interpretability analyses.

Refer to caption
Figure 1: Overview of MMP. (A) The tessellated WSI patches (tokens) are projected to low-dimensional embeddings with a pretrained patch encoder. The patch embeddings (Nh.>104subscript𝑁h.superscript104N_{\text{h.}}>10^{4}italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT > 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT) are aggregated to slide summary using a small set of prototypes (Ch.<subscript𝐶h.absentC_{\text{h.}}<italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT <32). (B) The transcriptomics data is projected onto a set of binary vectors indicating the presence of specific genes in each pathway, forming pathway summary. (C) The post-aggregation embeddings from both modalities are first matched to the same dimension. Cross-modal interactions between histology and transcriptomics are learned with a Transformer or an Optimal Transport, with intra-modal interactions learned with Transformer-based self-attention. The attended embeddings are aggregated to form a patient-level embedding used for risk prediction.

2 Related Work

2.1 Representing sets with prototypes

With NLP and bioinformatics producing more datasets represented as sets, recent approaches have explored representation learning of sets with prototypes, i.e., interpretable exemplars that can encode distinct concepts (Snell et al., 2017; Lee et al., 2019; Mialon et al., 2021; Kim, 2022; dan Guo et al., 2022; Lee et al., 2024). In computational pathology, AttnMISL (Yao et al., 2020) and H2T (Vu et al., 2023) perform K-means clustering within each WSI and use the cluster centroid embeddings as prototype (hard clustering). The proportion of patches in the cluster has also been used to represent the cluster (Quiros et al., 2023). MMP extends PANTHER (Song et al., 2024a) to a multimodal setting and formalizes the prototype-based set representation with a mathematical treatise and generalizes this concept to include Gaussian mixture models, OT, and clustering.

2.2 Prognostication with multimodal fusion

Late fusion. Early works exploring multimodal survival employed late fusion techniques based on merging unimodal representations, for instance using concatenation or Kronecker product (Chen et al., 2020b). While initial frameworks incorporated histology with small region-of-interests (Mobadersany et al., 2018; Wang et al., 2021), the development of MIL has enabled slide-level prognostication studies combined with omics (Chen et al., 2022; Ding et al., 2023; Volinsky-Fremond et al., 2024). However, late fusion methods are limited in modeling local cross-modal interactions potentially predictive of prognosis.
Transformer fusion. Transformers have facilitated progress in multimodal fusion by modeling interactions between cross-modal tokens (or “early fusion”). To address the computational complexity of dealing with a large number of tokens, token subsampling can be performed (Wulczyn et al., 2020; Xu & Chen, 2023), or fusion can be simplified to cross-attention (Chen et al., 2021; Jaume et al., 2024; Zhou & Chen, 2023). Present strategies allow the use of a single Transformer for merging tokens and facilitating early-fusion across various modalities (Jaegle et al., 2022; Girdhar et al., 2022; Wang et al., 2023; Liang et al., 2023; Wang et al., 2024). However, these works remain limited by the large number of histology tokens (>104absentsuperscript104>10^{4}> 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT). In contrast, MMP does not require approximation due to the prototype-based formulation significantly reducing the number of tokens. A recent work (Zhang et al., 2024) employs prototypes to remove intra- and inter-modality redundancy, but is specific to a time-discretized survival formulation. In contrast, MMP allows flexible survival problem formulation.
Optimal Transport fusion. OT-based cross-alignment between sets of multimodal tokens (Chen et al., 2020a; Cao et al., 2022; Duan et al., 2022; Pramanick et al., 2023) has gained interest as an alternative to Transformers. In multimodal prognosis, MOTCat (Xu et al., 2023) and MMP replace Transformer-based cross-attention with an OT.

3 Methods

We introduce MMP, a MultiModal Prototy** framework for survival prediction. We describe the construction of morphological and pathway prototypes (Section 3.1) and the multimodal fusion mechanism (Section 3.2). We finish by describing survival prediction (Section 3.3) and prototype-specific designs (Section 3.4).

As for notations, z𝑧zitalic_z represents a scalar, 𝐳𝐳\mathbf{z}bold_z a vector, and 𝐙𝐙\mathbf{Z}bold_Z a matrix. For a set of vectors {𝐳c}c=1Cdsuperscriptsubscriptsubscript𝐳𝑐𝑐1𝐶superscript𝑑\{\mathbf{z}_{c}\}_{c=1}^{C}\in\mathbb{R}^{d}{ bold_z start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, 𝐙C×d𝐙superscript𝐶𝑑\mathbf{Z}\in\mathbb{R}^{C\times d}bold_Z ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_d end_POSTSUPERSCRIPT represents the corresponding matrix with 𝐳csubscript𝐳𝑐\mathbf{z}_{c}bold_z start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT as the cthsuperscript𝑐thc^{\text{th}}italic_c start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT row entry. The notation [𝐗,𝐙]𝐗𝐙\big{[}\mathbf{X},\mathbf{Z}\big{]}[ bold_X , bold_Z ] is used to indicate concatenation.

3.1 Prototype-based encoding

3.1.1 Morphological prototypes (Histology)

Preprocessing. Given a WSI, we divide it into Nh.subscript𝑁h.N_{\text{h.}}italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT non-overlap** patches at 20×\times× magnification (0.5μm0.5𝜇𝑚0.5\mu m0.5 italic_μ italic_m/pixel), forming a set of histology patches {𝐱i,h.}i=1Nh.superscriptsubscriptsubscript𝐱𝑖h.𝑖1subscript𝑁h.\{\mathbf{x}_{i,\text{h.}}\}_{i=1}^{N_{\text{h.}}}{ bold_x start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT of varying cardinality, with typically Nh.>104subscript𝑁h.superscript104N_{\text{h.}}>10^{4}italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT > 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT. Each 𝐱i,h.subscript𝐱𝑖h.\mathbf{x}_{i,\text{h.}}bold_x start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT is then mapped to a low-dimensional embedding using a pretrained patch encoder fenc()subscript𝑓encf_{\text{enc}}(\cdot)italic_f start_POSTSUBSCRIPT enc end_POSTSUBSCRIPT ( ⋅ ), such that 𝐳i,h.=fenc(𝐱i,h.)Dsubscript𝐳𝑖h.subscript𝑓encsubscript𝐱𝑖h.superscript𝐷\mathbf{z}_{i,\text{h.}}=f_{\text{enc}}(\mathbf{x}_{i,\text{h.}})\in\mathbb{R}% ^{D}bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT enc end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT.
Aggregation. The standard practice for WSI-based outcome prediction uses Multiple Instance Learning (MIL) to aggregate patch embeddings into a slide embedding, where patch embeddings 𝕊h.={𝐳i,h.}i=1Nh.subscript𝕊h.superscriptsubscriptsubscript𝐳𝑖h.𝑖1subscript𝑁h.\mathbb{S}_{\text{h.}}=\{\mathbf{z}_{i,\text{h.}}\}_{i=1}^{N_{\text{h.}}}blackboard_S start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = { bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are pooled using a learnable function ϕh.():𝕊h.D:superscriptitalic-ϕh.subscript𝕊h.superscript𝐷\phi^{\text{h.}}(\cdot):\mathbb{S}_{\text{h.}}\rightarrow\mathbb{R}^{D}italic_ϕ start_POSTSUPERSCRIPT h. end_POSTSUPERSCRIPT ( ⋅ ) : blackboard_S start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, and forms a post-aggregation slide embedding 𝐳h.agg.=ϕh.(𝕊h.)superscriptsubscript𝐳h.agg.superscriptitalic-ϕh.subscript𝕊h.\mathbf{z}_{\text{h.}}^{\text{agg.}}=\phi^{\text{h.}}(\mathbb{S}_{\text{h.}})bold_z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT = italic_ϕ start_POSTSUPERSCRIPT h. end_POSTSUPERSCRIPT ( blackboard_S start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ). In the prototype-based approach, we take an alternate route by defining Ch.subscript𝐶h.C_{\text{h.}}italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT aggregation functions ϕch.():𝕊h.dh.:superscriptsubscriptitalic-ϕ𝑐h.subscript𝕊h.superscriptsubscript𝑑h.\phi_{c}^{\text{h.}}(\cdot):\mathbb{S}_{\text{h.}}\rightarrow\mathbb{R}^{d_{% \text{h.}}}italic_ϕ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT h. end_POSTSUPERSCRIPT ( ⋅ ) : blackboard_S start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, such that 𝐳c,h.agg.=ϕch.(𝕊h.)superscriptsubscript𝐳𝑐h.agg.superscriptsubscriptitalic-ϕ𝑐h.subscript𝕊h.\mathbf{z}_{c,\text{h.}}^{\text{agg.}}=\phi_{c}^{\text{h.}}(\mathbb{S}_{\text{% h.}})bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT = italic_ϕ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT h. end_POSTSUPERSCRIPT ( blackboard_S start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ), where Ch.subscript𝐶h.C_{\text{h.}}italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT is the number of prototypes. This produces a set 𝕊slide={𝐳c,h.agg.}c=1Ch.subscript𝕊slidesuperscriptsubscriptsuperscriptsubscript𝐳𝑐h.agg.𝑐1subscript𝐶h.\mathbb{S}_{\text{slide}}=\{\mathbf{z}_{c,\text{h.}}^{\text{agg.}}\}_{c=1}^{C_% {\text{h.}}}blackboard_S start_POSTSUBSCRIPT slide end_POSTSUBSCRIPT = { bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, referred to as slide summary.
Prototypes. We define prototypes, denoted as {𝐚c,h.}c=1Ch.superscriptsubscriptsubscript𝐚𝑐h.𝑐1subscript𝐶h.\{\mathbf{a}_{c,\text{h.}}\}_{c=1}^{C_{\text{h.}}}{ bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, 𝐚c,h.dh.subscript𝐚𝑐h.superscriptsubscript𝑑h.\mathbf{a}_{c,\text{h.}}\in\mathbb{R}^{d_{\text{h.}}}bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, such that each prototype exemplifies a unique morphology from the training set. Specifically, we apply K-means clustering on all the patch embeddings from the training dataset to extract the Ch.subscript𝐶h.C_{\text{h.}}italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT cluster centroids. The prototypes {𝐚c,h.}c=1Ch.superscriptsubscriptsubscript𝐚𝑐h.𝑐1subscript𝐶h.\{\mathbf{a}_{c,\text{h.}}\}_{c=1}^{C_{\text{h.}}}{ bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are then used as parameters for the prototype-specific aggregation function ϕch.superscriptsubscriptitalic-ϕ𝑐h.\phi_{c}^{\text{h.}}italic_ϕ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT h. end_POSTSUPERSCRIPT (Fig. 1A).
Slide summary. Given the prototype 𝐚c,h.subscript𝐚𝑐h.\mathbf{a}_{c,\text{h.}}bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT and patch embeddings 𝕊h.subscript𝕊h.\mathbb{S}_{\text{h.}}blackboard_S start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT, we can express ϕch.superscriptsubscriptitalic-ϕ𝑐h.\phi_{c}^{\text{h.}}italic_ϕ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT h. end_POSTSUPERSCRIPT as

𝐳c,h.agg.=ϕch.(𝕊h.,𝐚c,h.)=i=1Nh.g(𝐳i,h.,𝐚c,h.),c.formulae-sequencesuperscriptsubscript𝐳𝑐h.agg.superscriptsubscriptitalic-ϕ𝑐h.subscript𝕊h.subscript𝐚𝑐h.superscriptsubscript𝑖1subscript𝑁h.𝑔subscript𝐳𝑖h.subscript𝐚𝑐h.for-all𝑐\mathbf{z}_{c,\text{h.}}^{\text{agg.}}=\phi_{c}^{\text{h.}}(\mathbb{S}_{\text{% h.}},\mathbf{a}_{c,\text{h.}})=\sum_{i=1}^{N_{\text{h.}}}g(\mathbf{z}_{i,\text% {h.}},\mathbf{a}_{c,\text{h.}}),\,\,\forall c.bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT = italic_ϕ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT h. end_POSTSUPERSCRIPT ( blackboard_S start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_g ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT ) , ∀ italic_c . (1)

This allows the large variable-length set of patch embeddings to be represented using a small fixed-length set comprised of prototypical tokens. We use Ch.32subscript𝐶h.32C_{\text{h.}}\leq 32italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ≤ 32 with typically Nh.>104subscript𝑁h.superscript104N_{\text{h.}}>10^{4}italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT > 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT, achieving more than 300×\times× reduction. Eq. 1 implies that aggregation is performed by summing the contribution from all embeddings in 𝕊h.subscript𝕊h.\mathbb{S}_{\text{h.}}blackboard_S start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT, the exact manner of which is determined by the map** g𝑔gitalic_g.

We explore three strategies for defining g𝑔gitalic_g: hard clustering (HC), OT, and our preferred choice, Gaussian Mixture Models (GMM). We briefly explain GMM, similar to the developments in PANTHER (Song et al., 2024a), and defer the detailed explanations of other strategies in Appendix A.
GMM-based slide summarization. With GMM as the generative model, the probability distribution for 𝐳i,h.subscript𝐳𝑖h.\mathbf{z}_{i,\text{h.}}bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT is

p(𝐳i,h.;θ)=c=1Ch.p(ci=c)p(𝐳i,h.|ci=c)=c=1Ch.πc𝒩(𝐳i,h.;𝝁c,Σc),𝑝subscript𝐳𝑖h.𝜃superscriptsubscript𝑐1subscript𝐶h.𝑝subscript𝑐𝑖𝑐𝑝conditionalsubscript𝐳𝑖h.subscript𝑐𝑖𝑐superscriptsubscript𝑐1subscript𝐶h.subscript𝜋𝑐𝒩subscript𝐳𝑖h.subscript𝝁𝑐subscriptΣ𝑐\begin{split}p(\mathbf{z}_{i,\text{h.}};\theta)&=\sum_{c=1}^{C_{\text{h.}}}p(c% _{i}=c)\cdot p(\mathbf{z}_{i,\text{h.}}|c_{i}=c)\\ &=\sum_{c=1}^{C_{\text{h.}}}\pi_{c}\cdot\mathcal{N}(\mathbf{z}_{i,\text{h.}};% \boldsymbol{\mu}_{c},\Sigma_{c}),\\ \end{split}start_ROW start_CELL italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ ) end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ) ⋅ italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT | italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ⋅ caligraphic_N ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) , end_CELL end_ROW (2)

where cisubscript𝑐𝑖c_{i}italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT denotes the mixture identity of 𝐳i,h.subscript𝐳𝑖h.\mathbf{z}_{i,\text{h.}}bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT and θ={πc,𝝁c,Σc}c=1Ch.𝜃superscriptsubscriptsubscript𝜋𝑐subscript𝝁𝑐subscriptΣ𝑐𝑐1subscript𝐶h.\theta=\{\pi_{c},\boldsymbol{\mu}_{c},\Sigma_{c}\}_{c=1}^{C_{\text{h.}}}italic_θ = { italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT denotes mixture probability, mean, and diagonal covariance. Intuitively, 𝝁csubscript𝝁𝑐\boldsymbol{\mu}_{c}bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and πcsubscript𝜋𝑐\pi_{c}italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT represent a morphological exemplar and the proportion of similar patterns in WSI, respectively. The posterior distribution p(ci=c|𝐳i,h.)𝑝subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.p(c_{i}=c|\mathbf{z}_{i,\text{h.}})italic_p ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ) indirectly represents the distance between 𝐳i,h.subscript𝐳𝑖h.\mathbf{z}_{i,\text{h.}}bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT and 𝐚c,h.subscript𝐚𝑐h.\mathbf{a}_{c,\text{h.}}bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT and consequently its contribution towards each element of the slide summary. We obtain the maximum-likelihood estimate, θ^=argmaxθi=1Nh.logp(𝐳i,h.;θ)^𝜃subscript𝜃superscriptsubscript𝑖1subscript𝑁h.𝑝subscript𝐳𝑖h.𝜃\widehat{\theta}=\arg\max_{\theta}\sum_{i=1}^{N_{\text{h.}}}\log p(\mathbf{z}_% {i,\text{h.}};\theta)over^ start_ARG italic_θ end_ARG = roman_arg roman_max start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ ), via expectation-maximization (EM), which can be performed as a feedforward network operation (Dempster et al., 1977; Kim, 2022).

The estimated GMM parameters are concatenated to form a post-aggregation embedding, 𝐳c,h.agg.=[π^c,𝝁^c,Σ^c]dh.superscriptsubscript𝐳𝑐h.agg.subscript^𝜋𝑐subscript^𝝁𝑐subscript^Σ𝑐superscriptsubscript𝑑h.\mathbf{z}_{c,\text{h.}}^{\text{agg.}}=[\widehat{\pi}_{c},\widehat{\boldsymbol% {\mu}}_{c},\widehat{\Sigma}_{c}]\in\mathbb{R}^{d_{\text{h.}}}bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT = [ over^ start_ARG italic_π end_ARG start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , over^ start_ARG bold_italic_μ end_ARG start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with dh.=2D+1subscript𝑑h.2𝐷1d_{\text{h.}}=2D+1italic_d start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 2 italic_D + 1. During EM-based inference, the prototypes 𝐚c,h.subscript𝐚𝑐h.\mathbf{a}_{c,\text{h.}}bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT are used as the initial parameters for the mixture means, 𝝁c(0)=𝐚c,h.superscriptsubscript𝝁𝑐0subscript𝐚𝑐h.\boldsymbol{\mu}_{c}^{(0)}=\mathbf{a}_{c,\text{h.}}bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT with ΣcsubscriptΣ𝑐\Sigma_{c}roman_Σ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT as the identity matrix. Owing to GMM’s soft clustering nature, the distribution p(ci=c|𝐳i,h.)𝑝subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.p(c_{i}=c|\mathbf{z}_{i,\text{h.}})italic_p ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ) is non-zero, implying that all elements of 𝕊h.subscript𝕊h.\mathbb{S}_{\text{h.}}blackboard_S start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT contribute to 𝐳c,h.agg.superscriptsubscript𝐳𝑐h.agg.\mathbf{z}_{c,\text{h.}}^{\text{agg.}}bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT (Eq. 1). We note that deriving the slide summary (𝕊slidesubscript𝕊slide\mathbb{S}_{\text{slide}}blackboard_S start_POSTSUBSCRIPT slide end_POSTSUBSCRIPT or 𝐙h.agg.Ch.×dh.superscriptsubscript𝐙h.agg.superscriptsubscript𝐶h.subscript𝑑h.\mathbf{Z}_{\text{h.}}^{\text{agg.}}\in\mathbb{R}^{C_{\text{h.}}\times d_{% \text{h.}}}bold_Z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT) from the patch embeddings is done in an unsupervised manner, and drastically reduces the input size to the multimodal fusion model.

3.1.2 Pathway prototypes (Genomics)

We aim to define a similar compact prototypical representation of the transcriptomic profile. The transcriptomic profile spans Ng.subscript𝑁g.N_{\text{g.}}italic_N start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT gene expressions for each tissue, and is described as {xi,g.}i=1Ng.superscriptsubscriptsubscript𝑥𝑖g.𝑖1subscript𝑁g.\{x_{i,\text{g.}}\}_{i=1}^{N_{\text{g.}}}{ italic_x start_POSTSUBSCRIPT italic_i , g. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, with xi,g.subscript𝑥𝑖g.x_{i,\text{g.}}\in\mathbb{R}italic_x start_POSTSUBSCRIPT italic_i , g. end_POSTSUBSCRIPT ∈ blackboard_R.
Prototypes. We tokenize gene expression into biological pathway prototypes, i.e., into groups of genes that interact in certain ways to implement previously described cellular processes (Liberzon et al., 2015; Reimand et al., 2019). Unlike histology, the number Cg.subscript𝐶g.C_{\text{g.}}italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT (e.g., Cg.subscript𝐶g.C_{\text{g.}}italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT=50 for Hallmark pathways) and composition of prototypes is fixed and can be defined using existing biological pathway databases.

We define the prototypes as {𝐚c,g.}c=1Cg.superscriptsubscriptsubscript𝐚𝑐g.𝑐1subscript𝐶g.\{\mathbf{a}_{c,\text{g.}}\}_{c=1}^{C_{\text{g.}}}{ bold_a start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, where the binary vector 𝐚c,g.{0,1}Ng.subscript𝐚𝑐g.superscript01subscript𝑁g.\mathbf{a}_{c,\text{g.}}\in\{0,1\}^{N_{\text{g.}}}bold_a start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with 1 and 0 indicates the presence and absence of a specific gene in the pathway c𝑐citalic_c.
Pathway summary. Denoting Nc,g.subscript𝑁𝑐g.N_{c,\text{g.}}italic_N start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT as the number of genes in pathway c𝑐citalic_c, we can construct 𝐳c,g.agg.Nc,g.superscriptsubscript𝐳𝑐g.agg.superscriptsubscript𝑁𝑐g.\mathbf{z}_{c,\text{g.}}^{\text{agg.}}\in\mathbb{R}^{N_{c,\text{g.}}}bold_z start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and pathway summary 𝕊path.={𝐳c,g.agg.}c=1Cg.subscript𝕊path.superscriptsubscriptsuperscriptsubscript𝐳𝑐g.agg.𝑐1subscript𝐶g.\mathbb{S}_{\text{path.}}=\{\mathbf{z}_{c,\text{g.}}^{\text{agg.}}\}_{c=1}^{C_% {\text{g.}}}blackboard_S start_POSTSUBSCRIPT path. end_POSTSUBSCRIPT = { bold_z start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT as,

𝐳c,g.agg.=ϕg.(𝐱g.,𝐚c,g.)=R(𝐱g.𝐚c,g.)Nc,g.,superscriptsubscript𝐳𝑐g.agg.superscriptitalic-ϕg.subscript𝐱g.subscript𝐚𝑐g.𝑅direct-productsubscript𝐱g.subscript𝐚𝑐g.superscriptsubscript𝑁𝑐g.\mathbf{z}_{c,\text{g.}}^{\text{agg.}}=\phi^{\text{g.}}(\mathbf{x}_{\text{g.}}% ,\mathbf{a}_{c,\text{g.}})=R(\mathbf{x}_{\text{g.}}\odot\mathbf{a}_{c,\text{g.% }})\in\mathbb{R}^{N_{c,\text{g.}}},bold_z start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT = italic_ϕ start_POSTSUPERSCRIPT g. end_POSTSUPERSCRIPT ( bold_x start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT ) = italic_R ( bold_x start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT ⊙ bold_a start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , (3)

where 𝐱gNg.subscript𝐱𝑔superscriptsubscript𝑁g.\mathbf{x}_{g}\in\mathbb{R}^{N_{\text{g.}}}bold_x start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the vector representation of {xi,g}i=1Ng.superscriptsubscriptsubscript𝑥𝑖𝑔𝑖1subscript𝑁g.\{x_{i,g}\}_{i=1}^{N_{\text{g.}}}{ italic_x start_POSTSUBSCRIPT italic_i , italic_g end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, direct-product\odot denotes element-wise multiplication, and R𝑅Ritalic_R densifies the pathway representation by removing zero elements (Fig. 1B). In our work, Ng3×104similar-to-or-equalssubscript𝑁𝑔3superscript104N_{g}\simeq 3\times 10^{4}italic_N start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ≃ 3 × 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT and Nc,g.<200subscript𝑁𝑐g.200N_{c,\text{g.}}<200italic_N start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT < 200, achieving more than 20×20\times20 × reduction.

To summarize, morphological and pathway prototypes are used to extract a slide summary 𝕊slidesubscript𝕊slide\mathbb{S}_{\text{slide}}blackboard_S start_POSTSUBSCRIPT slide end_POSTSUBSCRIPT and a pathway summary 𝕊path.subscript𝕊path.\mathbb{S}_{\text{path.}}blackboard_S start_POSTSUBSCRIPT path. end_POSTSUBSCRIPT. The morphological prototypes are defined in the patch embedding space with fixed-length, 𝐳c,h.agg.dh.superscriptsubscript𝐳𝑐h.agg.superscriptsubscript𝑑h.\mathbf{z}_{c,\text{h.}}^{\text{agg.}}\in\mathbb{R}^{d_{\text{h.}}}bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and encode distinct morphological attributes. Differently, the pathway prototypes are defined in the raw data space with variable-length 𝐳c,g.agg.Nc,g.superscriptsubscript𝐳𝑐g.agg.superscriptsubscript𝑁𝑐g.\mathbf{z}_{c,\text{g.}}^{\text{agg.}}\in\mathbb{R}^{N_{c,\text{g.}}}bold_z start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and encode specific biological pathways. We note that both approaches are unsupervised and thus not require patient outcomes.

3.2 Multimodal fusion

3.2.1 Token dimension matching

Prior to multimodal fusion, we first match the dimensions of tokens from each modality. We use a linear projection 𝐳c,h.pre=fh.pre(𝐳c,h.agg.)dsuperscriptsubscript𝐳𝑐h.presuperscriptsubscript𝑓h.presuperscriptsubscript𝐳𝑐h.agg.superscript𝑑\mathbf{z}_{c,\text{h.}}^{\text{pre}}=f_{\text{h.}}^{\text{pre}}(\mathbf{z}_{c% ,\text{h.}}^{\text{agg.}})\in\mathbb{R}^{d}bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT = italic_f start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT for histology. For pathways, we use an MLP or self-normalizing neural networks (SNN) (Klambauer et al., 2017) fc,g.presuperscriptsubscript𝑓𝑐g.pref_{c,\text{g.}}^{\text{pre}}italic_f start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT per prototype to map variable-length representations to a common length, 𝐳c,g.pre=fc,g.pre(𝐳c,g.agg.)dsuperscriptsubscript𝐳𝑐g.presuperscriptsubscript𝑓𝑐g.presuperscriptsubscript𝐳𝑐g.agg.superscript𝑑\mathbf{z}_{c,\text{g.}}^{\text{pre}}=f_{c,\text{g.}}^{\text{pre}}(\mathbf{z}_% {c,\text{g.}}^{\text{agg.}})\in\mathbb{R}^{d}bold_z start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT = italic_f start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. The parameters of fh.presuperscriptsubscript𝑓h.pref_{\text{h.}}^{\text{pre}}italic_f start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT and {fc,g.pre}c=1Cg.superscriptsubscriptsuperscriptsubscript𝑓𝑐g.pre𝑐1subscript𝐶g.\{f_{c,\text{g.}}^{\text{pre}}\}_{c=1}^{C_{\text{g.}}}{ italic_f start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are learned for each downstream task.

3.2.2 Multimodal fusion

Inspired by multimodal early fusion methods, we learn dense intra- and cross-modal interactions between the histology and pathway tokens. We explore two strategies: Transformer attention and OT cross-alignment (Fig. 1C).
Transformer attention. We introduce three learnable query, key, value matrices 𝐖Q,𝐖K,𝐖Vd×dsubscript𝐖𝑄subscript𝐖𝐾subscript𝐖𝑉superscript𝑑𝑑\mathbf{W}_{Q},\mathbf{W}_{K},\mathbf{W}_{V}\in\mathbb{R}^{d\times d}bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT. Denoting 𝐐=(𝐐g.T𝐐h.T)T=(𝐙g.pre,T𝐙h.pre,T)T𝐖Q(Cg.+Ch.)×d𝐐superscriptsuperscriptsubscript𝐐g.Tsuperscriptsubscript𝐐h.TTsuperscriptsuperscriptsubscript𝐙g.preTsuperscriptsubscript𝐙h.preTTsubscript𝐖𝑄superscriptsubscript𝐶g.subscript𝐶h.𝑑\mathbf{Q}=(\mathbf{Q}_{\text{g.}}^{\text{T}}\hskip 3.41432pt\mathbf{Q}_{\text% {h.}}^{\text{T}})^{\text{T}}=\big{(}\mathbf{Z}_{\text{g.}}^{\text{pre},\text{T% }}\hskip 3.41432pt\mathbf{Z}_{\text{h.}}^{\text{pre},\text{T}}\big{)}^{\text{T% }}\mathbf{W}_{Q}\in\mathbb{R}^{(C_{\text{g.}}+C_{\text{h.}})\times d}bold_Q = ( bold_Q start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_Q start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT = ( bold_Z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre , T end_POSTSUPERSCRIPT bold_Z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre , T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT ( italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT + italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ) × italic_d end_POSTSUPERSCRIPT, and likewise for 𝐊𝐊\mathbf{K}bold_K and 𝐕𝐕\mathbf{V}bold_V, we can define the standard Transformer attention (Vaswani et al., 2017; Xu et al., 2023)

𝐙g.+h.post=(𝐙g.post𝐙h.post)=σ(𝐐𝐊Td)𝐕(Cg.+Ch.)×d=σ(1d(𝐐g.𝐊g.T𝐐g.𝐊h.T𝐐h.𝐊g.T𝐐h.𝐊h.T))(𝐕g.𝐕h.),subscriptsuperscript𝐙postg.h.matrixsubscriptsuperscript𝐙postg.subscriptsuperscript𝐙posth.𝜎superscript𝐐𝐊T𝑑𝐕superscriptsubscript𝐶g.subscript𝐶h.𝑑𝜎1𝑑matrixsubscript𝐐g.superscriptsubscript𝐊g.Tsubscript𝐐g.superscriptsubscript𝐊h.Tsubscript𝐐h.superscriptsubscript𝐊g.Tsubscript𝐐h.superscriptsubscript𝐊h.Tmatrixsubscript𝐕g.subscript𝐕h.\begin{split}\mathbf{Z}^{\text{post}}_{\text{g.}+\text{h.}}&=\begin{pmatrix}% \mathbf{Z}^{\text{post}}_{\text{g.}}\\ \mathbf{Z}^{\text{post}}_{\text{h.}}\end{pmatrix}=\sigma\left(\frac{\mathbf{Q}% \mathbf{K}^{\text{T}}}{\sqrt{d}}\right)\mathbf{V}\in\mathbb{R}^{(C_{\text{g.}}% +C_{\text{h.}})\times d}\\ &=\sigma\left(\frac{1}{\sqrt{d}}\begin{pmatrix}\mathbf{Q}_{\text{g.}}\mathbf{K% }_{\text{g.}}^{\text{T}}\hskip 2.84526pt\mathbf{Q}_{\text{g.}}\mathbf{K}_{% \text{h.}}^{\text{T}}\\ \mathbf{Q}_{\text{h.}}\mathbf{K}_{\text{g.}}^{\text{T}}\hskip 2.84526pt\mathbf% {Q}_{\text{h.}}\mathbf{K}_{\text{h.}}^{\text{T}}\\ \end{pmatrix}\right)\begin{pmatrix}\mathbf{V}_{\text{g.}}\\ \mathbf{V}_{\text{h.}}\end{pmatrix},\\ \end{split}start_ROW start_CELL bold_Z start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT start_POSTSUBSCRIPT g. + h. end_POSTSUBSCRIPT end_CELL start_CELL = ( start_ARG start_ROW start_CELL bold_Z start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_Z start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) = italic_σ ( divide start_ARG bold_QK start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) bold_V ∈ blackboard_R start_POSTSUPERSCRIPT ( italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT + italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ) × italic_d end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = italic_σ ( divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ( start_ARG start_ROW start_CELL bold_Q start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_Q start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_Q start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_Q start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) ) ( start_ARG start_ROW start_CELL bold_V start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_V start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) , end_CELL end_ROW (4)

where σ()𝜎\sigma(\cdot)italic_σ ( ⋅ ) denotes row-wise softmax. Eq. 4 illustrates how multimodal attention can be decomposed into the intra-modal self-attention (g.g.g.g.\text{g.}\rightarrow\text{g.}g. → g., h.h.h.h.\text{h.}\rightarrow\text{h.}h. → h.) and cross-modal cross-attention (g.h.g.h.\text{g.}\rightarrow\text{h.}g. → h., h.g.h.g.\text{h.}\rightarrow\text{g.}h. → g.). In MMP, the complexity of computing attention is simplified to 𝒪((Cg.+Ch.)2)𝒪superscriptsubscript𝐶g.subscript𝐶h.2\mathcal{O}((C_{\text{g.}}+C_{\text{h.}})^{2})caligraphic_O ( ( italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT + italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), a considerable reduction from 𝒪((Ng.+Nh.)2)𝒪superscriptsubscript𝑁g.subscript𝑁h.2\mathcal{O}((N_{\text{g.}}+N_{\text{h.}})^{2})caligraphic_O ( ( italic_N start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT + italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) in most multimodal fusion methods that do not use prototy**.

Optimal Transport cross-alignment Modeling cross-modal interactions can also be approached from the point of view of OT, where we aim to learn the transport plan 𝐓+Cg.×Ch.𝐓superscriptsubscriptsubscript𝐶g.subscript𝐶h.\mathbf{T}\in\mathbb{R}_{+}^{C_{\text{g.}}\times C_{\text{h.}}}bold_T ∈ blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT × italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with the minimal total cost between the empirical distributions p^(𝐳g.pre)=1Cg.c=1Cg.δ(𝐳c,g.pre)^𝑝superscriptsubscript𝐳g.pre1subscript𝐶g.superscriptsubscript𝑐1subscript𝐶g.𝛿superscriptsubscript𝐳𝑐g.pre\hat{p}(\mathbf{z}_{\text{g.}}^{\text{pre}})=\frac{1}{C_{\text{g.}}}\sum_{c=1}% ^{C_{\text{g.}}}\delta(\mathbf{z}_{c,\text{g.}}^{\text{pre}})over^ start_ARG italic_p end_ARG ( bold_z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_δ ( bold_z start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT ) and p^(𝐳h.pre)=1Ch.c=1Ch.δ(𝐳c,h.pre)^𝑝superscriptsubscript𝐳h.pre1subscript𝐶h.superscriptsubscriptsuperscript𝑐1subscript𝐶h.𝛿superscriptsubscript𝐳superscript𝑐h.pre\hat{p}(\mathbf{z}_{\text{h.}}^{\text{pre}})=\frac{1}{C_{\text{h.}}}\sum_{c^{% \prime}=1}^{C_{\text{h.}}}\delta(\mathbf{z}_{c^{\prime},\text{h.}}^{\text{pre}})over^ start_ARG italic_p end_ARG ( bold_z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_δ ( bold_z start_POSTSUBSCRIPT italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT ), and where δ()𝛿\delta(\cdot)italic_δ ( ⋅ ) is a delta function. The pairwise cost 𝐃c,csubscript𝐃𝑐superscript𝑐\mathbf{D}_{c,c^{\prime}}bold_D start_POSTSUBSCRIPT italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT between the two tokens is typically computed using a L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT distance or negative dot product. The estimate 𝐓^^𝐓\widehat{\mathbf{T}}over^ start_ARG bold_T end_ARG is given as the solution to the entropic-regularized OT problem (Kolouri et al., 2017),

min𝐓c,c𝐃c,c𝐓c,c+ε𝐓c,clog𝐓c,cs.t.c=1Cg.𝐓c,c=1/Ch.andc=1Ch.𝐓c,c=1/Cg.,formulae-sequencesubscript𝐓subscript𝑐superscript𝑐subscript𝐃𝑐superscript𝑐subscript𝐓𝑐superscript𝑐𝜀subscript𝐓𝑐superscript𝑐subscript𝐓𝑐superscript𝑐s.t.superscriptsubscript𝑐1subscript𝐶g.subscript𝐓𝑐superscript𝑐1subscript𝐶h.andsuperscriptsubscriptsuperscript𝑐1subscript𝐶h.subscript𝐓𝑐superscript𝑐1subscript𝐶g.\begin{split}&\min_{\mathbf{T}}\sum_{c,c^{\prime}}\mathbf{D}_{c,c^{\prime}}% \cdot\mathbf{T}_{c,c^{\prime}}+\varepsilon\mathbf{T}_{c,c^{\prime}}\log\mathbf% {T}_{c,c^{\prime}}\\ &\,\,\text{s.t.}\sum_{c=1}^{C_{\text{g.}}}\mathbf{T}_{c,c^{\prime}}=1/C_{\text% {h.}}\quad\text{and}\quad\sum_{c^{\prime}=1}^{C_{\text{h.}}}\mathbf{T}_{c,c^{% \prime}}=1/C_{\text{g.}},\\ \end{split}start_ROW start_CELL end_CELL start_CELL roman_min start_POSTSUBSCRIPT bold_T end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_D start_POSTSUBSCRIPT italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⋅ bold_T start_POSTSUBSCRIPT italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT + italic_ε bold_T start_POSTSUBSCRIPT italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_log bold_T start_POSTSUBSCRIPT italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL s.t. ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_T start_POSTSUBSCRIPT italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = 1 / italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT and ∑ start_POSTSUBSCRIPT italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_T start_POSTSUBSCRIPT italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = 1 / italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT , end_CELL end_ROW (5)

where ε𝜀\varepsilonitalic_ε is the regularization parameter. The optimal plan 𝐓^^𝐓\widehat{\mathbf{T}}over^ start_ARG bold_T end_ARG can be obtained with the Sinkhorn algorithm (Cuturi, 2013), which can be differentiated (Genevay et al., 2018). This enables joint learning of fh.presuperscriptsubscript𝑓h.pref_{\text{h.}}^{\text{pre}}italic_f start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT and {fc,g.pre}c=1Cg.superscriptsubscriptsuperscriptsubscript𝑓𝑐g.pre𝑐1subscript𝐶g.\{f_{c,\text{g.}}^{\text{pre}}\}_{c=1}^{C_{\text{g.}}}{ italic_f start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT along with the plan 𝐓^^𝐓\widehat{\mathbf{T}}over^ start_ARG bold_T end_ARG. Cross-alignment with 𝐓^^𝐓\widehat{\mathbf{T}}over^ start_ARG bold_T end_ARG, i.e., 𝐓^𝐙h.pre^𝐓superscriptsubscript𝐙h.pre\widehat{\mathbf{T}}\mathbf{Z}_{\text{h.}}^{\text{pre}}over^ start_ARG bold_T end_ARG bold_Z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT, performs h.g.h.g.\text{h.}\rightarrow\text{g.}h. → g. attention, while 𝐓^Tsuperscript^𝐓T\widehat{\mathbf{T}}^{\text{T}}over^ start_ARG bold_T end_ARG start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT performs g.h.g.h.\text{g.}\rightarrow\text{h.}g. → h. attention, 𝐓^T𝐙g.presuperscript^𝐓Tsuperscriptsubscript𝐙g.pre\widehat{\mathbf{T}}^{\text{T}}\mathbf{Z}_{\text{g.}}^{\text{pre}}over^ start_ARG bold_T end_ARG start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_Z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT. After the alignment, we learn intra-modal interactions using the Transformer self-attention. Denoting 𝐐g.=(𝐓^𝐙h.pre)𝐖Qsubscript𝐐g.^𝐓superscriptsubscript𝐙h.presubscript𝐖𝑄\mathbf{Q}_{\text{g.}}=(\widehat{\mathbf{T}}\mathbf{Z}_{\text{h.}}^{\text{pre}% })\mathbf{W}_{Q}bold_Q start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT = ( over^ start_ARG bold_T end_ARG bold_Z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT ) bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT and 𝐐h.=(𝐓^T𝐙g.pre)𝐖Qsubscript𝐐h.superscript^𝐓Tsuperscriptsubscript𝐙g.presubscript𝐖𝑄\mathbf{Q}_{\text{h.}}=(\widehat{\mathbf{T}}^{\text{T}}\mathbf{Z}_{\text{g.}}^% {\text{pre}})\mathbf{W}_{Q}bold_Q start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = ( over^ start_ARG bold_T end_ARG start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_Z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT ) bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT, and likewise for 𝐊g.,𝐕g.,𝐊h.,𝐕h.subscript𝐊g.subscript𝐕g.subscript𝐊h.subscript𝐕h.\mathbf{K}_{\text{g.}},\mathbf{V}_{\text{g.}},\mathbf{K}_{\text{h.}},\mathbf{V% }_{\text{h.}}bold_K start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT , bold_V start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT , bold_K start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT , bold_V start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT, we obtain

𝐙h.post=σ(𝐐h.𝐊h.Td)𝐕h.,𝐙g.post=σ(𝐐g.𝐊g.Td)𝐕g..formulae-sequencesuperscriptsubscript𝐙h.post𝜎subscript𝐐h.superscriptsubscript𝐊h.T𝑑subscript𝐕h.superscriptsubscript𝐙g.post𝜎subscript𝐐g.superscriptsubscript𝐊g.T𝑑subscript𝐕g.\mathbf{Z}_{\text{h.}}^{\text{post}}=\sigma\left(\frac{\mathbf{Q}_{\text{h.}}% \mathbf{K}_{\text{h.}}^{\text{T}}}{\sqrt{d}}\right)\mathbf{V}_{\text{h.}},\,% \mathbf{Z}_{\text{g.}}^{\text{post}}=\sigma\left(\frac{\mathbf{Q}_{\text{g.}}% \mathbf{K}_{\text{g.}}^{\text{T}}}{\sqrt{d}}\right)\mathbf{V}_{\text{g.}}.bold_Z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT = italic_σ ( divide start_ARG bold_Q start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) bold_V start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT , bold_Z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT = italic_σ ( divide start_ARG bold_Q start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) bold_V start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT . (6)

3.2.3 Connection between transformer and Optimal transport cross-alignment

The Transformer cross-attention and the OT cross-alignment exhibit similarities in the way attention and the transport plan are being modeled. We can formalize these similarities to demonstrate the connection between the two. Specifically, we show that the Transformer cross-attention is similar to OT cross-alignment, under certain conditions.

Lemma 3.1.

Let 𝐙g.Cg.×dsubscript𝐙g.superscriptsubscript𝐶g.𝑑\mathbf{Z}_{\text{g.}}\in\mathbb{R}^{C_{\text{g.}}\times d}bold_Z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT and 𝐙h.Ch.×dsubscript𝐙h.superscriptsubscript𝐶h.𝑑\mathbf{Z}_{\text{h.}}\in\mathbb{R}^{C_{\text{h.}}\times d}bold_Z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT be the matrix representation of the token sets {𝐳i,g.}i=1Cg.superscriptsubscriptsubscript𝐳𝑖g.𝑖1subscript𝐶g.\{\mathbf{z}_{i,\text{g.}}\}_{i=1}^{C_{\text{g.}}}{ bold_z start_POSTSUBSCRIPT italic_i , g. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and {𝐳k,h.}k=1Ch.superscriptsubscriptsubscript𝐳𝑘h.𝑘1subscript𝐶h.\{\mathbf{z}_{k,\text{h.}}\}_{k=1}^{C_{\text{h.}}}{ bold_z start_POSTSUBSCRIPT italic_k , h. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. Let 𝐙g.𝐖QTCg.×dsubscript𝐙g.superscriptsubscript𝐖𝑄Tsuperscriptsubscript𝐶g.𝑑\mathbf{Z}_{\text{g.}}\mathbf{W}_{Q}^{\text{T}}\in\mathbb{R}^{C_{\text{g.}}% \times d}bold_Z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT and 𝐙h.𝐖TCh.×dsubscript𝐙h.superscript𝐖Tsuperscriptsubscript𝐶h.𝑑\mathbf{Z}_{\text{h.}}\mathbf{W}^{\text{T}}\in\mathbb{R}^{C_{\text{h.}}\times d}bold_Z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT bold_W start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT be the linear projections of both sets. Let 𝐓^+Cg.×Ch.^𝐓subscriptsuperscriptsubscript𝐶g.subscript𝐶h.\widehat{\mathbf{T}}\in\mathbb{R}^{C_{\text{g.}}\times C_{\text{h.}}}_{+}over^ start_ARG bold_T end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT × italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT + end_POSTSUBSCRIPT be the optimal transport plan, i.e., the solution to the entropic-regularized, unbalanced optimal transport problem between the two projected sets. Then, 𝐓^^𝐓\widehat{\mathbf{T}}over^ start_ARG bold_T end_ARG is equivalent to the Transformer cross-attention matrix, σ(𝐙g.𝐖QT𝐖𝐙h.T/d)𝜎subscript𝐙g.superscriptsubscript𝐖𝑄Tsuperscriptsubscript𝐖𝐙h.T𝑑\sigma(\mathbf{Z}_{\text{g.}}\mathbf{W}_{Q}^{\text{T}}\mathbf{W}\mathbf{Z}_{% \text{h.}}^{\text{T}}/\sqrt{d})italic_σ ( bold_Z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_WZ start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT / square-root start_ARG italic_d end_ARG ), up to a multiplicative factor where σ()𝜎\sigma(\cdot)italic_σ ( ⋅ ) denotes row-wise softmax, {𝐖Q𝐳i,g.}i=1Cg.superscriptsubscriptsubscript𝐖𝑄subscript𝐳𝑖g.𝑖1subscript𝐶g.\{\mathbf{W}_{Q}\mathbf{z}_{i,\text{g.}}\}_{i=1}^{C_{\text{g.}}}{ bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT bold_z start_POSTSUBSCRIPT italic_i , g. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are queries, and {𝐖𝐳k,h.}k=1Ch.superscriptsubscriptsubscript𝐖𝐳𝑘h.𝑘1subscript𝐶h.\{\mathbf{W}\mathbf{z}_{k,\text{h.}}\}_{k=1}^{C_{\text{h.}}}{ bold_Wz start_POSTSUBSCRIPT italic_k , h. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are keys.

Proof.

The detailed derivation can be found in Appendix B. ∎

This lays the groundwork for MMP to integrate both approaches within a single framework, rather than regarding them as fundamentally distinct approaches. This offers a platform for future innovations in multimodal strategies.

3.3 Survival prediction

The post-attention embeddings are subject to a sequence of post-attention feedforward network fpostsuperscript𝑓postf^{\text{post}}italic_f start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT with layer normalization (LN), averaging within each modality, and concatenation to form a patient-level embedding 𝐳patient=[c=1Cg.LN(fpost(𝐳c,g.post)),c=1Ch.LN(fpost(𝐳c,h.post))]subscript𝐳patientsuperscriptsubscript𝑐1subscript𝐶g.LNsuperscript𝑓postsuperscriptsubscript𝐳𝑐g.postsuperscriptsubscript𝑐1subscript𝐶h.LNsuperscript𝑓postsuperscriptsubscript𝐳𝑐h.post\mathbf{z}_{\text{patient}}=\Big{[}\sum_{c=1}^{C_{\text{g.}}}\operatorname{LN}% (f^{\text{post}}(\mathbf{z}_{c,\text{g.}}^{\text{post}})),\sum_{c=1}^{C_{\text% {h.}}}\operatorname{LN}(f^{\text{post}}(\mathbf{z}_{c,\text{h.}}^{\text{post}}% ))\Big{]}bold_z start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT = [ ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_LN ( italic_f start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT ) ) , ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_LN ( italic_f start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT ) ) ]. The resulting embedding is fed through a linear predictor fpred.subscript𝑓pred.f_{\text{pred.}}italic_f start_POSTSUBSCRIPT pred. end_POSTSUBSCRIPT for patient-level risk prediction.

We use the Cox proportional hazards loss (Cox, 1972; Katzman et al., 2018; Carmichael et al., 2022), which requires training in batches to preserve the risk order within a patient group. Due to the large Nh.subscript𝑁h.N_{\text{h.}}italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT, it is computationally challenging to form a batch for non-prototype approaches. Appendix C provides a detailed explanation of losses.

3.4 Enhancing prototypes

Given that the identities of the prototypes remain consistent across patients – e.g., prototype c𝑐citalic_c consistently represents the same morphological concept or pathway – we can additionally inject this property into model design considerations. Specifically, we incorporate (1) a prototype-specific encoding and (2) a post-attention feed-forward network.

Prototype encoding, denoted as 𝐞csubscript𝐞𝑐\mathbf{e}_{c}bold_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, can be connected to modality-specific encodings (Jaegle et al., 2022; Liang et al., 2023). Specifically, we append the encodings to the embeddings before feeding them to the fusion network. We experiment with two approaches: 1) fixed one-hot encoding 𝐞c{0,1}desubscript𝐞𝑐superscript01subscript𝑑𝑒\mathbf{e}_{c}\in\{0,1\}^{d_{e}}bold_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with de=Cg.+Ch.subscript𝑑𝑒subscript𝐶g.subscript𝐶h.d_{e}=C_{\text{g.}}+C_{\text{h.}}italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT = italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT + italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT and 2) random-initialized and learnable embedding 𝐞cdesubscript𝐞𝑐superscriptsubscript𝑑𝑒\mathbf{e}_{c}\in\mathbb{R}^{d_{e}}bold_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with de=32subscript𝑑𝑒32d_{e}=32italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT = 32. The modified embeddings are then given as 𝐳c,h.pre=[𝐳c,h.pre,𝐞c]d+desuperscriptsubscript𝐳𝑐h.presuperscriptsubscript𝐳𝑐h.presubscript𝐞𝑐superscript𝑑subscript𝑑𝑒\mathbf{z}_{c,\text{h.}}^{\text{pre}}=[\mathbf{z}_{c,\text{h.}}^{\text{pre}},% \mathbf{e}_{c}]\in\mathbb{R}^{d+d_{e}}bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT = [ bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT , bold_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d + italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and same for 𝐳c,g.presuperscriptsubscript𝐳𝑐g.pre\mathbf{z}_{c,\text{g.}}^{\text{pre}}bold_z start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT.

We also employ prototype-specific feedforward network (FFN) fcpostsuperscriptsubscript𝑓𝑐postf_{c}^{\text{post}}italic_f start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT to the post-attention embeddings to learn additional nonlinearity per prototype. This differs from previous works that share fpostsuperscript𝑓postf^{\text{post}}italic_f start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT, which might limit expressivity. These components cannot be used for non-prototype frameworks, since 1) the patch identity is not preserved across patients and 2) large Nh.subscript𝑁h.N_{\text{h.}}italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT makes the use of 𝐞csubscript𝐞𝑐\mathbf{e}_{c}bold_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and fcpostsuperscriptsubscript𝑓𝑐postf_{c}^{\text{post}}italic_f start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT infeasible.

Table 1: Survival prediction Results for MMP and other baselines for measuring disease-specific survival with C-Index. The clinical baseline includes age, sex, and cancer grade as reported in the TCGA cohort. We use the same histology feature encoder, UNI, a ViT-L/16 model pretrained on an internal histology dataset (Chen et al., 2024). All histology prototype-based methods share the same set of morphological prototypes with Ch.=16subscript𝐶h.16C_{\text{h.}}=16italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 16. Standard deviation is reported over five runs. m.p. and p.p. denote morphological prototype and pathway prototype, respectively. The best and second-best performances are denoted by bold and underlined, respectively.
Dataset m.p. p.p. BRCA BLCA LUAD STAD CRC KIRC Avg. (\uparrow)
Clinical 0.563±0.055plus-or-minus0.055\pm 0.055± 0.055 0.570±0.033plus-or-minus0.033\pm 0.033± 0.033 0.528±0.028plus-or-minus0.028\pm 0.028± 0.028 0.592±0.044plus-or-minus0.044\pm 0.044± 0.044 0.655±0.119plus-or-minus0.119\pm 0.119± 0.119 0.602±0.066plus-or-minus0.066\pm 0.066± 0.066 0.585
gene Gene exp. 0.638±0.090plus-or-minus0.090\pm 0.090± 0.090 0.627±0.055plus-or-minus0.055\pm 0.055± 0.055 0.577±0.057plus-or-minus0.057\pm 0.057± 0.057 0.562±0.083plus-or-minus0.083\pm 0.083± 0.083 0.588±0.105plus-or-minus0.105\pm 0.105± 0.105 0.681±0.072plus-or-minus0.072\pm 0.072± 0.072 0.612
Pathways \checkmark 0.615±0.054plus-or-minus0.054\pm 0.054± 0.054 0.606±0.084plus-or-minus0.084\pm 0.084± 0.084 0.626±0.077plus-or-minus0.077\pm 0.077± 0.077 0.566±0.080plus-or-minus0.080\pm 0.080± 0.080 0.590±0.104plus-or-minus0.104\pm 0.104± 0.104 0.681±0.090plus-or-minus0.090\pm 0.090± 0.090 0.614
histology ABMIL 0.570±0.086plus-or-minus0.086\pm 0.086± 0.086 0.550±0.039plus-or-minus0.039\pm 0.039± 0.039 0.571±0.036plus-or-minus0.036\pm 0.036± 0.036 0.559±0.059plus-or-minus0.059\pm 0.059± 0.059 0.660±0.096plus-or-minus0.096\pm 0.096± 0.096 0.684±0.115plus-or-minus0.115\pm 0.115± 0.115 0.599
TransMIL 0.601±0.110plus-or-minus0.110\pm 0.110± 0.110 0.584±0.057plus-or-minus0.057\pm 0.057± 0.057 0.547±0.054plus-or-minus0.054\pm 0.054± 0.054 0.487±0.057plus-or-minus0.057\pm 0.057± 0.057 0.555±0.059plus-or-minus0.059\pm 0.059± 0.059 0.678±0.191plus-or-minus0.191\pm 0.191± 0.191 0.575
AttnMISL \checkmark 0.599±0.117plus-or-minus0.117\pm 0.117± 0.117 0.493±0.064plus-or-minus0.064\pm 0.064± 0.064 0.627±0.076plus-or-minus0.076\pm 0.076± 0.076 0.533±0.040plus-or-minus0.040\pm 0.040± 0.040 0.728±0.110plus-or-minus0.110\pm 0.110± 0.110 0.648±0.102plus-or-minus0.102\pm 0.102± 0.102 0.605
IB-MIL 0.511±0.068plus-or-minus0.068\pm 0.068± 0.068 0.524±0.051plus-or-minus0.051\pm 0.051± 0.051 0.578±0.067plus-or-minus0.067\pm 0.067± 0.067 0.525±0.061plus-or-minus0.061\pm 0.061± 0.061 0.576±0.129plus-or-minus0.129\pm 0.129± 0.129 0.702±0.081plus-or-minus0.081\pm 0.081± 0.081 0.569
ILRA 0.597±0.124plus-or-minus0.124\pm 0.124± 0.124 0.581±0.055plus-or-minus0.055\pm 0.055± 0.055 0.511±0.077plus-or-minus0.077\pm 0.077± 0.077 0.550±0.094plus-or-minus0.094\pm 0.094± 0.094 0.643±0.124plus-or-minus0.124\pm 0.124± 0.124 0.651±0.164plus-or-minus0.164\pm 0.164± 0.164 0.589
MMP \checkmark 0.669±0.119plus-or-minus0.119\pm 0.119± 0.119 0.593±0.062plus-or-minus0.062\pm 0.062± 0.062 0.600±0.039plus-or-minus0.039\pm 0.039± 0.039 0.488 ±0.093plus-or-minus0.093\pm 0.093± 0.093 0.646±0.111plus-or-minus0.111\pm 0.111± 0.111 0.701±0.177plus-or-minus0.177\pm 0.177± 0.177 0.611
Multimodal MCAT \checkmark 0.648±0.100plus-or-minus0.100\pm 0.100± 0.100 0.619±0.048plus-or-minus0.048\pm 0.048± 0.048 0.615±0.072plus-or-minus0.072\pm 0.072± 0.072 0.528±0.114plus-or-minus0.114\pm 0.114± 0.114 0.578±0.136plus-or-minus0.136\pm 0.136± 0.136 0.670±0.235plus-or-minus0.235\pm 0.235± 0.235 0.610
SurvPath \checkmark 0.709±0.062plus-or-minus0.062\pm 0.062± 0.062 0.619±0.052plus-or-minus0.052\pm 0.052± 0.052 0.612±0.060plus-or-minus0.060\pm 0.060± 0.060 0.556±0.136plus-or-minus0.136\pm 0.136± 0.136 0.539±0.150plus-or-minus0.150\pm 0.150± 0.150 0.738±0.131plus-or-minus0.131\pm 0.131± 0.131 0.629
MOTCat \checkmark 0.717±0.029plus-or-minus0.029\pm 0.029± 0.029 0.622±0.064plus-or-minus0.064\pm 0.064± 0.064 0.589±0.059plus-or-minus0.059\pm 0.059± 0.059 0.561±0.075plus-or-minus0.075\pm 0.075± 0.075 0.590±0.130plus-or-minus0.130\pm 0.130± 0.130 0.708±0.104plus-or-minus0.104\pm 0.104± 0.104 0.631
CMTA \checkmark 0.687±0.077plus-or-minus0.077\pm 0.077± 0.077 0.605±0.076plus-or-minus0.076\pm 0.076± 0.076 0.622±0.059plus-or-minus0.059\pm 0.059± 0.059 0.547±0.088plus-or-minus0.088\pm 0.088± 0.088 0.559±0.195plus-or-minus0.195\pm 0.195± 0.195 0.720±0.124plus-or-minus0.124\pm 0.124± 0.124 0.623
MMPOTsubscriptMMPOT\textsc{MMP}_{\text{OT}}MMP start_POSTSUBSCRIPT OT end_POSTSUBSCRIPT \checkmark 0.753±0.069plus-or-minus0.069\pm 0.069± 0.069 0.628±0.064plus-or-minus0.064\pm 0.064± 0.064 0.643±0.013plus-or-minus0.013\pm 0.013± 0.013 0.580±0.071plus-or-minus0.071\pm 0.071± 0.071 0.636±0.120plus-or-minus0.120\pm 0.120± 0.120 0.748±0.099plus-or-minus0.099\pm 0.099± 0.099 0.665
MMPTrans.subscriptMMPTrans.\textsc{MMP}_{\text{Trans.}}MMP start_POSTSUBSCRIPT Trans. end_POSTSUBSCRIPT \checkmark 0.738±0.069plus-or-minus0.069\pm 0.069± 0.069 0.635±0.051plus-or-minus0.051\pm 0.051± 0.051 0.642±0.037plus-or-minus0.037\pm 0.037± 0.037 0.598±0.051plus-or-minus0.051\pm 0.051± 0.051 0.630±0.125plus-or-minus0.125\pm 0.125± 0.125 0.747±0.106plus-or-minus0.106\pm 0.106± 0.106 0.665

4 Experiments

4.1 Datasets

We use publicly available The Cancer Genome Atlas (TCGA) to evaluate MMP across six cancer types: Bladder urothelial carcinoma (BLCA) (n=359𝑛359n=359italic_n = 359), Breast invasive carcinoma (BRCA) (n=868𝑛868n=868italic_n = 868), Lung adenocarcinoma (LUAD) (n=412𝑛412n=412italic_n = 412), Stomach adenocarcinoma (STAD) (n=318𝑛318n=318italic_n = 318), Colon and Rectum adenocarcinoma (CRC) (n=296𝑛296n=296italic_n = 296), and Kidney renal clear cell carcinoma (KIRC) (n=340𝑛340n=340italic_n = 340). We train the models to predict risks for disease-specific survival (DSS) (Liu et al., 2018). Following standard practice, we use 5-fold site-stratified cross-validation to mitigate batch effect (Howard et al., 2021). We evaluate MMP with the concordance index (C-Index), which measures the concordance between the ordering based on patients’ survival days and the predicted risks.

Log-2 transformed transcripts per million bulk RNA sequencing expression for all TCGA cohorts is accessed through UCSC Xena database (Goldman et al., 2020). The Cg.=50subscript𝐶g.50C_{\text{g.}}=50italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT = 50 Hallmark gene sets from the Molecular Signatures Database (MSigDB) (Subramanian et al., 2005; Liberzon et al., 2015) are used to select and organize genes into biological pathways. Hallmark gene sets (pathways) represent well-defined biological states in cancer. After organizing genes into Hallmark pathways, we obtained 4,241 unique genes across the 50 pathways, with a minimum and maximum pathway size of 31 and 199, respectively. More dataset details can be found in Appendix D.

4.2 Baselines

Histology. We employ Attention-based MIL (ABMIL) (Ilse et al., 2018), ABMIL with information bottleneck (ABMIL-IB) (Li et al., 2023), Transformer-based MIL (TransMIL) (Shao et al., 2021), low-rank MIL (ILRA) (Xiang & Zhang, 2023), and prototype-based MIL (AttnMISL) (Yao et al., 2020). We also use the unimodal version of MMP.
Transcriptomics. We employ a feed-forward neural network (2-layer MLP) (non-prototype) and a baseline with pathway-specific SNNs (Jaume et al., 2024; Zhang et al., 2024), followed by concatenation.
Multimodal. We use MCAT (Chen et al., 2021), SurvPath (Jaume et al., 2024), MOTCat (Xu & Chen, 2023), and CMTA (Zhou & Chen, 2023), which all use multimodal tokenization to derive histology and omics tokens (pathways in our evaluation), followed by co-attention Transformer.
MMP variants. We test MMP with a Transformer cross-attention (MMPTrans.subscriptMMPTrans.\textsc{MMP}_{\text{Trans.}}MMP start_POSTSUBSCRIPT Trans. end_POSTSUBSCRIPT) and OT cross-alignment (MMPOTsubscriptMMPOT\textsc{MMP}_{\text{OT}}MMP start_POSTSUBSCRIPT OT end_POSTSUBSCRIPT). The rest of the model comprises GMM histology aggregation with Ch.=16subscript𝐶h.16C_{\text{h.}}=16italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 16, learnable random prototype encoding, and prototype-specific feedforward networks.

For the patch encoder, we use UNI (Chen et al., 2024), a DINOv2-based ViT-Large (Dosovitskiy et al., 2021; Oquab et al., 2023) pretrained on 1×1081superscript1081\times 10^{8}1 × 10 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT patches sampled across 1×1051superscript1051\times 10^{5}1 × 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT WSIs from Mass General Brigham. We also ablate with CTransPath (Wang et al., 2022), a Swin Transformer pretrained on 3.2×1043.2superscript1043.2\times 10^{4}3.2 × 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT WSIs from the TCGA, and ResNet50 pretrained on Imagenet (Deng et al., 2009). Further information on all baselines can be found in Appendix E.

4.3 Implementation

All models are trained with a 1×1041superscript1041\times 10^{-4}1 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT learning rate with cosine decay scheduler, AdamW optimizer, and 1×1051superscript1051\times 10^{-5}1 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT weight decay for 20 epochs. MMP uses the Cox loss with a batch size of 64. Non-prototype baselines are trained with the NLL survival loss (Zadeh & Schmid, 2020) with a batch size of 1. During training, in MCAT, SurvPath, MOTCat, and CMTA, we randomly sample 4,096 patches per WSI to increase diversity and reduce memory. During inference, the whole WSI is used. All prototype baselines use Ch.=16subscript𝐶h.16C_{\text{h.}}=16italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 16.

5 Results

5.1 Survival prediction

The results are shown in Table 1. Overall, MMP outperforms all baselines (+5.4% and +7.8% avg. over the next-best multimodal and unimodal models) and ranks within top-2 for 5 out of 6 diseases. We highlight the main findings.

Comparison with clinical baseline. All multimodal baselines perform superior to the clinical baseline comprised of important prognostic variables – age, sex, and cancer grade (Bonnier et al., 1995; Rakha et al., 2010; Tas et al., 2013). This demonstrates the clinical potential of multimodal frameworks for enhanced patient prognostication. Additional univariate clinical baselines are in Appendix F.
Unimodal vs. Multimodal. All multimodal baselines (excluding MCAT) outperform the unimodal baselines (histology and transcriptomics). This aligns with previous multimodal literature showing that histology and transcriptomics contain complementary information to be leveraged for better prognostication. In addition for CRC, we observe unimodal histology baselines outperforming multimodal baselines, indicating that challenges remain in multimodal training dynamics of histology-omic models (Gat et al., 2020; Wang et al., 2020).
Prototypes vs. non-prototypes. MMP significantly outperforms all multimodal approaches that are based on prototy** (+5.4% avg. over the next-best model, MOTCat). While every multimodal baseline utilizes early fusion, MCAT and MOTCat learn only the uni-directional cross-modal interaction from transcriptomics to histology. Conversely, SurvPath omits histology-to-histology interactions in self-attention computation to reduce computational requirements. Overall, we attribute the superior performance of MMP to our ability to 1) retain and encode morphological information predicted of prognosis in the morphological prototypes, 2) model both the intra- and cross-modal interactions without approximations, and 3) employ the Cox survival loss. The quality of the morphological prototypes is reaffirmed in the unimodal setting, where the prototype-based AttnMISL and the unimodal MMP are the two best-performing models, outperforming all other approaches.
Transformer vs. OT-based cross-attention. We observe that the performance of MMPTrans.subscriptMMPTrans.\textsc{MMP}_{\text{Trans.}}MMP start_POSTSUBSCRIPT Trans. end_POSTSUBSCRIPT and MMPOTsubscriptMMPOT\textsc{MMP}_{\text{OT}}MMP start_POSTSUBSCRIPT OT end_POSTSUBSCRIPT are on the same level, empirically confirming the connection between both approaches highlighted in Section 3.2.3.

5.2 Risk stratification

We perform log-rank tests (Bland & Altman, 2004) between the high-risk and low-risk cohorts, stratified at 50% percentile of the risks predicted by MMP and MOTCat, the next-best performing model. Specifically, we aggregate the predicted risks across all test folds to construct the cohort-level risk set. Table 2 shows the p-values for the log-rank tests on all 6 cancer types. With a statistical significance threshold of 0.05, we observe that MMP can significantly stratify the high- and low-risk groups for all 6 cancer types, whereas MOTCat was significant for 3 cancer types. This demonstrates the strength of MMP for risk stratification over other baselines and reaffirms its clinical potential.

Table 2: Risk stratification. We report log-rank p-values for high- and low-risk patient cohorts for MMP and MOTCat. p-values below 0.05 are considered statistically significant.
BRCA BLCA LUAD
MOTCAT 6.16×1056.16superscript1056.16\times 10^{-5}6.16 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT 8.60×1058.60superscript1058.60\times 10^{-5}8.60 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT 9.65×1019.65superscript1019.65\times 10^{-1}9.65 × 10 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT
MMPTrans.subscriptMMPTrans.\textsc{MMP}_{\text{Trans.}}MMP start_POSTSUBSCRIPT Trans. end_POSTSUBSCRIPT 3.08×1053.08superscript1053.08\times 10^{-5}3.08 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT 4.50×1024.50superscript1024.50\times 10^{-2}4.50 × 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT 8.37×1058.37superscript1058.37\times 10^{-5}8.37 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT
STAD CRC KIRC
MOTCAT 5.70×1025.70superscript1025.70\times 10^{-2}5.70 × 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT 7.04×1017.04superscript1017.04\times 10^{-1}7.04 × 10 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT 2.40×1042.40superscript1042.40\times 10^{-4}2.40 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT
MMPTrans.subscriptMMPTrans.\textsc{MMP}_{\text{Trans.}}MMP start_POSTSUBSCRIPT Trans. end_POSTSUBSCRIPT 1.40×1021.40superscript1021.40\times 10^{-2}1.40 × 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT 3.60×1023.60superscript1023.60\times 10^{-2}3.60 × 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT 2.59×1082.59superscript1082.59\times 10^{-8}2.59 × 10 start_POSTSUPERSCRIPT - 8 end_POSTSUPERSCRIPT

5.3 Ablation study

We perform extensive ablations of MMP (Table 3). We summarize our findings below. (1) Number of morphological prototypes: Larger number of morphological prototypes (Ch.subscript𝐶h.C_{\text{h.}}italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT) generally yields better performance up to 16, then performance stagnates. To facilitate easier interpretation with fewer exemplars, we set Ch.=16subscript𝐶h.16C_{\text{h.}}=16italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 16. (2) Feature encoder: UNI, a DINOv2-pretrained ViT-L encoder, substantially improves compared to CTransPath and ResNet50 pretrained on ImageNet. This underscores the importance of a powerful vision encoder trained on large histology datasets. (3) Histology aggregation: Aggregation based on a GMM yields the best performance over optimal transport (OT) and hard clustering (HC). We hypothesize that this is due to GMM explicitly capturing sufficient statistics of patch embedding distribution, e.g., mixture probability and covariance, which other approaches cannot readily integrate. (4) Prototype encoding: Adding prototype encoding, 𝐞csubscript𝐞𝑐\mathbf{e}_{c}bold_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and fcpostsuperscriptsubscript𝑓𝑐postf_{c}^{\text{post}}italic_f start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT, leads to better performance. This suggests the benefits of a prototype-specific measure that leverages the consistent prototype identity. (5) Fusion stage: Early-fusion of tokens via cross-attention (MMP) outperforms late-fusion, which concatenates the self-attended embeddings averaged within each modality, without capturing cross-modal interactions.

We also perform unimodal MMP ablations in Appendix G, to isolate the impact of histology-related design choices.

Table 3: Ablation study. C-Index and its change against MMP as a single model component is modified, averaged across six cohorts.
Ablation Model Avg.
Full model MMP 0.665
Number of Ch.=16subscript𝐶h.16C_{\text{h.}}=16italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 16 \Rightarrow Ch.=8subscript𝐶h.8C_{\text{h.}}=8italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 8 0.655 (--1.5%)
histo. proto. Ch.=32subscript𝐶h.32C_{\text{h.}}=32italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 32 0.662 (--0.5%)
Histo. UNI \Rightarrow ResNet50 0.620 (--6.8%)
enc. fencsubscript𝑓encf_{\text{enc}}italic_f start_POSTSUBSCRIPT enc end_POSTSUBSCRIPT CTransPath 0.643 (--3.3%)
Histo. GMM \Rightarrow OT 0.658 (--1.1%)
agg. ϕch.superscriptsubscriptitalic-ϕ𝑐h.\phi_{c}^{\text{h.}}italic_ϕ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT h. end_POSTSUPERSCRIPT HC 0.629 (--5.4%)
Proto. random \Rightarrow None 0.652 (--2.0%)
embed. 𝐞csubscript𝐞𝑐\mathbf{e}_{c}bold_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT One-hot 0.660 (--0.8%)
FFN fcpostsuperscriptsubscript𝑓𝑐postf_{c}^{\text{post}}italic_f start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT Indiv. \Rightarrow Shared 0.658 (--1.1%)
Co-attention Trans. \Rightarrow OT 0.665 (--0.0%)
Fusion Early \Rightarrow Late 0.646 (--2.9%)

5.4 Loss function

We evaluate the performance of MMP trained using either the Cox or the NLL loss (Table 9 in Appendix H). The NLL loss has widely been used for multimodal frameworks as it accommodates training with a batch of a single patient, a necessity for managing many tokens. Conversely, the Cox loss requires a batch size greater than a single patient, which involves ordering patients within the batch. Applying both losses is feasible due to the reduced computational requirements in MMP. We observe that the Cox loss surpasses NLL loss overall (average C-Index of 0.665 vs. 0.644). Furthermore, increasing the batch size (bs) with NLL leads to enhanced performance (0.621 with bs=1 vs. 0.644 with bs=16), emphasizing the benefit of the reduced tokens.

5.5 Computational complexity

To assess the computational benefits of MMP, we measure the number of floating-point operations (FLOPs) for cross-attention baselines (Table 4). MMP achieves at least 5×\times× fewer giga-FLOPS, demonstrating the superior efficiency of prototy**. We observe that the aggregation (MMPagg.subscriptMMPagg.\textsc{MMP}_{\text{agg.}}MMP start_POSTSUBSCRIPT agg. end_POSTSUBSCRIPT), which maps Nh.subscript𝑁h.N_{\text{h.}}italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT tokens to Ch.subscript𝐶h.C_{\text{h.}}italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT prototypes, constitutes most of the MMP operations, with the fusion (MMPfusionsubscriptMMPfusion\textsc{MMP}_{\text{fusion}}MMP start_POSTSUBSCRIPT fusion end_POSTSUBSCRIPT) requiring significantly less due to condensed token set.

Table 4: Computational complexity. Average number of tokens per WSI and average number of giga-FLOPs per patient.
LUAD KIRC
tokens GFLOPs ()(\downarrow)( ↓ ) tokens GFLOPs ()(\downarrow)( ↓ )
MCAT 4,714 2.10 12,802 5.49
SurvPath 4,714 2.00 12,802 5.41
CMTA 4,714 17.2 12,802 40.1
MMPagg.subscriptMMPagg.\textsc{MMP}_{\text{agg.}}MMP start_POSTSUBSCRIPT agg. end_POSTSUBSCRIPT 4,714 0.309 12,802 0.839
MMPfusionsubscriptMMPfusion\textsc{MMP}_{\text{fusion}}MMP start_POSTSUBSCRIPT fusion end_POSTSUBSCRIPT 16 0.025 16 0.025
MMPtotalsubscriptMMPtotal\textsc{MMP}_{\text{total}}MMP start_POSTSUBSCRIPT total end_POSTSUBSCRIPT \cdot 0.334 \cdot 0.864
Refer to caption
Figure 2: Cross-modal interaction visualization. (A) A WSI for a BRCA patient. (B) The morphological prototype heatmap for c=13𝑐13c=13italic_c = 13 (C13), representing invasive ductal carcinoma (IDC), based on the posterior distribution for C13. (C) Prototype assignment map showing the closest morphological prototype for each patch in the WSI. (D) Top-3 patches for each morphological prototype and proportion of each prototype in the WSI. (E) The top-10 pathways attending to C13. (F) Top-6 morphological prototypes attending to the pathways in (E).

6 Interpretability

Unimodal: As WSIs are represented with a compact set of 16 prototypes in MMP, we can directly visualize a prototype heatmap for prototype c𝑐citalic_c that corresponds to the most similar patches, by relying on the posterior p(ci=c|𝐳i,h.)𝑝subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.p(c_{i}=c|\mathbf{z}_{i,\text{h.}})italic_p ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ) (Fig. 2A, B, with additional examples in Appendix I). The prototype assignment map can clearly show how all the prototypes are distributed in a given WSI (Fig. 2C). For a prototype c𝑐citalic_c, we can also visualize the most representative patches, by querying the patch embeddings closest to 𝝁^csubscript^𝝁𝑐\widehat{\boldsymbol{\mu}}_{c}over^ start_ARG bold_italic_μ end_ARG start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, and its proportion in the WSI, with π^csubscript^𝜋𝑐\widehat{\pi}_{c}over^ start_ARG italic_π end_ARG start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT (Fig. 2D).

Multimodal: With a tractable number of histology tokens, we can visualize the cross-modal attention interactions based on cross-attention scores, from histology to pathways (h.g.h.g.\text{h.}\rightarrow\text{g.}h. → g.) and pathways to histology (g.h.g.h.\text{g.}\rightarrow\text{h.}g. → h.). In contrast to MCAT and MOTCat, which only model and visualize g.h.g.h.\text{g.}\rightarrow\text{h.}g. → h. interactions, i.e., which patches correspond to the queried pathway (histology importance), MMP can also visualize h.g.h.g.\text{h.}\rightarrow\text{g.}h. → g. interactions, i.e., which pathways correspond to the queried prototype (pathway importance). While SurvPath also models h.g.h.g.\text{h.}\rightarrow\text{g.}h. → g. via cross-attention, since the histology patch tokens are redundant and not prototypical, visualizing h.g.h.g.\text{h.}\rightarrow\text{g.}h. → g. is intractable.

As an example, for the prototype c=𝑐absentc=italic_c =13 (C13), which represents the dominant invasive ductal carcinoma morphology in the BRCA WSI, we can visualize its highly-attended pathways (h.g.h.g.\text{h.}\rightarrow\text{g.}h. → g.) – bile acid metabolism, fatty acid metabolism, and cholesterol homeostasis, being important oncogenic pathways in BRCA (Fig. 2E). This agrees with the literature that highlights the association between these pathways and breast cancer prognosis (Nelson et al., 2014; Koundouros & Poulogiannis, 2020; Režen et al., 2022). We can also visualize the highly-attended morphological prototypes for these pathways (g.h.g.h.\text{g.}\rightarrow\text{h.}g. → h.), with C13 highly attended by bile acid metabolism (Fig. 2F). Other IDC variations, such as C1 and C8, are also highly attended by these pathways. By virtue of bi-directional visualization capability, MMP can elucidate tightly-linked relationships, characterized by strong bi-directional cross-attention values (C13 and bile acid metabolism), which is a unique capability over other methods that have only visualized (g. \rightarrow h.) (Fig. 2F) but not (h. \rightarrow g.) (Fig. 2E). Further discussion and visualizations are available in the Appendix I.

7 Conclusion & Future works

We introduced MMP, a prototype-based multimodal fusion framework for survival prediction in computational pathology. This framework introduces a prototype-based tokenization method that effectively reduces the number of tokens and the associated computational complexity common in multimodal fusion frameworks. Such reduction leads to improved overall prognostic performance and allows a bi-directional concept-based interpretation of how morphology and transcriptomes interact.

We consider this an essential step forward for future multimodal prognosis research, which we believe can be extended and validated in different ways (also further detailed in Appendix J). First, the number of prototypes can be determined in a data-driven manner, e.g., using frameworks in Dirichlet processes (Lee et al., 2020; Li et al., 2022). Next, instead of relying on shallow MLP for transcriptomics modeling, we can leverage the latest advances in single-cell foundation models (Rosen et al., 2023; Theodoris et al., 2023; Cui et al., 2024). Finally, a validation with different outcomes, such as progression-free interval and recurrence risk (Liu et al., 2018), as well as application to rare diseases for which not overfitting to a small cohort is paramount, will bring MMP closer to clinical translation.

Acknowledgements

The authors would like to thank Minyoung Kim at Samsung AI Center for the helpful advice on prototype-based aggregation; Ming Y. Lu and Tong Ding for setting up supervised MIL benchmarks. The authors acknowledge funding support from the Brigham and Women’s Hospital (BWH) President’s Fund, Mass General Hospital (MGH) Pathology and by the National Institute of Health (NIH) National Institute of General Medical Sciences (NIGMS) through R35GM138216.

Impact Statement

This manuscript details efforts to enhance cancer prognosis through the integration of whole-slide imaging and gene expression profiling. Our study utilizes data from The Cancer Genome Atlas public database, which has been ethically and institutionally approved by all contributing sites. Although our research may yield societal impacts, it is important to emphasize that this study is designed solely for research applications and not yet intended for clinical use. Further larger external cohort validation will be required for realizing the clinical potential of our work.

References

  • Acosta et al. (2022) Acosta, J. N., Falcone, G. J., Rajpurkar, P., and Topol, E. J. Multimodal biomedical AI. Nature Medicine, 28(9):1773–1784, 2022.
  • Benamou (2003) Benamou, J.-D. Numerical resolution of an “unbalanced” mass transport problem. ESAIM: Mathematical Modelling and Numerical Analysis, 37(5):851–868, 2003.
  • Bland & Altman (2004) Bland, J. M. and Altman, D. G. The logrank test. Bmj, 328(7447):1073, 2004.
  • Bonnier et al. (1995) Bonnier, P., Romain, S., Charpin, C., Lejeune, C., Tubiana, N., Martin, P.-M., and Piana, L. Age as a prognostic factor in breast cancer: relationship to pathologic and biologic features. International journal of cancer, 62(2):138–144, 1995.
  • Campanella et al. (2019) Campanella, G., Hanna, M. G., Geneslaw, L., Miraflor, A., Werneck Krauss Silva, V., Busam, K. J., Brogi, E., Reuter, V. E., Klimstra, D. S., and Fuchs, T. J. Clinical-grade computational pathology using weakly supervised deep learning on whole slide images. Nature medicine, 25(8):1301–1309, 2019.
  • Cao et al. (2022) Cao, Z., Xu, Q., Yang, Z., He, Y., Cao, X., and Huang, Q. Otkge: Multi-modal knowledge graph embeddings via optimal transport. Advances in Neural Information Processing Systems, 35:39090–39102, 2022.
  • Carmichael et al. (2022) Carmichael, I., Song, A. H., Chen, R. J., Williamson, D. F., Chen, T. Y., and Mahmood, F. Incorporating intratumoral heterogeneity into weakly-supervised deep learning models via variance pooling. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp.  387–397. Springer, 2022.
  • Chen et al. (2020a) Chen, L., Gan, Z., Cheng, Y., Li, L., Carin, L., and Liu, J. Graph optimal transport for cross-domain alignment. In International Conference on Machine Learning, pp.  1542–1553. PMLR, 2020a.
  • Chen et al. (2020b) Chen, R. J., Lu, M. Y., Wang, J., Williamson, D. F., Rodig, S. J., Lindeman, N. I., and Mahmood, F. Pathomic fusion: an integrated framework for fusing histopathology and genomic features for cancer diagnosis and prognosis. IEEE Transactions on Medical Imaging, 41(4):757–770, 2020b.
  • Chen et al. (2021) Chen, R. J., Lu, M. Y., Weng, W.-H., Chen, T. Y., Williamson, D. F., Manz, T., Shady, M., and Mahmood, F. Multimodal co-attention transformer for survival prediction in gigapixel whole slide images. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp.  4015–4025, 2021.
  • Chen et al. (2022) Chen, R. J., Lu, M. Y., Williamson, D. F., Chen, T. Y., Lipkova, J., Noor, Z., Shaban, M., Shady, M., Williams, M., Joo, B., et al. Pan-cancer integrative histology-genomic analysis via multimodal deep learning. Cancer Cell, 40(8):865–878, 2022.
  • Chen et al. (2024) Chen, R. J., Ding, T., Lu, M. Y., Williamson, D. F. K., Jaume, G., Song, A. H., Chen, B., Zhang, A., Shao, D., Shaban, M., Williams, M., Oldenburg, L., Weishaupt, L. L., Wang, J. J., Vaidya, A., Le, L. P., Gerber, G., Sahai, S., Williams, W., and Mahmood, F. Towards a general-purpose foundation model for computational pathology. Nature Medicine, 2024.
  • Chizat et al. (2018) Chizat, L., Peyré, G., Schmitzer, B., and Vialard, F.-X. Scaling algorithms for unbalanced optimal transport problems. Mathematics of Computation, 87(314):2563–2609, 2018.
  • Cox (1972) Cox, D. R. Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological), 34(2):187–202, 1972.
  • Cui et al. (2024) Cui, H., Wang, C., Maan, H., Pang, K., Luo, F., Duan, N., and Wang, B. scGPT: toward building a foundation model for single-cell multi-omics using generative AI. Nature Methods, pp.  1–11, 2024.
  • Cuturi (2013) Cuturi, M. Sinkhorn distances: Lightspeed computation of optimal transport. Advances in neural information processing systems, 26, 2013.
  • dan Guo et al. (2022) dan Guo, D., Tian, L., Zhang, M., Zhou, M., and Zha, H. Learning prototype-oriented set representations for meta-learning. In International Conference on Learning Representations, 2022.
  • Dempster et al. (1977) Dempster, A. P., Laird, N. M., and Rubin, D. B. Maximum likelihood from incomplete data via the em algorithm. Journal of the royal statistical society: series B (methodological), 39(1):1–22, 1977.
  • Deng et al. (2009) Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei-Fei, L. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp.  248–255. IEEE, 2009.
  • Ding et al. (2023) Ding, K., Zhou, M., Metaxas, D. N., and Zhang, S. Pathology-and-genomics multimodal transformer for survival outcome prediction. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp.  622–631. Springer, 2023.
  • Dongre & Weinberg (2019) Dongre, A. and Weinberg, R. A. New insights into the mechanisms of epithelial–mesenchymal transition and implications for cancer. Nature reviews Molecular cell biology, 20(2):69–84, 2019.
  • Dosovitskiy et al. (2021) Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021.
  • Duan et al. (2022) Duan, J., Chen, L., Tran, S., Yang, J., Xu, Y., Zeng, B., and Chilimbi, T. Multi-modal alignment using representation codebook. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  15651–15660, 2022.
  • Elmarakeby et al. (2021) Elmarakeby, H. A., Hwang, J., Arafeh, R., Crowdis, J., Gang, S., Liu, D., AlDubayan, S. H., Salari, K., Kregel, S., Richter, C., et al. Biologically informed deep neural network for prostate cancer discovery. Nature, 598(7880):348–352, 2021.
  • Gat et al. (2020) Gat, I., Schwartz, I., Schwing, A., and Hazan, T. Removing bias in multi-modal classifiers: Regularization by maximizing functional entropies. Advances in Neural Information Processing Systems, 33:3197–3208, 2020.
  • Genevay et al. (2018) Genevay, A., Peyre, G., and Cuturi, M. Learning generative models with sinkhorn divergences. In Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics, volume 84 of Proceedings of Machine Learning Research, pp.  1608–1617. PMLR, 09–11 Apr 2018.
  • Girdhar et al. (2022) Girdhar, R., Singh, M., Ravi, N., van der Maaten, L., Joulin, A., and Misra, I. Omnivore: A single model for many visual modalities. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  16102–16112, 2022.
  • Giudetti et al. (2019) Giudetti, A. M., De Domenico, S., Ragusa, A., Lunetti, P., Gaballo, A., Franck, J., Simeone, P., Nicolardi, G., De Nuccio, F., Santino, A., et al. A specific lipid metabolic profile is associated with the epithelial mesenchymal transition program. Biochimica et Biophysica Acta (BBA)-Molecular and Cell Biology of Lipids, 1864(3):344–357, 2019.
  • Goldman et al. (2020) Goldman, M. J., Craft, B., Hastie, M., Repečka, K., McDade, F., Kamath, A., Banerjee, A., Luo, Y., Rogers, D., Brooks, A. N., et al. Visualizing and interpreting cancer genomics data via the xena platform. Nature biotechnology, 38(6):675–678, 2020.
  • Harrell et al. (1982) Harrell, F. E., Califf, R. M., Pryor, D. B., Lee, K. L., and Rosati, R. A. Evaluating the yield of medical tests. Jama, 247(18):2543–2546, 1982.
  • Howard et al. (2021) Howard, F. M., Dolezal, J., Kochanny, S., Schulte, J., Chen, H., Heij, L., Huo, D., Nanda, R., Olopade, O. I., Kather, J. N., et al. The impact of site-specific digital histology signatures on deep learning model accuracy and bias. Nature communications, 12(1):4423, 2021.
  • Ilse et al. (2018) Ilse, M., Tomczak, J., and Welling, M. Attention-based deep multiple instance learning. In International conference on machine learning, pp.  2127–2136. PMLR, 2018.
  • Ishay-Ronen et al. (2019) Ishay-Ronen, D., Diepenbruck, M., Kalathur, R. K. R., Sugiyama, N., Tiede, S., Ivanek, R., Bantug, G., Morini, M. F., Wang, J., Hess, C., et al. Gain fat—lose metastasis: converting invasive breast cancer cells into adipocytes inhibits cancer metastasis. Cancer cell, 35(1):17–32, 2019.
  • Jaegle et al. (2022) Jaegle, A., Borgeaud, S., Alayrac, J.-B., Doersch, C., Ionescu, C., Ding, D., Koppula, S., Zoran, D., Brock, A., Shelhamer, E., Henaff, O. J., Botvinick, M., Zisserman, A., Vinyals, O., and Carreira, J. Perceiver IO: A general architecture for structured inputs & outputs. In International Conference on Learning Representations, 2022.
  • Jaume et al. (2024) Jaume, G., Vaidya, A., Chen, R., Williamson, D., Liang, P., and Mahmood, F. Modeling dense multimodal interactions between biological pathways and histology for survival prediction. Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2024.
  • Katzman et al. (2018) Katzman, J. L., Shaham, U., Cloninger, A., Bates, J., Jiang, T., and Kluger, Y. Deepsurv: personalized treatment recommender system using a cox proportional hazards deep neural network. BMC medical research methodology, 18(1):1–12, 2018.
  • Kim (2022) Kim, M. Differentiable expectation-maximization for set representation learning. In International Conference on Learning Representations, 2022.
  • Klambauer et al. (2017) Klambauer, G., Unterthiner, T., Mayr, A., and Hochreiter, S. Self-normalizing neural networks. In Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
  • Kolouri et al. (2017) Kolouri, S., Park, S. R., Thorpe, M., Slepcev, D., and Rohde, G. K. Optimal mass transport: Signal processing and machine-learning applications. IEEE Signal Processing Magazine, 34(4):43–59, 2017. doi: 10.1109/MSP.2017.2695801.
  • Koundouros & Poulogiannis (2020) Koundouros, N. and Poulogiannis, G. Reprogramming of fatty acid metabolism in cancer. British journal of cancer, 122(1):4–22, 2020.
  • Kvamme et al. (2019) Kvamme, H., Borgan, Ø., and Scheel, I. Time-to-event prediction with neural networks and cox regression. arXiv preprint arXiv:1907.00825, 2019.
  • Lee et al. (2024) Lee, D. B., Lee, S., Ko, J., Kawaguchi, K., Lee, J., and Hwang, S. J. Self-supervised dataset distillation for transfer learning. In The Twelfth International Conference on Learning Representations, 2024.
  • Lee et al. (2019) Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., and Teh, Y. W. Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning, pp.  3744–3753. PMLR, 2019.
  • Lee et al. (2020) Lee, S., Ha, J., Zhang, D., and Kim, G. A neural dirichlet process mixture model for task-free continual learning. In International Conference on Learning Representations, 2020.
  • Li & Dewey (2011) Li, B. and Dewey, C. N. RSEM: accurate transcript quantification from rna-seq data with or without a reference genome. BMC bioinformatics, 12:1–16, 2011.
  • Li et al. (2023) Li, H., Zhu, C., Zhang, Y., Sun, Y., Shui, Z., Kuang, W., Zheng, S., and Yang, L. Task-specific fine-tuning via variational information bottleneck for weakly-supervised pathology whole slide image classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  7454–7463, 2023.
  • Li et al. (2022) Li, N., Li, W., Jiang, Y., and Xia, S.-T. Deep dirichlet process mixture models. In Uncertainty in Artificial Intelligence, pp.  1138–1147. PMLR, 2022.
  • Liang et al. (2023) Liang, P. P., Lyu, Y., Fan, X., Tsaw, J., Liu, Y., Mo, S., Yogatama, D., Morency, L.-P., and Salakhutdinov, R. High-modality multimodal transformer: Quantifying modality & interaction heterogeneity for high-modality representation learning. Transactions on Machine Learning Research, 2023. ISSN 2835-8856.
  • Liberzon et al. (2015) Liberzon, A., Birger, C., Thorvaldsdóttir, H., Ghandi, M., Mesirov, J. P., and Tamayo, P. The molecular signatures database hallmark gene set collection. Cell systems, 1(6):417–425, 2015.
  • Lipkova et al. (2022) Lipkova, J., Chen, R. J., Chen, B., Lu, M. Y., Barbieri, M., Shao, D., Vaidya, A. J., Chen, C., Zhuang, L., Williamson, D. F., et al. Artificial intelligence for multimodal data integration in oncology. Cancer cell, 40(10):1095–1110, 2022.
  • Liu et al. (2018) Liu, J., Lichtenberg, T., Hoadley, K. A., Poisson, L. M., Lazar, A. J., Cherniack, A. D., Kovatich, A. J., Benz, C. C., Levine, D. A., Lee, A. V., et al. An integrated TCGA pan-cancer clinical data resource to drive high-quality survival outcome analytics. Cell, 173(2):400–416, 2018.
  • Loo et al. (2021) Loo, S. Y., Toh, L. P., Xie, W. H., Pathak, E., Tan, W., Ma, S., Lee, M. Y., Shatishwaran, S., Yeo, J. Z. Z., Yuan, J., et al. Fatty acid oxidation is a druggable gateway regulating cellular plasticity for driving metastasis in breast cancer. Science Advances, 7(41):eabh2443, 2021.
  • Lu et al. (2021) Lu, M. Y., Williamson, D. F., Chen, T. Y., Chen, R. J., Barbieri, M., and Mahmood, F. Data-efficient and weakly supervised computational pathology on whole-slide images. Nature biomedical engineering, 5(6):555–570, 2021.
  • Mialon et al. (2021) Mialon, G., Chen, D., d’Aspremont, A., and Mairal, J. A trainable optimal transport embedding for feature aggregation and its relationship to attention. In International Conference on Learning Representations, 2021.
  • Mobadersany et al. (2018) Mobadersany, P., Yousefi, S., Amgad, M., Gutman, D. A., Barnholtz-Sloan, J. S., Velázquez Vega, J. E., Brat, D. J., and Cooper, L. A. Predicting cancer outcomes from histology and genomics using convolutional networks. Proceedings of the National Academy of Sciences, 115(13):E2970–E2979, 2018.
  • Nelson et al. (2014) Nelson, E. R., Chang, C.-y., and McDonnell, D. P. Cholesterol and breast cancer pathophysiology. Trends in Endocrinology & Metabolism, 25(12):649–655, 2014.
  • Olea-Flores et al. (2018) Olea-Flores, M., Juárez-Cruz, J. C., Mendoza-Catalán, M. A., Padilla-Benavides, T., and Navarro-Tito, N. Signaling pathways induced by leptin during epithelial–mesenchymal transition in breast cancer. International journal of molecular sciences, 19(11):3493, 2018.
  • Olea-Flores et al. (2020) Olea-Flores, M., Juárez-Cruz, J. C., Zuñiga-Eulogio, M. D., Acosta, E., García-Rodríguez, E., Zacapala-Gomez, A. E., Mendoza-Catalán, M. A., Ortiz-Ortiz, J., Ortuño-Pineda, C., and Navarro-Tito, N. New actors driving the epithelial–mesenchymal transition in cancer: The role of leptin. Biomolecules, 10(12):1676, 2020.
  • Oquab et al. (2023) Oquab, M., Darcet, T., Moutakanni, T., Vo, H., Szafraniec, M., Khalidov, V., Fernandez, P., Haziza, D., Massa, F., El-Nouby, A., et al. Dinov2: Learning robust visual features without supervision. arXiv preprint arXiv:2304.07193, 2023.
  • Pölsterl (2020) Pölsterl, S. scikit-survival: A library for time-to-event analysis built on top of scikit-learn. Journal of Machine Learning Research, 21(212):1–6, 2020.
  • Pramanick et al. (2022) Pramanick, S., Roy, A., and Patel, V. M. Multimodal learning using optimal transport for sarcasm and humor detection. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp.  3930–3940, 2022.
  • Pramanick et al. (2023) Pramanick, S., **g, L., Nag, S., Zhu, J., Shah, H. J., LeCun, Y., and Chellappa, R. VoLTA: Vision-language transformer with weakly-supervised local-feature alignment. Transactions on Machine Learning Research, 2023. ISSN 2835-8856.
  • Quiros et al. (2023) Quiros, A. C., Coudray, N., Yeaton, A., Yang, X., Liu, B., Le, H., Chiriboga, L., Karimkhan, A., Narula, N., Moore, D. A., Park, C. Y., Pass, H., Moreira, A. L., Quesne, J. L., Tsirigos, A., and Yuan, K. Map** the landscape of histomorphological cancer phenotypes using self-supervised learning on unlabeled, unannotated pathology slides, 2023.
  • Rakha et al. (2010) Rakha, E. A., Reis-Filho, J. S., Baehner, F., Dabbs, D. J., Decker, T., Eusebi, V., Fox, S. B., Ichihara, S., Jacquemier, J., Lakhani, S. R., et al. Breast cancer prognostic classification in the molecular era: the role of histological grade. Breast cancer research, 12:1–12, 2010.
  • Reimand et al. (2019) Reimand, J., Isserlin, R., Voisin, V., Kucera, M., Tannus-Lopes, C., Rostamianfar, A., Wadi, L., Meyer, M., Wong, J., Xu, C., et al. Pathway enrichment analysis and visualization of omics data using g: Profiler, GSEA, Cytoscape and EnrichmentMap. Nature protocols, 14(2):482–517, 2019.
  • Režen et al. (2022) Režen, T., Rozman, D., Kovács, T., Kovács, P., Sipos, A., Bai, P., and Mikó, E. The role of bile acids in carcinogenesis. Cellular and molecular life sciences, 79(5):243, 2022.
  • Rosen et al. (2023) Rosen, Y., Roohani, Y., Agrawal, A., Samotorcan, L., Consortium, T. S., Quake, S. R., and Leskovec, J. Universal cell embeddings: A foundation model for cell biology. bioRxiv, pp.  2023–11, 2023.
  • Shao et al. (2021) Shao, Z., Bian, H., Chen, Y., Wang, Y., Zhang, J., Ji, X., et al. Transmil: Transformer based correlated multiple instance learning for whole slide image classification. Advances in Neural Information Processing Systems, 34:2136–2147, 2021.
  • Snell et al. (2017) Snell, J., Swersky, K., and Zemel, R. Prototypical networks for few-shot learning. Advances in neural information processing systems, 30, 2017.
  • Song et al. (2023) Song, A. H., Jaume, G., Williamson, D. F., Lu, M. Y., Vaidya, A., Miller, T. R., and Mahmood, F. Artificial intelligence for digital and computational pathology. Nature Reviews Bioengineering, 1(12):930–949, 2023.
  • Song et al. (2024a) Song, A. H., Chen, R. J., Ding, T., Williamson, D. F., Jaume, G., and Mahmood, F. Morphological prototy** for unsupervised slide representation learning in computational pathology. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2024a.
  • Song et al. (2024b) Song, A. H., Williams, M., Williamson, D. F., Chow, S. S., Jaume, G., Gao, G., Zhang, A., Chen, B., Baras, A. S., Serafin, R., et al. Analysis of 3D pathology samples using weakly supervised AI. Cell, 187(10):2502–2520, 2024b.
  • Steyaert et al. (2023) Steyaert, S., Pizurica, M., Nagaraj, D., Khandelwal, P., Hernandez-Boussard, T., Gentles, A. J., and Gevaert, O. Multimodal data fusion for cancer biomarker discovery with deep learning. Nature Machine Intelligence, 5(4):351–362, 2023.
  • Subramanian et al. (2005) Subramanian, A., Tamayo, P., Mootha, V. K., Mukherjee, S., Ebert, B. L., Gillette, M. A., Paulovich, A., Pomeroy, S. L., Golub, T. R., Lander, E. S., et al. Gene set enrichment analysis: a knowledge-based approach for interpreting genome-wide expression profiles. Proceedings of the National Academy of Sciences, 102(43):15545–15550, 2005.
  • Tas et al. (2013) Tas, F., Ciftci, R., Kilic, L., and Karabulut, S. Age is a prognostic factor affecting survival in lung cancer patients. Oncology letters, 6(5):1507–1513, 2013.
  • Theodoris et al. (2023) Theodoris, C. V., Xiao, L., Chopra, A., Chaffin, M. D., Al Sayed, Z. R., Hill, M. C., Mantineo, H., Brydon, E. M., Zeng, Z., Liu, X. S., et al. Transfer learning enables predictions in network biology. Nature, 618(7965):616–624, 2023.
  • Uno et al. (2011) Uno, H., Cai, T., Pencina, M. J., D’Agostino, R. B., and Wei, L.-J. On the c-statistics for evaluating overall adequacy of risk prediction procedures with censored survival data. Statistics in medicine, 30(10):1105–1117, 2011.
  • Vaidya et al. (2024) Vaidya, A., Chen, R. J., Williamson, D. F., Song, A. H., Jaume, G., Yang, Y., Hartvigsen, T., Dyer, E. C., Lu, M. Y., Lipkova, J., et al. Demographic bias in misdiagnosis by computational pathology models. Nature Medicine, 30(4):1174–1190, 2024.
  • Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Volinsky-Fremond et al. (2024) Volinsky-Fremond, S., Horeweg, N., Andani, S., Barkey Wolf, J., Lafarge, M. W., de Kroon, C. D., Ørtoft, G., Høgdall, E., Dijkstra, J., Jobsen, J. J., et al. Prediction of recurrence risk in endometrial cancer with multimodal deep learning. Nature Medicine, pp.  1–12, 2024.
  • Vu et al. (2023) Vu, Q. D., Rajpoot, K., Raza, S. E. A., and Rajpoot, N. Handcrafted Histological Transformer (H2T): Unsupervised representation of whole slide images. Medical Image Analysis, 85:102743, 2023. ISSN 1361-8415.
  • Wang et al. (2020) Wang, W., Tran, D., and Feiszli, M. What makes training multi-modal classification networks hard? In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp.  12695–12705, 2020.
  • Wang et al. (2023) Wang, W., Bao, H., Dong, L., Bjorck, J., Peng, Z., Liu, Q., Aggarwal, K., Mohammed, O. K., Singhal, S., Som, S., et al. Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  19175–19186, 2023.
  • Wang et al. (2022) Wang, X., Yang, S., Zhang, J., Wang, M., Zhang, J., Yang, W., Huang, J., and Han, X. Transformer-based unsupervised contrastive learning for histopathological image classification. Medical image analysis, 81:102559, 2022.
  • Wang et al. (2012) Wang, Y.-Y., Lehuédé, C., Laurent, V., Dirat, B., Dauvillier, S., Bochet, L., Le Gonidec, S., Escourrou, G., Valet, P., and Muller, C. Adipose tissue and breast epithelial cells: a dangerous dynamic duo in breast cancer. Cancer letters, 324(2):142–151, 2012.
  • Wang et al. (2021) Wang, Z., Li, R., Wang, M., and Li, A. GPDBN: deep bilinear network integrating both genomic data and pathological images for breast cancer prognosis prediction. Bioinformatics, 37(18):2963–2970, 2021.
  • Wang et al. (2024) Wang, Z., Zhang, Y., Xu, Y., Imoto, S., Chen, H., and Song, J. Histo-genomic knowledge distillation for cancer prognosis from histopathology whole slide images, 2024.
  • Wong (1986) Wong, W. H. Theory of partial likelihood. The Annals of statistics, pp.  88–123, 1986.
  • Wu & Zhou (2010) Wu, Y.-d. and Zhou, B. TNF-α𝛼\alphaitalic_α/NF-κ𝜅\kappaitalic_κB/Snail pathway in cancer cell migration and invasion. British journal of cancer, 102(4):639–644, 2010.
  • Wulczyn et al. (2020) Wulczyn, E., Steiner, D. F., Xu, Z., Sadhwani, A., Wang, H., Flament-Auvigne, I., Mermel, C. H., Chen, P.-H. C., Liu, Y., and Stumpe, M. C. Deep learning-based survival prediction for multiple cancer types using histopathology images. PloS ONE, 15(6), 2020.
  • Xiang & Zhang (2022) Xiang, J. and Zhang, J. Exploring low-rank property in multiple instance learning for whole slide image classification. In The Eleventh International Conference on Learning Representations, 2022.
  • Xiang & Zhang (2023) Xiang, J. and Zhang, J. Exploring low-rank property in multiple instance learning for whole slide image classification. In The Eleventh International Conference on Learning Representations, 2023.
  • Xiong et al. (2021) Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., and Singh, V. Nyströmformer: A nyström-based algorithm for approximating self-attention. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, pp.  14138–14148, 2021.
  • Xu et al. (2023) Xu, P., Zhu, X., and Clifton, D. A. Multimodal learning with transformers: A survey. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2023.
  • Xu & Chen (2023) Xu, Y. and Chen, H. Multimodal optimal transport-based co-attention transformer with global structure consistency for survival prediction. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), pp.  21241–21251, October 2023.
  • Yao et al. (2019) Yao, J., Zhu, X., and Huang, J. Deep multi-instance learning for survival prediction from whole slide images. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp.  496–504. Springer, 2019.
  • Yao et al. (2020) Yao, J., Zhu, X., Jonnagaddala, J., Hawkins, N., and Huang, J. Whole slide images based cancer survival prediction using attention guided deep multiple instance learning networks. Medical Image Analysis, 65:101789, 2020.
  • Yu et al. (2022) Yu, X. Q., Yap, M. L., Cheng, E. S., Ngo, P. J., Vaneckova, P., Karikios, D., Canfell, K., and Weber, M. F. Evaluating prognostic factors for sex differences in lung cancer survival: findings from a large australian cohort. Journal of Thoracic Oncology, 17(5):688–699, 2022.
  • Zadeh & Schmid (2020) Zadeh, S. G. and Schmid, M. Bias in cross-entropy-based training of deep survival networks. IEEE transactions on pattern analysis and machine intelligence, 43(9):3126–3137, 2020.
  • Zadeh & Schmid (2021) Zadeh, S. G. and Schmid, M. Bias in cross-entropy-based training of deep survival networks. IEEE Transactions on Pattern Analysis and Machine Intelligence, 43(9):3126–3137, 2021. doi: 10.1109/TPAMI.2020.2979450.
  • Zhang et al. (2024) Zhang, Y., Xu, Y., Chen, J., Xie, F., and Chen, H. Prototypical information bottlenecking and disentangling for multimodal cancer survival prediction. In The Twelfth International Conference on Learning Representations, 2024.
  • Zhou & Chen (2023) Zhou, F. and Chen, H. Cross-modal translation and alignment for survival analysis. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp.  21485–21494, 2023.

Appendix A Prototype-based histopathology baselines

In this section, we present the three histology prototype aggregation approaches that can be used by MMP, with particular emphasis on the Gaussian mixture model (GMM). The following prototype-based aggregation schemes can be embedded as a feed-forward module in our models.

A.1 Hard clustering (HC)

For each 𝐳i,h.subscript𝐳𝑖h.\mathbf{z}_{i,\text{h.}}bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT, we identify the closest prototype 𝐚c,h.subscript𝐚𝑐h.\mathbf{a}_{c,\text{h.}}bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT evaluated with the 2subscript2\mathcal{L}_{2}caligraphic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT distance, i.e., ci=argmaxc𝐳i,h.𝐚c,h.2c_{i}=\arg\max_{c}\lVert\mathbf{z}_{i,\text{h.}}-\mathbf{a}_{c,\text{h.}}% \rVert_{2}italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_arg roman_max start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∥ bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT - bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT to determine the cluster assignment. The post-aggregation embedding 𝐳c,h.agg.superscriptsubscript𝐳𝑐h.agg.\mathbf{z}_{c,\text{h.}}^{\text{agg.}}bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT is an average of all embeddings assigned to c𝑐citalic_c,

𝐳c,h.agg.=i=1Nh.𝟏ci=c𝐳i,h./i=1Nh.𝟏ci=c.superscriptsubscript𝐳𝑐h.agg.superscriptsubscript𝑖1subscript𝑁h.subscript1subscript𝑐𝑖𝑐subscript𝐳𝑖h.superscriptsubscript𝑖1subscript𝑁h.subscript1subscript𝑐𝑖𝑐\mathbf{z}_{c,\text{h.}}^{\text{agg.}}=\sum_{i=1}^{N_{\text{h.}}}\mathbf{1}_{c% _{i}=c}\cdot\mathbf{z}_{i,\text{h.}}/\sum_{i=1}^{N_{\text{h.}}}\mathbf{1}_{c_{% i}=c}.bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_1 start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c end_POSTSUBSCRIPT ⋅ bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT / ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_1 start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c end_POSTSUBSCRIPT . (7)

where 𝟏1\mathbf{1}bold_1 is the indicator function.

A.2 Optimal transport (OT)

We can formulate aggregation as that of transporting from the empirical distribution of p^(𝐳h.)=1/Nh.i=1Nh.δ(𝐳i,h)^𝑝subscript𝐳h.1subscript𝑁h.superscriptsubscript𝑖1subscript𝑁h.𝛿subscript𝐳𝑖\hat{p}(\mathbf{z}_{\text{h.}})=1/N_{\text{h.}}\cdot\sum_{i=1}^{N_{\text{h.}}}% \delta(\mathbf{z}_{i,h})over^ start_ARG italic_p end_ARG ( bold_z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ) = 1 / italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ⋅ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_δ ( bold_z start_POSTSUBSCRIPT italic_i , italic_h end_POSTSUBSCRIPT ) to p^(𝐚h.)=1/Ch.i=1Ch.δ(𝐚c,h.)^𝑝subscript𝐚h.1subscript𝐶h.superscriptsubscript𝑖1subscript𝐶h.𝛿subscript𝐚𝑐h.\hat{p}(\mathbf{a}_{\text{h.}})=1/C_{\text{h.}}\cdot\sum_{i=1}^{C_{\text{h.}}}% \delta({\mathbf{a}_{c,\text{h.}}})over^ start_ARG italic_p end_ARG ( bold_a start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ) = 1 / italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ⋅ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_δ ( bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT ). The transport plan 𝐓+Nh.×Ch.𝐓superscriptsubscriptsubscript𝑁h.subscript𝐶h.\mathbf{T}\in\mathbb{R}_{+}^{N_{\text{h.}}\times C_{\text{h.}}}bold_T ∈ blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT × italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is given as the solution to the following entropic-regularized optimal transport problem (Cuturi, 2013; Kolouri et al., 2017),

min𝐓i,c𝐳i,h.𝐚c,h.2𝐓i,c+ϵ𝐓i,clog𝐓i,c,such that i=1Nh.𝐓i,c=1/Ch.andi=1Ch.𝐓i,c=1/Nh.,\begin{split}&\min_{\mathbf{T}}\sum_{i,c}\lVert\mathbf{z}_{i,\text{h.}}-% \mathbf{a}_{c,\text{h.}}\rVert_{2}\cdot\mathbf{T}_{i,c}+\epsilon\cdot\mathbf{T% }_{i,c}\log\mathbf{T}_{i,c},\quad\text{such that }\sum_{i=1}^{N_{\text{h.}}}% \mathbf{T}_{i,c}=1/C_{\text{h.}}\,\,\text{and}\,\,\sum_{i=1}^{C_{\text{h.}}}% \mathbf{T}_{i,c}=1/N_{\text{h.}},\\ \end{split}start_ROW start_CELL end_CELL start_CELL roman_min start_POSTSUBSCRIPT bold_T end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT ∥ bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT - bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⋅ bold_T start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT + italic_ϵ ⋅ bold_T start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT roman_log bold_T start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT , such that ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_T start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT = 1 / italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT and ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_T start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT = 1 / italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT , end_CELL end_ROW (8)

where ε𝜀\varepsilonitalic_ε is the regularization parameter. Based on the optimal transport plan 𝐓^^𝐓\widehat{\mathbf{T}}over^ start_ARG bold_T end_ARG obtained by the widely-used Sinkhorn algorithm (Cuturi, 2013), the post-aggregation embedding is given as 𝐳c,h.agg.=i=1Nh.𝐓^i,c𝐳i,h.superscriptsubscript𝐳𝑐h.agg.superscriptsubscript𝑖1subscript𝑁h.subscript^𝐓𝑖𝑐subscript𝐳𝑖h.\mathbf{z}_{c,\text{h.}}^{\text{agg.}}=\sum_{i=1}^{N_{\text{h.}}}\widehat{% \mathbf{T}}_{i,c}\cdot\mathbf{z}_{i,\text{h.}}bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT over^ start_ARG bold_T end_ARG start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT ⋅ bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT.

A.3 Gaussian Mixture Models

With the Gaussian mixture model (GMM) as the generative model for each token embedding, we provide a detailed derivation for estimation of 1) the posterior probability for the prototype assignment q(c|𝐳i,h.;θ)𝑞conditional𝑐subscript𝐳𝑖h.𝜃q(c|\mathbf{z}_{i,\text{h.}};\theta)italic_q ( italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ ) and 2) the GMM parameters θ={πc,𝝁c,Σc}𝜃subscript𝜋𝑐subscript𝝁𝑐subscriptΣ𝑐\theta=\{\pi_{c},\boldsymbol{\mu}_{c},\Sigma_{c}\}italic_θ = { italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT }. Given the GMM specification,

p(𝐳i,h.;θ)=c=1Ch.p(ci=c;θ)p(𝐳i,h.|ci=c;θ)=c=1Ch.πc𝒩(𝐳i,h.;𝝁c,Σc),s.t.c=1Ch.πc=1,formulae-sequence𝑝subscript𝐳𝑖h.𝜃superscriptsubscript𝑐1subscript𝐶h.𝑝subscript𝑐𝑖𝑐𝜃𝑝conditionalsubscript𝐳𝑖h.subscript𝑐𝑖𝑐𝜃superscriptsubscript𝑐1subscript𝐶h.subscript𝜋𝑐𝒩subscript𝐳𝑖h.subscript𝝁𝑐subscriptΣ𝑐𝑠𝑡superscriptsubscript𝑐1subscript𝐶h.subscript𝜋𝑐1\begin{split}p(\mathbf{z}_{i,\text{h.}};\theta)&=\sum_{c=1}^{C_{\text{h.}}}p(c% _{i}=c;\theta)\cdot p(\mathbf{z}_{i,\text{h.}}|c_{i}=c;\theta)\\ &=\sum_{c=1}^{C_{\text{h.}}}\pi_{c}\cdot\mathcal{N}(\mathbf{z}_{i,\text{h.}};% \boldsymbol{\mu}_{c},\Sigma_{c}),\,\,s.t.\sum_{c=1}^{C_{\text{h.}}}\pi_{c}=1,% \\ \end{split}start_ROW start_CELL italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ ) end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ ) ⋅ italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT | italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ⋅ caligraphic_N ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) , italic_s . italic_t . ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = 1 , end_CELL end_ROW (9)

the goal is to estimate θ𝜃\thetaitalic_θ that maximizes the log-likelihood maxθi=1Nh.logp(𝐳i,h.;θ)=maxθn=1Nh.logp(𝐳i,h.;θ)subscript𝜃superscriptsubscript𝑖1subscript𝑁h.𝑝subscript𝐳𝑖h.𝜃subscript𝜃superscriptsubscript𝑛1subscript𝑁h.𝑝subscript𝐳𝑖h.𝜃\max_{\theta}\sum_{i=1}^{N_{\text{h.}}}\log p(\mathbf{z}_{i,\text{h.}};\theta)% =\max_{\theta}\sum_{n=1}^{N_{\text{h.}}}\log p(\mathbf{z}_{i,\text{h.}};\theta)roman_max start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ ) = roman_max start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ ). We now present a detailed walkthrough of the expectation-maximization (EM) algorithm (Dempster et al., 1977; Kim, 2022; Song et al., 2024a) and how these ultimately lead to 𝐳c,h.agg.superscriptsubscript𝐳𝑐h.agg.\mathbf{z}_{c,\text{h.}}^{\text{agg.}}bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT.

Using Jensen’s inequality, we can lower-bound the log-likelihood as follows,

i=1Nh.logp(𝐳i,h.;θ)=i=1Nh.logc=1Ch.p(𝐳i,h.,ci=c;θ)=i=1Nh.logc=1Ch.q(ci=c|𝐳i,h.;θold)p(𝐳i,h.,ci=c;θ)q(ci=c|𝐳i,h.;θold)i=1Nh.c=1Ch.q(ci=c|𝐳i,h.;θold)logp(𝐳i,h.,ci=c;θ)q(ci=c|𝐳i,h.;θold)=i=1Nh.Eq(ci=c|𝐳i,h.;θold)[logp(𝐳i,h.,ci=c;θ)]Q(θ;θold)i=1Nh.Eq(ci=c|𝐳i,h.;θold)[q(ci=c|𝐳i,h.;θold)]H(C;θold).\begin{split}\sum_{i=1}^{N_{\text{h.}}}\log p(\mathbf{z}_{i,\text{h.}};\theta)% &=\sum_{i=1}^{N_{\text{h.}}}\log\sum_{c=1}^{C_{\text{h.}}}p(\mathbf{z}_{i,% \text{h.}},c_{i}=c;\theta)\\ &=\sum_{i=1}^{N_{\text{h.}}}\log\sum_{c=1}^{C_{\text{h.}}}q(c_{i}=c|\mathbf{z}% _{i,\text{h.}};\theta_{\text{old}})\cdot\frac{p(\mathbf{z}_{i,\text{h.}},c_{i}% =c;\theta)}{q(c_{i}=c|\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})}\\ &\geq\sum_{i=1}^{N_{\text{h.}}}\sum_{c=1}^{C_{\text{h.}}}q(c_{i}=c|\mathbf{z}_% {i,\text{h.}};\theta_{\text{old}})\log\frac{p(\mathbf{z}_{i,\text{h.}},c_{i}=c% ;\theta)}{q(c_{i}=c|\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})}\\ &=\sum_{i=1}^{N_{\text{h.}}}\underbrace{E_{q(c_{i}=c|\mathbf{z}_{i,\text{h.}};% \theta_{\text{old}})}\left[\log p(\mathbf{z}_{i,\text{h.}},c_{i}=c;\theta)% \right]}_{Q(\theta;\theta_{\text{old}})}-\sum_{i=1}^{N_{\text{h.}}}\underbrace% {E_{q(c_{i}=c|\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})}\left[q(c_{i}=c|% \mathbf{z}_{i,\text{h.}};\theta_{\text{old}})\right]}_{-H(C;\theta_{\text{old}% })}.\\ \end{split}start_ROW start_CELL ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ ) end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) ⋅ divide start_ARG italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ ) end_ARG start_ARG italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) roman_log divide start_ARG italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ ) end_ARG start_ARG italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT under⏟ start_ARG italic_E start_POSTSUBSCRIPT italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ ) ] end_ARG start_POSTSUBSCRIPT italic_Q ( italic_θ ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT under⏟ start_ARG italic_E start_POSTSUBSCRIPT italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) ] end_ARG start_POSTSUBSCRIPT - italic_H ( italic_C ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT . end_CELL end_ROW (10)

Instead of maximizing the log-likelihood directly, we can now maximize a surrogate function, which is the lower bound given by Jensen’s inequality. It can be shown that increasing this lower bound with respect to θ𝜃\thetaitalic_θ leads to monotonically increasing the actual log-likelihood (Dempster et al., 1977). The optimization procedure involves iterative alternating steps of the E-step and the M-step and is thus referred to as the Expectation-Maximization (EM) algorithm.

The surrogate function consists of two terms, Q(θ;θold)𝑄𝜃subscript𝜃oldQ(\theta;\theta_{\text{old}})italic_Q ( italic_θ ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) and H(C;θold)𝐻𝐶subscript𝜃oldH(C;\theta_{\text{old}})italic_H ( italic_C ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ), which are expectations with respect to the posterior probability of prototype assignment, q(ci=c|𝐳i,h.;θold)𝑞subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.subscript𝜃oldq(c_{i}=c|\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ). In the E-step, we can use Bayes’ rule to compute the posterior probability and, consequently the expectations,

q(ci=c|𝐳i,h.;θold)=q(𝐳i,h.|ci=c;θold)q(ci=c;θold)q(𝐳i,h.;θold)=q(𝐳i,h.|ci=c;θold)q(ci=c;θold)c=1Ch.q(𝐳i,h.|ci=c;θold)q(ci=c;θold)=πc𝒩(𝐳i,h.;𝝁c,Σc)c=1Ch.πc𝒩(𝐳i,h.;𝝁c,Σc).𝑞subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.subscript𝜃old𝑞conditionalsubscript𝐳𝑖h.subscript𝑐𝑖𝑐subscript𝜃old𝑞subscript𝑐𝑖𝑐subscript𝜃old𝑞subscript𝐳𝑖h.subscript𝜃old𝑞conditionalsubscript𝐳𝑖h.subscript𝑐𝑖𝑐subscript𝜃old𝑞subscript𝑐𝑖𝑐subscript𝜃oldsuperscriptsubscript𝑐1subscript𝐶h.𝑞conditionalsubscript𝐳𝑖h.subscript𝑐𝑖𝑐subscript𝜃old𝑞subscript𝑐𝑖𝑐subscript𝜃oldsubscript𝜋𝑐𝒩subscript𝐳𝑖h.subscript𝝁𝑐subscriptΣ𝑐superscriptsubscript𝑐1subscript𝐶h.subscript𝜋𝑐𝒩subscript𝐳𝑖h.subscript𝝁𝑐subscriptΣ𝑐\begin{split}q(c_{i}=c|\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})&=\frac{q(% \mathbf{z}_{i,\text{h.}}|c_{i}=c;\theta_{\text{old}})\cdot q(c_{i}=c;\theta_{% \text{old}})}{q(\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})}\\ &=\frac{q(\mathbf{z}_{i,\text{h.}}|c_{i}=c;\theta_{\text{old}})\cdot q(c_{i}=c% ;\theta_{\text{old}})}{\sum_{c=1}^{C_{\text{h.}}}q(\mathbf{z}_{i,\text{h.}}|c_% {i}=c;\theta_{\text{old}})\cdot q(c_{i}=c;\theta_{\text{old}})}\\ &=\frac{\pi_{c}\cdot\mathcal{N}(\mathbf{z}_{i,\text{h.}};\boldsymbol{\mu}_{c},% \Sigma_{c})}{\sum_{c=1}^{C_{\text{h.}}}\pi_{c}\cdot\mathcal{N}(\mathbf{z}_{i,% \text{h.}};\boldsymbol{\mu}_{c},\Sigma_{c})}.\\ \end{split}start_ROW start_CELL italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_CELL start_CELL = divide start_ARG italic_q ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT | italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) ⋅ italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG italic_q ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT | italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) ⋅ italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT | italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) ⋅ italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ⋅ caligraphic_N ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ⋅ caligraphic_N ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) end_ARG . end_CELL end_ROW (11)

In the M-step, we find θnewsubscript𝜃new\theta_{\text{new}}italic_θ start_POSTSUBSCRIPT new end_POSTSUBSCRIPT that maximizes the surrogate function based on the posterior probability computed from the E-step. Since the term H(C;θold)𝐻𝐶subscript𝜃oldH(C;\theta_{\text{old}})italic_H ( italic_C ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) is not a function of θ𝜃\thetaitalic_θ and therefore a constant (it is a function of θoldsubscript𝜃old\theta_{\text{old}}italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT), we only need to optimize the term Q(θ;θold)𝑄𝜃subscript𝜃oldQ(\theta;\theta_{\text{old}})italic_Q ( italic_θ ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) by taking the derivative with respect to θ𝜃\thetaitalic_θ,

i=1Nh.Q(θ;θold)πc=0πcnew=i=1Nh.q(ci=c|𝐳i,h.;θold)Nh.i=1Nh.Q(θ;θold)𝝁c=0𝝁cnew=i=1Nh.q(ci=c|𝐳i,h.;θold)𝐳i,h.i=1Nh.q(ci=c|𝐳i,h.;θold)i=1Nh.Q(θ;θold)Σc=0Σcnew=i=1Nh.q(ci=c|𝐳i,h.;θold)(𝐳i,h.𝝁cnew)2i=1Nh.q(ci=c|𝐳i,h.;θold).superscriptsubscript𝑖1subscript𝑁h.𝑄𝜃subscript𝜃oldsubscript𝜋𝑐0superscriptsubscript𝜋𝑐newsuperscriptsubscript𝑖1subscript𝑁h.𝑞subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.subscript𝜃oldsubscript𝑁h.superscriptsubscript𝑖1subscript𝑁h.𝑄𝜃subscript𝜃oldsubscript𝝁𝑐0superscriptsubscript𝝁𝑐newsuperscriptsubscript𝑖1subscript𝑁h.𝑞subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.subscript𝜃oldsubscript𝐳𝑖h.superscriptsubscript𝑖1subscript𝑁h.𝑞subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.subscript𝜃oldsuperscriptsubscript𝑖1subscript𝑁h.𝑄𝜃subscript𝜃oldsubscriptΣ𝑐0superscriptsubscriptΣ𝑐newsuperscriptsubscript𝑖1subscript𝑁h.𝑞subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.subscript𝜃oldsuperscriptsubscript𝐳𝑖h.superscriptsubscript𝝁𝑐new2superscriptsubscript𝑖1subscript𝑁h.𝑞subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.subscript𝜃old\begin{split}\sum_{i=1}^{N_{\text{h.}}}\frac{\partial Q(\theta;\theta_{\text{% old}})}{\partial\pi_{c}}=0&\Rightarrow\pi_{c}^{\text{new}}=\frac{\sum_{i=1}^{N% _{\text{h.}}}q(c_{i}=c|\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})}{N_{\text% {h.}}}\\ \sum_{i=1}^{N_{\text{h.}}}\frac{\partial Q(\theta;\theta_{\text{old}})}{% \partial\boldsymbol{\mu}_{c}}=0&\Rightarrow\boldsymbol{\mu}_{c}^{\text{new}}=% \frac{\sum_{i=1}^{N_{\text{h.}}}q(c_{i}=c|\mathbf{z}_{i,\text{h.}};\theta_{% \text{old}})\cdot\mathbf{z}_{i,\text{h.}}}{\sum_{i=1}^{N_{\text{h.}}}q(c_{i}=c% |\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})}\\ \sum_{i=1}^{N_{\text{h.}}}\frac{\partial Q(\theta;\theta_{\text{old}})}{% \partial\Sigma_{c}}=0&\Rightarrow\Sigma_{c}^{\text{new}}=\frac{\sum_{i=1}^{N_{% \text{h.}}}q(c_{i}=c|\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})\cdot(% \mathbf{z}_{i,\text{h.}}-\boldsymbol{\mu}_{c}^{\text{new}})^{2}}{\sum_{i=1}^{N% _{\text{h.}}}q(c_{i}=c|\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})}.\\ \end{split}start_ROW start_CELL ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG ∂ italic_Q ( italic_θ ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG = 0 end_CELL start_CELL ⇒ italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT = divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG start_ARG italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG ∂ italic_Q ( italic_θ ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG = 0 end_CELL start_CELL ⇒ bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT = divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) ⋅ bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG end_CELL end_ROW start_ROW start_CELL ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG ∂ italic_Q ( italic_θ ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ roman_Σ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG = 0 end_CELL start_CELL ⇒ roman_Σ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT = divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) ⋅ ( bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT - bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ) end_ARG . end_CELL end_ROW (12)

The E-step and M-step alternate until convergence is reached. In our setting, we usually found one round of EM iteration sufficient. As for the initial parameters, we set πc(0)=1/Ch.superscriptsubscript𝜋𝑐01subscript𝐶h.\pi_{c}^{(0)}=1/C_{\text{h.}}italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = 1 / italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT, 𝝁c(0)=𝐚c,h.superscriptsubscript𝝁𝑐0subscript𝐚𝑐h.\boldsymbol{\mu}_{c}^{(0)}=\mathbf{a}_{c,\text{h.}}bold_italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT, and Σc(0)=𝐈superscriptsubscriptΣ𝑐0𝐈\Sigma_{c}^{(0)}=\mathbf{I}roman_Σ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = bold_I, which serves as a morphology-aware initialization for the algorithm. The initialization for {𝐚c,h.}c=1Ch.superscriptsubscriptsubscript𝐚𝑐h.𝑐1subscript𝐶h.\{\mathbf{a}_{c,\text{h.}}\}_{c=1}^{C_{\text{h.}}}{ bold_a start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is performed with K-means clustering on the training set of patches. This is constructed by aggregating token embeddings from all training slides in a disease cohort.

Once θ^^𝜃\widehat{\theta}over^ start_ARG italic_θ end_ARG is estimated, the post-aggregation embedding 𝐳c,h.agg.dh.superscriptsubscript𝐳𝑐h.agg.superscriptsubscript𝑑h.\mathbf{z}_{c,\text{h.}}^{\text{agg.}}\in\mathbb{R}^{d_{\text{h.}}}bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with dh.=1+2Dsubscript𝑑h.12𝐷d_{\text{h.}}=1+2Ditalic_d start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 1 + 2 italic_D, can be represented as a concatenation 𝐳c,h.agg.=[π^c,𝝁^c,Σ^c]superscriptsubscript𝐳𝑐h.agg.subscript^𝜋𝑐subscript^𝝁𝑐subscript^Σ𝑐\mathbf{z}_{c,\text{h.}}^{\text{agg.}}=[\widehat{\pi}_{c},\widehat{\boldsymbol% {\mu}}_{c},\widehat{\Sigma}_{c}]bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT = [ over^ start_ARG italic_π end_ARG start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , over^ start_ARG bold_italic_μ end_ARG start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ]. Denoting qi=q(ci=c|𝐳i,h.;θold)subscript𝑞𝑖𝑞subscript𝑐𝑖conditional𝑐subscript𝐳𝑖h.subscript𝜃oldq_{i}=q(c_{i}=c|\mathbf{z}_{i,\text{h.}};\theta_{\text{old}})italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_q ( italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_c | bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ), we can express 𝐳c,h.agg.superscriptsubscript𝐳𝑐h.agg.\mathbf{z}_{c,\text{h.}}^{\text{agg.}}bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT as in Eq. 1,

𝐳c,h.agg.=i=1Nh.[qi/Nh.,qi𝐳i,h./(i=1Nh.qi),qi(zi,h.i=1Nh.qi𝐳i,h./(i=1Nh.qi))2/(i=1Nh.qi)],superscriptsubscript𝐳𝑐h.agg.superscriptsubscript𝑖1subscript𝑁h.subscript𝑞𝑖subscript𝑁h.subscript𝑞𝑖subscript𝐳𝑖h.superscriptsubscript𝑖1subscript𝑁h.subscript𝑞𝑖subscript𝑞𝑖superscriptsubscript𝑧𝑖h.superscriptsubscript𝑖1subscript𝑁h.subscript𝑞𝑖subscript𝐳𝑖h.superscriptsubscript𝑖1subscript𝑁h.subscript𝑞𝑖2superscriptsubscript𝑖1subscript𝑁h.subscript𝑞𝑖\mathbf{z}_{c,\text{h.}}^{\text{agg.}}=\sum_{i=1}^{N_{\text{h.}}}\left[q_{i}/N% _{\text{h.}},\,\,\,q_{i}\mathbf{z}_{i,\text{h.}}/(\sum_{i=1}^{N_{\text{h.}}}q_% {i}),\,\,\,q_{i}\left(z_{i,\text{h.}}-\sum_{i=1}^{N_{\text{h.}}}q_{i}\mathbf{z% }_{i,\text{h.}}/(\sum_{i=1}^{N_{\text{h.}}}q_{i})\right)^{2}/(\sum_{i=1}^{N_{% \text{h.}}}q_{i})\right],bold_z start_POSTSUBSCRIPT italic_c , h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT agg. end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT [ italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT , italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT / ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_z start_POSTSUBSCRIPT italic_i , h. end_POSTSUBSCRIPT / ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] , (13)

which can indeed expressed as a sum of the map** function g𝑔gitalic_g (albeit non-trivial to write out the full expression due to the iterative nature of EM) over Nh.subscript𝑁h.N_{\text{h.}}italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT elements.

Appendix B Proof for similarity between OT-based cross-alignment and Transformer-based cross-attention

Lemma B.1.

Let 𝐙g.Cg.×dsubscript𝐙g.superscriptsubscript𝐶g.𝑑\mathbf{Z}_{\text{g.}}\in\mathbb{R}^{C_{\text{g.}}\times d}bold_Z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT and 𝐙h.Ch.×dsubscript𝐙h.superscriptsubscript𝐶h.𝑑\mathbf{Z}_{\text{h.}}\in\mathbb{R}^{C_{\text{h.}}\times d}bold_Z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT be the matrix representation of the token sets {𝐳i,g.}i=1Cg.superscriptsubscriptsubscript𝐳𝑖g.𝑖1subscript𝐶g.\{\mathbf{z}_{i,\text{g.}}\}_{i=1}^{C_{\text{g.}}}{ bold_z start_POSTSUBSCRIPT italic_i , g. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and {𝐳k,h.}k=1Ch.superscriptsubscriptsubscript𝐳𝑘h.𝑘1subscript𝐶h.\{\mathbf{z}_{k,\text{h.}}\}_{k=1}^{C_{\text{h.}}}{ bold_z start_POSTSUBSCRIPT italic_k , h. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. Let 𝐙g.𝐖QTCg.×dsubscript𝐙g.superscriptsubscript𝐖𝑄Tsuperscriptsubscript𝐶g.𝑑\mathbf{Z}_{\text{g.}}\mathbf{W}_{Q}^{\text{T}}\in\mathbb{R}^{C_{\text{g.}}% \times d}bold_Z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT and 𝐙h.𝐖TCh.×dsubscript𝐙h.superscript𝐖Tsuperscriptsubscript𝐶h.𝑑\mathbf{Z}_{\text{h.}}\mathbf{W}^{\text{T}}\in\mathbb{R}^{C_{\text{h.}}\times d}bold_Z start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT bold_W start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT be the linear projections of both sets. Let 𝐓^+Cg.×Ch.^𝐓subscriptsuperscriptsubscript𝐶g.subscript𝐶h.\widehat{\mathbf{T}}\in\mathbb{R}^{C_{\text{g.}}\times C_{\text{h.}}}_{+}over^ start_ARG bold_T end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT × italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT + end_POSTSUBSCRIPT be the optimal transport plan, i.e., the solution to the entropic-regularized, unbalanced optimal transport problem between the two projected sets. Then, 𝐓^^𝐓\widehat{\mathbf{T}}over^ start_ARG bold_T end_ARG is equivalent to the Transformer cross-attention matrix, σ(𝐙g.𝐖QT𝐖𝐙h.T/d)𝜎subscript𝐙g.superscriptsubscript𝐖𝑄Tsuperscriptsubscript𝐖𝐙h.T𝑑\sigma(\mathbf{Z}_{\text{g.}}\mathbf{W}_{Q}^{\text{T}}\mathbf{W}\mathbf{Z}_{% \text{h.}}^{\text{T}}/\sqrt{d})italic_σ ( bold_Z start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_WZ start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT / square-root start_ARG italic_d end_ARG ), up to a multiplicative factor where σ()𝜎\sigma(\cdot)italic_σ ( ⋅ ) denotes row-wise softmax, {𝐖Q𝐳i,g.}i=1Cg.superscriptsubscriptsubscript𝐖𝑄subscript𝐳𝑖g.𝑖1subscript𝐶g.\{\mathbf{W}_{Q}\mathbf{z}_{i,\text{g.}}\}_{i=1}^{C_{\text{g.}}}{ bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT bold_z start_POSTSUBSCRIPT italic_i , g. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are queries, and {𝐖𝐳k,h.}k=1Ch.superscriptsubscriptsubscript𝐖𝐳𝑘h.𝑘1subscript𝐶h.\{\mathbf{W}\mathbf{z}_{k,\text{h.}}\}_{k=1}^{C_{\text{h.}}}{ bold_Wz start_POSTSUBSCRIPT italic_k , h. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are keys.

Proof.

This proof is an extension and adaption of a lemma from (Kim, 2022) to our application. We use the negative dot-product similarity as the cost between two sets of linearly-projected tokens {𝐖Q𝐳i,g.}i=1Cg.superscriptsubscriptsubscript𝐖𝑄subscript𝐳𝑖g.𝑖1subscript𝐶g.\{\mathbf{W}_{Q}\mathbf{z}_{i,\text{g.}}\}_{i=1}^{C_{\text{g.}}}{ bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT bold_z start_POSTSUBSCRIPT italic_i , g. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and {𝐖𝐳k,h.}k=1Ch.superscriptsubscriptsubscript𝐖𝐳𝑘h.𝑘1subscript𝐶h.\{\mathbf{W}\mathbf{z}_{k,\text{h.}}\}_{k=1}^{C_{\text{h.}}}{ bold_Wz start_POSTSUBSCRIPT italic_k , h. end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT as 𝐃i,k=𝐳i,g.T𝐖QT𝐖𝐳k,h.subscript𝐃𝑖𝑘superscriptsubscript𝐳𝑖g.Tsuperscriptsubscript𝐖𝑄Tsubscript𝐖𝐳𝑘h.\mathbf{D}_{i,k}=-\mathbf{z}_{i,\text{g.}}^{\text{T}}\mathbf{W}_{Q}^{\text{T}}% \mathbf{W}\mathbf{z}_{k,\text{h.}}bold_D start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT = - bold_z start_POSTSUBSCRIPT italic_i , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_Wz start_POSTSUBSCRIPT italic_k , h. end_POSTSUBSCRIPT. We can formulate the entropic-regularized optimal transport problem for optimizing the transport plan 𝐓+Cg.×Ch.𝐓superscriptsubscriptsubscript𝐶g.subscript𝐶h.\mathbf{T}\in\mathbb{R}_{+}^{C_{\text{g.}}\times C_{\text{h.}}}bold_T ∈ blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT × italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT,

min𝐓i,k𝐃i,k𝐓i,k+ε𝐓i,klog𝐓i,k,s.t.k=1Ch.𝐓i,k=1Cg.,i,formulae-sequencesubscript𝐓subscript𝑖𝑘subscript𝐃𝑖𝑘subscript𝐓𝑖𝑘𝜀subscript𝐓𝑖𝑘subscript𝐓𝑖𝑘s.t.superscriptsubscript𝑘1subscript𝐶h.subscript𝐓𝑖𝑘1subscript𝐶g.for-all𝑖\min_{\mathbf{T}}\sum_{i,k}\mathbf{D}_{i,k}\cdot\mathbf{T}_{i,k}+\varepsilon% \mathbf{T}_{i,k}\log\mathbf{T}_{i,k},\quad\text{s.t.}\sum_{k=1}^{C_{\text{h.}}% }\mathbf{T}_{i,k}=\frac{1}{C_{\text{g.}}},\forall i,roman_min start_POSTSUBSCRIPT bold_T end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT bold_D start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT ⋅ bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT + italic_ε bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT roman_log bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT , s.t. ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_ARG , ∀ italic_i , (14)

without the constraint i=1Cg.𝐓i,k=1/Ch.superscriptsubscript𝑖1subscript𝐶g.subscript𝐓𝑖𝑘1subscript𝐶h.\sum_{i=1}^{C_{\text{g.}}}\mathbf{T}_{i,k}=1/C_{\text{h.}}∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT = 1 / italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT. Note that this can be considered as an unbalanced OT problem (Benamou, 2003; Chizat et al., 2018), as Eq. 14 can be written as

min𝐓i,k(𝐃i,k𝐓i,k+ε𝐓i,klog𝐓i,k)+λ1Div.(𝐓T𝟏Cg.,1/Ch.𝟏Ch.)+λ2Div.(𝐓𝟏Ch.,1/Cg.𝟏Cg.),\min_{\mathbf{T}}\sum_{i,k}\left(\mathbf{D}_{i,k}\cdot\mathbf{T}_{i,k}+% \varepsilon\mathbf{T}_{i,k}\log\mathbf{T}_{i,k}\right)+\lambda_{1}\cdot% \operatorname{Div.}(\mathbf{T}^{\text{T}}\cdot\mathbf{1}_{C_{\text{g.}}},1/C_{% \text{h.}}\cdot\mathbf{1}_{C_{\text{h.}}})+\lambda_{2}\cdot\operatorname{Div.}% (\mathbf{T}\cdot\mathbf{1}_{C_{\text{h.}}},1/C_{\text{g.}}\cdot\mathbf{1}_{C_{% \text{g.}}}),roman_min start_POSTSUBSCRIPT bold_T end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT ( bold_D start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT ⋅ bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT + italic_ε bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT roman_log bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT ) + italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋅ start_OPFUNCTION roman_Div . end_OPFUNCTION ( bold_T start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT ⋅ bold_1 start_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUBSCRIPT , 1 / italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT ⋅ bold_1 start_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) + italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⋅ start_OPFUNCTION roman_Div . end_OPFUNCTION ( bold_T ⋅ bold_1 start_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUBSCRIPT , 1 / italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT ⋅ bold_1 start_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) , (15)

with λ10subscript𝜆10\lambda_{1}\rightarrow 0italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT → 0 and λ2subscript𝜆2\lambda_{2}\rightarrow\inftyitalic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → ∞, where Div.\operatorname{Div.}roman_Div . is some divergence measure and 𝟏Cg.subscript1subscript𝐶g.\mathbf{1}_{C_{\text{g.}}}bold_1 start_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUBSCRIPT is a Cg.subscript𝐶g.C_{\text{g.}}italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT-length vector of ones. We now take Eq. 14 and solve it by using Lagrange multiplier,

=i,k(𝐃i,k𝐓i,k+ε𝐓i,klog𝐓i,k)+i=1Cg.βi(k=1Ch.𝐓i,k1Cg.).subscript𝑖𝑘subscript𝐃𝑖𝑘subscript𝐓𝑖𝑘𝜀subscript𝐓𝑖𝑘subscript𝐓𝑖𝑘superscriptsubscript𝑖1subscript𝐶g.subscript𝛽𝑖superscriptsubscript𝑘1subscript𝐶h.subscript𝐓𝑖𝑘1subscript𝐶g.\mathcal{L}=\sum_{i,k}\left(\mathbf{D}_{i,k}\cdot\mathbf{T}_{i,k}+\varepsilon% \mathbf{T}_{i,k}\log\mathbf{T}_{i,k}\right)+\sum_{i=1}^{C_{\text{g.}}}\beta_{i% }\left(\sum_{k=1}^{C_{\text{h.}}}\mathbf{T}_{i,k}-\frac{1}{C_{\text{g.}}}% \right).caligraphic_L = ∑ start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT ( bold_D start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT ⋅ bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT + italic_ε bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT roman_log bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_ARG ) . (16)

We proceed by taking the derivative of \mathcal{L}caligraphic_L with respect to 𝐓i,ksubscript𝐓𝑖𝑘\mathbf{T}_{i,k}bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT and setting it to 0,

𝐓i,k=𝐃i,k+ε(log𝐓i,k+1)+βi=0𝐓i,k=exp(𝐃i,k/ε+γi),subscript𝐓𝑖𝑘subscript𝐃𝑖𝑘𝜀subscript𝐓𝑖𝑘1subscript𝛽𝑖0subscript𝐓𝑖𝑘subscript𝐃𝑖𝑘𝜀subscript𝛾𝑖\frac{\partial\mathcal{L}}{\partial\mathbf{T}_{i,k}}=\mathbf{D}_{i,k}+% \varepsilon(\log\mathbf{T}_{i,k}+1)+\beta_{i}=0\Rightarrow\mathbf{T}_{i,k}=% \exp\left(-\mathbf{D}_{i,k}/\varepsilon+\gamma_{i}\right),divide start_ARG ∂ caligraphic_L end_ARG start_ARG ∂ bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT end_ARG = bold_D start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT + italic_ε ( roman_log bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT + 1 ) + italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 ⇒ bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT = roman_exp ( - bold_D start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT / italic_ε + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , (17)

where γi=(βi/ε+1)subscript𝛾𝑖subscript𝛽𝑖𝜀1\gamma_{i}=-(\beta_{i}/\varepsilon+1)italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = - ( italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / italic_ε + 1 ) is some constant. To solve for γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we can use the constraint k=1Ch.𝐓i,k=1Cg.superscriptsubscript𝑘1subscript𝐶h.subscript𝐓𝑖𝑘1subscript𝐶g.\sum_{k=1}^{C_{\text{h.}}}\mathbf{T}_{i,k}=\frac{1}{C_{\text{g.}}}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_T start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_ARG,

exp(γi)k=1Ch.exp(𝐃i,k/ε)=1Cg.exp(γi)=1Cg.k=1Ch.exp(𝐃i,k/ε),subscript𝛾𝑖superscriptsubscript𝑘1subscript𝐶h.subscript𝐃𝑖𝑘𝜀1subscript𝐶g.subscript𝛾𝑖1subscript𝐶g.superscriptsubscript𝑘1subscript𝐶h.subscript𝐃𝑖𝑘𝜀\exp(\gamma_{i})\sum_{k=1}^{C_{\text{h.}}}\exp(-\mathbf{D}_{i,k}/\varepsilon)=% \frac{1}{C_{\text{g.}}}\Rightarrow\exp(\gamma_{i})=\frac{1}{C_{\text{g.}}\cdot% \sum_{k=1}^{C_{\text{h.}}}\exp(-\mathbf{D}_{i,k}/\varepsilon)},roman_exp ( italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( - bold_D start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT / italic_ε ) = divide start_ARG 1 end_ARG start_ARG italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT end_ARG ⇒ roman_exp ( italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT ⋅ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( - bold_D start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT / italic_ε ) end_ARG , (18)

and obtain 𝐓^i,ksubscript^𝐓𝑖𝑘\widehat{\mathbf{T}}_{i,k}over^ start_ARG bold_T end_ARG start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT (by also setting ε=d𝜀𝑑\varepsilon=\sqrt{d}italic_ε = square-root start_ARG italic_d end_ARG),

𝐓^i,k=exp(𝐃i,k/d)Cg.k=1Ch.exp(𝐃i,k/d)=exp(𝐳i,g.T𝐖QT𝐖𝐳k,h./d)Cg.k=1Ch.exp(𝐳i,g.T𝐖QT𝐖𝐳k,h./d),subscript^𝐓𝑖𝑘subscript𝐃𝑖𝑘𝑑subscript𝐶g.superscriptsubscript𝑘1subscript𝐶h.subscript𝐃𝑖𝑘𝑑superscriptsubscript𝐳𝑖g.Tsuperscriptsubscript𝐖𝑄Tsubscript𝐖𝐳𝑘h.𝑑subscript𝐶g.superscriptsubscript𝑘1subscript𝐶h.superscriptsubscript𝐳𝑖g.Tsuperscriptsubscript𝐖𝑄Tsubscript𝐖𝐳𝑘h.𝑑\widehat{\mathbf{T}}_{i,k}=\frac{\exp(-\mathbf{D}_{i,k}/\sqrt{d})}{C_{\text{g.% }}\cdot\sum_{k=1}^{C_{\text{h.}}}\exp(-\mathbf{D}_{i,k}/\sqrt{d})}=\frac{\exp(% \mathbf{z}_{i,\text{g.}}^{\text{T}}\mathbf{W}_{Q}^{\text{T}}\mathbf{W}\mathbf{% z}_{k,\text{h.}}/\sqrt{d})}{C_{\text{g.}}\cdot\sum_{k=1}^{C_{\text{h.}}}\exp(% \mathbf{z}_{i,\text{g.}}^{\text{T}}\mathbf{W}_{Q}^{\text{T}}\mathbf{W}\mathbf{% z}_{k,\text{h.}}/\sqrt{d})},over^ start_ARG bold_T end_ARG start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT = divide start_ARG roman_exp ( - bold_D start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT / square-root start_ARG italic_d end_ARG ) end_ARG start_ARG italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT ⋅ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( - bold_D start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT / square-root start_ARG italic_d end_ARG ) end_ARG = divide start_ARG roman_exp ( bold_z start_POSTSUBSCRIPT italic_i , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_Wz start_POSTSUBSCRIPT italic_k , h. end_POSTSUBSCRIPT / square-root start_ARG italic_d end_ARG ) end_ARG start_ARG italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT ⋅ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( bold_z start_POSTSUBSCRIPT italic_i , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT T end_POSTSUPERSCRIPT bold_Wz start_POSTSUBSCRIPT italic_k , h. end_POSTSUBSCRIPT / square-root start_ARG italic_d end_ARG ) end_ARG , (19)

with the softmax term appearing as an entry for 𝐓^^𝐓\widehat{\mathbf{T}}over^ start_ARG bold_T end_ARG. This is the same as the Transformer-based cross-attention operation up to a multiplicative factor of 1/Cg.1subscript𝐶g.1/C_{\text{g.}}1 / italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT. ∎

Appendix C Survival loss functions

Survival analysis models the time to an event, where the event outcome is not always observed (i.e., censored). In cancer survival outcome prediction, a censored event refers to patient survival or last known follow-up time, whereas an uncensored event is a patient death. Let T𝑇Titalic_T be a continuous random variable representing patient survival time, and the survival function S(t)=P(Tt0)𝑆𝑡𝑃𝑇subscript𝑡0S(t)=P(T\geq t_{0})italic_S ( italic_t ) = italic_P ( italic_T ≥ italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) be the probability of a patient surviving longer than time t0subscript𝑡0t_{0}italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. The goal of survival analysis is to estimate the hazard function, λ(t)𝜆𝑡\lambda(t)italic_λ ( italic_t ), which denotes the probability of an event occurring instantaneously at time t>t0𝑡subscript𝑡0t>t_{0}italic_t > italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (Cox, 1972). We now detail the Cox proportional Hazards and Negative log-likelihood survival losses.

C.1 Cox proportional hazards loss

Cox proportional hazards model parameterizes the hazard function as an exponential linear function λ(t|x)=λ0(t)expθx𝜆conditional𝑡𝑥subscript𝜆0𝑡superscript𝜃𝑥\lambda(t|x)=\lambda_{0}(t)\exp^{\theta x}italic_λ ( italic_t | italic_x ) = italic_λ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_t ) roman_exp start_POSTSUPERSCRIPT italic_θ italic_x end_POSTSUPERSCRIPT. λ0subscript𝜆0\lambda_{0}italic_λ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is the baseline hazard function describing how the risk of an event changes over time. θ𝜃\thetaitalic_θ are the model parameters describing how the hazards vary with the features of a patient, 𝐱¯patient2dsubscript¯𝐱patientsuperscript2𝑑\bar{\mathbf{x}}_{\text{patient}}\in\mathbb{R}^{2d}over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 italic_d end_POSTSUPERSCRIPT. To express the likelihood of an event to be observed at time t𝑡titalic_t with model parameters θ𝜃\thetaitalic_θ, the Cox partial log-likelihood can be used  (Wong, 1986):

l(θ,𝐱¯patient)=iU(𝐱¯patient,iθlog(jRiexp(𝐱¯patient,jθ)))𝑙𝜃subscript¯𝐱patientsubscript𝑖𝑈subscript¯𝐱patient𝑖𝜃subscript𝑗subscript𝑅𝑖subscript¯𝐱patient𝑗𝜃l(\theta,\bar{\mathbf{x}}_{\text{patient}})=-\sum_{i\in U}\biggl{(}\bar{% \mathbf{x}}_{\text{patient},i}\theta-\log(\sum_{j\in R_{i}}\exp({\bar{\mathbf{% x}}_{\text{patient},j}\theta}))\biggr{)}italic_l ( italic_θ , over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ) = - ∑ start_POSTSUBSCRIPT italic_i ∈ italic_U end_POSTSUBSCRIPT ( over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient , italic_i end_POSTSUBSCRIPT italic_θ - roman_log ( ∑ start_POSTSUBSCRIPT italic_j ∈ italic_R start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_exp ( over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient , italic_j end_POSTSUBSCRIPT italic_θ ) ) ) (20)
l(θ,𝐱¯patient)𝐱¯patient,i=δ(i)θi,jCj,Uθexp(𝐱¯patient,iθ)kCjexp(𝐱¯patient,kθ)𝑙𝜃subscript¯𝐱patientsubscript¯𝐱patient𝑖𝛿𝑖𝜃subscriptformulae-sequence𝑖𝑗subscript𝐶𝑗𝑈𝜃subscript¯𝐱patient𝑖𝜃subscript𝑘subscript𝐶𝑗subscript¯𝐱patient𝑘𝜃\frac{\partial l(\theta,\bar{\mathbf{x}}_{\text{patient}})}{\partial\bar{% \mathbf{x}}_{\text{patient},i}}=\delta(i)\theta-\sum_{i,j\in C_{j},U}\frac{% \theta\exp(\bar{\mathbf{x}}_{\text{patient},i}\theta)}{\sum_{k\in C_{j}}\exp(% \bar{\mathbf{x}}_{\text{patient},k}\theta)}divide start_ARG ∂ italic_l ( italic_θ , over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient , italic_i end_POSTSUBSCRIPT end_ARG = italic_δ ( italic_i ) italic_θ - ∑ start_POSTSUBSCRIPT italic_i , italic_j ∈ italic_C start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_U end_POSTSUBSCRIPT divide start_ARG italic_θ roman_exp ( over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient , italic_i end_POSTSUBSCRIPT italic_θ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k ∈ italic_C start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_exp ( over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient , italic_k end_POSTSUBSCRIPT italic_θ ) end_ARG (21)

where U𝑈Uitalic_U is the set of uncensored patients, C𝐶Citalic_C is the set of censored patients, Risubscript𝑅𝑖R_{i}italic_R start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the set of patients whose last time of follow-up or time of death is after i𝑖iitalic_i, and δ(i)𝛿𝑖\delta(i)italic_δ ( italic_i ) signifies if event outcome is observed or if censored.

C.2 Negative log-likelihood loss

The Negative log-likelihood (NLL) survival loss (Zadeh & Schmid, 2021) generalizes the NLL to censored data. The aim is to predict the survival of a patient from the learned patient level embedding 𝐱¯patient2dsubscript¯𝐱patientsuperscript2𝑑\bar{\mathbf{x}}_{\text{patient}}\in\mathbb{R}^{2d}over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 italic_d end_POSTSUPERSCRIPT. In accordance with previous work (Zadeh & Schmid, 2021), the patient’s survival state is defined by: (1) censorship status c𝑐citalic_c, where c=0𝑐0c=0italic_c = 0 represents an observed patient death due to disease and c=1𝑐1c=1italic_c = 1 corresponds to the patient’s last known follow-up, and (2) a time-to-event tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, which corresponds to the time between the patient’s diagnosis and observed death if c=0𝑐0c=0italic_c = 0, or the last follow-up if c=1𝑐1c=1italic_c = 1. Instead of predicting the observed time-to-event tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we discretize it by defining non-overlap** time intervals (tj1,tj),j[1,,n]subscript𝑡𝑗1subscript𝑡𝑗𝑗1𝑛(t_{j-1},t_{j}),\;j\in[1,...,n]( italic_t start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) , italic_j ∈ [ 1 , … , italic_n ] based on the quartiles of survival time values, and denote as yjsubscript𝑦𝑗y_{j}italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. The setup simplifies to a classification problem with censorship information, where each patient is now defined by (𝐱¯slide,yj,c)subscript¯𝐱slidesubscript𝑦𝑗𝑐(\bar{\mathbf{x}}_{\text{slide}},y_{j},c)( over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT slide end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_c ). Next, we build a classifier such that each output logit y^jsubscript^𝑦𝑗\hat{y}_{j}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT corresponds to a time interval. Then, we define the discrete hazard function fhazard(yj|𝐱¯patient)=S(y^j)subscript𝑓hazardconditionalsubscript𝑦𝑗subscript¯𝐱patient𝑆subscript^𝑦𝑗f_{\text{hazard}}(y_{j}|\bar{\mathbf{x}}_{\text{patient}})=S(\hat{y}_{j})italic_f start_POSTSUBSCRIPT hazard end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT | over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ) = italic_S ( over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) where S𝑆Sitalic_S is the sigmoid activation. Intuitively, fhazard(yj|𝐱¯patient)subscript𝑓hazardconditionalsubscript𝑦𝑗subscript¯𝐱patientf_{\text{hazard}}(y_{j}|\bar{\mathbf{x}}_{\text{patient}})italic_f start_POSTSUBSCRIPT hazard end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT | over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ) represents the probability that the patient dies during time interval (tj1,tj)subscript𝑡𝑗1subscript𝑡𝑗(t_{j-1},t_{j})( italic_t start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ). Additionally, we define the discrete survival function fsurv(yj|𝐱¯patient)=k=1j(1fhazard(yk|𝐱¯patient))subscript𝑓survconditionalsubscript𝑦𝑗subscript¯𝐱patientsuperscriptsubscriptproduct𝑘1𝑗1subscript𝑓hazardconditionalsubscript𝑦𝑘subscript¯𝐱patientf_{\text{surv}}(y_{j}|\bar{\mathbf{x}}_{\text{patient}})=\prod_{k=1}^{j}\big{(% }1-f_{\text{hazard}}(y_{k}|\bar{\mathbf{x}}_{\text{patient}})\big{)}italic_f start_POSTSUBSCRIPT surv end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT | over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ) = ∏ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( 1 - italic_f start_POSTSUBSCRIPT hazard end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | over¯ start_ARG bold_x end_ARG start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ) ) that represents the probability that the patient survives up to time interval (tj1,tj)subscript𝑡𝑗1subscript𝑡𝑗(t_{j-1},t_{j})( italic_t start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ). Now, the NLL survival loss can be formally defined as:

(\displaystyle\mathcal{L}\Big{(}caligraphic_L ( {𝐱¯patient(i),yj(i),c(i)}i=1ND)=\displaystyle\{\bar{\mathbf{x}}^{(i)}_{\text{patient}},y^{(i)}_{j},c^{(i)}\}_{% i=1}^{N_{D}}\Big{)}={ over¯ start_ARG bold_x end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT , italic_y start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_c start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) = (22)
i=1NDc(i)log(fsurv(yj(i)|𝐱¯patient(i)))superscriptsubscript𝑖1subscript𝑁𝐷superscript𝑐𝑖subscript𝑓survconditionalsuperscriptsubscript𝑦𝑗𝑖subscriptsuperscript¯𝐱𝑖patient\displaystyle\sum_{i=1}^{N_{D}}-c^{(i)}\log(f_{\text{surv}}(y_{j}^{(i)}|\bar{% \mathbf{x}}^{(i)}_{\text{patient}}))∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - italic_c start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT roman_log ( italic_f start_POSTSUBSCRIPT surv end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT | over¯ start_ARG bold_x end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ) ) (23)
+(1c(i))log(fsurv(yj(i)1|𝐱¯patient(i)))1superscript𝑐𝑖subscript𝑓survsuperscriptsubscript𝑦𝑗𝑖conditional1subscriptsuperscript¯𝐱𝑖patient\displaystyle+(1-c^{(i)})\log(f_{\text{surv}}(y_{j}^{(i)}-1|\bar{\mathbf{x}}^{% (i)}_{\text{patient}}))+ ( 1 - italic_c start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) roman_log ( italic_f start_POSTSUBSCRIPT surv end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT - 1 | over¯ start_ARG bold_x end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ) ) (24)
+(1c(i))log(fhazard(yj(i)|𝐱¯patient(i)))1superscript𝑐𝑖subscript𝑓hazardconditionalsuperscriptsubscript𝑦𝑗𝑖subscriptsuperscript¯𝐱𝑖patient\displaystyle+(1-c^{(i)})\log(f_{\text{hazard}}(y_{j}^{(i)}|\bar{\mathbf{x}}^{% (i)}_{\text{patient}}))+ ( 1 - italic_c start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) roman_log ( italic_f start_POSTSUBSCRIPT hazard end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT | over¯ start_ARG bold_x end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT patient end_POSTSUBSCRIPT ) ) (25)

where NDsubscript𝑁𝐷N_{D}italic_N start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT is the number of samples in the dataset. Eq. 23 enforces high survival probability for patients alive after the final follow-up, Eq. 24 enforces high survival up to the time stamp where death was observed for patients that died, and Eq. 25 ensures correct timestamp is predicted for patients with observed death. As NLL does not require a set of patients for training, unlike Cox loss, it has been the de-facto loss function for cancer survival prediction with histology data, with the large number of tokens rendering the formation of patient batch infeasible.

C.3 Concordance Index

The Concordance Index (C-Index) (Harrell et al., 1982) is a popular metric to measure the performance of survival prediction model  (Chen et al., 2022; Jaume et al., 2024) and measures the rank correlation between the predicted risk scores and observed time points t𝑡titalic_t. In prognosis prediction, the C-Index can be conceptually understood as a metric that assesses the accuracy of a model in predicting a higher risk of adverse outcomes for patients with shorter survival times. Formally, C-Index is defined as the ratio of concordant pairs to total comparable pairs. Two patients i𝑖iitalic_i and j𝑗jitalic_j are comparable if the patient with the lower observed time experienced an event (i.e., if ti>tjsubscript𝑡𝑖subscript𝑡𝑗t_{i}>t_{j}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT > italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT then δj=1subscript𝛿𝑗1\delta_{j}=1italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = 1, where δ𝛿\deltaitalic_δ is a binary indicator of whether event is observed or if last follow up time is known). A comparable pair (i,j)𝑖𝑗(i,j)( italic_i , italic_j ) is considered concordant if the risk predicted by a survival model f^risksubscript^𝑓risk\hat{f}_{\text{risk}}over^ start_ARG italic_f end_ARG start_POSTSUBSCRIPT risk end_POSTSUBSCRIPT is larger for the patient with the smaller event time, i.e., f^risk,j>f^risk,isubscript^𝑓risk𝑗subscript^𝑓risk𝑖\hat{f}_{\text{risk},j}>\hat{f}_{\text{risk},i}over^ start_ARG italic_f end_ARG start_POSTSUBSCRIPT risk , italic_j end_POSTSUBSCRIPT > over^ start_ARG italic_f end_ARG start_POSTSUBSCRIPT risk , italic_i end_POSTSUBSCRIPT given tj<tisubscript𝑡𝑗subscript𝑡𝑖t_{j}<t_{i}italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT < italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Otherwise, the pair is considered discordant  (Pölsterl, 2020). While C-Index allows for easy comparisons between models, known limitations exist, such as it is overly optimistic for increasing censorship in datasets (Uno et al., 2011).

Appendix D Datasets

D.1 TCGA cohort

We evaluate all baselines on 6 cancer cohorts from TCGA: Bladder Urothelial Carcinoma (BLCA), Breast Invasive Carcinoma (BRCA), Lung adenocarcinoma (LUAD), Stomach adenocarcinoma (STAD), Colon and Rectum adenocarcinoma (CRC), Kidney renal clear cell carcinoma (KIRC), and low-grade gliomas (LGG). Table 5 contains representative statistics of the dataset. A WSI is tessellated into nonoverlap** patches (tokens) of 256×256256256256\times 256256 × 256 pixels at 20×20\times20 × magnification (0.5μm0.5𝜇𝑚0.5\mu m0.5 italic_μ italic_m/pixel).

Table 5: TCGA cohort statistics The number of patients, total WSIs, and the average number of patches (tokens) in a WSI. A single patient can have multiple WSIs.
Num. of patients Num. of slides Avg. set size
BLCA 359 423 16,312
BRCA 868 928 11,565
LUAD 412 463 4,714
STAD 318 318 10,955
CRC 296 300 9,127
KIRC 340 346 12,802

D.2 RNA-seq expression data

Bulk RNA-seq expression for all TCGA cohorts—accessed from UCSC Xena database (Goldman et al., 2020)—is measured by Illumina HiSeq 2000 RNA Sequencing platform and then log2(x+1)subscript2𝑥1\log_{2}(x+1)roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x + 1 ) transformed RSEM normalized (Li & Dewey, 2011). The Cg.=50subscript𝐶g.50C_{\text{g.}}=50italic_C start_POSTSUBSCRIPT g. end_POSTSUBSCRIPT = 50 Hallmark gene sets from Molecular Signatures Database (MSigDB) (Subramanian et al., 2005; Liberzon et al., 2015) are used to select and organize genes into biological pathways. Hallmark gene sets represent well-defined biological states in cancer. After organizing genes into Hallmark gene sets, we had 4,241 unique genes across the 50 gene sets. The average length of the gene sets is 142, with the minimum and maximum of 31 and 199.

Appendix E Baselines

E.1 Unimodal baselines

In this section, we explain the unimodal MIL baselines that we compare our proposed framework with.

  1. 1.

    ABMIL (Ilse et al., 2018): Attention-based multiple instance learning (ABMIL) first assigns patch-level importance scores through a local attention mechanism, where the score for one patch only depends on the contents of that patch. The attention-weighted sum of patches is used as the slide-level representation. The independence assumption of ABMIL neglects correlations between different patches.

  2. 2.

    TransMIL (Shao et al., 2021): Since ABMIL is unable to learn patch-level correlations, Transformer-based multiple instance learning (TransMIL) has been proposed. TransMIL first squares the sequence of low dimensional representations, then applies a Pyramidal Positional Encoding module to encode spatial knowledge, and finally uses Nystrom attention (Xiong et al., 2021) to approximate self-attention scores between patches. The CLS token is taken as the slide-level representation.

  3. 3.

    Low-rank MIL (Xiang & Zhang, 2022): While TransMIL tries to learn slide-level representations by encoding patch correlations, it does not leverage the redundancy in WSI, which (Xiang & Zhang, 2022) used to propose iterative low-rank attention (ILRA). Each ILRA block consists of two layers: one aims to project the sequence of patch representations to a low-rank space by cross-attending it with a latent matrix, and the second reconstructs the input. Max-pooling over the output of k𝑘kitalic_k such layers yields a low-rank slide-level representation.

  4. 4.

    AttnMISL (Yao et al., 2020): In contrast with ABMIL, TransMIL, and ILRA, which learn slide-level representations using patch representations, AttnMISL first clusters patches into morphological prototypes using K-means clustering. Next, each prototype is encoded using prototype-specific fully convolutional Siamese networks (Yao et al., 2019). The slide-level representation is then created using local attention pooling over the prototypes.

  5. 5.

    Information Bottleneck MIL (Li et al., 2023): Information bottlenecks (IB) are used to compress a WSI by removing irrelevant instances. IB aims to find patch instances that minimize the mutual information between the distribution of patches and patch representations. By only kee** such instances, (Li et al., 2023) argue that most informative patches are retained, which can then be aggregated into a compact representation of WSI.

  6. 6.

    Unimodal MMP: This is similar to Panther (Song et al., 2024a) in that GMM is used to map each histology patch embeddings into a pre-defined set of morphological prototypes. However, whereas Panther concatenates the post-aggregation embeddings to form the slide representation, the unimodal MMP employs fh.presuperscriptsubscript𝑓h.pref_{\text{h.}}^{\text{pre}}italic_f start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT, fc,g.presuperscriptsubscript𝑓𝑐g.pref_{c,\text{g.}}^{\text{pre}}italic_f start_POSTSUBSCRIPT italic_c , g. end_POSTSUBSCRIPT start_POSTSUPERSCRIPT pre end_POSTSUPERSCRIPT, and fcpostsuperscriptsubscript𝑓𝑐postf_{c}^{\text{post}}italic_f start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT post end_POSTSUPERSCRIPT that are learned along with the downstream tasks.

E.2 Multimodal baselines

We compare our proposed method MMP with several early-fusion multimodal survival baselines.

  1. 1.

    MCAT (Chen et al., 2021): Multimodal Co-Attention Transformer (MCAT) is an early fusion technique that learns a dense co-attention map** between histology and omic tokens. This map** is then used to calculate omic-guided histology features, which are concatenated with omics to predict patient survival. MCAT uses omic prototypes because it groups genes into 6 functional families.

  2. 2.

    MOTCat (Xu & Chen, 2023): Multimodal Optimal Transport-based Co-attention Transformer (MOTCat) uses Optimal Transport to learn an optimal plan between histology tokens and genes grouped into 6 functional groups, similar to MCAT. The estimated optimal transport plan is then used for selecting the most informative histology tokens.

  3. 3.

    SurvPath (Jaume et al., 2024): Unlike MCAT and MOTCat, which are limited to six gene families, SurvPath introduces a transcriptomics tokenizer to encode genes into biological pathways that represent known cellular functions. The pathway tokens are then fused with histology patches via a memory-efficient transformer, which learns interactions between pathways and those between pathways and histology, but does not learn histology-to-histology interactions.

  4. 4.

    CMTA (Zhou & Chen, 2023): Cross-Modal Translation and Alignment (CMTA) framework uses two parallel Transformer encoder-decoder modules. Encoders are used to extract intra-modal representations for each modality. Decoders generate cross-modal representations. A cross-modal attention module between the two encoders facilitates learning the cross-modal relations.

Appendix F Clinical baselines

We assess how MMP and other survival prediction frameworks perform against basic clinical information included in patient metadata. Based on age, sex, and grade, empirically shown as crucial prognostic factors (Bonnier et al., 1995; Rakha et al., 2010; Tas et al., 2013; Yu et al., 2022), we perform univariate/multivariate linear Cox regression to obtain the baseline. We observe that MMP outperforms the baseline overall, hinting at its clinical potential for patient prognosis.

Table 6: Survival prediction with clinical variables The clinical variables for the TCGA cohort were downloaded from cBioPortal. All denotes the combination of age, sex, and grade.
Dataset BRCA BLCA LUAD STAD CRC KIRC Avg. (\uparrow)
Age 0.496±0.086plus-or-minus0.086\pm 0.086± 0.086 0.578±0.056plus-or-minus0.056\pm 0.056± 0.056 0.533±0.063plus-or-minus0.063\pm 0.063± 0.063 0.449±0.055plus-or-minus0.055\pm 0.055± 0.055 0.357±0.161plus-or-minus0.161\pm 0.161± 0.161 0.554±0.147plus-or-minus0.147\pm 0.147± 0.147 0.495
Sex 0.490±0.011plus-or-minus0.011\pm 0.011± 0.011 0.489±0.028plus-or-minus0.028\pm 0.028± 0.028 0.480±0.049plus-or-minus0.049\pm 0.049± 0.049 0.529±0.069plus-or-minus0.069\pm 0.069± 0.069 0.542±0.070plus-or-minus0.070\pm 0.070± 0.070 0.437±0.057plus-or-minus0.057\pm 0.057± 0.057 0.495
Grade 0.597±0.078plus-or-minus0.078\pm 0.078± 0.078 0.515±0.018plus-or-minus0.018\pm 0.018± 0.018 N/A 0.552±0.055plus-or-minus0.055\pm 0.055± 0.055 N/A 0.594±0.083plus-or-minus0.083\pm 0.083± 0.083 N/A
All 0.563±0.055plus-or-minus0.055\pm 0.055± 0.055 0.570±0.033plus-or-minus0.033\pm 0.033± 0.033 0.528±0.028plus-or-minus0.028\pm 0.028± 0.028 0.592±0.044plus-or-minus0.044\pm 0.044± 0.044 0.655±0.119plus-or-minus0.119\pm 0.119± 0.119 0.602±0.066plus-or-minus0.066\pm 0.066± 0.066 0.585
MMPTrans.subscriptMMPTrans.\textsc{MMP}_{\text{Trans.}}MMP start_POSTSUBSCRIPT Trans. end_POSTSUBSCRIPT 0.738±0.069plus-or-minus0.069\pm 0.069± 0.069 0.635±0.051plus-or-minus0.051\pm 0.051± 0.051 0.642±0.037plus-or-minus0.037\pm 0.037± 0.037 0.598±0.051plus-or-minus0.051\pm 0.051± 0.051 0.630±0.125plus-or-minus0.125\pm 0.125± 0.125 0.747±0.106plus-or-minus0.106\pm 0.106± 0.106 0.665

Appendix G Histology ablations

We perform additional experiments in four cancer types, varying: (1) the number of histology prototypes (Ch=8,16,32subscript𝐶81632C_{h}=8,16,32italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 8 , 16 , 32) and (2) the pretrained encoder (ResNet50, CTransPath, and UNI). The results are shown in Tables 7, 8.

We observe that performance with UNI features is relatively consistent across Chsubscript𝐶C_{h}italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT, with Ch=32subscript𝐶32C_{h}=32italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 32 being the weakest. The choice of Ch=16subscript𝐶16C_{h}=16italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 16 was influenced by two factors: (1) This gave the best overall performance in multimodal evaluation. (2) Ch=8subscript𝐶8C_{h}=8italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 8 sometimes fails to distinguish between two similar but subtly different morphological exemplars (by grou** them into a single cluster), whereas Ch=32subscript𝐶32C_{h}=32italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 32 induces harder morphological interpretation due to an excessive number of exemplars. Ch=16subscript𝐶16C_{h}=16italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 16 offered the best trade-off.

Table 7: Ablation on the number of histology prototypes Unimodal MMP was trained on varying number of histology prototypes Ch.subscript𝐶h.C_{\text{h.}}italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT for select cancer types.
Dataset BRCA BLCA LUAD CRC Avg. (\uparrow)
Ch.=8subscript𝐶h.8C_{\text{h.}}=8italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 8 0.720±plus-or-minus\pm± 0.06 0.601±plus-or-minus\pm± 0.04 0.592±plus-or-minus\pm± 0.04 0.641±plus-or-minus\pm± 0.11 0.639
Ch.=16subscript𝐶h.16C_{\text{h.}}=16italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 16 0.669±plus-or-minus\pm± 0.12 0.593 ±plus-or-minus\pm± 0.06 0.600±plus-or-minus\pm± 0.04 0.646±plus-or-minus\pm± 0.11 0.627
Ch.=32subscript𝐶h.32C_{\text{h.}}=32italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 32 0.680 ±plus-or-minus\pm± 0.09 0.590±plus-or-minus\pm± 0.05 0.587±plus-or-minus\pm± 0.04 0.617±plus-or-minus\pm± 0.13 0.619
Table 8: Ablation on the histology encoder Unimodal MMP was trained on different histology encoders for select cancer types.
Dataset BRCA BLCA LUAD CRC Avg. (\uparrow)
ResNet50 0.574±plus-or-minus\pm± 0.11 0.511±plus-or-minus\pm± 0.05 0.600±plus-or-minus\pm± 0.06 0.534±plus-or-minus\pm± 0.18 0.555
CTransPath 0.653±plus-or-minus\pm± 0.10 0.566±plus-or-minus\pm± 0.05 0.578 ±plus-or-minus\pm± 0.02 0.574 ±plus-or-minus\pm± 0.14 0.593
UNI 0.669 ±plus-or-minus\pm± 0.12 0.593 ±plus-or-minus\pm± 0.06 0.600 ±plus-or-minus\pm± 0.04 0.646 ±plus-or-minus\pm± 0.11 0.627

Appendix H Survival loss ablation experiment

We assess how the train batch size affects the performance, using the Cox proportional hazards loss (Cox, 1972) and NLL survival loss (Zadeh & Schmid, 2020) (Table 9). To this end, we use the MMP full model. We observe that the C-Index increases with a larger batch size until it reaches the peak and starts to decline, regardless of the loss function (peak for Cox loss: 0.665 with batch size 64 and NLL loss: 0.644 with batch size 16). The increase can be attributed to stable training from having more patients in each batch to compare the predicted risks against (Kvamme et al., 2019). The decrease is likely due to a smaller number of parameter updates within the same number of epochs. This suggests the benefits of batch-based training for survival prediction, which does not apply to non-prototype-based approaches as they rely on the NLL survival loss with a single patient batch due to large Nh.subscript𝑁h.N_{\text{h.}}italic_N start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT. We also observe that employing the Cox loss gives a better overall performance, which illustrates another benefit of forming a batch of patients in MMP with fewer tokens. We attribute the lower performance of NLL survival loss to the discretization of time into non-overlap** coarse bins, which might result in discarding valuable survival information.

Table 9: Batch size ablation. Average C-Index across 5 cross-validation folds with varying batch sizes of patients with Cox and NLL loss. A batch of a single patient cannot be used for Cox loss.
BRCA BLCA LUAD STAD CRC KIRC Avg.()(\uparrow)( ↑ )
Cox B=1𝐵1B=1italic_B = 1 N/A N/A N/A N/A N/A N/A N/A
B=16𝐵16B=16italic_B = 16 0.711 0.642 0.648 0.558 0.635 0.730 0.654
B=32𝐵32B=32italic_B = 32 0.729 0.636 0.648 0.584 0.627 0.735 0.660
B=64𝐵64B=64italic_B = 64 0.738 0.635 0.645 0.598 0.630 0.744 0.665
B=128𝐵128B=128italic_B = 128 0.729 0.622 0.644 0.586 0.617 0.731 0.655
NLL B=1𝐵1B=1italic_B = 1 0.664 0.602 0.616 0.508 0.627 0.712 0.621
B=16𝐵16B=16italic_B = 16 0.662 0.635 0.656 0.561 0.660 0.691 0.644
B=32𝐵32B=32italic_B = 32 0.618 0.622 0.646 0.570 0.554 0.690 0.617
B=64𝐵64B=64italic_B = 64 0.590 0.616 0.635 0.556 0.574 0.678 0.608
B=128𝐵128B=128italic_B = 128 0.587 0.610 0.623 0.523 0.523 0.640 0.584
Refer to caption
Figure 3: Cross-modal interaction visualization. (A) BRCA WSIs with their prototype assignment map (categorical assignment of each histology patch to their nearest prototype), and prototype heatmaps of the the top-3 prominent tissue patterns in the WSI. (B) Morphological annotations provided by a board-certified pathologist of the nearest histology patches for each prototype. (C) For each prototype visualized in (A), we can visualize its most highly-attended pathways (h. \rightarrow g.), i.e., which pathways correspond to the queried prototype (pathway importance).

Appendix I Additional interpretability results

Unimodal: Compared to Fig. 2, Fig. 3A visualizes multiple prototype heatmaps, illustrating how different prototypes reflect distinct morphological tissue patterns in the tumor microenvironment. Using their nearest histology patches, a board-certified pathologist qualitatively assessed and captioned each prototype with a general morphological description. We found that using Ch.=16subscript𝐶h.16C_{\text{h.}}=16italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 16 may still have redundancy in unique prototypes, as multiple prototypes are found to describe IDC presence (C1, C8, C9, C10, C13). In general, however, each prototype was still found to be semantic in delineating general tumor tissue, normal connective tissue and stroma, adipose tissue, and tissue with immune cell presence, which is reflected in the high performance of unimodal MMP over other histology baselines found in Table 1 (with MMP using Ch.=8subscript𝐶h.8C_{\text{h.}}=8italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 8 and Ch.=32subscript𝐶h.32C_{\text{h.}}=32italic_C start_POSTSUBSCRIPT h. end_POSTSUBSCRIPT = 32 having worse performance in Table 3).

Multimodal: In Fig. 3C, we further visualize cross-modal histology to pathway interactions (h. \rightarrow g.) in BRCA, a unique capability in MMP compared to other works which have only visualized pathway to histology interactions (g. \rightarrow h.) (Chen et al., 2021; Xu et al., 2023; Jaume et al., 2024). Across all (h. \rightarrow g.) visualizations for the prototypes shown in Fig. 3C, fatty acid metabolism and cholesterol homeostasis were conserved in having high cross-attention scores, which corroborates with biomedical literature on how cancer cells hijack these pathways for exogenous energy uptake from the tissue microenvironment (enabling tumorigenesis and cancer progression) (Nelson et al., 2014; Koundouros & Poulogiannis, 2020). Other conserved and highly-attended pathways include tumor necrosis factor (TNF)-α𝛼\alphaitalic_α signaling and epithelial-mesenchymal transition (EMT), which are canonical markers related to tumor proliferation and invasion (Wu & Zhou, 2010; Dongre & Weinberg, 2019). We note that SurvPath also found EMT to have high importance, however, we note a subtle difference in that EMT importance was derived from attribution-based interpretability with respect to predicted survival risk, and not via cross-attention that pinpoints a relationship with an exact morphological pattern. In MMP, we find that EMT not only attends to invasive tumor (C10), but is also the most highly-attended pathway to adipose tissue (C5), which corroborates with recent and accumulating evidence of adipose tissue being more than a causal observer in contributing to inflammation and tumor progression (Wang et al., 2012; Olea-Flores et al., 2018; Giudetti et al., 2019; Ishay-Ronen et al., 2019; Olea-Flores et al., 2020; Loo et al., 2021).

Appendix J Limitations & recommendations for future directions

Multimodal interpretability: Due to potential redundancy of prototypes (corresponding to unique morphological patterns), queries for (h. \rightarrow g.) are not unique, with many prototypes associated with tumor cell presence of IDC morphology and thus querying similar pathways. In Fig. 2, we also note potential asymmetrical relationships in histology-pathway correspondences, in which cholesterol homeostasis highly attends to C13 (top 3 pathways out of 50) but C13 does not attend as highly to cholesterol homeostasis (top 6 prototypes out of 16). Again, this may be due to the redundancy of prototypes, with other IDC-related prototypes (C8) highly attending to cholesterol homeostasis instead. We note that though pathologist annotation found many clusters to correspond to similar morphological patterns for tumors, there may exist subtle differences in fine-grained features such as tumor grade, tumor invasiveness, tumor colocalization with stroma, adipose tissue, and immune cells which may have fine-grained interactions to pathways. Future directions include develo** approaches that would narrow down the number of unique prototypes, which may improve both survival modeling and cross-modal interpretability.

Study designs involving TCGA: The TCGA is the largest publicly-available pan-cancer atlas with paired histology-omic samples, and has been an immeasurable resource for the CPath community in building computational tools for unimodal and multimodal cancer prognosis. Still, the TCGA has several limitations which we provide caution. First, in addition to issues such as site-specific H&E intensity bias (Howard et al., 2021), and demographic bias (Vaidya et al., 2024), pretrained encoders developed on the TCGA should also be avoided when evaluating multimodal cancer prognosis tasks due to potential issues in data contamination. Though UNI was not pretrained on TCGA (Chen et al., 2024), using UNI (or any pretrained ROI encoder) as a part of non-parametric methods such as K-means clustering or GMMs may still lead to instances where all patches can be assigned to a single prototype, as demonstrated in PANTHER (Song et al., 2024a). Second, important consideration must be taken in utilizing the different survival endpoints available for each TCGA cohort. For instance, the median time-to-event and time-to-censor for disease-specific survival TCGA-BRCA is 26 and 25 months respectively, meaning that the follow-up time is too short to see breast cancer-specific deaths. Other works which have assessed the suitability of DSS as a survival endpoint in TCGA-BRCA were able to still show statistically significant differences between ER+ and ER- tumors, while also acknowledging its shortcomings (Liu et al., 2018).

Unimodal versus multimodal survival analysis: As emphasized in the Introduction and Related Work sections, multimodal survival analysis is a challenging clinical task that has seen significant interest in the biomedical, computer vision, and machine learning communities. Though multimodal integration generally outperforms unimodal baselines, we note that the development of better unimodal baselines may (or may not) close the performance gap for certain cancer types, which is an area of further exploration. In PORPOISE (Chen et al., 2022) and MCAT (Chen et al., 2021), multimodal integration (using ResNet50 features transferred from ImageNet for histology and gene families from MutSigDB for genomics) was found to improve in 9 out of 14 cancer types in the TCGA, with genomics generally outperforming histology in unimodal baselines. In SurvPath (Jaume et al., 2024), MOTCat (Xu et al., 2023) and PIBD (Zhang et al., 2024), which improved unimodal baselines in MCAT using CTransPath features and hallmark gene family features, also found very similar trends with multimodal improvement. Interestingly, MCAT was shown to lag behind unimodal genomics in the analysis of SurvPath, which may be attributed to not only stronger gene features used, but also higher computational complexity with the increased number of omics tokens used for Transformer attention (thus necessitating computational efficiency). In MMP, which improves the unimodal histology baseline further using UNI features, we observe that the unimodal ablation of MMP (based on GMM, 0.611 overall C-Index) is able to catch up with unimodal genomics baselines (0.612 to 0.614 overall C-Index) and also with multimodal baselines like MCAT (0.610 overall C-Index) (Table 1). We hypothesize that this is due to the simplicity of GMMs in representing WSIs as a fixed set of prototypes, which thus allows supervision using the Cox loss instead of the negative log-likelihood loss. As better unimodal baselines are developed, we envision new types of multimodal fusion techniques will also be needed that would emphasise simplicity and interpretability in develo** easy-to-train survival methods in high-dimensional, low-sample size regimes for cancer prognostication.