License: arXiv.org perpetual non-exclusive license
arXiv:2404.06253v1 [cs.CV] 09 Apr 2024
\jmlryear

2024 \jmlrworkshopFull Paper – MIDL 2024 submission \midlauthor\NameYitong Li\midljointauthortextContributed equally\nametag1,212{}^{1,2}start_FLOATSUPERSCRIPT 1 , 2 end_FLOATSUPERSCRIPT \Email[email protected]
\NameTom Nuno Wolf\midlotherjointauthor\nametag1,212{}^{1,2}start_FLOATSUPERSCRIPT 1 , 2 end_FLOATSUPERSCRIPT \Email[email protected]
\NameSebastian Pölsterl\nametag11{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPT \Email[email protected]
\NameIgor Yakushev\nametag33{}^{3}start_FLOATSUPERSCRIPT 3 end_FLOATSUPERSCRIPT \Email[email protected]
\NameDennis M. Hedderich\nametag44{}^{4}start_FLOATSUPERSCRIPT 4 end_FLOATSUPERSCRIPT \Email[email protected]
\NameChristian Wachinger\nametag1,212{}^{1,2}start_FLOATSUPERSCRIPT 1 , 2 end_FLOATSUPERSCRIPT \Email[email protected]
\addr11{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPT Laboratory for Artificial Intelligence in Medical Imaging, Department of Radiology, Technical University of Munich (TUM), Germany
\addr22{}^{2}start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPT Munich Center for Machine Learning (MCML), Germany
\addr33{}^{3}start_FLOATSUPERSCRIPT 3 end_FLOATSUPERSCRIPT Department of Nuclear Medicine, Klinikum rechts der Isar, TUM, Germany
\addr44{}^{4}start_FLOATSUPERSCRIPT 4 end_FLOATSUPERSCRIPT Department of Neuroradiology, Klinikum rechts der Isar, TUM, Germany

From Barlow Twins to Triplet Training:
Differentiating Dementia with Limited Data

Abstract

Differential diagnosis of dementia is challenging due to overlap** symptoms, with structural magnetic resonance imaging (MRI) being the primary method for diagnosis. Despite the clinical value of computer-aided differential diagnosis, research has been limited, mainly due to the absence of public datasets that contain diverse types of dementia. This leaves researchers with small in-house datasets that are insufficient for training deep neural networks (DNNs). Self-supervised learning shows promise for utilizing unlabeled MRI scans in training, but small batch sizes for volumetric brain scans make its application challenging. To address these issues, we propose Triplet Training for differential diagnosis with limited target data. It consists of three key stages: (i) self-supervised pre-training on unlabeled data with Barlow Twins, (ii) self-distillation on task-related data, and (iii) fine-tuning on the target dataset. Our approach significantly outperforms traditional training strategies, achieving a balanced accuracy of 75.6%. We further provide insights into the training process by visualizing changes in the latent space after each step. Finally, we validate the robustness of Triplet Training in terms of its individual components in a comprehensive ablation study. Our code is available at \urlhttps://github.com/ai-med/TripletTraining.

keywords:
differential diagnosis, dementia, transfer learning, limited data.

1 Introduction

The number of patients suffering from dementia is expected to increase to 152.8 million by 2050 [Nichols et al.(2022)Nichols, Steinmetz, Vollset, et al.], with Alzheimer’s Disease (AD) accounting for 60-80% of affected patients. Frontotemporal dementia (FTD) is the second most common type of dementia in the younger-elderly population (aged <<< 65 years) [Young et al.(2018)Young, Lavakumar, Tampi, Balachandran, and Tampi]. Accurately diagnosing different dementia types is challenging as symptoms overlap, but is crucial for patient management, therapy, and prognosis. In the clinical routine, differential diagnosis incorporates structural magnetic resonance imaging (sMRI) to evaluate distinct atrophy patterns. Despite the clinical importance of differential diagnosis, there is limited research in computer-aided diagnosis for this task compared to classifying AD and cognitively normal (CN) subjects, largely rooted in the lack of related public MRI datasets. Accessing in-house data from hospitals is an alternative; however, even if available, such data is typically too small to train DNNs successfully.

\floatconts

fig:problem Refer to caption

Figure 1: Triplet Training for differential diagnosis of dementia: 1) task un-related data is invoked with self-supervision, 2) self-distillation on task-related data, 3) the network is fine-tuned on the training part of the target dataset and evaluated on the test part.

At the same time, public datasets exist that focus on single types of dementia. For AD, the Alzheimer’s disease neuroimaging initiative (ADNI, adni.loni.usc.edu) provides an extensive resource [Jack et al.(2008)Jack, Bernstein, Fox, Thompson, Alexander, Harvey, Borowski, Britson, Whitwell, Ward, Dale, Felmlee, Gunter, Hill, Killiany, Schuff, Fox-Bosetti, Lin, Studholme, and Weiner]. Similarly, the initiative on Neuroimaging in Frontotemporal Dementia (NIFD, 4rtni-ftldni.ini.usc.edu) collected data for FTD. As a result, previous research on the differential diagnosis of AD and FTD combined the two datasets [Ma et al.(2020)Ma, Lu, Popuri, Wang, and Beg, Hu et al.(2021)Hu, Qing, Liu, Zhang, Lv, Wang, Wang, He, and Gao, Nguyen et al.(2022)Nguyen, Clément, et al.]. An inherent limitation of such a combination is the confounding of dataset and diagnosis, potentially yielding shortcut learning that differentiates datasets instead of diagnosis [Geirhos et al.(2020)Geirhos, Jacobsen, Michaelis, Zemel, Brendel, Bethge, and Wichmann]. While the evaluation of such a merged dataset easily leads to inflated estimates of classification accuracy, it can instead provide a valuable resource in the training process.

Population imaging studies, e.g., UK Biobank [Miller et al.(2016)Miller, Alfaro-Almagro, Bangerter, Thomas, Yacoub, Xu, Bartsch, Jbabdi, Sotiropoulos, Andersson, Griffanti, Douaud, Okell, Weale, Dragonu, Garratt, Hudson, Collins, Jenkinson, and Smith], establish an even larger resource of MRI data for training, but they do not contain task-related labels. Recent advances in self-supervised learning (SSL) can provide means to benefit from such data in an unsupervised fashion, which have not yet been incorporated for differential diagnosis. A challenge for applying common SSL methods like SimCLR [Chen et al.(2020)Chen, Kornblith, Norouzi, and Hinton] or SwAV [Caron et al.(2020)Caron, Misra, Mairal, Goyal, Bojanowski, and Joulin] to 3D brain MRI data is the need for large batch sizes and hence GPU memory, as they rely on hard negative samples to avoid collapse. Barlow Twins [Zbontar et al.(2021)Zbontar, **g, Misra, LeCun, and Deny] is an alternative that eliminates the need for negative samples and naturally avoids collapse by redundancy reduction. As a result, it demonstrates better robustness to small batch sizes, which makes it well-suited for SSL in neuroimaging.

We introduce Triplet Training for differential diagnosis with limited target data. Triplet Training, see \figurereffig:problem, combines three learning strategies to include all relevant MRI data in training. First, self-supervision trains the network on task un-related data without target labels (UK Biobank). Second, we apply self-distillation on a task-related dataset that is created by merging data from ADNI and NIFD. Third, we fine-tune the model on a training set of the small in-house clinical data. Our results demonstrate that Triplet Training outperforms competing methods while being robust to a variety of properties.

To summarize, our key contributions are:

  • [topsep=0pt,label=\bullet]

  • Triplet Training  for learning DNNs with limited target data.

  • Adapting Barlow Twins as an efficient SSL algorithm on volumetric brain MRI data.

  • Self-distillation to distill knowledge from the SSL-trained teacher network in combination with task-related labels.

  • Reporting of test accuracy for differential diagnosis of AD and FTD on a well-characterized single-site clinical dataset.

1.1 Related Work

Differential Diagnosis of AD and FTD with DNNs.

One line of research for differential diagnosis performs brain segmentation [Ma et al.(2020)Ma, Lu, Popuri, Wang, and Beg, Nguyen et al.(2022)Nguyen, Clément, et al.] and uses volume and thickness measurements for the classification. Such an approach may restrict learning general dementia-specific features across the entire brain. Motivated by the success of using a 3D-ResNet50 encoder-decoder on MRI [Hu et al.(2021)Hu, Qing, Liu, Zhang, Lv, Wang, Wang, He, and Gao] to extract latent representations for classification, we selected a 3D-ResNet as the backbone for our work.

As no public dataset exists comprising both AD and FTD patients, these methods combined ADNI and NIFD. The fundamental problem of such an approach is that datasets coincide with diagnosis; hence, it cannot be determined whether the network inadvertently learns to differentiate datasets instead of pathology [Geirhos et al.(2020)Geirhos, Jacobsen, Michaelis, Zemel, Brendel, Bethge, and Wichmann]. Thus, we incorporate ADNI and NIFD in Triplet Training for pretraining and evaluate on the in-house single-site dataset to allow for a reliable performance assessment.

Self-Supervised Learning and Self-Distillation in Medical Image Analysis.

In summary, self-supervised pre-training and self-distillation on medical images improve the performance of the downstream task, with domain-related datasets adding additional benefits. Such approaches have not yet been explored for differential diagnosis and have not yet been extended to Triplet Training. Moreover, research on Barlow Twins has been limited despite its attractive properties for volumetric medical images.

2 Methods

In this section, we present the details of Triplet Training to tackle the limited data availability for the target task. We utilize SSL with Barlow Twins to integrate task un-related data in the initial step. In the second step, we propose to include task-related data via self-distillation. Self-distillation fully utilizes the previous SSL step by aligning the distribution of latent features extracted by the student network with those learned from SSL, using the Kullback-Leibler (KL) divergence. This method not only builds on prior learning but also reduces the risk of overfitting on the task-related dataset. Finally, we fine-tune the model on the target dataset. Before going into technical details, we introduce notation and datasets.

2.1 Preliminaries and Datasets

We define a 3D image as H×W×Dsuperscript𝐻𝑊𝐷\mathcal{I}\in\mathbb{R}^{H\times W\times D}caligraphic_I ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_D end_POSTSUPERSCRIPT, with H𝐻Hitalic_H, W𝑊Witalic_W, D𝐷Ditalic_D as height, width and depth, respectively. A dataset consists of N𝑁Nitalic_N 3D images isubscript𝑖\mathcal{I}_{i}caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, i=1,,N𝑖1𝑁i=1,\dots,Nitalic_i = 1 , … , italic_N, and class labels yisubscript𝑦𝑖y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT if available. Our model consists of a feature extractor f:H×W×DZ:𝑓superscript𝐻𝑊𝐷superscript𝑍f:\mathbb{R}^{H\times W\times D}\rightarrow\mathbb{R}^{Z}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_D end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_Z end_POSTSUPERSCRIPT, with Z𝑍Zitalic_Z the latent space dimension, and a projection head g:ZC:𝑔superscript𝑍superscript𝐶g:\mathbb{R}^{Z}\rightarrow\mathbb{R}^{C}italic_g : blackboard_R start_POSTSUPERSCRIPT italic_Z end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT, which maps the latent vectors to outputs of dimension C𝐶Citalic_C. We select a 3D-ResNet backbone for the feature extractor f𝑓fitalic_f and a two-layer MLP for the projection head g𝑔gitalic_g (implementation details in  \sectionrefsec:architecture).

We utilize three datasets:

  1. 1.

    The unlabeled dataset 𝒰𝒰\mathcal{U}caligraphic_U comprises N=39,560𝑁39560N=39,560italic_N = 39 , 560 samples Xi𝒰=(i𝒰)subscriptsuperscript𝑋𝒰𝑖superscriptsubscript𝑖𝒰X^{\mathcal{U}}_{i}=(\mathcal{I}_{i}^{\mathcal{U}})italic_X start_POSTSUPERSCRIPT caligraphic_U end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_U end_POSTSUPERSCRIPT ) extracted from the UK Biobank Miller et al.(2016)Miller, Alfaro-Almagro, Bangerter, Thomas, Yacoub, Xu, Bartsch, Jbabdi, Sotiropoulos, Andersson, Griffanti, Douaud, Okell, Weale, Dragonu, Garratt, Hudson, Collins, Jenkinson, and Smith.

  2. 2.

    The labeled, task-related dataset 𝒟𝒟\mathcal{D}caligraphic_D consists of N=1,305𝑁1305N=1,305italic_N = 1 , 305 samples Xi𝒟=(i𝒟,yi𝒟),yi𝒟{CN,AD,FTD}formulae-sequencesuperscriptsubscript𝑋𝑖𝒟superscriptsubscript𝑖𝒟subscriptsuperscript𝑦𝒟𝑖subscriptsuperscript𝑦𝒟𝑖𝐶𝑁𝐴𝐷𝐹𝑇𝐷X_{i}^{\mathcal{D}}=(\mathcal{I}_{i}^{\mathcal{D}},y^{\mathcal{D}}_{i}),y^{% \mathcal{D}}_{i}\in\{CN,AD,FTD\}italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT = ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT , italic_y start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_y start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { italic_C italic_N , italic_A italic_D , italic_F italic_T italic_D } from ADNI and NIFD.

  3. 3.

    The labeled target in-house dataset 𝒯𝒯\mathcal{T}caligraphic_T consists of N=329𝑁329N=329italic_N = 329 samples Xi𝒯=(i𝒯,yi𝒯)superscriptsubscript𝑋𝑖𝒯superscriptsubscript𝑖𝒯superscriptsubscript𝑦𝑖𝒯X_{i}^{\mathcal{T}}=(\mathcal{I}_{i}^{\mathcal{T}},y_{i}^{\mathcal{T}})italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT = ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT ), yi𝒯{CN,AD,FTD}subscriptsuperscript𝑦𝒯𝑖𝐶𝑁𝐴𝐷𝐹𝑇𝐷y^{\mathcal{T}}_{i}\in\{CN,AD,FTD\}italic_y start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { italic_C italic_N , italic_A italic_D , italic_F italic_T italic_D } from hospital Klinikum rechts der Isar, Munich, Germany.

\tableref

tab:dataset_statistics reports demographic statistics for all three datasets.

Table 1: Statistics for unlabeled 𝒰𝒰\mathcal{U}caligraphic_U, task-related 𝒟𝒟\mathcal{D}caligraphic_D, and target 𝒯𝒯\mathcal{T}caligraphic_T datasets. MMSE denotes the Mini Mental State Examination score.

tab:dataset_statistics

Dataset Diagnosis  # Samples  % Female Age MMSE
𝒰=𝒰absent\mathcal{U}=caligraphic_U = UK Biobank N/A 39,560 52.6 63.6 ±plus-or-minus\pm± 7.5 N/A
𝒟=𝒟absent\mathcal{D}=caligraphic_D = ADNI+NIFD CN 766 56.9 71.9 ±plus-or-minus\pm± 7.1 29.0 ±plus-or-minus\pm± 1.2
AD 489 44.2 74.4 ±plus-or-minus\pm± 7.7 22.0 ±plus-or-minus\pm± 4.1
FTD 50 28.0 60.8 ±plus-or-minus\pm± 6.3 24.1 ±plus-or-minus\pm± 5.8
𝒯=𝒯absent\mathcal{T}=caligraphic_T = In-House CN 143 46.9 64.2 ±plus-or-minus\pm± 9.9 N/A
AD 110 50.0 67.3 ±plus-or-minus\pm± 8.4 N/A
FTD 76 50.0 64.6 ±plus-or-minus\pm± 9.4 N/A
Table 1: Statistics for unlabeled 𝒰𝒰\mathcal{U}caligraphic_U, task-related 𝒟𝒟\mathcal{D}caligraphic_D, and target 𝒯𝒯\mathcal{T}caligraphic_T datasets. MMSE denotes the Mini Mental State Examination score.

2.2 Triplet Training

1. Self-Supervised Learning. The self-supervision task proposed in Barlow Twins (BT) de-correlates features in latent space and has shown to be relatively robust with respect to the batch size Zbontar et al.(2021)Zbontar, **g, Misra, LeCun, and Deny. This benefits training with 3D medical images because their large size limits batch sizes. Hence, BT presents a promising approach for the initial step of Triplet Training.

To pre-train the feature extractor fθsuperscript𝑓𝜃f^{\theta}italic_f start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT with trainable parameters θ𝜃\thetaitalic_θ on the unlabeled dataset 𝒰𝒰\mathcal{U}caligraphic_U, two different augmentations A𝐴Aitalic_A and B𝐵Bitalic_B of an input image i𝒰superscriptsubscript𝑖𝒰\mathcal{I}_{i}^{\mathcal{U}}caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_U end_POSTSUPERSCRIPT are required. These augmented images A(i𝒰)𝐴superscriptsubscript𝑖𝒰A(\mathcal{I}_{i}^{\mathcal{U}})italic_A ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_U end_POSTSUPERSCRIPT ) and B(i𝒰)𝐵superscriptsubscript𝑖𝒰B(\mathcal{I}_{i}^{\mathcal{U}})italic_B ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_U end_POSTSUPERSCRIPT ) are fed into a neural network consisting of the feature extractor fθsuperscript𝑓𝜃f^{\theta}italic_f start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT and a projection head gθsuperscript𝑔𝜃g^{\theta}italic_g start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT, yielding two output latent vectors ziA=gθ(fθ(A(i𝒰)))subscriptsuperscript𝑧𝐴𝑖superscript𝑔𝜃superscript𝑓𝜃𝐴superscriptsubscript𝑖𝒰z^{A}_{i}=g^{\theta}(f^{\theta}(A(\mathcal{I}_{i}^{\mathcal{U}})))italic_z start_POSTSUPERSCRIPT italic_A end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_g start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT ( italic_f start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT ( italic_A ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_U end_POSTSUPERSCRIPT ) ) ) and ziB=gθ(fθ(B(i𝒰))),ziA,ziBCformulae-sequencesubscriptsuperscript𝑧𝐵𝑖superscript𝑔𝜃superscript𝑓𝜃𝐵superscriptsubscript𝑖𝒰subscriptsuperscript𝑧𝐴𝑖subscriptsuperscript𝑧𝐵𝑖superscript𝐶z^{B}_{i}=g^{\theta}(f^{\theta}(B(\mathcal{I}_{i}^{\mathcal{U}}))),z^{A}_{i},z% ^{B}_{i}\in\mathbb{R}^{C}italic_z start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_g start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT ( italic_f start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT ( italic_B ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_U end_POSTSUPERSCRIPT ) ) ) , italic_z start_POSTSUPERSCRIPT italic_A end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT. The model is optimized by maximizing the cross-correlation between corresponding features of different augmentations 𝒞ccsubscript𝒞𝑐𝑐\mathcal{C}_{cc}caligraphic_C start_POSTSUBSCRIPT italic_c italic_c end_POSTSUBSCRIPT and minimizing the cross-correlation between the remaining components 𝒞cjsubscript𝒞𝑐𝑗\mathcal{C}_{cj}caligraphic_C start_POSTSUBSCRIPT italic_c italic_j end_POSTSUBSCRIPT:

BT=c(1𝒞cc)2+λ1cjc𝒞cj2,  with 𝒞cj=izi,cAzi,jBi(zi,cA)2i(zi,jB)2subscriptBTsubscript𝑐superscript1subscript𝒞𝑐𝑐2subscript𝜆1subscript𝑐subscript𝑗𝑐superscriptsubscript𝒞𝑐𝑗2,  with subscript𝒞𝑐𝑗subscript𝑖subscriptsuperscript𝑧𝐴𝑖𝑐subscriptsuperscript𝑧𝐵𝑖𝑗subscript𝑖superscriptsubscriptsuperscript𝑧𝐴𝑖𝑐2subscript𝑖superscriptsubscriptsuperscript𝑧𝐵𝑖𝑗2\mathcal{L}_{\text{BT}}=\sum_{c}(1-\mathcal{C}_{cc})^{2}+\lambda_{1}\sum_{c}% \sum_{j\neq c}{\mathcal{C}_{cj}}^{2}\textrm{, \ with \ }\mathcal{C}_{cj}=\frac% {\sum_{i}z^{A}_{i,c}z^{B}_{i,j}}{\sqrt{\sum_{i}(z^{A}_{i,c})^{2}}\sqrt{\sum_{i% }(z^{B}_{i,j})^{2}}}caligraphic_L start_POSTSUBSCRIPT BT end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( 1 - caligraphic_C start_POSTSUBSCRIPT italic_c italic_c end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_j ≠ italic_c end_POSTSUBSCRIPT caligraphic_C start_POSTSUBSCRIPT italic_c italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , with caligraphic_C start_POSTSUBSCRIPT italic_c italic_j end_POSTSUBSCRIPT = divide start_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT italic_A end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_A end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG

with c=1,,C𝑐1𝐶c=1,\dots,Citalic_c = 1 , … , italic_C indices across the latent space dimension C𝐶Citalic_C, i𝑖iitalic_i the index of a sample within the dataset 𝒰𝒰\mathcal{U}caligraphic_U, and λ1subscript𝜆1\lambda_{1}italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT a constant hyper-parameter. This loss makes embeddings invariant to distortions while also reducing redundant information. We denote the resulting weights after this self-supervised pre-training step as θsuperscript𝜃\theta^{\prime}italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT.

\floatconts

fig:pipeline Refer to caption

Figure 2: Overview of the three stages of Triplet Training.

2. Self-Distillation. This step requires the feature extractor fθsuperscript𝑓𝜃f^{\theta}italic_f start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT, with pre-trained weights θ=θ𝜃superscript𝜃\theta=\theta^{\prime}italic_θ = italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT from the previous step, as a teacher. We freeze the teacher network fθsuperscript𝑓𝜃f^{\theta}italic_f start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT during training to reduce the risk of over-fitting towards the task-related dataset 𝒟𝒟\mathcal{D}caligraphic_D. We randomly initialize a student network fψsuperscript𝑓𝜓f^{\psi}italic_f start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT with trainable parameters ψ𝜓\psiitalic_ψ of the same architecture as the teacher, and an additional projection head gψsuperscript𝑔𝜓g^{\psi}italic_g start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT. Inspired by Tian et al.(2020)Tian, Wang, Krishnan, Tenenbaum, and Isola, the student is trained on the task-related dataset 𝒟𝒟\mathcal{D}caligraphic_D by minimizing the KL divergence KLsubscriptKL\mathcal{L}_{\text{KL}}caligraphic_L start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT between the outputs of the feature extractors fθ(i𝒟)superscript𝑓𝜃superscriptsubscript𝑖𝒟f^{\theta}(\mathcal{I}_{i}^{\mathcal{D}})italic_f start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT ) and fψ(i𝒟)superscript𝑓𝜓superscriptsubscript𝑖𝒟f^{\psi}(\mathcal{I}_{i}^{\mathcal{D}})italic_f start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT ), and minimizing the cross-entropy CEsubscriptCE\mathcal{L}_{\text{CE}}caligraphic_L start_POSTSUBSCRIPT CE end_POSTSUBSCRIPT between the predictions of the student gψ(fψ(i𝒟))superscript𝑔𝜓superscript𝑓𝜓superscriptsubscript𝑖𝒟g^{\psi}(f^{\psi}(\mathcal{I}_{i}^{\mathcal{D}}))italic_g start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT ( italic_f start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT ) ) and the related class labels yi𝒟superscriptsubscript𝑦𝑖𝒟y_{i}^{\mathcal{D}}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT:

SD=λ2KL(𝒵ψfψ(i𝒟),𝒵θfθ(i𝒟))+(1λ2)iCE(gψ(fψ(i𝒟)),yi𝒟),subscriptSDsubscript𝜆2subscriptKLformulae-sequencesimilar-tosuperscript𝒵𝜓superscript𝑓𝜓superscriptsubscript𝑖𝒟similar-tosuperscript𝒵𝜃superscript𝑓𝜃superscriptsubscript𝑖𝒟1subscript𝜆2subscript𝑖subscriptCEsuperscript𝑔𝜓superscript𝑓𝜓superscriptsubscript𝑖𝒟superscriptsubscript𝑦𝑖𝒟\mathcal{L}_{\text{SD}}=\lambda_{2}\mathcal{L}_{\text{KL}}(\mathcal{Z}^{\psi}% \sim f^{\psi}(\mathcal{I}_{i}^{\mathcal{D}}),\mathcal{Z}^{\theta}\sim f^{% \theta}(\mathcal{I}_{i}^{\mathcal{D}}))+(1-\lambda_{2})\sum_{i}\mathcal{L}_{% \text{CE}}(g^{\psi}(f^{\psi}(\mathcal{I}_{i}^{\mathcal{D}})),y_{i}^{\mathcal{D% }}),\vspace{-0.2cm}caligraphic_L start_POSTSUBSCRIPT SD end_POSTSUBSCRIPT = italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( caligraphic_Z start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT ∼ italic_f start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT ) , caligraphic_Z start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT ∼ italic_f start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT ) ) + ( 1 - italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT CE end_POSTSUBSCRIPT ( italic_g start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT ( italic_f start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT ( caligraphic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT ) ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_D end_POSTSUPERSCRIPT ) ,

with 𝒵θsuperscript𝒵𝜃\mathcal{Z}^{\theta}caligraphic_Z start_POSTSUPERSCRIPT italic_θ end_POSTSUPERSCRIPT and 𝒵ψsuperscript𝒵𝜓\mathcal{Z}^{\psi}caligraphic_Z start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT random variables sampled via forward passes of the samples from the dataset 𝒟𝒟\mathcal{D}caligraphic_D, and λ2subscript𝜆2\lambda_{2}italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT a constant hyper-parameter trading off the importance of the first and second terms of SDsubscriptSD\mathcal{L}_{\text{SD}}caligraphic_L start_POSTSUBSCRIPT SD end_POSTSUBSCRIPT. The resulting weights of the student network are denoted as ψsuperscript𝜓\psi^{\prime}italic_ψ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT.

3. Fine-Tuning. In the final step, we optimize the student network fψsuperscript𝑓𝜓f^{\psi}italic_f start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT, gψsuperscript𝑔𝜓g^{\psi}italic_g start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT initialized with pre-trained weights ψ=ψ𝜓superscript𝜓\psi=\psi^{\prime}italic_ψ = italic_ψ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT from the previous step, by fine-tuning it on the in-house dataset 𝒯𝒯\mathcal{T}caligraphic_T for the target task using cross-entropy loss:

FT=iCE(gψ(fψ(i𝒯)),yi𝒯).subscriptFTsubscript𝑖subscriptCEsuperscript𝑔𝜓superscript𝑓𝜓subscriptsuperscript𝒯𝑖subscriptsuperscript𝑦𝒯𝑖\mathcal{L}_{\text{FT}}=\sum_{i}\mathcal{L}_{\text{CE}}(g^{\psi}(f^{\psi}(% \mathcal{I}^{\mathcal{T}}_{i})),y^{\mathcal{T}}_{i}).caligraphic_L start_POSTSUBSCRIPT FT end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT CE end_POSTSUBSCRIPT ( italic_g start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT ( italic_f start_POSTSUPERSCRIPT italic_ψ end_POSTSUPERSCRIPT ( caligraphic_I start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , italic_y start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) .

3 Experiments

Pre-processing and Data Augmentation: Each T1-weighted MRI scan is pre-processed using SPM111https://www.fil.ion.ucl.ac.uk/spm/software/spm12 and the VBM pipeline of CAT12 Gaser et al.(2022)Gaser, Dahnke, Thompson, Kurth, and Luders. The results are gray-matter density volumes (samples with a quality control score lower than B– are discarded), which are min-max rescaled, center-cropped, and resampled to a spatial dimension of 55×55×5555555555\times 55\times 5555 × 55 × 55 (for training convenience without sacrificing model performance). \sectionrefsec:data_augmentation reports details about the data augmentation strategy.

Evaluation: As the target dataset 𝒯𝒯\mathcal{T}caligraphic_T is relatively small, we perform 5-fold cross-validation with ratios of 65%, 15%, and 20% for train, validation, and test sets, respectively, stratified by age, gender, and diagnostic labels to prevent biased results Barnes et al.(2010)Barnes, Ridgway, Bartlett, Henley, Lehmann, Hobbs, Clarkson, MacManus, Ourselin, and Fox. Additionally, we split a balanced 20%-portion of the task-related dataset 𝒟𝒟\mathcal{D}caligraphic_D to perform further evaluations for the task at hand.

Miscellaneous: Hyper-parameters for the individual training steps and search spaces of baseline methods are reported in \sectionrefsec:hyperparams. We implement models with PyTorch Paszke et al.(2019)Paszke, Gross, Massa, Lerer, Bradbury, Chanan, Killeen, et al. and train on one NVIDIA GeForce 3090 with 24 GByte memory. We train the model for 29,300 self-supervised iterations (24 hours), followed by 600 self-distillation iterations (2.5 hours) and 150 fine-tuning iterations with early stop** (40 minutes).

4 Results

As a baseline, we implement a non-deep learning approach for the differential diagnosis on 𝒯𝒯\mathcal{T}caligraphic_T, by extracting FreeSurfer Fischl(2012) volume and thickness features from MRI scans to train an XGBoost classifier, which achieves a balanced accuracy (BAcc) of 66.46 ±plus-or-minus\pm± 3.45%.

Table 2: Mean, standard deviation, and pairwise p-values of the balanced accuracy (BAcc), true positive rate per class (TPR), and macro-F1 score (F1) across splits for 3-class differential diagnosis.
Training Strategy 𝒰𝒰\mathcal{U}caligraphic_U 𝒟𝒟\mathcal{D}caligraphic_D 𝒯𝒯\mathcal{T}caligraphic_T BAcc𝒯subscriptBAcc𝒯\textrm{BAcc}_{\mathcal{T}}BAcc start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT p-value TPRCNsubscriptTPRCN\textrm{TPR}_{\textrm{CN}}TPR start_POSTSUBSCRIPT CN end_POSTSUBSCRIPT TPRADsubscriptTPRAD\textrm{TPR}_{\textrm{AD}}TPR start_POSTSUBSCRIPT AD end_POSTSUBSCRIPT TPRFTDsubscriptTPRFTD\textrm{TPR}_{\textrm{FTD}}TPR start_POSTSUBSCRIPT FTD end_POSTSUBSCRIPT F1𝒯subscriptF1𝒯\textrm{F1}_{\mathcal{T}}F1 start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT BAcc𝒟subscriptBAcc𝒟\textrm{BAcc}_{\mathcal{D}}BAcc start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT
Supervised 67.15 ±plus-or-minus\pm± 5.36 0.011 69.9 65.5 65.8 66.94 ±plus-or-minus\pm± 5.52 -
Supervised 68.44 ±plus-or-minus\pm± 4.63 0.016 79.7 66.4 59.2 69.78 ±plus-or-minus\pm± 4.26 78.2
Self-Supervised (SimCLR) 63.47 ±plus-or-minus\pm± 4.38 0.001 86.0 50.0 54.0 64.44 ±plus-or-minus\pm± 4.13 -
Chen et al.(2020)Chen, Kornblith, Norouzi, and Hinton
Self-Supervised (VICReg) 68.94 ±plus-or-minus\pm± 3.42 0.012 72.7 70.0 64.5 69.22 ±plus-or-minus\pm± 2.78 -
Bardes et al.(2022)Bardes, Ponce, and LeCun
Self-Supervised (DiRA) 66.78 ±plus-or-minus\pm± 0.89 0.001 80.4 60.9 59.2 67.21 ±plus-or-minus\pm± 2.03 -
Haghighi et al.(2022)Haghighi, Taher, Gotway, and Liang
Self-Supervised (BT) 71.36 ±plus-or-minus\pm± 4.18 0.072 79.7 68.2 65.8 72.24 ±plus-or-minus\pm± 3.78 -
Zbontar et al.(2021)Zbontar, **g, Misra, LeCun, and Deny
Triplet Training (Ours) 75.57 ±plus-or-minus\pm± 3.62 - 81.8 71.8 73.7 75.32 ±plus-or-minus\pm± 4.51 85.6

As seen in \tablereftab:results, training a DNN on the target dataset 𝒯𝒯\mathcal{T}caligraphic_T alone results in a BAcc of 67.15 ±plus-or-minus\pm± 4.78%, which is likely due to the overfitting on the small task-specific data. Pre-training the model on the task-related dataset 𝒟𝒟\mathcal{D}caligraphic_D improves the performance only marginally by 1.29%. Pre-training with unlabeled 𝒰𝒰\mathcal{U}caligraphic_U with established SSL methods (SimCLR Chen et al.(2020)Chen, Kornblith, Norouzi, and Hinton, VICReg Bardes et al.(2022)Bardes, Ponce, and LeCun, DiRA Haghighi et al.(2022)Haghighi, Taher, Gotway, and Liang, and Barlow Twins Zbontar et al.(2021)Zbontar, **g, Misra, LeCun, and Deny) and then fine-tuning on 𝒯𝒯\mathcal{T}caligraphic_T outperforms supervised pre-training on 𝒟𝒟\mathcal{D}caligraphic_D by 2.92% (with Barlow Twins). Triplet Training, which adds a self-distillation step on 𝒟𝒟\mathcal{D}caligraphic_D after self-supervised pre-training, significantly outperforms all competing approaches on the target dataset, achieving a BAcc of 75.57 ±plus-or-minus\pm± 3.62% with the highest true positive rates for both types of dementia (see \tablereftab:results).

Additionally, we evaluate Triplet Training on the hold-out test set of 𝒟𝒟\mathcal{D}caligraphic_D after self-distillation on 𝒟𝒟\mathcal{D}caligraphic_D, which clearly outperforms (+7.4%) supervised training on 𝒟𝒟\mathcal{D}caligraphic_D alone (denoted as BAcc𝒟subscriptBAcc𝒟\textrm{BAcc}_{\mathcal{D}}BAcc start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT in Table 2). This indicates that Triplet Training potentially mitigates overfitting when training with limited data, thus, extracts features that generalize well.


\subfigure

[Pre-training on 𝒰𝒰\mathcal{U}caligraphic_U] Refer to caption \subfigure[Self-distillation on 𝒟𝒟\mathcal{D}caligraphic_D] Refer to caption \subfigure[Fine-tuning on 𝒯𝒯\mathcal{T}caligraphic_T] Refer to caption \subfigure[Pre-training on 𝒰𝒰\mathcal{U}caligraphic_U (only 𝒰𝒰\mathcal{U}caligraphic_U colored)] Refer to caption \subfigure[Self-distillation on 𝒟𝒟\mathcal{D}caligraphic_D (only 𝒟𝒟\mathcal{D}caligraphic_D colored)] Refer to caption \subfigure[Fine-tuning on 𝒯𝒯\mathcal{T}caligraphic_T (only 𝒯𝒯\mathcal{T}caligraphic_T colored)] Refer to caption

Figure 3: Changes in latent space of all datasets (first row) and the step-wise target dataset (second row) after each step in Triplet Training with UMAP. 𝒰𝒰\mathcal{U}caligraphic_U: No label (purple, representative fraction of samples to improve readability); Task-related 𝒟𝒟\mathcal{D}caligraphic_D: CN (dark blue), AD (red), FTD (dark grey); In-house 𝒯𝒯\mathcal{T}caligraphic_T: CN (light blue), AD (orange), FTD (light grey).

Visualization of the latent space.

We argue that the high accuracy of Triplet Training is rooted in decision boundaries of the classifier that are less population dependent. Therefore, we plot the evolution of the latent features of all three datasets 𝒰𝒰\mathcal{U}caligraphic_U, 𝒟𝒟\mathcal{D}caligraphic_D and 𝒯𝒯\mathcal{T}caligraphic_T after each step in Triplet Training with UMAP McInnes et al.(2018)McInnes, Healy, and Melville, visualized in \figurereffig:latent_space. After self-supervised pre-training on 𝒰𝒰\mathcal{U}caligraphic_U only, all samples of different classes from the three datasets are mixed together. After self-distillation on 𝒟𝒟\mathcal{D}caligraphic_D, there is a trend of separation between CN, AD, and FTD samples from all datasets. The unlabeled samples drawn from 𝒰𝒰\mathcal{U}caligraphic_U display considerable overlap with the CN samples, which aligns with expectations as the majority of the UK Biobank samples consist of healthy individuals. Furthermore, the final features extracted after full Triplet Training are well separated for each class without dataset dependence, with a particularly clean cluster of FTD samples from 𝒟𝒟\mathcal{D}caligraphic_D and 𝒯𝒯\mathcal{T}caligraphic_T. Moreover, CN and AD samples of 𝒟𝒟\mathcal{D}caligraphic_D maintain a clear separation, indicating that the network did not unlearn the previous knowledge while fitting on the new domain. This property is crucial in continual learning and domain adaptation, showing that Triplet Training generalizes well even with limited data available for the target task.

Ablation Study 1: Hyper-parameters.

As shown in the original work Zbontar et al.(2021)Zbontar, **g, Misra, LeCun, and Deny, Barlow Twins is relatively robust to the batch size. However, the evaluated batch sizes up to 4,096 are infeasible when working with volumetric images. Thus, we examine the robustness of Triplet Training w.r.t. batch sizes typically used in DNNs for medical image analysis. As seen in \figurereffig:batchsize, Triplet Training consistently surpasses both supervised training on 𝒯𝒯\mathcal{T}caligraphic_T and pre-training on 𝒟𝒟\mathcal{D}caligraphic_D, 𝒯𝒯\mathcal{T}caligraphic_T across all batch sizes, with 128 (used for all experiments) achieving the highest performance marginally over the other batch sizes. Evidently, Triplet Training benefits from a moderate increase in batch size and surpasses all competing methods regardless of the batch size, demonstrating considerable robustness to the batch size variation.  \figurereffig:lambdasize shows that Triplet Training outperforms the baseline methods for a wide range of λ2subscript𝜆2\lambda_{2}italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, a constant hyper-parameter used during self-distillation.

\subfigure

[BAcc for different batch sizes.] Refer to caption \subfigure[BAcc for different values of λ2subscript𝜆2\lambda_{2}italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT.] Refer to caption

Figure 4: Ablation studies of hyper-parameters in Triplet Training.

Ablation Study 2: Benchmark Self-Supervised Approaches.

We replace the SSL algorithm (Barlow Twins) in the initial step of Triplet Training with three SOTA algorithms. \tablereftab:ablation_selfsupervised reports that Triplet Training showcases high and consistent accuracy across all SSL methods, highlighting its robustness and generalizability. Among them, Barlow Twins and SimCLR demonstrate the best performance, and introduce few additional hyper-parameters compared to the other methods. We argue that Barlow Twins is the optimal choice, as it has shown to be robust in terms of the batch sizes.

Table 3: Mean and standard deviation of the balanced accuracy (BAcc), true positive rate (TPR), and macro-F1 score (F1) for different SSL approaches in the initial step of the Triplet Training. We propose to use Barlow Twins (BT) in Triplet Training.
SSL in Triplet Training BAcc𝒯subscriptBAcc𝒯\textrm{BAcc}_{\mathcal{T}}BAcc start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT TPRCNsubscriptTPRCN\textrm{TPR}_{\textrm{CN}}TPR start_POSTSUBSCRIPT CN end_POSTSUBSCRIPT TPRADsubscriptTPRAD\textrm{TPR}_{\textrm{AD}}TPR start_POSTSUBSCRIPT AD end_POSTSUBSCRIPT TPRFTDsubscriptTPRFTD\textrm{TPR}_{\textrm{FTD}}TPR start_POSTSUBSCRIPT FTD end_POSTSUBSCRIPT F1𝒯subscriptF1𝒯\textrm{F1}_{\mathcal{T}}F1 start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT BAcc𝒟subscriptBAcc𝒟\textrm{BAcc}_{\mathcal{D}}BAcc start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT
SimCLR Chen et al.(2020)Chen, Kornblith, Norouzi, and Hinton 75.22 ±plus-or-minus\pm± 2.80 86.7 69.1 69.7 75.64 ±plus-or-minus\pm± 2.74 86.0
VicReg Bardes et al.(2022)Bardes, Ponce, and LeCun 73.44 ±plus-or-minus\pm± 4.92 83.9 69.1 67.1 74.15 ±plus-or-minus\pm± 4.91 85.5
DiRA Haghighi et al.(2022)Haghighi, Taher, Gotway, and Liang 74.49 ±plus-or-minus\pm± 4.14 86.7 65.5 71.1 74.85 ±plus-or-minus\pm± 4.03 85.4
BT Zbontar et al.(2021)Zbontar, **g, Misra, LeCun, and Deny 75.57 ±plus-or-minus\pm± 3.62 81.8 71.8 73.7 75.32 ±plus-or-minus\pm± 4.51 85.6

5 Conclusion

We introduced Triplet Training for differential diagnosis of dementia, which enhances predictive performance for tasks with limited data availability. Triplet Training consists of three steps that fully utilize large-scale unlabeled data, task-related data, and limited amounts of target data, achieving a BAcc of 75.6% on a well-characterized clinical dataset while showing strong generalizability. Ablation studies confirmed Triplet Training’s robustness against varying hyper-parameters and method selection in the initial step.

References

  • [Azizi et al.(2021)Azizi, Mustafa, Ryan, Beaver, Freyberg, Deaton, Loh, Karthikesalingam, Kornblith, Chen, Natarajan, and Norouzi] Shekoofeh Azizi, Basil Mustafa, Fiona Ryan, Zachary Beaver, Jan Freyberg, Jonathan Deaton, Aaron Loh, Alan Karthikesalingam, Simon Kornblith, Ting Chen, Vivek Natarajan, and Mohammad Norouzi. Big self-supervised models advance medical image classification. In ICCV, 2021.
  • [Bardes et al.(2022)Bardes, Ponce, and LeCun] Adrien Bardes, Jean Ponce, and Yann LeCun. Vicreg: Variance-invariance-covariance regularization for self-supervised learning. In ICLR, 2022.
  • [Barnes et al.(2010)Barnes, Ridgway, Bartlett, Henley, Lehmann, Hobbs, Clarkson, MacManus, Ourselin, and Fox] Josephine Barnes, Gerard R. Ridgway, Jonathan Bartlett, Susie M.D. Henley, Manja Lehmann, Nicola Hobbs, Matthew J. Clarkson, David G. MacManus, Sebastien Ourselin, and Nick C. Fox. Head size, age and gender adjustment in mri studies: a necessary nuisance? NeuroImage, 53(4):1244–1255, 2010. ISSN 1053-8119. https://doi.org/10.1016/j.neuroimage.2010.06.025.
  • [Caron et al.(2020)Caron, Misra, Mairal, Goyal, Bojanowski, and Joulin] Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, and Armand Joulin. Unsupervised learning of visual features by contrasting cluster assignments. Advances in neural information processing systems, 33:9912–9924, 2020.
  • [Chaitanya et al.(2020)Chaitanya, Erdil, Karani, and Konukoglu] Krishna Chaitanya, Ertunc Erdil, Neerav Karani, and Ender Konukoglu. Contrastive learning of global and local features for medical image segmentation with limited annotations. In Advances in Neural Information Processing Systems, 2020.
  • [Chen et al.(2019)Chen, Bentley, Mori, Misawa, Fujiwara, and Rueckert] Liang Chen, Paul Bentley, Kensaku Mori, Kazunari Misawa, Michitaka Fujiwara, and Daniel Rueckert. Self-supervised learning for medical image analysis using image context restoration. Medical Image Analysis, 58:101539, 2019. ISSN 1361-8415.
  • [Chen et al.(2020)Chen, Kornblith, Norouzi, and Hinton] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey E. Hinton. A simple framework for contrastive learning of visual representations. In ICML, 2020.
  • [Fischl(2012)] Bruce Fischl. Freesurfer. NeuroImage, 62(2):774–781, 2012. ISSN 1053-8119. https://doi.org/10.1016/j.neuroimage.2012.01.021. 20 YEARS OF fMRI.
  • [Gaser et al.(2022)Gaser, Dahnke, Thompson, Kurth, and Luders] Christian Gaser, Robert Dahnke, Paul Thompson, Florian Kurth, and Eileen Luders. Cat – a computational anatomy toolbox for the analysis of structural mri data. bioRxiv, 2022.
  • [Geirhos et al.(2020)Geirhos, Jacobsen, Michaelis, Zemel, Brendel, Bethge, and Wichmann] Robert Geirhos, Jörn-Henrik Jacobsen, Claudio Michaelis, Richard Zemel, Wieland Brendel, Matthias Bethge, and Felix A Wichmann. Shortcut learning in deep neural networks. Nature Machine Intelligence, 2(11):665–673, 2020.
  • [Haghighi et al.(2022)Haghighi, Taher, Gotway, and Liang] Fatemeh Haghighi, Mohammad Reza Hosseinzadeh Taher, Michael B. Gotway, and Jianming Liang. Dira: Discriminative, restorative, and adversarial learning for self-supervised medical image analysis. In CVPR, 2022.
  • [Hosseinzadeh Taher et al.(2021)Hosseinzadeh Taher, Haghighi, Feng, et al.] Mohammad Reza Hosseinzadeh Taher, Fatemeh Haghighi, Ruibin Feng, et al. In A Systematic Benchmarking Analysis of Transfer Learning for Medical Image Analysis, pages 3–13, Cham, 2021. Springer International Publishing.
  • [Hu et al.(2021)Hu, Qing, Liu, Zhang, Lv, Wang, Wang, He, and Gao] **g**g Hu, Zhao Qing, Renyuan Liu, Xin Zhang, Pin Lv, Maoxue Wang, Yang Wang, Kelei He, and Yang Gao. Deep learning-based classification and voxel-based visualization of frontotemporal dementia and alzheimer’s disease. Frontiers in Neuroscience, 14, 01 2021.
  • [Jack et al.(2008)Jack, Bernstein, Fox, Thompson, Alexander, Harvey, Borowski, Britson, Whitwell, Ward, Dale, Felmlee, Gunter, Hill, Killiany, Schuff, Fox-Bosetti, Lin, Studholme, and Weiner] Clifford Jack, Matt Bernstein, Nick Fox, Paul Thompson, Gene Alexander, Danielle Harvey, Bret Borowski, Paula Britson, Jennifer Whitwell, Chadwick Ward, Anders Dale, Joel Felmlee, Jeffrey Gunter, Derek Hill, Ron Killiany, Norbert Schuff, Sabrina Fox-Bosetti, Chen Lin, Colin Studholme, and Michael Weiner. The alzheimer’s disease neuroimaging initiative (adni): Mri methods. Journal of magnetic resonance imaging: JMRI, 27:685–91, 05 2008.
  • [Jiang and Miao(2022)] Hongchao Jiang and Chunyan Miao. Pre-training 3d convolutional neural networks for prodromal alzheimer’s disease classification. In IJCNN, pages 1–8, 2022.
  • [Li et al.(2022)Li, Togo, Ogawa, and Haseyama] Guang Li, Ren Togo, Takahiro Ogawa, and Miki Haseyama. Self-knowledge distillation based self-supervised learning for covid-19 detection from chest x-ray images. In ICASSP, pages 1371–1375, 2022.
  • [Li et al.(2021)Li, Xue, Chaitanya, et al.] Hongwei Li, Fei-Fei Xue, Krishna Chaitanya, et al. Imbalance-aware self-supervised learning for 3d radiomic representations. In MICCAI. Springer, 2021.
  • [Ma et al.(2020)Ma, Lu, Popuri, Wang, and Beg] Da Ma, Donghuan Lu, Karteek Popuri, Lei Wang, and Mirza Faisal Beg. Differential diagnosis of frontotemporal dementia, alzheimer’s disease, and normal aging using a multi-scale multi-type feature generative adversarial deep neural network on structural magnetic resonance images. Frontiers in Neuroscience, 14, 10 2020.
  • [McInnes et al.(2018)McInnes, Healy, and Melville] Leland McInnes, John Healy, and James Melville. Umap: Uniform manifold approximation and projection for dimension reduction. arXiv preprint arXiv:1802.03426, 2018.
  • [Miller et al.(2016)Miller, Alfaro-Almagro, Bangerter, Thomas, Yacoub, Xu, Bartsch, Jbabdi, Sotiropoulos, Andersson, Griffanti, Douaud, Okell, Weale, Dragonu, Garratt, Hudson, Collins, Jenkinson, and Smith] Karla Miller, Fidel Alfaro-Almagro, Neal Bangerter, David Thomas, Essa Yacoub, Junqian Xu, Andreas Bartsch, Saad Jbabdi, Stamatios Sotiropoulos, Jesper Andersson, Ludovica Griffanti, Gwenaëlle Douaud, Thomas Okell, Peter Weale, Iulius Dragonu, Steve Garratt, Sarah Hudson, Rory Collins, Mark Jenkinson, and Stephen Smith. Multimodal population brain imaging in the uk biobank prospective epidemiological study. Nature neuroscience, 19, 09 2016.
  • [Nguyen et al.(2022)Nguyen, Clément, et al.] Huy-Dung Nguyen, Michaël Clément, et al. Interpretable differential diagnosis for alzheimer’s disease and frontotemporal dementia. In MICCAI, pages 55–65. Springer Nature Switzerland, 2022.
  • [Nichols et al.(2022)Nichols, Steinmetz, Vollset, et al.] Emma Nichols, Jaimie D. Steinmetz, Stein Emil Vollset, et al. Estimation of the global prevalence of dementia in 2019 and forecasted prevalence in 2050: an analysis for the global burden of disease study 2019. Lancet Public Health, 7(2):e105–e125, 2022.
  • [Paluru et al.(2023)Paluru, Ravishankar, Hegde, and Yalavarthy] Naveen Paluru, Hariharan Ravishankar, Sharat Hegde, and Phaneendra K. Yalavarthy. Self distillation for improving the generalizability of retinal disease diagnosis using optical coherence tomography images. IEEE Journal of Selected Topics in Quantum Electronics, 29(4: Biophotonics):1–12, 2023.
  • [Paszke et al.(2019)Paszke, Gross, Massa, Lerer, Bradbury, Chanan, Killeen, et al.] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, et al. Pytorch: An imperative style, high-performance deep learning library. In Advances in Neural Information Processing Systems 32, pages 8024–8035. 2019.
  • [Sun et al.(2021)Sun, Wei, Ma, Wang, and Zheng] **ghan Sun, Dong Wei, Kai Ma, Liansheng Wang, and Yefeng Zheng. Unsupervised representation learning meets pseudo-label supervised self-distillation: A new approach to rare disease classification. In MICCAI, pages 519–529. Springer, 2021.
  • [Taleb et al.(2020)Taleb, Loetzsch, Danz, Severin, Gaertner, Bergner, and Lippert] Aiham Taleb, Winfried Loetzsch, Noel Danz, Julius Severin, Thomas Gaertner, Benjamin Bergner, and Christoph Lippert. 3d self-supervised methods for medical imaging. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 18158–18172. Curran Associates, Inc., 2020.
  • [Tian et al.(2020)Tian, Wang, Krishnan, Tenenbaum, and Isola] Yonglong Tian, Yue Wang, Dilip Krishnan, Joshua B. Tenenbaum, and Phillip Isola. Rethinking few-shot image classification: A good embedding is all you need? In Andrea Vedaldi, Horst Bischof, Thomas Brox, and Jan-Michael Frahm, editors, Computer Vision – ECCV 2020, pages 266–282, Cham, 2020. Springer International Publishing. ISBN 978-3-030-58568-6.
  • [Tran et al.(2022)Tran, Wagner, Boxberg, and Peng] Manuel Tran, Sophia J. Wagner, Melanie Boxberg, and Tingying Peng. S5cl: Unifying fully-supervised, self-supervised, and semi-supervised learning through hierarchical contrastive learning. In Medical Image Computing and Computer Assisted Intervention – MICCAI 2022, pages 99–108. Springer Nature Limited, 2022.
  • [Ye et al.(2022)Ye, Zhang, Chen, and Xia] Yiwen Ye, Jianpeng Zhang, Ziyang Chen, and Yong Xia. Desd: Self-supervised learning with deep self-distillation for 3d medical image segmentation. In Linwei Wang, Qi Dou, P. Thomas Fletcher, Stefanie Speidel, and Shuo Li, editors, MICCAI, pages 545–555. Springer, 2022.
  • [Young et al.(2018)Young, Lavakumar, Tampi, Balachandran, and Tampi] Juan Joseph Young, Mallika Lavakumar, Deena Tampi, Silpa Balachandran, and Rajesh R. Tampi. Frontotemporal dementia: latest evidence and clinical implications. Therapeutic Advances in Psychopharmacology, 8:33 – 48, 2018.
  • [Zbontar et al.(2021)Zbontar, **g, Misra, LeCun, and Deny] Jure Zbontar, Li **g, Ishan Misra, Yann LeCun, and Stéphane Deny. Barlow twins: Self-supervised learning via redundancy reduction. In International Conference on Machine Learning (ICML), 2021.
  • [Zhou et al.(2020)Zhou, Yu, Bian, Hu, Ma, and Zheng] Hong-Yu Zhou, Shuang Yu, Cheng Bian, Yifan Hu, Kai Ma, and Yefeng Zheng. Comparing to learn: Surpassing imagenet pretraining on radiographs by comparing image representations. In MICCAI, 2020.
  • [Zhou et al.(2019)Zhou, Sodha, Rahman Siddiquee, Feng, Tajbakhsh, Gotway, and Liang] Zongwei Zhou, Vatsal Sodha, Md Mahfuzur Rahman Siddiquee, Ruibin Feng, Nima Tajbakhsh, Michael B. Gotway, and Jianming Liang. Models genesis: Generic autodidactic models for 3d medical image analysis. In Dinggang Shen, Tianming Liu, Terry M. Peters, Lawrence H. Staib, Caroline Essert, Sean Zhou, Pew-Thian Yap, and Ali Khan, editors, MICCAI, pages 384–393. Springer, 2019.

Appendix A Architecture

Refer to caption
Figure 5: We select a 3D ResNet as the feature extractor f𝑓fitalic_f for all models. It consists of six residual blocks, each consisting of two convolutional layers followed by batch normalization and ReLU non-linearity. The five last residual blocks each start with a convolutional layer with stride two.
\subfigure

[] Refer to caption \subfigure[] Refer to caption

Figure 6: Projection head g𝑔gitalic_g for: (a) self-supervision (Barlow Twins); (b) self-distillation and fine-tuning.

Appendix B Training Details

B.1 Hyper-parameters

Table 4: Hyper-parameters of the different training strategies. The number of iterations for each step is based on the convergence of the validation set. If available, we use the hyper-parameters proposed in the original work.
Training Strategy Hyper-Parameter Value
Supervised Training (𝒯𝒯\mathcal{T}caligraphic_T) Learning rate 0.01
Weight decay 0.00001
Batch size 64
Training iterations 150
Supervised Pre-Training (𝒟𝒟\mathcal{D}caligraphic_D) Learning rate 0.01
Weight decay 0.0000015
Batch size 128
Training iterations 600
Triplet Training (Self-Supervision) Learning rate 0.5
Weight decay 0.0000015
Batch size 128
Training iterations 29,300
λ1subscript𝜆1\lambda_{1}italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT 0.005
Triplet Training (Self-Distillation) Learning rate 0.01
Weight decay 0.0000015
Batch size 128
Training iterations 600
λ2subscript𝜆2\lambda_{2}italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT 0.001
Triplet Training (Fine-Tuning) Learning rate 0.0005
Weight decay 0.00001
Batch size 64
Training iterations 150

B.2 Data Augmentation

Table 5: Data Augmentations used in the Triplet Training.
Training Strategy Augmentation Values
Self-Supervision Rescale Intensity intensity range = (0, 1)
Random Crop** with Resizing crop scale = (0.5, 1.0)
output size = (55, 55, 55)
random center = True
Random Flip** axes = (0, 1, 2)
probability = 0.5
Random Affine Transformation rotation range = (9090-90- 90°, +9090+90+ 90°)
translation range = (88-8- 8 pixel, +88+8+ 8 pixel)
probability = 0.5
Self-Distillation Rescale Intensity intensity range = (0, 1)
Random Affine Transformation rotation range = (88-8- 8°, +88+8+ 8°)
translation range = (88-8- 8 pixel, +88+8+ 8 pixel)
probability = 0.5
Fine-Tuning Rescale Intensity intensity range = (0, 1)
Random Affine Transformation rotation range = (88-8- 8°, +88+8+ 8°)
translation range = (88-8- 8 pixel, +88+8+ 8 pixel)
probability = 0.5