Continuing with the theme of neat tricks that can be very useful in Machine Learning, I wanted to share another insight that I was made aware of by Jamie Townsend who I think may have heard it from Matthew Johnson. Jamie and Matthew are both contributors/authors of the Python autograd package which lets you differentiate native Python. Hugely useful for prototyping.
The trick that Jamie pointed out to me is that for exponential family models, because the expected sufficient statistics are related straightforwardly to the derivative of the log-normaliser, you can do a lot of your inference with no extra effort by just leveraging automatic differentiation. I’ll very briefly recap the exponential family and the connection of derivatives to sufficient statistics before giving a quick example of just how easy this can be.
Exponential Family distributions
An exponential family distribution is just one that can be written in the following form:
Where are the standard parameters of the distribution, is a function of those parameters and is known as the natural parameters and are the sufficient statistics. is simply the reciprocal of the normaliser. Many if not most of the standard distributions we deal with can be written in this way eg. Normal, Poisson, Exponential, Laplacian, Bernoulli, Beta…
The Derivatives of Log-normaliser are the Expected Sufficient Stats
There is a well known identity that says for the exponential family:
which is fairly straightforward to show as follows:
and we know that is the inverse of the normaliser:
and so:
Thus:
Using auto-diff to do your inference
Many times when doing inference of graphical models the expected sufficient statistics are actually exactly what we want. For example when we do the EM algorithm over jointly exponential family models, the E-step only requires us to find the expectations of the sufficient statistics under the posterior on the latents.
In the autograd library, there is a nice example of taking advantage of this fact to do inference in an HMM. Traditionally you would do your inference in an HMM using the forward-backward or Baum-Welch algorithm but the autograd library takes advantage of the exponential family structure to do all the inference using differentiation.
For an HMM with observed variables , discrete latents and transition probabilities , . The joint probability over all the and is given by:
where we have encoded the and using one-hot vectors. This is clearly in the exponential family with sufficient statistics , and .
To do inference what we need is the expectation of these statistics under . In fact all we really need is the log normaliser . Once we have the normaliser we can take its derivative and this will automatically take care of all our backwards message passing for us. To calculate the log normaliser, we do the usual forwards pass and then sum all the forward messages.
Thats exactly what the autograd library does in the following example. You can find the rest of the code here.
# First get the log partition function
def log_partition_function(natural_params, data):
if isinstance(data, list):
return sum(map(partial(log_partition_function, natural_params), data))
log_pi, log_A, log_B = natural_params
log_alpha = log_pi
for y in data:
log_alpha = logsumexp(log_alpha[:,None] + log_A, axis=0) + log_B[:,y]
return logsumexp(log_alpha)
and then you can do all the inference simply using gradients:
def EM_update(params):
natural_params = list(map(np.log, params))
# E step
loglike, E_stats = val_and_grad(log_partition_function)(natural_params, data)
if callback: callback(loglike, params)
# M step
return list(map(normalize, E_stats))
Why does this work?
At first it might seem a bit mysterious that we can replace a forwards and backwards message passing algorithm by auto-diff alone but the reason this works is because the sum product algorithm and reverse mode auto-diff are in fact very similar. Reverse mode auto-diff first does a forward pass to calculate the value of the function and then retraces the computation graph multiplying gradients along branches and summing at mergers in a backwards pass. This is highly analogous to the standard forward backwards algorithm.