Regularization: Avoiding Overfitting in Machine Learning
How Regularization Works and when to use it
Regularization: Avoiding Overfitting in Machine Learning
What is regularization?
Regularization is a technique used in machine learning to help fix a problem we all face in this space; when a model performs well on training data but poorly on new, unseen data – a problem known as overfitting.
One of the telltale signs I have fallen into the trap of overfitting (and thus needing regularization) is when the model performs great on the training data but terribly on the test data. The reason this happens is that the model learns all the intricacies of the training data in too much detail, which means that it can’t generalize to unseen data.
Regularization is one way to solve this problem and works by penalising a model for having too many parameters with large values. Using a penalty term like this means that the model is encouraged to learn only the most important patterns in the data, and avoid getting bogged down in the noise specific to the training set.
Or, at least, that’s the idea 👀 Let’s dig a bit further to see how it works.
How Does Regularization Work?
Generally speaking, in machine learning, we are trying to learn a model (function) that takes some input features and outputs a number (or vector of numbers in a multi-class classification scenario, for example). And the way we know if the model is doing a good job or not, is by calculating some type of error, which is a function of our model output and y. So, if we pass in some input x and get an output y, we can calculate the error/cost associated with that input.
However, if we also want to penalise a model that is overly complex, we can add another element to the cost function, a penalty term that adds to the cost function when the model has many large weights. As a result, our cost function is now a function of our model output, y, and the parameters of the model.
The penalty term is generally calculated based on the magnitude of the model’s parameters, and it increases the cost as the parameters get larger. This means that the model has to choose what features to give weight to wisely, and to reduce or eliminate the weight on less important features. By doing this, regularization helps to prevent overfitting, and can lead to better performance on new, unseen data.
An Example
An example of a regularization term is the L2 regularization term, which adds a penalty based on the sum of the squares of the model’s parameters. We’ll talk more about this later but for now, let’s just see how it is implemented in the cost function.
Consider a generic mean squared error (MSE) cost function J(θ), which looks like this
where m denotes the number of training samples, h(x;θ) denotes our model output for an input x for some model h with parameters θ, and y is the true value. Here we see the two things (h(x;θ) and y) I mentioned earlier that the cost function (before regularization) takes as inputs. Based on the output of this cost function, we would then update the model parameters θ to minimize the cost and could do this using an algorithm like stochastic gradient descent.
As an aside on mean squared error: h(x;θ)-y tells us how far away our models prediction was from the truth y, and we square it because we want to penalise both predicting too high and too low (and squaring something makes everything positive ➕).
So, now we want to penalise large (in terms of magnitude, which is why we square θ) parameters in the model. We can do this by adding a term to the loss function. Remember the vector of model parameters is denoted by θ. The cost function with an L2 regularization term looks like this:
where λ is the regularization parameter that controls how harsh the regularization is and needs to be chosen by you. Adding this term makes the loss J(θ) larger when the model parameter weights are larger. And so, in the optimization of J(θ), smaller parameter values are encouraged.
Tuning λ is important here. If we choose a value of λ too high, then we could make the regularization part of the cost function have a higher influence than the original MSE portion. This would be a big problem since it essentially amounts to sacrificing model performance just to have smaller model weights.
When to use L1 & L2 Regularization?
There are two main types of regularization: L1 regularization and L2 regularization.
L1 regularization, also known as LASSO (Least Absolute Shrinkage and Selection Operator), adds a penalty term to the cost function that is proportional to the absolute value of the model’s parameters (the example we discussed above used the square of the model’s parameters). This encourages the model to use only a subset of the available parameters and can result in some parameters being set to zero, effectively removing them from the model (think feature selection here 💭 ).
L2 regularization, also known as Ridge Regression, adds a penalty term proportional to the square of the model’s parameters. This encourages the model to use all of the parameters but to reduce their values, resulting in a model that is less complex and less prone to overfitting.
When is Regularization Useful?
In general, regularization is most effective when the training data is limited or when the model has a high complexity, such as a deep neural network with many parameters. In these cases, the model is more likely to overfit, and regularization can help to prevent this by encouraging the model to learn only the most important patterns in the data.
Also, since regularization encourages a model to try and use only a subset of features, it can also improve interpretability and lead to interesting insights. For example, applying regularization in the context of linear regression can have the added benefit of highlighting the most important predictor variables as those remaining with the largest weights in the model.
To learn more about overfitting in ML, check out this article:
Challenges with Regularization
One challenge with regularization is choosing the right regularization parameter, usually denoted λ. This parameter controls the strength of the regularization, and as we mentioned before, it needs to be set carefully in order to achieve the right balance and make sure that the regularization component is weighted enough so as to be useful but not too much so as to overpower the actual error part of your cost function. Finding the right value for lambda can be challenging, and it requires experimentation using a validation set. ⚖️
Another challenge with regularization is that it can be computationally expensive, especially for large models with many parameters. This is because the regularization term needs to be calculated and added to the cost function for each iteration of training. This can significantly slow down the training process and can be a particular problem for L2 regularization, which involves computing the square of the parameters.
Implementation Tips
Despite these challenges, regularization can be a powerful tool for improving the performance of machine-learning models and preventing overfitting. The following are a few key takeaways to keep in mind when implementing regularization:
- Choose the right type of regularization. For example, L1 regularization is more effective for feature selection, while L2 regularization is more effective for preventing overfitting.
- Set the regularization parameter, λ. This parameter controls the strength of the regularization, and it needs to be set carefully in order to achieve the desired balance between model complexity and overfitting. It may be necessary to experiment with different values of lambda in order to find the best value for your model.
- Incorporate regularization into your cost function. In order to use regularization, the regularization term needs to be added to the cost function that the model is optimizing. This can be done by simply adding the regularization term to the existing cost function or by using a pre-built regularization function provided by a machine learning library. There is often no need to reinvent the wheel!
Conclusion
I hope that this article has given you a better understanding of how regularization can be a valuable tool to have in your machine learning toolbox. It is by no means a silver bullet, and it may not work in all situations, but if a model is experiencing overfitting, it can often be a good place to start.
Share This Article
Towards Data Science is a community publication. Submit your insights to reach our global audience and earn through the TDS Author Payment Program.
Write for TDS