License: arXiv.org perpetual non-exclusive license
arXiv:2307.09883v2 [cs.LG] 12 Mar 2024

 

Symmetric Equilibrium Learning of VAEs


 


Boris Flach                        Dmitrij Schlesinger                        Alexander Shekhovtsov Czech Techn. University in Prague                        Dresden University of Technology                        Czech Techn. University in Prague

Abstract

We view variational autoencoders (VAE) as decoder–encoder pairs, which map distributions in the data space to distributions in the latent space and vice versa. The standard learning approach for VAEs is the maximisation of the evidence lower bound (ELBO). It is asymmetric in that it aims at learning a latent variable model while using the encoder as an auxiliary means only. Moreover, it requires a closed form a-priori latent distribution. This limits its applicability in more complex scenarios, such as general semi-supervised learning and employing complex generative models as priors. We propose a Nash equilibrium learning approach, which is symmetric with respect to the encoder and decoder and allows learning VAEs in situations where both the data and the latent distributions are accessible only by sampling. The flexibility and simplicity of this approach allows its application to a wide range of learning scenarios and downstream tasks.

1 INTRODUCTION

Variational autoencoders (Kingma and Welling, 2014; Rezende et al., 2014) are a well established and well analysed approach of learning latent variable models of the form p(x)=zp(z)p(x|z)𝑝𝑥subscript𝑧𝑝𝑧𝑝conditional𝑥𝑧p(x)=\sum_{z}p(z)p(x|z)italic_p ( italic_x ) = ∑ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_p ( italic_z ) italic_p ( italic_x | italic_z ). Given a distribution π(x)𝜋𝑥\pi(x)italic_π ( italic_x ), x𝒳𝑥𝒳x\in\mathcal{X}italic_x ∈ caligraphic_X in the data space and an assumed distribution p(z)𝑝𝑧p(z)italic_p ( italic_z ), z𝒵𝑧𝒵z\in\mathcal{Z}italic_z ∈ caligraphic_Z in the latent space, a VAE combines a pair of parametrised distributions pθ(x|z)subscript𝑝𝜃conditional𝑥𝑧p_{\theta}(x\,|\,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ), qφ(z|x)subscript𝑞𝜑conditional𝑧𝑥q_{\varphi}(z\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ), which are usually modelled in terms of deep networks. The standard way to learn this encoder–decoder pair is to maximise the evidence lower bound of the data log-likelihood,

LB(θ,φ)=𝔼π(x)[\displaystyle L_{B}(\theta,\varphi)=\mathbb{E}_{\pi(x)}\bigl{[}italic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ , italic_φ ) = blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ 𝔼qφ(z|x)logpθ(x|z)subscript𝔼subscript𝑞𝜑conditional𝑧𝑥subscript𝑝𝜃conditional𝑥𝑧\displaystyle\mathbb{E}_{q_{\varphi}(z\,|\,x)}\log p_{\theta}(x\,|\,z)blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) (1)
DKL(qφ(z|x)p(z))].\displaystyle-D_{\rm KL}\big{(}q_{\varphi}(z\,|\,x)\,\big{\|}\,p(z)\big{)}% \bigr{]}.- italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ∥ italic_p ( italic_z ) ) ] .

This learning formulation is particularly well suited to situations where only the generative model p(x)𝑝𝑥p(x)italic_p ( italic_x ) is of interest. The research in this area in recent years has culminated in deep hierarchical VAEs (Vahdat and Kautz, 2020) and diffusion models (Ho et al., 2020; Rombach et al., 2022), which can be viewed also as hierarchical VAEs. The encoder’s role is auxiliary in the ELBO, and it is even fixed to a simple noisy shrinkage in diffusion models. However, a learned encoder is often of interest in applications on its own — it can provide compact representations, useful for downstream tasks (e.g. for semantic hashing, Dadaneh et al. 2020). Furthermore, while only samples from π(x)𝜋𝑥\pi(x)italic_π ( italic_x ) are needed in (1), an explicit model of p(z)𝑝𝑧p(z)italic_p ( italic_z ) is required in order to compute (and differentiate) the KL-divergence term. Although solutions to the latter problem have been proposed, they come with some other limitations (discussed in detail in Section 5).

The asymmetries of the standard VAE learning approach pointed above make it difficult to use it in semi-supervised training scenarios and in situations where both spaces 𝒳𝒳\mathcal{X}caligraphic_X and 𝒵𝒵\mathcal{Z}caligraphic_Z are complex and possibly structured, as for instance in semantic segmentation with images x𝑥xitalic_x and segmentations z𝑧zitalic_z. Learning an encoder–decoder pair in such a scenario would naturally allow solving inference problems in both directions between x𝑥xitalic_x and z𝑧zitalic_z as well as to build more complex models. The requirement to model p(z)𝑝𝑧p(z)italic_p ( italic_z ) by a simple and tractable density becomes then a significant limitation.

In this work, we propose a symmetric learning approach inspired by game theory, which leads to a simple learning algorithm. The method can handle implicitly given marginal distributions π(x)𝜋𝑥\pi(x)italic_π ( italic_x ) and π(z)𝜋𝑧\pi(z)italic_π ( italic_z ). It does not require gradients of parametric discrete expectations like the gradient of ELBO w.r.t. the encoder parameters, and therefore no reparametrisation is needed. Consequently, handling discrete or continuous variables is simple. The method gives a novel view of the well-known wake-sleep algorithm (Hinton et al., 1995), as discussed in Section 5. It can be applied to models with structured latent spaces, like hierarchical VAE, and extended to models consisting of 3 or more groups of variables. In the latter case, the model consists of several inference networks – one for each group of variables. They are learned jointly and can address an extended range of tasks at inference time, as we demonstrate experimentally.

The rest of the paper is organised as follows. In the next two sections we derive and analyse our novel learning approach. In the following section we exemplify its application to advanced models and learning setups. In the final experimental section we compare it with ELBO learning, show that it provides comparable model estimates, and demonstrate its applicability to more complex models not addressable by ELBO.

2 PROBLEM FORMULATION

We propose a generic learning approach, whose primary goal is to learn a decoder p(x|z)𝑝conditional𝑥𝑧p(x\,|\,z)italic_p ( italic_x | italic_z ) and an encoder q(z|x)𝑞conditional𝑧𝑥q(z\,|\,x)italic_q ( italic_z | italic_x ) in the following training scenarios:

Semi-supervised learning: We assume training samples xπ(x)similar-to𝑥𝜋𝑥x\sim\pi(x)italic_x ∼ italic_π ( italic_x ) and zπ(z)similar-to𝑧𝜋𝑧z\sim\pi(z)italic_z ∼ italic_π ( italic_z ) and possibly also joint samples (x,z)π(x,z)similar-to𝑥𝑧𝜋𝑥𝑧(x,z)\sim\pi(x,z)( italic_x , italic_z ) ∼ italic_π ( italic_x , italic_z ), i.i.d. drawn from an unknown distribution π(x,z)𝜋𝑥𝑧\pi(x,z)italic_π ( italic_x , italic_z ) and its marginals.

Unsupervised learning: Only samples of xπ(x)similar-to𝑥𝜋𝑥x\sim\pi(x)italic_x ∼ italic_π ( italic_x ) are observed. In this case the space 𝒵𝒵\mathcal{Z}caligraphic_Z is a free modelling choice.

Similar to VAE learning, the choice of the models for the decoder and encoder is dictated by the need to be able to evaluate (or at least differentiate) their respective log-densities and to sample from them. We will assume that the decoder and encoder belong to parametric exponential families of the form

pθ(x|z)exp[ϕ(x),fθ(z)],proportional-tosubscript𝑝𝜃conditional𝑥𝑧italic-ϕ𝑥subscript𝑓𝜃𝑧\displaystyle p_{\theta}(x\,|\,z)\propto\exp\bigl{[}\langle\phi(x),f_{\theta}(% z)\rangle],italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) ∝ roman_exp [ ⟨ italic_ϕ ( italic_x ) , italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ⟩ ] , (2a)
qφ(z|x)exp[ψ(z),gφ(x)],proportional-tosubscript𝑞𝜑conditional𝑧𝑥𝜓𝑧subscript𝑔𝜑𝑥\displaystyle q_{\varphi}(z\,|\,x)\propto\exp\bigl{[}\langle\psi(z),g_{\varphi% }(x)\rangle],italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ∝ roman_exp [ ⟨ italic_ψ ( italic_z ) , italic_g start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_x ) ⟩ ] , (2b)

where ϕ:𝒳n:italic-ϕ𝒳superscript𝑛\phi\colon\mathcal{X}\to\mathbb{R}^{n}italic_ϕ : caligraphic_X → blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT and ψ:𝒵m:𝜓𝒵superscript𝑚\psi\colon\mathcal{Z}\to\mathbb{R}^{m}italic_ψ : caligraphic_Z → blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT are fixed sufficient statistics. The map**s f𝑓fitalic_f and g𝑔gitalic_g are usually modelled by deep networks, parametrised by θ𝜃\thetaitalic_θ, φ𝜑\varphiitalic_φ. Notice that variables x𝑥xitalic_x, z𝑧zitalic_z can be either discrete or continuous depending on the chosen exponential family. Common choices are e.g. Bernoulli or Gaussian models.

3 SYMMETRIC EQUILIBRIUM LEARNING

We present our general approach and theoretical analysis for the semi-supervised learning task from the previous section, which naturally calls for a symmetric formulation.

For simplicity of exposition, let us assume that only marginal empirical distributions π(x)𝜋𝑥\pi(x)italic_π ( italic_x ) and π(z)𝜋𝑧\pi(z)italic_π ( italic_z ) are given, but no joint observations are available. The goal is to learn an encoder–decoder pair qφ(z|x)subscript𝑞𝜑conditional𝑧𝑥q_{\varphi}(z\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) and pθ(x|z)subscript𝑝𝜃conditional𝑥𝑧p_{\theta}(x\,|\,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) by (i) optimising the likelihood of the observed data and (ii) enforcing the encoder and decoder consistency at the same time. We formulate the learning task symmetrically as finding a Nash equilibrium for a two-player game. The strategy of the first player is represented by the decoder pθsubscript𝑝𝜃p_{\theta}italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. Similarly, the strategy of the second player is represented by the encoder qφsubscript𝑞𝜑q_{\varphi}italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT. The utility function of a player is the likelihood of the training data w.r.t. its strategy. Thereby, training examples are completed by the strategy of the other player. For example, the missing information in the examples xπ(x)similar-to𝑥𝜋𝑥x\sim\pi(x)italic_x ∼ italic_π ( italic_x ) for the decoder likelihood is completed by the encoder strategy: zqφ(z|x)similar-to𝑧subscript𝑞𝜑conditional𝑧𝑥z\sim q_{\varphi}(z\,|\,x)italic_z ∼ italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ). Proceeding in the same way for the encoder, we obtain the utility functions

Lp(θ,φ)subscript𝐿𝑝𝜃𝜑\displaystyle L_{p}(\theta,\varphi)italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_θ , italic_φ ) =𝔼π(x)𝔼qφ(z|x)[logpθ(x|z)],absentsubscript𝔼𝜋𝑥subscript𝔼subscript𝑞𝜑conditional𝑧𝑥delimited-[]subscript𝑝𝜃conditional𝑥𝑧\displaystyle=\mathbb{E}_{\pi(x)}\mathbb{E}_{q_{\varphi}(z\,|\,x)}[\log p_{% \theta}(x\,|\,z)],= blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) ] , (3a)
Lq(θ,φ)subscript𝐿𝑞𝜃𝜑\displaystyle L_{q}(\theta,\varphi)italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_θ , italic_φ ) =𝔼π(z)𝔼pθ(x|z)[logqφ(z|x)].absentsubscript𝔼𝜋𝑧subscript𝔼subscript𝑝𝜃conditional𝑥𝑧delimited-[]subscript𝑞𝜑conditional𝑧𝑥\displaystyle=\mathbb{E}_{\pi(z)}\mathbb{E}_{p_{\theta}(x\,|\,z)}[\log q_{% \varphi}(z\,|\,x)].= blackboard_E start_POSTSUBSCRIPT italic_π ( italic_z ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) end_POSTSUBSCRIPT [ roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ] . (3b)

As we will see later, the game aims at maximising the decoder likelihood and the encoder likelihood of the training data simultaneously, whereby the mutual completion reinforces decoder-encoder consistency.

A Nash equilibrium of the game is a pair (θ*,φ*)subscript𝜃subscript𝜑(\theta_{*},\varphi_{*})( italic_θ start_POSTSUBSCRIPT * end_POSTSUBSCRIPT , italic_φ start_POSTSUBSCRIPT * end_POSTSUBSCRIPT ) such that

Lp(θ*,φ*)Lp(θ,φ*),θ,subscript𝐿𝑝subscript𝜃subscript𝜑subscript𝐿𝑝𝜃subscript𝜑for-all𝜃\displaystyle L_{p}(\theta_{*},\varphi_{*})\geqslant L_{p}(\theta,\varphi_{*})% ,\hskip 1.99997pt\forall\theta,italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT * end_POSTSUBSCRIPT , italic_φ start_POSTSUBSCRIPT * end_POSTSUBSCRIPT ) ⩾ italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_θ , italic_φ start_POSTSUBSCRIPT * end_POSTSUBSCRIPT ) , ∀ italic_θ ,
Lq(θ*,φ*)Lq(θ*,φ),φ,subscript𝐿𝑞subscript𝜃subscript𝜑subscript𝐿𝑞subscript𝜃𝜑for-all𝜑\displaystyle L_{q}(\theta_{*},\varphi_{*})\geqslant L_{q}(\theta_{*},\varphi)% ,\hskip 1.99997pt\forall\varphi,italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT * end_POSTSUBSCRIPT , italic_φ start_POSTSUBSCRIPT * end_POSTSUBSCRIPT ) ⩾ italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT * end_POSTSUBSCRIPT , italic_φ ) , ∀ italic_φ , (4)

i.e. a point at which neither player can improve its objective function. Towards finding an equilibrium we consider a simple gradient algorithm, in which each player tries to improve its utility w.r.t. to its strategy

θ:=θ+αθLp(θ,φ);φ:=φ+αφLq(θ,φ).formulae-sequenceassign𝜃𝜃𝛼subscript𝜃subscript𝐿𝑝𝜃𝜑assign𝜑𝜑𝛼subscript𝜑subscript𝐿𝑞𝜃𝜑\displaystyle\theta:=\theta+\alpha\nabla_{\theta}L_{p}(\theta,\varphi);\ \ \ % \varphi:=\varphi+\alpha\nabla_{\varphi}L_{q}(\theta,\varphi).italic_θ := italic_θ + italic_α ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_θ , italic_φ ) ; italic_φ := italic_φ + italic_α ∇ start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_θ , italic_φ ) . (5)

These updates may be executed in parallel or sequentially. Stochastic unbiased estimates of the required gradients are readily obtained by differentiating Monte-Carlo estimates of expectations (3) with as few as a single sample. Unlike in ELBO, the expectation Lp(θ,φ)subscript𝐿𝑝𝜃𝜑L_{p}(\theta,\varphi)italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_θ , italic_φ ) does not need to be differentiated with respect to the encoder parameters and similarly for Lq(θ,φ)subscript𝐿𝑞𝜃𝜑L_{q}(\theta,\varphi)italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_θ , italic_φ ). There is no need for the reparametrization trick in case of continuous variables or specialised gradient estimators through discrete samples in case of discrete variables.

Uniqueness

It is well known that nonzero-sum games can have multiple and even infinitely many Nash equilibria. It is therefore crucial to analyse uniqueness of the solution as well as the convergence properties of the algorithm (5).

Extending the decoder and encoder to joint models via

pθ(x,z)=pθ(x|z)π(z);qφ(x,z)=qφ(z|x)π(x)formulae-sequencesubscript𝑝𝜃𝑥𝑧subscript𝑝𝜃conditional𝑥𝑧𝜋𝑧subscript𝑞𝜑𝑥𝑧subscript𝑞𝜑conditional𝑧𝑥𝜋𝑥\displaystyle p_{\theta}(x,z)=p_{\theta}(x\,|\,z)\pi(z);\ \ q_{\varphi}(x,z)=q% _{\varphi}(z\,|\,x)\pi(x)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) = italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) italic_π ( italic_z ) ; italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_x , italic_z ) = italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) italic_π ( italic_x ) (6)

the game utilities (3) can be compactly written as

𝔼qφ(x,z)logpθ(x,z);subscript𝔼subscript𝑞𝜑𝑥𝑧subscript𝑝𝜃𝑥𝑧\displaystyle\mathbb{E}_{q_{\varphi}(x,z)}\log p_{\theta}(x,z);blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) ; 𝔼pθ(x,z)logqφ(x,z).subscript𝔼subscript𝑝𝜃𝑥𝑧subscript𝑞𝜑𝑥𝑧\displaystyle\mathbb{E}_{p_{\theta}(x,z)}\log q_{\varphi}(x,z).blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_x , italic_z ) . (7)

This game is hard to analyse because of non-linear map**s involved.

To allow for theoretical analysis we will enlarge the spaces of feasible joint distributions by considering the following canonical exponential families

pu(x,z)=π(z)exp[ϕ(x,z),uA(u)],subscript𝑝𝑢𝑥𝑧𝜋𝑧italic-ϕ𝑥𝑧𝑢𝐴𝑢\displaystyle p_{u}(x,z)=\pi(z)\exp\bigl{[}\langle\phi(x,z),u\rangle-A(u)\bigr% {]},italic_p start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_x , italic_z ) = italic_π ( italic_z ) roman_exp [ ⟨ italic_ϕ ( italic_x , italic_z ) , italic_u ⟩ - italic_A ( italic_u ) ] , (8a)
qv(x,z)=π(x)exp[ψ(x,z),vB(v)],subscript𝑞𝑣𝑥𝑧𝜋𝑥𝜓𝑥𝑧𝑣𝐵𝑣\displaystyle q_{v}(x,z)=\pi(x)\exp\bigl{[}\langle\psi(x,z),v\rangle-B(v)\bigr% {]},italic_q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_x , italic_z ) = italic_π ( italic_x ) roman_exp [ ⟨ italic_ψ ( italic_x , italic_z ) , italic_v ⟩ - italic_B ( italic_v ) ] , (8b)

where ϕ(x,z),ψ(x,z)italic-ϕ𝑥𝑧𝜓𝑥𝑧\phi(x,z),\psi(x,z)italic_ϕ ( italic_x , italic_z ) , italic_ψ ( italic_x , italic_z ) are sufficient statistics on (x,z)𝑥𝑧(x,z)( italic_x , italic_z ), u𝑢uitalic_u and v𝑣vitalic_v are free parameter vectors and A𝐴Aitalic_A and B𝐵Bitalic_B are cumulant functions ensuring normalisation. The models  (3) are log-linear in u𝑢uitalic_u and v𝑣vitalic_v by design. At the same time, with sufficiently complex ϕ(x,z)italic-ϕ𝑥𝑧\phi(x,z)italic_ϕ ( italic_x , italic_z ) and ψ(x,z)𝜓𝑥𝑧\psi(x,z)italic_ψ ( italic_x , italic_z ) they can represent or approximate all models from the original families which were parametrised in terms of neural networks.

We explain this model relaxation for the case of binary valued vectors z𝑧zitalic_z and x𝑥xitalic_x. The components of the vector of natural parameters fθ(z)subscript𝑓𝜃𝑧f_{\theta}(z)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) in (2) and the corresponding cumulant function are then pseudo-Boolean functions and can be written as polynomials in the components of z𝑧zitalic_z. The same holds for the components of the sufficient statistic vector ϕ(x)italic-ϕ𝑥\phi(x)italic_ϕ ( italic_x ). This means that if we take the components of ϕ(x,z)italic-ϕ𝑥𝑧\phi(x,z)italic_ϕ ( italic_x , italic_z ) in the relaxed class to contain all base monomials, then for any θ𝜃\thetaitalic_θ there would be a corresponding parameter vector u𝑢uitalic_u making the models equal. Notice that only under this correspondence the exponent part in (8a) matches the conditional distribution pθ(x|z)subscript𝑝𝜃conditional𝑥𝑧p_{\theta}(x|z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) while this is not true for a generic u𝑢uitalic_u.

Theorem 1.

The two-player game with utility functions

Lp(u,v)=𝔼qv(x,z)logpu(x,z),subscript𝐿𝑝𝑢𝑣subscript𝔼subscript𝑞𝑣𝑥𝑧subscript𝑝𝑢𝑥𝑧\displaystyle L_{p}(u,v)=\mathbb{E}_{q_{v}(x,z)}\log p_{u}(x,z),italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_u , italic_v ) = blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_x , italic_z ) , (9a)
Lq(u,v)=𝔼pu(x,z)logqv(x,z)subscript𝐿𝑞𝑢𝑣subscript𝔼subscript𝑝𝑢𝑥𝑧subscript𝑞𝑣𝑥𝑧\displaystyle L_{q}(u,v)=\mathbb{E}_{p_{u}(x,z)}\log q_{v}(x,z)italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_u , italic_v ) = blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_x , italic_z ) (9b)

and strategies given by exponential family distributions (3) has a unique, asymptotically stable equilibrium.

The proof is given in Appendix A. The idea is to construct a dual formulation of the game, which maximises the entropy under moment matching constraints. In this reformulation, it is then easy to prove the diagonal strict concavity condition (Rosen, 1965) – a sufficient condition for uniqueness. Following theorems 7-10 in (Rosen, 1965), the theorem implies that the simple gradient ascent algorithm (5) converges to the unique equilibrium point.

The theorem applies to log-linear models (3) with free natural parameters u𝑢uitalic_u and v𝑣vitalic_v and guarantees that the proposed algorithm converges to a unique equilibrium in this case. This has direct applicability to e.g. EF-Harmonium models, which are however outside of our scope. Its value for VAEs defined in terms of neural networks is rather indirect: if the algorithm works in the lifted space, it gives more confidence that it would also make sense in a subspace with a non-linear parametrisation.

Consistency

Finally, we discuss the question of encoder–decoder consistency. We say that models p(x|z)𝑝conditional𝑥𝑧p(x\,|\,z)italic_p ( italic_x | italic_z ) and q(z|x)𝑞conditional𝑧𝑥q(z\,|\,x)italic_q ( italic_z | italic_x ) are consistent if there exists a joint distribution m(x,z)𝑚𝑥𝑧m(x,z)italic_m ( italic_x , italic_z ) of which they are conditional distributions (see also Liu et al. 2021). Since we model pθ(x|z)subscript𝑝𝜃conditional𝑥𝑧p_{\theta}(x\,|\,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) and qφ(z|x)subscript𝑞𝜑conditional𝑧𝑥q_{\varphi}(z\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) independently, they are in general inconsistent. Enforcing the consistency strictly, while kee** the models in exponential families (2), leads to a joint m(x,z)𝑚𝑥𝑧m(x,z)italic_m ( italic_x , italic_z ) necessarily collapsing to an EF-Harmonium (Arnold and Strauss 1991, Shekhovtsov et al. 2022), which is a severe limitation. However, encouraging consistency could serve as a useful regularisation and can improve learning efficiency.

We observe that our game formulation implicitly encourages consistency.

Proposition 1.

With the definition of joint distributions pθ(x,z)subscript𝑝𝜃𝑥𝑧p_{\theta}(x,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) and qφ(x,z)subscript𝑞𝜑𝑥𝑧q_{\varphi}(x,z)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_x , italic_z ) in (6) and their respective marginals, the game (3) is equivalent to the game with utilities:

Lp=𝔼π(x)[logpθ(x)DKL(qφ(z|x)pθ(z|x))],\displaystyle\textstyle L^{\prime}_{p}=\mathbb{E}_{\pi(x)}\big{[}\log p_{% \theta}(x)-D_{\rm KL}(q_{\varphi}(z\,|\,x)\,\|\,p_{\theta}(z\,|\,x))\big{]},italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) - italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ∥ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x ) ) ] ,
Lq=𝔼π(z)[logqφ(z)DKL(pθ(x|z)qφ(x|z))].\displaystyle\textstyle L^{\prime}_{q}=\mathbb{E}_{\pi(z)}\big{[}\log q_{% \varphi}(z)-D_{\rm KL}(p_{\theta}(x\,|\,z)\,\|\,q_{\varphi}(x\,|\,z))\big{]}.italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_π ( italic_z ) end_POSTSUBSCRIPT [ roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z ) - italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) ∥ italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_x | italic_z ) ) ] .

See details in Appendix A. The utility Lpsubscriptsuperscript𝐿𝑝L^{\prime}_{p}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT is an alternative decomposition of ELBO into the data likelihood part and the encoder–posterior divergence, encouraging consistency. The utility Lqsubscriptsuperscript𝐿𝑞L^{\prime}_{q}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT is a symmetric counterpart. The difference to ELBO learning is that Lpsubscriptsuperscript𝐿𝑝L^{\prime}_{p}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT is optimised over θ𝜃\thetaitalic_θ only and not over φ𝜑\varphiitalic_φ and vice-versa for Lqsubscriptsuperscript𝐿𝑞L^{\prime}_{q}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT.

Similar to ELBO learning, there is no guarantee that the proposed learning approach will result in a consistent decoder–encoder pair defining a unique joint distribution. The necessity for such a joint distribution might be however dictated by the application for which the VAE is learned. Or it might arise if the learned VAE is only a part of a larger model, which requires such a joint distribution. In such cases we may consider the distribution (e.g. Liu et al. 2021)

m(x,z)=12m(z)p(x|z)+12m(x)q(z|x)𝑚𝑥𝑧12𝑚𝑧𝑝conditional𝑥𝑧12𝑚𝑥𝑞conditional𝑧𝑥\textstyle m(x,z)=\frac{1}{2}m(z)p(x\,|\,z)+\frac{1}{2}m(x)q(z\,|\,x)italic_m ( italic_x , italic_z ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_m ( italic_z ) italic_p ( italic_x | italic_z ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_m ( italic_x ) italic_q ( italic_z | italic_x ) (10)

with implicitly defined marginals m(x)𝑚𝑥m(x)italic_m ( italic_x ) and m(z)𝑚𝑧m(z)italic_m ( italic_z ). They must satisfy m(x)=zm(x,z)𝑚𝑥subscript𝑧𝑚𝑥𝑧m(x)=\sum_{z}m(x,z)italic_m ( italic_x ) = ∑ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_m ( italic_x , italic_z ) and m(z)=xm(x,z)𝑚𝑧subscript𝑥𝑚𝑥𝑧m(z)=\sum_{x}m(x,z)italic_m ( italic_z ) = ∑ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_m ( italic_x , italic_z ), which leads to the equations

m(x)𝑚𝑥\displaystyle m(x)italic_m ( italic_x ) =zp(x|z)m(z),absentsubscript𝑧𝑝conditional𝑥𝑧𝑚𝑧\displaystyle\textstyle=\sum_{z}p(x\,|\,z)m(z),= ∑ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_p ( italic_x | italic_z ) italic_m ( italic_z ) , (11a)
m(z)𝑚𝑧\displaystyle m(z)italic_m ( italic_z ) =xq(z|x)m(x).absentsubscript𝑥𝑞conditional𝑧𝑥𝑚𝑥\displaystyle\textstyle=\sum_{x}q(z\,|\,x)m(x).= ∑ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_q ( italic_z | italic_x ) italic_m ( italic_x ) . (11b)

While it is usually not possible to compute these marginals in closed form, it is nevertheless possible to sample from them and from the joint m(x,z)𝑚𝑥𝑧m(x,z)italic_m ( italic_x , italic_z ) as the limiting distributions of a Markov chain that alternates sampling of xp(x|z)similar-to𝑥𝑝conditional𝑥𝑧x\sim p(x|z)italic_x ∼ italic_p ( italic_x | italic_z ) and zq(z|x)similar-to𝑧𝑞conditional𝑧𝑥z\sim q(z|x)italic_z ∼ italic_q ( italic_z | italic_x ), as considered by Lamb et al. (2017).

4 ADVANCED MODELS AND LEARNING SETUPS

In this section we exemplify the application of the proposed learning approach to several practically relevant learning setups and more complex models.

Semi-Supervised Learning with Mixed Data

We extend the model and learning setup from  Section 3 in two respects. First, we assume that in addition to empirical distributions π(x)𝜋𝑥\pi(x)italic_π ( italic_x ) and π(z)𝜋𝑧\pi(z)italic_π ( italic_z ) we also have complete training examples, i.e., matching pairs (x,z)𝑥𝑧(x,z)( italic_x , italic_z ), forming an empirical distribution π(x,z)𝜋𝑥𝑧\pi(x,z)italic_π ( italic_x , italic_z ). Note that here π𝜋\piitalic_π-s are empirical distributions, hence e.g. π(x)𝜋𝑥\pi(x)italic_π ( italic_x ) need not be a marginal of π(x,z)𝜋𝑥𝑧\pi(x,z)italic_π ( italic_x , italic_z ). Second, we assume that the decoder’s joint distribution is defined using its own parametrised prior for z𝑧zitalic_z, i.e. pθ(x,z)=pθ(z)pθ(x|z)subscript𝑝𝜃𝑥𝑧subscript𝑝𝜃𝑧subscript𝑝𝜃conditional𝑥𝑧p_{\theta}(x,z)=p_{\theta}(z)p_{\theta}(x\,|\,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) = italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ).

The utility function of the decoder sums the p𝑝pitalic_p-likelihoods of the training set, of which the likelihoods of examples (x,z)π(x,z)similar-to𝑥𝑧𝜋𝑥𝑧(x,z)\sim\pi(x,z)( italic_x , italic_z ) ∼ italic_π ( italic_x , italic_z ) and zπ(z)similar-to𝑧𝜋𝑧z\sim\pi(z)italic_z ∼ italic_π ( italic_z ), are tractable. The missing information in examples xπ(x)similar-to𝑥𝜋𝑥x\sim\pi(x)italic_x ∼ italic_π ( italic_x ) with intractable p𝑝pitalic_p-likelihood is completed by the encoder strategy qφ(z|x)subscript𝑞𝜑conditional𝑧𝑥q_{\varphi}(z\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ). Proceeding in the same way for the encoder, we get the utility functions

Lp(θ,φ)subscript𝐿𝑝𝜃𝜑\displaystyle L_{p}(\theta,\varphi)italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_θ , italic_φ ) =𝔼π(x,z)[logpθ(x,z)]+𝔼π(z)[logpθ(z)]+absentsubscript𝔼𝜋𝑥𝑧delimited-[]subscript𝑝𝜃𝑥𝑧limit-fromsubscript𝔼𝜋𝑧delimited-[]subscript𝑝𝜃𝑧\displaystyle=\mathbb{E}_{\pi(x,z)}[\log p_{\theta}(x,z)]+\mathbb{E}_{\pi(z)}[% \log p_{\theta}(z)]+= blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_z ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) ] + blackboard_E start_POSTSUBSCRIPT italic_π ( italic_z ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ] +
+𝔼π(x)𝔼qφ(z|x)[logpθ(x,z)],subscript𝔼𝜋𝑥subscript𝔼subscript𝑞𝜑conditional𝑧𝑥delimited-[]subscript𝑝𝜃𝑥𝑧\displaystyle+\mathbb{E}_{\pi(x)}\mathbb{E}_{q_{\varphi}(z\,|\,x)}[\log p_{% \theta}(x,z)],+ blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) ] , (12a)
Lq(θ,φ)subscript𝐿𝑞𝜃𝜑\displaystyle L_{q}(\theta,\varphi)italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_θ , italic_φ ) =𝔼π(x,z)[logqφ(z|x)]+absentlimit-fromsubscript𝔼𝜋𝑥𝑧delimited-[]subscript𝑞𝜑conditional𝑧𝑥\displaystyle=\mathbb{E}_{\pi(x,z)}[\log q_{\varphi}(z\,|\,x)]+= blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_z ) end_POSTSUBSCRIPT [ roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ] +
+𝔼π(z)𝔼pθ(x|z)[logqφ(z|x)].subscript𝔼𝜋𝑧subscript𝔼subscript𝑝𝜃conditional𝑥𝑧delimited-[]subscript𝑞𝜑conditional𝑧𝑥\displaystyle+\mathbb{E}_{\pi(z)}\mathbb{E}_{p_{\theta}(x\,|\,z)}[\log q_{% \varphi}(z\,|\,x)].+ blackboard_E start_POSTSUBSCRIPT italic_π ( italic_z ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) end_POSTSUBSCRIPT [ roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ] . (12b)

Although we follow the symmetric approach as before, the utilities (4) are not entirely symmetric due to the model asymmetry: pθ(x,z)subscript𝑝𝜃𝑥𝑧p_{\theta}(x,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) has its own parametrised prior pθ(z)subscript𝑝𝜃𝑧p_{\theta}(z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ), whereas qφ(z|x)subscript𝑞𝜑conditional𝑧𝑥q_{\varphi}(z\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) lacks a prior model for x𝑥xitalic_x.

Unsupervised Learning

By unsupervised learning we will understand the case when only xπ(x)similar-to𝑥𝜋𝑥x\sim\pi(x)italic_x ∼ italic_π ( italic_x ) is observed. The choice and interpretation of the 𝒵𝒵\mathcal{Z}caligraphic_Z space and the respective distribution is then completely free. We are interested in learning a decoder model pθ(x,z)=pθ(x|z)pθ(z)subscript𝑝𝜃𝑥𝑧subscript𝑝𝜃conditional𝑥𝑧subscript𝑝𝜃𝑧p_{\theta}(x,z)=p_{\theta}(x\,|\,z)p_{\theta}(z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) = italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) and an encoder qφ(z|x)subscript𝑞𝜑conditional𝑧𝑥q_{\varphi}(z\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) approximating pθ(z|x)subscript𝑝𝜃conditional𝑧𝑥p_{\theta}(z\,|\,x)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x ).

The utility function for the decoder is given by its likelihood for the examples xπ(x)similar-to𝑥𝜋𝑥x\sim\pi(x)italic_x ∼ italic_π ( italic_x ), completed by the encoder. To form a likelihood for the encoder, we consider examples generated by the decoder model. The resulting utility functions are

Lp(θ,φ)=𝔼π(x)𝔼qφ(z|x)[logpθ(x,z)],subscript𝐿𝑝𝜃𝜑subscript𝔼𝜋𝑥subscript𝔼subscript𝑞𝜑conditional𝑧𝑥delimited-[]subscript𝑝𝜃𝑥𝑧\displaystyle L_{p}(\theta,\varphi)=\mathbb{E}_{\pi(x)}\mathbb{E}_{q_{\varphi}% (z\,|\,x)}[\log p_{\theta}(x,z)],italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_θ , italic_φ ) = blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) ] ,
Lq(θ,φ)=𝔼pθ(x,z)[logqφ(z|x)].subscript𝐿𝑞𝜃𝜑subscript𝔼subscript𝑝𝜃𝑥𝑧delimited-[]subscript𝑞𝜑conditional𝑧𝑥\displaystyle L_{q}(\theta,\varphi)=\mathbb{E}_{p_{\theta}(x,z)}[\log q_{% \varphi}(z\,|\,x)].italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_θ , italic_φ ) = blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT [ roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ] . (13)

In comparison with ELBO approach, the required stochastic gradients of the log-likelihoods are easy to compute, as discussed in Section 3. Notice that the algorithm applies also in case when p(z)𝑝𝑧p(z)italic_p ( italic_z ) is fixed and implicit, i.e. accessible by sampling only.

Hierarchical VAEs

Finally, we show that our unsupervised learning approach generalises to hierarchical / autoregressive VAEs. We assume that the hidden state z𝑧zitalic_z consists of parts z0,z1,,zmsubscript𝑧0subscript𝑧1subscript𝑧𝑚z_{0},z_{1},\dots,z_{m}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT, and xπ(x)similar-to𝑥𝜋𝑥x\sim\pi(x)italic_x ∼ italic_π ( italic_x ) can be observed. Such models come in two variants. In the first one the factorisation order of the encoder is reverse to the factorisation order of the decoder. Examples are e.g. Helmholtz machines (Hinton et al., 1995) and deep belief networks (Hinton et al., 2006). Here, we will consider the second variant, in which the encoder and decoder have the same order of factorisation:

p(x,z)=p(z0)i=1mp(zi|z<i)p(x|z),𝑝𝑥𝑧𝑝subscript𝑧0superscriptsubscriptproduct𝑖1𝑚𝑝conditionalsubscript𝑧𝑖subscript𝑧absent𝑖𝑝conditional𝑥𝑧\displaystyle p(x,z)=p(z_{0})\prod_{i=1}^{m}p(z_{i}\,|\,z_{<i})\>p(x\,|\,z),italic_p ( italic_x , italic_z ) = italic_p ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT ) italic_p ( italic_x | italic_z ) , (14a)
q(z|x)=q(z0|x)i=1mq(zi|z<i,x).𝑞conditional𝑧𝑥𝑞conditionalsubscript𝑧0𝑥superscriptsubscriptproduct𝑖1𝑚𝑞conditionalsubscript𝑧𝑖subscript𝑧absent𝑖𝑥\displaystyle q(z\,|\,x)=q(z_{0}\,|\,x)\prod_{i=1}^{m}q(z_{i}\,|\,z_{<i},x).italic_q ( italic_z | italic_x ) = italic_q ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) . (14b)

The encoder of such models can share parameters with the decoder, in particular Sønderby et al. (2016) proposed to define the encoder by

qθ,φ(zi|z<i,x)pθ(zi|z<i)fi(zi;di(x,φ)),proportional-tosubscript𝑞𝜃𝜑conditionalsubscript𝑧𝑖subscript𝑧absent𝑖𝑥subscript𝑝𝜃conditionalsubscript𝑧𝑖subscript𝑧absent𝑖subscript𝑓𝑖subscript𝑧𝑖subscript𝑑𝑖𝑥𝜑q_{\theta,\varphi}(z_{i}\,|\,z_{<i},x)\propto p_{\theta}(z_{i}\,|\,z_{<i})f_{i% }(z_{i};d_{i}(x,\varphi)),italic_q start_POSTSUBSCRIPT italic_θ , italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) ∝ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT ) italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x , italic_φ ) ) , (15)

where fisubscript𝑓𝑖f_{i}italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is a factorised function of zisubscript𝑧𝑖z_{i}italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and di(x,φ)subscript𝑑𝑖𝑥𝜑d_{i}(x,\varphi)italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x , italic_φ ) are the hidden layer outputs of a deterministic encoder network xdmdm1d0maps-to𝑥subscript𝑑𝑚maps-tosubscript𝑑𝑚1maps-tosubscript𝑑0x\mapsto d_{m}\mapsto d_{m-1}\dots\mapsto d_{0}italic_x ↦ italic_d start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ↦ italic_d start_POSTSUBSCRIPT italic_m - 1 end_POSTSUBSCRIPT … ↦ italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, parameterised by φ𝜑\varphiitalic_φ. The strategy of the first player is represented by the decoder parameters θ𝜃\thetaitalic_θ, while the strategy of the second player is represented by the encoder parameters φ𝜑\varphiitalic_φ. The utility functions for unsupervised learning are as in (4). Thanks to the factorisation of the decoder and encoder, they decompose into sums over the blocks p(zi|z<i)𝑝conditionalsubscript𝑧𝑖subscript𝑧absent𝑖p(z_{i}\,|\,z_{<i})italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT ) and q(zi|z<i,x)𝑞conditionalsubscript𝑧𝑖subscript𝑧absent𝑖𝑥q(z_{i}\,|\,z_{<i},x)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) and are tractable.

The model can be also learned “partially” semi-supervised by assuming that besides training examples xπ(x)similar-to𝑥𝜋𝑥x\sim\pi(x)italic_x ∼ italic_π ( italic_x ) we have access to a (usually smaller) set of training examples (x,z0)π(x,z0)similar-to𝑥subscript𝑧0𝜋𝑥subscript𝑧0(x,z_{0})\sim\pi(x,z_{0})( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∼ italic_π ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). This is relevant, for example, when z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT represents some hidden state(s) like classes or segmentations, on which we want to condition the decoder p(x,z)𝑝𝑥𝑧p(x,z)italic_p ( italic_x , italic_z ). The additional training examples will add

𝔼π(x,z0)𝔼q(z>0|z0,x)[logp(x,z)],subscript𝔼𝜋𝑥subscript𝑧0subscript𝔼𝑞conditionalsubscript𝑧absent0subscript𝑧0𝑥delimited-[]𝑝𝑥𝑧\displaystyle\mathbb{E}_{\pi(x,z_{0})}\mathbb{E}_{q(z_{>0}\,|\,z_{0},x)}[\log p% (x,z)],blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q ( italic_z start_POSTSUBSCRIPT > 0 end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_x ) end_POSTSUBSCRIPT [ roman_log italic_p ( italic_x , italic_z ) ] , (16a)
𝔼π(x,z0)[logq(z0|x)]subscript𝔼𝜋𝑥subscript𝑧0delimited-[]𝑞conditionalsubscript𝑧0𝑥\displaystyle\mathbb{E}_{\pi(x,z_{0})}[\log q(z_{0}\,|\,x)]blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_q ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) ] (16b)

to the respective utility functions.

5 RELATED WORK

Wake-Sleep

The learning algorithm (5) with utility functions (4) in the unsupervised case turns out to be equivalent to the wake-sleep (WS) algorithm first proposed by Hinton et al. (1995). However, we arrived at it from a conceptually new game-theoretic formulation, allowing for new analysis and generalisation to other settings (semi-supervised, partial observation scenarios). In Appendix B we give a brief overview of the original WS and follow-up works.

Implicit Prior

An important advantage of the proposed method is allowing prior π(z)𝜋𝑧\pi(z)italic_π ( italic_z ) to be implicit, i.e. accessible via samples only. Several works have extended VAEs to handle implicit encoders and priors. Mescheder et al. (2017) and Huszár (2017) estimate the log-density ratio logqφ(z|x)π(z)subscript𝑞𝜑conditional𝑧𝑥𝜋𝑧\log\frac{q_{\varphi}(z\,|\,x)}{\pi(z)}roman_log divide start_ARG italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_ARG start_ARG italic_π ( italic_z ) end_ARG in ELBO by learning a logistic regression discriminator. Similar to GANs, this requires an inner loop with possibly complex discriminator. Molchanov et al. (2019) allow both the encoder and the prior to be an intractable mixture of tractable densities. At the training time, a finite sample from the mixture is used to form a density estimate of π(z)𝜋𝑧\pi(z)italic_π ( italic_z ) and a lower bound on ELBO. These approaches are substantially more complex than ours and have further limitations. The prior can be made completely implicit, by assuming that the encoder-decoder model is consistent and hence defines a joint distribution and its marginals symmetrically. Towards this end Liu et al. (2021) explicitly optimise consistency and an expression that matches likelihood when assuming consistency.

Symmetric Learning

Asymmetry of ELBO formulation has motivated several approaches, alternative to ours. Dumoulin et al. (2017) minimises Jensen-Shannon divergence between joint encoder q(x,z)=π(x)q(x|z)𝑞𝑥𝑧𝜋𝑥𝑞conditional𝑥𝑧q(x,z)=\pi(x)q(x|z)italic_q ( italic_x , italic_z ) = italic_π ( italic_x ) italic_q ( italic_x | italic_z ) and decoder p(x,z)𝑝𝑥𝑧p(x,z)italic_p ( italic_x , italic_z ). To estimate this divergence, a discriminator of joint samples is learned alongside, as in GANs. Pu et al. (2017) use a similar approach to minimise the symmetrised KL divergence. Lamb et al. (2017) learns the MCMC encoder–decoder sampler by using a discriminator between data-clamped and free-running chains. An important difference to our work is that the game in these approaches is between the discriminator and the model, not between decoder and encoder.

Unsupervised and Semi-Supervised VAEs

Unsupervised equilibrium learning with utilities (4) can be reinterpreted to facilitate theoretical comparison with ELBO alongside Proposition 1. Furthermore, hierarchical model with observed z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (4) is closely related to semi-supervised learning with ELBO (Kingma et al., 2014). These connections are detailed in Appendix C.

6 EXPERIMENTS

Hierarchical VAE (MNIST)

Random Latent Codes Limiting Distribution

ELBO

Refer to caption Refer to caption
FID=5.17FID5.17\text{FID}=5.17FID = 5.17 FID=83.30FID83.30\text{FID}=83.30FID = 83.30

Symmetric

Refer to caption Refer to caption
FID=1.73FID1.73\text{FID}=1.73FID = 1.73 FID=3.63FID3.63\text{FID}=3.63FID = 3.63
Figure 1: Ladder VAE (MNIST): FID scores and images generated from random latent codes and from limiting distributions of models learned by maximising ELBO and by symmetric equilibrium learning (images are shown by probabilities for better visibility).
Refer to caption
Figure 2: MNIST: tSNE embeddings for the VAE with class labels. Points are coloured by digit classes. See text for explanation.

The goal of this experiment is to compare the symmetric equilibrium learning and ELBO learning on a simple dataset – MNIST images binarised by a suitably chosen threshold. We consider two hierarchical VAE model variants, each with two groups of binary valued latent variables z030subscript𝑧0superscript30z_{0}\in\mathcal{B}^{30}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ caligraphic_B start_POSTSUPERSCRIPT 30 end_POSTSUPERSCRIPT and z1100subscript𝑧1superscript100z_{1}\in\mathcal{B}^{100}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ caligraphic_B start_POSTSUPERSCRIPT 100 end_POSTSUPERSCRIPT. The decoder model is p(x,z0,z1)=p(z0)p(z1|z0)p(x|z1)𝑝𝑥subscript𝑧0subscript𝑧1𝑝subscript𝑧0𝑝conditionalsubscript𝑧1subscript𝑧0𝑝conditional𝑥subscript𝑧1p(x,z_{0},z_{1})=p(z_{0})p(z_{1}\,|\,z_{0})p(x\,|\,z_{1})italic_p ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_p ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p ( italic_x | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), where we assume a uniform distribution p(z0)𝑝subscript𝑧0p(z_{0})italic_p ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). The encoder for the first model variant (similar to ladder VAEs) factorises in the same order as the decoder, i.e. q(z0,z1|x)=q(z0|x)q(z1|z0,x)𝑞subscript𝑧0conditionalsubscript𝑧1𝑥𝑞conditionalsubscript𝑧0𝑥𝑞conditionalsubscript𝑧1subscript𝑧0𝑥q(z_{0},z_{1}\,|\,x)=q(z_{0}\,|\,x)q(z_{1}\,|\,z_{0},x)italic_q ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) = italic_q ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) italic_q ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_x ) and shares parameters with the decoder as described in Sec. 3. The encoder in the second model variant factorises in reverse order, i.e. q(z0,z1|x)=q(z1|x)q(z0|z1)𝑞subscript𝑧0conditionalsubscript𝑧1𝑥𝑞conditionalsubscript𝑧1𝑥𝑞conditionalsubscript𝑧0subscript𝑧1q(z_{0},z_{1}\,|\,x)=q(z_{1}\,|\,x)q(z_{0}\,|\,z_{1})italic_q ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) = italic_q ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) italic_q ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and shares no parameters with the decoder. The networks used in the encoders and decoders are standard deep convolutional networks of decreasing and increasing spatial resolution respectively. More details are provided in Appendix E111The code is available under
https://github.com/dschles70/symvae-aistats2024
. Training such models with ELBO requires a specialised gradient estimator for differentiating expectations in q𝑞qitalic_q w.r.t. its parameters. We use the estimator by Gregor et al. (2014), which is superior to straight-through and comparable to complex unbiased estimators for VAEs (Gu et al., 2016). Notice again, that no such approximation is required for the symmetric equilibrium learning.

Besides validating the generative capabilities of two resulting hierarchical VAEs, we want to analyse the consistency of their decoder–encoder pairs. We therefore generate images (i) from the decoder model p𝑝pitalic_p and (ii) from the limiting distribution m(x)𝑚𝑥m(x)italic_m ( italic_x ) (see Sec. 3 for explanation). Fig. 1 and Table 1 indicate that the models obtained by symmetric learning achieves better consistency having at the same time slightly better FID scores. This is confirmed by tSNE embeddings of z𝑧zitalic_z samples from the two models (see Appendix E).

Table 1: MNIST FID scores
model / alg. rand. latent limiting
LVAE, ELBO 5.17 83.30
LVAE, symmetric 1.73 3.63
RVAE, ELBO 5.83 29.59
RVAE, symmetric 0.81 5.40

To further strengthen this finding, we conducted similar experiments for the Fashion-MNIST dataset. Results and details are given in Appendix F.

The next experiment aims to show that the internal representations of a hierarchical VAE can be learned to have good generative and discriminative capabilities at the same time, even without “supervised” terms in the encoder objective as in (4). For this we extend z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT by ten additional binary variables, which encode the class labels (one hot encoding). This means that z0=(l,c)subscript𝑧0𝑙𝑐z_{0}=(l,c)italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ( italic_l , italic_c ) combines latent variables l𝑙litalic_l with class labels c𝑐citalic_c. We learn the model by symmetric learning from labelled examples (x,c)𝑥𝑐(x,c)( italic_x , italic_c ), but use the following utility functions

Lp(θ,φ)=𝔼π(x,c)𝔼qφ(l|x)𝔼qφ(z>0|x,z0)[logpθ(x,z)],subscript𝐿𝑝𝜃𝜑subscript𝔼𝜋𝑥𝑐subscript𝔼subscript𝑞𝜑conditional𝑙𝑥subscript𝔼subscript𝑞𝜑conditionalsubscript𝑧absent0𝑥subscript𝑧0delimited-[]subscript𝑝𝜃𝑥𝑧\displaystyle L_{p}(\theta,\varphi)=\mathbb{E}_{\pi(x,c)}\mathbb{E}_{q_{% \varphi}(l\,|\,x)}\mathbb{E}_{q_{\varphi}(z_{>0}\,|\,x,z_{0})}[\log p_{\theta}% (x,z)],italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_θ , italic_φ ) = blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_c ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_l | italic_x ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT > 0 end_POSTSUBSCRIPT | italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) ] ,
Lq(θ,φ)=𝔼pθ(x,z)[logqφ(z|x)].subscript𝐿𝑞𝜃𝜑subscript𝔼subscript𝑝𝜃𝑥𝑧delimited-[]subscript𝑞𝜑conditional𝑧𝑥\displaystyle L_{q}(\theta,\varphi)=\mathbb{E}_{p_{\theta}(x,z)}[\log q_{% \varphi}(z\,|\,x)].italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_θ , italic_φ ) = blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT [ roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ] . (17)

This means that the class information is used only when learning the decoder (notice that qφ(c,l|x)subscript𝑞𝜑𝑐conditional𝑙𝑥q_{\varphi}(c,l\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_c , italic_l | italic_x ) factorises w.r.t. to c𝑐citalic_c and l𝑙litalic_l). The encoder is learned solely on examples generated from the decoder, i.e. without any discriminative terms. The learned encoder achieves 99% classification accuracy on the MNIST validation set, with almost no decrease of the FID scores for the generated images (2.92.92.92.9 when sampled from the decoder and 4.04.04.04.0 when sampled from the limiting distribution). We also analyse tSNE embeddings of samples of the latent part l𝑙litalic_l of z0=(l,c)subscript𝑧0𝑙𝑐z_{0}=(l,c)italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ( italic_l , italic_c ) and samples of z1subscript𝑧1z_{1}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, both from the prior distribution p(z)𝑝𝑧p(z)italic_p ( italic_z ) and from the limiting distribution m(z|c)𝑚conditional𝑧𝑐m(z\,|\,c)italic_m ( italic_z | italic_c ). Fig. 2 reveals that the latent part of z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is fully class agnostic, whereas z1subscript𝑧1z_{1}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is clearly clustered w.r.t. the digit classes. This can be interpreted as follows. The latent part l𝑙litalic_l of z0=(l,c)subscript𝑧0𝑙𝑐z_{0}=(l,c)italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ( italic_l , italic_c ) is “transversal” to the class labels c𝑐citalic_c and presumably encodes image properties like stroke width, slant etc., whereas the internal representations in z1subscript𝑧1z_{1}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT are clustered by digit classes and encode the appearance properties separately for each class.

Semantic Segmentation (CelebA)

The following experiments illustrate the flexibility of the proposed approach on an application which is not accessible by ELBO learning. We consider the task of semantic segmentation with the goal to build a generative image segmentation model which can (i) generate image and segmentation pairs, (ii) segment given images, and (iii) generate images given a segmentation.

We use the CelebA-HQ dataset (Karras et al., 2018) and downscale its images and segmentations to 64×64646464\times 6464 × 64 pixels for simplicity.

Refer to caption
Figure 3: Given images and segmentations (xi,si)subscript𝑥𝑖subscript𝑠𝑖(x_{i},s_{i})( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) from the training set (xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are shown in the leftmost column), latent codes z2isubscript𝑧2𝑖z_{2i}italic_z start_POSTSUBSCRIPT 2 italic_i end_POSTSUBSCRIPT are sampled from qφ2(z2|xi,si)subscript𝑞subscript𝜑2conditionalsubscript𝑧2subscript𝑥𝑖subscript𝑠𝑖q_{\varphi_{2}}(z_{2}\,|\,x_{i},s_{i})italic_q start_POSTSUBSCRIPT italic_φ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). Given segmentations sjsubscript𝑠𝑗s_{j}italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT shown in the top row, images xi,jsubscript𝑥𝑖𝑗x_{i,j}italic_x start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT are sampled from pθ2(x|sj,z2i)subscript𝑝subscript𝜃2conditional𝑥subscript𝑠𝑗subscript𝑧2𝑖p_{\theta_{2}}(x\,|\,s_{j},z_{2i})italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x | italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 2 italic_i end_POSTSUBSCRIPT ). Images are shown by mean values of the respective Gaussians for better visibility.

Let x3×64×64𝑥superscript36464x\in\mathbb{R}^{3\times 64\times 64}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT 3 × 64 × 64 end_POSTSUPERSCRIPT be an image and s{1,,K}64×64𝑠superscript1𝐾6464s\in\{1,\ldots,K\}^{64\times 64}italic_s ∈ { 1 , … , italic_K } start_POSTSUPERSCRIPT 64 × 64 end_POSTSUPERSCRIPT be a segmentation (a categorical variable for each pixel). In order to model a distribution p(x,s)𝑝𝑥𝑠p(x,s)italic_p ( italic_x , italic_s ), we might try to learn a VAE with a decoder pθ(x,s|z)subscript𝑝𝜃𝑥conditional𝑠𝑧p_{\theta}(x,s\,|\,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_s | italic_z ) and encoder qφ(z|x,s)subscript𝑞𝜑conditional𝑧𝑥𝑠q_{\varphi}(z\,|\,x,s)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x , italic_s ), assuming e.g. a uniform prior distribution for the vector of binary latent variables zm𝑧superscript𝑚z\in\mathcal{B}^{m}italic_z ∈ caligraphic_B start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT. However, this alone will not meet our goals because we can not access the resulting distributions p(s|x)𝑝conditional𝑠𝑥p(s\,|\,x)italic_p ( italic_s | italic_x ) and p(x|s)𝑝conditional𝑥𝑠p(x\,|\,s)italic_p ( italic_x | italic_s ). We propose to model pθ(x,s|z)subscript𝑝𝜃𝑥conditional𝑠𝑧p_{\theta}(x,s\,|\,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_s | italic_z ) as limiting distribution of a pair of parametrised conditional probability distributions pθ1(s|x,z)subscript𝑝subscript𝜃1conditional𝑠𝑥𝑧p_{\theta_{1}}(s\,|\,x,z)italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s | italic_x , italic_z ) and pθ2(x|s,z)subscript𝑝subscript𝜃2conditional𝑥𝑠𝑧p_{\theta_{2}}(x\,|\,s,z)italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x | italic_s , italic_z ) (see (10)). This means that the marginal probability distributions pθ(x|z)subscript𝑝𝜃conditional𝑥𝑧p_{\theta}(x\,|\,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) and pθ(s|z)subscript𝑝𝜃conditional𝑠𝑧p_{\theta}(s\,|\,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_s | italic_z ) are defined implicitly through the corresponding marginalisation constraints.

To summarise, the whole model consists of three learnable conditional probability distributions pθ1(s|x,z)subscript𝑝subscript𝜃1conditional𝑠𝑥𝑧p_{\theta_{1}}(s\,|\,x,z)italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s | italic_x , italic_z ), pθ2(x|s,z)subscript𝑝subscript𝜃2conditional𝑥𝑠𝑧p_{\theta_{2}}(x\,|\,s,z)italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x | italic_s , italic_z ) and qφ(z|x,s)subscript𝑞𝜑conditional𝑧𝑥𝑠q_{\varphi}(z\,|\,x,s)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x , italic_s ). This defines a nested game with three players. Their respective strategies are represented by θ1subscript𝜃1\theta_{1}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, θ2subscript𝜃2\theta_{2}italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and φ𝜑\varphiitalic_φ. Their utility functions are

Lθ1(θ1,θ2,φ)subscript𝐿subscript𝜃1subscript𝜃1subscript𝜃2𝜑\displaystyle L_{\theta_{1}}(\theta_{1},\theta_{2},\varphi)italic_L start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_φ ) =𝔼π(x,s)𝔼qφ(z|x,s)[logpθ1(s|x,z)+\displaystyle=\mathbb{E}_{\pi(x,s)}\mathbb{E}_{q_{\varphi}(z\,|\,x,s)}\Bigl{[}% \log p_{\theta_{1}}(s\,|\,x,z)+= blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_s ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x , italic_s ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s | italic_x , italic_z ) +
𝔼pθ2(x|s,z)logpθ1(s|x,z)],\displaystyle\mathbb{E}_{p_{\theta_{2}}(x^{\prime}\,|\,s,z)}\log p_{\theta_{1}% }(s\,|\,x^{\prime},z)\Bigr{]},blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | italic_s , italic_z ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s | italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_z ) ] , (18a)
Lθ2(θ1,θ2,φ)subscript𝐿subscript𝜃2subscript𝜃1subscript𝜃2𝜑\displaystyle L_{\theta_{2}}(\theta_{1},\theta_{2},\varphi)italic_L start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_φ ) =𝔼π(x,s)𝔼qφ(z|x,s)[logpθ2(x|s,z)+\displaystyle=\mathbb{E}_{\pi(x,s)}\mathbb{E}_{q_{\varphi}(z\,|\,x,s)}\Bigl{[}% \log p_{\theta_{2}}(x\,|\,s,z)+= blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_s ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x , italic_s ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x | italic_s , italic_z ) +
𝔼pθ1(s|x,z)logpθ2(x|s,z)],\displaystyle\mathbb{E}_{p_{\theta_{1}}(s^{\prime}\,|\,x,z)}\log p_{\theta_{2}% }(x\,|\,s^{\prime},z)\Bigr{]},blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | italic_x , italic_z ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x | italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_z ) ] , (18b)
Lφ(θ1,θ2,φ)subscript𝐿𝜑subscript𝜃1subscript𝜃2𝜑\displaystyle L_{\varphi}(\theta_{1},\theta_{2},\varphi)italic_L start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_φ ) =𝔼π(z)𝔼pθ(x,s|z)[logqφ(z|x,s)],absentsubscript𝔼𝜋𝑧subscript𝔼subscript𝑝𝜃𝑥conditional𝑠𝑧delimited-[]subscript𝑞𝜑conditional𝑧𝑥𝑠\displaystyle=\mathbb{E}_{\pi(z)}\mathbb{E}_{p_{\theta}(x,s\,|\,z)}\Bigl{[}% \log q_{\varphi}(z\,|\,x,s)\Bigr{]},= blackboard_E start_POSTSUBSCRIPT italic_π ( italic_z ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_s | italic_z ) end_POSTSUBSCRIPT [ roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x , italic_s ) ] , (18c)

where Gibbs sampling is applied for obtaining pairs (x,s)pθ(x,s|z)similar-to𝑥𝑠subscript𝑝𝜃𝑥conditional𝑠𝑧(x,s)\sim p_{\theta}(x,s\,|\,z)( italic_x , italic_s ) ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_s | italic_z ) in the last utility function. (See Appendix D for detailed explanation).

To ease the training, we start by pre-training model parts for p(s)𝑝𝑠p(s)italic_p ( italic_s ) and p(x|s)𝑝conditional𝑥𝑠p(x\,|\,s)italic_p ( italic_x | italic_s ) separately. For the former we introduce latent variables z150subscript𝑧1superscript50z_{1}\in\mathcal{B}^{50}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ caligraphic_B start_POSTSUPERSCRIPT 50 end_POSTSUPERSCRIPT, which should encode segmentation shapes, and define p(s)=z1p(z1)pθ1(s|z1)𝑝𝑠subscriptsubscript𝑧1𝑝subscript𝑧1subscript𝑝subscript𝜃1conditional𝑠subscript𝑧1p(s)=\sum_{z_{1}}p(z_{1})\cdot p_{\theta_{1}}(s\,|\,z_{1})italic_p ( italic_s ) = ∑ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_p ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ⋅ italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) with uniform prior p(z1)𝑝subscript𝑧1p(z_{1})italic_p ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ). The model for p(x|s)𝑝conditional𝑥𝑠p(x\,|\,s)italic_p ( italic_x | italic_s ) is a latent variable model p(x|s)=z2p(z2)pθ2(x|s,z2)𝑝conditional𝑥𝑠subscriptsubscript𝑧2𝑝subscript𝑧2subscript𝑝subscript𝜃2conditional𝑥𝑠subscript𝑧2p(x\,|\,s)=\sum_{z_{2}}p(z_{2})\cdot p_{\theta_{2}}(x\,|\,s,z_{2})italic_p ( italic_x | italic_s ) = ∑ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_p ( italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ⋅ italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x | italic_s , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) with latent variables z2100subscript𝑧2superscript100z_{2}\in\mathcal{B}^{100}italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ caligraphic_B start_POSTSUPERSCRIPT 100 end_POSTSUPERSCRIPT, also uniformly distributed a-priori, which should encode appearance properties, like e.g. segment colours, textures, characteristic shadows etc. Both pθ1(s|z1)subscript𝑝subscript𝜃1conditional𝑠subscript𝑧1p_{\theta_{1}}(s\,|\,z_{1})italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and pθ2(x|s,z2)subscript𝑝subscript𝜃2conditional𝑥𝑠subscript𝑧2p_{\theta_{2}}(x\,|\,s,z_{2})italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x | italic_s , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) are equipped with corresponding encoders, i.e. qφ1(z1|s)subscript𝑞subscript𝜑1conditionalsubscript𝑧1𝑠q_{\varphi_{1}}(z_{1}\,|\,s)italic_q start_POSTSUBSCRIPT italic_φ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_s ) and qφ2(z2|x,s)subscript𝑞subscript𝜑2conditionalsubscript𝑧2𝑥𝑠q_{\varphi_{2}}(z_{2}\,|\,x,s)italic_q start_POSTSUBSCRIPT italic_φ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | italic_x , italic_s ), and trained by symmetric learning, which is straightforward. All conditional probability distributions p𝑝pitalic_p and q𝑞qitalic_q are implemented as moderate complexity feed-forward networks, which output the parameters of the corresponding probability distribution. For example, pθ2(x|s,z2)=𝒩(μθ2(s,z2),σθ2(s,z2))subscript𝑝subscript𝜃2conditional𝑥𝑠subscript𝑧2𝒩subscript𝜇subscript𝜃2𝑠subscript𝑧2subscript𝜎subscript𝜃2𝑠subscript𝑧2p_{\theta_{2}}(x\,|\,s,z_{2})=\mathcal{N}(\mu_{\theta_{2}}(s,z_{2}),\sigma_{% \theta_{2}}(s,z_{2}))italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x | italic_s , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) , italic_σ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ) is a diagonal normal distribution with means μ𝜇\muitalic_μ and standard deviations σ𝜎\sigmaitalic_σ provided by the corresponding network.

Results for the learned pθ2(x|s)subscript𝑝subscript𝜃2conditional𝑥𝑠p_{\theta_{2}}(x\,|\,s)italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x | italic_s ) are illustrated in Fig. 3 in the following way. We consider pairs of training examples, each consisting of an image and its segmentation. The first example is encoded by qφ2(z2|x,s)subscript𝑞subscript𝜑2conditionalsubscript𝑧2𝑥𝑠q_{\varphi_{2}}(z_{2}\,|\,x,s)italic_q start_POSTSUBSCRIPT italic_φ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | italic_x , italic_s ) and the sampled latent code z2subscript𝑧2z_{2}italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is used to decode the segmentation of the second example to an image by using pθ2(x|s,z2)subscript𝑝subscript𝜃2conditional𝑥𝑠subscript𝑧2p_{\theta_{2}}(x\,|\,s,z_{2})italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x | italic_s , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ).

After pre-training we extend the model part pθ1(s|z1)subscript𝑝subscript𝜃1conditional𝑠subscript𝑧1p_{\theta_{1}}(s\,|\,z_{1})italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), learned in the previous step, to represent pθ1(s|x,z)subscript𝑝subscript𝜃1conditional𝑠𝑥𝑧p_{\theta_{1}}(s\,|\,x,z)italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_s | italic_x , italic_z ) by adding an “additional branch”, i.e. we define

p(s|x,z)expf1(z1)+f2(x,z2),soh,proportional-to𝑝conditional𝑠𝑥𝑧subscript𝑓1subscript𝑧1subscript𝑓2𝑥subscript𝑧2subscript𝑠𝑜\displaystyle p(s\,|\,x,z)\propto\exp\bigl{\langle}f_{1}(z_{1})+f_{2}(x,z_{2})% ,s_{oh}\bigr{\rangle},italic_p ( italic_s | italic_x , italic_z ) ∝ roman_exp ⟨ italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) , italic_s start_POSTSUBSCRIPT italic_o italic_h end_POSTSUBSCRIPT ⟩ , (19)

where f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is the pre-trained network, sohsubscript𝑠𝑜s_{oh}italic_s start_POSTSUBSCRIPT italic_o italic_h end_POSTSUBSCRIPT denotes the segmentation in one-hot encoding and f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is the additional network, which makes s𝑠sitalic_s dependent on x𝑥xitalic_x and z2subscript𝑧2z_{2}italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT as well. Its initial weights are chosen so that it outputs zeros at the beginning.

Finally, the model (6) is initialised by the pre-trained components and trained towards a Nash equilibrium for the three player game as explained above. Fig. 5 shows a few results. The model achieves 95.2% segmentation accuracy on the training set and 90.7% segmentation accuracy on the validation set.

Figure 4: First two rows: training data (x,s)𝑥𝑠(x,s)( italic_x , italic_s ). Third and fourth rows: reconstructed images, and segmentations sampled from pθ(x|s,z2)subscript𝑝𝜃conditional𝑥𝑠subscript𝑧2p_{\theta}(x\,|\,s,z_{2})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) and from pθ(s|x,z)subscript𝑝𝜃conditional𝑠𝑥𝑧p_{\theta}(s\,|\,x,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_s | italic_x , italic_z ) with zqφ(z|x,s)similar-to𝑧subscript𝑞𝜑conditional𝑧𝑥𝑠z\sim q_{\varphi}(z\,|\,x,s)italic_z ∼ italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x , italic_s ). Last two rows: sampling image–segmentation pairs from the full limiting distribution.
Refer to caption
Refer to caption
Figure 4: First two rows: training data (x,s)𝑥𝑠(x,s)( italic_x , italic_s ). Third and fourth rows: reconstructed images, and segmentations sampled from pθ(x|s,z2)subscript𝑝𝜃conditional𝑥𝑠subscript𝑧2p_{\theta}(x\,|\,s,z_{2})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) and from pθ(s|x,z)subscript𝑝𝜃conditional𝑠𝑥𝑧p_{\theta}(s\,|\,x,z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_s | italic_x , italic_z ) with zqφ(z|x,s)similar-to𝑧subscript𝑞𝜑conditional𝑧𝑥𝑠z\sim q_{\varphi}(z\,|\,x,s)italic_z ∼ italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x , italic_s ). Last two rows: sampling image–segmentation pairs from the full limiting distribution.
Figure 5: Segmentation from incomplete data. First row: original images from the validation set with hidden parts depicted as black squares. Second row: predicted segmentations, Third row: ground truth segmentations. Fourth row: “in-painting” – average over all images obtained during generation (with clamped visible part).

An important property of the obtained model is its ability to complete missing information for any subset of its variables. Given a partial observation – e.g. an image part, or a segmentation part, or a combination of such parts – we can complete the missing data by sampling from the corresponding limiting distribution. We illustrate this property on the example of inference from incomplete images x𝑥xitalic_x. Let x=(xo,xh)𝑥subscript𝑥𝑜subscript𝑥x=(x_{o},x_{h})italic_x = ( italic_x start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) consist of two parts: an observed part xosubscript𝑥𝑜x_{o}italic_x start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT and a hidden part xhsubscript𝑥x_{h}italic_x start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT. In order to segment such an image by the maximum marginal decision strategy, we need to compute the marginal probabilities p(si|xo)𝑝conditionalsubscript𝑠𝑖subscript𝑥𝑜p(s_{i}\,|\,x_{o})italic_p ( italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) for each pixel i𝑖iitalic_i. They can be estimated by Gibbs sampling, which alternately draws all unobserved random variables, including xhsubscript𝑥x_{h}italic_x start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT. We accumulate segmentation label frequencies for each pixel during the sampling and finally decide for the label with the highest occurrence. A few results are presented in Fig. 5. As compared to the segmentation from complete images, the segmentation accuracy drops from 95.2% to 92.8% for the training set and from 90.7% to 88.8% for the validation set. We consider this accuracy drop as minor, because the segmentations inferred for the hidden image parts need not necessarily coincide with the ground truth – they should only be “plausible”, which is seen in the figure. Although not the primary goal of this experiment, Gibbs sampling allows at the same time to reconstruct the image content in the hidden parts (i.e. in-painting). For this we employ a mean-marginal decision, i.e. we average all sampled image values observed during Gibbs sampling. Although the results are sometimes not perfect (see the last row in Fig. 5), it is however enough to infer reasonable segmentations.

7 CONCLUSION

We propose an alternative learning approach for variational autoencoders. For this we view VAEs as decoder–encoder pairs and derive a symmetric learning formulation inspired by game theory, which leads to a simple learning algorithm for finding a Nash equilibrium. We prove its uniqueness under fairly general assumptions. The proposed method can be applied for various learning scenarios and for models with complex, possibly structured latent spaces. This includes implicit distributions in the latent space as well as discrete latent variables. We show experimentally that the models learned by this method are comparable to those obtained by ELBO learning and demonstrate its applicability for tasks that are not accessible by standard VAE learning.

Acknowledgements

We would like to thank our colleagues Tomas Werner and Denis Barucic for their continued interest in this work and their valuable comments and discussions which helped to improve the manuscript. We also thank the reviewers for their critical remarks, which encouraged us to present more experiments and to resolve remaining unclarities. B.F. gratefully acknowledges support by the Czech OP VVV project ”Research Center for Informatics” (CZ.02.1.01/0.0/0.0/16019/0000765). D.S. was supported by the German Federal Ministry of Education and Research (BMBF) project 01/S18026A-F and by the German Federal Ministry for Economic Affairs and Climate Action (BMWK) project 01MN23021A. A.S. was supported by the Czech Science Foundation grant GA24-12697S. The authors would like to thank the Center for Information Services and HPC (ZIH) at TU Dresden for providing computing resources.

References

  • Arnold and Strauss (1991) B. C. Arnold and D. J. Strauss. Bivariate distributions with conditionals in prescribed exponential families. Journal of the Royal Statistical Society Series B (Methodological), 53(2), 1991.
  • Bornschein and Bengio (2015) Jorg Bornschein and Yoshua Bengio. Reweighted wake-sleep. ArXiv, 1406.2751, 2015.
  • Burda et al. (2016) Yuri Burda, Roger B. Grosse, and Ruslan Salakhutdinov. Importance weighted autoencoders. In ICLR, 2016.
  • Dadaneh et al. (2020) Siamak Zamani Dadaneh, Shahin Boluki, Mingzhang Yin, Mingyuan Zhou, and Xiaoning Qian. Pairwise supervised hashing with Bernoulli variational auto-encoder and self-control gradient estimator. In UAI, volume 124, 2020.
  • Dumoulin et al. (2017) Vincent Dumoulin, Ishmael Belghazi, Ben Poole, Alex Lamb, Martin Arjovsky, Olivier Mastropietro, and Aaron Courville. Adversarially learned inference. In ICLR, 2017.
  • Gregor et al. (2014) Karol Gregor, Ivo Danihelka, Andriy Mnih, Charles Blundell, and Daan Wierstra. Deep autoregressive networks. In ICML, 2014.
  • Gu et al. (2016) Shixiang Gu, Sergey Levine, Ilya Sutskever, and Andriy Mnih. Muprop: Unbiased backpropagation for stochastic neural networks. In ICLR, May 2016.
  • Hinton et al. (1995) Geoffrey E. Hinton, Peter Dayan, Brendan J. Frey, and Radford M. Neal. The "wake-sleep" algorithm for unsupervised neural networks. Science, 268(5214), May 1995.
  • Hinton et al. (2006) Geoffrey E. Hinton, Simon Osindero, and Yee-Whye Teh. A fast learning algorithm for deep belief nets. Neural Comput., 18(7), jul 2006.
  • Ho et al. (2020) Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. In NeurIPS, volume 33, 2020.
  • Huszár (2017) Ferenc Huszár. Variational inference using implicit distributions. ArXiv, abs/1702.08235, 2017.
  • Ikeda et al. (1998) Shiro Ikeda, Shun-ichi Amari, and Hiroyuki Nakahara. Convergence of the wake-sleep algorithm. In NeurIPS, volume 11, 1998.
  • Karras et al. (2018) Tero Karras, Timo Aila, Samuli Laine, and Jaakko Lehtinen. Progressive growing of GANs for improved quality, stability, and variation. In ICLR, 2018.
  • Kingma and Welling (2014) Diederik P Kingma and Max Welling. Auto-encoding variational bayes. In ICLR, 2014.
  • Kingma et al. (2014) Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, and Max Welling. Semi-supervised learning with deep generative models. In NeurIPS, NIPS’14, 2014.
  • Lamb et al. (2017) Alex M Lamb, Devon Hjelm, Yaroslav Ganin, Joseph Paul Cohen, Aaron C Courville, and Yoshua Bengio. Gibbsnet: Iterative adversarial inference for deep graphical models. In NeurIPS, volume 30, 2017.
  • Le et al. (2020) Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, and Frank Wood. Revisiting reweighted wake-sleep for models with stochastic control flow. In UAI, volume 115, 2020.
  • Liu et al. (2021) Chang Liu, Haoyue Tang, Tao Qin, **tao Wang, and Tie-Yan Liu. On the generative utility of cyclic conditionals. In NeurIPS, 2021.
  • Mescheder et al. (2017) Lars Mescheder, Sebastian Nowozin, and Andreas Geiger. Adversarial variational Bayes: Unifying variational autoencoders and generative adversarial networks. In ICML, 2017.
  • Molchanov et al. (2019) Dmitry Molchanov, Valery Kharitonov, Artem Sobolev, and Dmitry P. Vetrov. Doubly semi-implicit variational inference. In AISTATS, volume 89, 2019.
  • Pu et al. (2017) Yuchen Pu, Weiyao Wang, Ricardo Henao, Liqun Chen, Zhe Gan, Chunyuan Li, and Lawrence Carin. Adversarial symmetric variational autoencoder. In NeurIPS, volume 30, 2017.
  • Rezende et al. (2014) Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic backpropagation and approximate inference in deep generative models. In ICML, 2014.
  • Rombach et al. (2022) Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. High-resolution image synthesis with latent diffusion models. In CVPR, 2022.
  • Rosen (1965) J. B. Rosen. Existence and uniqueness of equilibrium points for concave n-person games. Econometrica, 33(3), 1965. doi: 10.2307/1911749.
  • Shekhovtsov et al. (2022) Alexander Shekhovtsov, Dmitrij Schlesinger, and Boris Flach. VAE approximation error: ELBO and exponential families. In ICLR, 2022.
  • Sønderby et al. (2016) Casper Kaae Sønderby, Tapani Raiko, Lars Maaløe, Søren Kaae Sønderby, and Ole Winther. Ladder variational autoencoders. In NeurIPS, volume 29, 2016.
  • Vahdat and Kautz (2020) Arash Vahdat and Jan Kautz. NVAE: A deep hierarchical variational autoencoder. In NeurIPS, 2020.
  • Vértes and Sahani (2018) Eszter Vértes and Maneesh Sahani. Flexible and accurate inference and learning for deep generative models. In NeurIPS, volume 31, 2018.
  • Wenliang et al. (2020) Li Wenliang, Theodore Moskovitz, Heishiro Kanagawa, and Maneesh Sahani. Amortised learning by wake-sleep. In ICML, volume 119, 13–18 Jul 2020.

Appendix A PROOFS

In this section we provide proofs of formal claims regarding uniqueness and consistency-enforcement. See 1

Proof.

We repeat here the model assumptions (3) for convenience

pu(x,z)subscript𝑝𝑢𝑥𝑧\displaystyle p_{u}(x,z)italic_p start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_x , italic_z ) =π(z)exp[ϕ(x,z),uA(u)]absent𝜋𝑧italic-ϕ𝑥𝑧𝑢𝐴𝑢\displaystyle=\pi(z)\exp\bigl{[}\langle\phi(x,z),u\rangle-A(u)\bigr{]}= italic_π ( italic_z ) roman_exp [ ⟨ italic_ϕ ( italic_x , italic_z ) , italic_u ⟩ - italic_A ( italic_u ) ] (20a)
qv(x,z)subscript𝑞𝑣𝑥𝑧\displaystyle q_{v}(x,z)italic_q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_x , italic_z ) =π(x)exp[ψ(x,z),vB(v)].absent𝜋𝑥𝜓𝑥𝑧𝑣𝐵𝑣\displaystyle=\pi(x)\exp\bigl{[}\langle\psi(x,z),v\rangle-B(v)].= italic_π ( italic_x ) roman_exp [ ⟨ italic_ψ ( italic_x , italic_z ) , italic_v ⟩ - italic_B ( italic_v ) ] . (20b)

Our proof relies on the classic result of (Rosen, 1965), who shows that games satisfying diagonal strict concavity (DSC), a condition stronger than concavity, have unique Nash equilibria.

Since the log-partition function of an exponential family is convex in its natural parameters, it follows that the game utilities are concave in their own strategies. A sufficient condition for the stronger DSC criterion is that the symmetrised Jacobian of the map**

[uv][uLp(u,v)vLq(u,v)]maps-tomatrix𝑢𝑣matrixsubscript𝑢subscript𝐿𝑝𝑢𝑣subscript𝑣subscript𝐿𝑞𝑢𝑣\begin{bmatrix}u\\ v\end{bmatrix}\mapsto\begin{bmatrix}\nabla_{u}L_{p}(u,v)\\ \nabla_{v}L_{q}(u,v)\end{bmatrix}[ start_ARG start_ROW start_CELL italic_u end_CELL end_ROW start_ROW start_CELL italic_v end_CELL end_ROW end_ARG ] ↦ [ start_ARG start_ROW start_CELL ∇ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_u , italic_v ) end_CELL end_ROW start_ROW start_CELL ∇ start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_u , italic_v ) end_CELL end_ROW end_ARG ] (21)

is negative definite. The most convenient way to prove this condition is to “dualise” the game. Maximising Lp(u,v)subscript𝐿𝑝𝑢𝑣L_{p}(u,v)italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_u , italic_v ) w.r.t. u𝑢uitalic_u is equivalent to finding the exponential family model, whose expected sufficient statistic 𝔼pu[ϕ(x,z)]subscript𝔼subscript𝑝𝑢delimited-[]italic-ϕ𝑥𝑧\mathbb{E}_{p_{u}}[\phi(x,z)]blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_ϕ ( italic_x , italic_z ) ] coincides with 𝔼qv[ϕ(x,z)]subscript𝔼subscript𝑞𝑣delimited-[]italic-ϕ𝑥𝑧\mathbb{E}_{q_{v}}[\phi(x,z)]blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_ϕ ( italic_x , italic_z ) ]. This follows from

u𝔼qv(x,z)logpu(x,z)=𝔼qv(x,z)[ϕ(x,z)]𝔼pu(x,z)[ϕ(x,z)].subscript𝑢subscript𝔼subscript𝑞𝑣𝑥𝑧subscript𝑝𝑢𝑥𝑧subscript𝔼subscript𝑞𝑣𝑥𝑧delimited-[]italic-ϕ𝑥𝑧subscript𝔼subscript𝑝𝑢𝑥𝑧delimited-[]italic-ϕ𝑥𝑧\nabla_{u}\mathbb{E}_{q_{v}(x,z)}\log p_{u}(x,z)=\\ \mathbb{E}_{q_{v}(x,z)}[\phi(x,z)]-\mathbb{E}_{p_{u}(x,z)}[\phi(x,z)].start_ROW start_CELL ∇ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_x , italic_z ) = end_CELL end_ROW start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT [ italic_ϕ ( italic_x , italic_z ) ] - blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT [ italic_ϕ ( italic_x , italic_z ) ] . end_CELL end_ROW (22)

The corresponding dual task reads

Fp(p)=x,zp(x,z)[logp(x,z)logπ(z)]minpsubscript𝐹𝑝𝑝subscript𝑥𝑧𝑝𝑥𝑧delimited-[]𝑝𝑥𝑧𝜋𝑧subscript𝑝\displaystyle F_{p}(p)=\sum_{x,z}p(x,z)\bigl{[}\log p(x,z)-\log\pi(z)\bigr{]}% \rightarrow\min_{p}italic_F start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_p ) = ∑ start_POSTSUBSCRIPT italic_x , italic_z end_POSTSUBSCRIPT italic_p ( italic_x , italic_z ) [ roman_log italic_p ( italic_x , italic_z ) - roman_log italic_π ( italic_z ) ] → roman_min start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT (23a)
s.t.{𝔼p[ϕ(x,z)]=𝔼q[ϕ(x,z)]x,zp(x,z)=1.s.t.casessubscript𝔼𝑝delimited-[]italic-ϕ𝑥𝑧subscript𝔼𝑞delimited-[]italic-ϕ𝑥𝑧𝑜𝑡ℎ𝑒𝑟𝑤𝑖𝑠𝑒subscript𝑥𝑧𝑝𝑥𝑧1𝑜𝑡ℎ𝑒𝑟𝑤𝑖𝑠𝑒\displaystyle\text{s.t.}\left.\begin{cases*}\mathbb{E}_{p}[\phi(x,z)]=\mathbb{% E}_{q}[\phi(x,z)]\\ \sum_{x,z}p(x,z)=1.\end{cases*}\right.s.t. { start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT [ italic_ϕ ( italic_x , italic_z ) ] = blackboard_E start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT [ italic_ϕ ( italic_x , italic_z ) ] end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL ∑ start_POSTSUBSCRIPT italic_x , italic_z end_POSTSUBSCRIPT italic_p ( italic_x , italic_z ) = 1 . end_CELL start_CELL end_CELL end_ROW (23b)

This can be seen by noticing that (A) is a convex optimisation task with linear constraints. Hence, we can apply Fenchel duality

infp{Fp(p)|Ap=b}=supγ{b,γFp*(ATγ)},subscriptinfimum𝑝conditional-setsubscript𝐹𝑝𝑝𝐴𝑝𝑏subscriptsupremum𝛾𝑏𝛾subscriptsuperscript𝐹𝑝superscript𝐴𝑇𝛾\inf_{p}\bigl{\{}F_{p}(p)\bigm{|}Ap=b\bigr{\}}=\sup_{\gamma}\bigl{\{}\langle b% ,\gamma\rangle-F^{*}_{p}(A^{T}\gamma)\bigr{\}},roman_inf start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT { italic_F start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_p ) | italic_A italic_p = italic_b } = roman_sup start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT { ⟨ italic_b , italic_γ ⟩ - italic_F start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_A start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_γ ) } , (24)

where Fp*subscriptsuperscript𝐹𝑝F^{*}_{p}italic_F start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT denotes the Fenchel conjugate function of Fpsubscript𝐹𝑝F_{p}italic_F start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT. For our case, we have b=(𝔼q[ϕ(x,z)],1)𝑏subscript𝔼𝑞delimited-[]italic-ϕ𝑥𝑧1b=(\mathbb{E}_{q}[\phi(x,z)],1)italic_b = ( blackboard_E start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT [ italic_ϕ ( italic_x , italic_z ) ] , 1 ) and the corresponding dual variables γ=(u,λ)𝛾𝑢𝜆\gamma=(u,\lambda)italic_γ = ( italic_u , italic_λ ). The Fenchel conjugate of the function f(p)=plogpplogπ𝑓𝑝𝑝𝑝𝑝𝜋f(p)=p\log p-p\log\piitalic_f ( italic_p ) = italic_p roman_log italic_p - italic_p roman_log italic_π is f*(w)=πew1superscript𝑓𝑤𝜋superscript𝑒𝑤1f^{*}(w)=\pi e^{w-1}italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( italic_w ) = italic_π italic_e start_POSTSUPERSCRIPT italic_w - 1 end_POSTSUPERSCRIPT. Substituting all terms in the rhs of (24) and solving for λ𝜆\lambdaitalic_λ, we get the task 𝔼qv(x,z)logpu(x,z)maxusubscript𝔼subscript𝑞𝑣𝑥𝑧subscript𝑝𝑢𝑥𝑧subscript𝑢\mathbb{E}_{q_{v}(x,z)}\log p_{u}(x,z)\rightarrow\max_{u}blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_x , italic_z ) → roman_max start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT.

Applying the same dualisation for Lq(u,v)subscript𝐿𝑞𝑢𝑣L_{q}(u,v)italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_u , italic_v ), we obtain the following “dual” game. The strategy of the first player is represented by p(x,z)𝑝𝑥𝑧p(x,z)italic_p ( italic_x , italic_z ) and the strategy of the second player is represented by q(x,z)𝑞𝑥𝑧q(x,z)italic_q ( italic_x , italic_z ). The utility functions Fp(p)subscript𝐹𝑝𝑝-F_{p}(p)- italic_F start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_p ) and Fq(q)subscript𝐹𝑞𝑞-F_{q}(q)- italic_F start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_q ) of the players depend on their respective strategy only. The game has additional linear constraints, where we assume existence of an interior feasible point (p,q)𝑝𝑞(p,q)( italic_p , italic_q ). The assertion of the theorem follows from Theorems 3,4,9 in (Rosen, 1965), if we prove that the symmetrised Jacobian of the map**

[pq][pFp(p)qFq(q)]maps-tomatrix𝑝𝑞matrixsubscript𝑝subscript𝐹𝑝𝑝subscript𝑞subscript𝐹𝑞𝑞\begin{bmatrix}p\\ q\end{bmatrix}\mapsto\begin{bmatrix}\nabla_{p}F_{p}(p)\\ \nabla_{q}F_{q}(q)\end{bmatrix}[ start_ARG start_ROW start_CELL italic_p end_CELL end_ROW start_ROW start_CELL italic_q end_CELL end_ROW end_ARG ] ↦ [ start_ARG start_ROW start_CELL ∇ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_p ) end_CELL end_ROW start_ROW start_CELL ∇ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_q ) end_CELL end_ROW end_ARG ] (25)

is positive definite. This is trivial since the Jacobian is diagonal with elements 1/p(x,z)1𝑝𝑥𝑧1/p(x,z)1 / italic_p ( italic_x , italic_z ) in the first half of the diagonal and elements 1/q(x,z)1𝑞𝑥𝑧1/q(x,z)1 / italic_q ( italic_x , italic_z ) in its second half. ∎

See 1

Proof.

For completeness, we include the fact that Lpsubscriptsuperscript𝐿𝑝L^{\prime}_{p}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT is an alternative form of the ELBO. It is verified as follows:

logpθ(x)DKL(qφ(z|x)pθ(z|x))\displaystyle\log p_{\theta}(x)-D_{\rm KL}(q_{\varphi}(z\,|\,x)\,\|\,p_{\theta% }(z\,|\,x))roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) - italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ∥ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x ) ) (26a)
=logpθ(x)𝔼qφ(z|x)[logqφ(z|x)pθ(z|x)]absentsubscript𝑝𝜃𝑥subscript𝔼subscript𝑞𝜑conditional𝑧𝑥delimited-[]subscript𝑞𝜑conditional𝑧𝑥subscript𝑝𝜃conditional𝑧𝑥\displaystyle=\log p_{\theta}(x)-\mathbb{E}_{q_{\varphi}(z\,|\,x)}\Big{[}\log% \frac{q_{\varphi}(z\,|\,x)}{p_{\theta}(z\,|\,x)}\Big{]}= roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) - blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_POSTSUBSCRIPT [ roman_log divide start_ARG italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_ARG ] (26b)
=𝔼qφ(z|x)[logpθ(x)logqφ(z|x)pθ(z|x)]absentsubscript𝔼subscript𝑞𝜑conditional𝑧𝑥delimited-[]subscript𝑝𝜃𝑥subscript𝑞𝜑conditional𝑧𝑥subscript𝑝𝜃conditional𝑧𝑥\displaystyle=\mathbb{E}_{q_{\varphi}(z\,|\,x)}\Big{[}\log p_{\theta}(x)-\log% \frac{q_{\varphi}(z\,|\,x)}{p_{\theta}(z\,|\,x)}\Big{]}= blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) - roman_log divide start_ARG italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_ARG ] (26c)
=𝔼qφ(z|x)[logpθ(x)pθ(z|x)qφ(z|x)]absentsubscript𝔼subscript𝑞𝜑conditional𝑧𝑥delimited-[]subscript𝑝𝜃𝑥subscript𝑝𝜃conditional𝑧𝑥subscript𝑞𝜑conditional𝑧𝑥\displaystyle=\mathbb{E}_{q_{\varphi}(z\,|\,x)}\Big{[}\log\frac{p_{\theta}(x)p% _{\theta}(z\,|\,x)}{q_{\varphi}(z\,|\,x)}\Big{]}= blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_POSTSUBSCRIPT [ roman_log divide start_ARG italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_ARG start_ARG italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_ARG ] (26d)
=𝔼qφ(z|x)[logpθ(x|z)π(z)qφ(z|x)]absentsubscript𝔼subscript𝑞𝜑conditional𝑧𝑥delimited-[]subscript𝑝𝜃conditional𝑥𝑧𝜋𝑧subscript𝑞𝜑conditional𝑧𝑥\displaystyle=\mathbb{E}_{q_{\varphi}(z\,|\,x)}\Big{[}\log\frac{p_{\theta}(x|z% )\pi(z)}{q_{\varphi}(z\,|\,x)}\Big{]}= blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_POSTSUBSCRIPT [ roman_log divide start_ARG italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) italic_π ( italic_z ) end_ARG start_ARG italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_ARG ] (26e)
=𝔼qφ(z|x)[logpθ(x|z)]DKL(qφ(z|x)π(z)).absentsubscript𝔼subscript𝑞𝜑conditional𝑧𝑥delimited-[]subscript𝑝𝜃conditional𝑥𝑧subscript𝐷KLconditionalsubscript𝑞𝜑conditional𝑧𝑥𝜋𝑧\displaystyle=\mathbb{E}_{q_{\varphi}(z\,|\,x)}\Big{[}\log p_{\theta}(x|z)\Big% {]}-D_{\rm KL}(q_{\varphi}(z\,|\,x)\,\|\,\pi(z)).= blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ) ] - italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ∥ italic_π ( italic_z ) ) .

Therefore,

Lp=Lp𝔼π(x)[DKL(qφ(z|x)π(z))].superscriptsubscript𝐿𝑝subscript𝐿𝑝subscript𝔼𝜋𝑥delimited-[]subscript𝐷KLconditionalsubscript𝑞𝜑conditional𝑧𝑥𝜋𝑧\displaystyle L_{p}^{\prime}=L_{p}-\mathbb{E}_{\pi(x)}[D_{\rm KL}(q_{\varphi}(% z\,|\,x)\,\|\,\pi(z))].italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT - blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ∥ italic_π ( italic_z ) ) ] . (27)

Therefore, for a fixed φ𝜑\varphiitalic_φ, utilities Lpsuperscriptsubscript𝐿𝑝L_{p}^{\prime}italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT and Lpsubscript𝐿𝑝L_{p}italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT share all local and global minima in θ𝜃\thetaitalic_θ. It is straightforward to see that (θ*,ϕ*)subscript𝜃subscriptitalic-ϕ(\theta_{*},\phi_{*})( italic_θ start_POSTSUBSCRIPT * end_POSTSUBSCRIPT , italic_ϕ start_POSTSUBSCRIPT * end_POSTSUBSCRIPT ) is an equilibrium of the game with utilities Lpsuperscriptsubscript𝐿𝑝L_{p}^{\prime}italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT and Lqsuperscriptsubscript𝐿𝑞L_{q}^{\prime}italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT iff it is an equilibrium of the game with utilities (3). ∎

Appendix B WAKE SLEEP

In this section we give a brief overview of the original wake-sleep (WS) algorithm and follow-up works.

Hinton et al. (1995) considered a multilayer network of stochastic neurons. The “recognition” (encoder) connections are used to convert the input vector into a representation in one or more layers of hidden units. The “generative” (decoder) connections are then used to reconstruct an approximation to the input vector from its underlying representation. In the wake phase of WS, given an observed sample x𝑥xitalic_x from the training dataset, a sample of hidden states z𝑧zitalic_z is obtained from the encoder network and the decoder is learned on the joint sample (x,z)𝑥𝑧(x,z)( italic_x , italic_z ). In the sleep phase a joint sample is drawn from the decoder model and the encoder is learned.

The model was initially assuming binary units and factorised encoder and decoder. In case of a hierarchical encoder–decoder model, the learning decouples over layers and no back-propagation is needed. Extended to a deep exponential family model (Vértes and Sahani, 2018), it is equivalent to a hierarchical VAE with the reverse encoder structure.

Bornschein and Bengio (2015) et al. uses importance sampling, similar to IWAE (Burda et al., 2016), to tighten the bounds for the decoder and introduces a wake-phase (importance weighted) update of the encoder, tightening the ELBO (as in VAE) as well.

Vértes and Sahani (2018) and Wenliang et al. (2020) showed that the encoder in WS can be specified implicitly by its mean parameters, which allows for non-conditionally independent encoders. This makes encoders more flexible so that higher quality decoder can be trained but impairs inference.

The advantage of not requiring differentiation through discrete sampling has been explored by Le et al. (2020) for models with stochastic control flow.

To our knowledge, prior work has neither extended WS to semi-supervised setting nor discussed the question of why it is a reasonable algorithm. The only analysis attempt by Ikeda et al. (1998) is limited to a strictly consistent encoder-decoder in a simple special case.

Appendix C (DIS-)SIMILARITIES TO ELBO

In this section we elaborate on similarities and difference between symmetric learning and ELBO learning in unsupervised as well as semi-supervised case (Kingma et al., 2014).

Unsupervised

Recall, that in the unsupervised case we consider utility functions

Lp(θ,φ)=𝔼π(x)𝔼qφ(z|x)[logpθ(x,z)],subscript𝐿𝑝𝜃𝜑subscript𝔼𝜋𝑥subscript𝔼subscript𝑞𝜑conditional𝑧𝑥delimited-[]subscript𝑝𝜃𝑥𝑧\displaystyle L_{p}(\theta,\varphi)=\mathbb{E}_{\pi(x)}\mathbb{E}_{q_{\varphi}% (z\,|\,x)}[\log p_{\theta}(x,z)],italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_θ , italic_φ ) = blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) ] ,
Lq(θ,φ)=𝔼pθ(x,z)[logqφ(z|x)].subscript𝐿𝑞𝜃𝜑subscript𝔼subscript𝑝𝜃𝑥𝑧delimited-[]subscript𝑞𝜑conditional𝑧𝑥\displaystyle L_{q}(\theta,\varphi)=\mathbb{E}_{p_{\theta}(x,z)}[\log q_{% \varphi}(z\,|\,x)].italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_θ , italic_φ ) = blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) end_POSTSUBSCRIPT [ roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ] . (28)

As discussed in Proposition 1, the decoder utility can be equivalently replaced with the common ELBO LB(θ,ϕ)subscript𝐿𝐵𝜃italic-ϕL_{B}(\theta,\phi)italic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ , italic_ϕ ) (both have the same dependence on θ𝜃\thetaitalic_θ). The difference to VAE of Kingma and Welling (2014) is therefore only in the encoder learning. In VAE the encoder is learned to tighten ELBO, i.e. to minimise the so-called reverse KL divergence in the expectation over the data distribution:

𝔼π(x)[DKL(qφ(z|x)pθ(z|x))].\displaystyle\textstyle\mathbb{E}_{\pi(x)}\big{[}D_{\rm KL}(q_{\varphi}(z|x)\,% \|\,p_{\theta}(z|x))\big{]}.blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT [ italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ∥ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x ) ) ] . (29)

In the equilibrium learning, minimising Lqsubscript𝐿𝑞L_{q}italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT in (C) w.r.t. encoder is equivalent to minimising

𝔼pθ(x)[DKL(pθ(z|x)qφ(z|x))],\displaystyle\textstyle\mathbb{E}_{p_{\theta}(x)}\big{[}D_{\rm KL}(p_{\theta}(% z|x)\,\|\,q_{\varphi}(z|x))\big{]},blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_POSTSUBSCRIPT [ italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x ) ∥ italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z | italic_x ) ) ] , (30)

which is a forward KL divergence between the same conditional distributions, and the expectation is over the generative model pθ(x)=zpθ(z)pθ(x|z)subscript𝑝𝜃𝑥subscript𝑧subscript𝑝𝜃𝑧subscript𝑝𝜃conditional𝑥𝑧p_{\theta}(x)=\sum_{z}p_{\theta}(z)p_{\theta}(x|z)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = ∑ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z ). The choice of the encoder as the true posterior, q(z|x)=p(z|x)𝑞conditional𝑧𝑥𝑝conditional𝑧𝑥q(z|x)=p(z|x)italic_q ( italic_z | italic_x ) = italic_p ( italic_z | italic_x ), when possible (i.e. for consistent models), is optimal to both ELBO and symmetric learning. But in general, Lqsubscript𝐿𝑞L_{q}italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT leads to different preferred solutions.

Semi-Supervised

Semi-supervised learning of VAE was previously considered by Kingma et al. (2014). It can be seen that the hierarchical model (14a) is a generalisation of the generative model of Kingma et al. (2014): the state z𝑧zitalic_z consists of two parts (z0,z1)subscript𝑧0subscript𝑧1(z_{0},z_{1})( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), where z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is the image label, available only for a part of images. Similar to unsupervised case, when learning the decoder for a fixed encoder, the learning objective (Kingma et al., 2014, Eq. 8) is equivalent to our Lpsubscript𝐿𝑝L_{p}italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT.

Only the learning of encoder differs. In their formulation the encoder minimises

𝔼π(x)DKL(q(z|x)p(z|x))\displaystyle\mathbb{E}_{\pi(x)}D_{\rm KL}(q(z|x)\,\|\,p(z|x))blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q ( italic_z | italic_x ) ∥ italic_p ( italic_z | italic_x ) ) (31)
+𝔼π(x,z0)DKL(q(z1|x,z0)p(x,z))subscript𝔼𝜋𝑥subscript𝑧0subscript𝐷KLconditional𝑞conditionalsubscript𝑧1𝑥subscript𝑧0𝑝𝑥𝑧\displaystyle+\mathbb{E}_{\pi(x,z_{0})}D_{\rm KL}(q(z_{1}|x,z_{0})\,\|\,p(x,z))+ blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ italic_p ( italic_x , italic_z ) )
α𝔼π(x,z0)logq(z0|x),𝛼subscript𝔼𝜋𝑥subscript𝑧0𝑞conditionalsubscript𝑧0𝑥\displaystyle-\alpha\mathbb{E}_{\pi(x,z_{0})}\log q(z_{0}|x),- italic_α blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT roman_log italic_q ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) ,

where α𝛼\alphaitalic_α is an empirical coefficient. In case when there are no unlabelled pairs, the first term disappears and the ELBO learning approach (Kingma et al., 2014) decouples into learning of a conditional VAE (decoder and encoder conditioned on z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT: p(x|z1,z0)𝑝conditional𝑥subscript𝑧1subscript𝑧0p(x|z_{1},z_{0})italic_p ( italic_x | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), q(z1|x,z0)𝑞conditionalsubscript𝑧1𝑥subscript𝑧0q(z_{1}|x,z_{0})italic_q ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )) and an independent discriminative learning of the encoder part q(z0|x)𝑞conditionalsubscript𝑧0𝑥q(z_{0}|x)italic_q ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) from the labelled data only. Thus, the generative counterpart of the model has no effect on learning of the recognition part (unless there is a parameter sharing).

In our formulation the encoder maximises

𝔼p(x,z)logq(z|x)+𝔼π(x,z0)logq(z0|x).subscript𝔼𝑝𝑥𝑧𝑞conditional𝑧𝑥subscript𝔼𝜋𝑥subscript𝑧0𝑞conditionalsubscript𝑧0𝑥\displaystyle\mathbb{E}_{p(x,z)}\log q(z|x)+\mathbb{E}_{\pi(x,z_{0})}\log q(z_% {0}|x).blackboard_E start_POSTSUBSCRIPT italic_p ( italic_x , italic_z ) end_POSTSUBSCRIPT roman_log italic_q ( italic_z | italic_x ) + blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT roman_log italic_q ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) . (32)

This objective is more homogeneous because both terms correspond to forward KL divergences. When there are no unlabelled training pairs, the objective stays the same and the encoder part q(z0|x)𝑞conditionalsubscript𝑧0𝑥q(z_{0}|x)italic_q ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) still needs to fulfil two goals: to approximate the posterior of the decoder p(z0|x)𝑝conditionalsubscript𝑧0𝑥p(z_{0}|x)italic_p ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) (in the expectation over the generated distribution p(x)𝑝𝑥p(x)italic_p ( italic_x ), like in the unsupervised case) and to approximate the empirical distribution π(z0|x)𝜋conditionalsubscript𝑧0𝑥\pi(z_{0}|x)italic_π ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) (in the expectation over π(x)𝜋𝑥\pi(x)italic_π ( italic_x )). A weighting coefficient might be appropriate here as well to balance the two objectives. Our semi-supervised MNIST experiment in Section 6 with utilities (6) shows that even when switching off the discriminative counterpart, the encoder still efficiently learns to classify.

Appendix D LEARNING MODELS WITH IMPLICIT MARGINALS

Here we give a more detailed derivation of the learning in situations, where a joint model is given by means of its conditional distributions only, i.e. marginal distributions are given implicitly. In particular, we used it in our experiments with CelebA to define and learn p(x,s|z)𝑝𝑥conditional𝑠𝑧p(x,s\,|\,z)italic_p ( italic_x , italic_s | italic_z ), where x𝑥xitalic_x are images, s𝑠sitalic_s are segmentations, and z𝑧zitalic_z are latent variables. Since everything is conditioned on z𝑧zitalic_z we will omit it for clarity and use x𝑥xitalic_x and s𝑠sitalic_s as variables of interest to be inline with our experiments.

With the above agreement, we want to learn two conditional probability distributions pθ(x|s)subscript𝑝𝜃conditional𝑥𝑠p_{\theta}(x\,|\,s)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) and qφ(s|x)subscript𝑞𝜑conditional𝑠𝑥q_{\varphi}(s\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ). As both images and segmentations are rather complex, it is desirable to avoid making any assumptions about the prior (marginal) distributions p(s)𝑝𝑠p(s)italic_p ( italic_s ) and q(x)𝑞𝑥q(x)italic_q ( italic_x ). Towards this end, we consider the MCMC process starting from a random state and alternately sampling using pθ(x|s)subscript𝑝𝜃conditional𝑥𝑠p_{\theta}(x\,|\,s)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) and qφ(s|x)subscript𝑞𝜑conditional𝑠𝑥q_{\varphi}(s\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ). This process defines two limiting joint distributions, depending on which variable was sampled last:

m(s)pθ(x|s)andm(x)qφ(s|x),𝑚𝑠subscript𝑝𝜃conditional𝑥𝑠and𝑚𝑥subscript𝑞𝜑conditional𝑠𝑥m(s)p_{\theta}(x\,|\,s)\ \ \ \text{and}\ \ \ m(x)q_{\varphi}(s\,|\,x),italic_m ( italic_s ) italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) and italic_m ( italic_x ) italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) , (33)

where m(x)𝑚𝑥m(x)italic_m ( italic_x ) and m(s)𝑚𝑠m(s)italic_m ( italic_s ) are solutions to the stationary equations

m(x)𝑚𝑥\displaystyle m(x)italic_m ( italic_x ) =spθ(x|s)m(s)absentsubscript𝑠subscript𝑝𝜃conditional𝑥𝑠𝑚𝑠\displaystyle=\sum_{s}p_{\theta}(x\,|\,s)m(s)= ∑ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) italic_m ( italic_s ) (34a)
m(s)𝑚𝑠\displaystyle m(s)italic_m ( italic_s ) =xqφ(s|x)m(x).absentsubscript𝑥subscript𝑞𝜑conditional𝑠𝑥𝑚𝑥\displaystyle=\sum_{x}q_{\varphi}(s\,|\,x)m(x).= ∑ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) italic_m ( italic_x ) . (34b)

It is natural to consider the mixture of these two limiting distributions

m(x,s)=12[m(s)pθ(x|s)+m(x)qφ(s|x)],𝑚𝑥𝑠12delimited-[]𝑚𝑠subscript𝑝𝜃conditional𝑥𝑠𝑚𝑥subscript𝑞𝜑conditional𝑠𝑥m(x,s)=\frac{1}{2}\Bigl{[}m(s)p_{\theta}(x\,|\,s)+m(x)q_{\varphi}(s\,|\,x)% \Bigr{]},italic_m ( italic_x , italic_s ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG [ italic_m ( italic_s ) italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) + italic_m ( italic_x ) italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) ] , (35)

as we suggest in (10). Our goal therefore will be to maximise the likelihood of the data π(x,s)𝜋𝑥𝑠\pi(x,s)italic_π ( italic_x , italic_s ) under this mixture joint model. The likelihood can be lower-bounded w.r.t. mixture components as

𝔼π(x,s)logm(x,s)12[𝔼π(s)logm(s)+𝔼π(x,s)logpθ(x|s)+𝔼π(x)logm(x)+𝔼π(x,s)logqφ(s|x)].subscript𝔼𝜋𝑥𝑠𝑚𝑥𝑠12delimited-[]subscript𝔼𝜋𝑠𝑚𝑠subscript𝔼𝜋𝑥𝑠subscript𝑝𝜃|𝑥𝑠subscript𝔼𝜋𝑥𝑚𝑥subscript𝔼𝜋𝑥𝑠subscript𝑞𝜑|𝑠𝑥\mathbb{E}_{\pi(x,s)}\log m(x,s)\\ \geq\frac{1}{2}\Bigl{[}\mathbb{E}_{\pi(s)}\log m(s)+\mathbb{E}_{\pi(x,s)}\log p% _{\theta}(x\,|\,s)+\\ \mathbb{E}_{\pi(x)}\log m(x)+\mathbb{E}_{\pi(x,s)}\log q_{\varphi}(s\,|\,x)% \Bigr{]}.start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_s ) end_POSTSUBSCRIPT roman_log italic_m ( italic_x , italic_s ) end_CELL end_ROW start_ROW start_CELL ≥ divide start_ARG 1 end_ARG start_ARG 2 end_ARG [ blackboard_E start_POSTSUBSCRIPT italic_π ( italic_s ) end_POSTSUBSCRIPT roman_log italic_m ( italic_s ) + blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_s ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) + end_CELL end_ROW start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT roman_log italic_m ( italic_x ) + blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_s ) end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) ] . end_CELL end_ROW (36)

Note that this lower bound is tight if the mixture components coincide, i.e. pθ(x|s)subscript𝑝𝜃conditional𝑥𝑠p_{\theta}(x\,|\,s)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) and qφ(s|x)subscript𝑞𝜑conditional𝑠𝑥q_{\varphi}(s\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) are consistent. The terms in (36) corresponding to pθsubscript𝑝𝜃p_{\theta}italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and qφsubscript𝑞𝜑q_{\varphi}italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT are tractable under assumption (2). However, m(x)𝑚𝑥m(x)italic_m ( italic_x ) and m(s)𝑚𝑠m(s)italic_m ( italic_s ) are not given in closed form and depend on both θ𝜃\thetaitalic_θ and φ𝜑\varphiitalic_φ. We approximate their defining equations (D) as

mθ(x)subscript𝑚𝜃𝑥\displaystyle m_{\theta}(x)italic_m start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) =spθ(x|s)π(s)absentsubscript𝑠subscript𝑝𝜃conditional𝑥𝑠𝜋𝑠\displaystyle=\sum_{s}p_{\theta}(x\,|\,s)\pi(s)= ∑ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) italic_π ( italic_s ) (37a)
mφ(s)subscript𝑚𝜑𝑠\displaystyle m_{\varphi}(s)italic_m start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s ) =xqφ(s|x)π(x)absentsubscript𝑥subscript𝑞𝜑conditional𝑠𝑥𝜋𝑥\displaystyle=\sum_{x}q_{\varphi}(s\,|\,x)\pi(x)= ∑ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) italic_π ( italic_x ) (37b)

and use these expressions in the mixture model (35). With this approximation, (36) sums the data likelihood terms with respect to separate model components pθ(x|s)subscript𝑝𝜃conditional𝑥𝑠p_{\theta}(x\,|\,s)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ), mθ(x)subscript𝑚𝜃𝑥m_{\theta}(x)italic_m start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ), qφ(s|x)subscript𝑞𝜑conditional𝑠𝑥q_{\varphi}(s\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) and mφ(s)subscript𝑚𝜑𝑠m_{\varphi}(s)italic_m start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s ). Hence, optimising this sum decouples into optimising the two objectives

Lpsubscript𝐿𝑝\displaystyle L_{p}italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT =𝔼π(x,s)logpθ(x|s)+𝔼π(x)logmθ(x),absentsubscript𝔼𝜋𝑥𝑠subscript𝑝𝜃conditional𝑥𝑠subscript𝔼𝜋𝑥subscript𝑚𝜃𝑥\displaystyle=\mathbb{E}_{\pi(x,s)}\log p_{\theta}(x\,|\,s)+\mathbb{E}_{\pi(x)% }\log m_{\theta}(x),= blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_s ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) + blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT roman_log italic_m start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) ,
Lqsubscript𝐿𝑞\displaystyle L_{q}italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT =𝔼π(x,s)logqφ(s|x)+𝔼π(s)logmφ(s)absentsubscript𝔼𝜋𝑥𝑠subscript𝑞𝜑conditional𝑠𝑥subscript𝔼𝜋𝑠subscript𝑚𝜑𝑠\displaystyle=\mathbb{E}_{\pi(x,s)}\log q_{\varphi}(s\,|\,x)+\mathbb{E}_{\pi(s% )}\log m_{\varphi}(s)= blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_s ) end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) + blackboard_E start_POSTSUBSCRIPT italic_π ( italic_s ) end_POSTSUBSCRIPT roman_log italic_m start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s ) (38)

independently in θ𝜃\thetaitalic_θ and φ𝜑\varphiitalic_φ, respectively. It remains only to explain how to handle logmθ(x)subscript𝑚𝜃𝑥\log m_{\theta}(x)roman_log italic_m start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) and logmφ(s)subscript𝑚𝜑𝑠\log m_{\varphi}(s)roman_log italic_m start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s ), which are still intractable. Substituting (D) and introducing a lower bound for logmθ(x)subscript𝑚𝜃𝑥\log m_{\theta}(x)roman_log italic_m start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) w.r.t. summation over s𝑠sitalic_s gives

𝔼π(x)logmθ(x)𝔼π(x)𝔼qφ(s|x)[logpθ(x|s)+logπ(s)logqφ(s|x)].subscript𝔼𝜋𝑥subscript𝑚𝜃𝑥subscript𝔼𝜋𝑥subscript𝔼subscript𝑞𝜑conditional𝑠𝑥delimited-[]subscript𝑝𝜃|𝑥𝑠𝜋𝑠subscript𝑞𝜑|𝑠𝑥\mathbb{E}_{\pi(x)}\log m_{\theta}(x)\geq\mathbb{E}_{\pi(x)}\mathbb{E}_{q_{% \varphi}(s\,|\,x)}\Bigl{[}\log p_{\theta}(x\,|\,s)+\\ \log\pi(s)-\log q_{\varphi}(s\,|\,x)\Bigr{]}.start_ROW start_CELL blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT roman_log italic_m start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) ≥ blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) + end_CELL end_ROW start_ROW start_CELL roman_log italic_π ( italic_s ) - roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) ] . end_CELL end_ROW (39)

If we consider the equilibrium learning approach, the objective Lpsubscript𝐿𝑝L_{p}italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT is to be optimised only w.r.t. its own parameters θ𝜃\thetaitalic_θ, and therefore we can drop logπ(s)𝜋𝑠\log\pi(s)roman_log italic_π ( italic_s ) and logqφ(s|x)subscript𝑞𝜑conditional𝑠𝑥\log q_{\varphi}(s\,|\,x)roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) terms. Applying similar steps to 𝔼π(s)logmφ(s)subscript𝔼𝜋𝑠subscript𝑚𝜑𝑠\mathbb{E}_{\pi(s)}\log m_{\varphi}(s)blackboard_E start_POSTSUBSCRIPT italic_π ( italic_s ) end_POSTSUBSCRIPT roman_log italic_m start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s ) leads to the following effective equilibrium learning objectives:

L~p(θ,φ)=𝔼π(x,s)logpθ(x|s)++𝔼π(x)𝔼qφ(s|x)logpθ(x|s),subscript~𝐿𝑝𝜃𝜑subscript𝔼𝜋𝑥𝑠subscript𝑝𝜃conditional𝑥𝑠subscript𝔼𝜋𝑥subscript𝔼subscript𝑞𝜑conditional𝑠𝑥subscript𝑝𝜃conditional𝑥𝑠\tilde{L}_{p}(\theta,\varphi)=\mathbb{E}_{\pi(x,s)}\log p_{\theta}(x\,|\,s)+\\ +\mathbb{E}_{\pi(x)}\mathbb{E}_{q_{\varphi}(s\,|\,x)}\log p_{\theta}(x\,|\,s),start_ROW start_CELL over~ start_ARG italic_L end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_θ , italic_φ ) = blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_s ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) + end_CELL end_ROW start_ROW start_CELL + blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) , end_CELL end_ROW (40)
L~q(θ,φ)=𝔼π(x,s)logqφ(s|x)++𝔼π(s)𝔼pθ(x|s)logqφ(s|x).subscript~𝐿𝑞𝜃𝜑subscript𝔼𝜋𝑥𝑠subscript𝑞𝜑conditional𝑠𝑥subscript𝔼𝜋𝑠subscript𝔼subscript𝑝𝜃conditional𝑥𝑠subscript𝑞𝜑conditional𝑠𝑥\tilde{L}_{q}(\theta,\varphi)=\mathbb{E}_{\pi(x,s)}\log q_{\varphi}(s\,|\,x)+% \\ +\mathbb{E}_{\pi(s)}\mathbb{E}_{p_{\theta}(x\,|\,s)}\log q_{\varphi}(s\,|\,x).start_ROW start_CELL over~ start_ARG italic_L end_ARG start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( italic_θ , italic_φ ) = blackboard_E start_POSTSUBSCRIPT italic_π ( italic_x , italic_s ) end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) + end_CELL end_ROW start_ROW start_CELL + blackboard_E start_POSTSUBSCRIPT italic_π ( italic_s ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_s ) end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_s | italic_x ) . end_CELL end_ROW (41)

Note that the first terms in these utilities correspond to the pseudo-likelihood objective, whereas the mutual completion in the second terms additionally enforces consistency.

Appendix E ADDITIONAL DETAILS FOR MNIST EXPERIMENTS

Refer to caption
Figure 6: MNIST network architecture.
Refer to caption
Refer to caption
Figure 7: MNIST training with ELBO and symmetric learning. Top: data-term and KL-term for ELBO learning, bottom: negative utilities for symmetric learning.
Figure 8: Fashion MNIST. Left: ELBO learning, right: symmetric learning. In each image: top row: original images, second row: means of the reconstructed images, i.e. μ(z1)𝜇subscript𝑧1\mu(z_{1})italic_μ ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) with z1q(z1|x)similar-tosubscript𝑧1𝑞conditionalsubscript𝑧1𝑥z_{1}\sim q(z_{1}\,|\,x)italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∼ italic_q ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ), third row: images generated from random codes (means visualised), fourth row: sampling from limiting distribution including image noise, last row: sampling from limiting distribution, means visualised.
Refer to caption
Refer to caption
Random Latent Codes Limiting Distribution

ELBO

Refer to caption Refer to caption
FID=32.37FID32.37\text{FID}=32.37FID = 32.37 FID=48.89FID48.89\text{FID}=48.89FID = 48.89

Symmetric

Refer to caption Refer to caption
FID=40.57FID40.57\text{FID}=40.57FID = 40.57 FID=37.85FID37.85\text{FID}=37.85FID = 37.85
Figure 8: Fashion MNIST. Left: ELBO learning, right: symmetric learning. In each image: top row: original images, second row: means of the reconstructed images, i.e. μ(z1)𝜇subscript𝑧1\mu(z_{1})italic_μ ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) with z1q(z1|x)similar-tosubscript𝑧1𝑞conditionalsubscript𝑧1𝑥z_{1}\sim q(z_{1}\,|\,x)italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∼ italic_q ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ), third row: images generated from random codes (means visualised), fourth row: sampling from limiting distribution including image noise, last row: sampling from limiting distribution, means visualised.
Figure 9: Fashion MNIST. FID scores and images generated from random latent codes and from limiting distributions of models learned by maximising ELBO and by symmetric equilibrium learning (images are shown by means for better visibility).

Here, we provide additional implementation details for the HVAE models used by symmetric learning and by ELBO optimisation in the first MNIST experiment. The first model variant is defined by the decoder pθ(z0,z1,x)=p(z0)pθ(z1|z0)pθ(x|z1)subscript𝑝𝜃subscript𝑧0subscript𝑧1𝑥𝑝subscript𝑧0subscript𝑝𝜃conditionalsubscript𝑧1subscript𝑧0subscript𝑝𝜃conditional𝑥subscript𝑧1\textstyle p_{\theta}(z_{0},z_{1},x)=p(z_{0})p_{\theta}(z_{1}\,|\,z_{0})p_{% \theta}(x\,|\,z_{1})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x ) = italic_p ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and the encoder qφ(z0,z1,x)=π(x)qφ(z0|x)qφ(z1|z0,x),subscript𝑞𝜑subscript𝑧0subscript𝑧1𝑥𝜋𝑥subscript𝑞𝜑conditionalsubscript𝑧0𝑥subscript𝑞𝜑conditionalsubscript𝑧1subscript𝑧0𝑥\textstyle q_{\varphi}(z_{0},z_{1},x)=\pi(x)q_{\varphi}(z_{0}\,|\,x)q_{\varphi% }(z_{1}\,|\,z_{0},x),italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x ) = italic_π ( italic_x ) italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x ) italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_x ) , where p(z0)𝑝subscript𝑧0p(z_{0})italic_p ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is uniform and π(x)𝜋𝑥\pi(x)italic_π ( italic_x ) is the data distribution. The network architecture is shown in Fig. 6. The one-dimensional components are connected by a Multi-layer Perceptron (MLP) architecture. We used two hidden layers, 600600600600 hidden units each in our MLPs. Connections between z1subscript𝑧1z_{1}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and x𝑥xitalic_x are implemented by standard convolutional encoder/decoder architectures with decreasing and increasing spatial resolutions respectively. Both encoder and decoder have 6 hidden layers, connected by 2D-convolution operations. In order to effectively reduce the spatial dimension some convolutions are performed with strides. We used the tanh\tanhroman_tanh activation function everywhere. The network weights are learned using the Adam-optimiser.

The hierarchical decoder consists of two “separate” networks, an MLP and a decoder, representing pθ(z1|z0)subscript𝑝𝜃conditionalsubscript𝑧1subscript𝑧0p_{\theta}(z_{1}\,|\,z_{0})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and pθ(x|z1)subscript𝑝𝜃conditional𝑥subscript𝑧1p_{\theta}(x\,|\,z_{1})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) respectively. The encoder corresponding to the direct factorisation order (shown in the figure) is a multi-head network. The common part is an encoder, which produces intermediate features, whereas the heads are an MLP for f0(x)subscript𝑓0𝑥f_{0}(x)italic_f start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) and a single fully connected layer for f1(x)subscript𝑓1𝑥f_{1}(x)italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ). Two network outputs f0(x)subscript𝑓0𝑥f_{0}(x)italic_f start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) and f1(x)subscript𝑓1𝑥f_{1}(x)italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) serve as multiplier to the hierarchical decoder model, so qφ(z0)=f0(x)subscript𝑞𝜑subscript𝑧0subscript𝑓0𝑥q_{\varphi}(z_{0})=f_{0}(x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_f start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) and qφ(z1|z0,x)pθ(z1|z0)f1(x)proportional-tosubscript𝑞𝜑conditionalsubscript𝑧1subscript𝑧0𝑥subscript𝑝𝜃conditionalsubscript𝑧1subscript𝑧0subscript𝑓1𝑥q_{\varphi}(z_{1}\,|\,z_{0},x)\propto p_{\theta}(z_{1}\,|\,z_{0})\cdot f_{1}(x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_x ) ∝ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ⋅ italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ). For the reverse factorisation order we keep the hierarchical encoder architecture basically the same but split it into two separate networks: the encoder for qφ(z1|x)subscript𝑞𝜑conditionalsubscript𝑧1𝑥q_{\varphi}(z_{1}\,|\,x)italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) and the MLP for qφ(z0|z1)subscript𝑞𝜑conditionalsubscript𝑧0subscript𝑧1q_{\varphi}(z_{0}\,|\,z_{1})italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ).

The learning curves for losses/utilities are shown in Fig. 7 for ELBO learning and symmetric learning respectively as a function of gradient update steps. For better clarity all values are normalised by the number of corresponding elements, e.g. we show the per-pixel data-loss in ELBO. It is clearly seen that the convergence behaviours are pretty similar in both cases: all values converge very quickly to almost their final values, followed by a long period in which they change much more slowly. However, we observed that the quality of generated images keeps improving, even after the losses/utilities have almost reached saturation. Hence, we run all our experiments with a small learning rate of 104superscript10410^{-4}10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT for 1M gradient update steps (note: only first 100k steps are shown in Fig. 7 for better visibility).

We further compare the HVAE models obtained by symmetric learning and by ELBO optimisation by embedding samples for z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and z1subscript𝑧1z_{1}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT from (i) the prior distributions p(z0)𝑝subscript𝑧0p(z_{0})italic_p ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), pθ(z1)subscript𝑝𝜃subscript𝑧1p_{\theta}(z_{1})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), (ii) the posterior distributions qφ(z0)subscript𝑞𝜑subscript𝑧0q_{\varphi}(z_{0})italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), qφ(z1)subscript𝑞𝜑subscript𝑧1q_{\varphi}(z_{1})italic_q start_POSTSUBSCRIPT italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), and (iii) the limiting distributions mθ,φ(z0)subscript𝑚𝜃𝜑subscript𝑧0m_{\theta,\varphi}(z_{0})italic_m start_POSTSUBSCRIPT italic_θ , italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and mθ,φ(z1)subscript𝑚𝜃𝜑subscript𝑧1m_{\theta,\varphi}(z_{1})italic_m start_POSTSUBSCRIPT italic_θ , italic_φ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) for each of the two models by tSNE. Fig. 10 shows that all three samples match well for the model learned by symmetric learning. This is however not the case for the model learned by ELBO.

Refer to caption
Figure 10: MNIST: tSNE embeddings of latent variables z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and z1subscript𝑧1z_{1}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT for ELBO maximisation and symmetric learning.

Appendix F FASHION MNIST

We also tested our approach for HVAE with the direct encoder factorisation order on the Fashion MNIST dataset. The model is exactly the same as the one used in our first MNIST experiment, except:

  • Images are grey-valued now. We model them by a Gaussian, where the means for all pixels are computed by a network, and the standard deviation is common for all pixels and does not depend on z𝑧zitalic_z, i.e. pθ,σ(x|z1)=𝒩(x;μθ(z1),σ)subscript𝑝𝜃𝜎conditional𝑥subscript𝑧1𝒩𝑥subscript𝜇𝜃subscript𝑧1𝜎p_{\theta,\sigma}(x\,|\,z_{1})=\mathcal{N}(x;\mu_{\theta}(z_{1}),\sigma)italic_p start_POSTSUBSCRIPT italic_θ , italic_σ end_POSTSUBSCRIPT ( italic_x | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = caligraphic_N ( italic_x ; italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_σ ). The network architecture for μθ(z1)subscript𝜇𝜃subscript𝑧1\mu_{\theta}(z_{1})italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) is the same as the decoder in the MNIST experiment, σ𝜎\sigmaitalic_σ is learned alongside with the network weights.

  • We observed that the overall results are slightly better (especially for ELBO), when using ReLU activations in p(x|z1)𝑝conditional𝑥subscript𝑧1p(x\,|\,z_{1})italic_p ( italic_x | italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) instead of tanh\tanhroman_tanh used for MNIST.

The results are shown in  Figs. 9 and 9. They confirm our finding, that ELBO and symmetric learning are on par, whereby the latter produces more consistent encoder/decoder pairs.