Drawf

Haozhe Luo

1 Introduction

The importance of interpretability in medical imaging stems from the critical need for transparency and trust in healthcare applications of AI. Historically, medical imaging analysis focused primarily on accuracy, but as AI integration has grown, so has the emphasis on understandable and explainable AI systems. This evolution addresses the inherent complexities of medical data, potential risks of misdiagnosis, and ethical concerns. Current challenges include achieving a balance between interpretability and accuracy, handling diverse and intricate medical datasets, and navigating the legal and ethical ramifications of AI deployment in medical settings. The advancement of explainable AI (XAI) aims to render AI decision-making processes in medical imaging more comprehensible, promoting reliability and enabling healthcare professionals to effectively integrate AI tools into clinical practice. Current XAI methods interpret the output of the model through different means. However, due to the uncertainty and complexity of the patterns learned by deep learning itself, it is difficult to translate them into an intuitive interpretation for the user. In this work, we address the ”human out of the loop” and ”trustworthiness” issue in medical image analysis for interpretability. We incorporate medical professionals into the loop, utilizing their insights to enhance interpretability maps, aiming to align deep learning explanations more closely with medical intuition. This approach improves the relevance and utility of deep learning interpretations in medical diagnostics by leveraging expert feedback.

2 Definition of the problem

The goal is to optimize a multi-label classification model in medical imaging, incorporating the use of cross-attention feature maps to enhance interpretability. The optimization problem can be formulated as follows:

minθtotal(θ)=cls(θ)+(1λ)atten(θ)subscript𝜃subscript𝑡𝑜𝑡𝑎𝑙𝜃subscript𝑐𝑙𝑠𝜃1𝜆subscript𝑎𝑡𝑡𝑒𝑛𝜃\min_{\theta}\mathcal{L}_{total}(\theta)=\mathcal{L}_{cls}(\theta)+(1-\lambda)% \mathcal{L}_{atten}(\theta)roman_min start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l end_POSTSUBSCRIPT ( italic_θ ) = caligraphic_L start_POSTSUBSCRIPT italic_c italic_l italic_s end_POSTSUBSCRIPT ( italic_θ ) + ( 1 - italic_λ ) caligraphic_L start_POSTSUBSCRIPT italic_a italic_t italic_t italic_e italic_n end_POSTSUBSCRIPT ( italic_θ ) (1)

where:

  • θ𝜃\thetaitalic_θ represents the parameters of the model.

  • total(θ)subscript𝑡𝑜𝑡𝑎𝑙𝜃\mathcal{L}_{total}(\theta)caligraphic_L start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l end_POSTSUBSCRIPT ( italic_θ ) is the total loss function to be minimized.

  • cls(θ)subscript𝑐𝑙𝑠𝜃\mathcal{L}_{cls}(\theta)caligraphic_L start_POSTSUBSCRIPT italic_c italic_l italic_s end_POSTSUBSCRIPT ( italic_θ ) is the loss function associated with the multi-label classification accuracy.

  • atten(θ)subscript𝑎𝑡𝑡𝑒𝑛𝜃\mathcal{L}_{atten}(\theta)caligraphic_L start_POSTSUBSCRIPT italic_a italic_t italic_t italic_e italic_n end_POSTSUBSCRIPT ( italic_θ ) corresponds to the loss function for the interpretability, which measures the effectiveness of the cross-attention feature maps in highlighting relevant features.

  • λ𝜆\lambdaitalic_λ is a hyperparameter that balances the trade-off between classification accuracy and interpretability.

The objective is to simultaneously enhance the classification performance and the reliability of the interpretability provided by the cross-attention maps. This involves refining the model’s focus on critical areas in the images, which are significant for accurate diagnosis and transparent decision-making processes.

3 Literature Review

3.1 General Overview of Explainability Approaches in Deep Learning

Explainable Artificial Intelligence (XAI) has emerged as a multifaceted field aimed at illuminating the inner workings of deep learning models. It encompasses a broad spectrum of strategies designed to make neural networks’ decisions transparent, especially in sensitive sectors such as healthcare. Model-agnostic techniques, such as LIME (Local Interpretable Model-agnostic Explanations) [ribeiro2016should] and SHAP (SHapley Additive exPlanations)[walia2022using], offer insights independent of model architecture, while model-specific methods delve into the architecture’s internals, like DeepLIFT (Deep Learning Important FeaTures) and layer-wise relevance propagation. Visual tools like Grad-CAM (Gradient-weighted Class Activation Map**)[selvaraju2017grad] provide saliency maps to highlight decision-critical regions in image data. Example-based methods, including case-based reasoning[gu2020case, duck2022analogy] and prototype learning[li2021adaptive, xu2020attribute], offer comparative insights by aligning model decisions with known examples. Collectively, these methods strive to bridge the gap between AI’s complexity and the need for human-intelligible explanations, fostering trust and facilitating model validation and improvement.

3.2 Explanation by Saliency

Explanation by Saliency in deep learning offers a visual approach to decipher complex model decisions[selvaraju2017grad, malhi2019explaining, rio2020understanding, lin2020covid], particularly within the realm of medical imaging. The primary strength of this method lies in its ability to produce saliency maps that highlight the areas within the input that significantly influence the model’s output. These visual cues enhance interpretability, allowing for a more intuitive understanding of the model’s focus and considerations. In clinical practice, such explanations can bridge the gap between AI models and human experts, fostering greater trust and collaboration. However, the method is not without its challenges. Saliency maps can sometimes be misinterpreted or might not provide a full explanation for a decision, especially when dealing with the high-dimensional data in medical applications. Additionally, the generation of saliency maps can introduce its own biases or highlight features that, while statistically significant, may lack practical clinical relevance, thus necessitating cautious application and expert oversight.

3.3 Explanation by Feature Attribution

Explanation by Feature Attribution in deep learning is pivotal for understanding complex model decisions[malhi2019explaining, eitel2019testing, wang2021interpretability]. Its primary advantage is enhancing interpretability by pinpointing influential features, fostering model transparency and trust, especially in critical healthcare applications. This method demystifies the model’s logic, facilitating clinical validation and offering insights into the model’s learning process. However, it also faces significant challenges. The complexity of medical data can lead to oversimplification in interpretations, potentially omitting critical nuances. Moreover, there’s a risk of bias in feature selection, where the model might focus on features that are not clinically relevant, leading to misleading interpretations[geirhos2020shortcut].

3.4 Explanation by Text

Explanation by Text in deep learning[zhang2017mdnet, wang2018tienet, li2018hybrid, gale2019producing], especially for medical imaging, plays a crucial role in conveying complex model decisions in an understandable format. Its main advantage lies in its ability to translate intricate data interpretations into readable text, enhancing interpretability and user-friendliness for medical professionals. This approach helps in clarifying the reasoning behind model predictions, making AI tools more transparent and trustworthy in clinical settings. However, it also encounters significant hurdles. The simplification necessary for textual explanations can lead to a loss of detail, potentially overlooking crucial subtleties of medical data[monshi2020deep]. Furthermore, the reliance on language models raises concerns about the accuracy and relevance of these textual interpretations, which might not always align perfectly with clinical realities or model intricacies.

3.5 Explanation by Examples

Explanation by Examples in the context of deep learning for medical imaging[tschandl2019diagnostic, lamy2019explainable, barnett2021interpretable, barata2021improving] is essential for illustrating model decisions through practical demonstrations. Its primary strength lies in offering concrete cases or scenarios where the model’s predictions can be observed and analyzed, providing a tangible and relatable understanding for clinicians. This method helps in contextualizing the AI’s decision-making process, making it more approachable and verifiable in a clinical environment. Nonetheless, this technique faces its own set of challenges. The diversity and complexity of medical cases mean that examples might not always represent the full scope of the model’s capabilities or might focus on atypical scenarios. Additionally, there’s a risk of reinforcing biases if the selected examples are not representative or diverse enough, potentially leading to skewed interpretations and applications.

This integration of XAI into medical imaging marks a crucial step towards transparent, interpretable AI in healthcare. The following sections explore different methodologies applied to XAI in medical images.

4 Datasets

4.1 ChestX-Det Dataset

The ChestX-Det dataset is an expertly curated subset of the public NIH ChestX-ray14 dataset, designed specifically to support the development and evaluation of automated diagnostic models in medical imaging, particularly for chest X-rays. This dataset comprises approximately 3,500 high-quality images across 13 common thoracic disease categories. Each image in the dataset has been annotated with instance-level details (both bounding boxes and masks) by three board-certified radiologists to ensure the accuracy and reliability of the annotations. The ChestX-Det dataset is unique in its provision of detailed instance annotations, as opposed to the more common class labels, thereby enabling more precise training and validation of models designed for detailed localization and identification of thoracic diseases in X-ray images. This dataset serves as a critical resource for advancing research in automated chest X-ray analysis by providing a robust benchmark for evaluating the performance of novel diagnostic algorithms.

4.2 CheXlocalize Dataset

The CheXlocalize dataset serves as a specialized grounding verification subset of the CheXpert dataset, specifically curated for enhancing the validation of localization models in chest X-ray (CXR) interpretation. This dataset includes a subset of CXRs specifically chosen from the CheXpert dataset to focus on the precise localization of thoracic diseases. It comprises detailed annotations that go beyond the typical classification of pathologies, including exact coordinates for disease manifestations.

4.3 Vindr-CXR Dataset

The Vindr-CXR dataset is designed to support the development and evaluation of diagnostic models for chest X-ray imaging. It includes around 18,000 high-resolution images with detailed annotations for various thoracic conditions. Each image is annotated with bounding boxes and segmentation masks by experienced radiologists, ensuring high accuracy. This dataset is essential for training and validating models that focus on the localization and identification of thoracic pathologies and serves as a benchmark for evaluating the performance of automated chest X-ray analysis algorithms.

5 Method

Although the XAI methods mentioned in related work each have their advantages, the significant variability in the features learned by the different model means that the areas of focus during prediction cannot be adequately delineated solely by information from partial layers. Furthermore, due to a lack of prior knowledge about classes, these methods exhibit insensitivity to specific features of different categories. In this section, we introduce a prior-knowledge injected feature map refinement approach capable of enhancing any pre-trained model. By employing a series of fine-tuning steps, this method enables the model to generate more compact, consistent, and precise feature maps for predictions across different classes.

To enhance clinicians’ focus during the diagnostic process, we utilize annotation maps (bounding boxes or segmentation maps) as auxiliary information. This ensures that our model generates a more compact and accurate attention map, thereby improving clinicians’ confidence in the classification results.

5.1 Construction of amendments

To better formulate amendments to help training. We define the amendments into 3 parts. The first part is task assignment. To be specific, it’s a prompt like ”Refine saliency map”; The second part is for disease-specific adjustment, we add the specific disease name here to retain the disease related saliency map preference. The third part is the discription. During the training phase, we use the discription created by clinicians mixed with no discription ones(the text input only contains part1 and part2) to train our model. During the inference phase, we use part 1 and part 2 only to refine the saliency map.

To alleviate the workload of clinical practitioners and generate more impactful prompts, we propose the following amendments to optimize the structure of prompts using ChatGPT:

  • Language Style Modification: Transitioning from a formal to an informal style, or vice versa, to better align with the communication preferences of the target audience.

  • Information Quantity Adjustment: Increasing or decreasing the amount of information contained within the prompt to enhance clarity and relevance.

  • Structural Alteration: Modifying the prompt’s structure, such as transforming a question into a command, to streamline interaction and comprehension.

  • Incorporation of Domain-specific Terminology: Adding or modifying key terms related to the specific application scenario to ensure precision and applicability.

These strategies aim to refine the efficacy and applicability of prompts in clinical settings, thereby supporting medical professionals in their decision-making processes.

5.2 Network Architecture

The objective of our study is to refine the attention map by learning from disease-specific instances. We introduce a disease-weighted attention map refinement network (Dwarf), which is based on a Vision-Language Model (VLM). The input image, denoted as x𝑥xitalic_x, is fed into the image encoder with parameters fθisubscript𝑓subscript𝜃𝑖f_{\theta_{i}}italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT, resulting in the embedding vI=fθi(x)superscript𝑣𝐼subscript𝑓subscript𝜃𝑖𝑥v^{I}=f_{\theta_{i}}(x)italic_v start_POSTSUPERSCRIPT italic_I end_POSTSUPERSCRIPT = italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ). This embedding, along with disease-specific textual information from a frozen text encoder, is processed through cross-attention layers, denoted as CrossAttentionθcasubscriptCrossAttentionsubscript𝜃ca\text{CrossAttention}_{\theta_{\text{ca}}}CrossAttention start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT ca end_POSTSUBSCRIPT end_POSTSUBSCRIPT. Specifically, we use the embedding eMcsubscript𝑒subscript𝑀𝑐e_{M_{c}}italic_e start_POSTSUBSCRIPT italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT from the image as the query and the text embedding etsubscript𝑒𝑡e_{t}italic_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT as the key and value, resulting in the cross-attention output:

eMct=CrossAttentionθca(eMc,et)subscript𝑒subscript𝑀𝑐𝑡subscriptCrossAttentionsubscript𝜃casubscript𝑒subscript𝑀𝑐subscript𝑒𝑡e_{M_{c}t}=\text{CrossAttention}_{\theta_{\text{ca}}}(e_{M_{c}},e_{t})italic_e start_POSTSUBSCRIPT italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = CrossAttention start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT ca end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_e start_POSTSUBSCRIPT italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )

The final classification output y𝑦yitalic_y is generated, and class-wise saliency maps Mcsubscript𝑀𝑐M_{c}italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT for each class c𝑐citalic_c are produced, enhancing the interpretability of the model by highlighting relevant regions in the input image.

5.3 Finding-related Attention Map Refinement Network

To refine the saliency map with finding-related prior knowledge, we introduce our Dwarf module, as shown in Fig.2. The overall structure of Dwarf consists of a pretrained Vision-Language Model (VLM), denoted as fvlmsubscript𝑓vlmf_{\text{vlm}}italic_f start_POSTSUBSCRIPT vlm end_POSTSUBSCRIPT, and expert heads fheadssubscript𝑓headsf_{\text{heads}}italic_f start_POSTSUBSCRIPT heads end_POSTSUBSCRIPT.

For each epoch, given an input attention map Mch×wsubscript𝑀𝑐superscript𝑤M_{c}\in\mathbb{R}^{h\times w}italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_h × italic_w end_POSTSUPERSCRIPT of class c𝑐citalic_c , we initially utilize a finding-specific head to project the attenion map from its origin embedding space to the visualization space for clinicians. Finally we get the segmentation map Mc=fhead(Mc)superscriptsubscript𝑀𝑐subscript𝑓headsubscript𝑀𝑐M_{c}^{\prime}=f_{\text{head}}(M_{c})italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_f start_POSTSUBSCRIPT head end_POSTSUBSCRIPT ( italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ).

5.4 Cyclic Training for Finding-Specific Knowledge Accumulation

To accumulate finding-specific knowledge effectively, we introduce a cyclic training process. The cyclic training mechanism is designed to iteratively refine the network’s understanding and segmentation of specific findings. During each epoch, the process involves:
1. **Data Collection**: Initially, each multi-label classification task is decomposed into multiple single-label tasks. For each finding (e.g., Atelectasis, Cardiomegaly, Consolidation), the corresponding images and their labels are extracted, creating a dataset focused on one specific finding at a time. This allows the model to concentrate on learning the characteristics and segmentation of each individual finding without interference from others.
2. **Segmentation**: For a given finding, an initial segmentation map Mcsubscript𝑀𝑐M_{c}italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT is generated using the corresponding segmentation expert head fheadcsubscript𝑓subscripthead𝑐f_{\text{head}_{c}}italic_f start_POSTSUBSCRIPT head start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUBSCRIPT. This map highlights the regions of interest related to the specific finding.
3. **Classification and Segmentation Feedback Loop**: The initial segmentation map Mcsubscript𝑀𝑐M_{c}italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and classification outputs are evaluated against the ground truth. The discrepancies between the initial predictions and the ground truth are used to calculate the loss for both classification and segmentation tasks. This loss is then used to update the network parameters.
4. **Iterative Refinement**: This process is repeated cyclically for each finding. The model iteratively updates its parameters to improve the segmentation maps Mcsubscript𝑀𝑐M_{c}italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and classification accuracy. By focusing on one finding at a time, the network continuously refines its ability to accurately identify and segment specific medical conditions.

By incorporating cyclic training, the network can effectively refine its ability to identify and segment specific medical findings, leading to improved diagnostic performance.

5.5 Distill information to classification model

Besides supervised by ground truth, Dwarf could also utilize experts as teachers to impart knowledge about human preferences and disease characteristics to the classification model. As illustrated in Fig.2, For each input image x𝑥xitalic_x, it is fed into the classification model, generating a corresponding saliency map Mcsubscript𝑀𝑐M_{c}italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, which is then upsampled. This upsampled map, Mcsubscript𝑀𝑐M_{c}italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, is input into a frozen DARF module, producing a revised saliency map Mcsuperscriptsubscript𝑀𝑐M_{c}^{\prime}italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. A Dice loss is calculated between Mcsubscript𝑀𝑐M_{c}italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and Mcsuperscriptsubscript𝑀𝑐M_{c}^{\prime}italic_M start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT to facilitate the distillation of knowledge from the DARF model into the classification model, effectively transferring crucial information regarding human preferences and disease-specific attributes. We ablate the performance with ground truth as shown in Tab.8

5.6 Losses and network initialization

5.6.1 Loss

In our framework, we adopt cross entropy loss as multi-label classification loss while a modifed dice loss for optimizing attention map.
In medical image analysis, attention maps are usually considered to be sensitive to detect disease related markers. Unlike traditional segmentation tasks, the training and validation for attention map often utilize only positive samples, which can inadvertently lead to a model overestimating the presence of certain features or conditions, known as false positives. To counteract this, we implement a False Positive Suppression technique that adjusts the model’s scoring function.

The standard metric used is the Soft Dice Score, defined as:

Soft Dice Score=2×intersection+smooth+εcardinality+smooth+εSoft Dice Score2intersectionsmooth𝜀cardinalitysmooth𝜀\text{Soft Dice Score}=\frac{2\times\text{intersection}+\text{smooth}+% \varepsilon}{\text{cardinality}+\text{smooth}+\varepsilon}Soft Dice Score = divide start_ARG 2 × intersection + smooth + italic_ε end_ARG start_ARG cardinality + smooth + italic_ε end_ARG

To mitigate the issue of false positives, we introduce a penalty for false predictions:

Soft Dice Score with FP Penalty=2×intersection+smooth+εadjusted cardinality+smooth+εSoft Dice Score with FP Penalty2intersectionsmooth𝜀adjusted cardinalitysmooth𝜀\text{Soft Dice Score with FP Penalty}=\frac{2\times\text{intersection}+\text{% smooth}+\varepsilon}{\text{adjusted cardinality}+\text{smooth}+\varepsilon}Soft Dice Score with FP Penalty = divide start_ARG 2 × intersection + smooth + italic_ε end_ARG start_ARG adjusted cardinality + smooth + italic_ε end_ARG

where the adjusted cardinality is modified to penalize false positives:

adjusted cardinality=cardinality+(weightfp1)×false positivesadjusted cardinalitycardinalitysubscriptweightfp1false positives\text{adjusted cardinality}=\text{cardinality}+(\text{weight}_{\text{fp}}-1)% \times\text{false positives}adjusted cardinality = cardinality + ( weight start_POSTSUBSCRIPT fp end_POSTSUBSCRIPT - 1 ) × false positives

5.6.2 network parameters’ initialization

The image encoder, text encoder, and cross-attention layers of the pretrained network are based on the DeViDe. Besides, we introduce additional disese-related segmentation heads for each specific disease.

The Identity Enhanced Initialization (IEI) technique is proposed to overcome the shortcomings of random or simplistic initializations that often lead to suboptimal learning paths. IEI focuses on initializing the model parameters in a way that enhances its sensitivity to the structures relevant to specific diseases, steering the learning process towards more effective feature recognition from the onset.

This initialization ensures that the predictions are based not merely on the easiest or most apparent features (short-cut path) but rather on a comprehensive analysis of the images, facilitated by enhanced parameter settings. The initial parameter settings are thus crucial for the effective training of the model, especially in complex scenarios where accurate segmentation plays a critical role in diagnosis and treatment planning. The qualitative comparison is illustrated as Fig.1.

Through these methodologies, the model’s performance in distinguishing and accurately segmenting medical images is significantly improved, reducing both false positives and enhancing the reliability of the medical diagnosis process.

Refer to caption
Figure 1: With random initialization, the model tends to directly learn shortcut results which always highlight the same area. While using IEI initialization, the model can start from pretrained VLM’s attention to refine its focus.

6 Training Details

With ViT-B as the visual backbone and Med-KEBERT as the textual backbone, we finetune on the ChestX-Det dataset [lian2021structure] on an image size of 224. We utilize the AdamW optimizer with learning rates lr=5×105𝑙𝑟5superscript105lr=5\times 10^{-5}italic_l italic_r = 5 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT. We optimize on V100 16G GPUS with a total batch size of 32 for a total of 500 epochs.

Refer to caption
Figure 2: Flow chart of finetuning the classification model. Our method only trains single disease each epoch with disease name as prompt. For each disease, we add an additional head to map** origin attention to refined segmentation map.

7 Evaluation

To systematically evaluate our proposed framework, we employed two validation methodologies: segmentation metrics and classification metrics.

7.1 Segmentation Metrics (Grounding Metrics)

7.1.1 Dice Coefficient

The Dice Coefficient assesses the alignment between the predicted regions (R^^𝑅\hat{R}over^ start_ARG italic_R end_ARG) and expert annotations (R𝑅Ritalic_R). It is defined as:

Dice=2|R^R||R^|+|R|Dice2^𝑅𝑅^𝑅𝑅\text{Dice}=\frac{2|\hat{R}\cap R|}{|\hat{R}|+|R|}Dice = divide start_ARG 2 | over^ start_ARG italic_R end_ARG ∩ italic_R | end_ARG start_ARG | over^ start_ARG italic_R end_ARG | + | italic_R | end_ARG

This metric emphasizes both precision and recall, capturing spatial overlap effectively. Compared to the Intersection over Union (IoU), defined as:

IoU=|R^R||R^R|IoU^𝑅𝑅^𝑅𝑅\text{IoU}=\frac{|\hat{R}\cap R|}{|\hat{R}\cup R|}IoU = divide start_ARG | over^ start_ARG italic_R end_ARG ∩ italic_R | end_ARG start_ARG | over^ start_ARG italic_R end_ARG ∪ italic_R | end_ARG

the Dice Coefficient generally yields higher values for partial overlaps and is more sensitive to small mismatches.

7.1.2 Hit-rate Metric

The Hit-rate metric evaluates the extent to which algorithmically identified key areas overlap with expert annotations. A higher Hit-rate indicates better alignment between the model’s focus and expert annotations, highlighting improvements in interpretability.

7.2 Classification Metrics

7.2.1 AUC (Area Under the Curve)

The AUC-ROC measures the model’s ability to distinguish between classes across different thresholds. An AUC value close to 1 indicates excellent performance, while a value around 0.5 suggests performance no better than random chance, demonstrating the model’s discriminative power.

7.2.2 MCC (Matthews Correlation Coefficient)

The MCC evaluates binary classification quality, providing a balanced score even with unbalanced classes. Values range from -1 (total disagreement) to +1 (perfect prediction), with 0 indicating random guessing. MCC is particularly useful in fields like clinical diagnostics where balanced accuracy is critical.

7.2.3 F1 Score

The F1 Score is the harmonic mean of precision and recall, ranging from 0 to 1. A score of 1 indicates perfect precision and recall. This metric is crucial in scenarios where both false positives and false negatives carry significant costs and performs well under uneven class distribution.

8 Results

In this section, we evaluate our method with three prominent datasets and compare it with the baselines.

8.1 Finding specific grounding experts

In our derived version, we trained disease-specific experts to mimic human attention to categorised datasets without annotations as shown in Tab.1. We divide it into 3 versions. four findings version includes common findings includes Atelectasis, Cardiomegaly, Consolidation and Effusion. The improved 7 findings includes additional Diffuse Nodule, Emphysema and Mass. Those seven findings are with relative high performance. The full version is original ChexDet dataset with 14 findings.

Table 1: Segmentation Experts for each disease
Finding Dice score Finding Dice score
Atelectasis 0.3031 Fracture 0.0755
Calcification 0.0073 Mass 0.4583
Cardiomegaly 0.8342 Nodule 0.0355
Consolidation 0.4288 Pleural Thickening 0.1355
Diffuse Nodule 0.4441 Pneumothorax 0.0707
Effusion 0.3525
Emphysema 0.4355
Fibrosis 0.1621
Table 2: Comparison between Dwarf and naive training with CLS Loss and Seg Loss
Backbone Disease number Dice AUC Max Dice Max AUC F1/MCC
CLS Loss only 4 0.1438 0.8680 0.1438 0.8680 -
7 0.1903 0.8519 0.1903 0.8519 -
14 0.1385 0.8090 0.1390 0.8090 0.4857/0.4265
Dwarf 4 0.3854 0.8729 0.4147 0.8871 -
7 0.3492 0.8508 0.3559 0.8717 0.6017/0.5201
14 0.1646 0.8114 0.1805 0.8157 0.5344/0.4992
Table 3: Comparison between Dwarf and naive training with CLS Loss and Seg Loss
Method Dataset/Metrics Max AUC Max Dice F1 score MCC
w/o ChestX-Det 0.8090 0.1390 0.4857 0.4265
ChestX-Det 0.8157 0.1805 0.5344 0.4992
w/o cheXlocalize 0.8364 0.1191 0.6286 0.5018
cheXlocalize 0.8488 0.1289 0.6364 0.5047
w/o Vindr-CXR 0.7851 0.0723 0.4520 0.3648
Vindr-CXR 0.7977 0.1062 0.4591 0.3848
Table 4: Comparison between cyclic training and direct multi-label Segmentation Loss
Backbone Disease number Max Dice Max AUC
CLS Loss + Seg Loss 7 0.2335 0.8532
Dwarf 7 0.3559 0.871

8.2 Ablations

In this section, we conduct ablation study to investigate the contribution of disease specific head and our cyclic training strategy. With our cyclic training framework, the Dice score result of multi-label segmentation performance improved from 0.2335 to 0.3492 as well as the max AUC score improves from 0.8532 to 0.8717 as shown in Tab.5. Also, We explored the scalability and stability of Dwarf. By comparing Dwarf with naive training using only CLS Loss and Seg Loss, we observed significant improvements across different numbers of diseases. For 7 diseases, the Dice score improved from 0.1903 to 0.3492, and the max AUC score increased from 0.8519 to 0.8717. Similarly, for 4 diseases, the Dice score improved from 0.1438 to 0.3854, and the max AUC score increased from 0.8660 to 0.8871. These results, as shown in the table below, highlight the consistent enhancement provided by Dwarf across different numbers of diseases. Furthermore, we investigate the performance of different prompts (Tab.6), specifically those involving the disease name and visual cues created by radiologists. Directly using the disease name yields better performance, likely because the pretraining prompt consists only of the disease name. Finally, since our model is only trained once per epoch, it might result in insufficient training within the same number of epochs. Therefore, we extended the training from 500 epochs to 1000 epochs to explore the scalability of the model’s performance. We found that the Dice score improved from 0.1805 to 0.2302 as shown in Tab.7.

Table 5: Ablations of disease-specific head
Method Max DICE Max AUC
Directly optimize 0.2288 0.8663
Introducing disease-specific head 0.3559 0.8732
Table 6: Ablations of text input context
Method Dataset Max DICE Max AUC
Disease name only ChestX-Det 0.1805 0.8157
With detailed visual cues ChestX-Det 0.1769 0.8125
Table 7: Ablations of training epochs
Method Dataset Max DICE Max AUC
500 epochs ChestX-Det 0.1805 0.8157
1000 epochs ChestX-Det 0.2302 0.8231
Table 8: Ablations of expert supervision
Method Disease numbers DICE AUC Max DICE Max AUC
CLS Loss only 7 0.1438 0.8680 0.1438 0.8680
Dwarf (expert teachers) 7 0.3171 0.8473 0.3694 0.8757
Dwarf 7 0.3856 0.8578 0.3911 0.8766

8.3 Qualitative Result

To enhance the clarity of the explanation regarding the visual explainability representation of our method, we illustrate the attention map using the test set of the CheX-Det dataset. As depicted in Fig.3, our Dwarf method significantly improves the focus of the classification model’s attention. This enhancement allows the model to more precisely highlight the relevant areas that form the basis for its classification decisions. Also, we illustrates the attention map for different findings in Fig.4.

Refer to caption
Figure 3: Qualitative results of training with and without the Dwarf architecture demonstrate that utilizing our Dwarf framework consistently enhances the aggregation of feature maps and provides prior region information.
Refer to caption
Figure 4: Qualitative results of our method. Our method could precisely locates the diseases with only attention map as explainations.
Refer to caption
Figure 5: Qualitative results of more findings.

9 Conclusion

In this research, we have developed a two-stage saliency map revision strategy. This approach effectively integrates disease-related knowledge and clinicians’ preferences into the generation of saliency maps. By incorporating this methodology, we are also introducing clinicians into the AI training loop. This inclusion fosters enhanced human-computer interactions, making the AI more intuitive and relevant for medical professionals. This strategy not only improves the accuracy of the AI but also makes it more user-friendly for clinicians, ensuring that their expertise and insights are reflected in the AI’s learning process.