License: arXiv.org perpetual non-exclusive license
arXiv:2312.13486v2 [cs.LG] 18 Jan 2024

Meta-Learning with Versatile Loss Geometries
for Fast Adaptation Using Mirror Descent

Abstract

Utilizing task-invariant prior knowledge extracted from related tasks, meta-learning is a principled framework that empowers learning a new task especially when data records are limited. A fundamental challenge in meta-learning is how to quickly “adapt” the extracted prior in order to train a task-specific model within a few optimization steps. Existing approaches deal with this challenge using a preconditioner that enhances convergence of the per-task training process. Though effective in representing locally a quadratic training loss, these simple linear preconditioners can hardly capture complex loss geometries. The present contribution addresses this limitation by learning a nonlinear mirror map, which induces a versatile distance metric to enable capturing and optimizing a wide range of loss geometries, hence facilitating the per-task training. Numerical tests on few-shot learning datasets demonstrate the superior expressiveness and convergence of the advocated approach.

Index Terms—  Meta-learning, bilevel optimization, mirror descent, loss geometries

1 Introduction

The success of deep learning relies heavily on large-scale and high-dimensional models, which require extensive training using a large number of data. However, this “data-driven learning” approach is not feasible in applications where data are scarce due to costly data collection and labelling process. Examples of such applications include drug discovery [1], machine translation [2], and robot manipulation [3].

In contrast, meta-learning offers a powerful approach for learning a task in data-limited setups. Specifically, meta-learning extracts task-invariant prior information from a collection of given tasks, that can subsequently aid learning of a new, albeit related task. Although this new task may have limited training data, the prior serves as a strong inductive bias that effectively transfers knowledge to aid its learning. In image classification for instance, a feature extractor learned from a collection of given tasks can act as a common prior, and thus benefit a variety of other image classification tasks.

Depending on how this “data-limited learning” is performed, meta-learning algorithms can be categorized into neural network (NN)- and optimization-based ones. In NN-based ones, the per-task learning is viewed as an NN map** from its training data to task-specific model parameters [4, 5]. The prior information is encoded in the NN weights, which are shared and optimized across tasks. With the universality of NNs in approximating complex map**s granted, their black-box structure challenges their reliability and interpretability. On the other hand, optimization-based meta-learning alternatives interpret “data-limited learning” as a cascade of a few optimization iterations (a.k.a. adaptation) over the model parameters. The prior here is captured by the shared hyperparameters of the iterative optimizer. A representative of these alternatives is the model-agnostic meta-learning (MAML) [6], which views the prior as a learnable task-invariant initialization of the optimizer. By starting from an informative initial point, the model parameters can rapidly converge to local minima within a few gradient descent (GD) steps. Building upon MAML, a series of variants have been proposed to learn different priors [7, 8, 9].

While optimization-based meta-learning has been proven effective numerically, recent studies suggest that its generalization and stability heavily rely on convergence of per-task optimization [7, 9]. This motivates one to grow the number of descent iterations. However, this can be infeasible as the overall complexity of meta-learning scales linearly with the number of GD steps [7]. Besides, using accelerated first-order optimizers, such as Adam [10], introduces extra backpropagation complexity when optimizing the prior. To improve the per-task convergence without markedly adding to the complexity, another line of research focuses on second-order optimization using a learnable precondition matrix having simple form [11, 12, 13, 14, 15, 16]. In fact, the precondition matrix captures the local quadratic curvature of the training loss, and linearly transforms the gradient based on this curvature. To acquire more expressive preconditioners, recent advances suggest replacing the linear matrix multiplication with a nonlinear NN transformation [17]. However, convergence of this NN-manipulated GD is an uncharted territory.

The present work advocates learning a generic distance metric induced by a strictly increasing nonlinear mirror map, which enables efficient optimization over generic loss geometries. All in all, our contribution is three-fold.

  1. i)

    Broadening linear preconditioners with guaranteed per-task convergence.

  2. ii)

    Blockwise inverse autoregressive flow (blockIAF) ensuring monotonicity and scalability of the mirror map.

  3. iii)

    Numerical tests showing superior performance and improved convergence compared to linear preconditioners.

2 Problem setup

To enable “data-limited learning” of a new task, meta-learning forms task-invariant priors using a collection of given tasks indexed by t=1,,T𝑡1𝑇t=1,\ldots,Titalic_t = 1 , … , italic_T. Each task comprises a dataset 𝒟t:={(𝐱tn,ytn)}n=1Ntassignsubscriptsuperscript𝒟absent𝑡superscriptsubscriptsuperscriptsubscript𝐱𝑡𝑛superscriptsubscript𝑦𝑡𝑛𝑛1subscript𝑁𝑡\mathcal{D}^{\mathrm{}}_{t}:=\{(\mathbf{x}_{t}^{n},y_{t}^{n})\}_{n=1}^{N_{t}}caligraphic_D start_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT := { ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT consisting of Ntsubscript𝑁𝑡N_{t}italic_N start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT data-label pairs, which are split into a training subset 𝒟ttrnsubscriptsuperscript𝒟trn𝑡\mathcal{D}^{\mathrm{trn}}_{t}caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and a disjoint validation subset 𝒟tvalsubscriptsuperscript𝒟val𝑡\mathcal{D}^{\mathrm{val}}_{t}caligraphic_D start_POSTSUPERSCRIPT roman_val end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. The new task, indexed by \star, contains a training subset 𝒟trnsubscriptsuperscript𝒟trn\mathcal{D}^{\mathrm{trn}}_{\star}caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT, and a set of test data {𝐱n}n=1Ntstsuperscriptsubscriptsuperscriptsubscript𝐱𝑛𝑛1superscriptsubscript𝑁tst\{\mathbf{x}_{\star}^{n}\}_{n=1}^{N_{\star}^{\mathrm{tst}}}{ bold_x start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_tst end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT for which the corresponding labels {yn}n=1Ntstsuperscriptsubscriptsuperscriptsubscript𝑦𝑛𝑛1superscriptsubscript𝑁tst\{y_{\star}^{n}\}_{n=1}^{N_{\star}^{\mathrm{tst}}}{ italic_y start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_tst end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT are to be predicted. The key premise of meta-learning is that all the aforementioned tasks share related model structures or data distributions. Thus, one can postulate a large model shared across all tasks, along with distinct model parameters ϕtdsubscriptbold-italic-ϕ𝑡superscript𝑑\boldsymbol{\phi}_{t}\in\mathbb{R}^{d}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT per individual task. But since the cardinality Nttrn:=|𝒟ttrn|assignsuperscriptsubscript𝑁𝑡trnsubscriptsuperscript𝒟trn𝑡N_{t}^{\mathrm{trn}}:=|\mathcal{D}^{\mathrm{trn}}_{t}|italic_N start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT := | caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | can be much smaller than d𝑑ditalic_d, learning a task by directly optimizing ϕtsubscriptbold-italic-ϕ𝑡\boldsymbol{\phi}_{t}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over 𝒟ttrnsubscriptsuperscript𝒟trn𝑡\mathcal{D}^{\mathrm{trn}}_{t}caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is impractical. Fortunately, since T𝑇Titalic_T is considerably large, a task-invariant prior can be learned using {𝒟tval}t=1Tsuperscriptsubscriptsubscriptsuperscript𝒟val𝑡𝑡1𝑇\{\mathcal{D}^{\mathrm{val}}_{t}\}_{t=1}^{T}{ caligraphic_D start_POSTSUPERSCRIPT roman_val end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT to render per-task learning well posed.

Letting 𝜽d𝜽superscriptsuperscript𝑑\boldsymbol{\theta}\in\mathbb{R}^{d^{\prime}}bold_italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT denote the vector parameter of the prior, the meta-learning objective can be formulated as a bilevel optimization problem. The lower-level trains each task-specific model by optimizing ϕtsubscriptbold-italic-ϕ𝑡\boldsymbol{\phi}_{t}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT using 𝒟ttrnsubscriptsuperscript𝒟trn𝑡\mathcal{D}^{\mathrm{trn}}_{t}caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and 𝜽𝜽\boldsymbol{\theta}bold_italic_θ from the upper-level. The upper-level adjusts 𝜽𝜽\boldsymbol{\theta}bold_italic_θ by evaluating the optimized ϕtsubscriptbold-italic-ϕ𝑡\boldsymbol{\phi}_{t}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT on the validation sets {𝒟tval}t=1Tsuperscriptsubscriptsubscriptsuperscript𝒟val𝑡𝑡1𝑇\{\mathcal{D}^{\mathrm{val}}_{t}\}_{t=1}^{T}{ caligraphic_D start_POSTSUPERSCRIPT roman_val end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT. The two levels depend on each other and yield the following nested objective

min𝜽t=1T(ϕt*(𝜽);𝒟tval)subscript𝜽superscriptsubscript𝑡1𝑇superscriptsubscriptbold-italic-ϕ𝑡𝜽subscriptsuperscript𝒟val𝑡\displaystyle\min_{\boldsymbol{\theta}}\sum_{t=1}^{T}\mathcal{L}(\boldsymbol{% \phi}_{t}^{*}(\boldsymbol{\theta});\mathcal{D}^{\mathrm{val}}_{t})roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_θ ) ; caligraphic_D start_POSTSUPERSCRIPT roman_val end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) (1a)
s.t.ϕt*(𝜽)=argminϕt(ϕt;𝒟ttrn)+(ϕt;𝜽),ts.t.superscriptsubscriptbold-italic-ϕ𝑡𝜽subscriptargminsubscriptbold-italic-ϕ𝑡subscriptbold-italic-ϕ𝑡subscriptsuperscript𝒟trn𝑡subscriptbold-italic-ϕ𝑡𝜽for-all𝑡\displaystyle~{}\text{s.t.}~{}~{}\boldsymbol{\phi}_{t}^{*}(\boldsymbol{\theta}% )=\operatornamewithlimits{argmin}_{\boldsymbol{\phi}_{t}}\mathcal{L}(% \boldsymbol{\phi}_{t};\mathcal{D}^{\mathrm{trn}}_{t})+\mathcal{R}(\boldsymbol{% \phi}_{t};\boldsymbol{\theta}),~{}\forall ts.t. bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_θ ) = roman_argmin start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + caligraphic_R ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_θ ) , ∀ italic_t (1b)

where \mathcal{L}caligraphic_L is the loss function capturing each task-specific model fit, and \mathcal{R}caligraphic_R is the regularizer accounting for the task-invariant prior. From the Bayesian viewpoint, \mathcal{L}caligraphic_L and \mathcal{R}caligraphic_R represent the negative log-likelihood (nll), logp(𝐲ttrn|ϕt;𝐗ttrn)𝑝conditionalsuperscriptsubscript𝐲𝑡trnsubscriptbold-italic-ϕ𝑡superscriptsubscript𝐗𝑡trn-\log p(\mathbf{y}_{t}^{\mathrm{trn}}|\boldsymbol{\phi}_{t};\mathbf{X}_{t}^{% \mathrm{trn}})- roman_log italic_p ( bold_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT | bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT ), and the negative log-prior (nlp) logp(ϕt;𝜽)𝑝subscriptbold-italic-ϕ𝑡𝜽-\log p(\boldsymbol{\phi}_{t};\boldsymbol{\theta})- roman_log italic_p ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_θ ), where 𝐗ttrn:=[𝐱t1,,𝐱tNttrn]assignsuperscriptsubscript𝐗𝑡trnsuperscriptsubscript𝐱𝑡1superscriptsubscript𝐱𝑡superscriptsubscript𝑁𝑡trn\mathbf{X}_{t}^{\mathrm{trn}}:=[\mathbf{x}_{t}^{1},\ldots,\mathbf{x}_{t}^{N_{t% }^{\mathrm{trn}}}]bold_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT := [ bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ] and 𝐲ttrn:=[yt1,,ytNttrn]assignsuperscriptsubscript𝐲𝑡trnsuperscriptsuperscriptsubscript𝑦𝑡1superscriptsubscript𝑦𝑡superscriptsubscript𝑁𝑡trntop\mathbf{y}_{t}^{\mathrm{trn}}:=[y_{t}^{1},\ldots,y_{t}^{N_{t}^{\mathrm{trn}}}]% ^{\top}bold_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT := [ italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (top{}^{\top}start_FLOATSUPERSCRIPT ⊤ end_FLOATSUPERSCRIPT denotes transpose). Bayes’ rule then implies ϕt*=argminlogp(ϕt|\boldsymbol{\phi}_{t}^{*}=\operatornamewithlimits{argmin}-\log p(\boldsymbol{% \phi}_{t}|bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = roman_argmin - roman_log italic_p ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | 𝐲ttrnsuperscriptsubscript𝐲𝑡trn\mathbf{y}_{t}^{\mathrm{trn}}bold_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT; 𝐗ttrn,𝜽)\mathbf{X}_{t}^{\mathrm{trn}},\boldsymbol{\theta})bold_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT , bold_italic_θ ) is the maximum a posteriori (MAP) estimator.

Reaching the global optimum ϕt*superscriptsubscriptbold-italic-ϕ𝑡\boldsymbol{\phi}_{t}^{*}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT is generally infeasible because the task-specific model is nonlinear. Hence, a prudent remedy is to rely on an approximate solver ϕ^tϕt*subscript^bold-italic-ϕ𝑡superscriptsubscriptbold-italic-ϕ𝑡\hat{\boldsymbol{\phi}}_{t}\approx\boldsymbol{\phi}_{t}^{*}over^ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≈ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT obtained by a tractable optimizer. For instance, MAML replaces (1b) with a K𝐾Kitalic_K-step GD minimizing the nll:

ϕt(k)(𝜽)=ϕt(k1)(𝜽)α(ϕt(k1)(𝜽);𝒟ttrn),tsuperscriptsubscriptbold-italic-ϕ𝑡𝑘𝜽superscriptsubscriptbold-italic-ϕ𝑡𝑘1𝜽𝛼superscriptsubscriptbold-italic-ϕ𝑡𝑘1𝜽subscriptsuperscript𝒟trn𝑡for-all𝑡\boldsymbol{\phi}_{t}^{(k)}(\boldsymbol{\theta})=\boldsymbol{\phi}_{t}^{(k-1)}% (\boldsymbol{\theta})-\alpha\nabla\mathcal{L}(\boldsymbol{\phi}_{t}^{(k-1)}(% \boldsymbol{\theta});\mathcal{D}^{\mathrm{trn}}_{t}),~{}\forall tbold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( bold_italic_θ ) = bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ( bold_italic_θ ) - italic_α ∇ caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ( bold_italic_θ ) ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , ∀ italic_t (2)

where k=1,,K𝑘1𝐾k=1,\ldots,Kitalic_k = 1 , … , italic_K indexes iterations; initialization ϕt(0)=ϕ(0)=𝜽superscriptsubscriptbold-italic-ϕ𝑡0superscriptbold-italic-ϕ0𝜽\boldsymbol{\phi}_{t}^{(0)}=\boldsymbol{\phi}^{(0)}=\boldsymbol{\theta}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = bold_italic_ϕ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = bold_italic_θ; approximate solver ϕ^t(𝜽)=ϕt(K)(𝜽)subscript^bold-italic-ϕ𝑡𝜽superscriptsubscriptbold-italic-ϕ𝑡𝐾𝜽\hat{\boldsymbol{\phi}}_{t}(\boldsymbol{\theta})=\boldsymbol{\phi}_{t}^{(K)}(% \boldsymbol{\theta})over^ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_θ ) = bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT ( bold_italic_θ ); and α𝛼\alphaitalic_α denotes the step size. Although (ϕt;𝜽)=0subscriptbold-italic-ϕ𝑡𝜽0\mathcal{R}(\boldsymbol{\phi}_{t};\boldsymbol{\theta})=0caligraphic_R ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_θ ) = 0 in MAML, it has been shown that the GD solver satisfies [18]

ϕ^t(𝜽)ϕt*(𝜽)=argminϕt(ϕt;𝒟ttrn)+12ϕt𝜽𝚲t2,tformulae-sequencesubscript^bold-italic-ϕ𝑡𝜽superscriptsubscriptbold-italic-ϕ𝑡𝜽subscriptargminsubscriptbold-italic-ϕ𝑡subscriptbold-italic-ϕ𝑡subscriptsuperscript𝒟trn𝑡12superscriptsubscriptnormsubscriptbold-italic-ϕ𝑡𝜽subscript𝚲𝑡2for-all𝑡\hat{\boldsymbol{\phi}}_{t}(\boldsymbol{\theta})\approx\boldsymbol{\phi}_{t}^{% *}(\boldsymbol{\theta})=\operatornamewithlimits{argmin}_{\boldsymbol{\phi}_{t}% }\mathcal{L}(\boldsymbol{\phi}_{t};\mathcal{D}^{\mathrm{trn}}_{t})+\frac{1}{2}% \|\boldsymbol{\phi}_{t}-\boldsymbol{\theta}\|_{\mathbf{\Lambda}_{t}}^{2},~{}\forall tover^ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_θ ) ≈ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_θ ) = roman_argmin start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_θ ∥ start_POSTSUBSCRIPT bold_Λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , ∀ italic_t

where the precision matrix 𝚲tsubscript𝚲𝑡\mathbf{\Lambda}_{t}bold_Λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is determined by 2(𝜽;𝒟ttrn)superscript2𝜽subscriptsuperscript𝒟trn𝑡\nabla^{2}\mathcal{L}(\boldsymbol{\theta};\mathcal{D}^{\mathrm{trn}}_{t})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( bold_italic_θ ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), α𝛼\alphaitalic_α, and K𝐾Kitalic_K. This indicates that MAML’s optimization strategy (2) is approximately tantamount to an implicit Gaussian prior probability density function (pdf) p(ϕt;𝜽)=𝒩(𝜽,𝚲t1)𝑝subscriptbold-italic-ϕ𝑡𝜽𝒩𝜽superscriptsubscript𝚲𝑡1p(\boldsymbol{\phi}_{t};\boldsymbol{\theta})=\mathcal{N}(\boldsymbol{\theta},% \mathbf{\Lambda}_{t}^{-1})italic_p ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_θ ) = caligraphic_N ( bold_italic_θ , bold_Λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ), with the task-invariant initialization serving as the mean vector. Alongside implicit priors, their explicit counterparts have also been investigated with various prior pdfs [7, 9].

For both implicit and explicit priors, numerical studies [11, 13] and theoretical analyses [7, 9] demonstrate that the gradient error for optimizing 𝜽𝜽\boldsymbol{\theta}bold_italic_θ in (1a) relies on the convergence accuracy of ϕ^tsubscript^bold-italic-ϕ𝑡\hat{\boldsymbol{\phi}}_{t}over^ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT relative to a stationary point. In addition, employing a large K𝐾Kitalic_K or complicated optimizers could prohibitively escalate the overall complexity for solving (2). As a consequence, attention has been directed towards preconditioned GD (PGD) solvers, as in the update

ϕt(k)(𝜽)=ϕt(k1)(𝜽)α𝐏(𝜽P)(ϕt(k1)(𝜽);𝒟ttrn)superscriptsubscriptbold-italic-ϕ𝑡𝑘𝜽superscriptsubscriptbold-italic-ϕ𝑡𝑘1𝜽𝛼𝐏subscript𝜽𝑃superscriptsubscriptbold-italic-ϕ𝑡𝑘1𝜽subscriptsuperscript𝒟trn𝑡\boldsymbol{\phi}_{t}^{(k)}(\boldsymbol{\theta})=\boldsymbol{\phi}_{t}^{(k-1)}% (\boldsymbol{\theta})-\alpha\mathbf{P}(\boldsymbol{\theta}_{P})\nabla\mathcal{% L}(\boldsymbol{\phi}_{t}^{(k-1)}(\boldsymbol{\theta});\mathcal{D}^{\mathrm{trn% }}_{t})bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( bold_italic_θ ) = bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ( bold_italic_θ ) - italic_α bold_P ( bold_italic_θ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ) ∇ caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ( bold_italic_θ ) ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) (3)

where 𝜽Psubscript𝜽𝑃\boldsymbol{\theta}_{P}bold_italic_θ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT parametrizes 𝐏d×d𝐏superscript𝑑𝑑\mathbf{P}\in\mathbb{R}^{d\times d}bold_P ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT, and the prior parameter is augmented as 𝜽:=[ϕ(0),𝜽P]assign𝜽superscriptsuperscriptbold-italic-ϕlimit-from0topsuperscriptsubscript𝜽𝑃toptop\boldsymbol{\theta}:=[\boldsymbol{\phi}^{(0)\top},\boldsymbol{\theta}_{P}^{% \top}]^{\top}bold_italic_θ := [ bold_italic_ϕ start_POSTSUPERSCRIPT ( 0 ) ⊤ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. To ensure (3) incurs affordable complexity after preconditioning, 𝐏𝐏\mathbf{P}bold_P must have a simple enough structure so that 𝐏(𝜽P)(ϕt(k1);𝒟ttrn)𝐏subscript𝜽𝑃superscriptsubscriptbold-italic-ϕ𝑡𝑘1subscriptsuperscript𝒟trn𝑡\mathbf{P}(\boldsymbol{\theta}_{P})\nabla\mathcal{L}(\boldsymbol{\phi}_{t}^{(k% -1)};\mathcal{D}^{\mathrm{trn}}_{t})bold_P ( bold_italic_θ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ) ∇ caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) incurs computational complexity 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ). Examples of such structures include diagonal [11, 12], block-diagonal [13, 14], and NN-based [15] matrices. A more generic preconditioner can be formed by replacing the linear transformation 𝐏(𝜽P)(ϕt(k1);𝒟ttrn)𝐏subscript𝜽𝑃superscriptsubscriptbold-italic-ϕ𝑡𝑘1subscriptsuperscript𝒟trn𝑡\mathbf{P}(\boldsymbol{\theta}_{P})\nabla\mathcal{L}(\boldsymbol{\phi}_{t}^{(k% -1)};\mathcal{D}^{\mathrm{trn}}_{t})bold_P ( bold_italic_θ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ) ∇ caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) with a nonlinear NN f((ϕt(k1);𝒟ttrn);𝜽P)𝑓superscriptsubscriptbold-italic-ϕ𝑡𝑘1subscriptsuperscript𝒟trn𝑡subscript𝜽𝑃f(\nabla\mathcal{L}(\boldsymbol{\phi}_{t}^{(k-1)};\mathcal{D}^{\mathrm{trn}}_{% t});\boldsymbol{\theta}_{P})italic_f ( ∇ caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ; bold_italic_θ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ) [17], but unfortunately convergence of this alternative iterate may not be guaranteed.

Essentially, GD conducts a pre-step greedy search with a quadratic loss approximation. To see this, let lin((ϕt),ϕ¯t)linsubscriptbold-italic-ϕ𝑡subscript¯bold-italic-ϕ𝑡\text{lin}(\mathcal{L}(\boldsymbol{\phi}_{t}),\bar{\boldsymbol{\phi}}_{t})lin ( caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , over¯ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) :=(ϕ¯t;𝒟ttrn)+(ϕtϕ¯t)(ϕ¯t;𝒟ttrn)assignabsentsubscript¯bold-italic-ϕ𝑡subscriptsuperscript𝒟trn𝑡superscriptsubscriptbold-italic-ϕ𝑡subscript¯bold-italic-ϕ𝑡topsubscript¯bold-italic-ϕ𝑡subscriptsuperscript𝒟trn𝑡:=\mathcal{L}(\bar{\boldsymbol{\phi}}_{t};\mathcal{D}^{\mathrm{trn}}_{t})+(% \boldsymbol{\phi}_{t}-\bar{\boldsymbol{\phi}}_{t})^{\top}\nabla\mathcal{L}(% \bar{\boldsymbol{\phi}}_{t};\mathcal{D}^{\mathrm{trn}}_{t}):= caligraphic_L ( over¯ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - over¯ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ caligraphic_L ( over¯ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). Using this linearization of \mathcal{L}caligraphic_L at ϕ¯tdsubscript¯bold-italic-ϕ𝑡superscript𝑑\bar{\boldsymbol{\phi}}_{t}\in\mathbb{R}^{d}over¯ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, the GD update reduces to (cf. (2))

ϕt(k)=argminϕtlin((ϕt),ϕt(k1))+12αϕtϕt(k1)22superscriptsubscriptbold-italic-ϕ𝑡𝑘subscriptargminsubscriptbold-italic-ϕ𝑡linsubscriptbold-italic-ϕ𝑡superscriptsubscriptbold-italic-ϕ𝑡𝑘112𝛼superscriptsubscriptnormsubscriptbold-italic-ϕ𝑡superscriptsubscriptbold-italic-ϕ𝑡𝑘122\boldsymbol{\phi}_{t}^{(k)}=\operatornamewithlimits{argmin}_{\boldsymbol{\phi}% _{t}}\text{lin}(\mathcal{L}(\boldsymbol{\phi}_{t}),\boldsymbol{\phi}_{t}^{(k-1% )})+\frac{1}{2\alpha}\|\boldsymbol{\phi}_{t}-\boldsymbol{\phi}_{t}^{(k-1)}\|_{% 2}^{2}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = roman_argmin start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT lin ( caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG 2 italic_α end_ARG ∥ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (4)

where dependencies on 𝜽𝜽\boldsymbol{\theta}bold_italic_θ are dropped hereafter for notational brevity. The term 12αϕtϕt(k1)2212𝛼superscriptsubscriptnormsubscriptbold-italic-ϕ𝑡superscriptsubscriptbold-italic-ϕ𝑡𝑘122\frac{1}{2\alpha}\|\boldsymbol{\phi}_{t}-\boldsymbol{\phi}_{t}^{(k-1)}\|_{2}^{2}divide start_ARG 1 end_ARG start_ARG 2 italic_α end_ARG ∥ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT implies the isotropic approximation 2(ϕt(k1);𝒟ttrn)1α𝐈dsuperscript2superscriptsubscriptbold-italic-ϕ𝑡𝑘1subscriptsuperscript𝒟trn𝑡1𝛼subscript𝐈𝑑\nabla^{2}\mathcal{L}(\boldsymbol{\phi}_{t}^{(k-1)};\mathcal{D}^{\mathrm{trn}}% _{t})\approx\frac{1}{\alpha}\mathbf{I}_{d}∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ≈ divide start_ARG 1 end_ARG start_ARG italic_α end_ARG bold_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT, while (3) refines the approximation as a more informative matrix 1α𝐏11𝛼superscript𝐏1\frac{1}{\alpha}\mathbf{P}^{-1}divide start_ARG 1 end_ARG start_ARG italic_α end_ARG bold_P start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT (if invertible). This quadratic local approximation is particularly effective when K𝐾Kitalic_K is large and α𝛼\alphaitalic_α is small, which gradually ameliorates ϕt(k)superscriptsubscriptbold-italic-ϕ𝑡𝑘\boldsymbol{\phi}_{t}^{(k)}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT to a stationary point. In meta-learning however, the standard setup relies on a small K𝐾Kitalic_K (e.g., 1111 or 5555) and a sufficiently large α𝛼\alphaitalic_α, so that the model can quickly adapt to the task with low complexity. This tradeoff highlights the need for learning more expressive loss geometries.

3 Loss Geometries using Mirror Descent

Instead of quadratic approximations of the local loss induced by certain norms (e.g., 2\|\cdot\|_{2}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and 𝐏1\|\cdot\|_{\mathbf{P}^{-1}}∥ ⋅ ∥ start_POSTSUBSCRIPT bold_P start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT), our fresh idea is a data-driven distance metric that captures a broader spectrum of loss geometries. This is accomplished by learning the so-termed “mirror map,” which will be introduced first. All the proofs are delegated to Appendix A.

3.1 Modeling the loss geometry using the mirror map

To generalize the (P)GD, we will replace the 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm in (4) with a generic metric Dhsubscript𝐷D_{h}italic_D start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT to arrive at

ϕt(k)=argminϕtlin((ϕt),ϕt(k1))+1αDh(ϕt,ϕt(k1))superscriptsubscriptbold-italic-ϕ𝑡𝑘subscriptargminsubscriptbold-italic-ϕ𝑡linsubscriptbold-italic-ϕ𝑡superscriptsubscriptbold-italic-ϕ𝑡𝑘11𝛼subscript𝐷subscriptbold-italic-ϕ𝑡superscriptsubscriptbold-italic-ϕ𝑡𝑘1\boldsymbol{\phi}_{t}^{(k)}=\operatornamewithlimits{argmin}_{\boldsymbol{\phi}% _{t}}\text{lin}(\mathcal{L}(\boldsymbol{\phi}_{t}),\boldsymbol{\phi}_{t}^{(k-1% )})+\frac{1}{\alpha}D_{h}(\boldsymbol{\phi}_{t},\boldsymbol{\phi}_{t}^{(k-1)})bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = roman_argmin start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT lin ( caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_α end_ARG italic_D start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ) (5)

where Dh(ϕt,ϕt(k1)):=h(ϕt)lin(h(ϕt),ϕt(k1))assignsubscript𝐷subscriptbold-italic-ϕ𝑡superscriptsubscriptbold-italic-ϕ𝑡𝑘1subscriptbold-italic-ϕ𝑡linsubscriptbold-italic-ϕ𝑡superscriptsubscriptbold-italic-ϕ𝑡𝑘1D_{h}(\boldsymbol{\phi}_{t},\boldsymbol{\phi}_{t}^{(k-1)}):=h(\boldsymbol{\phi% }_{t})-\text{lin}(h(\boldsymbol{\phi}_{t}),\boldsymbol{\phi}_{t}^{(k-1)})italic_D start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ) := italic_h ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - lin ( italic_h ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ) is the Bregman divergence, and the associated distance-generating function h:d:maps-tosuperscript𝑑h:\mathbb{R}^{d}\mapsto\mathbb{R}italic_h : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ↦ blackboard_R is strongly convex to ensure the existence and uniqueness of the minimizer. As a result, h\nabla h∇ italic_h is strictly increasing, and thus invertible111When h\nabla h∇ italic_h is discontinuous but hhitalic_h is proper, the inverse (h)1superscript1(\nabla h)^{-1}( ∇ italic_h ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT is defined as h*(𝐳):=argmaxϕϕ𝐳h(ϕ)assignsuperscript𝐳subscriptargmaxbold-italic-ϕsuperscriptbold-italic-ϕtop𝐳bold-italic-ϕ\nabla h^{*}(\mathbf{z}):=\operatornamewithlimits{argmax}_{\boldsymbol{\phi}}% \boldsymbol{\phi}^{\top}\mathbf{z}-h(\boldsymbol{\phi})∇ italic_h start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_z ) := roman_argmax start_POSTSUBSCRIPT bold_italic_ϕ end_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_z - italic_h ( bold_italic_ϕ ), where h*(𝐳):=supϕϕ𝐳h(ϕ)assignsuperscript𝐳subscriptsupremumbold-italic-ϕsuperscriptbold-italic-ϕtop𝐳bold-italic-ϕh^{*}(\mathbf{z}):=\sup_{\boldsymbol{\phi}}\boldsymbol{\phi}^{\top}\mathbf{z}-% h(\boldsymbol{\phi})italic_h start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_z ) := roman_sup start_POSTSUBSCRIPT bold_italic_ϕ end_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_z - italic_h ( bold_italic_ϕ ) is the Fenchel conjugate of hhitalic_h.. Then, applying the optimality condition leads to the mirror descent (MD) update

ϕt(k)=(h)1(h(ϕt(k1))α(ϕt(k1);𝒟ttrn)).superscriptsubscriptbold-italic-ϕ𝑡𝑘superscript1superscriptsubscriptbold-italic-ϕ𝑡𝑘1𝛼superscriptsubscriptbold-italic-ϕ𝑡𝑘1subscriptsuperscript𝒟trn𝑡\boldsymbol{\phi}_{t}^{(k)}=(\nabla h)^{-1}\big{(}\nabla h(\boldsymbol{\phi}_{% t}^{(k-1)})-\alpha\nabla\mathcal{L}(\boldsymbol{\phi}_{t}^{(k-1)};\mathcal{D}^% {\mathrm{trn}}_{t})\big{)}.bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = ( ∇ italic_h ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ∇ italic_h ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ) - italic_α ∇ caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) . (6)

The invertible h\nabla h∇ italic_h, dubbed mirror map, connects ϕtsubscriptbold-italic-ϕ𝑡\boldsymbol{\phi}_{t}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT in the primal space to \nabla\mathcal{L}∇ caligraphic_L in the dual space under the endowed metric Dhsubscript𝐷D_{h}italic_D start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT. As a special case, when choosing h()=1222h(\cdot)=\frac{1}{2}\|\cdot\|_{2}^{2}italic_h ( ⋅ ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, it is easy to verify that (6) boils down to (2) due to the self-duality of the 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm. Likewise, (3) can be obtained with h()=12𝐏12h(\cdot)=\frac{1}{2}\|\cdot\|_{\mathbf{P}^{-1}}^{2}italic_h ( ⋅ ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ ⋅ ∥ start_POSTSUBSCRIPT bold_P start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, where h\nabla h∇ italic_h reduces to a linear map**. Function hhitalic_h reflects our prior knowledge about the geometry of \mathcal{L}caligraphic_L. In particular, letting h()=(;𝒟ttrn)subscriptsuperscript𝒟trn𝑡h(\cdot)=\mathcal{L}(\cdot;\mathcal{D}^{\mathrm{trn}}_{t})italic_h ( ⋅ ) = caligraphic_L ( ⋅ ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) (even when \mathcal{L}caligraphic_L is not strong convex) in (5) gives ϕt(k)=argminϕt(ϕt;𝒟ttrn)superscriptsubscriptbold-italic-ϕ𝑡𝑘subscriptargminsubscriptbold-italic-ϕ𝑡subscriptbold-italic-ϕ𝑡subscriptsuperscript𝒟trn𝑡\boldsymbol{\phi}_{t}^{(k)}=\operatornamewithlimits{argmin}_{\boldsymbol{\phi}% _{t}}\mathcal{L}(\boldsymbol{\phi}_{t};\mathcal{D}^{\mathrm{trn}}_{t})bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = roman_argmin start_POSTSUBSCRIPT bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), which is precisely the original nll minimization solved in (2) and (3). Thus, an ideal choice of hhitalic_h would yield hh\approx\mathcal{L}italic_h ≈ caligraphic_L (up to a constant) within a sufficiently large region around ϕt(k1)superscriptsubscriptbold-italic-ϕ𝑡𝑘1\boldsymbol{\phi}_{t}^{(k-1)}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT.

Different from past works that rely on a simple preselected hhitalic_h to model loss geometries, we here acquire a data-driven hhitalic_h by learning a strictly increasing h\nabla h∇ italic_h that best fits the given tasks. Interestingly, (6) can be reformulated to yield an update of the dual vector 𝐳t:=h(ϕt)assignsubscript𝐳𝑡subscriptbold-italic-ϕ𝑡\mathbf{z}_{t}:=\nabla h(\boldsymbol{\phi}_{t})bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT := ∇ italic_h ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) as

𝐳t(k)=𝐳t(k1)α((h)1(𝐳t(k1));𝒟ttrn)superscriptsubscript𝐳𝑡𝑘superscriptsubscript𝐳𝑡𝑘1𝛼superscript1superscriptsubscript𝐳𝑡𝑘1subscriptsuperscript𝒟trn𝑡\mathbf{z}_{t}^{(k)}=\mathbf{z}_{t}^{(k-1)}-\alpha\nabla\mathcal{L}\big{(}(% \nabla h)^{-1}(\mathbf{z}_{t}^{(k-1)});\mathcal{D}^{\mathrm{trn}}_{t}\big{)}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT - italic_α ∇ caligraphic_L ( ( ∇ italic_h ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ) ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) (7)

with 𝐳t(0)=h(ϕ(0))superscriptsubscript𝐳𝑡0superscriptbold-italic-ϕ0\mathbf{z}_{t}^{(0)}=\nabla h(\boldsymbol{\phi}^{(0)})bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = ∇ italic_h ( bold_italic_ϕ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) and ϕ^t=(h)1(𝐳t(K))subscript^bold-italic-ϕ𝑡superscript1superscriptsubscript𝐳𝑡𝐾\hat{\boldsymbol{\phi}}_{t}=(\nabla h)^{-1}(\mathbf{z}_{t}^{(K)})over^ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( ∇ italic_h ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT ). Hence, it suffices to learn a strictly increasing (h)1superscript1(\nabla h)^{-1}( ∇ italic_h ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT and a task-invariant dual initialization 𝐳(0):=h(ϕ(0))assignsuperscript𝐳0superscriptbold-italic-ϕ0\mathbf{z}^{(0)}:=\nabla h(\boldsymbol{\phi}^{(0)})bold_z start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT := ∇ italic_h ( bold_italic_ϕ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ), thus removing the need for directly calculating h\nabla h∇ italic_h.

3.2 Learning the inverse mirror map via blockIAF

Inspired by this observation, a prudent option is to model (h)1superscript1(\nabla h)^{-1}( ∇ italic_h ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT as an inverse autoregressive flow (IAF) [19]. The notable benefit of IAF lies in its efficient parallelization of forward computation, that makes it considerably faster than computing its inverse. However, directly applying the dimension-wise IAF to the high-dimensional 𝐳tdsubscript𝐳𝑡superscript𝑑\mathbf{z}_{t}\in\mathbb{R}^{d}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT will incur prohibitively high complexity of Ω(d2)Ωsuperscript𝑑2\Omega(d^{2})roman_Ω ( italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). For this reason, we introduce a novel blockIAF model that effectively reduces complexity by performing block-wise (nonlinear) autoregression on a low-dimensional space encoding 𝐳tsubscript𝐳𝑡\mathbf{z}_{t}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. To this end, let {i}i=1Bsuperscriptsubscriptsubscript𝑖𝑖1𝐵\{\mathcal{B}_{i}\}_{i=1}^{B}{ caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT be a partition of the index set {1,,d}1𝑑\{1,\ldots,d\}{ 1 , … , italic_d }, and [𝐳t]isubscriptdelimited-[]subscript𝐳𝑡subscript𝑖[\mathbf{z}_{t}]_{\mathcal{B}_{i}}[ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT denote the subvector of 𝐳tsubscript𝐳𝑡\mathbf{z}_{t}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT restricted to the block isubscript𝑖\mathcal{B}_{i}caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. The blockIAF model transforms 𝐳tsubscript𝐳𝑡\mathbf{z}_{t}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to ϕtsubscriptbold-italic-ϕ𝑡\boldsymbol{\phi}_{t}bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT through

[ϕt]isubscriptdelimited-[]subscriptbold-italic-ϕ𝑡subscript𝑖\displaystyle[\boldsymbol{\phi}_{t}]_{\mathcal{B}_{i}}[ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT =[𝐳t]iσ(𝜶i)+𝝁iabsentdirect-productsubscriptdelimited-[]subscript𝐳𝑡subscript𝑖𝜎subscript𝜶𝑖subscript𝝁𝑖\displaystyle=[\mathbf{z}_{t}]_{\mathcal{B}_{i}}\odot\sigma(\boldsymbol{\alpha% }_{i})+\boldsymbol{\mu}_{i}= [ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⊙ italic_σ ( bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + bold_italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (8a)
[𝜶i,𝝁i]superscriptsuperscriptsubscript𝜶𝑖topsuperscriptsubscript𝝁𝑖toptop\displaystyle[\boldsymbol{\alpha}_{i}^{\top},\boldsymbol{\mu}_{i}^{\top}]^{\top}[ bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , bold_italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT =di({ej([𝐳t]j)}j=1i1),i=1,,Bformulae-sequenceabsentsubscript𝑑𝑖superscriptsubscriptsubscript𝑒𝑗subscriptdelimited-[]subscript𝐳𝑡subscript𝑗𝑗1𝑖1𝑖1𝐵\displaystyle=d_{i}\big{(}\{e_{j}([\mathbf{z}_{t}]_{\mathcal{B}_{j}})\}_{j=1}^% {i-1}\big{)},~{}i=1,\ldots,B= italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( { italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( [ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT ) , italic_i = 1 , … , italic_B (8b)

where nonlinearity σ𝜎\sigmaitalic_σ is positive and upper bounded (e.g., logistic function), σ(𝜶i),𝝁i|i|𝜎subscript𝜶𝑖subscript𝝁𝑖superscriptsubscript𝑖\sigma(\boldsymbol{\alpha}_{i}),\boldsymbol{\mu}_{i}\in\mathbb{R}^{|\mathcal{B% }_{i}|}italic_σ ( bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , bold_italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | end_POSTSUPERSCRIPT are the scale and shift of [𝐳i]isubscriptdelimited-[]subscript𝐳𝑖subscript𝑖[\mathbf{z}_{i}]_{\mathcal{B}_{i}}[ bold_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT, eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and disubscript𝑑𝑖d_{i}italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT denote learnable encoder and decoder for the i𝑖iitalic_i-th block, and direct-product\odot is the Hadamard (element-wise) product. In our implementation, {ei}i=1B1superscriptsubscriptsubscript𝑒𝑖𝑖1𝐵1\{e_{i}\}_{i=1}^{B-1}{ italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B - 1 end_POSTSUPERSCRIPT and {di}i=1Bsuperscriptsubscriptsubscript𝑑𝑖𝑖1𝐵\{d_{i}\}_{i=1}^{B}{ italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT are multilayer perceptrons (MLPs) with ReLU activations. To further reduce complexity, all linear layers in MLPs are implemented by tensor mode product [13]. This technique is equivalent to a low-rank Kronecker approximation to MLPs’ weight matrices. This lowers the per-step MD complexity to 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ).

The following theorem characterizes two important properties of the proposed blockIAF model.

Theorem 1.

Let g:ddnormal-:𝑔maps-tosuperscript𝑑superscript𝑑g:\mathbb{R}^{d}\mapsto\mathbb{R}^{d}italic_g : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ↦ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT be the blockIAF model (3.2). For any partition {i}i=1Bsuperscriptsubscriptsubscript𝑖𝑖1𝐵\{\mathcal{B}_{i}\}_{i=1}^{B}{ caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT, g𝑔gitalic_g is strictly increasing, that is

(𝐳t𝐳t)(g(𝐳t)g(𝐳t))>0,𝐳t𝐳t.formulae-sequencesuperscriptsubscript𝐳𝑡superscriptsubscript𝐳𝑡top𝑔subscript𝐳𝑡𝑔superscriptsubscript𝐳𝑡0for-allsubscript𝐳𝑡superscriptsubscript𝐳𝑡(\mathbf{z}_{t}-\mathbf{z}_{t}^{\prime})^{\top}(g(\mathbf{z}_{t})-g(\mathbf{z}% _{t}^{\prime}))>0,~{}~{}\forall\mathbf{z}_{t}\neq\mathbf{z}_{t}^{\prime}.( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) > 0 , ∀ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≠ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT . (9)

Moreover, there exists a constant C>0𝐶0C>0italic_C > 0 such that

(g1)(ϕt)C.succeeds-or-equalssuperscript𝑔1subscriptbold-italic-ϕ𝑡𝐶\nabla(g^{-1})(\boldsymbol{\phi}_{t})\succeq C.∇ ( italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⪰ italic_C . (10)

Theorem 1 asserts that with (h)1=gsuperscript1𝑔(\nabla h)^{-1}=g( ∇ italic_h ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = italic_g, one ensures the desired strict monotonicity, and strong convexity of the induced hhitalic_h (by noting that 2h=(g1)superscript2superscript𝑔1\nabla^{2}h=\nabla(g^{-1})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h = ∇ ( italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT )). As a result, the per-task optimization (7) enjoys the standard convergence guarantee of MD. Although the convergence rate of MD is in the same order as GD, it outperforms GD markedly in the constant factor when d𝑑ditalic_d is large [20], and relies on more relaxed assumptions [21].

Table 1: Comparison of meta-learning algorithms with different loss geometry models on the 5555-class miniImageNet dataset. Maximum and mean accuracies within its 95%percent9595\%95 % confidence interval are in bold. (No model ensembling for a fair comparison.)
Method Lower-level optimizer Loss geometry model 5555-class accuracies
1111-shot 5555-shot
MAML [6] GD identity matrix 48.70±1.84%plus-or-minus48.70percent1.8448.70\pm 1.84\%48.70 ± 1.84 % 63.11±0.92%plus-or-minus63.11percent0.9263.11\pm 0.92\%63.11 ± 0.92 %
MetaSGD [11] PGD diag. matrix 50.47±1.87%plus-or-minus50.47percent1.8750.47\pm 1.87\%50.47 ± 1.87 % 64.03±0.94%plus-or-minus64.03percent0.9464.03\pm 0.94\%64.03 ± 0.94 %
MT-net [14] PGD block diag. matrix 51.70±1.84%plus-or-minus51.70percent1.8451.70\pm 1.84\%51.70 ± 1.84 % --
WarpGrad [15] PGD NN-based low-rank matrix 52.3±0.8%plus-or-minus52.3percent0.852.3\pm 0.8\%52.3 ± 0.8 % 68.4±0.6%plus-or-minus68.4percent0.668.4\pm 0.6\%68.4 ± 0.6 %
MetaCurvature [13] PGD block diag. & Kron. (low-rank) matrix 54.23±0.88%plus-or-minus54.23percent0.8854.23\pm 0.88\%54.23 ± 0.88 % 67.99±0.73%plus-or-minus67.99percent0.7367.99\pm 0.73\%67.99 ± 0.73 %
MetaKFO [17] NN-transformed GD NN-based gradient transformation -- 64.9%percent64.964.9\%64.9 %
ECML [16] PGD Gauss-Newton approximation 48.94±0.80%plus-or-minus48.94percent0.8048.94\pm 0.80\%48.94 ± 0.80 % 65.26±0.67%plus-or-minus65.26percent0.6765.26\pm 0.67\%65.26 ± 0.67 %
This paper’s method MD blockIAF-based mirror map 56.10±1.43%plus-or-minus56.10percent1.43\mathbf{56.10\pm 1.43\%}bold_56.10 ± bold_1.43 % 69.59±0.71%plus-or-minus69.59percent0.71\mathbf{69.59\pm 0.71\%}bold_69.59 ± bold_0.71 %
Refer to caption
(a) (ϕ(k);𝒟trn)superscriptsubscriptbold-italic-ϕ𝑘subscriptsuperscript𝒟trn\mathcal{L}(\boldsymbol{\phi}_{\star}^{(k)};\mathcal{D}^{\mathrm{trn}}_{\star})caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ) versus k𝑘kitalic_k
Refer to caption
(b) (ϕ(k);𝒟trn)2subscriptnormsuperscriptsubscriptbold-italic-ϕ𝑘subscriptsuperscript𝒟trn2\|\nabla\mathcal{L}(\boldsymbol{\phi}_{\star}^{(k)};\mathcal{D}^{\mathrm{trn}}% _{\star})\|_{2}∥ ∇ caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT versus k𝑘kitalic_k
Fig. 1: Convergence comparison on randomly sampled new tasks.

The meta-learning objective (2) is solved using alternating optimization. With 𝜽gsubscript𝜽𝑔\boldsymbol{\theta}_{g}bold_italic_θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT denoting the blockIAF parameters, let 𝜽:=[𝐳(0),𝜽g]assign𝜽superscriptsuperscript𝐳limit-from0topsuperscriptsubscript𝜽𝑔toptop\boldsymbol{\theta}:=[\mathbf{z}^{(0)\top},\boldsymbol{\theta}_{g}^{\top}]^{\top}bold_italic_θ := [ bold_z start_POSTSUPERSCRIPT ( 0 ) ⊤ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT be the prior parameter vector. In the (r)𝑟(r)( italic_r )-th iteration of (1a), the optimizer has access to 𝜽(r1)superscript𝜽𝑟1\boldsymbol{\theta}^{(r-1)}bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT provided by its last iteration, and a batch of randomly sampled tasks 𝒯(r){1,,T}superscript𝒯𝑟1𝑇\mathcal{T}^{(r)}\subset\{1,\ldots,T\}caligraphic_T start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT ⊂ { 1 , … , italic_T }. The optimizer first solves ϕ^t(𝜽(r1))subscript^bold-italic-ϕ𝑡superscript𝜽𝑟1\hat{\boldsymbol{\phi}}_{t}(\boldsymbol{\theta}^{(r-1)})over^ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT ) for each t𝒯(r)𝑡superscript𝒯𝑟t\in\mathcal{T}^{(r)}italic_t ∈ caligraphic_T start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT leveraging the K𝐾Kitalic_K-step MD (7). Then, 𝜽(r1)superscript𝜽𝑟1\boldsymbol{\theta}^{(r-1)}bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT is updated using mini-batch stochastic GD with step size β𝛽\betaitalic_β:

𝜽(r)=𝜽(r1)βT|𝒯(r)|t𝒯(r)𝜽(r1)(ϕ^t(𝜽(r1));𝒟tval).superscript𝜽𝑟superscript𝜽𝑟1𝛽𝑇superscript𝒯𝑟subscript𝑡superscript𝒯𝑟subscriptsuperscript𝜽𝑟1subscript^bold-italic-ϕ𝑡superscript𝜽𝑟1subscriptsuperscript𝒟val𝑡\boldsymbol{\theta}^{(r)}=\boldsymbol{\theta}^{(r-1)}-\beta\frac{T}{|\mathcal{% T}^{(r)}|}\sum_{t\in\mathcal{T}^{(r)}}\nabla_{\boldsymbol{\theta}^{(r-1)}}% \mathcal{L}(\hat{\boldsymbol{\phi}}_{t}(\boldsymbol{\theta}^{(r-1)});\mathcal{% D}^{\mathrm{val}}_{t}).bold_italic_θ start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT = bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT - italic_β divide start_ARG italic_T end_ARG start_ARG | caligraphic_T start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT | end_ARG ∑ start_POSTSUBSCRIPT italic_t ∈ caligraphic_T start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L ( over^ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT ) ; caligraphic_D start_POSTSUPERSCRIPT roman_val end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) .

A summary of the algorithm can be found in Appendix B.

4 Numerical tests

Here we compare the empirical performance of optimization-based meta-learning using different lower-level optimizers, on the standard few-shot classification dataset miniImageNet [22], where “shots” signify the per-class training data for each t𝑡titalic_t. The task-specific model is a standard 4444-layer convolutional NN (CNN) [22, 6]. Each layer comprises a 3×3333\times 33 × 3 convolution of 64646464 channels, batch normalization, ReLU activation, and 2×2222\times 22 × 2 max pooling module. After the convolutional layers, a linear regressor with softmax activation is appended to perform classification. Subset isubscript𝑖\mathcal{B}_{i}caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is formed by the weight indices of the i𝑖iitalic_i-th CNN layer. The autoregression in (8b) implies that “how to optimize weights of the i𝑖iitalic_i-th layer” depends on “how weights of previous layers have been optimized.” This choice enables blockIAF to model the optimization dependency of high-level features (e.g., textures and patterns) on low-level ones (e.g., colors and edges). Test setups and hyperparameters can be found in Appendix C.

Table 1 lists various loss geometry models, where classification accuracy on new tasks is the figure of merit. For fairness, MAML is the backbone of all methods. By utilizing a more versatile loss geometry model, our approach outperforms the state-of-the-art ones by a large margin.

To further gauge the performance gain achieved by our novel approach, Fig. 1 visualizes the convergence of (ϕ(k);𝒟trn)superscriptsubscriptbold-italic-ϕ𝑘subscriptsuperscript𝒟trn\mathcal{L}(\boldsymbol{\phi}_{\star}^{(k)};\mathcal{D}^{\mathrm{trn}}_{\star})caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ) averaged on 1,00010001,0001 , 000 random new tasks. The proposed method results in faster convergence to a lower and more stable nll compared with all three competitors. Moreover, Fig. 0(a) reveals that both the proposed method and MetaCurvature improve the initialization compared to MAML and MetaSGD. This confirms that convergence and generalization of (1a) relies on the convergence accuracy of ϕ^tsubscript^bold-italic-ϕ𝑡\hat{\boldsymbol{\phi}}_{t}over^ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [7, 9]. Fig. 0(b) further illustrates that although the initial gradients of different methods have comparable norms (ϕ(0);𝒟trn)2subscriptnormsuperscriptsubscriptbold-italic-ϕ0subscriptsuperscript𝒟trn2\|\nabla\mathcal{L}(\boldsymbol{\phi}_{\star}^{(0)};\mathcal{D}^{\mathrm{trn}}% _{\star})\|_{2}∥ ∇ caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, our method can make better use of the gradient, leading to a rapid reduction of the nll as well as its gradient norm at k=1𝑘1k=1italic_k = 1. This improved gradient utilization highlights our method’s superior modeling of loss geometries.

5 Conclusions and outlook

Versatile loss geometry models can accelerate the lower-level convergence in meta-learning. A novel BlockIAF model is introduced to learn the inverse mirror map (h)1superscript1(\nabla h)^{-1}( ∇ italic_h ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT induced by a strongly convex hhitalic_h. The resultant algorithm generalizes preconditioning-based meta-learning, captures versatile loss geometries, and improves lower-level convergence. Effectiveness of the novel approach was validated on a standard few-shot dataset. Future research includes bi-level convergence guarantees for the proposed method, and development of more expressive yet scalable inverse mirror maps.

References

  • [1] Han Altae-Tran, Bharath Ramsundar, Aneesh S. Pappu, and Vijay Pande, “Low data drug discovery with one-shot learning,” ACS Central Science, vol. 3, no. 4, pp. 283–293, 2017.
  • [2] Jiatao Gu, Yong Wang, Yun Chen, Kyunghyun Cho, and Victor OK Li, “Meta-learning for low-resource neural machine translation,” arXiv preprint arXiv:1808.08437, 2018.
  • [3] Ignasi Clavera, Anusha Nagabandi, Simin Liu, Ronald S. Fearing, Pieter Abbeel, Sergey Levine, and Chelsea Finn, “Learning to adapt in dynamic, real-world environments through meta-reinforcement learning,” in Proc. Int. Conf. Learn. Represent., 2019.
  • [4] Sachin Ravi and Hugo Larochelle, “Optimization as a model for few-shot learning,” in Proc. Int. Conf. Learn. Represent., 2017.
  • [5] Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, and Pieter Abbeel, “A simple neural attentive meta-learner,” in Proc. Int. Conf. Learn. Represent., 2018.
  • [6] Chelsea Finn, Pieter Abbeel, and Sergey Levine, “Model-agnostic meta-learning for fast adaptation of deep networks,” in Proc. Int. Conf. Mach. Learn., 2017, vol. 70, pp. 1126–1135.
  • [7] Aravind Rajeswaran, Chelsea Finn, Sham M Kakade, and Sergey Levine, “Meta-learning with implicit gradients,” in Proc. Adv. Neural Inf. Process. Syst., 2019, vol. 32.
  • [8] Kwonjoon Lee, Subhransu Maji, Avinash Ravichandran, and Stefano Soatto, “Meta-learning with differentiable convex optimization,” in Proc. IEEE/CVF Conf. on Comp. Vis. and Pat. Recog., 2019.
  • [9] Yilang Zhang, Bingcong Li, Shijian Gao, and Georgios B. Giannakis, “Scalable bayesian meta-learning through generalized implicit gradients,” in Proc. AAAI Conf. Artif. Intel., 2023, vol. 37(9), pp. 11298–11306.
  • [10] Diederik P. Kingma and Jimmy Ba, “Adam: A method for stochastic optimization,” in Proc. Int. Conf. Learn. Represent., 2015.
  • [11] Zhenguo Li, Fengwei Zhou, Fei Chen, and Hang Li, “Meta-sgd: Learning to learn quickly for few-shot learning,” arXiv preprint arXiv:1707.09835, 2017.
  • [12] Boyan Gao, Henry Gouk, Hae Beom Lee, and Timothy M Hospedales, “Meta mirror descent: Optimiser learning for fast convergence,” arXiv preprint arXiv:2203.02711, 2022.
  • [13] Eunbyung Park and Junier B Oliva, “Meta-curvature,” in Proc. Adv. Neural Inf. Process. Syst., 2019, vol. 32.
  • [14] Yoonho Lee and Seung** Choi, “Gradient-based meta-learning with learned layerwise metric and subspace,” in Proc. Int. Conf. Mach. Learn., 2018, vol. 80, pp. 2927–2936.
  • [15] Sebastian Flennerhag, Andrei A. Rusu, Razvan Pascanu, Francesco Visin, Hujun Yin, and Raia Hadsell, “Meta-learning with warped gradient descent,” in Proc. Int. Conf. Learn. Represent., 2020.
  • [16] Markus Hiller, Mehrtash Harandi, and Tom Drummond, “On enforcing better conditioned meta-learning for rapid few-shot adaptation,” in Proc. Adv. Neural Inf. Process. Syst., 2022, vol. 35, pp. 4059–4071.
  • [17] Sébastien M. R. Arnold, Shariq Iqbal, and Fei Sha, “When maml can adapt fast and how to assist when it cannot,” in Proc. Int. Conf. Artif. Intel. and Stats., 2021, vol. 130, pp. 244–252.
  • [18] Erin Grant, Chelsea Finn, Sergey Levine, Trevor Darrell, and Thomas Griffiths, “Recasting gradient-based meta-learning as hierarchical Bayes,” in Proc. Int. Conf. Learn. Represent., 2018.
  • [19] Durk P Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, and Max Welling, “Improved variational inference with inverse autoregressive flow,” in Proc. Adv. Neural Inf. Process. Syst., 2016, vol. 29.
  • [20] Aharon Ben-Tal, Tamar Margalit, and Arkadi Nemirovski, “The ordered subsets mirror descent optimization method with applications to tomography,” SIAM Journal on Optimization, vol. 12, no. 1, pp. 79–108, 2001.
  • [21] Zhengyuan Zhou, Panayotis Mertikopoulos, Nicholas Bambos, Stephen Boyd, and Peter W Glynn, “Stochastic mirror descent in variationally coherent optimization problems,” in Proc. Adv. Neural Inf. Process. Syst., 2017, vol. 30.
  • [22] Oriol Vinyals, Charles Blundell, Timothy Lillicrap, koray kavukcuoglu, and Daan Wierstra, “Matching networks for one shot learning,” in Proc. Adv. Neural Inf. Process. Syst., 2016, vol. 29.
  • [23] Ashish Bora, Ajil Jalal, Eric Price, and Alexandros G. Dimakis, “Compressed sensing using generative models,” in Proc. Int. Conf. Mach. Learn., Doina Precup and Yee Whye Teh, Eds., 2017, vol. 70, pp. 537–546.

Appendix

Appendix A Proof of Theorem 1

Theorem 1 (Restated).

Let g:ddnormal-:𝑔maps-tosuperscript𝑑superscript𝑑g:\mathbb{R}^{d}\mapsto\mathbb{R}^{d}italic_g : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ↦ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT denote the blockIAF model (3.2). For any partition {i}i=1Bsuperscriptsubscriptsubscript𝑖𝑖1𝐵\{\mathcal{B}_{i}\}_{i=1}^{B}{ caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT, g𝑔gitalic_g is strictly increasing, that is

(𝐳t𝐳t)(g(𝐳t)g(𝐳t))>0,𝐳t𝐳t.formulae-sequencesuperscriptsubscript𝐳𝑡superscriptsubscript𝐳𝑡top𝑔subscript𝐳𝑡𝑔superscriptsubscript𝐳𝑡0for-allsubscript𝐳𝑡superscriptsubscript𝐳𝑡(\mathbf{z}_{t}-\mathbf{z}_{t}^{\prime})^{\top}(g(\mathbf{z}_{t})-g(\mathbf{z}% _{t}^{\prime}))>0,~{}~{}\forall\mathbf{z}_{t}\neq\mathbf{z}_{t}^{\prime}.( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) > 0 , ∀ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≠ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT .

Moreover, there exists a constant C>0𝐶0C>0italic_C > 0 such that

(g1)(ϕt)C.succeeds-or-equalssuperscript𝑔1subscriptbold-italic-ϕ𝑡𝐶\nabla(g^{-1})(\boldsymbol{\phi}_{t})\succeq C.∇ ( italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⪰ italic_C .
Proof.

Let π:=[1,,B]assign𝜋subscript1subscript𝐵\pi:=[\mathcal{B}_{1},\ldots,\mathcal{B}_{B}]italic_π := [ caligraphic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , caligraphic_B start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ] denote a permutation of {1,,n}1𝑛\{1,\ldots,n\}{ 1 , … , italic_n }, and 𝐐πd×d:=[[𝐈d]1,,[𝐈d]B]subscript𝐐𝜋superscript𝑑𝑑assignsubscriptdelimited-[]subscript𝐈𝑑subscript1subscriptdelimited-[]subscript𝐈𝑑subscript𝐵\mathbf{Q}_{\pi}\in\mathbb{R}^{d\times d}:=\big{[}[\mathbf{I}_{d}]_{\mathcal{B% }_{1}},\ldots,[\mathbf{I}_{d}]_{\mathcal{B}_{B}}\big{]}bold_Q start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT := [ [ bold_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , [ bold_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] the permutation matrix under π𝜋\piitalic_π, where [𝐈d]isubscriptdelimited-[]subscript𝐈𝑑subscript𝑖[\mathbf{I}_{d}]_{\mathcal{B}_{i}}[ bold_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the submatrix of the identity 𝐈dd×dsubscript𝐈𝑑superscript𝑑𝑑\mathbf{I}_{d}\in\mathbb{R}^{d\times d}bold_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT restricted to the columns indexed by isubscript𝑖\mathcal{B}_{i}caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT.

Consider the partial derivatives (cf. (3.2))

[g(𝐳t)]i[𝐳t]j=[ϕt(𝐳t)]i[𝐳t]j={|i|×|j| matrix,ifi>jdiag(σ(𝜶i)),ifi=j𝟎d,otherwise.subscriptdelimited-[]𝑔subscript𝐳𝑡subscript𝑖subscriptdelimited-[]subscript𝐳𝑡subscript𝑗subscriptdelimited-[]subscriptbold-italic-ϕ𝑡subscript𝐳𝑡subscript𝑖subscriptdelimited-[]subscript𝐳𝑡subscript𝑗cases|i|×|j| matrixif𝑖𝑗diag𝜎subscript𝜶𝑖if𝑖𝑗subscript0𝑑otherwise\frac{\partial[g(\mathbf{z}_{t})]_{\mathcal{B}_{i}}}{\partial[\mathbf{z}_{t}]_% {\mathcal{B}_{j}}}=\frac{\partial[\boldsymbol{\phi}_{t}(\mathbf{z}_{t})]_{% \mathcal{B}_{i}}}{\partial[\mathbf{z}_{t}]_{\mathcal{B}_{j}}}=\begin{cases}% \text{a $|\mathcal{B}_{i}|\times|\mathcal{B}_{j}|$ matrix},&\text{if}~{}i>j\\ \text{diag}(\sigma(\boldsymbol{\alpha}_{i})),&\text{if}~{}i=j\\ \mathbf{0}_{d},&\text{otherwise}\end{cases}.divide start_ARG ∂ [ italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG ∂ [ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ [ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG ∂ [ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG = { start_ROW start_CELL a | caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | × | caligraphic_B start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT | matrix , end_CELL start_CELL if italic_i > italic_j end_CELL end_ROW start_ROW start_CELL diag ( italic_σ ( bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , end_CELL start_CELL if italic_i = italic_j end_CELL end_ROW start_ROW start_CELL bold_0 start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT , end_CELL start_CELL otherwise end_CELL end_ROW . (11)

It can be verified that the Jacobian [𝐳t]π[g(𝐳t)]π\nabla_{[\mathbf{z}_{t}]_{\pi}}[g(\mathbf{z}_{t})]_{\pi}∇ start_POSTSUBSCRIPT [ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT of the permuted parameters is block-upper-triangular, with the i𝑖iitalic_i-th diagonal block given by [g(𝐳t)]i[𝐳t]i=diag(σ(𝜶i))0subscriptdelimited-[]𝑔subscript𝐳𝑡subscript𝑖subscriptdelimited-[]subscript𝐳𝑡subscript𝑖diag𝜎subscript𝜶𝑖succeeds0\frac{\partial[g(\mathbf{z}_{t})]_{\mathcal{B}_{i}}}{\partial[\mathbf{z}_{t}]_% {\mathcal{B}_{i}}}=\text{diag}(\sigma(\boldsymbol{\alpha}_{i}))\succ 0divide start_ARG ∂ [ italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG ∂ [ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG = diag ( italic_σ ( bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ≻ 0. It thus holds that [𝐳t]π[g(𝐳t)]π0\nabla_{[\mathbf{z}_{t}]_{\pi}}[g(\mathbf{z}_{t})]_{\pi}\succ 0∇ start_POSTSUBSCRIPT [ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ≻ 0, or equivalently, 𝐐πg(𝐳t)𝐐π0succeedssuperscriptsubscript𝐐𝜋top𝑔subscript𝐳𝑡subscript𝐐𝜋0\mathbf{Q}_{\pi}^{\top}\nabla g(\mathbf{z}_{t})\mathbf{Q}_{\pi}\succ 0bold_Q start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) bold_Q start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ≻ 0, which implies that

g(𝐳t)0,𝐳td.formulae-sequencesucceeds𝑔subscript𝐳𝑡0for-allsubscript𝐳𝑡superscript𝑑\nabla g(\mathbf{z}_{t})\succ 0,~{}\forall\mathbf{z}_{t}\in\mathbb{R}^{d}\;.∇ italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ≻ 0 , ∀ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT . (12)

Letting g~(α):=g(α𝐳t+(1α)𝐳t)assign~𝑔𝛼𝑔𝛼subscript𝐳𝑡1𝛼superscriptsubscript𝐳𝑡\tilde{g}(\alpha):=g(\alpha\mathbf{z}_{t}+(1-\alpha)\mathbf{z}_{t}^{\prime})over~ start_ARG italic_g end_ARG ( italic_α ) := italic_g ( italic_α bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + ( 1 - italic_α ) bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ), it holds for 𝐳t𝐳tfor-allsubscript𝐳𝑡superscriptsubscript𝐳𝑡\forall\mathbf{z}_{t}\neq\mathbf{z}_{t}^{\prime}∀ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≠ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT

(𝐳t𝐳t)(g(𝐳t)g(𝐳t))superscriptsubscript𝐳𝑡superscriptsubscript𝐳𝑡top𝑔subscript𝐳𝑡𝑔superscriptsubscript𝐳𝑡\displaystyle(\mathbf{z}_{t}-\mathbf{z}_{t}^{\prime})^{\top}(g(\mathbf{z}_{t})% -g(\mathbf{z}_{t}^{\prime}))( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_g ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) =(𝐳t𝐳t)(g~(1)g~(0))absentsuperscriptsubscript𝐳𝑡superscriptsubscript𝐳𝑡top~𝑔1~𝑔0\displaystyle=(\mathbf{z}_{t}-\mathbf{z}_{t}^{\prime})^{\top}(\tilde{g}(1)-% \tilde{g}(0))= ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( over~ start_ARG italic_g end_ARG ( 1 ) - over~ start_ARG italic_g end_ARG ( 0 ) )
=(𝐳t𝐳t)01g~(α)𝑑αabsentsuperscriptsubscript𝐳𝑡superscriptsubscript𝐳𝑡topsuperscriptsubscript01superscript~𝑔𝛼differential-d𝛼\displaystyle=(\mathbf{z}_{t}-\mathbf{z}_{t}^{\prime})^{\top}\int_{0}^{1}% \tilde{g}^{\prime}(\alpha)d\alpha= ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_α ) italic_d italic_α
=01(𝐳t𝐳t)g(α𝐳t+(1α)𝐳t)(𝐳t𝐳t)𝑑α>0absentsuperscriptsubscript01superscriptsubscript𝐳𝑡superscriptsubscript𝐳𝑡top𝑔𝛼subscript𝐳𝑡1𝛼superscriptsubscript𝐳𝑡subscript𝐳𝑡superscriptsubscript𝐳𝑡differential-d𝛼0\displaystyle=\int_{0}^{1}(\mathbf{z}_{t}-\mathbf{z}_{t}^{\prime})^{\top}% \nabla g(\alpha\mathbf{z}_{t}+(1-\alpha)\mathbf{z}_{t}^{\prime})(\mathbf{z}_{t% }-\mathbf{z}_{t}^{\prime})d\alpha>0= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_g ( italic_α bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + ( 1 - italic_α ) bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_d italic_α > 0 (13)

where the inequality follows from (12).

Next, upper bounding σ1/C𝜎1𝐶\sigma\leq 1/Citalic_σ ≤ 1 / italic_C, we will show that (g1)(ϕt)Csucceeds-or-equalssuperscript𝑔1subscriptbold-italic-ϕ𝑡𝐶\nabla(g^{-1})(\boldsymbol{\phi}_{t})\succeq C∇ ( italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⪰ italic_C for some constant C>0𝐶0C>0italic_C > 0. To obtain the inverse g1superscript𝑔1g^{-1}italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT, notice that (8a) can be readily rewritten as

[𝐳t]i=([ϕt]i𝝁i)1/σ(𝜶i).subscriptdelimited-[]subscript𝐳𝑡subscript𝑖direct-productsubscriptdelimited-[]subscriptbold-italic-ϕ𝑡subscript𝑖subscript𝝁𝑖1𝜎subscript𝜶𝑖[\mathbf{z}_{t}]_{\mathcal{B}_{i}}=([\boldsymbol{\phi}_{t}]_{\mathcal{B}_{i}}-% \boldsymbol{\mu}_{i})\odot 1/\sigma(\boldsymbol{\alpha}_{i}).[ bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT = ( [ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - bold_italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⊙ 1 / italic_σ ( bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) .

where /// is the element-wise division. Similar to (11), it can be easily verified that the Jacobian [ϕt]π[(g1)(ϕt)]π\nabla_{[\boldsymbol{\phi}_{t}]_{\pi}}[(g^{-1})(\boldsymbol{\phi}_{t})]_{\pi}∇ start_POSTSUBSCRIPT [ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ( italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT is also block-upper-triangular, with i𝑖iitalic_i-th diagonal block [(g1)(ϕt)]i[ϕt]i=diag1(σ(𝜶i))Csubscriptdelimited-[]superscript𝑔1subscriptbold-italic-ϕ𝑡subscript𝑖subscriptdelimited-[]subscriptbold-italic-ϕ𝑡subscript𝑖superscriptdiag1𝜎subscript𝜶𝑖succeeds-or-equals𝐶\frac{\partial[(g^{-1})(\boldsymbol{\phi}_{t})]_{\mathcal{B}_{i}}}{\partial[% \boldsymbol{\phi}_{t}]_{\mathcal{B}_{i}}}=\text{diag}^{-1}(\sigma(\boldsymbol{% \alpha}_{i}))\succeq Cdivide start_ARG ∂ [ ( italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG ∂ [ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG = diag start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_σ ( bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ⪰ italic_C.

As a result, we have that

(g1)(ϕt)mini=1,,B1/σ(𝜶i)Csucceeds-or-equalssuperscript𝑔1subscriptbold-italic-ϕ𝑡subscript𝑖1𝐵1subscriptnorm𝜎subscript𝜶𝑖succeeds-or-equals𝐶\nabla(g^{-1})(\boldsymbol{\phi}_{t})\succeq\min_{i=1,\ldots,B}1/\|\sigma(% \boldsymbol{\alpha}_{i})\|_{\infty}\succeq C∇ ( italic_g start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⪰ roman_min start_POSTSUBSCRIPT italic_i = 1 , … , italic_B end_POSTSUBSCRIPT 1 / ∥ italic_σ ( bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ⪰ italic_C

which completes the proof. ∎

Appendix B Summary of the algorithm

Input: {𝒟t}t=1Tsuperscriptsubscriptsubscriptsuperscript𝒟absent𝑡𝑡1𝑇\{\mathcal{D}^{\mathrm{}}_{t}\}_{t=1}^{T}{ caligraphic_D start_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, step sizes α𝛼\alphaitalic_α and β𝛽\betaitalic_β, maximum number of iterations K𝐾Kitalic_K and R𝑅Ritalic_R, and blockIAF mirror map h\nabla h∇ italic_h.
Initialization: randomly initialize 𝜽(0)=[𝐳(0),𝜽g]superscript𝜽0superscriptsuperscript𝐳limit-from0topsuperscriptsubscript𝜽𝑔toptop\boldsymbol{\theta}^{(0)}=[\mathbf{z}^{(0)\top},\boldsymbol{\theta}_{g}^{\top}% ]^{\top}bold_italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = [ bold_z start_POSTSUPERSCRIPT ( 0 ) ⊤ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT.
1 for r=1,,R𝑟1normal-…𝑅r=1,\ldots,Ritalic_r = 1 , … , italic_R do
2       Randomly sample a mini-batch of tasks 𝒯(r){1,,T}superscript𝒯𝑟1𝑇\mathcal{T}^{(r)}\subset\{1,\ldots,T\}caligraphic_T start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT ⊂ { 1 , … , italic_T };
3       for t𝒯(r)𝑡superscript𝒯𝑟t\in\mathcal{T}^{(r)}italic_t ∈ caligraphic_T start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT do
4             Initialize 𝐳t(0)=𝐳(0)superscriptsubscript𝐳𝑡0superscript𝐳0\mathbf{z}_{t}^{(0)}=\mathbf{z}^{(0)}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = bold_z start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT;
5             for k=1,,K𝑘1normal-…𝐾k=1,\ldots,Kitalic_k = 1 , … , italic_K do
6                   Map ϕt(k1)(𝜽(r1))=(h)1(𝐳t(k1)(𝜽(r1));𝜽g(r1))superscriptsubscriptbold-italic-ϕ𝑡𝑘1superscript𝜽𝑟1superscript1superscriptsubscript𝐳𝑡𝑘1superscript𝜽𝑟1superscriptsubscript𝜽𝑔𝑟1\boldsymbol{\phi}_{t}^{(k-1)}(\boldsymbol{\theta}^{(r-1)})=(\nabla h)^{-1}(% \mathbf{z}_{t}^{(k-1)}(\boldsymbol{\theta}^{(r-1)});\boldsymbol{\theta}_{g}^{(% r-1)})bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT ) = ( ∇ italic_h ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT ) ; bold_italic_θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT );
7                   Descend 𝐳t(k)(𝜽(r1))=𝐳t(k1)(𝜽(r1))α(ϕt(k1)(𝜽(r1));𝒟ttrn)superscriptsubscript𝐳𝑡𝑘superscript𝜽𝑟1superscriptsubscript𝐳𝑡𝑘1superscript𝜽𝑟1𝛼superscriptsubscriptbold-italic-ϕ𝑡𝑘1superscript𝜽𝑟1subscriptsuperscript𝒟trn𝑡\mathbf{z}_{t}^{(k)}(\boldsymbol{\theta}^{(r-1)})=\mathbf{z}_{t}^{(k-1)}(% \boldsymbol{\theta}^{(r-1)})-\alpha\nabla\mathcal{L}\big{(}\boldsymbol{\phi}_{% t}^{(k-1)}(\boldsymbol{\theta}^{(r-1)});\mathcal{D}^{\mathrm{trn}}_{t}\big{)}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT ) = bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT ) - italic_α ∇ caligraphic_L ( bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k - 1 ) end_POSTSUPERSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT ) ; caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT );
8                  
9             end for
10            Map ϕ^t(𝜽(r1))=(h)1(𝐳t(K)(𝜽(r1));𝜽g(r1))subscript^bold-italic-ϕ𝑡superscript𝜽𝑟1superscript1superscriptsubscript𝐳𝑡𝐾superscript𝜽𝑟1superscriptsubscript𝜽𝑔𝑟1\hat{\boldsymbol{\phi}}_{t}(\boldsymbol{\theta}^{(r-1)})=(\nabla h)^{-1}(% \mathbf{z}_{t}^{(K)}(\boldsymbol{\theta}^{(r-1)});\boldsymbol{\theta}_{g}^{(r-% 1)})over^ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT ) = ( ∇ italic_h ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT ) ; bold_italic_θ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT );
11            
12       end for
13      Update 𝜽(r)=𝜽(r1)βT|𝒯r|t𝒯r𝜽(r1)(ϕ^t(𝜽(r1));𝒟tval)superscript𝜽𝑟superscript𝜽𝑟1𝛽𝑇superscript𝒯𝑟subscript𝑡superscript𝒯𝑟subscriptsuperscript𝜽𝑟1subscript^bold-italic-ϕ𝑡superscript𝜽𝑟1subscriptsuperscript𝒟val𝑡\boldsymbol{\theta}^{(r)}=\boldsymbol{\theta}^{(r-1)}-\beta\frac{T}{|\mathcal{% T}^{r}|}\sum_{t\in\mathcal{T}^{r}}\nabla_{\boldsymbol{\theta}^{(r-1)}}\mathcal% {L}(\hat{\boldsymbol{\phi}}_{t}(\boldsymbol{\theta}^{(r-1)});\mathcal{D}^{% \mathrm{val}}_{t})bold_italic_θ start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT = bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT - italic_β divide start_ARG italic_T end_ARG start_ARG | caligraphic_T start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT | end_ARG ∑ start_POSTSUBSCRIPT italic_t ∈ caligraphic_T start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L ( over^ start_ARG bold_italic_ϕ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_r - 1 ) end_POSTSUPERSCRIPT ) ; caligraphic_D start_POSTSUPERSCRIPT roman_val end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT );
14      
15 end for
Output: 𝜽^=𝜽(R)^𝜽superscript𝜽𝑅\hat{\boldsymbol{\theta}}=\boldsymbol{\theta}^{(R)}over^ start_ARG bold_italic_θ end_ARG = bold_italic_θ start_POSTSUPERSCRIPT ( italic_R ) end_POSTSUPERSCRIPT.
Algorithm 1 Meta-learning with MD and blockIAF

Appendix C Numerical setups

This section elaborates further on the dataset and setups of the numerical tests.

The miniImageNet dataset is a few-shot classification dataset comprising natural images from 100100100100 classes, each containing 600600600600 samples. All images are cropped and resized to 84×84848484\times 8484 × 84, as suggested by [4]. The 100100100100 classes are disjointly divided into 3333 groups with corresponding size 64646464, 20202020 and 16161616, which are available to the meta-training, meta-validation, and meta-testing phases, respectively. The task setups follow from the standard M𝑀Mitalic_M-class N𝑁Nitalic_N-shot few-shot learning protocol [4, 6]. In particular, 𝒟ttrnsubscriptsuperscript𝒟trn𝑡\mathcal{D}^{\mathrm{trn}}_{t}caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT per task t𝑡titalic_t contains M𝑀Mitalic_M classes randomly drawn from the dataset, each consisting of N𝑁Nitalic_N labeled data. It is easy to see that |𝒟ttrn|=MN,tsubscriptsuperscript𝒟trn𝑡𝑀𝑁for-all𝑡|\mathcal{D}^{\mathrm{trn}}_{t}|=MN,~{}\forall t| caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | = italic_M italic_N , ∀ italic_t. Likewise, 𝒟tvalsubscriptsuperscript𝒟val𝑡\mathcal{D}^{\mathrm{val}}_{t}caligraphic_D start_POSTSUPERSCRIPT roman_val end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is constructed in a manner akin to 𝒟ttrnsubscriptsuperscript𝒟trn𝑡\mathcal{D}^{\mathrm{trn}}_{t}caligraphic_D start_POSTSUPERSCRIPT roman_trn end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, albeit with each class comprising 15151515 labeled data.

The hyperparameters used in the tests are the same as those used by MAML [6], and are listed in Table 2. Our implementation relies on PyTorch, and codes are available at https://github.com/zhangyilang/MetaMirrorDescent.

Table 2: Hyperparameter setup for the numerical tests.
Hyperparameter Notation Value
Lower-level iterations K𝐾Kitalic_K 5555
Lower-level learning rate α𝛼\alphaitalic_α 102superscript10210^{-2}10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT
Upper-leve iterations R𝑅Ritalic_R 60,0006000060,00060 , 000
Upper-level learning rate β𝛽\betaitalic_β 103superscript10310^{-3}10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT
Upper-level SGD batch size |𝒯(r)|superscript𝒯𝑟|\mathcal{T}^{(r)}|| caligraphic_T start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT | 4444

All the MLPs used in blockIAF have three fully-connected layers with ReLU nonlinearity, and with the weight matrix of each layer Kronecker factorized [13]. Let sizei:=di,1×di,2××di,Oiassignsubscriptsize𝑖subscript𝑑𝑖1subscript𝑑𝑖2subscript𝑑𝑖subscript𝑂𝑖\text{size}_{i}:=d_{i,1}\times d_{i,2}\times\ldots\times d_{i,O_{i}}size start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT := italic_d start_POSTSUBSCRIPT italic_i , 1 end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_i , 2 end_POSTSUBSCRIPT × … × italic_d start_POSTSUBSCRIPT italic_i , italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT be the size of the original tensor corresponding to the vector [ϕt]isubscriptdelimited-[]subscriptbold-italic-ϕ𝑡subscript𝑖[\boldsymbol{\phi}_{t}]_{\mathcal{B}_{i}}[ bold_italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT, where Oisubscript𝑂𝑖O_{i}italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the total order of the tensor, and j=1Oidi,j=|i|superscriptsubscriptproduct𝑗1subscript𝑂𝑖subscript𝑑𝑖𝑗subscript𝑖\prod_{j=1}^{O_{i}}d_{i,j}=|\mathcal{B}_{i}|∏ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = | caligraphic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT |. Each layer of the encoder eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT outputs a tensor with dimensionality of half size. This implies that the output tensor of the l𝑙litalic_l-th layer of eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT has size sizei2l:=di,12l×di,22l××di,Oi2l,l=1,2,3formulae-sequenceassignsubscriptsize𝑖superscript2𝑙subscript𝑑𝑖1superscript2𝑙subscript𝑑𝑖2superscript2𝑙subscript𝑑𝑖subscript𝑂𝑖superscript2𝑙𝑙123\lfloor\frac{\text{size}_{i}}{2^{l}}\rfloor:=\lfloor\frac{d_{i,1}}{2^{l}}% \rfloor\times\lfloor\frac{d_{i,2}}{2^{l}}\rfloor\times\ldots\times\lfloor\frac% {d_{i,O_{i}}}{2^{l}}\rfloor,~{}l=1,2,3⌊ divide start_ARG size start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT end_ARG ⌋ := ⌊ divide start_ARG italic_d start_POSTSUBSCRIPT italic_i , 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT end_ARG ⌋ × ⌊ divide start_ARG italic_d start_POSTSUBSCRIPT italic_i , 2 end_POSTSUBSCRIPT end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT end_ARG ⌋ × … × ⌊ divide start_ARG italic_d start_POSTSUBSCRIPT italic_i , italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT end_ARG ⌋ , italic_l = 1 , 2 , 3. The decoder disubscript𝑑𝑖d_{i}italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT first vectorizes and concatenates the embeddings provided by {ej}j=1i1superscriptsubscriptsubscript𝑒𝑗𝑗1𝑖1\{e_{j}\}_{j=1}^{i-1}{ italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT, maps this concatenated embedding vector to sizei8subscriptsize𝑖8\lfloor\frac{\text{size}_{i}}{8}\rfloor⌊ divide start_ARG size start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG 8 end_ARG ⌋, and recovers the tensor to sizeisubscriptsize𝑖\text{size}_{i}size start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT by performing the inverse size operations of eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT; that is, its l𝑙litalic_l-th layer changes the tensor size from sizei24lsubscriptsize𝑖superscript24𝑙\lfloor\frac{\text{size}_{i}}{2^{4-l}}\rfloor⌊ divide start_ARG size start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG 2 start_POSTSUPERSCRIPT 4 - italic_l end_POSTSUPERSCRIPT end_ARG ⌋ to sizei23lsubscriptsize𝑖superscript23𝑙\lfloor\frac{\text{size}_{i}}{2^{3-l}}\rfloor⌊ divide start_ARG size start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG 2 start_POSTSUPERSCRIPT 3 - italic_l end_POSTSUPERSCRIPT end_ARG ⌋.

Appendix D Complexity analysis

Next, complexity comparison is implemented to justify the effectiveness of the introduced blockIAF model. To showcase the computational efficiency, numerical complexities are assessed using the 5-class 5-shot miniImageNet dataset. In the test, the blockIAF-based mirror map incurs a 9.1%percent9.19.1\%9.1 % increase of forward and backpropagation time compared to the basic GD update in MAML. This slight increment confirms the claimed low complexity of the proposed approach.