Machine learning is full of lots of seemingly simple mathematical “tricks” that have a disproportionate usefulness relative to their complexity. The Reparamaterisation trick is one of these. In the next few posts I want to go into more detail about variational auto-encoders and the trick will come up there again but for the time being I just want to present a neat and useful “trick” that I think can be confusing when you first come across it.

Often times in machine learning we want to estimate the gradient of an expectation using sampling. This happens in reinforcement learning when doing policy gradients and in stochastic variational inference. The problem we are trying to solve is to find:

As long as is independent of there’s no problem here, we can use Leibniz rule and simply bring the derivative into the expectation like follows:

and then construct a monte-carlo estimate for the gradient:

where is sampled from . If however, as is often the case, the distribution depends on . i.e , then the above trick wont work because after you bring the derivative inside the integral (or sum), the new integral is no longer an expectation and we cant construct a straightforward monte-carlo estimate.

There are two ways to get around this. The first is used in the Reinforce algorithm and is known as the log derivative trick. It relies on using the following identity, which is easy to verify:

If we substitute this identity for in the following equation then we get:

and since this is still an expectation we can again construct a simple monte-carlo estimate of the gradient.

job done right? well sometimes yes but sometimes this naive substitution yields a gradient estimate that is too high variance to be useful. In those cases we can turn to the reparamaterisation trick.

The idea behind the reparamaterisation trick is write as a function of a varible where does not depend on . If we can do this, then we are back in the first situation we considered where the gradient operator can simply be brought into the expectation without any complexity. Put another way we want to find a differentiable function such that:

and

where is a distribution that has no dependence. An example of such a function would be the following reparamaterisation of a normal distribution , where what we were calling is now and . In this case if we define:

and then sample , we then have that .

Once we’ve made this reparamaterisation we can return to our original problem and see that the issues have disappeared.

with . Or more concretely for our Gaussian case:

where .

It might at first seem that constructing these reparamaterisations might be very hard. In fact for discrete distributions, it is very hard and in a later post we may discuss ways around this such as the Gumbel Softmax. However for continuos distributions when you consider the fact that most random number generators first generate uniform distributed variables and transform them, you realise that a lot of complex distributions can be built straightforwardly from simpler distributions.