Sum Product Algorithm and Graph Factorization
This post describes the Sum Product algorithm which is a neat idea that leverages a handful of important theories from across the field and extends towards more complex models. I found it to be a very nice method for accessing more complex models and it is infact a generalization of Markov Chains, Kalman Filter, Fast Fourier Transform, Forward-Backward algorithm and more. I’m aware that the approach I use may not satisfy everyone and in that case, take a look at the reference section for some excellent resources.
Minimum Minimorum
I assume familiarity with the 2 fundamental rules of probability, the product rule and the sum rule given below and the associated notions of joint distributions and marginals. We will be making extensive use of these.
What and Why - Sum Product Algorithm?
It is used for Inference, which is a frequently used word in statistics to mean marginalizing a joint distribution so we can be informed of something that was unknown given the other known variables. An issue with marginalizing a joint is that it quickly becomes intractable, i.e., computationally impossible due to the size of the numbers involved. For instance, say you have a model with 100 binary variables (each variable is 0 or 1) and you are now faced with marginalizing a joint distribution with $2^{100}$ terms. Here we can make a simplifying assumption that our joint distribution is not exactly a general joint but a factorized distribution where each of the factors just depend on a few local variables that are ‘close’ to it. This is the underlying assumption that makes fast inference through Sum Product possible.
Factorization
I have been mentioning factors and all it means here is that there’s a different way of specifying joint distribution wherein the function (joint distribution) $p(x_1,…,x_n)$ factors into a product of several local functions each of which only contain a subset of the variables.
where $X_s$ is a subset of the variables.
Factor Graphs
Factorization can be verbosely represented through factor graphs because it makes it explicit by adding additional nodes for factors.
A factor graph is a bipartite graph that expresses the structure
of the factorization. A factor graph has a variable node for each variable $x_i$,
a factor node for each local function $f_s$, and an edge-connecting variable node
$x_i$ to factor node $f_s$ if and only if $x_i$ is an argument of $f_s$.
An example factorization and its corresponding graph is given below.
Factor graphs as Expression Trees
If the factor graph doesn’t contain cycles, then it can be represented as a tree and computation can be simplified using the distributive law of multiplication - \ref{distributive}. We can view this with a ‘message-passing’ analogy whereby the marginal variable is the ‘product’ of ‘messages’. This idea is made clear in the next section. To convert a function representing $p(x_1,…,x_n)$ to the corresponding expression tree for $p(x_i)$, rearrange the factor graph as a rooted tree with $x_i$ as root.
Algorithm in action
I found that the best way to learn the algorithm was to see it in execution given the basic ingredients we have gathered so far. There are missing pieces which will be explained as it appears.
Problem - Find the marginal
The factor graph describes the factorization given
And we want to find the marginal $p(c)$. Using \ref{sum}
We can represent the factor graph as a tree and this will give us the ground to build an intuition of message passing. Notice that the tree is simply rearranged to reflect our problem of finding $p(c)$.
Message Passing
Using the message passing analogy, picture the marginal as a message comprised of several messages that were gathered along the branches of the tree. Substituting \ref{cc} in \ref{dd}, we get a form that we can start to work on
- Variable $c$ is composed of 3 messages that it received from each of its neighboring factors. To get the message of a variable node, simply multiply all the incoming messages from neighboring factor nodes. $$p(c) = m_{f3 \rightarrow c}(c) m_{f4 \rightarrow c}(c) m_{f5 \rightarrow c}(c)$$ where $m_{x \rightarrow y}(z)$ represents a message sent from node $x$ to node $y$ which is a function of variable $z$ because the other variables have been summed out.
- A natural question now is what the message in the factor nodes are. Let's go through the first factor. In case you are left wondering how something is the way it is, remember that everything is a combination of \ref{distributive} and \ref{sum}. To initiate messages at leaf nodes, the following rules are used depending on if its a leaf or factor node:
- $m_{x \rightarrow f}(x)=1$ a leaf variable node sends an identity function.
- $m_{f \rightarrow x}(x)=f(x)$ a leaf factor node sends a description of the function to its parent
- Take product of incoming messages into factor node.
- Multiply by factor associated with the node
- Marginalize over all variables associated with incoming messages by pulling out the summations.
By recursively applying the two rules we have seen, the two incoming messages for $f_3$
Note that the right hand side of both the above equations can be seen as the message from the factor node since variable node simply multiply the messages of factor nodes.And what we are left with is
In the original paper, the authors propose a different notation for equations like above called ‘not-sum’ or summary notation which gives \ref{3-c} the form
Instead of indicating the variables being summed over, we indicate
those variables not being summed over.
The message sent from a node $v$ on an edge $c$ is the product of the
local function at $v$(or the unit function if $v$ is a variable node) with
all messages received at $v$ on edges other than $e$, summarized for
the variable associated with $e$. Variable to local function $$m_{x \to f} (x) = \prod_{h \in n(x) \backslash \{f\}} m_{h \to x} (x)$$ Local function to variable $$m_{f \to x} (x) = \sum_{\sim \{x\}} \Big( f(X) \prod_{y \in n(f) \backslash \{x\}} m_{y \to f} (y) \Big)$$
We now know everything required to complete the marginal. Applying the equations above, the final form is
Marginal for every node
Having done the work of finding a marginal for $p(x)$, and going on to think of calculating the marginal for $p(y)$ both variables in the same factor graph, one might notice redundancies because the sub computations for evaluating messages are the same. We can take advantage of this by picking any node and propagating messages from leaf to the root as shown above and from the root back to the leaf so that every node has seen two messages, caching the evaluations all along. A slight increase in computation but now we have all the marginals.
Summary
Hopefully this post introduced Factor Graphs and Sum Product algorithm and provided an intuition of the ideas. It is essentially a method for exact inference of the marginal given a factorized distribution in an acyclical graph.
Next Steps
I strongly recommend reading the original paper as this post introduced less than half of the paper. It’s very accessible with familiar examples and contains a lot more information including applications.
I also avoided talking about cyclic graphs which is where most of the interesting problems lie and the paper discusses interesting ways of working around this with algorithms like Junction Tree and Loopy Belief but you already know everything to understand them.
Another avenue for thought is working with continuous variables, i.e., when the messages are intractable. Interesting techniques like Monte Carlo and Variational inference are used in such cases which are whole books worth of content on its own.
Here is a nice python implementation of the algorithm for the code savvy learners.
References
Factor Graphs and the Sum-Product Algorithm
Christopher Bishop’s presentation video
Ilya’s Sum Product Python implementation
Sergey Dovgal’s post