This is the Variational Auto-Encoder paper by Kingma & Welling.
For a simple pytorch implementation, check out my github repo, where I auto-encoded CryptoPunks.
Problem statement
There is a random variable X that is generated by:
sampling a latent variable z∼pθ∗(z)
sampling x∼pθ∗(x∣z)
where the true parameters θ∗ of the distribution and the latent variable z are hidden.
Reminder on the evidence lower bound
Suppose that we want to estimate the distribution pθ∗ by estimating parameters θ. θ is both used to generate z and then x (by sampling pθx∣z).
Let's say we found some appropriate θ, it's intractable to estimate pθ(x) for a given x because we need to marginalize over all zs:
pθ(x)=∫pθ(x∣z)pθ(z)dz
However, we also know that, given a certain z we have: pθ(x,z)=pθ(x/vertz)p(z) and pθ(x,z)=pθ(z∣x)p(x).
Combining the two:
pθ(x)=pθ(z)pθ(x∣z)/pθ(z∣x) for any z.
If we can find a good approximation of pθ(z∣x), we can compute pθ(x). However:
we might not have a closed form solution
pθ(z∣x)=pθ(x∣z)pθ(z)/pθ(x) and we already established that p(x) is intractable.
Therefore, we'll use another distribution family qϕ(z∣x) to approximate pθ(z∣x).
Now, to estimate the best model pθ we wish to maximize it's likelihood, which is the probability density function: pθ(x). For convenience, we usually use the log-likelihood. The optimization problem is equivalent since log is a monotonously increasing function.
lnpθ(x)=ln∫pθ(x,z)dz
Recall that this is intractable, because of the integral over z. We can estimate this using importance sampling:
Now, this term: Eqϕ(z∣x)lnqϕ(z∣x)pθ(x,z) is what we call the evidence lower bound L(θ,ϕ). But it's not over. Let's put it on the left hand side and combine it with the log-likelihood term lnpθ(x):
Therefore, by maximizing the ELBO we are maximizing the log-likelihood. The point of the paper is to find a way to differentiate and optimize the ELBO with low variance.
To recap, here's our encoder: qϕ(z∣x)≈pθ(z∣x); here's our decoder pθ(x∣z). We'll learn ϕ and θ jointly.
We want the algo to work in the case of:
intractibility: can't compute pθ(z)=∫pθ(z)pθ(x∣z)dz or pθ(z∣x)=pθ(x)pθ(x∣z)pθ(z) (EM can't be used). pθ(x) intractable because of large dimensionality (image). pθ(x∣z) intractable due to large number of hidden variables (in case of neural net with nonlinear hidden layer for example).
large dataset: monte carlo EM would be too slow (expensive sampling loop per datapoint)
Some relevant applications:
the parameters θ can be of interest if we're analyzing some natural process or want to generate artificial data (by sampling p(x∣θ)). We want efficient max likelihood (maximize probability of seeing the data given the model) or max à priori estimation (maximize probability of the model given the data; requires a model prior) of parameters θ.
representing data (e.g. generating image embeddings): posterior inference of z given x
marginal inference of x for tasks where a prior over x is required like image denoising, inpainting, superresolution
Stochastic Gradient Variational Bayes
The ELBO L(pθ,qϕ) can also be written as L(pθ,qϕ)=−DKL(qϕ(z∣x)∥pθ(z))+Eqϕ(z∣x)(logpθ(x∣z))
The first term is the KL divergence which can be seen as a regularization term with respect to the prior pθ(z). The second term is the reconstruction loss: given z we want to reconstruct x.
Naively, we can use a monte carlo gradient estimator for either term: ∇ϕEqϕ(z)[f(z)]=∇ϕ∫zq(z)f(z)dz=∫z∇ϕq(z)f(z)dz=∫zq(z)∇ϕlogqϕ(z)q(z)1∇ϕq(z)f(z)dz=E[f(z)∇ϕlogqϕ(z)]
where f(z)=logqϕ(z∣x)pθ(z) or f(z)=logpθ(x∣z)
Thus: ∇ϕEqϕ(z)[f(z)]≈L1∑f(z)∇ϕlogqϕ(z)
However, this estimator has very high variance.
We can reparameterizeqϕ(z∣x) as z~=gϕ(ϵ,x),ϵ∼p(ϵ).
KL divergence can be integrated analytically (e.g. when prior pθ(z) and posterior qϕ(z∣x) are gaussian, see appendix B in the paper). KL divergence can be interpreted as regularization to encourage q(x∣z) to be close to prior. Only the expected reconstruction error requires estimation Eqϕ(z∣x)[logpθ(x∣z)]
Example: Variational Auto-Encoder
we set a prior over the latent pθ(z)=N(z;0,I).
For discrete data, encoder pθ(x∣z) is a multivariate Bernoulli distribution:
logp(x∣z)=∑xilogyi+(1−xi)log(1−yi)
(product of each pixel independently, p(xi∣z)=yi if xi=1 and 1−yi otherwise)
where yi=sigmoid(W2+tanh(W1z+b1)+b2)
For real value data, encoder pθ(x∣z) is a multivariate gaussian:
logp(x∣z)=logN(x;μ,σ2I)
where μ=W1h+b1,logσ2=W2h+b2,h=tanh(W0z+b0).
The true posterior pθ(z∣x) is intractable. We approximate it with decoder qϕ(z∣x). For qϕ(z∣x), same formula as multivar gaussian (with diff parameters) and z and x are swapped
KL divergence can be computed and differentiated analytically. For reconstruction loss, we sample z∼μ+σϵ where ϵ∼N(0,I).