PathoWAve: A Deep Learning-based Weight Averaging Method for Improving Domain Generalization in Histopathology Images
Abstract
Recent advancements in deep learning (DL) have significantly advanced medical image analysis. In the field of medical image processing, particularly in histopathology image analysis, the variation in staining protocols and differences in scanners present significant domain shift challenges, undermine the generalization capabilities of models to the data from unseen domains, prompting the need for effective domain generalization (DG) strategies to improve the consistency and reliability of automated cancer detection tools in diagnostic decision-making. In this paper, we introduce Pathology Weight Averaging (PathoWAve), a multi-source DG strategy for addressing domain shift phenomenon of DL models in histopathology image analysis. Integrating specific weight averaging technique with parallel training trajectories and a strategically combination of regular augmentations with histopathology-specific data augmentation methods, PathoWAve enables a comprehensive exploration and precise convergence within the loss landscape. This method significantly enhanced generalization capabilities of DL models across new, unseen histopathology domains. To the best of our knowledge, PathoWAve is the first proposed weight averaging method for DG in histopathology image analysis. Our quantitative results on Camelyon17 WILDS dataset demonstrate PathoWAve’s superiority over previous proposed methods to tackle the domain shift phenomenon in histopathology image processing. Our code is available at https://github.com/ParastooSotoudeh/PathoWAve
Index Terms:
Deep Learning, Medical Image Processing, Domain Generalization, Domain Shift, Weight Averaging, Histopathology images, Data AugmentationI Introduction
Histopathology is key in diagnosing and prognosticating diseases, especially in oncology. Whole slide images (WSIs) from tissue sections offer critical insights into tissue morphology, cellular structures, and disease progression. However, WSI analysis faces challenges due to variability in staining protocols, scanners, tissue preparation, and imaging systems across medical centers, causing significant data distribution changes. This variability leads to domain shift, where DL models trained on data from one domain falter in generalizing to unseen domains, compromising diagnoses in clinical applications [1]. DG techniques that enhance model invariance to data distribution changes, promising consistent performance across medical settings, are crucial for tackling the domain shift challenge. We categorized some of the current related proposed methods into the following categories:
DG methods: CORAL [2] aligns domain feature distributions via covariance matching, foundational to domain adaptation advances. IRM [3] introduces a framework for learning domain invariances, enhancing generalization to unseen environments. Group DRO [4] focuses on worst-case domain performance, improving model resilience. FISH [5] targets feature-level domain discrepancies for domain-invariant representation learning. PLDG [6] employs pseudo labeling for cross-domain data variability, enhancing model adaptability. TFS-ViT [7] utilizes token-level feature stylization in Vision Transformers for robustness, marking progress in domain generalization (DG) techniques. These methods highlight the progression towards bridging the source-target domain gap, fostering more refined approaches.
Data Augmentation methods: Data augmentation plays a pivotal role in DG. Data augmentation strategies [8, 9] provide simple yet effective strategies for introducing variability into the training process.
Test-Time Methods: Test-Time Training methods, like Test-time image-to-image translation [10], offer promising way for dynamically adjusting models in response to new domain characteristics encountered at inference.
In this paper, we propose PathoWAve, a domain generalization technique, tailored specifically for histopathology image analysis inspired by recent advances in weight averaging methods like SWA [11], SWAD [12], Lookahead [13], and Lookaround [14] and the understanding of loss landscapes in neural networks, particularly the concept of Linear Mode Connectivity (LMC) [15]. Our significant contributions through deploying PathoWAve on a histopathology dataset include:
-
•
Introduction of PathoWAve for addressing domain shift in Histopathology: PathoWAve, a multi-source domain generalization method, tackles domain shift challenges in histopathology image analysis caused by varied staining techniques and imaging conditions. It enhances generalization by training identical models on diversely augmented images in parallel, integrating an advanced weight averaging strategy within the training cycle to ensure model diversity and locality.
-
•
Strategic Combination of regular and histopathology-specific augmentation methods: PathoWAve merges regular and histopathology-specific augmentation techniques, notably employing HEDJitter, a unique method for histopathology images [16], demonstrating improved results and enhanced generalization through this combination.
-
•
State of the Art in weight averaging for DG in histopathology images: PathoWAve pioneers the use of weight averaging and combining histopathology-specific augmentation with regular augmentations to combat domain shift in histopathology images, marking a first in DG in histopathology analysis. Tested on the Camlyon17-WILDS dataset, our method outperforms existing DG techniques, proving its efficacy in mitigating domain shift and boosting model robustness to variations in unseen data.
![Refer to caption](x1.png)
II Method
In the specific context of medical image processing, particularly histopathology image analysis, significant challenges arise from the variability in staining techniques, scanners, imaging conditions, and tissue processing methods. These factors contribute to domain shift challenges in real-world scenarios, where accurate and timely disease detection is critical for patient care. Consequently, there is a pressing need for robust domain generalization strategies that can effectively address these issues, ensuring models remain invariant to data distribution shifts and maintain their reliability in diagnostic decision-making. In the subsequent sections, we first delve into the intricacies of domain generalization, laying the groundwork for understanding its significance. Subsequently, we introduce and elaborate on our proposed method, PathoWAve, specifically designed to address domain generalization challenges in histopathology image analysis.
II-A Domain Generalization Objective
Domain generalization (DG) tackles the challenge of develo** models that, when trained on multiple source domains, exhibit robust performance on previously unseen target domains. Each domain is characterized by its unique joint distribution over input space and target space , denoted as . The aim is to leverage the diversity among source domains to predict accurately in target domains that the model has not encountered during training.
Formally, the DG framework seeks to minimize the expected loss over unseen target domains, , using knowledge derived from a set of source domains. This objective can be expressed mathematically as:
(1) |
where is the predictive function designed to approximate the posterior distribution , and is the loss function measuring the discrepancy between predicted and actual labels.
Given that direct access to is not feasible during training, the strategy shifts towards minimizing the empirical risk over the source domains, represented as:
(2) |
where is the cumulative number of samples across all source domains, represents the parameters of model , and signifies the model parameterized by .
This problem setup acknowledges the inherent variability across domains by assuming each source domain for has a distinct joint distribution , reflecting the real-world scenario where training data might not encompass the full spectrum of variation present in unseen target data. The ultimate goal of DG is to construct a model that, despite the distributional differences encapsulated by for , can generalize effectively across novel domains, thereby ensuring reliable predictions for without requiring explicit knowledge of its distribution.
II-B PathoWAve Framework
Our PathoWAve framework introduces a novel cyclical training regime that leverages the strengths of specific weight averaging method during training process and integrates both standard and histopathology-specific data augmentations. This strategic combination aims to cultivate an expansive exploration of both locality and diversity within the model’s learning process, ultimately enhancing generalization capabilities across unseen histopathology domains. Figure 1 illustrates the overall architecture of our proposed PathoWAve framework.
Inspired by recent advancements in weight averaging methods [11, 12, 14, 13], we explore the potential of a multi-training trajectory approach. By having multiple training trajectories on an identical model architecture but diversified through carefully selected data augmentations, including histopathology-specific ones designed to simulate a broad spectrum of staining variations and imaging conditions, this approach seeks to achieve beneficial model diversity within the scope of histopathology images and thus enhancing generalization to unseen domains.
More specifically, our method employs the concurrent training of multiple identical neural networks, each initialized from a common point within the loss landscape. Each training path will be exposed to distinct batches () of data, with each batch sampled from the union of all source domain datasets . These batches are processed through a tailored suite of data augmentation strategies () to introduce a broad spectrum of variations reflective of real-world histopathological conditions.
The training process for each network is mathematically captured by the update formula:
(3) |
where represents the model parameters for the network updated after training iteration , denotes the learning rate, and signifies the gradient of the loss function with respect to the averaged model parameters , evaluated on the augmented batch . This approach ensures that each network is exposed to a variety of data representations through its specific augmentation strategy , enhancing the overall diversity of the model’s learning experience and its ability to generalize across unseen domains.
Following this phase of individual training, we integrate the weights of models of each training path through a weight averaging strategy, applied directly within the training cycle, aiming to reach flatter minima within the loss landscapes which results in better generalization of the model. This process is formalized as follows:
(4) |
where represents the unified set of weights obtained by averaging the parameters of each model out of the total models at iteration . This integration forms a cohesive weight set, serving as the starting point for all models in subsequent iterations and substantially enhancing their generalization capabilities. In this way, in addition to improving the diversity by having a multi-training trajectory and using specific augmentation methods, PathoWAve facilitates the convergence of the models towards lower-loss regions during the whole training cycle, promoting model robustness and locality. Through the PathoWAve method, we achieve a sophisticated implementation of the DG objective, effectively minimizing loss across unseen domains via a structured, iterative refinement process.
Choosing suitable augmentations plays a crucial role in the efficacy of the PathoWAve framework. Our approach uniquely combines general data augmentations—such as AutoAugment and RandAugment [8]—with histopathology-specific augmentations like HEDJitter [16] to address the broad spectrum of variability encountered in histopathological images. General augmentations introduce a wide range of variations in the dataset, fostering the model’s adaptability and robustness against common variations in image data. In contrast, the HEDJitter augmentation technique is meticulously designed for histopathology images, utilizing a predefined Optical Density (OD) matrix to transition images from RGB to a domain that emphasizes pathology stains—hematoxylin, eosin, and Diaminobenzidine (DAB). This technique’s capacity to independently adjust stain intensity levels simulates the diverse staining protocols found across laboratories, ensuring the preservation of crucial image features such as cell structures and tissue architecture.
The integration of both regular and histopathology-specific augmentations into the PathoWAve training regime is a strategic decision aimed at enhancing the model’s exposure to a wide array of data variations, thereby ensuring a more comprehensive learning experience. Regular augmentations prepare the model for a broad base of image variations, enhancing its adaptability and resilience to general shifts in input data distributions. Meanwhile, HEDJitter and similar histopathology-specific techniques target the nuanced challenges specific to histopathological imagery, such as staining variability, which are critical for achieving high diagnostic accuracy in unseen domains. This dual-strategy augmentation approach not only widens the model’s exposure but also fine-tunes its sensitivity to the unique challenges of histopathology image analysis. Consequently, PathoWAve exhibits outstanding generalization capabilities, setting a new benchmark in domain generalization for histopathology image analysis by leveraging this comprehensive augmentation strategy.
III Experiments
Dataset:
In our experiments, we used the Camelyon17 WILDS dataset [17], featuring patches from Whole Slide Images of lymph node sections across five medical centers with diverse staining protocols and scanners, to test DL models on metastatic breast cancer detection. This dataset is partitioned by medical center origin, for develo** generalized models to unseen data for cancerous tissues detection. For training and identification validation (id val), data come from three hospitals (30 WSIs and 302,436 patches for training, plus 33,560 patches for id val), while validation (val) and testing datasets are sourced from unique, previously unseen hospitals—val with 10 WSIs and 34,904 patches from one hospital, and testing with 10 WSIs and 85,054 patches from another hospital, ensuring models are assessed on their adaptability to new medical center data.
Implementation Details:
We utilized the ResNet50 as our network, and we used an NVIDIA V100 32 GB GPUs for all of our experiments. The learning rate and batch size are set to and , respectively. To enhance model’s robustness to staining variations, we incorporated several augmentation methods, including HedJitter augmentation [16] with a jitter_strength of 0.05.
Method | Backbone | Validation % | Test % |
---|---|---|---|
CORAL† (2016) [2] | ResNet50 | 86.2 | 59.5 |
IRM† (2019) [3] | ResNet50 | 86.2 | 64.2 |
Group DRO† (2019) [4] | ResNet50 | 85.5 | 68.4 |
DomainMix (2020) [18] | ResNet50 | — | 69.7 |
MMLD‡ (2020) [19] | ResNet50 | — | 70.2 |
ERM (2021) [17] | ResNet50 | — | 70.3 |
FISH† (2021) [5] | ResNet50 | 83.9 | 74.7 |
V_REx (2021) [20] | ResNet50 | — | 71.5 |
IB-IRM (2021) [21] | ResNet50 | — | 68.9 |
LISA (2022) [22] | ResNet50 | — | 77.1 |
FuseStyle (2023) [23] | ResNet50 | — | 90.5 |
CORAL‡ (2016) [2] | ViT-Base | — | 71.8 |
DANN‡ (2016) [24] | ViT-Base | — | 83.5 |
IRM‡ (2019) [3] | ViT-Base | — | 75.0 |
ERM‡ (2021) [17] | ViT-Base | — | 73.1 |
SelfReg‡ (2021) [25] | ViT-Base | — | 70.4 |
PLDG‡ (2024) [6] | ViT-Base | — | 84.3 |
EPVT‡ (2024) [26] | ViT-Base | — | 86.4 |
PathoWAve (ours) | ResNet50 | 93.07 | 94.36 |
IV Results
Comparison with State of the Art: Our comprehensive evaluation on the Camelyon17 WILDS dataset, presented in Tables I and II, illustrates PathoWAve’s exceptional capability to generalize across domain shifts within histopathology images. The comparison includes robust domain generalization (DG) methods, underlining PathoWAve’s state-of-the-art performance. Remarkably, PathoWAve, leveraging a straightforward ResNet architecture, excels beyond more complex architectures, including those based on the vision transformer (ViT). This underscores the efficiency of our proposed method. Moreover, a comparison with non-DG methods, including advanced training-time augmentation and test-time adaptation techniques, further highlights PathoWAve’s effectiveness. Specifically, PathoWAve outperforms methods like STRAP, which leverages non-histopathological data, and others employing dynamic adaptations during test time, as shown in our comparisons. Crucially, PathoWAve attains this high level of generalization and accuracy without leveraging direct test data insights, emphasizing the robustness of our proposed approach.
Method | #Independent Trajectories (Augmentations) | Test % |
ERM | 1 (baseline with no weight averaging) | 70.3 |
PathoWAve | 2 (AutoAugment, RandomAugment) | 92.53 |
PathoWAve | 2 (RandomAugment, HEDJitter) | 92.98 |
PathoWAve | 2 (AutoAugment, HEDJitter) | 94.20 |
PathoWAve | 3 (AutoAugment, RandomAugment, AutoRandomRotation) | 89.80 |
PathoWAve | 3 (AutoAugment, RandomAugment, RandomGaussBlur) | 88.91 |
PathoWAve | 3 (AutoAugment, RandomAugment, RandomAffine) | 91.53 |
PathoWAve | 3 (AutoAugment, RandomAugment, HEDJitter) | 94.36 |
Ablation Analysis: Our detailed ablation study, as summarized in Table III, evaluates the impact of various augmentation strategies and the number of independent training trajectories on the PathoWAve method’s effectiveness within the domain of histopathology image analysis on the Camelyon17 WILDS dataset. Initially establishing a baseline with the ERM method, which utilizes a single training trajectory without weight averaging, yielded a test accuracy of . The introduction of PathoWAve with dual augmentation strategies significantly enhances model performance, highlighting the method’s responsiveness to diverse training signals.
Notably, combinations involving two augmentations, particularly AutoAugment with HEDJitter, demonstrated remarkable improvements, achieving a test accuracy of . This underscores the critical role of HEDJitter, a histopathology-specific augmentation, in bolstering the model’s generalization capability across unseen domains.
Further exploration with three augmentations revealed varying degrees of success. While adding AutoRandomRotation, RandomGaussBlur, or RandomAffine to the AutoAugment and RandomAugment mix led to lower test accuracies compared to dual-augmentation setups, the incorporation of HEDJitter alongside AutoAugment and RandomAugment within a three-trajectory framework achieved the highest performance at . This pinnacle result not only signifies the optimal augmentation combination but also establishes PathoWAve as the state-of-the-art in domain generalization for histopathology images.
It is worth mentioning that our proposed method’s training time is A times that of traditional one-trajectory methods, as we perform A augmentations in parallel per iteration before weight averaging. Importantly, this overhead is only during training; the testing time remains the same as other methods since we use the averaged weights for evaluation.
V Conclusion
Our study presents PathoWAve, a weight averaging methodology for DG in histopathology imaging, achieving significant domain shift mitigation. Utilizing a multi-trajectory training strategy and a tailored mix of histopathology-specific augmentation with other augmentation techniques, PathoWAve outperforms existing models in robustness and accuracy on the Camelyon17 WILDS dataset, demonstrating its potential to improve automated cancer detection. Future research will explore the performance of our method on other real-world benchmarks in addition to Camelyon17, as well as additional histopathology-specific augmentations to further enhance model generalization. This aims to develop more robust and reliable diagnostic tools for cancer detection.
References
- [1] K. Zhou, Z. Liu, Y. Qiao, T. Xiang, and C. C. Loy, “Domain generalization: A survey,” IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 45, no. 4, pp. 4396–4415, 2022.
- [2] B. Sun and K. Saenko, “Deep coral: Correlation alignment for deep domain adaptation,” in Computer Vision–ECCV 2016 Workshops: Amsterdam, The Netherlands, October 8-10 and 15-16, 2016, Proceedings, Part III 14. Springer, 2016, pp. 443–450.
- [3] M. Arjovsky, L. Bottou, I. Gulrajani, and D. Lopez-Paz, “Invariant risk minimization,” arXiv preprint arXiv:1907.02893, 2019.
- [4] S. Sagawa, P. W. Koh, T. B. Hashimoto, and P. Liang, “Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization,” arXiv preprint arXiv:1911.08731, 2019.
- [5] Y. Shi, J. Seely, P. H. Torr, N. Siddharth, A. Hannun, N. Usunier, and G. Synnaeve, “Gradient matching for domain generalization,” arXiv preprint arXiv:2104.09937, 2021.
- [6] S. Yan, C. Liu, Z. Yu, L. Ju, D. Mahapatra, B. Betz-Stablein, V. Mar, M. Janda, P. Soyer, and Z. Ge, “Prompt-driven latent domain generalization for medical image classification,” arXiv preprint arXiv:2401.03002, 2024.
- [7] M. Noori, M. Cheraghalikhani, A. Bahri, G. A. V. Hakim, D. Osowiechi, I. B. Ayed, and C. Desrosiers, “Tfs-vit: Token-level feature stylization for domain generalization,” Pattern Recognition, vol. 149, p. 110213, 2024.
- [8] E. D. Cubuk, B. Zoph, J. Shlens, and Q. V. Le, “Randaugment: Practical automated data augmentation with a reduced search space,” in Proceedings of the IEEE/CVF conference on computer vision and pattern recognition workshops, 2020, pp. 702–703.
- [9] K. Zhou, Y. Yang, T. Hospedales, and T. Xiang, “Learning to generate novel domains for domain generalization,” in Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part XVI 16. Springer, 2020, pp. 561–578.
- [10] M. Scalbert, M. Vakalopoulou, and F. Couzinié-Devy, “Test-time image-to-image translation ensembling improves out-of-distribution generalization in histopathology,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2022, pp. 120–129.
- [11] P. Izmailov, D. Podoprikhin, T. Garipov, D. Vetrov, and A. G. Wilson, “Averaging weights leads to wider optima and better generalization,” arXiv preprint arXiv:1803.05407, 2018.
- [12] J. Cha, S. Chun, K. Lee, H.-C. Cho, S. Park, Y. Lee, and S. Park, “Swad: Domain generalization by seeking flat minima,” Advances in Neural Information Processing Systems, vol. 34, pp. 22 405–22 418, 2021.
- [13] M. Zhang, J. Lucas, J. Ba, and G. E. Hinton, “Lookahead optimizer: k steps forward, 1 step back,” Advances in neural information processing systems, vol. 32, 2019.
- [14] J. Zhang, S. Liu, J. Song, T. Zhu, Z. Xu, and M. Song, “Lookaround optimizer: steps around, 1 step average,” Advances in Neural Information Processing Systems, vol. 36, 2024.
- [15] R. Entezari, H. Sedghi, O. Saukh, and B. Neyshabur, “The role of permutation invariance in linear mode connectivity of neural networks,” arXiv preprint arXiv:2110.06296, 2021.
- [16] D. Tellez, M. Balkenhol, I. Otte-Höller, R. van de Loo, R. Vogels, P. Bult, C. Wauters, W. Vreuls, S. Mol, N. Karssemeijer et al., “Whole-slide mitosis detection in h&e breast histology using phh3 as a reference to train distilled stain-invariant convolutional networks,” IEEE transactions on medical imaging, vol. 37, no. 9, pp. 2126–2136, 2018.
- [17] P. W. Koh, S. Sagawa, H. Marklund, S. M. Xie, M. Zhang, A. Balsubramani, W. Hu, M. Yasunaga, R. L. Phillips, I. Gao et al., “Wilds: A benchmark of in-the-wild distribution shifts,” in International conference on machine learning. PMLR, 2021, pp. 5637–5664.
- [18] M. Xu, J. Zhang, B. Ni, T. Li, C. Wang, Q. Tian, and W. Zhang, “Adversarial domain adaptation with domain mixup,” in Proceedings of the AAAI conference on artificial intelligence, vol. 34, no. 04, 2020, pp. 6502–6509.
- [19] T. Matsuura and T. Harada, “Domain generalization using a mixture of multiple latent domains,” in Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, no. 07, 2020, pp. 11 749–11 756.
- [20] D. Krueger, E. Caballero, J.-H. Jacobsen, A. Zhang, J. Binas, D. Zhang, R. Le Priol, and A. Courville, “Out-of-distribution generalization via risk extrapolation (rex),” in International Conference on Machine Learning. PMLR, 2021, pp. 5815–5826.
- [21] K. Ahuja, E. Caballero, D. Zhang, J.-C. Gagnon-Audet, Y. Bengio, I. Mitliagkas, and I. Rish, “Invariance principle meets information bottleneck for out-of-distribution generalization,” Advances in Neural Information Processing Systems, vol. 34, pp. 3438–3450, 2021.
- [22] H. Yao, Y. Wang, S. Li, L. Zhang, W. Liang, J. Zou, and C. Finn, “Improving out-of-distribution robustness via selective augmentation,” in International Conference on Machine Learning. PMLR, 2022, pp. 25 407–25 437.
- [23] V. Khamankar, S. Bera, S. Bhattacharya, D. Sen, and P. K. Biswas, “Histopathological image analysis with style-augmented feature domain mixing for improved generalization,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2023, pp. 285–294.
- [24] Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, M. March, and V. Lempitsky, “Domain-adversarial training of neural networks,” Journal of machine learning research, vol. 17, no. 59, pp. 1–35, 2016.
- [25] D. Kim, Y. Yoo, S. Park, J. Kim, and J. Lee, “Selfreg: Self-supervised contrastive regularization for domain generalization,” in Proceedings of the IEEE/CVF International Conference on Computer Vision, 2021, pp. 9619–9628.
- [26] S. Yan, C. Liu, Z. Yu, L. Ju, D. Mahapatra, V. Mar, M. Janda, P. Soyer, and Z. Ge, “Epvt: Environment-aware prompt vision transformer for domain generalization in skin lesion recognition,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2023, pp. 249–259.
- [27] R. Yamashita, J. Long, S. Banda, J. Shen, and D. L. Rubin, “Learning domain-agnostic visual representation for computational pathology using medically-irrelevant style transfer augmentation,” IEEE Transactions on Medical Imaging, vol. 40, no. 12, pp. 3945–3954, 2021.