Where does the Mean Squared Error come from?

Tivadar Danka small portrait Tivadar Danka
regression dataset

The simplest and one of the most commonly used loss functions in machine learning is the mean squared error, defined by

2021-08-mse-01-mse.png

which is just the mean of the squared distances between ground truth and prediction.

The mean squared error conveys a completely natural idea. Intuitively, the distance between ground truth and prediction should indicate how well our model fits the data. The closer they are, the better the fit.

Beyond this intuition, mean squared error has its roots in probability theory. For one, as the observations yᵢ can be thought of as samples from a probability distribution, MSE is the empirical expected value of

2021-08-mse-02-individual-error.png

Minimizing the MSE is the same as minimizing the expected value of the error between prediction and ground truth. However, this interpretation does not offer an insight into why such a model would explain the dataset. Consider the following simple example below.

2021-08-mse-03-dataset.png

Notice that around each x, the observations y fluctuate and exhibit high variance. It seems like that a deterministic model y = f(x) is not enough to "explain" the data. Instead of fitting a function, let's look for a model that also accounts for the variance in the dataset. Finding a parametrized function like f(x) = ax + b is not good enough. We need a model that can explain the variance of the data, not just its mean, unlike a linear regression would. So, let's model it with probability distributions instead of just a function!

Mean squared error as maximum likelihood

To explore this idea, let's go back to a simple special case and suppose that both our input and output are a single number. Say we are trying to predict the price of an apartment from its size in square foot.

Suppose that the data x and the observation y comes from the distributions X and Y. What we are looking for is the conditional distribution of Y given X. As it is a probability distribution, it provides a more refined understanding of the underlying process.

How can we approximate it? For one we can model it with Gaussian distributions, so that

2021-08-mse-04-bayesian-model.png

where f(x) describes the expected value function. To fit such a model, we can turn to maximum likelihood estimation. Here, the maximum likelihood function is

2021-08-mse-05-likelihood.png

After a bit of calculation, we obtain

2021-08-mse-06-likelihood-calculation.png

This is what we want to maximize. Now we do a trick that should be familiar by now: maximizing a function is the same as maximizing its logarithm. (Since the logarithm is monotone increasing.)

The purpose of this is to turn the product into a sum:

2021-08-mse-07-log-likelihood.png

Because we want to maximize the above formula in f, we can omit the terms where it is not present. Moreover, maximizing a function is the same as minimizing its negative. In the end, we are left with

2021-08-mse-08-maximum-likelihood.png

This is almost the mean squared error! One thing is missing, though: scaling with the number of samples.

Taking the average of norms

The mean squared error is the mean of squared errors. Why not just take the sum? The reason is to keep the loss independent of the dataset size. Consider the two scenarios:

  1. A model that fits poorly on a dataset with 10 samples. 
  2. A model that fits excellently on a dataset with 1000 samples.

Without scaling taking the mean, the sum of squared errors can be the same in both cases. Roughly speaking,

2021-08-mse-09-scaling-1.png

which is not something we would like. By scaling, we have

2021-08-mse-10-scaling-2.png

indicating the fact that the second model fits better. There is an additional benefit: it keeps the gradient in check. For large datasets, the gradient can be huge, throwing the optimization off by forcing it to take giant steps.

Understanding math is a superpower in machine learning.

I am writing a book about it to help you go from high school mathematics to neural networks.
Join me on this journey and let's do this together!