MDNs take your boring old neural network and turn it into a prediction powerhouse. Why settle for just one prediction when you can have a whole buffet of potential outcomes?

## The basic idea

In an MDN, the probability density of the target variable *t* given the entry *X* is represented as a linear combination of kernel functions, usually Gaussian functions, but not limited to them. In mathematics, speak:

Where 𝛼*ᵢ(x)* are the mixing coefficients, and who doesn't love a good mixing, right? 🎛️ These determine how much *weight* each component *𝜙ᵢ(t|x) — *every Gaussian in our case is valid in the model.

## Mix the Gaussians ☕

Each Gaussian component *𝜙ᵢ(t|x)* has its own mean 𝜇*ᵢ(x)* and variance 𝜎*ᵢ*².

## Mix 🎧 with coefficients

The mixing coefficients 𝛼*ᵢ* are crucial because they balance the influence of each Gaussian component, governed by a *softmax* function to make sure they add up to 1:

## Magic Parameters ✨ Averages and deviations

Means 𝜇*ᵢ* and deviations 𝜎*ᵢ*² define each Gaussian. And guess what? The differences must be positive! We achieve this by using the exponential of the network outputs:

Okay, so how do we train this beast? Well, it's about maximizing the probability of our observed data. Fancy terms, I know. Let's see it in action.

## The fate of log-likelihood ✨

The probability of our data under the MDN model is the product of the probabilities assigned to each data point. In mathematics, speak:

This basically says, *“Hey, what’s the chance we have this data given our model?”*. But products can get complicated, so we take the newspaper (because math loves newspapers), which turns our product into a sum:

Now here's the thing: we actually want to minimize the probability of negative logging because our optimization algorithms like to minimize things. So, plugging in the definition of *p(t|x)*the error function that we actually minimize is:

This formula may seem intimidating, but it simply says that we summarize the log probabilities for all data points and then add a negative sign, because minimization is our problem.

Now here's how to translate our magic into Python, and you can find the full code here:

## The loss function

`def mdn_loss(alpha, sigma, mu, target, eps=1e-8):`

target = target.unsqueeze(1).expand_as(mu)

m = torch.distributions.Normal(loc=mu, scale=sigma)

log_prob = m.log_prob(target)

log_prob = log_prob.sum(dim=2)

log_alpha = torch.log(alpha + eps) # Avoid log(0) disaster

loss = -torch.logsumexp(log_alpha + log_prob, dim=1)

return loss.mean()

Here is the breakdown:

`target = target.unsqueeze(1).expand_as(mu)`

: Enlarge the target to match the shape of`mu`

.`m = torch.distributions.Normal(loc=mu, scale=sigma)`

: Create a normal distribution.`log_prob = m.log_prob(target)`

: Calculate the logarithmic probability.`log_prob = log_prob.sum(dim=2)`

: Sum of log probabilities.`log_alpha = torch.log(alpha + eps)`

: Calculate the log of mixing coefficients.`loss = -torch.logsumexp(log_alpha + log_prob, dim=1)`

: Combine and record the sum of the probabilities.`return loss.mean()`

: Returns the average loss.

## The neural network

Let's create a neural network ready to handle magic:

`class MDN(nn.Module):`

def __init__(self, input_dim, output_dim, num_hidden, num_mixtures):

super(MDN, self).__init__()

self.hidden = nn.Sequential(

nn.Linear(input_dim, num_hidden),

nn.Tanh(),

nn.Linear(num_hidden, num_hidden),

nn.Tanh(),

)

self.z_alpha = nn.Linear(num_hidden, num_mixtures)

self.z_sigma = nn.Linear(num_hidden, num_mixtures * output_dim)

self.z_mu = nn.Linear(num_hidden, num_mixtures * output_dim)

self.num_mixtures = num_mixtures

self.output_dim = output_dimdef forward(self, x):

hidden = self.hidden(x)

alpha = F.softmax(self.z_alpha(hidden), dim=-1)

sigma = torch.exp(self.z_sigma(hidden)).view(-1, self.num_mixtures, self.output_dim)

mu = self.z_mu(hidden).view(-1, self.num_mixtures, self.output_dim)

return alpha, sigma, mu

Notice the *softmax* being applied to 𝛼*ᵢ *`alpha = F.softmax(self.z_alpha(hidden), dim=-1)`

so their sum is equal to 1, and the exponential is equal to 𝜎*ᵢ* `sigma = torch.exp(self.z_sigma(hidden)).view(-1, self.num_mixtures, self.output_dim)`

to ensure they remain positive, as explained previously.

## The prediction

Getting predictions from MDNs is a bit of a trick. Here's how to sample from the mixture model:

`def get_sample_preds(alpha, sigma, mu, samples=10):`

N, K, T = mu.shape

sampled_preds = torch.zeros(N, samples, T)

uniform_samples = torch.rand(N, samples)

cum_alpha = alpha.cumsum(dim=1)

for i, j in itertools.product(range(N), range(samples)):

u = uniform_samples(i, j)

k = torch.searchsorted(cum_alpha(i), u).item()

sampled_preds(i, j) = torch.normal(mu(i, k), sigma(i, k))

return sampled_preds

Here is the breakdown:

`N, K, T = mu.shape`

: Get the number of data points, mixture components, and output dimensions.`sampled_preds = torch.zeros(N, samples, T)`

: Initialize the tensor to store the sampled predictions.`uniform_samples = torch.rand(N, samples)`

: Generate uniform random numbers for sampling.`cum_alpha = alpha.cumsum(dim=1)`

: Calculate the cumulative sum of the weights of the mixture.`for i, j in itertools.product(range(N), range(samples))`

: loop over each combination of data points and samples.`u = uniform_samples(i, j)`

: Get a random number for the current sample.`k = torch.searchsorted(cum_alpha(i), u).item()`

: Find the index of the components of the mixture.`sampled_preds(i, j) = torch.normal(mu(i, k), sigma(i, k))`

: Sample of the selected Gaussian component.`return sampled_preds`

: Returns the tensor of the sampled predictions.

Let's apply MDN to predict *“Apparent temperature”* using a simple Weather dataset. I trained an MDN with a 50 hidden layer network, and guess what? It rocks ! 🎸

Find the complete code here. Here are some results: