Conditional densities and affine transformations#
In this chapter we will learn how GaussianConditional
can be used in the GT
. Imagine we have a set of random variables, where the marginal is given by \(p(X)={\cal N}(\mu_X,\Sigma_X)\).
[1]:
from jax import numpy as jnp
from matplotlib import pyplot as plt
import jax
from gaussian_toolbox import pdf
Dx = 1
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
mu_X = jax.random.normal(subkey, (1,1))
Sigma_X = jnp.array([[[1,]]])
p_X = pdf.GaussianPDF(Sigma=Sigma_X, mu=mu_X)
Now we have a second set of variabels \(Y\), for which we only know the conditional \(p(Y\vert X)={\cal N}(\mu_{Y\vert X},\Sigma_{Y\vert X})\), where
Let’s create such an object in GT
.
[2]:
from gaussian_toolbox import conditional
Dy = 1
key, subkey = jax.random.split(key)
M = jax.random.normal(subkey, (1, Dy, Dx))
key, subkey = jax.random.split(key)
b = jax.random.normal(subkey, (1, Dy))
Sigma_YcX = Sigma_X = jnp.array([[[.2,]]])
p_Y_given_X = conditional.ConditionalGaussianPDF(M=M, b=b, Sigma=Sigma_YcX)
print(p_Y_given_X)
Conditional Gaussian density p(y|x)
In such a setup one is often interested in the following affine operations
Joint transformation: \(T_{\rm joint}[p(Y|X),p(X)] = p(X,Y)\)
Marginal transformation: \(T_{\rm marg}[p(Y|X),p(X)] = p(Y)\)
Conditional transformation: \(T_{\rm cond}[p(Y|X),p(X)] = p(X\vert Y)\)
When \(p(X)\) and \(p(Y\vert X)\) have the form as described above GT
enables us to quickly perform the operations. E.g. let’s do the joint operation.
[3]:
p_XY = p_Y_given_X.affine_joint_transformation(p_X)
This transformation gives back a GaussianPDF
object, which we have encountered in the first part of the tutorial. Hence, we can solve integrals, evaluate it at different points, and sample it. Same holds for the marginal transform:
[4]:
p_Y = p_Y_given_X.affine_marginal_transformation(p_X)
[5]:
from gaussian_toolbox import factor
# Create plot grid
reso = 70
xrange = (jnp.linspace(-5,5,reso) + p_X.mu).T
yrange = (jnp.linspace(-5,5,reso) + p_Y.mu).T
xmesh, ymesh = jnp.meshgrid(xrange[:,0], yrange[:,0])
xy = jnp.vstack([xmesh.flatten(), ymesh.flatten()]).T
from matplotlib import gridspec
fig = plt.figure(figsize=(6,6))
grid = gridspec.GridSpec(6,6,fig)
ax1 = fig.add_subplot(grid[1:,:1])
ax1.plot(p_Y(yrange).T, yrange[:,0], 'C3')
ax1.set_ylim(yrange[jnp.array([0,-1]),0])
ax1.set_ylabel('$Y$')
ax1.set_xlabel('p(Y)')
ax2 = fig.add_subplot(grid[:1,1:])
ax2.plot(xrange,p_X(xrange).T)
ax2.set_xlim(xrange[jnp.array([0,-1]), 0])
ax2.set_ylabel('p(X)')
ax2.set_xticklabels([])
ax3 = fig.add_subplot(grid[1:,1:])
ax3.pcolor(xrange[:,0], yrange[:,0], p_XY(xy).reshape((reso,reso)), shading='nearest')
ax3.set_xlabel('$X$')
ax3.set_xlim(xrange[jnp.array([0,-1]),0])
ax3.set_ylim(yrange[jnp.array([0,-1]),0])
ax3.set_yticklabels([])
plt.show()
The conditional transformation is a bit different, in the sense, that it does not return a GaussianPDF
, but again a ConditionalDensity
. To check consistency let’s check, whether
[6]:
p_X_given_Y = p_Y_given_X.affine_conditional_transformation(p_X)
p_X2 = p_X_given_Y.affine_marginal_transformation(p_Y)
plt.plot(xrange[:,0], p_X(xrange).T, label='p(X)', lw=2)
plt.plot(xrange[:,0], p_X2(xrange).T, 'C3--', label=r'$T_{\rm marg}[T_{\rm cond}[p(Y|X),p(X)], p(Y)] $', lw=2)
plt.legend(loc=1)
plt.xlim(xrange[jnp.array([0,-1]),0])
plt.xlabel('$X$')
plt.ylabel('$p(X)$')
plt.show()
As we see, the transformation is consistent.
Sometimes, we also which to condition on some specific value, i.e. \(p(Y\vert X=x)\) which transforms the ConditionalGaussianPDF
again in a GaussianPDF
. Let’s condition \(p(X\vert Y=4)\).
[7]:
p_Y_given_4 = p_Y_given_X(jnp.array([[2]]))
And again, we can see by visual inspection, that it makes sense
[8]:
from gaussian_toolbox import factor
from matplotlib import gridspec
fig = plt.figure(figsize=(6,6))
grid = gridspec.GridSpec(6,6,fig)
ax1 = fig.add_subplot(grid[1:,:1])
ax1.plot(p_Y_given_4(yrange).T, yrange[:,0], 'C3')
ax1.set_ylim(yrange[jnp.array([0,-1]),0])
ax1.set_ylabel('$Y$')
ax1.set_xlabel(r'$p(Y\vert X=4)$')
ax2 = fig.add_subplot(grid[:1,1:])
ax2.plot(xrange,p_X(xrange).T)
ax2.set_xlim(xrange[jnp.array([0,-1]), 0])
ax2.set_ylabel('p(X)')
ax2.set_xticklabels([])
ax3 = fig.add_subplot(grid[1:,1:])
ax3.pcolor(xrange[:,0], yrange[:,0], p_XY(xy).reshape((reso,reso)), shading='nearest')
ax3.set_xlabel('$X$')
ax3.vlines(4,yrange[0],yrange[-1], color='C3')
ax3.vlines(2,yrange[0], yrange[-1], color='C3')
ax3.set_xlim(xrange[jnp.array([0,-1]),0])
ax3.set_ylim(yrange[jnp.array([0,-1]),0])
ax3.set_yticklabels([])
plt.show()
Why are affine transformations so useful? Let’s consider Bayes’ rule
where \({\cal D}\) is data and \(W\) parameters.
You start with likelihood \(p({\cal D}\vert W)\) (GaussianConditionalDensity
) and prior \(p(W)\) (GaussianPDF
). The affine transformations, give us the objects of interest in Bayesian inference. + Joint transformation \(p({\cal D}, W) = T_{\rm joint}[p({\cal D}|W),p(W)]\) to get joint likelihood. + Marginal transformation \(p({\cal D}) = T_{\rm marg}[p({\cal D}|W),p(W)]\) to get marginal likelihood. + Conditional transformation
\(p(W\vert {\cal D}) = T_{\rm cond}[p({\cal D}|W),p(W)]\) to get posterior.
In the next chapter, we will see as an example how easy linear regression is done.