Probabilities are inherently exponentially-sized objects; need to make assumptions about their structure.
Naive Bayes assumption: conditional independence among the variables
Three major parts in this course:
Representation: how to specify a (tractable) model?
Inference: given a probabilistic model, how to determine the marginal or conditional probabilities of certain events.
marginal inference: what is the probability of a given variable in the model?
maximum a posteriori (MAP) inference: most likely assignment of variables given the data.
Learning: fitting a model to a dataset. Learning and inference are inherently linked, as inference is a key subroutine that we repeatedly call within learning algorithms.
Applications:
Image generation (see Radford et al.). Sampling new images from a learned probability distribution x^∼p(x) that assigns high probability to images that resemble the ones in the training set.
In-painting: given a patch, sample from p(image∣patch) to complete the image.
Image denoising: given an image corrupted by noise, model the posterior distribution p(original image∣noisy image) and sample/use exact inference to predict the original image.
Language models:
generation: sample from a probability distribution over sequences of words or characters.
machine translation: p(target language sentence∣source language sentence)
Audio models:
upsampling or super-resolution: increase resolution of an audio signal by calculating signal values at intermediate points. Sampling/perform inference on p(intermediate values∣observed values)
speech synthesis
speech recognition
Sidenotes
Judea Pearl was awarded the 2011 Turing award (Nobel prize of computer science) for founding the field of probabilistic graphical modeling.
For a philosophical discussion of why one should use probability theory as opposed to something else: see the Dutch book argument for probabilism.
Representation: Bayesian networks
Bayesian networks are directed graphical models in which each factor depends only on a small number of ancestor variables:
p(xi∣xi−1,…,x1)=p(xi∣xAi)
where Ai is a subset of {x1,…,xi−1}.
When the variables are discrete, we may think of the possible values of p(xi∣Ai) as probability tables. If each variable takes d values and has at most k ancestors, the entire table will contain at most O(dk+1). With one table per variable, the entire probability distribution can be compactly described with only O(ndk+1) parameters compared to O(dn) with a naive approach.
Edges indicate dependency relationships between a node xi and its ancestors Ai.
Formal definition
A Bayesian network is a directed graph G=(V,E) together with:
a random variable xi for each node i∈V
one conditional probability distribution p(xi∣xAi) per node
A probability pfactorizes over a DAG (directed acyclic graph) G if it can be decomposed into a product of factors, as specified by G.
We can show by counter-example that when G contains cycles, its associated probability may not sum to one.
Dependencies
Let I(p) be the set of all independencies that hold for a joint distribution p. E.g., if p(x,y)=p(x)p(y), then x⊥y∈I(p)
Independencies can be recovered from the graph by looking at three types of structures:
Common parent:
if G is of the form A←B→C, and B is observed, then A⊥C∣B (C given B).
However, if B is unobserved then A⊥C.
Intuitively, B contains all the information that determines the outcome of A and C; once it is observed, nothing else affects their outcome (it does not matter what value A or C take respective to the outcome of C and A, respectively).
Cascade:
if G equals A→B→C and B is observed, then A⊥C∣B.
However, if B is unobserved, then A⊥C.
V-structure:
if G is A→C←B, then knowing C couples A and B. I.e. A⊥B if C is unobserved, but A⊥B∣C (C is observed).
E.g. suppose C is true if the lawn is wet and false otherwise. A (it rained) and B (the sprinkler turned on) are two explanations for it being wet. If we know that C is true (grass is wet) and B is false (the sprinkler didn't go on), then p(A)=1 (only other possible explanation). Hence, A and B are not independent given C.
We can extend these structures by applying them recursively to any larger Bayesian net. This leads to a notion called d-separation (where d stands for directed).
Let Q, W and O be three sets of nodes in a Bayesian Network G. We say that Q and W are d-separated given O (i.e. the nodes in O are observed) if Q and W are not connected by an active path. An undirected path in G is called active given observed variables O if for every consecutive triple of variables X,Y,Z on the path, one of the following holds:
X←Y←Z and Y is unobserved Y∈/O
X→Y→Z and Y is unobserved Y∈/O
X←Y→Z and Y is unobserved Y∈/O
X→Y←Z and Y or any of its descendants are observed.
(i.e. there is pairwise dependency between consecutive variables on the path)
In the following example, X1 and X6 are d-separated given X2,X3 (you cannot infer X6 from X1 given X2,X3):
However, in the next example, X2,X3 are not d-separated given X1,X6. There is an active path which passed through the V-structure created when X6 is observed:
Let I(G)={(X⊥Y∣Z):X,Yare d-sep given Z} be a set of variables that are d-separated in G.
I-map: If p factorizes over G, then I(G)⊆I(p). We say that G is an I-map (independence map) for p.
The intuition is that if X,Y and Y,Z are mutually dependent, so are X,Z. Thus, we can look at adjacent nodes and propagate dependencies.
In other words, all the independencies encoded in G are sound: variables that are d-separated in G are truly independent in p. However, the converse is not true: a distribution may factorize over G, yet have independencies that are not captured in G.
Note that if p(x,y)=p(x)p(y) then this distribution still factorizes over the graph y→x since we can always write it as p(x,y)=p(x∣y)p(y) where p(x∣y) does not actually vary with y. However, we can construct a graph that matches the structure of p by simply removing that unnecessary edge.
Representational power of directed graphs
Can directed graphs express all the independencies of any distribution p? More formally, given a distribution p, can we construct a graph G such that I(G)=I(p)?
First, note that it is easy to construct a G such that I(G)⊆I(p); A fully connected DAG G is an I-map for any distribution since I(G)=∅ (there are no variables which are d-separated in G since one can always find an active path; the trivial one is the one that connects the two variables and thus creates the dependency).
A more interesting question is whether we can find a minimalI-map. We may start with a fully connected G and remove edges until G is no longer an I-map (i.e. G encodes independences that are not in p therefore p no longer factorizes to G). One pruning method consist in following the natural topological ordering of the graph and removing node ancestors until this is no longer possible (see structure learning at the end of course).
Does any probability distribution p always admit a perfect map G for which I(p)=I(G). Unfortunately, the answer is no.\
For example, consider the following distribution p over three variables X,Y,Z (noisy-xor example)):
we sample X,Y∼Ber(0.5) from a Bernoulli distribution, and we set Z=X xor Y (Z=1 only when one of X or Y equals one).
{X⊥Y,Z⊥Y,X⊥Z}∈I(p) but Z⊥{Y,X}∈/I(p) (we can deduct Z from the other two). Thus, X→Z←Y is an I-map for p (every independency we can observe from G is encoded in p), but Z⊥Y and X⊥Z are not in I(G) (since the edge automatically creates an active path). The only way to capture all three independencies is to have all three nodes separated but this would not be an I-map (since Z⊥{X,Y} would also be included). None of the 3-node graph structures that we discussed perfectly describes I(p), and hence this distribution doesn't have a perfect map.
A related question is whether perfect maps are unique when they exist. Again, this is not the case, as X→Y and X←Y encode the same independencies, yet form different graphs. Two bayes nets G1,G2 are I-equivalent if they encode the same dependencies I(G1)=I(G2).
When are two Bayesian nets I-equivalent?
Each of the graphs below have the same skeleton (if we drop the directionality of the arrows, we obtain the same undirected graph)
a,b and c are symmetric (X⊥Y∣Z but X⊥Y). They encode exactly the same dependencies and the directionality does not matter as long as we don't turn them into a V-structure (d). The V-structure is the only one that describes the dependency X⊥Y∣Z.
General result on I-equivalence: If G,G′ have the same skeleton and the same V-structures, then I(G)=I(G′).
Intuitively, two graphs are I-equivalent if the d-separation between variables is the same. We can flip the directionality of any edge, unless it forms a v-structure, and the d-connectivity of the graph will be unchanged. See the textbook of Koller and Friedman for a full proof in Theorem 3.7 (page 77).
Representation: Markov random fields
In Bayesian networks, unless we want to introduce false independencies among the variables, we must fall back to a less compact representation (with additional, unnecessary edges). This leads to extra parameters in the model and makes it more difficult for the model to learn them.
Markov Random Fields (MRFs) are based on undirected graphs. They can compactly represent independence assumptions that directed models cannot.
Unlike in the directed case, we are not saying anything about how one variable is generated from another set of variables (as a conditional probability distribution would). We simply indicated a level of coupling between dependent variables in the graph. This requires less prior knowledge, as we no longe have to specify a full generative story of how certain variables are constructed from others. We simply identify dependent variables and define the strength of their interactions. This defines an energy landscape over the space of possible assignments and we convert this energy to a probability via the normalization constant.
Formal definition
A Markov Random Field is a probability distribution p over variables x1,…,xn defined by an undirected graph G.
p(x1,…,xn)=Z1∏c∈Cϕc(xc)
C denotes the set of cliques. A clique is a fully connected subgraph, i.e. two distinct vertices in the clique are adjacent. It can be a single node, an edge, a triangle, etc.
Each factorϕc is a non-negative function over the variables in a clique.
The partition functionZ=∑x1,…,xn∏c∈Cϕc(xc) is a normalizing constant
Note: we do not need to specify a factor for each clique.
Example
we are modeling preferences among A,B,C,D. (A,B),(B,C),(C,D),(D,A) are friends and friends have similar voting preferences.
p~(A,B,C,D)=ϕ(A,B)ϕ(B,C)ϕ(C,D)ϕ(D,A)
where ϕ(X,Y)=10 if X=Y=1,5 if X=Y=0,1 otherwise.
The final probability is:
p(A,B,C,D)=Z1p~(A,B,C,D)
where Z=∑A,B,C,Dp~(A,B,D,C)
Comparison to Bayesian networks
In the previous example, we had a distribution over A,B,C,D that satisfied A⊥C∣{B,D} and B⊥D∣{A,C} (since only friends directly influence a person's vote). We can check by counter-example that these independencies cannot be perfectly represented by a Bayesian network. However, the MRF turns out to be a perfect map for this distribution.
Advantages:
can be applied to a wider range of problems in which there is no natural directionality in variable dependencies
can succinctly express certain dependencies that Bayesian nets cannot (converse is also true)
Drawbacks:
computing the normalization constant Z requires summing over a potentially exponential number of assignments. NP-hard in the general case; thus many undirected models are intractable and require approximation techniques.
difficult to interpret
much easier to generate data from a Bayesian network.
Bayesian networks are a special case of MRFs with a very specific type of clique factor (one that corresponds to a conditional probability distribution and implies a directed acyclic structure in the graph), and a normalizing constant of one.
Moralization
A Bayesian network can always be converted into an undirected network with normalization constant one by adding side edges to all parents of a given node and removing their directionality.
The converse is also possible, but may be intractable and may produce a very large directed graph (e.g. fully connected).
A general rule of thumb is to use Bayesian networks whenever possible and only switch to MRFs if there is no natural way to model the problem with a directed graph (like the voting example).
Independencies in Markov Random Fields
Variables x and y are dependent if they are connected by a path of unobserved variables. However, if x's neighbors are all observed, then x is independent of all the other variables (since they influence x only via its neighbors, referred to as the Markov blanket of x).
If a set of observed variables forms a cut-set between two halves of the graph, then variables in one half are independent from ones in the other.
We define the Markov blanketU of a variable X as the minimal set of nodes such that X is independent from the rest of the graph if U is observed. This holds for both directed and undirected models. For undirected models, the Markov blanket is simply equal to a node's neighborhood.
Just as in the directed case, I(G)⊆(I(p)), but the converse is not necessarily true. E.g.:
Conditional Random Fields
Special case of Markov Random Fields when they are applied to model a conditional probability distributionp(y∣x) where x∈X,y∈Y are vector-valued variables. This common setting in supervised learning is also known as structured prediction.
Formal definition
A CRF is a Markov network over variables X∪Y which specifies a conditional distribution:
P(y∣x)=Z1∏c∈Cϕc(xc,yc)
with partition function:
Z(x)=∑y∈Y∏c∈Cϕc(xc,yc)
The partition function depends on x. p(y∣x) encodes a different probability function for each x. Therefore, a conditional random field results in an instantiation of a new Markov Random Field for each input x.
Example
Recognize word from a sequence of black-and-white character images xi∈[0,1]d×d (pixel matrices of size d). The output is a sequence of alphabet letters yi∈{a,b,…,z}
We could train a classifier to separately predict each yi from its xi. However, since the letters together form a word, the predictions across different positions ought to inform each other.
p(y∣x) is a CRF with two types of factors:
image factors ϕ(xi,yi) for i=1,…,n which assign higher values to yi that are consistent with an input xi. Can be seen as p(yi∣xi) given by standard softmax regression e.g.
pairwise factors ϕ(yi,yi+1) for i=1,…,n−1. Can be seen as empirical frequencies of letter co-occurences obtained from a large corpus of text.
We can jointly infer the structured label y using MAP inference:
argmaxyϕ(x1,y1)∏i=2nϕ(yi−1,yi)ϕ(xi,yi)
CRF features
In most practical applications, we assume that the factors ϕc(xc,yc) are of the form:
ϕc(xc,yc)=exp(wc⊤fc(xc,yc))
where fc(xc,yc) is an arbitrary set of features describing the compatibility between xc and yc.
In the above example:
f(xi,yi) may be the probability of letter yi produced by logistic regression or a neural network evaluated on pixels xi.
f(yi,yi+1)=I(yi=l1,yi+1=l2) where l1,l2 are two letters of the alphabet. The CRF would then learn weights w that would assign more weight to more common probability of consecutive letters (l1,l2) (useful when yi given xi is ambiguous).
CRF features can be arbitrarily complex. We can define a model with factors ϕ(x,yi)=exp(wi⊤f(x,yi)) that depend on the entire input x. Does not affect computational performance since x is observed (does not change the scope of the factors, just their values).
MRF models the joint distribution p(x,y) and needs to fit two distributions p(y∣x) and p(x). However, if we want to predict y given x, modeling p(x) is unnecessary and disadvantageous to do so:
statistically: not enough data to fit both p(y∣x) and p(x) because, since the models have shared parameters (the same graph), fitting one may not result in the best parameters for the other
computationally: we also need to make simplifying assumption in the distribution so that p(x) can be tractable.
CRFs forgo this assumption and often perform better on prediction tasks.
Factor Graphs
Bipartite graph where one group is the variables (round nodes) in the distribution being modeled and the other group is the factors defined on these variables (square nodes). Useful to see factor dependencies between variables explicitly, and facilitates computation.
Representation: Variable elimination (VE)
Marginal inference: probability of a given variable p(y=1)=∑x1⋯∑xnp(y=1,x1,…,xn)
Maximum a posteriori (MAP) inference: most likely assignment to the variables in the model possibly conditioned on evidence maxx1…xnp(y=1,x1,…,xn)
NP-hard: whether inference is tractable depends on the structure of the graph. If problem is intractable, we use approximate inference methods.
Variable elimination is an exact inference algorithm. Let xi be discrete variables taking k possible values each (also extends to continuous distributions).
Each τ has complexity 0(k2) and we calculate n of them, thus overall complexity is 0(nk2) (much better than O(kn)).
Formal Definition
Graphical model as product of factors: p(x1,…,xn)=∏c∈Cϕc(xc)
Variable elimination algorithm (instance of dynamic programming) performs two factor operations:
product ϕ3(xc)=ϕ1(xc(1))×ϕ2(xc(2)). xc(i) denotes an assignment to the variables in the scope of ϕi e.g.: ϕ3(a,b,c)=ϕ(a,b)×ϕ2(b,c)
marginalization: τ(x)=∑yϕ(x,y). τ is the marginalized factor and does not necessarily correspond to a probability distribution even if ϕ was a CPD.
Marginalizing B from ϕ(A,B,C). For (a1,c1): 0.25+0.08=0.33.
Variable elimination requires an ordering over the variables according to which variables will be eliminated (e.g. ordering implied by the DAG, see example).
the ordering affects the running time (as the variables become more coupled)
finding the best ordering is NP-hard
Let the ordering O be fixed. According to O, for each variable Xi:
Multiply all factors Φi containing Xi
Marginalize out Xi to obtain new factor τ
replace factors Φi with τ
Running time of Variable Elimination is O(nkM+1) where M is the maximum size (number of variables) of any factor τ formed during the elimination process and n is the number of variables.
Choosing optimal VE ordering is NP-hard. In practice, we resort to heuristics:
min-neighbors: choose variable with fewest dependent variables
min-weight: choose variables to minimize the product of cardinalities of its dependent variables
min-fill: choose vertices to minimize the size of the factor that will be added to the graph
Junction Tree (JT) algorithm
When computing marginals, VE produces many intermediate factors τ as a side-product. These factors are the same as the ones that we need to answer other marginal queries. By caching them, we can answer new marginal queries at no additional cost.
VE and JT algorithm are two flavors of dynamic programming: top-down DP vs bottom-up table filling.
JT first executes two runs of VE to initialize a particular data structure: bottom-up to get root probabilities τ(x)=∑yϕ(x,y); and top-down to get leaf probabilities τ(y)=∑xϕ(x,y). It can then answer marginal queries in O(1) time.
Two variants: belief propagation (BP) (applies to tree-structured graphs) and full junction tree method (general networks).
Belief propagation
Variable elimination as message passing
Consider running VE algorithm on a tree to compute marginal p(xi). Optimal ordering: rooting the tree at xi and iterating in post-order (leaves to root s.t. nodes are visited after their children). This ordering is optimal because the largest clique formed during VE has size 2.
At each step, eliminate xj by computing τk(xk)=∑xjϕ(xk,xj)τj(xj) (where xk is parent of xj). τj(xj) can be thought of as a message sent from xj to xk that summarizes all the information from the subtree rooted at xj.
Say, after computing p(xi) we want to compute p(xk); we would run VE again with xk as root. However, we already computed the messages received by xk when xi was root (since there is only one path connecting two nodes in a tree).
Message passing algorithm
(assumes tree structure)
a node xi sends a message to a neighbor xj whenever it has received messages from all nodes besides xj.
There will always be a node with a message to send, unless all the messages have been sent out.
Since each edge can receive one message in both directions, all messages will have been sent out after 2∣E∣ steps (to get all marginals, we need all incoming messages for every node: bottom-up and then top-down).
Messages are defined as intermediate factors in the VE algorithm
Two variants: sum-product message passing (for marginal inference) and max-product message passing (for MAP inference)
Sum-product message passing
mi→j(xj)=∑xiϕ(xi,xj)∏l∈N(i)∖jml→i(xi)
(i.e. sum over all possible values taken by xi)
After computing all messages, we may answer any marginal query over xi in constant time:
p(xi)∝ϕ(xi)∏l∈N(i)∖jml→i(xi)
Sum-product message passing for factor trees
Recall that a factor graph is a bipartite graph with edges going between variables and factors.
As long as there is a factor (or variable) ready to transmit to a variable (or factor), send message as defined above. Therefore, an edge receives exactly two messages (from variable to factor and factor to variable)
Max-product message passing
Since MRF factors are non-negative, max operator distributes over products, just like the sum (in general, max only distributes over products of non-negative factors, since within the max, two large negative factors can become a large positive factor, although taken separately, their maximum could be small positive factors).
Since both problems decompose the same way we can reuse the same approach as for marginal inference (also applies to factor trees).
If we also want the argmax, we can keep back-pointers during the optimization procedure.
Junction tree algorithm
If graph not a tree, inference will not be tractable.
However, we can partition the graph into a tree of clusters that are amenable to the variable elimination algorithm.
Then we simply perform message-passing on this tree.
Within a cluster, variables could be highly coupled but interactions among clusters will have a tree structure.
yields tractable global solutions if the local problems can be solved exactly.
Suppose we have undirected graphical model G (if directed, take the moralized graph). A junction tree T=(C,ET) over G=(χ,EG) is a tree whose nodes c∈C are associated with subsets xc⊆χ of the graph vertices (i.e. sets of variables). Must satisfy:
family preservation: for each factor ϕ there is a cluster c such that variables taken as inputscope[ϕ]⊆xc
running intersection: for every pair of clusters c(i),c(j), every cluster on the path between c(i),c(j) contains xc(i)∩xc(j)
Example of MRF with graph G and junction tree T:
MRF potentials (i.e. product of the factors) are denoted using different colors
circles are nodes of the junction tree
rectangles are sepsets (separation sets) (= sets of variables shared by neighboring clusters)
Trivial junction tree: one node containing all the variables in G (useless because brute force marginalization algorithm)
Optimal trees make the clusters as small and modular as possible: NP hard to find one.
Special case: when G itself is a tree, define a cluster for each edge.
Example of invalid junction tree that does not satisfy running intersection property (green cluster should contain intersection of red and purple):
Algorithm
Let us define potential ψc(xc) of each cluster c as the product of all the factors ϕ in G that have been assigned to c.
By family preservation property, this is well defined and we may assume distribution in the form:
p(x1,…,xn)=Z1∏c∈Cψc(xc)
At each step of the algorithm, we choose a pair of adjacent clusters c(i),c(j) in T and compute message whose scope is the sepset Sij between the two clusters:
We choose c(i),c(j) only if c(i) has received messages from all of its neighbors except c(j). Terminates in 2∣ET∣ steps just as in belief propagation (bottom-up then top-down to get all messages).
We then define belief of each cluster based on all the messages that it receives:
βc(xc)=ψc(xc)∏l∈N(i)ml→i(Sli)
(often referred to as Shafer-Shenoy updates)
Beliefs will be proportional to the marginal probabilities over their scopes: βc(xc)∝p(xc).
We answer queries of the form p~(x),x∈xc by marginalizing out the variable in its belief:
p~(x)=∑xc∖xβc(xc)
(requires brute force sum over all variables in xc)
Normalize by partition function Z=∑xcβc(xc) (sum of all the beliefs in a cluster).
Running time is exponential in the size of the largest cluster (because we need to marginalize out variables from the cluster; must be done using brute force).
Variable elimination over a junction tree
Running VE with a certain ordering is equivalent to performing message passing on the junction tree.
We may prove correctness of the JT algorithm through an induction argument on the number of factors ψ. The key property that makes this argument possible is the running intersection property (assures that it's safe to eliminate a variable from a leaf cluster that is not found in that cluster's sepset since it cannot occur anywhere except that one cluster)
The caching approach used for belief propagation extends to junction trees.
Finding a good junction tree
by hand: model has regular structure for which there is an obvious solution (e.g. when model is a grid, clusters are pairs of adjacent rows)
using variable elimination: running the VE elimination algorithm implicitly generates a junction tree over the variables. Thus, it is possible to use the heuristics discussed earlier
Loopy belief propagation
Technique for performing inference on complex (non-tree) graphs.
Running time of JT is exponential in size of the largest cluster. In some cases, we can give quick approximate solution.
Suppose we are given a MRF with pairwise potentials. Main idea is to disregard loops and perform message passing anyway.
Keep performing these updates for fixed number of steps or until convergence of messages. All messages are typically initialized with uniform distribution.
Performs surprisingly well in practice
Provably converges on trees and on graphs with at most one cycle
However it may not converge and beliefs may not necessarily be equal to true marginals
special case of variational inference algorithm
MAP Inference
maxxlogp(x)=maxx∑clogϕc(xc)−logZ
Intractable partition constant Z doesn't depend on x and can be ignored.
Marginal inference is summing all assignments, one of which is MAP assignment
we could replace summation with maximization, however there exists more efficient methods
Many intractable problems as special case.\
E.g.: 3-sat. For each clause c=(x∨y∨¬z) a factor θc(x,y,z)=1 if c and 0 otherwise. 3-sat instance satisfiable iff the value of the MAP assignment equals the number of clauses.
We may use similar construction to prove that marginal inference is NP-hard: add additional variable X=1 when all clauses are satisfied and 0 otherwise. Its marginal probability will be ≥0 iff the 3-sat instance is satisfiable
Example: image segmentation (input x∈[0,1]d×d; predict label y∈{0,1}d×d indicating wether each pixel encodes the object we want to recover). Intuitively, neighboring pixels should have similar values. This prior knowledge can be modeled via an Ising model (see box 4.C p. 128 in Koller and Friedman)
Graph cuts
See Koller and Friedman 13.6
Efficient MAP inference algorithm for certain Potts models over binary-valued variables. Returns optimal solution in polynomial time regardless of structural complexity of the underlying graph.
A graph cut of undirected graph is a partition of nodes into 2 disjoint sets Vs, Vt. Let each edge be associated with a non-negative cost. Cost of a graph cut is the sum of the costs of the edges that cross between the two partitions:
cost(Vs,Vt)=∑v1∈V,v2∈Vtcost(v1,v2)
min-cut problem is finding the partition that minimizes the cost of the graph cut. See algorithms textbooks for details.
Reduction of MAP inference on a particular class of MRFs to the min-cut problem
See Metric MRFs model box 4.D p. 127 in Koller and Friedman
MRF over binary variables with pairwise factors in which edge energies (i.e., negative log-edge factors) take the form:
Euv(xu,xv)=0 if xu=xv and cost of edge mismatchλuv if xu=xv
Each node has unary potential described by energy function Eu(xu) (normalized by substracting the minEuso that its ≥0 with either Eu(1)=0 or Eu(0)=0).
p(x)=Z1exp−[∑uEu(xu)+∑u,v∈EEuv(xu,xv)]
Motivation comes from image segmentation: reduce discordance between adjacent variables.
Formulate as a min-cut problem in augmented graph: add special source and sink nodes s,t
s represents the object (assignment of 0) and t the background (assignment of 1)
remember that either Eu(1)=0 or Eu(0)=1 since we normalized
node s is connected to nodes u with Eu(0)=0 by an edge with weight Eu(1)
node t is connected to nodes v with Ev(1)=0 by edge with weight Ev(0)
thus every node is either connected to the source or the sink
an edge only makes a contribution to the cost of the cut if the nodes are on opposite sides and its contribution is the λuv of the edge.
The cost of the cut (unary potentials and edge contributions) is precisely the energy of the assignment.
Segmentation task in a 2x2 MRF as a graph cut problem:
Refer to Koller and Friedman textbook for more general models with submodular edge potentials.
Linear programming approach
Graph-cut only applicable in restricted classes of MRFs.
Linear Programming (LP)
Linear programming (a.k.a. linear optimization) refers to problems of the form:
minxc⋅x subject to Ax≤b
with x∈Rn, c,b∈Rn and A∈Rn×n.
Has been extensively studied since the 1930s
Major breakthrough of applied mathematics in the 1980s was the development of polynomial-time algorithms for linear programming
Practical tools like CPLEX that can solve very large LP instances (100,000 variables or more) in reasonable time.
####Integer Linear Programming (ILP)
extension of linear programming which also requires x∈{0,1}n
NP-hard in general
Nonetheless, many heuristics exist such as rounding:
relaxed constraint 0≤x≤1 then round solution
works surprisingly well in practice and has theoretical guarantees for some classes of ILPs
Formulating MAP inference as ILP
Introduce two types of indicator variables:
μi(xi) for each i∈V and state xi
μij(xi,xj) for each edge (i,j)∈E and pair of states xi,xj
Suppose we can find dual variables δˉ such that the local maximizers of θˉiδˉ(xi) and θˉfδˉ(xf) agree; in other words, we can find a xˉ such that xˉi∈argmaxxiθˉiδˉ(xi) and xˉf∈argmaxxfθˉfδˉ(xf). Then we have that:
Second equalities follows because terms involving Lagrange multipliers cancel out when x and xf agree.
On the other hand, by definition of p∗:
∑i∈Vθi(xˉi)+∑f∈Fθf(xˉf)≤p∗≤L(δˉ)
which implies that L(δˉ)=p∗.
Thus:
bound given by Lagrangian can be made tight for the right choice of δ
we can compute p∗ by finding a δ at which the local sub-problems agree with each other
Minimizing the objective
L(δ) is continuous and convex (point-wise max of a set of affine functions, see EE227C), we may minimizing using subgradient descent or block coordinate descent (faster). Objective is not strongly convex, thus minimum is not global.
Recovering MAP assignment
As shown above, if a solution x,xf agrees for some δ, it is optimal.
If each θˉiδ∗ has a unique maximum, problem is decodable. If some variables do not have a unique maximum, then we assign their optimal values to the ones that can be uniquely decoded to their optimal values and use exact inference to find the remaining variables' values. (NP-hard but usually not a big problem)
How can we decouple variables from each other? Isn't that impossible due to the edges costs?
Local search
Start with arbitrary assignment and perform "moves" on the joint assignment that locally increases the probability. No guarantees but prior knowledge makes effective moves.
Branch and bound
Exhaustive search over the space of assignments, while pruning branches that can be provably shown not to contain a MAP assignment (like backtracking ?). LP relaxation or its dual can be used to obtain upper bounds and prune trees.
Simulated annealing
Read more about this
Sampling methods (e.g. Metropolis-Hastings) to sample form:
pt(x)∝exp(t1∑c∈Cθc(xc))
t is called the temperature:
As t→∞, pt is close to the uniform distribution, which is easy to sample from
As t→0, pt places more weight on argmaxx∑c∈Cθc(xc) (quantity we want to recover). However, since the distribution is highly peaked, it is difficult to sample from.
Idea of simulated annealing is to run sampling algorithm starting with high t and gradually decrease it. If "cooling rate" is sufficiently slow (requires lots of tuning), we are guaranteed to find the mode of the distribution.
Sampling methods
Interesting classes of models may not admit exact polynomial-time solutions at all.
Two families of approximate algorithms:
variational methods (take their name from calculus of variations = optimizing functions that take other functions as arguments): formulate inference as an optimization problem
sampling methods: main way of performing approximate inference over the past 15 years before variational methods emerged as viable and superior alternatives
Forward (aka ancestral) sampling
Bayesian network with multinomial variables. Samples variables in topological order:
start by sampling variables with no parents
then sample from the next generation by conditioning on the values sampled before
proceed until all n variables have been sampled
linear O(n) time by taking exactly 1 multinomial sample from each CPD
Assume multinomial distribution with k outcomes and associated probabilties θ1,…,θk. Subdivide unit interval into k regions with size θi,1≤i≤k and sample uniformly from [0,1]:
Forward sampling can also be performed on undirected models if the model can be represented by a clique tree with a small number of variables per node:
calibrate clique tree; gives us marginal distribution over each node and choose node to be root
marginalize over variables in the root node to get the marginal for a single variable
sample from each variables in the node, each time incorporating newly sampled values as evidence x3∼p(X3=x3∣X1=x1,X2=x2)
move down the tree to sample from other nodes and send updated message containing values of the sampled variables.
Monte Carlo estimation
Name refers to famous casino in Monaco. Term was coined as codeword by physicists working on the atomic bomb as part of the Manhattan project.
Constructs solutions based on large number of samples. E.g. consider the following integral:
Ex∼p[f(x)]≈IT=T1∑t=1Tf(xt)
Since samples are i.i.d., MC estimate is unbiased and variance is inversely proportional to T.
Rejection sampling
Special case of Monte Carlo integration. Compute area of a region R by sampling in a larger region with known area and recording the fraction of samples that falls within R.
E.g., Bayesian network over set of variables X=Z∪E. We use rejection sampling to compute marginal probabilities p(E=e):
Suppose we want to compute μ=E[f(X)] where f(x) is nearly zero outside a region A for which P(X∈A) is small. A plain Monte Carlo sample would be very wasteful and could fail to have even one point inside the region A.
Let q be a known probability density function (importance sampling distribution):
We have Eq(μ^q)=μ (unbiased estimate) and: Varq(μ^q)=Var(n1∑fp/q)=nσq2
where σq2=∫(q(x)f(x)p(x))2dx−μ2=∫(q(x)f(x)p(x)−μq(x))2dx
Therefore, the numerator is small when q is nearly proportional to fp. Small values of q greatly magnify whatever lack of proportionality appears in the numerator. It is good for q to have spikes in the same places that fp does.
Having q follow a uniform distribution collapses to plain Monte Carlo.
Normalized importance sampling
When estimating fraction of two probabilities (e.g. a conditional proba), if we sample each proba separately, errors compound and variance can be high. If we use the same samples to evaluate the fraction, estimator is biased but asymptotically unbiased and we avoid the issue of compounding errors.
Used to perform marginal and MAP inference as opposed to computing expectations.
We construct a Markov chain whose states are joint assignments of the variables and whose stationary distribution equals the model probability p.
Run Markov chain from initial state for Bburn-in steps (number of steps needed to converge to stationary distribution, see mixing time)
Run Markov chain for Nsampling steps
We produce Monte Carlo estimates of marginal probabilities. We then take the sample with highest probability to perform MAP inference.
Two algorithms:
Metropolis-Hastings
Gibbs sampling (special case of Metropolis-Hastings)
For distributions that have narrow modes, the algorithm will sample from a given mode for a long time with high probability. Therefore, convergence will be slow.
Inference as an optimization problem: given intractable distribution p and class of tractable distributions Q, find q∈Q that is most similar to p.
Unlike sampling-based methods:
variational approaches will not find globally optimal solution
but we always know if they have converged and even have bounds on their accuracy
scale better and more amenable to techniques like stochastic gradient optimization, parallelization over multiple processors, acceleration using GPUs
Kullback-Leibler divergence
Need to choose approximation family Q and optimization objective J(q) that captures similarity between q and p. Information theory provides us with Kullback-Leibler (KL) divergence.
For two distributions with discrete support:
KL(q∥p)=∑x∼qq(x)logp(x)q(x)
with properties:
KL(q∥p)≥0∀q,p
KL(q∥p)=0 iff q=p
It is however asymmetric: KL(q∥p)=KL(p∥q). That is why it is called a divergence and not a distance.
Variational lower bound
Use unnormalized distribution p~ instead of p because evaluating KL(q∥p) is intractable because of normalization constant.
−J(q) is called the variational lower bound or evidence lower bound (ELBO) and often written in the form:
logZ(θ)≥Eq(x)[logp~(x)−logq(x)]
E.g.: If we are trying to compute marginal probability p(x∣D)=p(x,D)/p(D), minimizing J(q) amounts to maximizing a lower bound on log-likelihood logp(D) of the observed data.
By maximizing evidence-lower bound, we are minimizing KL(q∥p) by squeezing it between −J(q) and logZ(θ).
KL(q∥p) or KL(p∥q)?
Computationally, computing KL(p∥q) involves an expectation with respect to p which is typically intractable to evaluate
KL(q∥p) is called the I-projection (information projection) and is infinite when p(x)=0 and q(x)>0. Therefore, if p(x)=0 we must have q(x)=0. We say that KL(q∥p) is zero-forcing for q and it under-estimates the support of p. It is called the inclusive KL divergence.
KL(p∥q) is called the M-projection (moment projection) and is infinite if q(x)=0 and p(x)>0 thus if p(x)>0 we must have q(x)>0. KL(p∥q) is zero-avoiding for q and it over-estimates the support of p. It is called the exclusive KL divergence.
E.g.: fitting unimodal approximating distribution (red) to multimodal (blue). KL(p∥q) leads to a) and KL(q∥p) leads to b) and c).
Mean-field inference
How to choose approximating family Q?
One of most widely used classes is set of fully-factored q(x)=q1(x1)q2(x2)…qn(xn). Each qi is a categorical distribution over a one-dimensional discrete variable. This is called mean-field inference and consists in solving:
minq1,…,qnJ(q)
via coordinate descent over the qj.
For one coordinate, the optimization has a closed form solution:
To compute p~(x) we only need the factors belonging to the Markov blanket of xj. If variables are discrete with K possible values and there are F factors and N variables in the Markov blanket of xj then computing expectation takes O(KFKN) time (for each value of xj we sum over all KN assignments of the N variables and, in each case, sum over the F factors.)
Learning in directed models
Two different learning tasks:
parameter learning, graph structure is known and we want to estimate the factors
structure learning, estimate the graph
Possible use cases:
density estimation: want full distribution so that we can compute conditional probabilities
specific prediction task
structure or knowledge discovery (interested in the model itself)
Maximum likelihood
Want to construct p as close as possible to p∗ in order to perform density estimation. Use KL divergence:
Hence minimizing KL divergence is equivalent to maximizing expected log-likelihood Ex∼p∗[logp(x)] (p must assign high probability to instances sampled from p∗)
Since we do not know p∗, we use a Monte-Carlo estimate of the expected log-likelihood and maximum likelihood learning is defined as:
maxp∈family of modelsM∣D∣1∑x∈Dlogp(x)
Likelihood, Loss and Risk
Minimize expected loss or risk: Ex∼p∗L(x,p)
Loss that corresponds to maximum likelihood is the log loss: −logp(x)
For CRFs we use conditional log likelihood −logp(y∣x)
For prediction, we can use classification error E(x,y)∼p∗[I{∃y′=y:p(y′∣x)≥p(y∣x)}] (probability of predicting the wrong assignment).
Better choice might be hamming loss (fraction of variables whose MAP assigment differs from ground truth).
Work on generalizing hinge loss (from SVM) to CRFs which leads to structured support vector machines.
Maximum likelihood learning in Bayesian networks
Given Bayesian network p(x)=∏i=1nθxi∣xparents(i) and i.i.d. samples D={x(1),…,x(m)} (θ parameters are the conditional probabilities)
Likelihood is L(θ,D)=∏i=1n∏j=1mθxi(j)∣xparents(i)(j)
Taking log and combinining same values: logL(θ,D)=∑i=1n∑xparents(i)∑xi∣(xi,xparents(i))∣⋅logθxi∣xparents(i)
where f(x) is a vector of indicator functions and θ is the set of all model parameters defined by logϕc(xc′)
Partition function Z(θ) equals to one for Bayesian networks but MRFs do not make this assumption.
Exponential families
See Stanford CS229 course
exponential families are log-concave in their natural parametersθ
f(x) is called vector of sufficient statistics. E.g. if p is Gaussian, f(x) contains mean and variance
exponential families make the fewest unnecessary assumptions about the data disribution. Formally, distribution maximizing entropy H(p) under constraint Ep[ϕ(x)]=α is in the exponential family.
admit conjugate priors which makes them applicable in variational inference
Maximum likelihood learning of MRFs
Since we are working with an exponential family, maximum likelihood will be concave.
Covariance matrices are always positive semi-definite, which is why logZ(θ) is convex.
Usually non-convexity is what makes optimization intractable but in this case it is the computation of the gradient.
Approximate learning techniques
Maximum-likelihood learning reduces to repeatedly using inference to compute the gradient and then change model weights using gradient descent.
Gibbs sampling from distribution at each step of gradient descent; then approximate gradient using Monte-Carlo
persistent contrastive divergence (used to train Restricted Boltzmann Machines; see link) which re-uses same Markov Chain between iterations (since model has changed very little)
Pseudo-likelihood
For each example x, pseudo likelihood makes the approximation: logp(x;θ)=∑ilogp(xi∣xN(i);θ) where xi is the i-th variable in x and N(i) is the Markov blanket of i (i.e. neighbors). Since each term involves one variable, we only need to sum over the values of one variable to get its partition function (tractable).
Pseudo-likelihood objective assumes that xi depends mainly on its neighbors in the graph.
Pseudo-likelihood converges towards true likelihood as number of data points increases.
Moment matching
Recall, log-likelihood of MRF: ∣D∣1logp(D;θ)=∣D∣1∑x∈DθTf(x)−logZ(θ)
Taking the gradient (see intro of chapter for details): ∣D∣1∑x∈df(x)−Ex∼p[f(x)]
This is precisely the difference between expectations of natural parameters under empirical (data) and model distribution.
f is a vector of indicator functions for the variables of a clique: one entry equals I(xc=xcˉ) for some xc,xcˉ.
The log-likelihood objective forces model marginals to match empirical marginals. This property is called moment matching.
When minimizing inclusive KL-divergence KL(p∥q), minimizer q∗ will match the moments of the sufficient statistics to the corresponding moments of p.
MLE estimate is minimizer of KL(p^∥q) where p^ is the empirical distribution of the data. In variational inference, minimization over q in smaller set of distributions Q is known as M-projection ("moment projection").
Learning in conditional random fields
p(y∣x)=Z(x,φ)1∏c∈Cϕc(yc,x;φ)
where: Z(x,φ)=∑y1,…,yn∏c∈Cϕc(yc,x;φ)
We can reparameterize it as we did for MRFs: p(y∣x)=Z(x,θ)exp(θTf(x,y))
Log-likelihood given dataset D is: ∣D∣1logp(D;θ)=∣D∣1∑x,y∈DθTf(x,y)−∣D∣1∑x∈DlogZ(x,θ)
There is a different partition function for each data point since it is dependent on x.
The gradient is now ∣D∣1∑x,y∈Df(x,y)−∣D∣1∑x∈DEy∼p(y∣x)[f(x,y)]
And the Hessian is the covariance matrix covy∼p(y∣x)[f(x,y)]
Conditional log-likelihood is still concave, but computing the gradient now requires one inference per training data point, therefore gradient ascent is more expensive than for MRFs.
One should try to limit the number of variables or make sure that the model's graph is not too densely connected.
Popular objective for training CRFs: max-margin loss, a generalization of the objective for training SVMs. Models trained using this loss are called structured support vector machines or max-margin networks. This loss is more widely used in practice. Only requires MAP inference rather than general (e.g. marginal inference).
Learning in latent variables models (LVMs)
LVMs enable us to leverage prior knowledge. E.g. language model of news articles. We know our set of news articles is a mixture of K distinct distributions (one for each topic). Let x be an article and t a topic (unobserved variable), we may build a more accurate model p(x∣t)p(t). We now learn a separate p(x∣t) for each topic rather than trying to model everything with one p(x). However, since t is unobserved we cannot use the previous learning methods.
LVMs also increase the expressive power of the model.
Formally, latent variable model is probability distribution p over two sets of variables x,z:
Whereas a single exponential family distribution has concave log-likelihood, the log of a weighted mixture of such distributions is no longer concave or convex.
The class website notes use the fact that an exponential family is entirely described by its sufficient statistics to derive the optimal assignment of μ and Σ in the GMM example.
EM as variational inference
Why does EM converge?
Consider posterior inference problem for p(z∣x). We apply our variational inference framework by taking p(x,z) to be the unnormalized distribution; in that case, p(x) will be the normalization constant (maximizing it maximizes the likelihood).
Recall that variational inference maximizes the evidence lower bound (ELBO):
L(p,q)=Eq(z)[logp(x,z;θ)−logq(z)]
over distributions q. The ELBO satisfies the equation
logp(x;θ)=KL(q(z)∣p(z∣x;θ))+L(p,q).
Hence, L(p,q) is maximized when q=p(z∣x); in that case the KL term becomes zero and the lower bound is tight: logp(x;θ)=L(p,q).
The EM algorithm can be seen as iteratively optimizing the ELBO over q (at the E step) and over θ (at the M) step.
Starting at some θt, we compute the posterior p(z∣x;θ) at the E step. We evaluate the ELBO for q=p(z∣x;θ); this makes the ELBO tight:
This is precisely the optimization problem solved at the M step of EM (in the above equation, there is an additive constant independent of θ).
Solving this problem increases the ELBO. However, since we fixed q to p(z∣x;θt), the ELBO evaluated at the new θt+1 is no longer tight. But since the ELBO was equal to logp(x;θt) before optimization, we know that the true log-likelihood logp(x;θt+1) must have increased.
Every step increases the marginal likelihood logp(x;θt), which is what we wanted to show.
Since the marginal likelihood is upper-bounded by its true global maximum, EM must eventually converge (however objective is non-convex so we have no guarantee to find the global optimum; heavily dependent on initial θ0)
Bayesian learning
Example of limits of maximum likelihood estimation:
we're not taking into account the confidence in our estimate considering the sample size
no prior distribution of the parameters: e.g. for out of vocabulary words, their probability will be zero.
In Bayesian learning both observed variables X and parameters θ are random variables. We're taking uncertainty over the parameters into account.
Prior distribution p(θ) encodes our initial beliefs. Choice of prior is subjective.
Useful when we want to provide uncertainty estimates about model parameters or when we encounter out of sample data.
Conjugate priors
For some choices of priors p(θ), the posterior p(θ∣D) can be computed in close form.
Suppose:
P(θ)=Beta(θ∣controls shape of Beta distributionαH,αT)
Then:
P(θ∣D)=Beta(θ∣αH+NH,αT+HT)
No need to compute: p(θ∣D)=p(D∣θ)p(θ)/p(D) where the integral p(D)=∫θp(D∣θ)p(θ)dθ is intractable.
Beta distribution
p(θ)=Beta(θ∣αH,αT)=B(αH,αT)θαH−1(1−θ)αT−1
where B is a normalization constant.
If αH>αT, we believe heads are more likely:
E[θ]=αH+αTαH
As αH or αT increase, variance decreases, i.e. we are more certain about the value of θ:
Var(θ)=(αH+αT)2(αH+αT+1)αHαT
The Beta distribution family is a conjugate prior to the Bernoulli distribution family, i.e. if p(θ) is a Beta distribution and p(X∣θ) is a Bernoulli distribution then p(θ∣D=(X1,…,XN)) is still a Beta distribution.
Categorical Distribution
Parameter of the categorical distribution:
θ=(θ1,…,θK):=(P(X=1),…,P(X=K))
where θ sums to 1.
The Dirichlet distribution is the conjugate prior for the categorical distribution. Dirichlet distribution is defined by parameter α=(α1,…,αK) and its pdf is:
Used in topic modeling (e.g. latent dirichlet allocation)
Limits of conjugate priors
restricts the kind of priors we can use
for more complex distributions, posterior can still be NP hard to compute
Practitioners should compare it with other tools such as MCMC or variational inference.
Structure learning for bayesian networks
Learn structure of Directed Acyclic Graph from data. Two approaches: score-based and constraint-based.
Score-based approach
Search over space of DAGs to maximize score metric.
Score metric: penalized log-likelihood
score(G:D)=LL(G:D)−ϕ(∣D∣)∥G∥
LL(G:D): log-likelihood of the graph under graph structure G
parameters in G are estimated based on MLE and log-likelihood score is calculated based on parameters
if we only considered log-likelihood, we would end up with a complete graph (overfitting). Second term penalizes over-complicated structures.
∣D∣ is the number of samples and ∥G∥ is the number of parameters in G
For AIC: ϕ(t)=1, for BIC: ϕ(t)=log(t)/2 (for BIC, influence of model complexity decreases as the number of samples grow, allowing log-likelihood to dominate the score
Ni,π,j count of variable i taking value j with parent configuration π
N′ counts in the prior
With prior for graph structure P(ΘG), BD score:
logP(D,ΘG)=logP(D∣ΘG)+logP(ΘG)
Overfitting is implicitly penalized via integral over parameter space.
Chow-Liu Algorithm
Finds maximum-likelihood tree where each node has at most one parent. We can simply use maximum likelihood as a score (already penalizing complexity by restricting ourselves to tree structures).
Likelihood decomposes into mutual information and entropy (∑plogp).
Entropy is independent of tree structure (so we eliminate it from the maximization). Mutual information is symmetric, therefore edge orientation does not matter.
3 steps:
compute mutual information for all pairs of variables X,U (edges):
MI(X,U)=∑x,up^(x,u)log[p^(x)p^(u)p^(x,u)]
where p^(x,u)=number of data pointscounts(x,u)
complexity: O(n2)
find maximum weight spanning tree (tree that connects all vertices) using Kruskal or Prim Algorithms (complexity: O(n2))
pick a node to be root variable, assign directions going outward of this node.
Search algorithms
local search: start with empty/complete graph. at each step, perform single operation on graph strucure (add/remove/reverse edge, while preserving acyclic property). If score increases, then adopt attempt otherwise, make other attempt.
greedy search (K3 algorithm): assume topological order of graph in advance (for every directed edge uv from vertex u to vertex v, u comes before v in the ordering). Restrict parent set to variables with a higher order. While searching for parent set for each variable, add parent that increases the score most until no improvement can be made. (Doesn't that create a complete graph? No: restricted by the topological order). When specified topological order is a poor one, results in bad graph structure (low graph score).
Space is highly non-convex and both algos might get stuck at sub-optimal regions.
Constraint-based approach
Employs independence test to identify a set of edge constraints for the graph and then finds best DAG that satisfies the constraint.
E.g.: distinguish V-structure and fork-structure by doing independence test for the two variables on the side conditional on the variable in the middle. Requires lots of data samples to guarantee testing power.
Recent advances
order search (OS): search over topological orders and graph space at the same time. It swaps the order of two adjacent variables at each step and employs K3 algorithm as a sub-routine.
ILP: encodes graph structure, scoring and acyclic constraint into ILP problem and uses solver. Approach requires a bound on the maximum number of parents of any node. Otherwise, number of constraints explodes and problem becomes intractable
Variational auto-encoder
Deep learning technique for learning latent representations.
Directed latent-variable model: p(x,z)=p(x∣z)p(z)
Deep generative model with m layers: p(x,z1,…,zm)=p(x∣z1)p(z1∣z2)…p(zm−1∣zm)p(zm)
Objectives:
learning parameters θ of p
approximate posterior inference over z (given image x, what are its latent factors)
approximate marginal inference over x (given image x with missing parts, how do we fill them in)
Assumptions:
intractable posterior probability p(z∣x)
dataset is too large to fit in memory
Standard approaches?
EM: need to compute posterior p(z∣x) (intractable) + would need to use online version of EM to perform M-step.
mean field: requires us to compute expectation. Time complexity scales exponentially with size of Markov blanket of target variable. For z, if at least one component of x depends on each component of z, this introduces V-structure (x explains away differences among z) and Markov blanket of some zi contains all the other z-variables (intractable)
sampling-based methods: authors (Kingma and Welling) found that they don't scale well on large datasets
Auto-encoding variational Bayes
seminal paper by Kingma and Welling
variational auto-encoder is one instantation of this algo
ELBO:
L(pθ,qϕ)=Eqϕ(z∣x)[logpθ(x,z)−logqϕ(z∣x)]
satisfies equation:
logpθ(x)=KL(qϕ(z∣x)∥p(z∣x))+L(pθ,qϕ)
We are conditioning q(z) on x. Could we optimize over q(z∣x) using mean field? No, assumption that q is fully factored is too strong.
Approach: black-box variational inference
gradient descent over ϕ (only assumption: qϕ differentiable as opposed to coordinate descent)
simultaneously perform learning via gradient descent on both ϕ (keep ELBO tight around logp(x)) and θ (push up lower bound (hence logp(x))). Similar to EM algo.
We need to compute gradient:
∇θ,ϕEqϕ(z)[logpθ(x,z)−logqϕ(z)]
For gradient over θ we can swap with expectation and estimate Eqϕ(z)[∇θlogpθ(x,z)] via Monte Carlo.
calling q the encoder, both terms take sample z∼q(z∣x) which we interpret as a code describing x
first term is log-likelihood of observed x given code z. p is called decoder network and the term is called the reconstruction error
second term is divergence between q(z∣x) and prior p(z), which is fixed to be a unit Normal. Encourages the codes z to look Gaussian. It's a regularization term that prevents q(z∣x) to encode a simple identity mapping.
reminiscent of auto-encoder neural networks (learn xˉ=f(g(x)) by minimizing reconstruction loss ∥xˉ−x∥), hence the name.
Main contribution of the paper is a low-variance gradient estimator based on the reparameterization trick:
where noise variable ϵ is sampled from a simple distribution p(ϵ) (e.g. standard normal) and deterministic transformation gϕ(ϵ,x) maps random noise into more complex distribution qϕ(z∣x).
This approach has much lower variance (see appendix of this paper by Rezende et al.) than the score function estimator.
q and p are parameterized by neural networks (bridge between classical machine learning method (approximate Bayesian inference here) and modern deep learning):
q(z∣x)=N(z;μ(x),diag(σ(x))2)
p(x∣z)=N(x;μ(z),diag(σ(z))2)
p(z)=N(z;0,I)
where the μ and σ functions are neural nets (two dense hidden layers of 500 units each).
Experimental results:
Monte-Carlo EM and hybrid Monte-Carlo are quite accurate, but don’t scale well to large datasets.
Wake-sleep is a variational inference algorithm that scales much better; however it does not use the exact gradient of the ELBO (it uses an approximation), and hence it is not as accurate as AEVB.