Auto-Encoding Variational Bayes

drawing.png

This is the Variational Auto-Encoder paper by Kingma & Welling.

For a simple pytorch implementation, check out my github repo, where I auto-encoded CryptoPunks.

Problem statement

There is a random variable XX that is generated by:

  • sampling a latent variable zpθ(z)z\sim p_{\theta^*}(z)
  • sampling xpθ(xz)x\sim p_{\theta^*}(x\vert z)

where the true parameters θ\theta^* of the distribution and the latent variable zz are hidden.

Reminder on the evidence lower bound

Suppose that we want to estimate the distribution pθp_{\theta^*} by estimating parameters θ\theta. θ\theta is both used to generate zz and then xx (by sampling pθxzp_\theta{x\vert z}).

Let's say we found some appropriate θ\theta, it's intractable to estimate pθ(x)p_\theta(x) for a given xx because we need to marginalize over all zzs:

pθ(x)=pθ(xz)pθ(z)dzp_\theta(x)=\int p_\theta(x\vert z)p_\theta(z)dz

However, we also know that, given a certain zz we have: pθ(x,z)=pθ(x/vertz)p(z)p_\theta(x,z) = p_\theta(x/vert z)p(z) and pθ(x,z)=pθ(zx)p(x)p_\theta(x, z) = p_\theta(z\vert x)p(x).

Combining the two:

pθ(x)=pθ(z)pθ(xz)/pθ(zx)p_\theta(x)=p_\theta(z)p_\theta(x\vert z)/p_\theta(z\vert x) for any zz.

If we can find a good approximation of pθ(zx)p_\theta(z\vert x), we can compute pθ(x)p_\theta(x). However:

  • we might not have a closed form solution
  • pθ(zx)=pθ(xz)pθ(z)/pθ(x)p_\theta(z\vert x) = p_\theta(x\vert z)p_\theta(z)/p_\theta(x) and we already established that p(x)p(x) is intractable.

Therefore, we'll use another distribution family qϕ(zx)q_\phi(z\vert x) to approximate pθ(zx)p_\theta(z\vert x).

Now, to estimate the best model pθp_\theta we wish to maximize it's likelihood, which is the probability density function: pθ(x)p_\theta(x). For convenience, we usually use the log-likelihood. The optimization problem is equivalent since log\log is a monotonously increasing function.

lnpθ(x)=lnpθ(x,z)dz\ln p_\theta(x) = \ln \int p_\theta(x, z)dz

Recall that this is intractable, because of the integral over zz. We can estimate this using importance sampling:

lnpθ(x)=lnEqϕ(zx)pθ(x,z)qϕ(zx)Eqϕ(zx)lnpθ(x,z)qϕ(zx) (Jensen’s inequality)\ln p_\theta(x) = \ln \mathbb{E}_{q_\phi(z\vert x)}\frac{p_\theta(x, z)}{q_\phi(z\vert x)} \geq \mathbb{E}_{q_\phi(z\vert x)}\ln\frac{p_\theta(x, z)}{q_\phi(z\vert x)}\text{ (Jensen's inequality)}

Now, this term: Eqϕ(zx)lnpθ(x,z)qϕ(zx)\mathbb{E}_{q_\phi(z\vert x)}\ln\frac{p_\theta(x, z)}{q_\phi(z\vert x)} is what we call the evidence lower bound L(θ,ϕ)\mathcal{L}(\theta, \phi). But it's not over. Let's put it on the left hand side and combine it with the log-likelihood term lnpθ(x)\ln p_\theta(x):

lnpθ(x)Eqϕ(zx)lnpθ(x,z)qϕ(zx)0\ln p_\theta(x) - \mathbb{E}_{q_\phi(z\vert x)}\ln\frac{p_\theta(x, z)}{q_\phi(z\vert x)} \geq 0

where

lnpθ(x)Eqϕ(zx)lnpθ(x,z)qϕ(zx)=Eqϕ(zx)[lnpθ(x,z)qϕ(zx)lnpθ(x)]=Eqϕ(zx)[lnpθ(x,z)/pθ(x)qϕ(zx)]=Eqϕ(zx)[lnpθ(zx)qϕ(zx)]=DKL(qϕ(zx)pθ(zx))\begin{aligned}\ln p_\theta(x) - \mathbb{E}_{q_\phi(z\vert x)}\ln\frac{p_\theta(x, z)}{q_\phi(z\vert x)} & = -\mathbb{E}_{q_\phi(z\vert x)}[\ln\frac{p_\theta(x, z)}{q_\phi(z\vert x)} - \ln p_\theta(x)] \\& = -\mathbb{E}_{q_\phi(z\vert x)}[\ln\frac{p_\theta(x, z)/p_\theta(x)}{q_\phi(z\vert x)}] \\& = -\mathbb{E}_{q_\phi(z\vert x)}[\ln\frac{p_\theta(z\vert x)}{q_\phi(z\vert x)}] \\& = D_{KL}(q_\phi(z\vert x)\Vert p_\theta(z\vert x)) \\\end{aligned}

So we basically have:

lnpθ(x)log likelihoodL(θ,ϕ)ELBO=DKL(qϕ(zx)pθ(zx))KL divergence0\underbrace{\ln p_\theta(x)}_{\text{log likelihood}} - \underbrace{\mathcal{L}(\theta, \phi)}_{\text{ELBO}} = \underbrace{D_{KL}(q_\phi(z\vert x)\Vert p_\theta(z\vert x))}_{\text{KL divergence}} \geq 0

Therefore, by maximizing the ELBO we are maximizing the log-likelihood. The point of the paper is to find a way to differentiate and optimize the ELBO with low variance.

To recap, here's our encoder: qϕ(zx)pθ(zx)q_\phi(z\vert x)\approx p_\theta(z\vert x); here's our decoder pθ(xz)p_\theta(x\vert z). We'll learn ϕ\phi and θ\theta jointly.

We want the algo to work in the case of:

  • intractibility: can't compute pθ(z)=pθ(z)pθ(xz)dzp_\theta(z) = \int p_\theta(z)p_\theta(x\vert z)dz or pθ(zx)=pθ(xz)pθ(z)pθ(x)p_\theta(z\vert x)=\frac{p_\theta(x\vert z)p_\theta(z)}{p_\theta(x)} (EM can't be used). pθ(x)p_\theta(x) intractable because of large dimensionality (image). pθ(xz)p_\theta(x\vert z) intractable due to large number of hidden variables (in case of neural net with nonlinear hidden layer for example).
  • large dataset: monte carlo EM would be too slow (expensive sampling loop per datapoint)

Some relevant applications:

  • the parameters θ\theta can be of interest if we're analyzing some natural process or want to generate artificial data (by sampling p(xθ)p(x\vert \theta)). We want efficient max likelihood (maximize probability of seeing the data given the model) or max à priori estimation (maximize probability of the model given the data; requires a model prior) of parameters θ\theta.
  • representing data (e.g. generating image embeddings): posterior inference of zz given xx
  • marginal inference of xx for tasks where a prior over xx is required like image denoising, inpainting, superresolution

Stochastic Gradient Variational Bayes

The ELBO L(pθ,qϕ)L(p_\theta, q_\phi) can also be written as L(pθ,qϕ)=DKL(qϕ(zx)pθ(z))+Eqϕ(zx)(logpθ(xz))L(p_\theta, q_\phi)=-D_{KL}(q_\phi(z\vert x)\lVert p_\theta(z)) + \mathbb{E_{q_\phi(z\vert x)}}(\log p_\theta(x\vert z))

The first term is the KL divergence which can be seen as a regularization term with respect to the prior pθ(z)p_\theta(z). The second term is the reconstruction loss: given zz we want to reconstruct xx.

Naively, we can use a monte carlo gradient estimator for either term: ϕEqϕ(z)[f(z)]=ϕzq(z)f(z)dz=zϕq(z)f(z)dz=zq(z)1q(z)ϕq(z)ϕlogqϕ(z)f(z)dz=E[f(z)ϕlogqϕ(z)]\nabla_\phi \mathbb{E}_{q_\phi(z)}[f(z)]=\nabla_\phi \int_z q(z)f(z)dz=\int_z \nabla_\phi q(z) f(z) dz =\int_z q(z)\underbrace{\frac{1}{q(z)}\nabla_\phi q(z)}_{\nabla_\phi \log q_\phi(z)} f(z) dz = \mathbb{E}[f(z)\nabla_\phi \log q_\phi(z)]

where f(z)=logpθ(z)qϕ(zx)f(z) = \log\frac{p_\theta(z)}{q_\phi(z\vert x)} or f(z)=logpθ(xz)f(z)=\log p_\theta(x\vert z)

Thus: ϕEqϕ(z)[f(z)]1Lf(z)ϕlogqϕ(z)\nabla_\phi \mathbb{E}_{q_\phi(z)}[f(z)] \approx \frac{1}{L}\sum f(z)\nabla_\phi \log q_\phi(z)

However, this estimator has very high variance.

We can reparameterize qϕ(zx)q_\phi(z\vert x) as z~=gϕ(ϵ,x),ϵp(ϵ)\tilde z = g_\phi(\epsilon , x), \epsilon\sim p(\epsilon).

For instance, z=N(μ,σ)z=\mathcal{N}(\mu, \sigma): z~=μ+σϵ,ϵN(0,1)\tilde z=\mu + \sigma \epsilon, \epsilon\sim\mathcal{N}(0, 1)

Eqϕ(zx)[f(z)]=Ep(ϵ)[f(gϕ(ϵ,x))]1Lf(gϕ(ϵ,x))\mathbb{E}_{q_\phi(z\vert x)}[f(z)]=\mathbb{E}_{p(\epsilon)}[f(g_\phi(\epsilon, x))]\approx \frac{1}{L}\sum f(g_\phi(\epsilon, x))

KL divergence can be integrated analytically (e.g. when prior pθ(z)p_\theta(z) and posterior qϕ(zx)q_\phi(z\vert x) are gaussian, see appendix B in the paper). KL divergence can be interpreted as regularization to encourage q(xz)q(x\vert z) to be close to prior. Only the expected reconstruction error requires estimation Eqϕ(zx)[logpθ(xz)]\mathbb{E}_{q_\phi(z\vert x)}[\log p_\theta(x\vert z)]

Example: Variational Auto-Encoder

  • we set a prior over the latent pθ(z)=N(z;0,I)p_\theta(z)=\mathcal{N}(z; 0, I).
  • For discrete data, encoder pθ(xz)p_\theta(x\vert z) is a multivariate Bernoulli distribution:

logp(xz)=xilogyi+(1xi)log(1yi)\log p(x\vert z) = \sum x_i \log y_i + (1-x_i)\log(1-y_i)

(product of each pixel independently, p(xiz)=yip(x_i\vert z) = y_i if xi=1x_i=1 and 1yi1-y_i otherwise)

where yi=sigmoid(W2+tanh(W1z+b1)+b2)y_i=\text{sigmoid}(W_2 + \tanh(W_1 z +b_1) + b_2)

  • For real value data, encoder pθ(xz)p_\theta(x\vert z) is a multivariate gaussian:

logp(xz)=logN(x;μ,σ2I)\log p(x\vert z) = \log \mathcal{N}(x; \mu, \sigma^2 I)

where μ=W1h+b1,logσ2=W2h+b2,h=tanh(W0z+b0)\mu=W_1 h + b_1, \log \sigma^2 = W_2 h + b_2, h =\tanh(W_0 z + b_0).

  • The true posterior pθ(zx)p_\theta(z\vert x) is intractable. We approximate it with decoder qϕ(zx)q_\phi(z\vert x). For qϕ(zx)q_\phi(z\vert x), same formula as multivar gaussian (with diff parameters) and zz and xx are swapped

KL divergence can be computed and differentiated analytically. For reconstruction loss, we sample zμ+σϵz\sim \mu + \sigma \epsilon where ϵN(0,I)\epsilon\sim\mathcal{N}(0,I).

vae_estimator.png

AEVB algorithm:

aevb.png