(csquotes) Package csquotes Warning: Load ’inputenc’ before ’csquotes’

PromptSync: Bridging Domain Gaps in Vision-Language Models through Class-Aware Prototype Alignment and Discrimination

Anant Khandelwal
Glance AI
[email protected]
Abstract

The potential for zero-shot generalization in vision-language (V-L) models such as CLIP has spurred their widespread adoption in addressing numerous downstream tasks. Previous methods have employed test-time prompt tuning to adapt the model to unseen domains, but they overlooked the issue of imbalanced class distributions. In this study, we explicitly address this problem by employing class-aware prototype alignment weighted by mean class probabilities obtained for the test sample and filtered augmented views. Additionally, we ensure that the class probabilities are as accurate as possible by performing prototype discrimination using contrastive learning. The combination of alignment and discriminative loss serves as a geometric regularizer, preventing the prompt representation from collapsing onto a single class and effectively bridging the distribution gap between the source and test domains. Our method, named PromptSync, synchronizes the prompts for each test sample on both the text and vision branches of the V-L model. In empirical evaluations on the domain generalization benchmark, our method outperforms previous best methods by 2.33% in overall performance, by 1% in base-to-novel generalization, and by 2.84% in cross-dataset transfer tasks.

1 Introduction

Training Vision-Language Models (VLMs) with large-scale image-text pairs is known for imparting robust generalization capabilities across diverse downstream tasks [31, 20, 44, 41, 42, 1]. However, training these models from scratch for each downstream task is very time-consuming. Moreover, the essence of pre-training with a large-scale dataset is lost when the pre-trained model is not generalizable across downstream tasks. This is due to unexpected changes in data distribution, and the sensitivity to these shifts leads to a decline in performance [16, 32, 30]. To tackle this, there exist three most commonly used techniques: fine-tuning [28], prompt tuning [47], adapter [13], and LoRA [19]. Among these, prompt tuning is the simple, recent, and most widely used technique for foundation models [47, 46, 21, 22, 43]. However, prompt learning/tuning approaches are used during the training phase to learn representative prompts based on the training data for the downstream task. This approach does not specifically address the distribution shift present in the dataset. Recent methods, TPT [35] and PromptAlign [33], adjusts the learnable prompt tokens dynamically during testing to enable test-time adaptation and align the context of the test sample as per the seen distribution by the model. Specifically, TPT [35] updates the learnable prompt tokens (kee** the model parameters frozen) by minimizing the entropy of top-Nksubscript𝑁𝑘N_{k}italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT confidently predicted samples, acquired through diverse augmented views of the incoming test sample. Additionally, PromptAlign [33] aligns token distribution of the test sample in the visual branch with the pre-computed statistics of the complete proxy source dataset irrespective of the fact that one class distribution may have different mean and variance than the other classes.

In this work, we demonstrate the multi-modal test-time adaptation of prompts. In contrast to PromptAlign, which aligns the distribution for the complete source dataset with test sample, we propose class-aware prototype alignment to address the distributional shift on a class-wise basis. For instance, in an open world there are 360 different breeds of dogs compared to only 71 for cats, leading to one class having higher variance than the others. For each test sample, we obtain randomly augmented views (for both text and image) that are fed to the model for prompt tuning on both the textual and visual branches. We adapt the learnable prompt tokens by aligning the prototype for test sample and confident augmented views with the pre-computed class prototypes (obtained from the proxy source dataset) weighted by the mean probability of each class obtained from confident augmented views. Before alignment, we update the prompt tokens on both the text and visual branches using prototype discrimination and then use updated prompts to align the test sample and augmented views with class prototypes using mean class probabilities. This is based on the idea that prototype vector can capture the complete information of mean and variance for each class distribution and hence it mitigates the class collapse (during test time adaptation) due to high variance of particular classes. Empirical evaluation of our methods shows state-of-the-art Top-1 accuracy for three tasks: domain generalization, base-to-novel generalization, and cross-dataset transfer. This validates the effectiveness of our method in enhancing zero-shot generalization. Our contributions can be summarized as follows:

  • We propose a class-aware prototype alignment technique for individual test samples to align the context of each test sample with the source distribution on a class-wise basis, thereby mitigating the effects of distributional shift between classes.

  • We propose class-aware prototype discrimination to discover the class distribution for efficient alignment. Additionally, we propose the offline computation of class prototypes from a proxy source dataset for foundation V-L models.

  • We propose multi-modal test-time prompt tuning for both text and visual branches. Empirical evaluation on base-to-novel generalization, domain generalization, and cross-dataset transfer shows the efficiency of our method over existing methods.

2 Related Work

Vision-Language (V-L) foundation models like CLIP [31] and ALIGN [20] have emerged as robust zero-shot generalizable models. They integrate image and text modalities through pre-training on extensive image-text pairs. However, adapting these models to specific downstream tasks with limited data remains challenging. Recent methods explore prompt tuning in CLIP-like models, treating prompts as continuous learnable vectors and fine-tuning them while kee** the model parameters frozen. CoOp [47] proposed fine-tuning CLIP by learning a set of prompts in the text encoder. CoCoOp [46], an improvement over CoOp, dynamically conditions the text prompts by the image embeddings. MaPLe [21] is a deep prompting baseline that tunes prompts on both text and image branches, further conditioning image prompts on text prompts using a V-L coupling function. However, these approaches necessitate training data for prompt learning, limiting adaptation to novel datasets during test time. Recent approaches like TPT [35] aim to learn prompts exclusively at test time but encounter challenges in handling distribution misalignment between CLIP’s pre-training data and downstream test data. PromptAlign [33] addresses this by introducing token distribution alignment in the image branch. However, it does not account for the potential variance in class distributions. In contrast, our method, inspired by a multi-modal prompting variant [21], actively aligns class prototypes by leveraging a proxy dataset as a substitute for unavailable CLIP pre-training data. To our knowledge, our approach is the first to explicitly address class-aware distribution misalignment in V-L foundational models during test time.

3 Methodology

Revisiting CLIP: Our approach is based on the pre-trained V-L model: Contrastive Language-Image Pre-Training (CLIP). It consists of a text and visual encoder (denoted by tsubscript𝑡\mathcal{F}_{t}caligraphic_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and vsubscript𝑣\mathcal{F}_{v}caligraphic_F start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT, respectively, and their pre-trained parameters are represented by θCLIP={θt,θv}subscript𝜃CLIPsubscript𝜃𝑡subscript𝜃𝑣\theta_{\textsc{\tiny{CLIP}}}=\{\theta_{t},\theta_{v}\}italic_θ start_POSTSUBSCRIPT CLIP end_POSTSUBSCRIPT = { italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT }, respectively), used for map** the text and image to the vector representation, respectively. The input image is 𝐗𝐗\mathbf{X}bold_X, which is divided into M𝑀Mitalic_M patches, and the [CLS] token is prepended to these M𝑀Mitalic_M patch tokens that are projected to produce 𝐗~v={𝐞[CLS],𝐞1,𝐞2,𝐞M}subscript~𝐗𝑣subscript𝐞[CLS]subscript𝐞1subscript𝐞2subscript𝐞𝑀\tilde{\mathbf{X}}_{v}=\{\mathbf{e}_{\tiny{\textsc{[CLS]}}},\mathbf{e}_{1},% \mathbf{e}_{2},......\mathbf{e}_{M}\}over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT = { bold_e start_POSTSUBSCRIPT [CLS] end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … … bold_e start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT }, where eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the embedding for the corresponding patch token in 𝐗𝐗\mathbf{X}bold_X. The image encoder produces latent visual feature representation 𝒇~v=v(𝐗~v,θv)subscript~𝒇𝑣subscript𝑣subscript~𝐗𝑣subscript𝜃𝑣\tilde{\boldsymbol{f}}_{v}=\mathcal{F}_{v}(\tilde{\mathbf{X}}_{v},\theta_{v})over~ start_ARG bold_italic_f end_ARG start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT = caligraphic_F start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) with transformer blocks from 𝐗~vsubscript~𝐗𝑣\tilde{\mathbf{X}}_{v}over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT. The class label y𝑦yitalic_y is embedded within a text template, such as a photo of a <CLS> resulting in 𝐗~t={SOS,𝐭1,𝐭2,,𝐭L,𝐜k,EOS}subscript~𝐗𝑡𝑆𝑂𝑆subscript𝐭1subscript𝐭2subscript𝐭𝐿subscript𝐜𝑘𝐸𝑂𝑆\tilde{\mathbf{X}}_{t}=\{\tiny{SOS},\mathbf{t}_{1},\mathbf{t}_{2},...,\mathbf{% t}_{L},\mathbf{c}_{k},\tiny{EOS}\}over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = { italic_S italic_O italic_S , bold_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , bold_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_E italic_O italic_S }, where SOS and EOS are the start and end token embeddings and 𝐭l|l=1Levaluated-atsubscript𝐭𝑙𝑙1𝐿\mathbf{t}_{l}|_{l=1}^{L}bold_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT and 𝐜ksubscript𝐜𝑘\mathbf{c}_{k}bold_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT are the token embeddings corresponding to the text template and the class label, respectively. Similarly, the text encoder tsubscript𝑡\mathcal{F}_{t}caligraphic_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT encodes 𝐗~tsubscript~𝐗𝑡\tilde{\mathbf{X}}_{t}over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with transformer blocks to produce latent text feature representation 𝒇~t=t(𝐗~t,θt)subscript~𝒇𝑡subscript𝑡subscript~𝐗𝑡subscript𝜃𝑡\tilde{\boldsymbol{f}}_{t}=\mathcal{F}_{t}(\tilde{\mathbf{X}}_{t},\theta_{t})over~ start_ARG bold_italic_f end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = caligraphic_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). For zero-shot inference, each text feature for class labels y={1,2,..C}y=\{1,2,.....C\}italic_y = { 1 , 2 , … . . italic_C } is paired with an image feature to compute the similarity score si=sim(𝒇~ti𝒇~v)subscript𝑠𝑖simsubscript~𝒇subscript𝑡𝑖subscript~𝒇𝑣s_{i}=\textrm{sim}(\tilde{\boldsymbol{f}}_{t_{i}}\cdot\tilde{\boldsymbol{f}}_{% v})italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = sim ( over~ start_ARG bold_italic_f end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ over~ start_ARG bold_italic_f end_ARG start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) where sim()sim\textrm{sim}(\cdot)sim ( ⋅ ) denotes cosine similarity. The predicted probability on X for each yisubscript𝑦𝑖y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is given as p(yi|X)=esim(𝒇~ti𝒇~v)/τj=1Cesim(𝒇~tj𝒇~v)/τ𝑝conditionalsubscript𝑦𝑖Xsuperscript𝑒simsubscript~𝒇subscript𝑡𝑖subscript~𝒇𝑣𝜏superscriptsubscript𝑗1𝐶superscript𝑒simsubscript~𝒇subscript𝑡𝑗subscript~𝒇𝑣𝜏p(y_{i}|\textbf{X})=\frac{e^{\textrm{sim}(\tilde{\boldsymbol{f}}_{t_{i}}\cdot% \tilde{\boldsymbol{f}}_{v})/\tau}}{\sum_{j=1}^{C}e^{\textrm{sim}(\tilde{% \boldsymbol{f}}_{t_{j}}\cdot\tilde{\boldsymbol{f}}_{v})/\tau}}italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | X ) = divide start_ARG italic_e start_POSTSUPERSCRIPT sim ( over~ start_ARG bold_italic_f end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ over~ start_ARG bold_italic_f end_ARG start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) / italic_τ end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT sim ( over~ start_ARG bold_italic_f end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ over~ start_ARG bold_italic_f end_ARG start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) / italic_τ end_POSTSUPERSCRIPT end_ARG, where τ𝜏\tauitalic_τ is the temperature of softmax.
Prompt Tuning: CLIP integrates a considerable pool of knowledge derived from its training on millions of image-text pairs characterized by varying degrees of noise. Prompt tuning methods aim to extract the rich features learned by the CLIP model. Recent approaches [46, 47, 21, 3, 43] append extra learnable prompts to the input of image and text encoders while kee** them frozen. Modified input prompts with frozen encoders generate undistorted and rich CLIP features, where prompt tuning tries to map the context to the source distribution, i.e., the CLIP pre-training dataset. In our work, we use a recent multi-modal prompting baseline [21] where prompt tuning is performed on both the text and image encoders. Specifically, the image and text encoders process the input 𝐗~v={𝐞[CLS],𝐩v,𝐞1,𝐞2,𝐞M}subscript~𝐗𝑣subscript𝐞[CLS]subscript𝐩𝑣subscript𝐞1subscript𝐞2subscript𝐞𝑀\tilde{\mathbf{X}}_{v}=\{\mathbf{e}_{\tiny{\textsc{[CLS]}}},\mathbf{p}_{v},% \mathbf{e}_{1},\mathbf{e}_{2},......\mathbf{e}_{M}\}over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT = { bold_e start_POSTSUBSCRIPT [CLS] end_POSTSUBSCRIPT , bold_p start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … … bold_e start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT } and 𝐗~t={SOS,𝐩t,𝐭1,𝐭2,,𝐭L,𝐜k,EOS}subscript~𝐗𝑡𝑆𝑂𝑆subscript𝐩𝑡subscript𝐭1subscript𝐭2subscript𝐭𝐿subscript𝐜𝑘𝐸𝑂𝑆\tilde{\mathbf{X}}_{t}=\{\tiny{SOS},\mathbf{p}_{t},\mathbf{t}_{1},\mathbf{t}_{% 2},...,\mathbf{t}_{L},\mathbf{c}_{k},\tiny{EOS}\}over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = { italic_S italic_O italic_S , bold_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , bold_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_E italic_O italic_S } respectively. The learnable prompts pvsubscriptp𝑣\textbf{p}_{v}p start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT and ptsubscriptp𝑡\textbf{p}_{t}p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT represent the V𝑉Vitalic_V visual and T𝑇Titalic_T textual tokens, respectively. We will call prompts ptsubscriptp𝑡\textbf{p}_{t}p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and pvsubscriptp𝑣\textbf{p}_{v}p start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT as p only. Our approach is based on deep prompting, as in[21], along with text and image prompts at subsequent transformer blocks. We suggest referring to [21] for more details on baseline architecture.
Test Time Adaptation: Test-time adaptation aims to boost generalisation in a zero-shot manner. Existing methods, Test time prompt tuning (TPT)[35] and PromptAlign[33], both are introduced to provide the model context that is customized for each individual test sample in order to extract rich knowledge from CLIP. For both methods, several augmented views (Xtest)subscriptX𝑡𝑒𝑠𝑡\mathcal{H}(\textbf{X}_{test})caligraphic_H ( X start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT ) are generated from the given test sample XtestsubscriptX𝑡𝑒𝑠𝑡\textbf{X}_{test}X start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT. The average entropy for the filtered views (selected using a confidence threshold) is then used to update the prompts p using the following unsupervised objective:

ent=argminpi=1Cp~p(yi|Xtest)logp~p(yi|Xtest)subscript𝑒𝑛𝑡argpminsuperscriptsubscript𝑖1𝐶subscript~𝑝pconditionalsubscript𝑦𝑖subscriptX𝑡𝑒𝑠𝑡subscript~𝑝pconditionalsubscript𝑦𝑖subscriptX𝑡𝑒𝑠𝑡\mathcal{L}_{ent}=\textrm{arg}\underset{\textbf{p}}{\textrm{min}}-\sum_{i=1}^{% C}\tilde{p}_{\textbf{p}}(y_{i}|\textbf{X}_{test})\log\tilde{p}_{\textbf{p}}(y_% {i}|\textbf{X}_{test})caligraphic_L start_POSTSUBSCRIPT italic_e italic_n italic_t end_POSTSUBSCRIPT = arg underp start_ARG min end_ARG - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT p end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | X start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT ) roman_log over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT p end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | X start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT ) (1)

where p~p(yi|Xtest)subscript~𝑝pconditionalsubscript𝑦𝑖subscriptX𝑡𝑒𝑠𝑡\tilde{p}_{\textbf{p}}(y_{i}|\textbf{X}_{test})over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT p end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | X start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT ) is the average of vector class probabilities (over the filtered augmented views) produced by the model. Additionally, PromptAlign uses distribution alignment loss, which aligns the mean and variance of filtered augmented views of the test sample with source statistics across layers of the model.

Refer to caption
Figure 1: Architecture of proposed PromptSync method for zero-shot generalization in CLIP. During test time, we updates the learnable prompts using discriminative and alignment of class prototypes. For a single test example, we obtain multiple augmented views and obtain the mean class probabilities after parameter updates with discriminating loss. Mean class probabilities act as weights in the class-prototype alignment with filtered augmented views. Gradient are accumulated over multiple iterations before final update to the learnable prompts.

3.1 Proposed Method: PromptSync

The multi-modal test-time prompt tuning method, PromptAlign [33], updates text and visual prompts using entropy loss and distribution alignment loss with highly confident augmented views (obtained from a test sample XtestsubscriptXtest\textbf{X}_{\text{test}}X start_POSTSUBSCRIPT test end_POSTSUBSCRIPT). PromptAlign, despite considering the distribution, does not take into account the fact that the distribution of each class/domain can be entirely different from other classes/domains, and hence using the source statistics of mean and variance for distribution alignment can still be suboptimal. Inspired by prototype learning [38] and Extreme-Multi-PatchSSL (EMP-SSL) [39], which establish a prototype/benchmark for each class/sample, we propose class-wise prototype alignment between original and augmented views for both source and test samples. The architecture of PromptSync is shown in Figure 1. We use the parameter update from prototype discrimination to generate the class probabilities for the test sample and its augmented views. We accumulate the average of gradients from prototype alignment loss weighted by class probabilities for confident augmented views. The accumulated gradient over multiple iterations is then applied for prompt tuning during test-time adaptation.

3.2 Class-aware Prototype Generation

We generated prototypes for each class for both text and visual branches. The prototype for each class is computed using proxy source dataset. For a test sample, XtestsubscriptX𝑡𝑒𝑠𝑡\textbf{X}_{test}X start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT and its Nksubscript𝑁𝑘N_{k}italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT random views (generated using a set of augmentations \mathcal{H}caligraphic_H on XtestsubscriptX𝑡𝑒𝑠𝑡\textbf{X}_{test}X start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT) the prototype vector is generated. Let’s denote the token eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT features of a sample x{Xtest+(Xtest)}𝑥subscriptX𝑡𝑒𝑠𝑡subscriptX𝑡𝑒𝑠𝑡x\in\{\textbf{X}_{test}+\mathcal{H}(\textbf{X}_{test})\}italic_x ∈ { X start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT + caligraphic_H ( X start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT ) }, at the output of the text encoder and visual encoder as ET(x,ei)subscript𝐸𝑇𝑥subscript𝑒𝑖E_{T}(x,e_{i})italic_E start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ( italic_x , italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) and EV(x,ei)subscript𝐸𝑉𝑥subscript𝑒𝑖E_{V}(x,e_{i})italic_E start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_x , italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), respectively. The prototype for a sample from text and visual branches is given as:

hx{t,v}superscriptsubscript𝑥𝑡𝑣\displaystyle h_{x}^{\{t,v\}}italic_h start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT { italic_t , italic_v } end_POSTSUPERSCRIPT =1|P|i=1|P|E{T,V}(x,ei)absent1𝑃superscriptsubscript𝑖1𝑃subscript𝐸𝑇𝑉𝑥subscript𝑒𝑖\displaystyle=\frac{1}{|P|}\sum_{i=1}^{|P|}E_{\{T,V\}}(x,e_{i})= divide start_ARG 1 end_ARG start_ARG | italic_P | end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | italic_P | end_POSTSUPERSCRIPT italic_E start_POSTSUBSCRIPT { italic_T , italic_V } end_POSTSUBSCRIPT ( italic_x , italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) (2)
hCLS,xvsuperscriptsubscriptCLS𝑥𝑣\displaystyle h_{\textrm{\tiny{CLS}},x}^{v}italic_h start_POSTSUBSCRIPT CLS , italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v end_POSTSUPERSCRIPT =EV(x,eCLS)absentsubscript𝐸𝑉𝑥subscript𝑒CLS\displaystyle=E_{V}(x,e_{\textrm{\tiny{CLS}}})= italic_E start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_x , italic_e start_POSTSUBSCRIPT CLS end_POSTSUBSCRIPT )

where |P|𝑃|P|| italic_P | represents the total number of tokens (learnable and non-learnable for both text and visual), excluding EOS, SOS, and CLS. t,v𝑡𝑣t,vitalic_t , italic_v represents textual branch and visual branch respectively. For the proxy source dataset, the class-aware prototype is obtained as:

hckt=1|𝒟(ck)|x𝒟(ck)hxt,superscriptsubscriptsubscript𝑐𝑘𝑡1𝒟subscript𝑐𝑘subscript𝑥𝒟subscript𝑐𝑘superscriptsubscript𝑥𝑡\displaystyle h_{c_{k}}^{t}=\frac{1}{|\mathcal{D}(c_{k})|}\sum_{x\in\mathcal{D% }(c_{k})}h_{x}^{t},italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG | caligraphic_D ( italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) | end_ARG ∑ start_POSTSUBSCRIPT italic_x ∈ caligraphic_D ( italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , (3)
hckv=1|𝒟(ck)|x𝒟(ck)hxv,superscriptsubscriptsubscript𝑐𝑘𝑣1𝒟subscript𝑐𝑘subscript𝑥𝒟subscript𝑐𝑘superscriptsubscript𝑥𝑣\displaystyle h_{c_{k}}^{v}=\frac{1}{|\mathcal{D}(c_{k})|}\sum_{x\in\mathcal{D% }(c_{k})}h_{x}^{v},italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG | caligraphic_D ( italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) | end_ARG ∑ start_POSTSUBSCRIPT italic_x ∈ caligraphic_D ( italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v end_POSTSUPERSCRIPT , (4)
hCLS,ckv=1|𝒟(ck)|x𝒟(ck)hCLS,xvsuperscriptsubscriptCLSsubscript𝑐𝑘𝑣1𝒟subscript𝑐𝑘subscript𝑥𝒟subscript𝑐𝑘superscriptsubscriptCLS𝑥𝑣\displaystyle h_{\textrm{\tiny{CLS}},c_{k}}^{v}=\frac{1}{|\mathcal{D}(c_{k})|}% \sum_{x\in\mathcal{D}(c_{k})}h_{\textrm{\tiny{CLS}},x}^{v}italic_h start_POSTSUBSCRIPT CLS , italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG | caligraphic_D ( italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) | end_ARG ∑ start_POSTSUBSCRIPT italic_x ∈ caligraphic_D ( italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT CLS , italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v end_POSTSUPERSCRIPT (5)

where 𝒟(ck)𝒟subscript𝑐𝑘\mathcal{D}(c_{k})caligraphic_D ( italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) contains all samples for class cksubscript𝑐𝑘c_{k}italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. The prototypes for augmented views are calculated using the augmented samples for each class cksubscript𝑐𝑘c_{k}italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT denoted as 𝒟aug(ck)superscript𝒟𝑎𝑢𝑔subscript𝑐𝑘\mathcal{D}^{aug}(c_{k})caligraphic_D start_POSTSUPERSCRIPT italic_a italic_u italic_g end_POSTSUPERSCRIPT ( italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) and the corresponding prototypes are denoted as hckaug,tsuperscriptsubscriptsubscript𝑐𝑘𝑎𝑢𝑔𝑡h_{c_{k}}^{aug,t}italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_u italic_g , italic_t end_POSTSUPERSCRIPT, hckaug,vsuperscriptsubscriptsubscript𝑐𝑘𝑎𝑢𝑔𝑣h_{c_{k}}^{aug,v}italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_u italic_g , italic_v end_POSTSUPERSCRIPT and hCLS,ckaug,vsuperscriptsubscriptCLSsubscript𝑐𝑘𝑎𝑢𝑔𝑣h_{\textrm{\tiny{CLS}},c_{k}}^{aug,v}italic_h start_POSTSUBSCRIPT CLS , italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_u italic_g , italic_v end_POSTSUPERSCRIPT respectively.

3.3 Prototype Discriminating Loss

The discriminating loss is responsible for training learnable prompts to distinguish the context of samples from one class compared to other classes. This goal is achieved by pushing the class prototype hckm, m{t,v}superscriptsubscriptsubscript𝑐𝑘𝑚 𝑚𝑡𝑣h_{c_{k}}^{m},\textrm{ }m\in\{t,v\}italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_m ∈ { italic_t , italic_v } for both text and visual branches away from the prototype of class hcjm, m{t,v}superscriptsubscriptsubscript𝑐𝑗𝑚 𝑚𝑡𝑣h_{c_{j}}^{m},\textrm{ }m\in\{t,v\}italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_m ∈ { italic_t , italic_v } where ckcjsubscript𝑐𝑘subscript𝑐𝑗c_{k}\neq c_{j}italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ≠ italic_c start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. Likewise, we pull prototypes hckmsuperscriptsubscriptsubscript𝑐𝑘𝑚h_{c_{k}}^{m}italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT and hckaug,msuperscriptsubscriptsubscript𝑐𝑘𝑎𝑢𝑔𝑚h_{c_{k}}^{aug,m}italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_u italic_g , italic_m end_POSTSUPERSCRIPT for same class and push away augmented ones for cjcksubscript𝑐𝑗subscript𝑐𝑘c_{j}\neq c_{k}italic_c start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ≠ italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. In this regard, contrastive learning [7, 14, 8, 23] offers a solution to pull prototypes of positive pairs and push away negative pairs. We refer to [34] to propose our discriminating loss Dsubscript𝐷\mathcal{L}_{D}caligraphic_L start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT, formally expressed as:

pos(ck)=1||augesim(hckm,hckaug,m)/τsubscript𝑝𝑜𝑠subscript𝑐𝑘1subscript𝑎𝑢𝑔superscript𝑒simsuperscriptsubscriptsubscript𝑐𝑘𝑚superscriptsubscriptsubscript𝑐𝑘𝑎𝑢𝑔𝑚𝜏\mathcal{L}_{pos}(c_{k})=\frac{1}{|\mathcal{H}|}\sum_{aug\in\mathcal{H}}e^{% \textrm{sim}(h_{c_{k}}^{m},h_{c_{k}}^{aug,m})/\tau}caligraphic_L start_POSTSUBSCRIPT italic_p italic_o italic_s end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_H | end_ARG ∑ start_POSTSUBSCRIPT italic_a italic_u italic_g ∈ caligraphic_H end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT sim ( italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_u italic_g , italic_m end_POSTSUPERSCRIPT ) / italic_τ end_POSTSUPERSCRIPT (6)
neg(ck)=1||augc=1,cckCesim(hckm,hcm)/τ+esim(hckaug,m,hcm)/τ+esim(hckm,hcaug,m)/τsubscript𝑛𝑒𝑔subscript𝑐𝑘1subscript𝑎𝑢𝑔superscriptsubscriptformulae-sequence𝑐1𝑐subscript𝑐𝑘𝐶superscript𝑒simsuperscriptsubscriptsubscript𝑐𝑘𝑚superscriptsubscript𝑐𝑚𝜏superscript𝑒simsuperscriptsubscriptsubscript𝑐𝑘𝑎𝑢𝑔𝑚superscriptsubscript𝑐𝑚𝜏superscript𝑒simsuperscriptsubscriptsubscript𝑐𝑘𝑚superscriptsubscript𝑐𝑎𝑢𝑔𝑚𝜏\mathcal{L}_{neg}(c_{k})=\frac{1}{|\mathcal{H}|}\sum_{aug\in\mathcal{H}}\sum_{% c=1,c\neq c_{k}}^{C}e^{\textrm{sim}(h_{c_{k}}^{m},h_{c}^{m})/\tau}\\ +e^{\textrm{sim}(h_{c_{k}}^{aug,m},h_{c}^{m})/\tau}+e^{\textrm{sim}(h_{c_{k}}^% {m},h_{c}^{aug,m})/\tau}start_ROW start_CELL caligraphic_L start_POSTSUBSCRIPT italic_n italic_e italic_g end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG | caligraphic_H | end_ARG ∑ start_POSTSUBSCRIPT italic_a italic_u italic_g ∈ caligraphic_H end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_c = 1 , italic_c ≠ italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT sim ( italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_h start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) / italic_τ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL + italic_e start_POSTSUPERSCRIPT sim ( italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_u italic_g , italic_m end_POSTSUPERSCRIPT , italic_h start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) / italic_τ end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT sim ( italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_h start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_u italic_g , italic_m end_POSTSUPERSCRIPT ) / italic_τ end_POSTSUPERSCRIPT end_CELL end_ROW (7)
D=1|m|Cmc=1Clogposnegsubscript𝐷1𝑚𝐶subscriptfor-all𝑚superscriptsubscript𝑐1𝐶subscript𝑝𝑜𝑠subscript𝑛𝑒𝑔\mathcal{L}_{D}=-\frac{1}{|m|*C}\sum_{\forall m}\sum_{c=1}^{C}\log\frac{% \mathcal{L}_{pos}}{\mathcal{L}_{neg}}caligraphic_L start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT = - divide start_ARG 1 end_ARG start_ARG | italic_m | ∗ italic_C end_ARG ∑ start_POSTSUBSCRIPT ∀ italic_m end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT roman_log divide start_ARG caligraphic_L start_POSTSUBSCRIPT italic_p italic_o italic_s end_POSTSUBSCRIPT end_ARG start_ARG caligraphic_L start_POSTSUBSCRIPT italic_n italic_e italic_g end_POSTSUBSCRIPT end_ARG (8)

where ck[1,C]subscript𝑐𝑘1𝐶c_{k}\in[1,C]italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ [ 1 , italic_C ] and m{t,v},i.e.|m|=2formulae-sequence𝑚𝑡𝑣𝑖𝑒𝑚2m\in\{t,v\},i.e.|m|=2italic_m ∈ { italic_t , italic_v } , italic_i . italic_e . | italic_m | = 2. The prototypes hckmsuperscriptsubscriptsubscript𝑐𝑘𝑚h_{c_{k}}^{m}italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT and hckaug,msuperscriptsubscriptsubscript𝑐𝑘𝑎𝑢𝑔𝑚h_{c_{k}}^{aug,m}italic_h start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_u italic_g , italic_m end_POSTSUPERSCRIPT additionally contains hCLS,ckv,hCLS,ckaug,vsuperscriptsubscriptCLSsubscript𝑐𝑘𝑣superscriptsubscriptCLSsubscript𝑐𝑘𝑎𝑢𝑔𝑣h_{\textrm{\tiny{CLS}},c_{k}}^{v},h_{\textrm{\tiny{CLS}},c_{k}}^{aug,v}italic_h start_POSTSUBSCRIPT CLS , italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_v end_POSTSUPERSCRIPT , italic_h start_POSTSUBSCRIPT CLS , italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_u italic_g , italic_v end_POSTSUPERSCRIPT when m=v𝑚𝑣m=vitalic_m = italic_v. Resulting prompt update 𝐩𝐩^𝐩^𝐩\mathbf{p}\rightarrow\hat{\mathbf{p}}bold_p → over^ start_ARG bold_p end_ARG (learnable prompt tokens) is obtained after applying gradients for the discriminating loss. Since the proxy dataset will remain same for all the test instances, the updated prompt can be saved and restored each time for an incoming test sample. We presented the study on performance and latency with and without saving these updated prompts in Appendix 10. For the rest of the paper we generalize our method without requiring to save these updated prompts.

3.4 Prototype Alignment Loss

Loss Dsubscript𝐷\mathcal{L}_{D}caligraphic_L start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT can effectively separate different classes, it is not able to tune the prompt for the test sample, which comes from a different distribution than the source distribution. Hence, we propose the prototype alignment of the test sample (and its augmented views) with the class prototype obtained from the source distribution. We propose to weigh the prototype alignment by the probability of the test sample lying in the particular class. Lets denote the probability p~𝐩^[c]subscript~𝑝^𝐩delimited-[]𝑐\tilde{p}_{\mathbf{\hat{p}}}[c]over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_p end_ARG end_POSTSUBSCRIPT [ italic_c ] (as mentioned in Eq.1) as the mean of probabilities (for class c) produced with the updated prompt 𝐩^^𝐩\hat{\mathbf{p}}over^ start_ARG bold_p end_ARG across filtered augmented views (preserved after the confidence selection filter) including test sample (F𝐹Fitalic_F). The amplitude and angle alignment of sample xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with the class prototypes for both text and visual branches is calculated as follows:

amp(xi)=c=1Cp~𝐩^[c]pximpcm2superscriptsubscript𝑎𝑚𝑝subscript𝑥𝑖superscriptsubscript𝑐1𝐶subscript~𝑝^𝐩delimited-[]𝑐superscriptnormsuperscriptsubscript𝑝subscript𝑥𝑖𝑚superscriptsubscript𝑝𝑐𝑚2\displaystyle\mathcal{L}_{amp}^{{}^{\prime}}(x_{i})=\sum_{c=1}^{C}\tilde{p}_{% \mathbf{\hat{p}}}[c]||p_{x_{i}}^{m}-p_{c}^{m}||^{2}caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_p end_ARG end_POSTSUBSCRIPT [ italic_c ] | | italic_p start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT - italic_p start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (9)
ang(xi)=c=1Cp~𝐩^[c]sim(pxim,pcm)superscriptsubscript𝑎𝑛𝑔subscript𝑥𝑖superscriptsubscript𝑐1𝐶subscript~𝑝^𝐩delimited-[]𝑐simsuperscriptsubscript𝑝subscript𝑥𝑖𝑚superscriptsubscript𝑝𝑐𝑚\displaystyle\mathcal{L}_{ang}^{{}^{\prime}}(x_{i})=\sum_{c=1}^{C}\tilde{p}_{% \mathbf{\hat{p}}}[c]\textrm{sim}(p_{x_{i}}^{m},p_{c}^{m})caligraphic_L start_POSTSUBSCRIPT italic_a italic_n italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_p end_ARG end_POSTSUBSCRIPT [ italic_c ] sim ( italic_p start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_p start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) (10)

where m{t,v}𝑚𝑡𝑣m\in\{t,v\}italic_m ∈ { italic_t , italic_v }. However, there is an issue with MSE loss since it gives an equal penalty (e.g. ampsuperscriptsubscript𝑎𝑚𝑝\mathcal{L}_{amp}^{{}^{\prime}}caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT = 0.1) for an increase from 1.2 to 1.3 and 1.7 to 1.8. But we wanted to penalise more for 1.2 to 1.3 since increase in MSE in the smaller range should be penalised more to preserve the base class performance. Hence we penalise with logarithm i.e. we use amp=logampsubscript𝑎𝑚𝑝superscriptsubscript𝑎𝑚𝑝\mathcal{L}_{amp}=\log\mathcal{L}_{amp}^{{}^{\prime}}caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT = roman_log caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT. Similarly, the penalty should be applied to 1/ang1superscriptsubscript𝑎𝑛𝑔1/\mathcal{L}_{ang}^{{}^{\prime}}1 / caligraphic_L start_POSTSUBSCRIPT italic_a italic_n italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT for angle alignment. The updated amplitude and angle alignment loss is amp(xi)=logampsubscript𝑎𝑚𝑝subscript𝑥𝑖superscriptsubscript𝑎𝑚𝑝\mathcal{L}_{amp}(x_{i})=\log\mathcal{L}_{amp}^{{}^{\prime}}caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = roman_log caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT and ang(xi)=log(1/ang)=logangsubscript𝑎𝑛𝑔subscript𝑥𝑖1superscriptsubscript𝑎𝑛𝑔superscriptsubscript𝑎𝑛𝑔\mathcal{L}_{ang}(x_{i})=\log(1/\mathcal{L}_{ang}^{{}^{\prime}})=-\log\mathcal% {L}_{ang}^{{}^{\prime}}caligraphic_L start_POSTSUBSCRIPT italic_a italic_n italic_g end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = roman_log ( 1 / caligraphic_L start_POSTSUBSCRIPT italic_a italic_n italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ) = - roman_log caligraphic_L start_POSTSUBSCRIPT italic_a italic_n italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT respectively. We combined the amplitude and angle loss with equal importance and hence the prototype alignment loss Asubscript𝐴\mathcal{L}_{A}caligraphic_L start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT is given as:

A=1|F|xiF(amp(xi)+ang(xi))subscript𝐴1𝐹subscriptsubscript𝑥𝑖𝐹subscript𝑎𝑚𝑝subscript𝑥𝑖subscript𝑎𝑛𝑔subscript𝑥𝑖\displaystyle\mathcal{L}_{A}=\frac{1}{|F|}\sum_{x_{i}\in F}(\mathcal{L}_{amp}(% x_{i})+\mathcal{L}_{ang}(x_{i}))caligraphic_L start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG | italic_F | end_ARG ∑ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_F end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + caligraphic_L start_POSTSUBSCRIPT italic_a italic_n italic_g end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) (11)
=1|F|xiFlogang(xi)amp(xi)absent1𝐹subscriptsubscript𝑥𝑖𝐹superscriptsubscript𝑎𝑛𝑔subscript𝑥𝑖superscriptsubscript𝑎𝑚𝑝subscript𝑥𝑖\displaystyle=-\frac{1}{|F|}\sum_{x_{i}\in F}\log\frac{\mathcal{L}_{ang}^{% \prime}(x_{i})}{\mathcal{L}_{amp}^{\prime}(x_{i})}= - divide start_ARG 1 end_ARG start_ARG | italic_F | end_ARG ∑ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_F end_POSTSUBSCRIPT roman_log divide start_ARG caligraphic_L start_POSTSUBSCRIPT italic_a italic_n italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG

3.5 Algorithm Details

In order to compute the prototype discriminating loss on the source dataset, we require the pre-training dataset of the CLIP model. However, it was trained on over 400 million image-text pairs, which are not publicly available. Nevertheless, in previous works[2, 33], CLIP has been heavily tuned on the ImageNet[11] dataset to achieve excellent zero-shot performance. Hence, we use ImageNet as the proxy for the source dataset to compute prototypes for each class. These prototypes are computed offline for both the sample and its augmented views, and they are used directly during test-time adaptation. During each iteration of test-time adaptation, the meta-train stage is entered first. The model starts training using the prototype discriminating objective argmin 𝒑D𝒑argmin subscript𝐷\underset{\boldsymbol{p}}{\textrm{{argmin} }}\mathcal{L}_{D}underbold_italic_p start_ARG bold_argmin end_ARG caligraphic_L start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT, and gradients are calculated, resulting in the prompt update 𝐩𝐩^𝐩^𝐩\mathbf{p}\rightarrow\hat{\mathbf{p}}bold_p → over^ start_ARG bold_p end_ARG (ithsuperscript𝑖𝑡i^{th}italic_i start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT iteration). Subsequently, the meta-test stage is executed. Here, the augmented views are first filtered using a confidence threshold over predicted probabilities using the updated prompts 𝐩^^𝐩\hat{\mathbf{p}}over^ start_ARG bold_p end_ARG. The mean probabilities p~𝐩^subscript~𝑝^𝐩\tilde{p}_{\mathbf{\hat{p}}}over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_p end_ARG end_POSTSUBSCRIPT are computed over F𝐹Fitalic_F and used as weights in Asubscript𝐴\mathcal{L}_{A}caligraphic_L start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT. The model is trained on F𝐹Fitalic_F, and the gradient of prototype alignment loss 𝐩Asubscript𝐩subscript𝐴\nabla_{{\mathbf{p}}}\mathcal{L}_{A}∇ start_POSTSUBSCRIPT bold_p end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT is calculated. We average out the gradients over all samples in F𝐹Fitalic_F. Finally, the prompts 𝒑𝒑\boldsymbol{p}bold_italic_p is updated using combined objective: ent+argmin 𝒑Asubscript𝑒𝑛𝑡𝒑argmin subscript𝐴\mathcal{L}_{ent}+\underset{\boldsymbol{p}}{\textrm{{argmin} }}\mathcal{L}_{A}caligraphic_L start_POSTSUBSCRIPT italic_e italic_n italic_t end_POSTSUBSCRIPT + underbold_italic_p start_ARG bold_argmin end_ARG caligraphic_L start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT. For n>1𝑛1n>1italic_n > 1 we accumulate the averaged gradients before final prompt update.

Imagenet V2 Imagenet Sketch Imagenet A Imagenet R OOD Avg
CLIP 60.86 46.09 47.87 73.98 57.20
CLIP+TPT 64.35 47.94 54.77 77.06 60.81
CoOp 64.20 47.99 49.71 75.21 59.28
CoOp+TPT 66.83 49.29 57.95 77.27 62.84
Co-CoOp 64.07 48.75 50.63 76.18 59.91
Co-CoOp+TPT 64.85 48.27 58.47 78.65 62.61
MaPLe 64.07 49.15 50.90 76.98 60.28
MaPLe+TPT 64.87 48.16 58.08 78.12 62.31
PromptAlign 65.29 50.23 59.37 79.33 63.55
PromptSync 67.54 53.42 61.92 80.64 65.88
Table 1: Comparison on the domain generalization setting. Prompt tuning methods are trained on ImageNet and evaluated on datasets with domain shifts
Camera
(Yaw/ Pitch/ Roll)
Pose
(Yaw/ Pitch/ Roll)
Scale Texture Lighting Worlds
MaPLe 48.73/ 39.93/ 32.13 48.10/ 28.40/ 27.80 46.90 37.90 15.50 32.13
MaPLe+TPT 57.04/ 45.99/ 39.23 56.26/ 35.64/ 33.26 54.87 43.73 22.52 42.00
PromptAlign 58.14/ 46.93/ 40.45 57.43/ 36.31/ 34.32 56.18 44.97 23.06 43.24
PromptSync 59.84/ 48.54/ 41.92 59.72/ 38.84/ 36.64 58.12 45.98 25.02 44.84
Table 2: Comparison on the domain generalization setting for distribution alignment. MaPLe is trained on ImageNet and evaluated on OOD dataset i.e., PUG

4 Experiments

We have evaluated PromptSync on different benchmark settings (Appendix 9) with different datasets described below:
Datasets: For domain generalisation setting, we follow PromptAlign [33] and evaluated our method on four out-of-distribution (OOD) variants of ImageNet [11]: ImageNetV2 [32], ImageNet-Sketch [40], ImageNet-A [18] and ImageNet-R [17]. We also consider the evaluation on a recent and challenging benchmark, namely, Photorealistic Unreal Graphics (PUG) dataset [4], comprised of different textures, sizes, orientations, and backgrounds. For cross-dataset transfer setting, we follow TPT [35] and evaluate the performance on 10 diverse image classification datasets with varying complexities for visual recognition tasks. This includes Caltech 101 [12] for generic objects. Five fine-grained datasets (spanning images of animals, flowers and transportation) are StanfordCars [24], Food101 [5], Flowers102 [27], FGVC-Aircraft [25], OxfordPets [29]. Moreover, four datasets, namely, SUN397 [37], DTD [10], UCF101 [36], and EUROSAT [15], comprise scenes, textures, human actions, and satellite imagery, respectively. For base-to-novel generalisation, we follow [21] and evaluate our method on ImageNet and the 10 image classification datasets.
Baselines: We compared PromptSync with existing few-shot prompt learning methods for CLIP adaptation; these are CoOp [47], CoCoOp [46], TPT [35], and PromptAlign [33]. MaPLe [21] is a multi-modal prompt learning baseline that adapts CLIP by learning prompts on both text and visual branches. TPT [35] and PromptAlign [33] are the test-time prompt tuning methods that tune the prompt for each incoming test sample, achieving state-of-the-art performance in prompt learning.
Implementation Details: We ran all experiments on a single NVIDIA A100 40GB GPU. Following [21], we trained on ImageNet with 16-shot training data selected at random for each class using 2 prompt tokens for a depth of 3 layers (on CLIP ViT-B/16 backbone architecture). We optimized the prompts on both the text and visual branches using a single test image. We augmented each test image with 127 different views using random resized crops, background substitution, horizontal flip augmentations, and visual corruption. For text augmentation, we used hyponyms, synonyms, and meronyms from WordNet[26]. Moreover, we generated various text prompts from pre-trained LLMs [6]. Additionally, we randomly masked one of the learnable tokens for 15% of augmented views. We computed the gradients of alignment loss for a batch size of 128 images, including the original image. During the meta-train stage, we updated the original parameters (using a single iteration) and then optimized the prompts in the meta-test stage by calculating the gradients of alignment loss w.r.t. the updated parameters accumulated for a single (n=1𝑛1n=1italic_n = 1) iteration to facilitate the one-to-one comparison with baselines. We obtained the top 10% confident predictions of augmented views based on the lowest entropy. We used the AdamW optimizer and a learning rate β𝛽\betaitalic_β of 5e45superscript𝑒45e^{-4}5 italic_e start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT for the fine-grained datasets and 0.040.040.040.04 for the rest of the datasets.

4.1 Domain Generalization

We demonstrate that all test-time adaptation methods exhibit better performance (Table 1) compared to the pre-trained CLIP model, highlighting the advantage of tuning V-L models at test time. PromptSync achieves the highest Top-1 accuracy averaged across all the domains of ImageNet variants. Furthermore, we evaluated the ImageNet-trained model on various out-of-distribution (OOD) datasets and observed consistent improvement in performance compared to existing state-of-the-art (SOTA) approaches. The detailed results for each domain dataset are presented in Tables 1 and 2. This confirms that alignment and discriminative training with augmented views on both the text and visual branches enhance the generalization performance of V-L models like CLIP.

Datasets Sets
CoOp
(IJCV22)
CoCoOp
(CVPR22)
ProDA
(CVPR22)
MaPLe
(CVPR23)
MaPLe + TPT
(CVPR23)
PromptAlign
(NIPS23)
PromptSync
Average
Base
Novel
82.38
67.96
80.47
71.69
81.56
72.30
82.24
75.09
82.16
74.95
83.19
75.88
84.17
77.17
ImageNet
Base
Novel
76.46
66.31
75.98
70.43
75.40
70.23
76.67
70.54
77.73
72.24
78.26
72.59
79.23
73.84
Caltech101
Base
Novel
97.80
93.27
97.96
93.81
98.27
93.23
98.00
94.27
98.54
94.29
98.60
94.50
98.62
94.67
OxfordPets
Base
Novel
94.47
96.00
95.20
97.69
95.43
97.83
95.43
97.80
95.23
97.37
95.38
97.56
95.44
97.83
Stanford Cars
Base
Novel
75.67
67.53
70.49
73.59
74.70
71.20
72.90
73.97
74.00
75.20
75.02
75.71
76.42
77.21
Flowers102
Base
Novel
97.27
67.13
94.87
71.75
97.70
68.68
95.93
72.40
96.24
72.10
96.61
72.34
97.73
73.78
Food101
Base
Novel
89.37
88.77
90.70
91.29
90.30
88.57
90.70
92.07
91.13
92.03
91.63
92.68
92.39
92.95
FGVC Aircraft
Base
Novel
39.67
31.23
33.41
23.71
36.90
34.13
37.27
35.53
34.31
35.81
37.21
37.27
40.91
39.31
SUN397
Base
Novel
80.85
68.34
79.74
76.86
78.67
76.93
80.80
78.70
81.15
79.18
81.57
79.48
84.28
83.01
DTD
Base
Novel
79.97
48.60
77.01
56.00
80.67
56.48
80.30
59.23
82.20
59.91
82.60
60.55
83.49
62.03
Eurosat
Base
Novel
90.10
53.00
87.49
60.04
83.90
66.00
93.63
72.87
91.02
68.96
94.10
72.71
94.63
73.19
UCF101
Base
Novel
84.53
67.37
82.33
73.45
85.23
71.97
82.97
78.57
82.23
77.34
84.11
79.30
85.75
81.29
Table 3: Comparison on Base-to-novel generalization setting. PromptSync shows consistent improvement on both base and novel classes over previous methods

4.2 Base to Novel Generalization

Table 3 presents the detailed performance report of PromptSync on base and novel classes across 11 recognition datasets. On average, our strategy outperforms the model performance by 1.29% on base classes and nearly 1% on novel classes. We observe that PromptAlign, based on a distribution alignment strategy, outperforms for novel classes in most cases, with an average improvement of 0.79% compared to the best-performing model. However, the margin of improvement is very low. In contrast, with TPT, the performance drops in some instances, such as for OxfordPets, Eurosat, and UCF101. This demonstrates that: 1) test-source alignment is crucial for prompt tuning. 2) Prompt tuning alone in the text branch is not sufficient for zero-shot generalization. Since distribution alignment does not promote discriminative learning and the entropy loss on the test dataset is noisy, PromptSync outperforms with class-aware prototype discrimination and alignment across different augmented views. Averaging the gradients further motivates domain-agnostic prompt tuning on both the text and visual branches. This enhances the zero-shot generalization of the V-L model compared to other state-of-the-art approaches. Moreover, our strategy for prompt tuning does not lose information for base classes.

Caltech Pets Cars Flowers Food101 Aircraft SUN397 DTD EuroSAT UCF101 Average
CLIP 93.35 88.25 65.48 67.44 83.65 23.67 62.59 44.27 42.01 65.13 63.58
CLIP+TPT 94.16 87.79 66.87 68.98 84.67 24.78 65.50 47.75 42.44 68.04 65.10
CoOp 93.70 89.14 64.51 68.71 85.30 18.47 64.15 41.92 46.39 66.55 63.88
CoOp + TPT 93.15 89.48 66.77 68.48 86.48 20.51 66.06 43.32 37.73 68.91 64.08
CoCoOp 93.79 90.46 64.90 70.85 83.97 22.29 66.89 45.45 39.23 68.44 64.63
CoCoOp + TPT 88.57 85.33 59.68 55.31 80.64 16.89 60.24 38.93 48.55 63.35 59.75
ProDA 86.70 88.20 60.10 77.50 80.80 22.20 - 50.90 58.50 - 65.62
MaPLe 93.53 90.49 65.57 72.23 86.20 24.74 67.01 46.49 48.06 68.69 66.30
MaPLe+TPT 93.59 90.72 66.50 72.37 86.64 24.70 67.54 45.87 47.80 69.19 66.50
PromptAlign 94.01 90.76 68.50 72.39 86.65 24.80 67.54 47.24 47.86 69.47 66.92
PromptSync 95.78 91.89 69.24 77.68 87.72 25.91 67.98 50.99 59.36 71.04 69.76
Table 4: Comparison on cross-dataset transfer setting. Prompt tuning methods are trained on ImageNet and evaluated on cross-datasets
Method
Entropy
Loss
Alignment
Loss
Discriminative
Loss
Top-1
Acc.
MaPLe 50.90
MaPLe+TPT \checkmark 58.08
PromptAlign \checkmark 50.85
PromptAlign \checkmark \checkmark 59.37
PromptSync \checkmark 56.67
PromptSync \checkmark \checkmark \checkmark 61.92
Table 5: Ablation Study. Analysis of Alignment, Discriminative and Entropy minimization loss. The average of Top-1 accuracy(%) across three seeds is reported

4.3 Cross-Dataset Transfer

In Table 4, we compared the transfer performance of PromptSync with existing state-of-the-art methods using prompt learning. We evaluated methods for transfer performance across diverse cross-datasets. PromptSync consistently outperforms the previous best method, i.e., PromptAlign [33], across all cross-datasets, providing an average improvement of 2.84%percent2.842.84\%2.84 %. Compared to PromptAlign, which outperforms the previous method MaPLe + TPT by a very small margin, i.e., 0.42%percent0.420.42\%0.42 %, our method shows a significant average improvement of 3.26%percent3.263.26\%3.26 % over MaPLe + TPT. Other methods, CoOp and CoCoOp, on average, perform worse than zero-shot CLIP + TPT (except ProDA [45]). This affirms that both text-visual alignment and domain-agnostic parameter updates result in better transfer generalization across cross-datasets in V-L models. As opposed to our method, the previous approaches were not consistent in performance across all datasets, which further affirms the advantage of a domain-agnostic training strategy.

ampsubscript𝑎𝑚𝑝\mathcal{L}_{amp}caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT angsubscript𝑎𝑛𝑔\mathcal{L}_{ang}caligraphic_L start_POSTSUBSCRIPT italic_a italic_n italic_g end_POSTSUBSCRIPT sub sum_exp Asubscript𝐴\mathcal{L}_{A}caligraphic_L start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT
PromptSync 59.84 58.81 57.86 59.83 61.92
Table 6: Ablation analysis to alignment loss variants. All results are on ImageNet-A dataset

5 Ablation

Class-Aware Prototype Alignment: Table 5 summarizes the comparison between two alignment strategies: distribution alignment of the test sample with the class-agnostic source distribution. All results are on the ImageNet-A dataset. PromptAlign adopted distribution alignment along with averaged cross-entropy for prompt tuning. However, we perform domain-agnostic parameter updates with class-aware prototype alignment for the test sample. As shown in Table 5, PromptAlign without entropy loss is as good as vanilla MaPLe. This is due to the fact that distributional alignment does not promote any discriminative learning in the absence of entropy loss. However, because entropy loss is noisy due to the poor performance of the vanilla zero-shot V-L model, we propose the stronger discriminative loss of class prototype alignment for prompt tuning with source and test samples with augmented views. PromptSync without entropy loss outperforms the corresponding counterpart PromptAlign. This is because the class-aware prototype alignment has both alignment and discriminative properties, thus improving test-time adaptation on its own. With additional signals from predicted probabilities for each class, the class-aware prototype alignment acts as a geometric regularizer, mitigating class collapse in prompt representation.
Loss variants: We conducted an ablation study on amplitude and angle loss for the class-aware prototype alignment objective. Table 6 compares three loss choices: 1) amplitude loss, 2) angle loss, and 3) amplitude + angle loss. Clearly, the combination of amplitude and angle performs better than other choices. The formulation for the combination of amplitude and angle loss is the same as in equation 11. We further investigated other variants, i.e., combining two of them without taking the log: 1) subtraction between amplitude and angle ampangsuperscriptsubscript𝑎𝑚𝑝superscriptsubscript𝑎𝑛𝑔\mathcal{L}_{amp}^{{}^{\prime}}-\mathcal{L}_{ang}^{{}^{\prime}}caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT - caligraphic_L start_POSTSUBSCRIPT italic_a italic_n italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT(sub) 2) the summation of exponential of both losses exp(amp)+exp(1/ang)expsuperscriptsubscript𝑎𝑚𝑝exp1superscriptsubscript𝑎𝑛𝑔\textrm{exp}(\mathcal{L}_{amp}^{{}^{\prime}})+\textrm{exp}(1/\mathcal{L}_{ang}% ^{{}^{\prime}})exp ( caligraphic_L start_POSTSUBSCRIPT italic_a italic_m italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ) + exp ( 1 / caligraphic_L start_POSTSUBSCRIPT italic_a italic_n italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT )(sum_exp). Clearly, the formulation in equation 11 (Asubscript𝐴\mathcal{L}_{A}caligraphic_L start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT) performs best among other variants. Ablation on the proxy dataset is given in Appendix 12, and ablation on performance and latency with and without saving updated prompts is provided in Appendix 10. We also compared the number of augmented views and prompt updates in Appendix 11.

6 Performance and Latency

The experiments presented in the Table 7 (Appendix) involve a comparison of different methods, namely MaPLe + TPT, PromptAlign, PromptSync*, and PromptSync. In these experiments, we evaluated the top-1 average accuracy (%) and latency (in hours for a single prompt update) of each method. Specifically, we investigated PromptSync with and without saving the updated prompt obtained after prototype discrimination, with the variant denoted as PromptSync* indicating the adaptation of prompt tokens for test samples after restoring saved prompt tokens.

The results, as shown in Table 7, include latency measurements represented in hours for a single prompt update, and all evaluations are conducted on the ImageNet-A dataset. Notably, the PromptSync* variant demonstrates a faster processing time compared to the full PromptSync method, with only a marginal drop in performance. This outcome underscores the achieved generalization through prototype alignment. Furthermore, in comparison to previous methods such as MaPLe + TPT and PromptAlign, the PromptSync* variant exhibits only a slight increase in latency (0.03 hours) while still improving overall performance.

7 Sensitivity Comparison

We further performed the sensitivity comparison of our method as compared to other state-of-the-art baselines. In Appendix, Figure 2(a) shows the comparison of performance during test time adaptation as the number of views increases. All the results are on ImageNet-A dataset. In comparison to PromptAlign and MaPLe + TPT, their performance almost plateaus around 64 views with insignificant improvement further, while PromptSync shows a consistent improvement with the increase in views and insignificant improvement beyond 128. This proves the generalizability achieved by our method since it optimises base CLIP over a larger number of possible shifts in the dataset, resulting in better performance. Figure 2(b) shows the performance comparison as the number of prompt update steps increases. All the methods increase their performance with an increase in the number of steps; however, our method shows better adaptation to the test sample with more steps in comparison to PromptAlign and MaPLe + TPT. For apples-to-apples comparison we perform a single-step update (128 views) following TPT [35].

8 LAION400M Proxy Dataset Analysis

Given CLIP’s impressive zero-shot performance on ImageNet, we opted for ImageNet as a viable proxy source dataset, aligning with prior research [33]. We worked with a subset of LAION400M, comprising 2.5 million images (2 times the size of ImageNet). Furthermore, we carried out an ablation study on the alignment strategy using LAION400M as the source dataset, a dataset known to mirror CLIP’s training dataset [9]. The results for this ablation study is shown in Table 8 (Appendix). Notably, the performance impact remains consistent when utilizing this subset of LAION400M alongside ImageNet. Source class prototypes are computed on the proxy source data to derive the distribution for alignment during test time. As this proxy dataset aligns with the model’s training set, this offline computation remains unchanged despite environmental shifts and only necessitates computation once.

Conclusion

In summary, PromptSync significantly improves zero-shot generalization in vision-language models. Our approach, addressing class dominance and variance, outperforms existing methods by 2.33% overall, with a 1% boost in base-to-novel generalization and 2.84% in cross-dataset transfer on a domain generalization benchmark. This underscores PromptSync’s effectiveness in enhancing the robustness of vision-language models.

References

  • Alayrac et al. [2022] Jean-Baptiste Alayrac, Jeff Donahue, Pauline Luc, Antoine Miech, Iain Barr, Yana Hasson, Karel Lenc, Arthur Mensch, Katherine Millican, Malcolm Reynolds, et al. Flamingo: a visual language model for few-shot learning. Advances in Neural Information Processing Systems, 35:23716–23736, 2022.
  • Bahng et al. [2022a] Hyo** Bahng, Ali Jahanian, Swami Sankaranarayanan, and Phillip Isola. Exploring visual prompts for adapting large-scale models. arXiv preprint arXiv:2203.17274, 2022a.
  • Bahng et al. [2022b] Hyo** Bahng, Ali Jahanian, Swami Sankaranarayanan, and Phillip Isola. Visual prompting: Modifying pixel space to adapt pre-trained models. arXiv preprint arXiv:2203.17274, 3:11–12, 2022b.
  • Bordes et al. [2023] Florian Bordes, Shashank Shekhar, Mark Ibrahim, Diane Bouchacourt, Pascal Vincent, and Ari S Morcos. Pug: Photorealistic and semantically controllable synthetic data for representation learning. arXiv preprint arXiv:2308.03977, 2023.
  • Bossard et al. [2014] Lukas Bossard, Matthieu Guillaumin, and Luc Van Gool. Food-101–mining discriminative components with random forests. In Computer Vision–ECCV 2014: 13th European Conference, Zurich, Switzerland, September 6-12, 2014, Proceedings, Part VI 13, pages 446–461. Springer, 2014.
  • Brown et al. [2020] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • Chen et al. [2020a] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. In International conference on machine learning, pages 1597–1607. PMLR, 2020a.
  • Chen et al. [2020b] Xinlei Chen, Haoqi Fan, Ross Girshick, and Kaiming He. Improved baselines with momentum contrastive learning. arXiv preprint arXiv:2003.04297, 2020b.
  • Cherti et al. [2023] Mehdi Cherti, Romain Beaumont, Ross Wightman, Mitchell Wortsman, Gabriel Ilharco, Cade Gordon, Christoph Schuhmann, Ludwig Schmidt, and Jenia Jitsev. Reproducible scaling laws for contrastive language-image learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2818–2829, 2023.
  • Cimpoi et al. [2014] Mircea Cimpoi, Subhransu Maji, Iasonas Kokkinos, Sammy Mohamed, and Andrea Vedaldi. Describing textures in the wild. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 3606–3613, 2014.
  • Deng et al. [2009] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. Ieee, 2009.
  • Fei-Fei et al. [2004] Li Fei-Fei, Rob Fergus, and Pietro Perona. Learning generative visual models from few training examples: An incremental bayesian approach tested on 101 object categories. In 2004 conference on computer vision and pattern recognition workshop, pages 178–178. IEEE, 2004.
  • He et al. [2021] Junxian He, Chunting Zhou, Xuezhe Ma, Taylor Berg-Kirkpatrick, and Graham Neubig. Towards a unified view of parameter-efficient transfer learning. arXiv preprint arXiv:2110.04366, 2021.
  • He et al. [2020] Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick. Momentum contrast for unsupervised visual representation learning. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 9729–9738, 2020.
  • Helber et al. [2019] Patrick Helber, Benjamin Bischke, Andreas Dengel, and Damian Borth. Eurosat: A novel dataset and deep learning benchmark for land use and land cover classification. IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing, 12(7):2217–2226, 2019.
  • Hendrycks and Dietterich [2019] Dan Hendrycks and Thomas Dietterich. Benchmarking neural network robustness to common corruptions and perturbations. arXiv preprint arXiv:1903.12261, 2019.
  • Hendrycks et al. [2021a] Dan Hendrycks, Steven Basart, Norman Mu, Saurav Kadavath, Frank Wang, Evan Dorundo, Rahul Desai, Tyler Zhu, Samyak Parajuli, Mike Guo, et al. The many faces of robustness: A critical analysis of out-of-distribution generalization. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 8340–8349, 2021a.
  • Hendrycks et al. [2021b] Dan Hendrycks, Kevin Zhao, Steven Basart, Jacob Steinhardt, and Dawn Song. Natural adversarial examples. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 15262–15271, 2021b.
  • Hu et al. [2021] Edward J Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685, 2021.
  • Jia et al. [2021] Chao Jia, Yinfei Yang, Ye Xia, Yi-Ting Chen, Zarana Parekh, Hieu Pham, Quoc Le, Yun-Hsuan Sung, Zhen Li, and Tom Duerig. Scaling up visual and vision-language representation learning with noisy text supervision. In International conference on machine learning, pages 4904–4916. PMLR, 2021.
  • Khattak et al. [2023a] Muhammad Uzair Khattak, Hanoona Rasheed, Muhammad Maaz, Salman Khan, and Fahad Shahbaz Khan. Maple: Multi-modal prompt learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 19113–19122, 2023a.
  • Khattak et al. [2023b] Muhammad Uzair Khattak, Syed Talal Wasim, Muzammal Naseer, Salman Khan, Ming-Hsuan Yang, and Fahad Shahbaz Khan. Self-regulating prompts: Foundational model adaptation without forgetting. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 15190–15200, 2023b.
  • Khosla et al. [2020] Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, and Dilip Krishnan. Supervised contrastive learning. Advances in neural information processing systems, 33:18661–18673, 2020.
  • Krause et al. [2013] Jonathan Krause, Michael Stark, Jia Deng, and Li Fei-Fei. 3d object representations for fine-grained categorization. In Proceedings of the IEEE international conference on computer vision workshops, pages 554–561, 2013.
  • Maji et al. [2013] Subhransu Maji, Esa Rahtu, Juho Kannala, Matthew Blaschko, and Andrea Vedaldi. Fine-grained visual classification of aircraft. arXiv preprint arXiv:1306.5151, 2013.
  • Miller [1995] George A Miller. Wordnet: a lexical database for english. Communications of the ACM, 38(11):39–41, 1995.
  • Nilsback and Zisserman [2008] Maria-Elena Nilsback and Andrew Zisserman. Automated flower classification over a large number of classes. In 2008 Sixth Indian conference on computer vision, graphics & image processing, pages 722–729. IEEE, 2008.
  • Oquab et al. [2014] Maxime Oquab, Leon Bottou, Ivan Laptev, and Josef Sivic. Learning and transferring mid-level image representations using convolutional neural networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1717–1724, 2014.
  • Parkhi et al. [2012] Omkar M Parkhi, Andrea Vedaldi, Andrew Zisserman, and CV Jawahar. Cats and dogs. In 2012 IEEE conference on computer vision and pattern recognition, pages 3498–3505. IEEE, 2012.
  • Quinonero-Candela et al. [2008] Joaquin Quinonero-Candela, Masashi Sugiyama, Anton Schwaighofer, and Neil D Lawrence. Dataset shift in machine learning. Mit Press, 2008.
  • Radford et al. [2021] Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, et al. Learning transferable visual models from natural language supervision. In International conference on machine learning, pages 8748–8763. PMLR, 2021.
  • Recht et al. [2019] Benjamin Recht, Rebecca Roelofs, Ludwig Schmidt, and Vaishaal Shankar. Do imagenet classifiers generalize to imagenet? In International conference on machine learning, pages 5389–5400. PMLR, 2019.
  • Samadh et al. [2023] Jameel Hassan Abdul Samadh, Hanan Gani, Noor Hazim Hussein, Muhammad Uzair Khattak, Muzammal Naseer, Fahad Khan, and Salman Khan. Align your prompts: Test-time prompting with distribution alignment for zero-shot generalization. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  • Sha et al. [2023] Zeyang Sha, Xinlei He, Ning Yu, Michael Backes, and Yang Zhang. Can’t steal? cont-steal! contrastive stealing attacks against image encoders. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 16373–16383, 2023.
  • Shu et al. [2022] Manli Shu, Weili Nie, De-An Huang, Zhiding Yu, Tom Goldstein, Anima Anandkumar, and Chaowei Xiao. Test-time prompt tuning for zero-shot generalization in vision-language models. Advances in Neural Information Processing Systems, 35:14274–14289, 2022.
  • Soomro et al. [2012] Khurram Soomro, Amir Roshan Zamir, and Mubarak Shah. A dataset of 101 human action classes from videos in the wild. Center for Research in Computer Vision, 2(11), 2012.
  • Sun et al. [2020] Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei Efros, and Moritz Hardt. Test-time training with self-supervision for generalization under distribution shifts. In International conference on machine learning, pages 9229–9248. PMLR, 2020.
  • Tan et al. [2022] Yue Tan, Guodong Long, Lu Liu, Tianyi Zhou, Qinghua Lu, **g Jiang, and Chengqi Zhang. Fedproto: Federated prototype learning across heterogeneous clients. In Proceedings of the AAAI Conference on Artificial Intelligence, pages 8432–8440, 2022.
  • Tong et al. [2023] Shengbang Tong, Yubei Chen, Yi Ma, and Yann Lecun. Emp-ssl: Towards self-supervised learning in one training epoch. arXiv preprint arXiv:2304.03977, 2023.
  • Wang et al. [2019] Haohan Wang, Songwei Ge, Zachary Lipton, and Eric P Xing. Learning robust global representations by penalizing local predictive power. Advances in Neural Information Processing Systems, 32, 2019.
  • Yao et al. [2021] Lewei Yao, Runhui Huang, Lu Hou, Guansong Lu, Minzhe Niu, Hang Xu, Xiaodan Liang, Zhenguo Li, Xin Jiang, and Chun**g Xu. Filip: Fine-grained interactive language-image pre-training. arXiv preprint arXiv:2111.07783, 2021.
  • Yuan et al. [2021] Lu Yuan, Dongdong Chen, Yi-Ling Chen, Noel Codella, Xiyang Dai, Jianfeng Gao, Houdong Hu, Xuedong Huang, Boxin Li, Chunyuan Li, et al. Florence: A new foundation model for computer vision. arXiv preprint arXiv:2111.11432, 2021.
  • Zang et al. [2022] Yuhang Zang, Wei Li, Kaiyang Zhou, Chen Huang, and Chen Change Loy. Unified vision and language prompt learning. arXiv preprint arXiv:2210.07225, 2022.
  • Zhai et al. [2022] Xiaohua Zhai, Xiao Wang, Basil Mustafa, Andreas Steiner, Daniel Keysers, Alexander Kolesnikov, and Lucas Beyer. Lit: Zero-shot transfer with locked-image text tuning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 18123–18133, 2022.
  • Zhang et al. [2021] Pan Zhang, Bo Zhang, Ting Zhang, Dong Chen, Yong Wang, and Fang Wen. Prototypical pseudo label denoising and target structure learning for domain adaptive semantic segmentation. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 12414–12424, 2021.
  • Zhou et al. [2022a] Kaiyang Zhou, **gkang Yang, Chen Change Loy, and Ziwei Liu. Conditional prompt learning for vision-language models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 16816–16825, 2022a.
  • Zhou et al. [2022b] Kaiyang Zhou, **gkang Yang, Chen Change Loy, and Ziwei Liu. Learning to prompt for vision-language models. International Journal of Computer Vision, 130(9):2337–2348, 2022b.
\thetitle

Supplementary Material

9 Benchmark Settings

Method Top-1 Average Accuracy(%) Latency
MaPLe + TPT 58.08 0.41
PromptAlign 59.37 0.46
PromptSync* 61.88 0.49
PromptSync 61.92 0.65
Table 7: Performance and Latency: Performance and Latency comparison of PromptSync with state-of-the-art baselines and its variant which reuse the learned prompt tokens after prototype discrimination without learning them for each incoming test sample.
Refer to caption
Refer to caption
Figure 2: Sensitivity Comparison. (a) Top-1 accuracy improves with number of augmented views (b) Top-1 accuracy improves consistently with number of prompt update steps.
Method Flowers DTD Pets Cars UCF Caltech Food SUN Aircraft Eurosat Avg
ImageNet 77.68 50.99 91.89 69.24 71.04 95.78 87.72 67.98 25.91 59.36 69.74
LAION 77.68 51.00 91.88 69.25 71.03 95.79 87.75 68.00 25.90 59.35 69.76
Table 8: Performance impact analysis using both ImageNet and LAION400M subset

Base-to-Novel Generalisation: Following MaPLe [21], we evaluate PromptSync on a zero-shot setting. We split the dataset into base and novel classes. The model is trained only on the base classes in a few-shot setting and evaluated on the base and novel classes.
Cross-dataset Transfer: We evaluate PromptSync on the ImageNet[11] pre-trained model on other datasets to determine the transfer performance. Following CoCoOp[46], our model is trained on all 1000 ImageNet classes in a few-shot manner.
Domain Generalisation: We evaluate PromptSync on out-of-distribution (OOD) datasets for domain generalizability. Similar to cross-dataset, we evaluate our ImageNet-trained model directly on OOD datasets, which are described in Section 4.

10 Performance and Latency

The experiments presented in the table 7 above involve a comparison of different methods, namely MaPLe + TPT, PromptAlign, PromptSync*, and PromptSync. In these experiments, we evaluated the top-1 average accuracy (%) and latency (in hours for a single prompt update) of each method. Specifically, we investigated PromptSync with and without saving the updated prompt obtained after prototype discrimination, with the variant denoted as PromptSync* indicating the adaptation of prompt tokens for test samples after restoring saved prompt tokens.

The results, as shown in Table 7, include latency measurements represented in hours for a single prompt update, and all evaluations are conducted on the ImageNet-A dataset. Notably, the PromptSync* variant demonstrates a faster processing time compared to the full PromptSync method, with only a marginal drop in performance. This outcome underscores the achieved generalization through prototype alignment. Furthermore, in comparison to previous methods such as MaPLe + TPT and PromptAlign, the PromptSync* variant exhibits only a slight increase in latency (0.03 hours) while still improving overall performance.

11 Sensitivity Comparison

We further performed the sensitivity comparison of our method as compared to other state-of-the-art baselines. Figure 2(a) shows the comparison of performance during test time adaptation as the number of views increases. All the results are on ImageNet-A dataset. In comparison to PromptAlign and MaPLe + TPT, their performance almost plateaus around 64 views with insignificant improvement further, while PromptSync shows a consistent improvement with the increase in views and insignificant improvement beyond 128. This proves the generalizability achieved by our method since it optimises base CLIP over a larger number of possible shifts in the dataset, resulting in better performance. Figure 2(b) shows the performance comparison as the number of prompt update steps increases. All the methods increase their performance with an increase in the number of steps; however, our method shows better adaptation to the test sample with more steps in comparison to PromptAlign and MaPLe + TPT. For apples-to-apples comparison we perform a single-step update (with 128 views) following TPT [35].

12 LAION400M Proxy Dataset Analysis

Given CLIP’s impressive zero-shot performance on ImageNet, we opted for ImageNet as a viable proxy source dataset, aligning with prior research [33]. We worked with a subset of LAION400M, comprising 2.5 million images (2 times the size of ImageNet). Furthermore, we carried out an ablation study on the alignment strategy using LAION400M as the source dataset, a dataset known to mirror CLIP’s training dataset [9]. The results for this ablation study is shown in Table 8. Notably, the performance impact remains consistent when utilizing this subset of LAION400M alongside ImageNet. Source class prototypes are computed on the proxy source data to derive the distribution for alignment during test time. As this proxy dataset aligns with the model’s training set, this offline computation remains unchanged despite environmental shifts and only necessitates computation once.