Unleashing the Power of Meta-tuning for Few-shot Generalization
Through Sparse Interpolated Experts

Shengzhuang Chen    Jihoon Tack    Yunqiao Yang    Yee Whye Teh    Jonathan Richard Schwarz    Ying Wei
Abstract

Recent successes suggest that parameter-efficient fine-tuning of foundation models is becoming the state-of-the-art method for transfer learning in vision, gradually replacing the rich literature of alternatives such as meta-learning. In trying to harness the best of both worlds, meta-tuning introduces a subsequent optimization stage of foundation models but has so far only shown limited success and crucially tends to underperform on out-of-distribution (OOD) tasks. In this paper, we introduce Sparse MetA-Tuning (SMAT), a method inspired by sparse mixture-of-experts approaches and trained to isolate subsets of pre-trained parameters automatically for meta-tuning on each task. SMAT successfully overcomes OOD sensitivity and delivers on the promise of enhancing the transfer abilities of vision foundation models beyond parameter-efficient fine-tuning. We establish new state-of-the-art results on a challenging combination of Meta-Dataset augmented with additional OOD tasks in both zero-shot and gradient-based adaptation settings. In addition, we provide a thorough analysis of the superiority of learned over hand-designed sparsity patterns for sparse expert methods and the pivotal importance of the sparsity level in balancing between in-distribution and out-of-distribution generalization. Our code and models are publicly available.

Machine Learning, ICML

1 Introduction

Refer to caption
Figure 1: Average testing accuracy during meta-training for meta-tuning methods. SMAT yields better ID and OOD results and shows an attractive learning speedup.

The emergence of foundation models (Bommasani et al., 2021) has marked a new chapter in machine learning, with pre-trained models in established domains (e.g., vision or language) becoming virtually indispensable and a vibrant research landscape develo** around the design and training of foundation models for new modalities and problems, ranging from the life sciences (Lin et al., 2023) to spectral data (Hong et al., 2023), time series (Yeh et al., 2023), graphs (Liu et al., 2023a) and combinations thereof in multi-modal systems (Yu et al., 2023). Foundation models have seemingly also signaled the convergence of decades of research on transfer learning (see Zhuang et al. (2020) for a survey) to the simple yet powerful paradigm of full or parameter-efficient fine-tuning (Perez et al., 2018; Hu et al., 2021) of the best foundation model available. As with many breakthroughs in science, this convergence runs contrary to an attractive hypothesis: That the explicit formulation of objectives, algorithms, and optimization procedures targeted directly at downstream performance will result in the best transfer learner (most directly advocated for in the learning to learn, or meta-learning community (Thrun & Pratt, 2012; Finn et al., 2017)). Instead, the field’s general belief has shifted towards self-supervised objectives such as autoregressive losses (e.g. Mikolov et al., 2013) or contrastive learning (e.g. Radford et al., 2021) with large models and big data as the best strategy for generalist models with the potential to transfer to a wide variety of tasks.

In this paper, rather than committing fully to this view, we instead join a nascent group of researchers aiming to find a middle point between both paradigms, harnessing their strengths and aiming to find a synergy. Indeed, the recent popularity of instruction tuning for large language models (Zhang et al., 2023) takes a similar view and has emerged as a promising avenue to not only narrow the gap between pre-training and downstream objectives but also enhance zero-shot generalization of pre-trained models. Similarly, meta-tuning aims to enhance the transferability of foundation models through a secondary meta-learning stage initiated once pre-training has converged. Indeed, existing research in the field of Natural Language Processing (NLP) has substantiated the advantages of meta-tuning over traditional fine-tuning and transfer learning approaches, particularly in zero-shot and few-shot testing scenarios (Gao et al., 2021; Min et al., 2021; Chen et al., 2021).

Despite initial progress made, the exploration of meta-tuning in vision still remains notably limited to date. Hu et al. propose a pre-training \rightarrow meta-training \rightarrow fine-tuning pipeline, dubbed PMF (2022), for enhancing the few-shot learning performance of the resulting model relative to the default pre-training \rightarrow fine-tuning approach. With the design principle of simplicity in mind, PMF meta-trains all parameters in a vision transformer using Prototypical Networks (Snell et al., 2017) starting from a pre-trained initialization, yielding the state-of-the-art performance on popular meta-learning benchmarks such as Meta-Dataset (Triantafillou et al., 2020). Despite such promising reported results, we find that this intuitive approach tends to underperform on downstream few-shot tasks, particularly when testing for out-of-distribution (OOD) tasks (i.e., tasks dissimilar to the ones presented during the meta-training stage).

We hypothesize that this low generalization performance on OOD tasks stems from two major factors. (i) The strong emphasis on learning from small amounts of data using a limited number of optimization steps in meta-learning can lead to algorithms that are “greedy” w.r.t. the distribution of tasks presented, sacrificing more generalizable features for performance on the distribution at hand. This leads to a risk of meta-overfitting, a phenomenon previously observed (e.g. Zintgraf et al., 2019; Yao et al., 2021; Chen et al., 2022). (ii) When meta-tuning tasks are diverse, the default setting of updating all parameters suffers from task interference, making optimization unstable and thereby reducing generalization performance.

In introducing our method, we thus explicitly design core model components to overcome these challenges. We address (i) by taking inspiration from recent work (Ilharco et al., 2022; Panigrahi et al., 2023a; Wortsman et al., 2021), noticing that interpolation between pre-trained and fine-tuned weights leads to a trade-off between ID and OOD generalization performance, with an optimal point usually existing between the extremes. We implement this trade-off through a learned gated interpolation implemented with a sparsity constraint. This particular choice also has the added benefit of addressing (ii) by considering a Mixture-of-Experts inspired approach (with each expert defined through sparse masks), which guarantees expressiveness while alleviating task interference. Finally, sparsity has an additional regularizing effect, further reducing the chance of meta-overfitting and thus counteracting poor OOD generalization observed for standard Meta-Tuning (see Figure 1 for a direct comparison with the aforementioned PMF).

In summary, we propose a reformulation of meta-tuning as a process wherein a hypernetwork undergoes meta-training to select a combination of sparse experts based on few-shot examples, which are subsequently interpolated with the pre-trained model to tailor a powerful foundation model for downstream performance on each specific task. The integration of an interpolation strategy alongside specialized experts not only preserves the pre-trained model’s generalization capabilities but also consolidates the knowledge acquired from all meta-tuned tasks without interference. This synergy contributes significantly to our strong performance across both in-distribution and out-of-distribution few-shot generalization scenarios. Furthermore, we showcase the interpretability on task relationship through the experts selected, and compatibility of our approach with both full fine-tuning and parameter-efficient fine-tuning methods, such as LoRA (Hu et al., 2021).

2 Related Work

Few-shot learning and meta-tuning.    Much of few-shot learning (FSL) relies on extracting transferable prior knowledge from a collection of few-shot training task episodes through meta-learning (Hospedales et al., 2020), which can then be utilized for data-efficient learning on unseen but related downstream FSL tasks at test time. Meta-learned inductive biases may take the form of a model initialization (Finn et al., 2017), a learned metric (Snell et al., 2017), a Bayesian prior (Grant et al., 2018) or an optimization strategy (Li et al., 2017). A particular subdomain of FSL, namely, cross-domain FSL algorithms (Li et al., 2022; Triantafillou et al., 2020; Liu et al., 2021; Bateni et al., 2019), explicitly deals with task-distribution shift between meta-training and -testing. Nevertheless, most architectures used in existing work are limited in scale and without large-scale pre-training. Transitioning into the LLM era,  (Min et al., 2021) first study meta-training a pre-trained LLM on a large collection of few-shot in-context learning tasks. Their results highlight the effectiveness of meta-training on improving few-shot in-context learning generalization of powerful pre-trained transformers; motivating several follow-up works in the field of meta-tuning in NLP (Gao et al., 2021; Min et al., 2021; Chen et al., 2021). In computer vision,  Hu et al. propose the simple pre-training, meta-training then fine-tuning (PMF) pipeline (2022) and achieve SOTA performance. Concurrent to our work, Eustratiadis et al. explore meta-tuning from an orthogonal direction to ours by proposing a neural architecture search algorithm (2024) designed to find the optimal model configuration (e.g., optimal arrangement of adapters) for fine-tuning a meta-tuned model for downstream adaptation.

Sparse mixture-of-experts (MoE)    The key idea of sparse MoE is the selective activation of expert modules, usually MLP layers, for each input token during training and inference, thereby achieving graceful scaling. Earlier MoE methods make discrete expert-to-token assignments through a token-choice router scoring experts and selecting the top-k for each token (Shazeer et al., 2017; Riquelme et al., 2021; Lepikhin et al., 2021; Fedus et al., 2021). Alternatively, methods may choose the top-k scored tokens for each expert (Zhou et al., 2022), use stochastic or fixed routers (Roller et al., 2021; Zuo et al., 2022), and more advanced routing techniques (Lewis et al., 2021; Liu et al., 2023b). However, the discrete nature of the assignment poses a serious challenge to stability in training and optimization (Mustafa et al., 2022; Dai et al., 2022). To this end, more recent works on soft MoE consider a soft relaxation or approximation to the otherwise discrete expert-token assignment (Puigcerver et al., 2023), as well as other MoE works that employ a weighted sum of experts in the parameter space (Muqeeth et al., 2023).

Partitioned meta-learning.    The importance of isolating a subset of parameters with high plasticity for optimization-based meta-learning is well-established. A common feature is thus the partitioning of parameters into a set of shared parameters optimized in the outer loop and a (typically smaller) parameter set that implements task adaptation, thus reducing meta-overfitting and memory usage. Early work in this direction (e.g. Raghu et al., 2019; Oh et al., 2020) rely on partitioning heuristics (such as only updating the last layer) or introduce additional context parameters (Zintgraf et al., 2019) which are concatenated with the input vector.

Since then, an increasing amount of attention has been placed on adapters, i.e. compact, parameter-efficient modules which have been shown to be particularly impactful when fine-tuning foundation models, thus being particularly suitable for Meta-Tuning. Popular adapter types such as FiLM (Perez et al., 2018), LoRA (Hu et al., 2021) as well as various alternatives feature in various episodic-training methods (e.g. Requeima et al., 2019; Triantafillou et al., 2021a; Shysheya et al., 2022; Schwarz et al., 2023).

Most closely related to our method is a line of work utilizing sparsification to isolate and train a subset of parameters for rapid adaptation, thus increasing their plasticity for fine-tuning. Most closely related to our method is the aforementioned MSCN (Schwarz & Teh, 2022), although the authors focus their experiments on a more specialized compression problem and do not address the problem of how to tackle a specific sparsity level. Alternative approaches feature sparsification in the outer loop through magnitude-based pruning (Lee et al., 2021), which, while simple, may overly constrain the representational capacity. Similar to over hyper-network inspired approach (Schwarz et al., 2023) predict sparse masks that index into model weights, although they still rely on second-order meta-learning. Finally, the work in (Von Oswald et al., 2021) presents a first-order method for gradient sparsity, demonstrating the approach in traditional meta-learning as well as Continual Learning.

3 Preliminary: Meta-tuning

Meta-tuning aims to improve the few-show learning performance of a pre-trained model on downstream few-shot testing tasks - usually by directly meta-training the pre-trained model over a collection of labeled few-shot training task episodes  (Hu et al., 2022; Min et al., 2021). Specifically, we assume the availability of a pre-trained model f𝜽𝚙𝚛𝚎subscript𝑓superscript𝜽𝚙𝚛𝚎f_{\bm{\theta}^{\mathtt{pre}}}italic_f start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT end_POSTSUBSCRIPT as initialization for meta-training, and a training task distribution 𝒫ID(𝒯)subscript𝒫𝐼𝐷𝒯{\mathcal{P}}_{ID}({\mathcal{T}})caligraphic_P start_POSTSUBSCRIPT italic_I italic_D end_POSTSUBSCRIPT ( caligraphic_T ) from which we may sample fully labeled few-shot training tasks 𝒯i𝒫ID(𝒯)similar-tosubscript𝒯𝑖subscript𝒫𝐼𝐷𝒯{\mathcal{T}}_{i}\sim{\mathcal{P}}_{ID}({\mathcal{T}})caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_P start_POSTSUBSCRIPT italic_I italic_D end_POSTSUBSCRIPT ( caligraphic_T ). Note that we explicitly denote this as in-distribution (ID𝐼𝐷IDitalic_I italic_D). In particular, in the supervised setting, each training task 𝒯isubscript𝒯𝑖{\mathcal{T}}_{i}caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT takes the form of 𝒯i:={i,𝒯is,𝒯iq}assignsubscript𝒯𝑖subscript𝑖superscriptsubscript𝒯𝑖𝑠superscriptsubscript𝒯𝑖𝑞{\mathcal{T}}_{i}:=\{{\mathcal{L}}_{i},{\mathcal{T}}_{i}^{s},{\mathcal{T}}_{i}% ^{q}\}caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT := { caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT , caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT }, where isubscript𝑖{\mathcal{L}}_{i}caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the task loss to be minimized, 𝒯is:={𝐱i,js,𝐲i,js}j=1Nisassignsuperscriptsubscript𝒯𝑖𝑠superscriptsubscriptsuperscriptsubscript𝐱𝑖𝑗𝑠superscriptsubscript𝐲𝑖𝑗𝑠𝑗1superscriptsubscript𝑁𝑖𝑠{\mathcal{T}}_{i}^{s}:=\{\mathbf{x}_{i,j}^{s},\mathbf{y}_{i,j}^{s}\}_{j=1}^{N_% {i}^{s}}caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT := { bold_x start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT , bold_y start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT and 𝒯iq:={𝐱i,jq,𝐲i,jq}j=1Niqassignsuperscriptsubscript𝒯𝑖𝑞superscriptsubscriptsuperscriptsubscript𝐱𝑖𝑗𝑞superscriptsubscript𝐲𝑖𝑗𝑞𝑗1superscriptsubscript𝑁𝑖𝑞{\mathcal{T}}_{i}^{q}\!:=\!\{\mathbf{x}_{i,j}^{q},\mathbf{y}_{i,j}^{q}\}_{j=1}% ^{N_{i}^{q}}caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT := { bold_x start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT , bold_y start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT are labeled support and query sets of Nissuperscriptsubscript𝑁𝑖𝑠N_{i}^{s}italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT, Niqsuperscriptsubscript𝑁𝑖𝑞N_{i}^{q}italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT input-target pairs, respectively. We use the shorthand notations 𝐗𝐗\mathbf{X}bold_X and 𝐘𝐘\mathbf{Y}bold_Y to represent a set of inputs and labels, respectively. Meta-tuning is then realized through the typical episodic-learning setting familiar from meta-learning, i.e., the direct optimization of 𝜽𝜽\bm{\theta}bold_italic_θ on the few-shot learning objective which considers minimizing the task loss on the query predictions given information of the support set, i.e., 𝜽:=argmin𝜽Θ𝔼𝒫ID[i(f𝜽(𝐗iq,𝒯is),𝐘iq)]assignsuperscript𝜽subscriptargmin𝜽Θsubscript𝔼subscript𝒫𝐼𝐷delimited-[]subscript𝑖subscript𝑓𝜽subscriptsuperscript𝐗𝑞𝑖subscriptsuperscript𝒯𝑠𝑖subscriptsuperscript𝐘𝑞𝑖\bm{\theta}^{*}:=\operatorname*{arg\,min}_{\bm{\theta}\in\Theta}\mathbb{E}_{{% \mathcal{P}}_{ID}}[{\mathcal{L}}_{i}(f_{\bm{\theta}}(\mathbf{X}^{q}_{i},{% \mathcal{T}}^{s}_{i}),\mathbf{Y}^{q}_{i})]bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT := start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_θ ∈ roman_Θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT caligraphic_P start_POSTSUBSCRIPT italic_I italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_X start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , caligraphic_T start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , bold_Y start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ], over training task episodes sampled from 𝒫ID(𝒯)subscript𝒫𝐼𝐷𝒯{\mathcal{P}}_{ID}({\mathcal{T}})caligraphic_P start_POSTSUBSCRIPT italic_I italic_D end_POSTSUBSCRIPT ( caligraphic_T ). At test time, we expect to encounter both ID and OOD testing tasks i.e., 𝒯~:={𝒯~is,𝐗~iq}𝒫ID𝒫OODassign~𝒯subscriptsuperscript~𝒯𝑠𝑖superscriptsubscript~𝐗𝑖𝑞similar-tosubscript𝒫𝐼𝐷subscript𝒫𝑂𝑂𝐷\tilde{{\mathcal{T}}}:=\{\tilde{{\mathcal{T}}}^{s}_{i},\tilde{\mathbf{X}}_{i}^% {q}\}\sim{\mathcal{P}}_{ID}\cup{\mathcal{P}}_{OOD}over~ start_ARG caligraphic_T end_ARG := { over~ start_ARG caligraphic_T end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT } ∼ caligraphic_P start_POSTSUBSCRIPT italic_I italic_D end_POSTSUBSCRIPT ∪ caligraphic_P start_POSTSUBSCRIPT italic_O italic_O italic_D end_POSTSUBSCRIPT, where 𝒫OOD𝒫IDsubscript𝒫𝑂𝑂𝐷subscript𝒫𝐼𝐷{\mathcal{P}}_{OOD}\neq{\mathcal{P}}_{ID}caligraphic_P start_POSTSUBSCRIPT italic_O italic_O italic_D end_POSTSUBSCRIPT ≠ caligraphic_P start_POSTSUBSCRIPT italic_I italic_D end_POSTSUBSCRIPT is an unseen OOD task distribution. For each testing task, we evaluate the few-shot generalization performance of the meta-tuned model by predicting query labels. Hence, our objective is to develop a meta-tuning algorithm that enables the meta-tuned 𝜽superscript𝜽\bm{\theta}^{*}bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT to attain optimal few-shot generalization performance across both ID and OOD testing tasks.

Refer to caption
Figure 2: Overview of the proposed Sparse MetA-Tuning approach, showing meta-training and inference procedures for a single task 𝒯isubscript𝒯𝑖{\mathcal{T}}_{i}caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. SMAT meta-learns a shared knowledge pool \mathcal{M}caligraphic_M consisting of |||\mathcal{M}|| caligraphic_M | sparse interpolated experts characterized by a common, learnable set of dense parameters 𝜽δsuperscript𝜽𝛿\bm{\theta}^{\delta}bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT and distinct, learnable sets of gating masks {𝒛m}m=1||subscriptsuperscriptsubscript𝒛𝑚𝑚1\{\bm{z}_{m}\}^{|\mathcal{M}|}_{m=1}{ bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT with sparsity constraints. To construct each task-specific model 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for both meta-training and inference, (1) SMAT first combines experts via a weighted-sum with merging weights 𝜶isubscript𝜶𝑖\bm{\alpha}_{i}bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT generated by a meta-learned hypernetwork h𝜻subscript𝜻h_{\bm{\zeta}}italic_h start_POSTSUBSCRIPT bold_italic_ζ end_POSTSUBSCRIPT based on the task’s support set 𝒯issuperscriptsubscript𝒯𝑖𝑠\mathcal{T}_{i}^{s}caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT. (2) The experts are then subsequently combined with the frozen pre-trained model 𝜽𝚙𝚛𝚎superscript𝜽𝚙𝚛𝚎\bm{\theta}^{\mathtt{pre}}bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT to enhance both in-distribution (ID) and out-of-distribution (OOD) generalization performance. Alongside (3) the query prediction loss 𝒯iq𝚌𝚎(𝜽i)subscriptsuperscript𝚌𝚎superscriptsubscript𝒯𝑖𝑞subscript𝜽𝑖{\mathcal{L}}^{\mathtt{ce}}_{\mathcal{T}_{i}^{q}}(\bm{\theta}_{i})caligraphic_L start_POSTSUPERSCRIPT typewriter_ce end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), (4) knowledge distillation with task-specific dense teachers 𝒯iq𝚔𝚍(𝜽i,𝜽i𝚝𝚛)subscriptsuperscript𝚔𝚍superscriptsubscript𝒯𝑖𝑞subscript𝜽𝑖superscriptsubscript𝜽𝑖𝚝𝚛{\mathcal{L}}^{\mathtt{kd}}_{\mathcal{T}_{i}^{q}}(\bm{\theta}_{i},\bm{\theta}_% {i}^{\mathtt{tr}})caligraphic_L start_POSTSUPERSCRIPT typewriter_kd end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT ) is introduced during meta-training to promote specialization and cooperation of the sparse interpolated experts, ensuring optimization success.

4 SMAT: Sparse MetA-Tuning

4.1 Meta-training

As discussed in Section 1, naively sharing and updating all pre-trained parameters across all tasks in meta-tuning leads to task interference in optimization (Yu et al., 2020; Wang et al., 2020). To address this issue, we instead hypothesize that the solution for each task (ID or OOD) comprises a task-specific mixture of a common pool of knowledge covering a broad range of tasks. The knowledge pool is represented by distinct sets of model parameters (i.e., experts), which can be combined cooperatively as a complementary addition to the pre-trained model to promote systematic generalization. Formally, we assume that each task-specific model 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is derived from aggregating the experts via a task-specific weighted sum in the parameter space:

𝜽i=𝜽𝚙𝚛𝚎+m=1||αi,m𝜽mδ,subscript𝜽𝑖superscript𝜽𝚙𝚛𝚎superscriptsubscript𝑚1subscript𝛼𝑖𝑚subscriptsuperscript𝜽𝛿𝑚\bm{\theta}_{i}=\bm{\theta}^{\mathtt{pre}}+\sum_{m=1}^{|{\mathcal{M}}|}\alpha_% {i,m}\bm{\theta}^{\delta}_{m},bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , (1)

where 𝜽mδsuperscriptsubscript𝜽𝑚𝛿{\bm{\theta}}_{m}^{\delta}bold_italic_θ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT represents the m𝑚mitalic_m-th expert in the pool {\mathcal{M}}caligraphic_M, and 𝜶i:=[αi,1,αi,2,,αi,||]assignsubscript𝜶𝑖subscript𝛼𝑖1subscript𝛼𝑖2subscript𝛼𝑖\bm{\alpha}_{i}:=[\alpha_{i,1},\alpha_{i,2},...,\alpha_{i,|{\mathcal{M}}|}]bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT := [ italic_α start_POSTSUBSCRIPT italic_i , 1 end_POSTSUBSCRIPT , italic_α start_POSTSUBSCRIPT italic_i , 2 end_POSTSUBSCRIPT , … , italic_α start_POSTSUBSCRIPT italic_i , | caligraphic_M | end_POSTSUBSCRIPT ] are the weights. This way of merging experts in Eqn. (1) is appealing due to its high expressiveness, as supported by recent findings revealing that merging multiple sets of specialized parameters through simple arithmetic can lead to a better overall multi-task solution (Ilharco et al., 2023; Matena & Raffel, 2021). Moreover, the complementary addition of experts to the pre-trained backbone effectively decouples the two for optimization - meta-training {𝜽mδ}||superscriptsubscriptsuperscript𝜽𝛿𝑚\{\bm{\theta}^{\delta}_{m}\}^{|{\mathcal{M}}|}{ bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT while kee** 𝜽𝚙𝚛𝚎superscript𝜽𝚙𝚛𝚎\bm{\theta}^{\mathtt{pre}}bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT frozen provides sufficient capacity and preserves pre-trained knowledge.

That said, meta-tuning, in our case, is seen as a procedure for discovering a generalizable selection rule for assigning appropriate experts to tasks. We present an overview of our proposed method in Fig. 2.

Sparsification of experts.    At the core of our method lies the question on where and how to find a set of experts in the model during meta-tuning. Previous studies on partitioned meta-learning, as well as recent works on Mixture-of-Experts, propose a seemingly reasonable solution: Experts are incorporated into a heuristically (and a-priori) chosen subset of modules in the model. For example, the batch-norm (BN) layers in ResNets (Triantafillou et al., 2021b) or the MLP layers in Vision Transformers (ViT) (Puigcerver et al., 2023). Although these approaches based on fixed partitioning are more memory-efficient and generally result in improved performance compared to their non-expert counterparts, they suffer from a significant bias due to their manually crafted selection (i.e., the concrete choice of 𝜽mδsuperscriptsubscript𝜽𝑚𝛿\bm{\theta}_{m}^{\delta}bold_italic_θ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT). This bias may be suboptimal when it conflicts with the intricate dynamics of meta-training and, in addition, is not always model-agnostic, e.g., there are no BN layers in ViT.

Thus, instead of specifying experts as prior knowledge, we propose to automatically identify the sparsity patterns in experts through meta-tuning using a maximum likelihood objective while considering sparsity constraints on the experts’ capacity, thus encouraging specialization. Meta-learning sparsity patterns not only shows minimal bias but also allows for differentiation in the architecture of experts, giving our model a wider range of capabilities to better handle various types of distribution shifts. (Lee et al., 2022).

Specifically, we induce a learnable sparsity pattern in the m𝑚mitalic_m-th expert in the form of a sparse reparameterization 𝜽mδ𝒛mdirect-productsubscriptsuperscript𝜽𝛿𝑚subscript𝒛𝑚\bm{\theta}^{\delta}_{m}\odot\bm{z}_{m}bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT of the dense expert, where 𝒛msubscript𝒛𝑚\bm{z}_{m}bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT denotes a learnable binary mask with dimension |𝒛m|=|𝜽mδ|subscript𝒛𝑚subscriptsuperscript𝜽𝛿𝑚|\bm{z}_{m}|=|\bm{\theta}^{\delta}_{m}|| bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT | = | bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT |, and direct-product\odot denotes element-wise multiplication. To learn binary masks through gradient-based optimization, we take inspiration from (Schwarz & Teh, 2022), and reparameterize the m𝑚mitalic_m-th binary mask using an underlying continuous distribution qϕmsubscript𝑞subscriptbold-italic-ϕ𝑚q_{\bm{\phi}_{m}}italic_q start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT with parameters ϕmsubscriptbold-italic-ϕ𝑚\bm{\phi}_{m}bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT. The reparameterization samples, i.e., 𝐬mqϕm(𝐬)similar-tosubscript𝐬𝑚subscript𝑞subscriptbold-italic-ϕ𝑚𝐬\mathbf{s}_{m}\!\sim\!q_{\bm{\phi}_{m}}(\mathbf{s})bold_s start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_s ), can be transformed to have values exactly 00 and 1111 through a hard rectification 𝐳m=g(𝐬m):=min(1,max(0,𝐬m))subscript𝐳𝑚𝑔subscript𝐬𝑚assign10subscript𝐬𝑚\mathbf{z}_{m}\!=\!g(\mathbf{s}_{m})\!:=\!\min(1,\max(0,\mathbf{s}_{m}))bold_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_g ( bold_s start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) := roman_min ( 1 , roman_max ( 0 , bold_s start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ). As a result, sparsity in 𝒛msubscript𝒛𝑚\bm{z}_{m}bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT can be enforced by limiting the probability of 𝒔msubscript𝒔𝑚\bm{s}_{m}bold_italic_s start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT being non-zero which can be straight-forwardly expressed through the CDF Q𝑄Qitalic_Q of qϕmsubscript𝑞subscriptbold-italic-ϕ𝑚q_{\bm{\phi}_{m}}italic_q start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT, i.e. 1Qϕm(𝐬0)1subscript𝑄subscriptbold-italic-ϕ𝑚𝐬01-Q_{\bm{\phi}_{m}}(\mathbf{s}\leq 0)1 - italic_Q start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_s ≤ 0 ). Choosing qϕmsubscript𝑞subscriptbold-italic-ϕ𝑚q_{\bm{\phi}_{m}}italic_q start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT as the stretched hard concrete distribution (Louizos et al., 2018) enables both gradient-based optimization through reparameterization as well as analytical evaluation of the CDF.

From sparse experts to sparse interpolated experts.    Though the aforementioned sparse reparameterization eventually reduces parameter counts (e.g., by removing zeros in expert parameters), it leads to a strong increase in the number of parameters compared to its non-sparse counterpart at the beginning of meta-tuning – due to learning 𝒛msubscript𝒛𝑚\bm{z}_{m}bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT and 𝜽mδsuperscriptsubscript𝜽𝑚𝛿\bm{\theta}_{m}^{\delta}bold_italic_θ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT concurrently. Moreover, assigning each expert its own dense underlying parameters 𝜽mδsuperscriptsubscript𝜽𝑚𝛿\bm{\theta}_{m}^{\delta}bold_italic_θ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT greatly hinders knowledge transfer among experts, which contradicts the principles of partitioned meta-learning emphasizing collaboration of both task-agnostic and task-specific components.

To tackle these issues, we propose sharing dense parameters across sparse reparameterization of different experts, i.e., 𝜽mδ=𝜽δ,msuperscriptsubscript𝜽𝑚𝛿superscript𝜽𝛿for-all𝑚\bm{\theta}_{m}^{\delta}=\bm{\theta}^{\delta},\forall mbold_italic_θ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT = bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT , ∀ italic_m. By rearranging Eqn. (1) slightly, the experts now essentially become different sparse interpolations between the same pre-trained and meta-tuned models (more details in Appendix D). Although this formulation of MoE may initially appear bold, it is well-supported by recent works (Panigrahi et al., 2023b) which suggest that multiple task-specific optimal points can coexist between the same set of pre-trained and meta-tuned models, offering favorable trade-offs for both in-distribution (ID) and some out-of-distribution (OOD) performance. Therefore, our approach of learning sparse interpolated experts can be seen as an inductive bias that promotes the recovery of these optimal interpolation points through meta-tuning. Favorably, this partitioning allows for knowledge transfer among experts through the overlap** regions in their masks.

Meta-learning expert selection through a hypernetwork.     In theory, there are two possible approaches to achieve task-specific inference for the expert merging scores 𝜶isubscript𝜶𝑖\bm{\alpha}_{i}bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. The first one is a meta-learned merging score initialization 𝜶0subscript𝜶0\bm{\alpha}_{0}bold_italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT combined with inner-loop gradient-descent on the task support set; the second involves a meta-learned global hypernetwork, h𝜻(𝒯is)subscript𝜻superscriptsubscript𝒯𝑖𝑠h_{\bm{\zeta}}({\mathcal{T}}_{i}^{s})italic_h start_POSTSUBSCRIPT bold_italic_ζ end_POSTSUBSCRIPT ( caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ), parameterized by 𝜻𝜻\bm{\zeta}bold_italic_ζ, that directly outputs 𝜶isubscript𝜶𝑖\bm{\alpha}_{i}bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT conditioned on the task support set 𝒯issuperscriptsubscript𝒯𝑖𝑠{\mathcal{T}}_{i}^{s}caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT. We opt for the latter approach as it scales better with larger model sizes. For implementation, we use the pre-trained model f𝜽𝚙𝚛𝚎subscript𝑓superscript𝜽𝚙𝚛𝚎f_{\bm{\theta}^{\mathtt{pre}}}italic_f start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT end_POSTSUBSCRIPT to encode each support image 𝐱i,js𝒯issuperscriptsubscript𝐱𝑖𝑗𝑠superscriptsubscript𝒯𝑖𝑠\mathbf{x}_{i,j}^{s}\in\mathcal{T}_{i}^{s}bold_x start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ∈ caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT into a vector embedding f𝜽𝚙𝚛𝚎(𝐱i,js)subscript𝑓superscript𝜽𝚙𝚛𝚎superscriptsubscript𝐱𝑖𝑗𝑠f_{\bm{\theta}^{\mathtt{pre}}}(\mathbf{x}_{i,j}^{s})italic_f start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ). The support set embeddings are aggregated into class prototypes which are then concatenated into a sequence and fed into a single trainable transformer block to obtain 𝜶i||superscriptsubscript𝜶𝑖superscript\bm{\alpha}_{i}^{\prime}\in\mathbb{R}^{|{\mathcal{M}}|}bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT as the output. We treat 𝜶isuperscriptsubscript𝜶𝑖\bm{\alpha}_{i}^{\prime}bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT as the logits for activating the experts. Instead of choosing the top-k activated ones, which can cause training instability issues and, more importantly, restrict the number of experts per task, we employ the Gumbel-Sigmoid trick (Jang et al., 2017) to sample a soft activation value (0,1)absent01\in(0,1)∈ ( 0 , 1 ) for each expert, followed by normalization to obtain 𝜶i(0,1)||subscript𝜶𝑖superscript01\bm{\alpha}_{i}\in(0,1)^{|{\mathcal{M}}|}bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ ( 0 , 1 ) start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT.

Enhancing expert specialization through task-specific dense teachers.    Specialization and cooperation among experts play a crucial role in a MoE model. One way to promote specialization is by penalizing the similarity between experts. For example, an orthogonal penalty can be applied to pairs of experts. However, incorporating such explicit penalties makes the optimization problem in the meta-objective more challenging, as it introduces a trade-off with respect to the few-shot prediction performance. To circumvent this trade-off, we propose an alternative approach that utilizes a knowledge distillation loss 𝒯iq𝚔𝚍(𝜽i,𝜽i𝚝𝚛)subscriptsuperscript𝚔𝚍superscriptsubscript𝒯𝑖𝑞subscript𝜽𝑖superscriptsubscript𝜽𝑖𝚝𝚛\mathcal{L}^{\mathtt{kd}}_{\mathcal{T}_{i}^{q}}(\bm{\theta}_{i},\bm{\theta}_{i% }^{\mathtt{tr}})caligraphic_L start_POSTSUPERSCRIPT typewriter_kd end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT ) (Hinton et al., 2015) between the merged MoE model 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and a teacher network 𝜽i𝚝𝚛superscriptsubscript𝜽𝑖𝚝𝚛\bm{\theta}_{i}^{\mathtt{tr}}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT. By using a highly task-specific teacher, the distillation loss imposes the weighted-sum of experts that constitute each 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT to mimic the behaviour of the teacher, thus enhancing both specialization and cooperation implicitly through knowledge transfer. Furthermore, the predictive performance of 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is not impeded as the loss consistently encourages the current student to achieve a better task-specific generalization on the query set, similar to that of the teacher.

We generate the teacher 𝜽i𝚝𝚛superscriptsubscript𝜽𝑖𝚝𝚛\bm{\theta}_{i}^{\mathtt{tr}}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT dynamically in each meta-training episode by performing K𝐾Kitalic_K-step gradient descent starting from 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT on the query loss 𝒯iq𝚌𝚎subscriptsuperscript𝚌𝚎superscriptsubscript𝒯𝑖𝑞\mathcal{L}^{\mathtt{ce}}_{\mathcal{T}_{i}^{q}}caligraphic_L start_POSTSUPERSCRIPT typewriter_ce end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (we find K=1𝐾1K=1italic_K = 1 is sufficient in practice). It is worth noting that unlike 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we do not limit the capacity of 𝜽i𝚝𝚛superscriptsubscript𝜽𝑖𝚝𝚛\bm{\theta}_{i}^{\mathtt{tr}}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT through sparse regularization. This results in a teacher model that is both dense in modulation (i.e., 𝜽i𝚝𝚛𝜽𝚙𝚛𝚎superscriptsubscript𝜽𝑖𝚝𝚛superscript𝜽𝚙𝚛𝚎\bm{\theta}_{i}^{\mathtt{tr}}-\bm{\theta}^{\mathtt{pre}}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT - bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT), making it expressive and highly task-specific. Importantly, we do not propagate the loss gradients through 𝜽i𝚝𝚛𝜽isuperscriptsubscript𝜽𝑖𝚝𝚛subscript𝜽𝑖\bm{\theta}_{i}^{\mathtt{tr}}\rightarrow\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT → bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. This approach is thus similar to the use of bootstrap** in optimization-based meta-learning (Flennerhag et al., 2022; Tack et al., 2024).

Meta-optimization with controlled expert sparsity.    We are now ready to state our optimization problem for meta-tuning. To enable precise control of expert sparsity- which is pivotal for controlling the trade-offs between ID and OOD generalization performance, we solve an optimization problem for few-shot learning performance under sparsity constraints, namely, 1L0(𝒛m)dim(𝒛m)τm,m[||]formulae-sequence1subscript𝐿0subscript𝒛𝑚dimsubscript𝒛𝑚subscript𝜏𝑚for-all𝑚delimited-[]1-\frac{L_{0}(\bm{z}_{m})}{\text{dim}(\bm{z}_{m})}\geq\tau_{m},\forall m\in% \big{[}|{\mathcal{M}}|\big{]}1 - divide start_ARG italic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) end_ARG start_ARG dim ( bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) end_ARG ≥ italic_τ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , ∀ italic_m ∈ [ | caligraphic_M | ] where τm[0,1]subscript𝜏𝑚01\tau_{m}\in[0,1]italic_τ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∈ [ 0 , 1 ] are the targeted sparsity levels. In practice, we optimize the Lagrangian associated with the constraint optimization problem (with Lagrangian multipliers 𝝀𝝀\bm{\lambda}bold_italic_λ) during meta-tuning:

min𝜽δ,𝜻,𝚽max𝝀𝟎subscriptsuperscript𝜽𝛿𝜻𝚽subscript𝝀0\displaystyle\min\limits_{\bm{\theta}^{\delta},\bm{\zeta},\bm{\Phi}}\max% \limits_{\bm{\lambda}\geq\mathbf{0}}\ roman_min start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT , bold_italic_ζ , bold_Φ end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT bold_italic_λ ≥ bold_0 end_POSTSUBSCRIPT 𝔼𝒯i𝒫ID[𝒯iq𝚌𝚎(𝜽i)+𝒯iq𝚔𝚍(𝜽i,𝜽i𝚝𝚛)]subscript𝔼similar-tosubscript𝒯𝑖subscript𝒫𝐼𝐷delimited-[]subscriptsuperscript𝚌𝚎superscriptsubscript𝒯𝑖𝑞subscript𝜽𝑖subscriptsuperscript𝚔𝚍superscriptsubscript𝒯𝑖𝑞subscript𝜽𝑖superscriptsubscript𝜽𝑖𝚝𝚛\displaystyle\mathbb{E}_{{\mathcal{T}}_{i}\sim{\mathcal{P}}_{ID}}\big{[}{% \mathcal{L}}^{\mathtt{ce}}_{\mathcal{T}_{i}^{q}}(\bm{\theta}_{i})+\mathcal{L}^% {\mathtt{kd}}_{\mathcal{T}_{i}^{q}}(\bm{\theta}_{i},\bm{\theta}_{i}^{\mathtt{% tr}})\big{]}blackboard_E start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_P start_POSTSUBSCRIPT italic_I italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ caligraphic_L start_POSTSUPERSCRIPT typewriter_ce end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + caligraphic_L start_POSTSUPERSCRIPT typewriter_kd end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT ) ]
+m=1||λm(1|ϕm|k=1|ϕm|τQϕm(sk0)),superscriptsubscript𝑚1subscript𝜆𝑚1subscriptbold-italic-ϕ𝑚superscriptsubscript𝑘1subscriptbold-italic-ϕ𝑚𝜏subscript𝑄subscriptbold-italic-ϕ𝑚subscript𝑠𝑘0\displaystyle+\sum_{m=1}^{|{\mathcal{M}}|}\lambda_{m}(\frac{1}{|\bm{\phi}_{m}|% }\sum_{k=1}^{|\bm{\phi}_{m}|}\tau-Q_{\bm{\phi}_{m}}(s_{k}\leq 0)),+ ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG | bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT | end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT | end_POSTSUPERSCRIPT italic_τ - italic_Q start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ≤ 0 ) ) ,
where 𝜽i=𝜽𝚙𝚛𝚎+m=1||αi,m(𝒛m𝜽δ),subscript𝜽𝑖superscript𝜽𝚙𝚛𝚎superscriptsubscript𝑚1subscript𝛼𝑖𝑚direct-productsubscript𝒛𝑚superscript𝜽𝛿\displaystyle\bm{\theta}_{i}=\bm{\theta}^{\mathtt{pre}}+\sum_{m=1}^{|{\mathcal% {M}}|}\alpha_{i,m}(\bm{z}_{m}\odot\bm{\theta}^{\delta}),bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT ) ,
𝒛mqϕm;𝜶ih𝜻(𝒯is),formulae-sequencesimilar-tosubscript𝒛𝑚subscript𝑞subscriptbold-italic-ϕ𝑚similar-tosubscript𝜶𝑖subscript𝜻superscriptsubscript𝒯𝑖𝑠\displaystyle\bm{z}_{m}\sim q_{\bm{\phi}_{m}};\bm{\alpha}_{i}\sim h_{\bm{\zeta% }}({\mathcal{T}}_{i}^{s}),bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT ; bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ italic_h start_POSTSUBSCRIPT bold_italic_ζ end_POSTSUBSCRIPT ( caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) , (2)

in which the objective is the minimization problem in the first line while the aforementioned sparsity constraints translate to the maximization problem in the second. We put the sparsity constraints on individual masks, hence experts, as opposed to the overall 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT to allow task-dependent model capacity in 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT through selective merging. For simplicity, we set a common constraint for all masks i.e., τm=τ,msubscript𝜏𝑚𝜏for-all𝑚\tau_{m}=\tau,\forall mitalic_τ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_τ , ∀ italic_m, which is treated as a hyperparameter. We use simultaneous gradient descent and projected gradient ascent for optimizing Eqn. (4.1). To avoid over-penalizing the model capacity from surpassing the sparsity constraints, we reset λmsubscript𝜆𝑚\lambda_{m}italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT to zero after its associated sparsity constraint is met (Gallego-Posada et al., 2022). This results in final sparsity close to the target τ𝜏\tauitalic_τ, enabling precise control of the expert sparsity levels. The pseudocodes for meta-training can be found in Section A.1.

4.2 Meta-testing

In the next two sections, we outline the procedures for fine-tuning meta-tuned models using SMAT. SMAT is fully compatible with existing off-the-shelf fine-tuning techniques. Additionally, we introduce a gradient-free fine-tuning method specifically designed for SMAT, which can be optionally employed during downstream task adaptation at test time in more computation-restricted scenarios. Our pseudocodes and ablation results for various fine-tuning techniques with SMAT are provided in Section A.2 and Section C.2 .

Gradient-free optimization for expert selection.    Prior work has discovered that further task-specific adaptation of a meta-trained model is essential for improving its performance on OOD tasks during meta-testing (Hu et al., 2022; Chen et al., 2023; Li et al., 2022). Although effective, one major limitation of these adaptation methods is the reliance on back-propagation of the gradients, which can be expensive, making these methods inefficient and potentially impractical due to poor scaling with model size. To this end, we propose an adaptation strategy, specifically designed for SMAT, that bypasses the gradient computation. At the core of our method lies the hypothesis that each expert selection score can be descreatzed i.e, either 0 or 1, which aligns with the intuition that each expert knowledge is either required or not for solving each tasks. We can thus optimize the expert selection score in a binary hypothesize space i.e, 𝜶i{0,1}||subscript𝜶𝑖superscript01\bm{\alpha}_{i}\in\{0,1\}^{|{\mathcal{M}}|}bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { 0 , 1 } start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT.

Gradient-based fine-tuning with SMAT.    Following (Hu et al., 2022), we also consider gradient-based fine-tuning of our meta-tuned model at meta-testing time. With only few changes, our method is fully compatible with existing full fine-tuning (i.e., fine-tuning the entire model), and parameter-efficient fine-tuning (PEFT) techniques. Specifically, we use 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in Eqn. (4.1) as the task-specific model initialization, before applying any off-the-shelf fine-tuning technique for further optimizing 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT on the support set of each task.

5 Experiment

We now verify the efficacy and competitiveness of SMAT on standard meta-learning benchmarks. Additional details and results can be found in Appendix B and C.

Setup.  We conduct meta-tuning experiments on Meta-dataset (MD(Triantafillou et al., 2020), which is a widely studied large-scale cross-domain few-shot learning benchmark. As in PMF (Hu et al., 2022), we adhere to the official guidelines and employ the standard meta-training and meta-testing splits for meta-tuning and meta-testing. We select all hyperparameters and the meta-tuned checkpoint for testing using the official meta-validation split. In order to obtain a more comprehensive evaluation of the meta-tuned models, we introduce additional OOD datasets for meta-testing only, which were not used during the meta-tuning process on MD.

Table 1: Few-shot testing results on the Meta-dataset benchmark and additional OOD testing datasets for methods using DINO-ViT-Small backbone. and respectively indicate published results in (Hu et al., 2022) and (Basu et al., 2023). Gray indicates our method.
w/o fine-tuning with gradient-based fine-tuning
Datasets Pre PMF SoftMerge SMAT Pre+full PMF+full SoftMerge+full SMAT+full Pre+LoRA PMF+LoRA SMAT+LoRA
ImageNet 73.48 73.54 74.33 74.69 73.54 74.59 74.71 75.24 74.22 73.54 75.72
Aircraft 62.17 88.33 88.80 89.78 75.4 88.33 90.60 90.01 80.8 89.75 90.71
Omniglot 54.33 91.79 91.24 89.84 78.7 91.79 92.01 90.83 80.8 92.78 90.99
CUB 85.37 91.02 91.54 92.57 85.4 91.02 91.95 92.57 85.8 91.17 92.57
DTD 83.67 81.64 80.98 86.29 86.9 86.61 86.84 88.41 86.8 86.73 88.28
Quickdraw 60.59 79.23 78.98 79.17 73.6 79.23 79.90 79.17 72.7 79.23 78.83
Fungi 56.26 74.2 72.40 73.31 54.7 74.20 72.40 73.31 59.8 75.44 73.31
VGGFlower 94.45 94.12 96.89 97.22 94.2 94.12 97.01 97.22 94.8 96.05 97.25
ID Avg 71.29 84.23 84.40 85.36 77.81 84.99 85.56 85.84 79.47 85.59 85.88
TrafficSig 53.7 54.37 56.21 57.72 87.3 88.85 89.91 91.33 88.1 89.14 90.18
MSCOCO 54.58 57.04 55.75 58.81 61.5 62.59 62.15 63.11 62.1 61.71 63.38
Cifar10 85.64 80.82 84.58 87.05 92.48 89.61 91.84 92.21 93.33 91.53 92.46
Cifar100 76.86 69.11 70.85 77.46 86.13 82.54 85.88 86.12 86.17 85.06 85.88
MNIST 78.57 93.33 94.16 94.43 92.54 96.44 96.20 96.73 94.98 96.41 96.46
Sketch 47.25 41.10 43.30 47.76 56.39 49.65 53.85 56.67 57.34 47.59 55.63
Pet 91.73 91.37 89.84 91.97 92.03 91.73 90.48 91.97 92.06 92.01 92.31
Clipart 55.19 53.92 54.83 58.97 67.18 62.83 65.50 65.79 66.51 60.6 66.07
Food 62.64 61.89 63.04 65.59 65.08 62.97 63.36 66.99 65.06 62.71 67.77
Cars 34.58 38.00 36.21 36.79 40.98 40.07 41.62 42.39 39.49 42.37 40.05
OOD Avg 64.07 64.10 64.87 67.65 74.16 72.73 74.08 75.32 74.51 72.91 75.02

Baselines.  We compare SMAT to two types of baselines: (a) the Pre-trained model without meta-tuning, and (b) meta-tuning methods: PMF (Hu et al., 2022), which is the SOTA on MD. To compare against a MoE baseline for meta-tuning, we adopt the recently proposed SMEAR (Muqeeth et al., 2023) which implements soft merging of experts in the parameter space, and denote this baseline by SoftMerge. All methods use DINO-ViT-Small (Caron et al., 2021) as the pre-trained backbone.

Evaluation.  At meta-testing time, we resort to the ProtoNet (Snell et al., 2017) classifier for performing direct inference on each few-shot testing tasks without further adaptation. When considering task-specific fine-tuning on the support sets, we follow the same protocols in PMF (Hu et al., 2022) for all models, using Full (fine-tuning the entire model) and LoRA (Hu et al., 2021) as fine-tuning methods. Namely, for each dataset, we perform a hyperparameter search on a few validation tasks to obtain the optimal learning rate for task-specific fine-tuning using the Adam optimizer for 50 steps for each testing tasks from that dataset.

5.1 SMAT achieves new SOTA performance

In Tab. 1, SMAT consistently achieves the highest overall few-shot classification accuracy for both ID and OOD meta-testing with and without adaptation on task support sets. Results in Table A4 for supervised pre-trained backbone further demonstrate the superiority of our approach over baselines. More specifically, results in Tab. 1 show that:

SMAT is a better out-of-the-box few-shot learner.    SMAT attains the best few-shot learning performance in 5/8 ID and 7/10 OOD datasets without adaptation, outperforming the baseline, PMF, by 0.91% and 3.17% on average in ID and OOD evaluation settings, respectively.

SMAT is a transferable initialization for few-shot fine-tuning.    When considering task-specific adaptation through fine-tuning on the support set of each task, SMAT shows great compatibility with off-the-shelf fine-tuning techniques. Specifically, fine-tuning starting from SMAT’s 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT leads to the best performance among all baselines when applying the same fine-tuning technique. SMAT improves on PMF by as much as 0.82% (ID) and 2.48% (OOD) when fully fine-tuning the entire meta-tuned model as initialization.

SMAT achieves superior OOD generalization performance.    While PMF exhibits relatively lower OOD performance w.r.t. the pre-trained baselines, SMAT, in contrast, achieves improved generalization performance, outperforming the pre-trained by 3.2% and at least 0.5% (when both +LoRA) for without and with adaptation, respectively.

5.2 Roles of sparsity in meta-tuning with SMAT

ID vs. OOD tradeoff through controlled sparsity levels.    We observe that adjusting the expert sparsity τ𝜏\tauitalic_τ allows us to balance the trade-off between in-domain (ID) and out-of-domain (OOD) performance of our meta-tuned SMAT models. In Fig. 3, we see that the OOD performance generally improves while the ID performance decreases as the expert sparsity level τ𝜏\tauitalic_τ increases. We hypothesize that this result is due to the stronger intrinsic meta-regularization effect associated with higher sparsity constraints, as well as the better preservation of the more generic pre-trained features through weight interpolation between the meta-tuned and pre-trained parameters in our formulation in sparse interpolated experts (see Eqn. (4.1)). Both of these factors help to mitigate meta-overfitting to the ID meta-training tasks hence improve meta-generalization.

Sparsity in experts encourages specialization.    Meta-learned sparsity patterns in the experts 𝒛m𝜽δdirect-productsubscript𝒛𝑚superscript𝜽𝛿\bm{z}_{m}\odot\bm{\theta}^{\delta}bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT induce sparsity in the meta-gradients which alleviates harmful task-interference, thereby encouraging specialization among the experts. In Fig. 4, we compare the average alignment of meta-gradients between tasks during training for two SMAT models with different expert sparsity levels τ𝜏\tauitalic_τ, where the alignment of gradients, defined as 𝔼𝒯i,𝒯j𝒫ID[cos(𝜽i,𝜽j)]subscript𝔼similar-tosubscript𝒯𝑖subscript𝒯𝑗subscript𝒫𝐼𝐷delimited-[]cossubscript𝜽subscript𝑖subscript𝜽subscript𝑗\mathbb{E}_{{\mathcal{T}}_{i},{\mathcal{T}}_{j}\sim{\mathcal{P}}_{ID}}[\text{% cos}(\nabla_{\bm{\theta}}{\mathcal{L}}_{i},\nabla_{\bm{\theta}}{\mathcal{L}}_{% j})]blackboard_E start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , caligraphic_T start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∼ caligraphic_P start_POSTSUBSCRIPT italic_I italic_D end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ cos ( ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ], is respectively computed for 𝜽{𝜽δ,𝒛m𝜽δ,m[||]\bm{\theta}\in\{\bm{\theta}^{\delta},\bm{z}_{m}\odot\bm{\theta}^{\delta},\ % \forall m\in[|{\mathcal{M}}|]bold_italic_θ ∈ { bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT , bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT , ∀ italic_m ∈ [ | caligraphic_M | ], i.e., the overall meta-tuned parameter and each expert individually. The results show that the higher sparsity level τ=0.9𝜏0.9\tau=0.9italic_τ = 0.9 in SMAT can lead to greater alignment of meta-gradients between tasks. Moreover, the alignments in the experts’ meta-gradients (which are sparse) are generally higher than that of the overall one, (i.e., w.r.t. 𝜽δsuperscript𝜽𝛿\bm{\theta}^{\delta}bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT in black) - a sign for development of each expert into highly specialized parameters.

Table 2: Few-shot testing results on the Meta-dataset benchmark for various SMAT meta-tuned Vision Transformer backbones.
with gradient-based full fine-tuning
SMAT meta-tuned ViT backbone Params. In-domain Out-of-domain Avg
INet Acraft Omglot CUB DTD Qdraw Fungi Flower Sign Coco Cifar10 Cifar100 Mnist
DINO Small (Caron et al., 2021) 21M 75.24 90.01 90.83 92.57 88.41 79.17 73.31 97.21 91.33 63.11 92.08 85.91 96.73 86.42
DINO Base (Caron et al., 2021) 86M 78.28 92.19 92.89 93.41 86.97 80.35 75.59 96.93 91.23 64.85 92.57 86.95 97.70 87.22
Sup21k Small (Dosovitskiy et al., 2020) 21M 82.10 89.16 88.56 94.98 89.05 79.36 74.50 99.19 90.13 68.19 94.87 90.32 96.76 87.81
Sup21k Base (Dosovitskiy et al., 2020) 86M 85.01 87.41 88.05 95.52 87.59 79.78 74.01 99.34 91.38 68.51 93.61 92.84 96.88 87.96
Sup21k Large (Dosovitskiy et al., 2020) 307M 87.38 88.00 88.16 95.59 88.56 81.02 74.34 99.09 91.77 69.50 95.27 92.62 96.79 88.43
Refer to caption
Figure 3: Average performance tradeoff on sampled ID vs OOD tasks as a function of (color) expert sparsity level τ𝜏\tauitalic_τ, and (marker) number of experts.
Refer to caption
Figure 4: Meta-gradients alignment between tasks throughout for SMAT with low and high sparsity levels. Meta-gradients are calculated w.r.t. the parameters shown in the legend.
Refer to caption
Figure 5: (a-b) Meta-learned model capacity after meta-tuning (i.e., number of non-zero parameters) grouped by a): layer types, and b): layer depth. (c-d) Expert specialisation. c) Dendrogram of task similarity based on expert selection scores. d) Overlap between masks.

5.3 Ablation studies

Importance of the different components.  In Tab. 3, we present several ablated variants of SMAT where we replace or remove certain components. We perform the study by meta-tuning on the DINO ViT-Small backbone and report the meta-testing results without adaptation. Overall, we observe that SMAT performs better than all ablated models on average, demonstrating the effectiveness of each proposed component. We notice that incorporating MoE at the MLP layers of the ViT (index 5), hence predetermining the sparsity patterns, leads to a marginally better ID performance; however, at the cost of a 1% drop in its OOD performance compared to SMAT. The results indicate the advantages of explicitly meta-learning the sparsity patterns in the experts for generalization.

Table 3: Ablation studies on different components of SMAT. MLS: meta-learned sparsity, Meta: Meta-training using support and query splits (otherwise no split), DT: dense teachers. IE: interpolated experts
ID Model MLS Meta DT IE ID OOD Avg
1 SMAT 85.14 67.27 75.21
2 85.07 66.44 74.74
3 84.77 67.02 74.90
4 82.35 63.64 71.95
5 85.21 66.21 74.75
6 PMF 84.23 64.09 73.05

Number of experts in SMAT.    In Fig. 3, we see that having more experts, hence higher model capacity given the same expert sparsity τ𝜏\tauitalic_τ, generally increases both ID and OOD performance of our model. The aforementioned ID vs. OOD tradeoff still exists for different numbers of experts; however, the OOD-to-ID tradeoff ratio (defined as ΔOOD accΔID accsubscriptΔOOD accsubscriptΔID acc\frac{\Delta_{\text{OOD acc}}}{\Delta_{\text{ID acc}}}divide start_ARG roman_Δ start_POSTSUBSCRIPT OOD acc end_POSTSUBSCRIPT end_ARG start_ARG roman_Δ start_POSTSUBSCRIPT ID acc end_POSTSUBSCRIPT end_ARG) varies - with ||=44|{\mathcal{M}}|=4| caligraphic_M | = 4 experts having the worst tradeoff ratio, and increasing |||{\mathcal{M}}|| caligraphic_M | from 48484\rightarrow 84 → 8 leads to the most significant gain in the ratio while the improvement seems to saturate when further increasing |||{\mathcal{M}}|| caligraphic_M | from 8168168\rightarrow 168 → 16.

Scale of Meta-tuning datasets.    In Fig. 6, we investigate the impact of the quantity and diversity of tasks observed during meta-tuning on the generalization performance of meta-tuned models on unseen meta-testing tasks. As anticipated, the overall generalization performance (on ID and OOD testing tasks), of both PMF and our model, improves as the scale of the meta-tuning datasets increases along the x-axis. However, even with increased quantity and diversity in the meta-tuning tasks, the OOD performance of PMF is not always better than that of the pre-trained model, which we conjecture is due to both meta-overfitting and harmful task interference.

Refer to caption
Figure 6: Relative testing performance w.r.t. the pre-trained initialization (Avg acc meta-tunedAvg acc pre-trainedAvg acc meta-tunedAvg acc pre-trained\frac{\text{Avg acc meta-tuned}}{\text{Avg acc pre-trained}}divide start_ARG Avg acc meta-tuned end_ARG start_ARG Avg acc pre-trained end_ARG) for SMAT vs PMF using meta-training tasks of increasing diversity.

In contrast, SMAT consistently achieves better OOD and overall performance compared to both PMF and pre-trained models, with a noticeable similar-to\sim4% improvement in relative OOD performance even in the low-diversity (IO) scenario. It is worth noting that this is also where the largest improvement in OOD performance for SMAT occurs. Intuitively, this is because Omniglot is very different from ImageNet, which was the only training source prior to its addition, resulting in a significant increase in task diversity. With only these two datasets, SMAT achieves comparable OOD performance to its full-MD version. These results highlight the effectiveness of using SMAT for meta-tuning in low data diversity settings, as well as its ability to achieve improved few-shot generalization by better leveraging task diversity during meta-tuning.

Scale of Vision Transformer backbones.  In Tab. 2, we present the meta-testing results on MD of various Vision Transformer backbones meta-tuned with SMAT. Our findings indicate that the larger models tend to offer superior overall performance but may require substantially more computational resources due to the number of parameters. These results can guide the selection of appropriate transformer architectures based on specific application requirements and resource constraints.

5.4 Qualitative visualization

Patterned sparsity emerges through meta-tuning.    In Fig. 5, we visualize the sparsity patterns on masks identified through meta-tuning. We observe that the sparsity levels vary significantly depending on the layer types (a), and depths (b). Specifically, the intermediate layers (5-9) have lower per-layer sparsity, while the first few layers are highly sparsified, with sparsity levels as high as 95%. Among the different layer types, we find that three types of layers retain most of their modulation parameters (non-zeros): (1) the first input embedding layers, (2) values of the attention module, and (3) linear layers of attention and feedforward modules. Across different masks, we notice that the standard deviations of sparsity levels are particularly larger for the first embedding layers and throughout layers at all depths. By examining the overlap** ratio (defined as |(𝒛i𝒛j)0||(𝒛i𝒛j)0|\frac{|(\bm{z}_{i}\cap\bm{z}_{j})\neq 0|}{|(\bm{z}_{i}\cup\bm{z}_{j})\neq 0|}divide start_ARG | ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∩ bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ≠ 0 | end_ARG start_ARG | ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∪ bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ≠ 0 | end_ARG), as shown in Fig. 5(d), we find that different masks, hence experts, generally have a small overlap. This indicates that SMAT has indeed discovered a diverse set of sparse interpolated experts through meta-tuning.

Learned expert merging rule encodes task relationship.    A closer look at the average expert selection scores 𝜶𝜶\bm{\alpha}bold_italic_α by datasets reveals that both specialized experts and a meaningful selection rule have been meta-learned by SMAT, as evident in the Fig. 5(c). We first note that overall, every expert has been utilized by some domains. More interestingly, the dendrogram produced by the similarity of mask selection scores clearly shows hierarchical clustering according to visual similarities between domains. Furthermore, we note sparsity and discreteness of expert selection generally inversely correlates with tasks complexity, with more sparse and discrete selection for intuitively simpler tasks (e.g., Omniglot, Dtd) than the more complex ones (e.g., Ilsvrc_2012).

6 Conclusion

We introduced a simple-yet-effective meta-tuning framework coined SMAT that accommodates to each task through an interpolation of the pre-trained model and a learned combination of sparse experts. Our experiments conclusively demonstrate SMAT’s effectiveness in delivering a more generalizable pre-trained model, resulting in state-of-the-art performance on out-of-distribution datasets. Notably, SMAT seamlessly integrates with cutting-edge parameter-efficient fine-tuning methods, and analyses of sparsity patterns underscore the specialization of the learned experts.

Acknowledgement

This work was supported by the Research Matching Grant Scheme (RMGS 9229111) founded by the University Grants Committee of Hong Kong, and the Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No.2022-0-00713, Meta-learning applicable to real-world problems)

Impact Statement

The proposed meta-tuning strategy exhibits a broad applicability, poised for extension or adaptation beyond vision pre-trained models. We envision this project as pioneering a new research trajectory for improving few-shot generalization of foundational models across diverse domains, including natural language processing, life sciences, time series, and more.

By augmenting few-shot generalization in pre-trained models, particularly in real-world applications often situated outside of the pre-training distribution, the proposed meta-tuning approach stands poised to substantially impact various downstream tasks like medical imaging analysis, self-driving, wildlife monitoring and etc.

References

  • Basu et al. (2023) Basu, S., Massiceti, D., Hu, S. X., and Feizi, S. Strong Baselines for Parameter Efficient Few-Shot Fine-tuning, April 2023. arXiv:2304.01917 [cs].
  • Bateni et al. (2019) Bateni, P., Goyal, R., Masrani, V., Wood, F. D., and Sigal, L. Improved few-shot visual classification. 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp.  14481–14490, 2019.
  • Bommasani et al. (2021) Bommasani, R., Hudson, D. A., Adeli, E., Altman, R., Arora, S., von Arx, S., Bernstein, M. S., Bohg, J., Bosselut, A., Brunskill, E., et al. On the opportunities and risks of foundation models. arXiv preprint arXiv:2108.07258, 2021.
  • Caron et al. (2021) Caron, M., Touvron, H., Misra, I., Jégou, H., Mairal, J., Bojanowski, P., and Joulin, A. Emerging properties in self-supervised vision transformers. In IEEE International Conference on Computer Vision, 2021.
  • Chen et al. (2022) Chen, L., Lu, S., and Chen, T. Understanding benign overfitting in gradient-based meta learning. Advances in Neural Information Processing Systems, 35:19887–19899, 2022.
  • Chen et al. (2023) Chen, S., Huang, L.-K., Schwarz, J. R., Du, Y., and Wei, Y. Secure out-of-distribution task generalization with energy-based models. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  • Chen et al. (2021) Chen, Y., Zhong, R., Zha, S., Karypis, G., and He, H. Meta-learning via language model in-context tuning. ArXiv, abs/2110.07814, 2021.
  • Dai et al. (2022) Dai, D., Dong, L., Ma, S., Zheng, B., Sui, Z., Chang, B., and Wei, F. StableMoE: Stable routing strategy for mixture of experts. In Muresan, S., Nakov, P., and Villavicencio, A. (eds.), Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp.  7085–7095, Dublin, Ireland, May 2022. Association for Computational Linguistics. doi: 10.18653/v1/2022.acl-long.489.
  • Dosovitskiy et al. (2020) Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N. An image is worth 16x16 words: Transformers for image recognition at scale. ArXiv, abs/2010.11929, 2020.
  • Eustratiadis et al. (2024) Eustratiadis, P., Dudziak, Ł., Li, D., and Hospedales, T. Neural fine-tuning search for few-shot learning. In The Twelfth International Conference on Learning Representations, 2024.
  • Fedus et al. (2021) Fedus, W., Zoph, B., and Shazeer, N. M. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity. J. Mach. Learn. Res., 23:120:1–120:39, 2021.
  • Finn et al. (2017) Finn, C., Abbeel, P., and Levine, S. Model-agnostic meta-learning for fast adaptation of deep networks. In International Conference on Machine Learning, 2017.
  • Flennerhag et al. (2022) Flennerhag, S., Schroecker, Y., Zahavy, T., van Hasselt, H., Silver, D., and Singh, S. Bootstrapped meta-learning. In International Conference on Learning Representations, 2022.
  • Gallego-Posada et al. (2022) Gallego-Posada, J., Ramirez, J., Erraqabi, A., Bengio, Y., and Lacoste-Julien, S. Controlled sparsity via constrained optimization or: How i learned to stop tuning penalties and love constraints. In Thirty-Sixth Conference on Neural Information Processing Systems, 2022.
  • Gao et al. (2021) Gao, T., Fisch, A., and Chen, D. Making pre-trained language models better few-shot learners. In Annual Meeting of the Association for Computational Linguistics, 2021.
  • Grant et al. (2018) Grant, E., Finn, C., Levine, S., Darrell, T., and Griffiths, T. Recasting gradient-based meta-learning as hierarchical bayes. In International Conference on Learning Representations, 2018.
  • Hinton et al. (2015) Hinton, G., Vinyals, O., and Dean, J. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
  • Hong et al. (2023) Hong, D., Zhang, B., Li, X., Li, Y., Li, C., Yao, J., Yokoya, N., Li, H., Jia, X., Plaza, A., et al. Spectralgpt: Spectral foundation model. arXiv preprint arXiv:2311.07113, 2023.
  • Hospedales et al. (2020) Hospedales, T. M., Antoniou, A., Micaelli, P., and Storkey, A. J. Meta-learning in neural networks: A survey. IEEE Transactions on Pattern Analysis and Machine Intelligence, 44:5149–5169, 2020.
  • Hu et al. (2021) Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L., and Chen, W. Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685, 2021.
  • Hu et al. (2022) Hu, S. X., Li, D., Stühmer, J., Kim, M., and Hospedales, T. M. Pushing the limits of simple pipelines for few-shot learning: External data and fine-tuning make a difference. In IEEE Conference on Computer Vision and Pattern Recognition, 2022.
  • Ilharco et al. (2022) Ilharco, G., Wortsman, M., Gadre, S. Y., Song, S., Hajishirzi, H., Kornblith, S., Farhadi, A., and Schmidt, L. Patching open-vocabulary models by interpolating weights. ArXiv, abs/2208.05592, 2022.
  • Ilharco et al. (2023) Ilharco, G., Ribeiro, M. T., Wortsman, M., Schmidt, L., Hajishirzi, H., and Farhadi, A. Editing models with task arithmetic. In The Eleventh International Conference on Learning Representations, 2023.
  • Jang et al. (2017) Jang, E., Gu, S., and Poole, B. Categorical reparameterization with gumbel-softmax. In International Conference on Learning Representations, 2017.
  • Lee et al. (2021) Lee, J., Tack, J., Lee, N., and Shin, J. Meta-learning sparse implicit neural representations. Advances in Neural Information Processing Systems, 34:11769–11780, 2021.
  • Lee et al. (2022) Lee, Y., Chen, A. S., Tajwar, F., Kumar, A., Yao, H., Liang, P., and Finn, C. Surgical fine-tuning improves adaptation to distribution shifts. ArXiv, abs/2210.11466, 2022.
  • Lepikhin et al. (2021) Lepikhin, D., Lee, H., Xu, Y., Chen, D., Firat, O., Huang, Y., Krikun, M., Shazeer, N., and Chen, Z. {GS}hard: Scaling giant models with conditional computation and automatic sharding. In International Conference on Learning Representations, 2021.
  • Lewis et al. (2021) Lewis, M., Bhosale, S., Dettmers, T., Goyal, N., and Zettlemoyer, L. Base layers: Simplifying training of large, sparse models. In International Conference on Machine Learning, 2021.
  • Li et al. (2022) Li, W.-H., Liu, X., and Bilen, H. Cross-domain few-shot learning with task-specific adapters. In IEEE/CVF International Conference on Computer Vision and Pattern Recognition (CVPR), June 2022.
  • Li et al. (2017) Li, Z., Zhou, F., Chen, F., and Li, H. Meta-sgd: Learning to learn quickly for few shot learning. ArXiv, abs/1707.09835, 2017.
  • Lin et al. (2023) Lin, Z., Akin, H., Rao, R., Hie, B., Zhu, Z., Lu, W., Smetanin, N., Verkuil, R., Kabeli, O., Shmueli, Y., et al. Evolutionary-scale prediction of atomic-level protein structure with a language model. Science, 379(6637):1123–1130, 2023.
  • Liu et al. (2023a) Liu, J., Yang, C., Lu, Z., Chen, J., Li, Y., Zhang, M., Bai, T., Fang, Y., Sun, L., Yu, P. S., et al. Towards graph foundation models: A survey and beyond. arXiv preprint arXiv:2310.11829, 2023a.
  • Liu et al. (2023b) Liu, T., Puigcerver, J., and Blondel, M. Sparsity-constrained optimal transport. In The Eleventh International Conference on Learning Representations, 2023b.
  • Liu et al. (2021) Liu, Y., Lee, J., Zhu, L., Chen, L., Shi, H., and Yang, Y. A multi-mode modulator for multi-domain few-shot classification. In 2021 IEEE/CVF International Conference on Computer Vision (ICCV), pp.  8433–8442, 2021. doi: 10.1109/ICCV48922.2021.00834.
  • Louizos et al. (2018) Louizos, C., Welling, M., and Kingma, D. P. Learning sparse neural networks through l0subscript𝑙0l_{0}italic_l start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT regularization. In International Conference on Learning Representations, 2018.
  • Matena & Raffel (2021) Matena, M. and Raffel, C. Merging models with fisher-weighted averaging. ArXiv, abs/2111.09832, 2021.
  • Mikolov et al. (2013) Mikolov, T., Sutskever, I., Chen, K., Corrado, G. S., and Dean, J. Distributed representations of words and phrases and their compositionality. Advances in neural information processing systems, 26, 2013.
  • Min et al. (2021) Min, S., Lewis, M., Zettlemoyer, L., and Hajishirzi, H. Metaicl: Learning to learn in context. ArXiv, abs/2110.15943, 2021.
  • Muqeeth et al. (2023) Muqeeth, M., Liu, H., and Raffel, C. Soft merging of experts with adaptive routing. ArXiv, abs/2306.03745, 2023.
  • Mustafa et al. (2022) Mustafa, B., Ruiz, C. R., Puigcerver, J., Jenatton, R., and Houlsby, N. Multimodal contrastive learning with LIMoe: the language-image mixture of experts. In Oh, A. H., Agarwal, A., Belgrave, D., and Cho, K. (eds.), Advances in Neural Information Processing Systems, 2022.
  • Oh et al. (2020) Oh, J., Yoo, H., Kim, C., and Yun, S.-Y. Boil: Towards representation change for few-shot learning. arXiv preprint arXiv:2008.08882, 2020.
  • Panigrahi et al. (2023a) Panigrahi, A., Saunshi, N., Zhao, H., and Arora, S. Task-specific skill localization in fine-tuned language models. In International Conference on Machine Learning, 2023a.
  • Panigrahi et al. (2023b) Panigrahi, A., Saunshi, N., Zhao, H., and Arora, S. Task-specific skill localization in fine-tuned language models. arXiv preprint arXiv:2302.06600, 2023b.
  • Perez et al. (2018) Perez, E., Strub, F., De Vries, H., Dumoulin, V., and Courville, A. Film: Visual reasoning with a general conditioning layer. In Proceedings of the AAAI conference on artificial intelligence, 2018.
  • Puigcerver et al. (2023) Puigcerver, J., Riquelme, C., Mustafa, B., and Houlsby, N. From sparse to soft mixtures of experts. ArXiv, abs/2308.00951, 2023.
  • Radford et al. (2021) Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., Sastry, G., Askell, A., Mishkin, P., Clark, J., et al. Learning transferable visual models from natural language supervision. In International conference on machine learning, pp.  8748–8763. PMLR, 2021.
  • Raghu et al. (2019) Raghu, A., Raghu, M., Bengio, S., and Vinyals, O. Rapid learning or feature reuse? towards understanding the effectiveness of maml. arXiv preprint arXiv:1909.09157, 2019.
  • Requeima et al. (2019) Requeima, J., Gordon, J., Bronskill, J., Nowozin, S., and Turner, R. E. Fast and flexible multi-task classification using conditional neural adaptive processes. Advances in Neural Information Processing Systems, 32, 2019.
  • Riquelme et al. (2021) Riquelme, C., Puigcerver, J., Mustafa, B., Neumann, M., Jenatton, R., Pinto, A. S., Keysers, D., and Houlsby, N. Scaling vision with sparse mixture of experts. In Neural Information Processing Systems, 2021.
  • Roller et al. (2021) Roller, S., Sukhbaatar, S., Szlam, A., and Weston, J. Hash layers for large sparse models. In Neural Information Processing Systems, 2021.
  • Schwarz & Teh (2022) Schwarz, J. R. and Teh, Y. W. Meta-learning sparse compression networks. arXiv preprint arXiv:2205.08957, 2022.
  • Schwarz et al. (2023) Schwarz, J. R., Tack, J., Teh, Y. W., Lee, J., and Shin, J. Modality-agnostic variational compression of implicit neural representations. arXiv preprint arXiv:2301.09479, 2023.
  • Shazeer et al. (2017) Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q., Hinton, G., and Dean, J. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. In International Conference on Learning Representations, 2017.
  • Shysheya et al. (2022) Shysheya, A., Bronskill, J., Patacchiola, M., Nowozin, S., and Turner, R. E. Fit: Parameter efficient few-shot transfer learning for personalized and federated image classification. arXiv preprint arXiv:2206.08671, 2022.
  • Snell et al. (2017) Snell, J., Swersky, K., and Zemel, R. Prototypical Networks for Few-shot Learning. In Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
  • Tack et al. (2024) Tack, J., Kim, S., Yu, S., Lee, J., Shin, J., and Schwarz, J. R. Learning large-scale neural fields via context pruned meta-learning. Advances in Neural Information Processing Systems, 36, 2024.
  • Thrun & Pratt (2012) Thrun, S. and Pratt, L. Learning to learn. Springer Science & Business Media, 2012.
  • Triantafillou et al. (2020) Triantafillou, E., Zhu, T., Dumoulin, V., Lamblin, P., Evci, U., Xu, K., Goroshin, R., Gelada, C., Swersky, K., Manzagol, P.-A., and Larochelle, H. Meta-dataset: A dataset of datasets for learning to learn from few examples. In International Conference on Learning Representations, 2020.
  • Triantafillou et al. (2021a) Triantafillou, E., Larochelle, H., Zemel, R., and Dumoulin, V. Learning a universal template for few-shot dataset generalization. In International Conference on Machine Learning, pp.  10424–10433. PMLR, 2021a.
  • Triantafillou et al. (2021b) Triantafillou, E., Larochelle, H., Zemel, R. S., and Dumoulin, V. Learning a universal template for few-shot dataset generalization. ArXiv, abs/2105.07029, 2021b.
  • Von Oswald et al. (2021) Von Oswald, J., Zhao, D., Kobayashi, S., Schug, S., Caccia, M., Zucchet, N., and Sacramento, J. Learning where to learn: Gradient sparsity in meta and continual learning. Advances in Neural Information Processing Systems, 2021.
  • Wang et al. (2020) Wang, Z., Tsvetkov, Y., Firat, O., and Cao, Y. Gradient vaccine: Investigating and improving multi-task optimization in massively multilingual models. ArXiv, abs/2010.05874, 2020.
  • Wortsman et al. (2021) Wortsman, M., Ilharco, G., Li, M., Kim, J. W., Hajishirzi, H., Farhadi, A., Namkoong, H., and Schmidt, L. Robust fine-tuning of zero-shot models. 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp.  7949–7961, 2021.
  • Yao et al. (2021) Yao, H., Wang, Y., Wei, Y., Zhao, P., Mahdavi, M., Lian, D., and Finn, C. Meta-learning with an adaptive task scheduler. Advances in Neural Information Processing Systems, 34:7497–7509, 2021.
  • Yeh et al. (2023) Yeh, C.-C. M., Dai, X., Chen, H., Zheng, Y., Fan, Y., Der, A., Lai, V., Zhuang, Z., Wang, J., Wang, L., et al. Toward a foundation model for time series data. In Proceedings of the 32nd ACM International Conference on Information and Knowledge Management, pp.  4400–4404, 2023.
  • Yu et al. (2023) Yu, L., Shi, B., Pasunuru, R., Muller, B., Golovneva, O., Wang, T., Babu, A., Tang, B., Karrer, B., Sheynin, S., et al. Scaling autoregressive multi-modal models: Pretraining and instruction tuning. arXiv preprint arXiv:2309.02591, 2023.
  • Yu et al. (2020) Yu, T., Kumar, S., Gupta, A., Levine, S., Hausman, K., and Finn, C. Gradient surgery for multi-task learning. ArXiv, abs/2001.06782, 2020.
  • Zhang et al. (2023) Zhang, S., Dong, L., Li, X., Zhang, S., Sun, X., Wang, S., Li, J., Hu, R., Zhang, T., Wu, F., et al. Instruction tuning for large language models: A survey. arXiv preprint arXiv:2308.10792, 2023.
  • Zhou et al. (2022) Zhou, Y., Lei, T., Liu, H., Du, N., Huang, Y., Zhao, V. Y., Dai, A. M., Chen, Z., Le, Q. V., and Laudon, J. Mixture-of-experts with expert choice routing. In Oh, A. H., Agarwal, A., Belgrave, D., and Cho, K. (eds.), Advances in Neural Information Processing Systems, 2022.
  • Zhuang et al. (2020) Zhuang, F., Qi, Z., Duan, K., Xi, D., Zhu, Y., Zhu, H., Xiong, H., and He, Q. A comprehensive survey on transfer learning. Proceedings of the IEEE, 109(1):43–76, 2020.
  • Zintgraf et al. (2019) Zintgraf, L. M., Shiarlis, K., Kurin, V., Hofmann, K., and Whiteson, S. Caml: Fast context adaptation via meta-learning. In International Conference on Machine Learning, 2019.
  • Zuo et al. (2022) Zuo, S., Liu, X., Jiao, J., Kim, Y. J., Hassan, H., Zhang, R., Gao, J., and Zhao, T. Taming sparsely activated transformer with stochastic experts. In International Conference on Learning Representations, 2022.

Appendix A Pseudocode for SMAT

A.1 SMAT for meta-training

The pseudocode for meta-tuning using SMAT can be found in Alg. 1 below. Our implementation is publicly available at github.com/szc12153/sparse_meta_tuning.

Algorithm 1 SMAT: Meta-training
  Data: Meta-training tasks 𝕋:={𝒯1,𝒯2,𝒯N}𝒫IDassign𝕋subscript𝒯1subscript𝒯2subscript𝒯𝑁similar-tosubscript𝒫𝐼𝐷{\mathbb{T}}:=\{{\mathcal{T}}_{1},{\mathcal{T}}_{2},...{\mathcal{T}}_{N}\}\sim% \mathcal{P}_{ID}blackboard_T := { caligraphic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , caligraphic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … caligraphic_T start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT } ∼ caligraphic_P start_POSTSUBSCRIPT italic_I italic_D end_POSTSUBSCRIPT
  Require: Pre-trained initialization 𝜽𝚙𝚛𝚎superscript𝜽𝚙𝚛𝚎\bm{\theta}^{\mathtt{pre}}bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT; Target expert sparsity level τm=τsubscript𝜏𝑚𝜏\tau_{m}=\tauitalic_τ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_τ; Number of experts M𝑀Mitalic_M
  Output: Variational distribution Φ={ϕm}m=1||Φsubscriptsuperscriptsubscriptitalic-ϕ𝑚𝑚1\Phi=\{\phi_{m}\}^{|{\mathcal{M}}|}_{m=1}roman_Φ = { italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT for the sparse masks; Hypernetwork hψsubscript𝜓h_{\psi}italic_h start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT; Meta-tuned parameters 𝜽δsuperscript𝜽𝛿\bm{\theta}^{\delta}bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT
  Initialize ζ,Φ,θδ𝜁Φsuperscript𝜃𝛿\zeta,\ \Phi,\ \theta^{\delta}italic_ζ , roman_Φ , italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT
  Initialize the Lagrangian multipliers λm=0subscript𝜆𝑚0\lambda_{m}=0italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = 0 for sparsity constraint 1L0(𝒛m)dim(𝒛m)τ,m[||]formulae-sequence1subscript𝐿0subscript𝒛𝑚dimsubscript𝒛𝑚𝜏for-all𝑚delimited-[]1-\frac{L_{0}(\bm{z}_{m})}{\text{dim}(\bm{z}_{m})}\geq\tau,\ \ \forall m\in[|{% \mathcal{M}}|]1 - divide start_ARG italic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) end_ARG start_ARG dim ( bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) end_ARG ≥ italic_τ , ∀ italic_m ∈ [ | caligraphic_M | ].
  while not converged do
     𝔹𝕋similar-to𝔹𝕋{\mathbb{B}}\sim{\mathbb{T}}blackboard_B ∼ blackboard_T# Sample a batch of tasks
     for i=1,2,,|𝔹|𝑖12𝔹i=1,2,...,|{\mathbb{B}}|italic_i = 1 , 2 , … , | blackboard_B | do
        𝒯i𝒯is,𝒯iqsubscript𝒯𝑖subscriptsuperscript𝒯𝑠𝑖superscriptsubscript𝒯𝑖𝑞{\mathcal{T}}_{i}\rightarrow{\mathcal{T}}^{s}_{i},{\mathcal{T}}_{i}^{q}caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT → caligraphic_T start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT# Split into support and query sets
        𝜶𝒊𝙶𝚞𝚖𝚋𝚕𝚎𝚂𝚒𝚐𝚖𝚘𝚒𝚍(hζ(𝒯is))similar-tosuperscriptsubscript𝜶𝒊bold-′𝙶𝚞𝚖𝚋𝚕𝚎𝚂𝚒𝚐𝚖𝚘𝚒𝚍subscript𝜁subscriptsuperscript𝒯𝑠𝑖\bm{\alpha_{i}^{\prime}}\sim\mathtt{GumbleSigmoid}(h_{\zeta}({\mathcal{T}}^{s}% _{i}))bold_italic_α start_POSTSUBSCRIPT bold_italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_′ end_POSTSUPERSCRIPT ∼ typewriter_GumbleSigmoid ( italic_h start_POSTSUBSCRIPT italic_ζ end_POSTSUBSCRIPT ( caligraphic_T start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )# Sample expert merging scores from the hypernetwork
        αi,m=αi,mmMαi,msubscript𝛼𝑖𝑚superscriptsubscript𝛼𝑖𝑚superscriptsubscript𝑚𝑀superscriptsubscript𝛼𝑖𝑚\alpha_{i,m}=\frac{\alpha_{i,m}^{\prime}}{\sum_{m}^{M}\alpha_{i,m}^{\prime}}italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG# Normalize the merging weights
        𝒛mqϕm,m[M]formulae-sequencesimilar-tosubscript𝒛𝑚subscript𝑞subscriptitalic-ϕ𝑚for-all𝑚delimited-[]𝑀\bm{z}_{m}\sim q_{\phi_{m}},\ \forall m\in[M]bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT , ∀ italic_m ∈ [ italic_M ]# Sample sparse masks from the variational distribution
        𝜽i=𝜽𝚙𝚛𝚎+𝜽δmMαi,m𝒛msubscript𝜽𝑖superscript𝜽𝚙𝚛𝚎direct-productsuperscript𝜽𝛿superscriptsubscript𝑚𝑀subscript𝛼𝑖𝑚subscript𝒛𝑚\bm{\theta}_{i}=\bm{\theta}^{\mathtt{pre}}+\bm{\theta}^{\delta}\odot\sum_{m}^{% M}\alpha_{i,m}\bm{z}_{m}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT ⊙ ∑ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT# Weighted-sum of sparse experts
        𝜽i𝚝𝚛𝚂𝚝𝚘𝚙𝙶𝚛𝚊𝚍𝚜(𝜽i)subscriptsuperscript𝜽𝚝𝚛𝑖𝚂𝚝𝚘𝚙𝙶𝚛𝚊𝚍𝚜subscript𝜽𝑖\bm{\theta}^{\mathtt{tr}}_{i}\leftarrow\mathtt{StopGrads}(\bm{\theta}_{i})bold_italic_θ start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← typewriter_StopGrads ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
        for k=1,2,,K𝑘12𝐾k=1,2,...,Kitalic_k = 1 , 2 , … , italic_K  do
            # Task-specific dense teacher
           𝜽i𝚝𝚛𝙶𝚛𝚊𝚍𝚒𝚎𝚗𝚝𝙳𝚎𝚜𝚌𝚎𝚗𝚝(𝜽i𝚝𝚛𝒯iq𝚌𝚎(𝜽i𝚝𝚛))subscriptsuperscript𝜽𝚝𝚛𝑖𝙶𝚛𝚊𝚍𝚒𝚎𝚗𝚝𝙳𝚎𝚜𝚌𝚎𝚗𝚝subscriptsubscriptsuperscript𝜽𝚝𝚛𝑖superscriptsubscriptsuperscriptsubscript𝒯𝑖𝑞𝚌𝚎subscriptsuperscript𝜽𝚝𝚛𝑖\bm{\theta}^{\mathtt{tr}}_{i}\leftarrow\mathtt{GradientDescent}(\nabla_{\bm{% \theta}^{\mathtt{tr}}_{i}}{\mathcal{L}}_{{\mathcal{T}}_{i}^{q}}^{\mathtt{ce}}(% \bm{\theta}^{\mathtt{tr}}_{i}))bold_italic_θ start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← typewriter_GradientDescent ( ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_ce end_POSTSUPERSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )
        end for
        i:=β𝒯iq𝚌𝚎(𝜽i)+(1β)𝒯iq𝚔𝚍(𝜽i,𝜽i𝚝𝚛)assignsubscript𝑖𝛽subscriptsuperscript𝚌𝚎superscriptsubscript𝒯𝑖𝑞subscript𝜽𝑖1𝛽subscriptsuperscript𝚔𝚍superscriptsubscript𝒯𝑖𝑞subscript𝜽𝑖superscriptsubscript𝜽𝑖𝚝𝚛{\mathcal{L}}_{i}:=\beta{\mathcal{L}}^{\mathtt{ce}}_{\mathcal{T}_{i}^{q}}(\bm{% \theta}_{i})+(1-\beta)\mathcal{L}^{\mathtt{kd}}_{\mathcal{T}_{i}^{q}}(\bm{% \theta}_{i},\bm{\theta}_{i}^{\mathtt{tr}})caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT := italic_β caligraphic_L start_POSTSUPERSCRIPT typewriter_ce end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ( 1 - italic_β ) caligraphic_L start_POSTSUPERSCRIPT typewriter_kd end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_tr end_POSTSUPERSCRIPT )# task’s meta-loss, β𝛽\betaitalic_β is a weighting coefficient (0,1)absent01\in(0,1)∈ ( 0 , 1 )
     end for
     𝒞om:=λdim(ϕm)jdim(ϕm)Qϕm(sj0)τ,m[M]formulae-sequenceassign𝒞subscript𝑜𝑚𝜆dimsubscriptbold-italic-ϕ𝑚superscriptsubscript𝑗dimsubscriptbold-italic-ϕ𝑚subscript𝑄subscriptbold-italic-ϕ𝑚subscript𝑠𝑗0𝜏for-all𝑚delimited-[]𝑀{\mathcal{C}}o_{m}:=\frac{\lambda}{\text{\small dim}(\bm{\phi}_{m})}\sum_{j}^{% \text{dim}(\bm{\phi}_{m})}Q_{\bm{\phi}_{m}}(s_{j}\leq 0)-\tau,\ \forall m\in[M]caligraphic_C italic_o start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT := divide start_ARG italic_λ end_ARG start_ARG dim ( bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) end_ARG ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT dim ( bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT italic_Q start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ≤ 0 ) - italic_τ , ∀ italic_m ∈ [ italic_M ]# sparsity loss for each expert
     [𝜻,𝚽,𝜽δ]𝙶𝚛𝚊𝚍𝚒𝚎𝚗𝚝𝙳𝚎𝚜𝚌𝚎𝚗𝚝([𝜻,𝚽,𝜽δ](1|𝔹|i=1|𝔹|i+m=1Mλm𝒞om))𝜻𝚽superscript𝜽𝛿𝙶𝚛𝚊𝚍𝚒𝚎𝚗𝚝𝙳𝚎𝚜𝚌𝚎𝚗𝚝subscript𝜻𝚽superscript𝜽𝛿1𝔹superscriptsubscript𝑖1𝔹subscript𝑖superscriptsubscript𝑚1𝑀subscript𝜆𝑚𝒞subscript𝑜𝑚[\bm{\zeta},\bm{\Phi},\bm{\theta}^{\delta}]\leftarrow\mathtt{GradientDescent}(% \nabla_{[\bm{\zeta},\bm{\Phi},\bm{\theta}^{\delta}]}\big{(}\frac{1}{|{\mathbb{% B}}|}\sum_{i=1}^{|{\mathbb{B}}|}\mathcal{L}_{i}+\sum_{m=1}^{M}\lambda_{m}{% \mathcal{C}}o_{m})\big{)}[ bold_italic_ζ , bold_Φ , bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT ] ← typewriter_GradientDescent ( ∇ start_POSTSUBSCRIPT [ bold_italic_ζ , bold_Φ , bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT ] end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG | blackboard_B | end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | blackboard_B | end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT caligraphic_C italic_o start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) )# Minimization problem in Eqn. (4.1)
     for m=1,2,,M𝑚12𝑀m=1,2,...,Mitalic_m = 1 , 2 , … , italic_M do
        if 𝒞om<0𝒞subscript𝑜𝑚0{\mathcal{C}}o_{m}<0caligraphic_C italic_o start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT < 0  then
            # expected expert sparsity is lower than the constraint τ𝜏\tauitalic_τ
           λm𝙶𝚛𝚊𝚍𝚒𝚎𝚗𝚝𝙰𝚜𝚌𝚎𝚗𝚝(λm𝒞omi𝔹αi,m|𝔹|)subscript𝜆𝑚𝙶𝚛𝚊𝚍𝚒𝚎𝚗𝚝𝙰𝚜𝚌𝚎𝚗𝚝subscriptsubscript𝜆𝑚𝒞subscript𝑜𝑚subscript𝑖𝔹subscript𝛼𝑖𝑚𝔹\lambda_{m}\leftarrow\mathtt{GradientAscent}(\nabla_{\lambda_{m}}{\mathcal{C}}% o_{m}\frac{\sum_{i\in\mathbb{B}}\alpha_{i,m}}{|\mathbb{B}|})italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ← typewriter_GradientAscent ( ∇ start_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_C italic_o start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT divide start_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ blackboard_B end_POSTSUBSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT end_ARG start_ARG | blackboard_B | end_ARG )# Maximization problem in Eqn. (4.1)
        else
            # expected expert sparsity is at least τ𝜏\tauitalic_τ, the constraint is satisfied
           λm0subscript𝜆𝑚0\lambda_{m}\leftarrow 0italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ← 0 # Reset the Lagrangian multiplier for the m𝑚mitalic_m-th constraint
        end if
     end for
  end while
  Return 𝚽,𝜻,𝜽𝜹𝚽𝜻subscript𝜽𝜹\bm{\Phi},\bm{\zeta},\bm{\theta_{\delta}}bold_Φ , bold_italic_ζ , bold_italic_θ start_POSTSUBSCRIPT bold_italic_δ end_POSTSUBSCRIPT

A.2 SMAT for meta-testing

We propose a heuristic in Alg. 2 for optimizing the task-specific expert selection during meta-testing time without the need for expensive gradient computation. Specifically, we restrict the normalized expert selection score estimated by our hypernetwork hζsubscript𝜁h_{\zeta}italic_h start_POSTSUBSCRIPT italic_ζ end_POSTSUBSCRIPT in binary states i.e., αi,m{0,1},m[M]formulae-sequencesuperscriptsubscript𝛼𝑖𝑚01for-all𝑚delimited-[]𝑀\alpha_{i,m}^{\prime}\in\{0,1\},\ \forall m\in[M]italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ { 0 , 1 } , ∀ italic_m ∈ [ italic_M ], and optimize αi,msubscript𝛼𝑖𝑚\alpha_{i,m}italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT in this binary space by minimizing the loss on the meta-testing task’s support set. The intuition behind this is that each expert is either needed or discarded for each few-shot learning tasks, which is supported by our empirical observations on αi,msuperscriptsubscript𝛼𝑖𝑚\alpha_{i,m}^{\prime}italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT being very discrete (close to 0 or 1) in most cases.

Algorithm 2 SMAT: Meta-testing time gradient-free expert selection (for a single task 𝒯~isubscript~𝒯𝑖\tilde{\mathcal{T}}_{i}over~ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT)
  Input: Testing support set 𝒯~issuperscriptsubscript~𝒯𝑖𝑠\tilde{\mathcal{T}}_{i}^{s}over~ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT and query inputs 𝐗~iqsuperscriptsubscript~𝐗𝑖𝑞\tilde{\mathbf{X}}_{i}^{q}over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT, Meta-trained ζ,θδ,Φ𝜁superscript𝜃𝛿Φ\zeta,\theta^{\delta},\Phiitalic_ζ , italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT , roman_Φ, Pre-trained θ𝚙𝚛𝚎superscript𝜃𝚙𝚛𝚎\theta^{\mathtt{pre}}italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT
  𝜶i𝙷𝚊𝚛𝚍𝙶𝚞𝚖𝚋𝚕𝚎𝚂𝚒𝚐𝚖𝚘𝚒𝚍(hζ(𝒯~is))similar-tosuperscriptsubscript𝜶𝑖𝙷𝚊𝚛𝚍𝙶𝚞𝚖𝚋𝚕𝚎𝚂𝚒𝚐𝚖𝚘𝚒𝚍subscript𝜁subscriptsuperscript~𝒯𝑠𝑖\bm{\alpha}_{i}^{\prime}\sim\mathtt{HardGumbleSigmoid}(h_{\zeta}(\tilde{% \mathcal{T}}^{s}_{i}))bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ typewriter_HardGumbleSigmoid ( italic_h start_POSTSUBSCRIPT italic_ζ end_POSTSUBSCRIPT ( over~ start_ARG caligraphic_T end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )# Initialize expert merging scores using the hypernetwork and round to [0,1]
  ll*italic_l ∗ = positive infinity  # use to record the lowest support loss during exploration
  𝒛mqϕm,m[M]formulae-sequencesimilar-tosubscript𝒛𝑚subscript𝑞subscriptitalic-ϕ𝑚for-all𝑚delimited-[]𝑀\bm{z}_{m}\sim q_{\phi_{m}},\ \forall m\in[M]bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT italic_ϕ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT , ∀ italic_m ∈ [ italic_M ]# Sample sparse masks once at the start
  for r=1,2,,R𝑟12𝑅r=1,2,...,Ritalic_r = 1 , 2 , … , italic_R  do
      # repeat for R rounds of sampling
     for m=1,2,,M𝑚12𝑀m=1,2,...,Mitalic_m = 1 , 2 , … , italic_M do
         # iterate through each score in αisubscript𝛼𝑖\alpha_{i}italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
        Flip αi,m, 01superscriptsubscript𝛼𝑖𝑚 01\alpha_{i,m}^{\prime},\ 0\leftrightarrow 1italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , 0 ↔ 1# Generate candidate score for the m𝑚mitalic_m-th expert
        αi,m=αi,mmMαi,msubscript𝛼𝑖𝑚superscriptsubscript𝛼𝑖𝑚superscriptsubscript𝑚𝑀superscriptsubscript𝛼𝑖𝑚\alpha_{i,m}=\frac{\alpha_{i,m}^{\prime}}{\sum_{m}^{M}\alpha_{i,m}^{\prime}}italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG# Normalize the merging weights
        𝜽i=𝜽𝚙𝚛𝚎+𝜽δmMαi,m𝒛msubscript𝜽𝑖superscript𝜽𝚙𝚛𝚎direct-productsuperscript𝜽𝛿superscriptsubscript𝑚𝑀subscript𝛼𝑖𝑚subscript𝒛𝑚\bm{\theta}_{i}=\bm{\theta}^{\mathtt{pre}}+\bm{\theta}^{\delta}\odot\sum_{m}^{% M}\alpha_{i,m}\bm{z}_{m}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT ⊙ ∑ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT# Weighted-sum of sparse experts
        ~i:=𝒯~is𝚌𝚎(𝜽i)assignsubscript~𝑖subscriptsuperscript𝚌𝚎superscriptsubscript~𝒯𝑖𝑠subscript𝜽𝑖\tilde{\mathcal{L}}_{i}:={\mathcal{L}}^{\mathtt{ce}}_{\tilde{\mathcal{T}}_{i}^% {s}}(\bm{\theta}_{i})over~ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT := caligraphic_L start_POSTSUPERSCRIPT typewriter_ce end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )# Evaluate the support loss which only requires forward passes
        if ~i<lsubscript~𝑖𝑙\tilde{\mathcal{L}}_{i}<lover~ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT < italic_l  then
            # Rejection sampling
           Accept the candidate αi,msubscript𝛼𝑖𝑚\alpha_{i,m}italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT with ρ𝜌\rhoitalic_ρ and record l=l=~il*=l=\tilde{\mathcal{L}}_{i}italic_l ∗ = italic_l = over~ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, otherwise reject
        else
           Accept the candidate αi,msubscript𝛼𝑖𝑚\alpha_{i,m}italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT with (1ρ)1𝜌(1-\rho)( 1 - italic_ρ ) and record l=~i𝑙subscript~𝑖l=\tilde{\mathcal{L}}_{i}italic_l = over~ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT otherwise reject
        end if
     end for
  end for
  Return: 𝜶𝒊subscript𝜶𝒊\bm{\alpha_{i}}bold_italic_α start_POSTSUBSCRIPT bold_italic_i end_POSTSUBSCRIPT at the lowest support loss lsuperscript𝑙l^{*}italic_l start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, Which is then used for final prediction on the query 𝐘^iq=f(𝐗~iq;𝜽i=𝜽𝚙𝚛𝚎+𝜽δmMαi,m𝒛m\hat{\mathbf{Y}}^{q}_{i}=f(\tilde{\mathbf{X}}_{i}^{q};\bm{\theta}_{i}=\bm{% \theta}^{\mathtt{pre}}+\bm{\theta}^{\delta}\odot\sum_{m}^{M}\alpha_{i,m}\bm{z}% _{m}over^ start_ARG bold_Y end_ARG start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_f ( over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ; bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT ⊙ ∑ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT.
Algorithm 3 SMAT: Meta-testing time full fine-tuning using θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as an model initialization (for a single task 𝒯~isubscript~𝒯𝑖\tilde{\mathcal{T}}_{i}over~ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT)
  Input: Testing support set 𝒯~issuperscriptsubscript~𝒯𝑖𝑠\tilde{\mathcal{T}}_{i}^{s}over~ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT and query inputs 𝐗~iqsuperscriptsubscript~𝐗𝑖𝑞\tilde{\mathbf{X}}_{i}^{q}over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT; Meta-trained ζ,θδ,Φ𝜁superscript𝜃𝛿Φ\zeta,\theta^{\delta},\Phiitalic_ζ , italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT , roman_Φ, Pre-trained θ𝚙𝚛𝚎superscript𝜃𝚙𝚛𝚎\theta^{\mathtt{pre}}italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT
  𝜶i𝙶𝚞𝚖𝚋𝚕𝚎𝚂𝚒𝚐𝚖𝚘𝚒𝚍(hζ(𝒯~is))similar-tosuperscriptsubscript𝜶𝑖𝙶𝚞𝚖𝚋𝚕𝚎𝚂𝚒𝚐𝚖𝚘𝚒𝚍subscript𝜁subscriptsuperscript~𝒯𝑠𝑖\bm{\alpha}_{i}^{\prime}\sim\mathtt{GumbleSigmoid}(h_{\zeta}(\tilde{\mathcal{T% }}^{s}_{i}))bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ typewriter_GumbleSigmoid ( italic_h start_POSTSUBSCRIPT italic_ζ end_POSTSUBSCRIPT ( over~ start_ARG caligraphic_T end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )# Obtain expert merging scores using the hypernetwork
  αi,m=αi,mmMαi,msubscript𝛼𝑖𝑚superscriptsubscript𝛼𝑖𝑚superscriptsubscript𝑚𝑀superscriptsubscript𝛼𝑖𝑚\alpha_{i,m}=\frac{\alpha_{i,m}^{\prime}}{\sum_{m}^{M}\alpha_{i,m}^{\prime}}italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG# Normalize the merging weights
  𝜽i=𝜽𝚙𝚛𝚎+𝜽δmMαi,m𝒛msubscript𝜽𝑖superscript𝜽𝚙𝚛𝚎direct-productsuperscript𝜽𝛿superscriptsubscript𝑚𝑀subscript𝛼𝑖𝑚subscript𝒛𝑚\bm{\theta}_{i}=\bm{\theta}^{\mathtt{pre}}+\bm{\theta}^{\delta}\odot\sum_{m}^{% M}\alpha_{i,m}\bm{z}_{m}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT ⊙ ∑ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT# Weighted-sum of sparse experts
  𝜽i,0𝚂𝚝𝚘𝚙𝙶𝚛𝚊𝚍𝚜(𝜽i)subscript𝜽𝑖0𝚂𝚝𝚘𝚙𝙶𝚛𝚊𝚍𝚜subscript𝜽𝑖\bm{\theta}_{i,0}\leftarrow\mathtt{StopGrads}(\bm{\theta}_{i})bold_italic_θ start_POSTSUBSCRIPT italic_i , 0 end_POSTSUBSCRIPT ← typewriter_StopGrads ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
  for k=0,2,(K1)𝑘02𝐾1k=0,2,...(K-1)italic_k = 0 , 2 , … ( italic_K - 1 ) do
     𝜽i,k+1𝙶𝚛𝚊𝚍𝚒𝚎𝚗𝚝𝙳𝚎𝚜𝚌𝚎𝚗𝚝(𝜽i,k𝒯~is𝚌𝚎𝜽i,k)subscript𝜽𝑖𝑘1𝙶𝚛𝚊𝚍𝚒𝚎𝚗𝚝𝙳𝚎𝚜𝚌𝚎𝚗𝚝subscriptsubscript𝜽𝑖𝑘superscriptsubscriptsuperscriptsubscript~𝒯𝑖𝑠𝚌𝚎subscript𝜽𝑖𝑘\bm{\theta}_{i,k+1}\leftarrow\mathtt{GradientDescent}(\nabla_{\bm{\theta}_{i,k% }}{\mathcal{L}}_{\tilde{\mathcal{T}}_{i}^{s}}^{\mathtt{ce}}\bm{\theta}_{i,k})bold_italic_θ start_POSTSUBSCRIPT italic_i , italic_k + 1 end_POSTSUBSCRIPT ← typewriter_GradientDescent ( ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT over~ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT typewriter_ce end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT )# finetune θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT on the support set for K𝐾Kitalic_K steps
  end for
  Return: Final prediction on the query set 𝐘^iq=f(𝐗~iq;𝜽i,K)subscriptsuperscript^𝐘𝑞𝑖𝑓superscriptsubscript~𝐗𝑖𝑞subscript𝜽𝑖𝐾\hat{\mathbf{Y}}^{q}_{i}=f(\tilde{\mathbf{X}}_{i}^{q};\bm{\theta}_{i,K})over^ start_ARG bold_Y end_ARG start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_f ( over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ; bold_italic_θ start_POSTSUBSCRIPT italic_i , italic_K end_POSTSUBSCRIPT )

Appendix B Experiment details

B.1 Implementation of baselines

PMF (Hu et al., 2022).    In Tab. 1, we report published results by PMF in their paper (Hu et al., 2022) whenever they are available. For extra meta-testing datasets in Tab. 1, which were not included in their paper, we produce these results using the official meta-trained PMF model checkpoint, which is publicly available with their code on Github.

B.2 Details for meta-testing

W/O fine-tuning.    To perform direct few-shot inference for a given testing task, 𝒯~i:={𝐗~is,𝐘~is,𝐗~iq}assignsubscript~𝒯𝑖superscriptsubscript~𝐗𝑖𝑠subscriptsuperscript~𝐘𝑠𝑖superscriptsubscript~𝐗𝑖𝑞\tilde{{\mathcal{T}}}_{i}:=\{\tilde{\mathbf{X}}_{i}^{s},\tilde{\mathbf{Y}}^{s}% _{i},\tilde{\mathbf{X}}_{i}^{q}\}over~ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT := { over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT , over~ start_ARG bold_Y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , over~ start_ARG bold_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT }, we consider using the prototypical network (Snell et al., 2017), which first constructs class centroids in the feature space of the model using the labeled support set, before performing the nearest centroid classification on the query input.

Denote the feature backbone (e.g., a pre-trained or meta-tuned ViT) by f𝜽:𝒳D:subscript𝑓𝜽𝒳superscript𝐷f_{\bm{\theta}}:{\mathcal{X}}\rightarrow\mathbb{R}^{D}italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT : caligraphic_X → blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT parameterized by 𝜽𝜽\bm{\theta}bold_italic_θ, which essentially is a map** from the input pixel space to a vector space of dimension D𝐷Ditalic_D. The centroid for each unique class 𝐜k,k=[1,2,3K]subscript𝐜𝑘𝑘123𝐾\mathbf{c}_{k},k=[1,2,3...K]bold_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_k = [ 1 , 2 , 3 … italic_K ] in 𝒯~isubscript~𝒯𝑖\tilde{\mathcal{T}}_{i}over~ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is calculated from the support set by:

𝐜k=1|j:yi,js=ck|j:yi,js=ckf𝜽(𝐱i,js).\mathbf{c}_{k}=\frac{1}{|j:y_{i,j}^{s}=c_{k}|}\sum_{j:y_{i,j}^{s}=c_{k}}f_{\bm% {\theta}}(\mathbf{x}_{i,j}^{s}).bold_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG | italic_j : italic_y start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT = italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | end_ARG ∑ start_POSTSUBSCRIPT italic_j : italic_y start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT = italic_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) .

To this end, the predicted label for each query input is given by

p(y=k|𝐱q)=exp(d(f𝜽(𝐱q),𝐜k))k=[1,2..]Kexp(d(f𝜽(𝐱q),𝐜k)),p(y=k|\mathbf{x}^{q})=\frac{\text{exp}(-d(f_{\bm{\theta}}(\mathbf{x}^{q}),% \mathbf{c}_{k}))}{\sum_{k^{\prime}=[1,2..]}^{K}\text{exp}(-d(f_{\bm{\theta}}(% \mathbf{x}^{q}),\mathbf{c}_{k}^{\prime}))},italic_p ( italic_y = italic_k | bold_x start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) = divide start_ARG exp ( - italic_d ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) , bold_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = [ 1 , 2 . . ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT exp ( - italic_d ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ) , bold_c start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) end_ARG ,

where d(,)𝑑d(\cdot,\cdot)italic_d ( ⋅ , ⋅ ) is some distance metric e.g., cosine distance.

Appendix C Additional experimental results

C.1 Meta-dataset meta-testing results for meta-tuned supervised pre-trained backbones

Table 4 provides the results of the meta-tuned models on MD using PMF (Hu et al., 2022) and SMAT with Sup21K-ViT-Small backbone.

SMAT attains the best performance.     Although using a different pre-trained backbone, we again observe that SMAT outperforms all baselines in both ID and OOD meta-testing. These results, together with our main results in Tab. 1, validate the efficay of our approach.

Comparing to few-shot meta-testing results for meta-tuned DINO backbone. Meta-tuning on self-supervised and supervised pre-trained backbones produces vastly different generalization results on OOD tasks. While for both PMF and ours, meta-tuning with the Sup21K backbone generally improves few-shot testing performance over the pre-trained backbone even on unseen tasks and domains, meta-tuning with a self-supervised pre-trained backbone (e.g., DINO) requires taking more care in the design of meta-tuning strategy - noticeably, naively meta-tuning with PMF can be outperformed by simple pre-trained + fine-tuned baselines, particularly on OOD tasks; in contrast, SMAT, still maintains the high transferability in its meta-tuned feature representation which leads to better, if not better, fine-tuning performance when using ours meta-tuned model as the fine-tune initialization.

Table 4: Few-shot testing results on the Meta-dataset benchmark and additional OOD testing datasets for methods using Sup21k-ViT-Small backbone (Dosovitskiy et al., 2020). SMAT indicates our method.
Gradient-free Gradient-based
Meta-dataset Pre PMF Ours Pre+full PM+full SMAT+full Pre+LoRA PMF+LoRA SMAT+LoRA
ImageNet 68.45 79.57 81.39 78.96 80.37 82.10 76.30 80.21 81.89
Aircraft 52.57 85.55 87.05 83.20 88.48 89.16 80.73 87.40 88.40
Omniglot 37.02 86.63 86.14 78.25 88.58 88.56 73.75 88.05 87.91
CUB 84.65 94.72 94.98 91.20 94.72 94.98 90.18 94.72 94.98
DTD 80.75 84.84 86.27 87.20 88.60 89.05 86.59 87.85 88.50
Quickdraw 55.05 78.55 79.35 75.75 80.01 79.36 74.07 78.55 79.35
Fungi 44.20 73.02 74.50 56.29 73.02 74.50 55.88 73.02 74.50
VGGFlower 94.11 98.97 99.09 98.01 99.09 99.19 96.97 99.06 99.14
Avg ID 64.60 85.23 86.10 81.11 86.61 87.11 79.31 86.11 86.83
TrafficSig 48.14 55.80 60.10 90.02 90.13 90.13 89.04 88.89 89.52
MSCOCO 52.39 63.77 63.91 64.64 67.02 68.19 64.08 67.12 67.65
Cifar10 79.33 87.50 91.10 93.40 93.78 94.87 92.80 92.31 93.62
Cifar100 68.53 79.11 82.02 88.54 88.81 90.32 87.69 88.81 88.80
MNIST 73.53 93.90 94.01 95.16 96.70 96.76 94.82 96.51 96.67
Avg OOD 64.38 76.02 78.23 86.35 87.29 88.05 85.69 86.76 87.25

C.2 Different fine-tuning strategies for SMAT at meta-testing

Performance.    In Tab 5 below, we evaluate various meta-testing fine-tuning strategies for SMAT and compare their performance. We first note the effectiveness of our proposed gradient-free expert selection method (see Section 4.2), as evidenced by its improved performance compared to directly using SMAT with ProtoNet (Snell et al., 2017) for meta-testing. Second, using θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as the initialization for full fine-tuning, which has the same capacity as the pre-trained model θ𝚙𝚛𝚎superscript𝜃𝚙𝚛𝚎\theta^{\mathtt{pre}}italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT, leads to improved performance over fine-tuning the full SMAT model jointly (i.e., (4)), as it is sufficiently expressive while much more parameter-efficient than the latter hence avoids potential over-fitting issues. We provide more explaination as follows.

Gradient-based fine-tuning outperforms gradient-free fine-tuning primarily because it is much more flexible. The extra flexibility of gradient-based fine-tuning (Alg. 3, Appendix) stems from the fact that the entire model θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is allowed to be updated, whereas gradient-free fine-tuning (Alg. 2, Appendix) only allows updates on the expert selection weights αisubscript𝛼𝑖\alpha_{i}italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - with dimensions equal to the number of experts in SMAT; while all other parameters remain frozen. That said, the hypothesis space of our gradient-free fine-tuning algorithm is much more tightly constrained around the model meta-tuned on ID tasks. As a result, the effectiveness of gradient-free fine-tuning is limited on certain OOD tasks that exhibit noticeable distribution shift, such as TrafficSign. In such cases, the meta-tuned ID model may become inadequate, requiring significant parameter updates.

Table 5: Different fine-tuning strategies for SMAT on a subset of few-shot testing tasks. (1): Direct inference by ProtoNet, reported in Tab. 1 as SMAT; (2): Gradient-free expert selection (Alg. 2); (3) Full fine-tuning using θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as initialization (Alg. 3), reported in Tab. 1 as SMAT+Full. (4): Full fine-tuning the entire SMAT model, i.e., 𝜻,𝜽𝜹,𝚽𝜻superscript𝜽𝜹𝚽\bm{\zeta},\bm{\theta^{\delta}},\bm{\Phi}bold_italic_ζ , bold_italic_θ start_POSTSUPERSCRIPT bold_italic_δ end_POSTSUPERSCRIPT , bold_Φ jointly.
Gradient-free Gradient-based
Datasets (1) (2) (3) (4)
TrafficSign 58.51 59.59 90.83 89.93
MSCOCO 57.35 58.78 63.07 62.76
Cifar10 83.95 86.95 92.08 91.97
Cifar100 74.85 77.20 85.91 85.95
MNIST 94.53 94.63 96.73 96.70

Computational cost.     We present results below in Tab. 6 for a quantitative comparison of computational cost, in terms of time and GPU memory, between w/o fine-tuning, gradient-free fine-tuning, and gradient-based full fine-tuning for SMAT. All fine-tuning are carried out in FP16 mixed precision. In particular, our gradient-free fine-tuning method offers a +1.591.59+1.59+ 1.59% improvement on average over w/o fine-tuning at no additional memory cost, while saves 3/4343/43 / 4 of the total memory cost of gradient-based fine-tuning. Gradient-based fine-tuning, however, outperforms both by at least +10.3110.31+10.31+ 10.31% despite requiring 4x the GPU memory of both, and 2x the time cost compared to gradient-free fine-tuning.

Table 6: Computational cost for fine-tuning a SMAT model (ViT-DINO-Small backbone) at meta-testing time with different fine-tuning strategies.
Method Time (sec./task) GPU memory (MiB) Avg. Acc (Tab. 5)
w/o fine-tuning (ProtoNet) 0.2 4332 73.83
Alg.( 2). gradient-free fine-tuning (50steps) 6.4 4332 75.43
Alg.( 3). gradient-based full fine-tuning (50steps) 11.8 17264 85.74

C.3 Performance vs number of parameters.

Details on parameter counts: SMAT: We use a very naive compression scheme to remove exact zeros in our model. We use a (𝚟𝚊𝚕𝚞𝚎,𝚙𝚘𝚜𝚒𝚝𝚒𝚘𝚗)𝚟𝚊𝚕𝚞𝚎𝚙𝚘𝚜𝚒𝚝𝚒𝚘𝚗(\mathtt{value},\mathtt{position})( typewriter_value , typewriter_position ) tuple to represent each non-zero parameter in our model after flattening all parameters in a single long vector. Thus, the total number of parameters left in the experts is equal to two-times the number of non-zero parameters remained at the end of meta-tuning. We point out that there are perhaps more memory efficient ways for representing sparse weights e.g., PyTorch sparse tensors, which could potentially result in a more significant saving in terms of number of binary bits.

Refer to caption
Figure 7: Average testing performance of models vs model parameter counts. We observe SMAT attains the best overall performance while requiring less number of parameters comparing to other MoE implementation variants including SoftMoE (Puigcerver et al., 2023), and incorporating experts in the MLP layers in ViT (SMAT MLP experts).

Appendix D Sparse interpolated experts

As previously stated, by rearranging Eqn. (1) slightly, the experts now essentially become different sparse interpolations between the same pre-trained and meta-tuned models. Here are the details.

Starting from Eqn. (1), we have:

𝜽isubscript𝜽𝑖\displaystyle\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =𝜽𝚙𝚛𝚎+m=1||αi,m𝜽mδabsentsuperscript𝜽𝚙𝚛𝚎superscriptsubscript𝑚1subscript𝛼𝑖𝑚subscriptsuperscript𝜽𝛿𝑚\displaystyle=\bm{\theta}^{\mathtt{pre}}+\sum_{m=1}^{|{\mathcal{M}}|}\alpha_{i% ,m}\bm{\theta}^{\delta}_{m}= bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT (3)
=𝜽𝚙𝚛𝚎+m=1||αi,m(𝒛m𝜽δ)absentsuperscript𝜽𝚙𝚛𝚎superscriptsubscript𝑚1subscript𝛼𝑖𝑚direct-productsubscript𝒛𝑚superscript𝜽𝛿\displaystyle=\bm{\theta}^{\mathtt{pre}}+\sum_{m=1}^{|{\mathcal{M}}|}\alpha_{i% ,m}(\bm{z}_{m}\odot\bm{\theta}^{\delta})= bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT italic_δ end_POSTSUPERSCRIPT ) (4)
=𝜽𝚙𝚛𝚎+m=1||αi,m(𝒛m𝜽δ+𝚙𝚛𝚎)m=1||αi,m(𝒛m𝜽𝚙𝚛𝚎)absentsuperscript𝜽𝚙𝚛𝚎superscriptsubscript𝑚1subscript𝛼𝑖𝑚direct-productsubscript𝒛𝑚superscript𝜽𝛿𝚙𝚛𝚎superscriptsubscript𝑚1subscript𝛼𝑖𝑚direct-productsubscript𝒛𝑚superscript𝜽𝚙𝚛𝚎\displaystyle=\bm{\theta}^{\mathtt{pre}}+\sum_{m=1}^{|{\mathcal{M}}|}\alpha_{i% ,m}(\bm{z}_{m}\odot\bm{\theta}^{\delta+\mathtt{pre}})-\sum_{m=1}^{|{\mathcal{M% }}|}\alpha_{i,m}(\bm{z}_{m}\odot\bm{\theta}^{\mathtt{pre}})= bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT italic_δ + typewriter_pre end_POSTSUPERSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT ) (5)
=m=1||αi,m(𝟏𝜽𝚙𝚛𝚎)+m=1||αi,m(𝒛m𝜽δ+𝚙𝚛𝚎)m=1||αi,m(𝒛m𝜽𝚙𝚛𝚎)absentsuperscriptsubscript𝑚1subscript𝛼𝑖𝑚direct-product1superscript𝜽𝚙𝚛𝚎superscriptsubscript𝑚1subscript𝛼𝑖𝑚direct-productsubscript𝒛𝑚superscript𝜽𝛿𝚙𝚛𝚎superscriptsubscript𝑚1subscript𝛼𝑖𝑚direct-productsubscript𝒛𝑚superscript𝜽𝚙𝚛𝚎\displaystyle=\sum_{m=1}^{|{\mathcal{M}}|}\alpha_{i,m}(\bm{1}\odot\bm{\theta}^% {\mathtt{pre}})+\sum_{m=1}^{|{\mathcal{M}}|}\alpha_{i,m}(\bm{z}_{m}\odot\bm{% \theta}^{\delta+\mathtt{pre}})-\sum_{m=1}^{|{\mathcal{M}}|}\alpha_{i,m}(\bm{z}% _{m}\odot\bm{\theta}^{\mathtt{pre}})= ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT ( bold_1 ⊙ bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT italic_δ + typewriter_pre end_POSTSUPERSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT ) (6)
=m=1||αi,m((𝟏𝒛m)𝜽𝚙𝚛𝚎+𝒛m𝜽δ+𝚙𝚛𝚎).absentsuperscriptsubscript𝑚1subscript𝛼𝑖𝑚direct-product1subscript𝒛𝑚superscript𝜽𝚙𝚛𝚎direct-productsubscript𝒛𝑚superscript𝜽𝛿𝚙𝚛𝚎\displaystyle=\sum_{m=1}^{|{\mathcal{M}}|}\alpha_{i,m}((\bm{1}-\bm{z}_{m})% \odot\bm{\theta}^{\mathtt{pre}}+\bm{z}_{m}\odot\bm{\theta}^{\delta+\mathtt{pre% }}).= ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT ( ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ⊙ bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT italic_δ + typewriter_pre end_POSTSUPERSCRIPT ) . (7)

We have assumed that m=1||αi,m=1superscriptsubscript𝑚1subscript𝛼𝑖𝑚1\sum_{m=1}^{|{\mathcal{M}}|}\alpha_{i,m}=1∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_M | end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i , italic_m end_POSTSUBSCRIPT = 1 which we have ensured through normalizing the expert activation in Alg. 1. The result shows that each task model 𝜽isubscript𝜽𝑖\bm{\theta}_{i}bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT can now be interpreted as a weighted sum of different experts, where each expert (i.e.,(𝟏𝒛m)𝜽𝚙𝚛𝚎+𝒛m𝜽δ+𝚙𝚛𝚎)(\bm{1}-\bm{z}_{m})\odot\bm{\theta}^{\mathtt{pre}}+\bm{z}_{m}\odot\bm{\theta}^% {\delta+\mathtt{pre}})( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ⊙ bold_italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT + bold_italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT italic_δ + typewriter_pre end_POSTSUPERSCRIPT )) is a sparse interpolation between pre-trained θ𝚙𝚛𝚎superscript𝜃𝚙𝚛𝚎\theta^{\mathtt{pre}}italic_θ start_POSTSUPERSCRIPT typewriter_pre end_POSTSUPERSCRIPT and meta-tuned θ𝚙𝚛𝚎+δsuperscript𝜃𝚙𝚛𝚎𝛿\theta^{\mathtt{pre}+\delta}italic_θ start_POSTSUPERSCRIPT typewriter_pre + italic_δ end_POSTSUPERSCRIPT models in the parameter space.