VOOZH about

URL: https://towardsdatascience.com/validation-curve-explained-plot-the-influence-of-a-single-hyperparameter-1ac4864deaf8/

⇱ Validation Curve Explained - Plot the influence of a single hyperparameter | Towards Data Science


Skip to content

Validation Curve Explained – Plot the influence of a single hyperparameter

Plot with less code! Save time to interpret it!

7 min read
👁 Photo by CHUTTERSNAP on Unsplash
Photo by CHUTTERSNAP on Unsplash

In Machine Learning (ML), model validation is used to measure the effectiveness of an ML model. A good ML model not only fits the training data very well but also is generalizable to new input data.

Model hyperparameters play an important role in determining the effectiveness of an ML model. With a grid search or randomized search, we can find the optimal hyperparameter combination for an ML model. However, without doing grid search or randomized search, it is sometimes useful to find out the influence of a single hyperparameter on the training and test data. This is where you plot the validation curve. The validation curve is a graphical technique that can be used to measure the influence of a single hyperparameter. By looking at this curve, you can determine if the model is underfitting, overfitting or just-right for some range of hyperparameter values.

Prerequisites

To understand today’s content, you should have good knowledge in cross-validation. If you haven’t, please read the "Using k-fold cross-validation for evaluating a model’s performance" section of the following article written by me.

k-fold cross-validation explained in plain English

If you’re looking for some content about hyperparameter tuning, read the "Using k-fold cross-validation for hyperparameter tuning" section of the above article. This is completely optional and having knowledge in that part is not necessary to understand today’s content.

In addition to that, having knowledge in the Random Forests algorithm is preferred. This is because, today, we build a random forest model and plot the validation curve based on it. Take a look at the following article which explains the random forests in plain English.

Random forests – An ensemble of decision trees

Enough intro. Let’s dive into the topic.

Get the data

We use the "heart_disease" dataset. You can download it by clicking this link. The first few rows of the dataset are:

👁 First 5 rows of the "heart_disease" dataset (Image by author)
First 5 rows of the "heart_disease" dataset (Image by author)

This dataset contains 303 samples and 13 features. The last column is the target column which contains 0s (not having a heart disease) and 1s (having a heart disease).

The method

We build a Random Forest Classifier on the above data. Then, we especially consider the max_depth hyperparameter of that classifier. The max_depth value determines the number of times that each tree in the ensemble splits and after that, it stops branching. By ranging the values from 1 to 10 of the max_depth hyperparameter, we plot cross-validated training and test scores. Then, we decide the best value for the max_depth hyperparameter.

Let’s code

After building the model, I use only 1 line of code (less code) to create the validation curve. I take the advantage of the Yellowbrick machine learning visualization library. However, it doesn’t come with the Anaconda installer. You need to manually install it. Open your Anaconda prompt and just run the following command.

pip install yellowbrick

If that didn’t work for you, try the following with the user tag.

pip install yellowbrick --user

or try it with the conda-forge channel.

conda install -c conda-forge yellowbrick

or try it with the DistrictDataLabs channel.

conda install -c districtdatalabs yellowbrick

Now, see the following code.

The output is:

👁 Validation Curve on the max_depth hyperparameter (Image by author)
Validation Curve on the max_depth hyperparameter (Image by author)

Let’s explain

When we execute the validation_curve() function, a lot of work happens behind the scenes. The first argument of this function should be a Scikit-learn estimator (here it is a Random Forest Classifier). The second and third ones should be X (feature matrix) and y (target vector). The param_name contains the name of the hyperparameter that we want to measure the influence. "n_jobs=-1" means that we use all the cores of the computer processor to do parallel computations when doing the cross-validation procedure. "param_range" contains a 1-dimensional numpy array of possible parameter values. In our example, those values should be integers starting from 1. Zeros and negatives are not acceptable values for max_depth. The "cv" defines the number of folds for the cross-validation. Standard values are 3, 5, and 10. The scoring argument contains the method of scoring of the model. In classification, "accuracy" and "roc_auc" are most preferred. In regression, "r2" and "neg_mean_squared_error" are commonly used. In addition to those, there are many evaluation metrics. You can find all of them by visiting this link.

When we execute the validation_curve() function, the cross-validation procedure happens behind the scenes. Because of this, we just input X and y. We don’t need to split the dataset as X_train, y_train, X_test, y_test. In cross-validation, the splitting is done internally based on the number of folds we specified in cv. Using cross-validation here guarantees that the accuracy score of the model isn’t much affected by the random data splitting process. If you just use the train_test_split() function _withou_t cross-validation, the accuracy score will vary significantly based on the random_state you provide inside the train_test_split() function. Here in cross-validation, the accuracy is calculated using the average of 10 (cv=10) such iterations!

In k-fold cross-validation, we make an assumption that all observations in the dataset are nicely distributed in a way that the data are not biased. That is why we first shuffle the dataset using the shuffle function.

Note: The same functionality of the validation_curve() function can be achieved using the ValidationCurve() class. Here, you first create a visualizer (an object of the ValidationCurve() class), then use the common .fit(X, y) paradigm. Here is the code.

Let’s interpret the validation curve

Now, we interpret the validation curve that we plotted previously. By looking at the curve, we can determine if the model is underfitting, overfitting or just-right for some range of hyperparameter values of max_depth. Note that, in the graph, the accuracy score of the train set is marked as the "Training Score" and the accuracy score of the test set is marked as the "Cross-Validation Score".

  • Underfitting: Accuracy scores of both train and test sets are low. This indicates that the model is too simple or has been regularized too much. At the max_depth values of 1 and 2, the random forests model is underfitting.
  • Overfitting: The training accuracy score is very high and the accuracy score of the test set is low. The model fits very well for the training data, but it fails to generalize to new input data. For max_depth values of 4, 5, …, 10, the model is highly overfitted.
  • Just-right: No overfitting or underfitting. At the max_depth value of 3, the model is just right. The model fits the training data very well and it is also generalizable to new input data. That’s what we want!

Be careful: When you use an evaluation metric such as MSE, the overfitting condition happens when the training MSE is very low (not high) and the MSE of the test set is high (not low). This is because here we consider an error (Mean Squared Error).

Be careful: Here, you got the optimal max_depth hyperparameter value of 3. Keep in mind that this is what we got when we consider only the max_depth hyperparameter. When we consider several hyperparameters at a time as in Grid Search or Randomized Search, the optimal max_depth hyperparameter value will not be 3.

Summary

Validation curve is a great tool that you should have in your machine learning toolkit. It can be used to plot the influence of a single hyperparameter. It should not be used to tune the model. Use a grid search or randomized search instead. When creating the curve, the cross-validation method should be considered. The interpretation is different according to the evaluation metric you select. I recommend you use the Yellobrick Python library when creating the validation curve. It is much easy to use and save a lot of time in coding. Then use that time to interpret it!

Thanks for reading!

This tutorial was designed and created by Rukshan Pramoditha, the Author of Data Science 365 Blog.

Read my other articles at https://rukshanpramoditha.medium.com

2021–03–13


Written By

Rukshan Pramoditha

Towards Data Science is a community publication. Submit your insights to reach our global audience and earn through the TDS Author Payment Program.

Write for TDS

Related Articles