The problem
Machine learning models are trained on a set of sampled data (the training set). Data scientists use these trained models to make predictions from new data. For example, a recommender system might be trained on a data set of movies people have watched, then used to make recommendations on the movies people might like to watch. Key to the success of machine learning models is their accuracy; recommending the wrong movie, predicting the wrong sales volume, or misdiagnosing a medical image all have moral and financial consequences.
There are two causes of machine learning failure closely related to model training: underfitting and overfitting.
Underfitting is where the model is too simple to correctly represent the data. The symptoms are a poor fit to the training data set. This chart shows the problem.
Overfitting is where the model is too complex, meaning it tries to fit noise instead of just the underlying trends. The symptoms are an excellent fit to the training data, but poor results when the model is exposed to real data or extrapolated. This chart shows the problem. The curve was overfit (the red dotted line), so when the curve is extrapolated, it produces nonsense.
In another company, I saw an analyst try to forecast sales data. He used a highly complex data set and a very, very, very complex model. It fit the data beautifully well. Unfortunately, it gave clearly wrong sales predictions for the next year (e.g., negative sales). He tweaked the model and got some saner predictions, unfortunately as it turned out, his predictions were way off. He had overfit his data, so when you extrapolated to the next year, it gave nonsense. When he tweaked his model, it gave less less obviously bad results, but because it overfit, it’s forecast was very wrong.
Like all disciplines, machine learning has a set of terminology aimed at keeping outsiders out. Underfitting is called bias and overfitting is called variance. These are not helpful terms in my view, but we’re stuck with them. I’m going to use the proper terminology (bias and variance) and the more straightforward terms (underfitting and overfitting) for clarity in this blog post.
Let’s look at how machine learning copes with this problem by using regularization.
Regularization
Let’s start with a simple machine linear learning model where we have a set of \(m\) features (\(X = {x_1, x_2, ...x_m}\)) and we’re trying to model a target variable \(y\) with \(n\) observations. \(\hat{y}\) is our estimate of \(y\) using the features \(X\), so we have:
\[\hat{y}^{(i)} = wx^{(i)} + b\]
Where i varies from 1 to \(n\).
The cost function is the difference between our model predictions and the actual values.
\[J(w, b) = \frac{1}{2m}\sum_{i=1}^{m}( \hat{y}^{(i)} - y^{(i)} )^2\]
To find the model parameters \(w\), we minimize the cost function (typically, using gradient descent, Adam, or something like that). Overfitting manifests itself when some of the \(w\) parameters are too big.
The idea behind regularization is that it introduces a penalty for adding more complexity to the model, which means keeping the \(w\) values as small as possible. With the right choices, we can make the model fit the 'baseline' without being too distracted by the noise.
As we'll see in a minute, there are several different types of regularization. For the simple machine learning model we're using here, we'll use the popular L2 form of regularization.
Regularization means altering the cost function to penalize more complicated models. Specifically, it introduces an extra term to the cost function, called the regularization term.
\[J(w, b) = \frac{1}{2m}\sum_{i=1}^{m}( \hat{y}^{(i)} - y^(i) )^2 + \frac{\lambda}{2m}\sum_{j=1}^{n} w_{j}^{2}\]
\(\lambda\) is the regularization parameter and we set \(\lambda > 0\). Because \(\lambda > 0\) we're penalizing the cost function for higher values of \(w\), so gradient descent will tend to avoid them when we're minimizing. The regularization term is a square term; this modified cost function is a ridge regression or L2 form of regularization.
You might think that regularization would reduce some of the \(w\) parameters to zero, but in practice, that’s not what happens. It reduces their contribution substantially, but often not totally. You can still end up with a model that’s more computationally complex than it needs to be, but it won’t overfit.
You probably noticed the \(b\) values appeared in the model but not in the cost function or the regularized cost function. That's because in practice, the \(b\) value makes very little difference, but it does complicate the math, so I'm ignoring it here to make our lives easier.
Types of regularization
This is the ridge regression or L2 form of regularization (that we saw in the previous section):
The L1 form is a bit simpler, it's sometimes known as lasso which is an acronym meaning Least Absolute Shrinkage and Selection Operator.
Of course, you can combine L1 and L2 regularization, which is something called elastic net regularization. It's more accurate than L1 and L2, but the computational complexity is higher.
A more complex form of regularization is entropy regularization which is used a lot in reinforcement learning.
For most cases, the L2 form works just fine.
Regularization in more complex machine learning models - dropping out
Linear machine learning models are very simple, but that about logistic models or the more complex neural nets? As it turns out, regularization works for neural nets and other complex models too.
Overfitting in neural nets can occur due to "over-reliance" on a small number of nodes and their connections. To regularize the network, we randomly drop out nodes during the raining process, this is called drop out regularization, and for once, we have a well-named piece of jargon. The net effect of drop out regularization is a "smoother" network that models the baseline and not the noise.
Regularization in Python
The scikit-learn package has the functionality you need. In particular, check out the Lasso, Ridge, ElasticNet and GridSearchCV functions. Dropout regularization in neural networks is a bit more complicated and in my view it needs a little more standardization in the libraries (which is a fancy way of saying, you'll need to check the current state of the documents).
Seeking \(\lambda\)
Given that \(\lambda\) is a hyperparameter and important, how do we calculate it? The answer is using cross-validation. We can either set up a search or step through various \(\lambda\) values to see which values minimize the cost function. This probably doesn't seem very satisfactory to you and frankly, it isn't. How to cheaply find \(\lambda\) is an area of research so maybe we'll have better answers in a few years' time.
The bottom line
Underfitting (bias) and overfitting (variance) can kill machine learning models (and models in general). Regularization is a powerful method for preventing these problems. Despite the large equations, it's actually quite easy to implement.