VOOZH about

URL: https://towardsdatascience.com/decision-trees-introduction-intuition-dac9592f4b7f/

⇱ Decision Trees: Introduction & Intuition | Towards Data Science


Decision Trees: Introduction & Intuition

Making data-informed decisions with Python

12 min read
👁 Photo by niko photos on Unsplash peppered with thinking emojis.
Photo by niko photos on Unsplash peppered with thinking emojis.

This is the first article in a series on Decision Trees. In this post, I introduce decision trees and describe how to grow them using data. The post concludes with example Python code showing how to create and use a decision tree to help make medical prognoses.

Key Points:

  • Decision Trees are a widely-used and intuitive machine learning technique used to solve prediction problems.
  • We can grow decision trees from data.
  • Hyperparameter tuning can be used to help avoid the overfitting problem.

What are Decision Trees?

Decision trees are a widely-used and intuitive machine learning technique. Typically, they are used to solve prediction problems. For example, predicting tomorrow’s weather forecast or estimating an individual’s probability of developing heart disease.

They work through a series of yes-no questions, which are used to narrow down possible choices and arrive at an outcome. A simple example of a decision tree is shown below.

👁 Example decision tree to predict whether I will drink tea or coffee. Image by author.
Example decision tree to predict whether I will drink tea or coffee. Image by author.

As shown in the above figure, a decision tree consists of nodes connected by directed edges. Each node in a decision tree corresponds to a conditional statement based on a predictor variable.

At the top of the decision tree shown above is the root node, which sets the initial splitting of data records. Here we evaluate whether it is after 4 PM or not. Each possible response (yes or no) follows a different path in our tree.

If yes, we follow the left branch and end up at a leaf node (also called the terminal node). No further splits are required to determine the outcome at this type of node. In this case, we go with tea over coffee so we can get to bed at a reasonable hour.

Conversely, if it is 4 PM or earlier, we follow the right branch and end up at a so-called splitting node. These nodes further split data records based on conditional statements. From here, we evaluate whether the hours of sleep from last night were more than 6 hours. If yes, we go with tea again, but if no, we go with coffee ☕️.

Using Decision Trees

In practice, we often don’t use decision trees like we did just now _(_i.e. looking at a decision tree and following along for a particular data record). Rather, have a computer evaluate data for us. All we have to do is give the computer the data it needs in the form of a table.

An example of this is shown below. Here we have tabular data with two variables: time of day and hours of sleep from the previous night (blue columns). Then using the decision tree above, we can assign an appropriate caffeinated beverage to each record (green column).

👁 Example table of input data and the resulting decision tree prediction. Image by author.
Example table of input data and the resulting decision tree prediction. Image by author.

Graphical View of a Decision Tree

Another way to think about decision trees is graphically. (This is personally the intuition I carry around for decision trees.)

Imagine we take the two predictor variables from the example decision tree and visualize them on a 2D plot. We can then represent the decision tree splits as lines that divide our plot into different sections. This then allows us to identify the beverage choice by simply looking at which quadrant a data point lies.

Intuitively, this is all a decision tree is doing. Partitioning the predictor space into sections and assigning a label (or probability) to each section.

👁 Graphical view of decision tree predictions for tea or coffee example. Image by author.
Graphical view of decision tree predictions for tea or coffee example. Image by author.

How to Grow a Decision Tree?

Decision trees are an intuitive way to partition data. However, it may not be easy to draw out an appropriate decision tree by hand using data. In such cases, we can use machine learning strategies to learn the "best" decision tree for a given dataset.

Data can be used to grow decision trees in an optimization process called training. Training requires a training dataset consisting of predictor variables pre-labeled with target values.

A standard strategy for training a decision tree uses something called Greedy Search. This is a popular technique in optimization, where we simplify a more complicated optimization problem by finding locally optimal solutions instead of globally optimal ones. (I give an intuition for greedy search in a previous article on causal discovery.)

In the case of decision trees, the Greedy Search determines the gain from each possible splitting option and then chooses the one that provides the greatest gain [1,2]. Here "gain" is determined by the split criterion, which can be based on a few different quantities, e.g. Gini impurity, information gain, mean squared error (MSE), among others. This process is repeated recursively until the decision tree is fully grown.

For example, if using Gini impurity, data records are recursively split into two groups such that the weighted average impurity of the resulting groups is minimized. This splitting procedure can continue until all data partitions are pure, meaning all data records in a given partition corresponds to a single target value.

Although this implies decision trees can be perfect estimators, such an approach would result in overfitting. The trained decision tree would not perform well on data sufficiently different than the training dataset.

Hyperparameter Tuning

One way to combat the overfitting problem is hyperparameter tuning. Hyperparameters are values that constrain the growth of a decision tree.

Common decision tree hyperparameters are the maximum number of splits, minimum leaf size, and the number of splitting variables. The key result of setting decision tree hyperparameters is to limit the tree’s size, which can help avoid overfitting and improve generalizability.

Alternative Training Strategies

While the training process I have described above is widely-used for decision trees, there are alternative approaches we can use.

Pruning – One such approach is called pruning [3]. In a sense, pruning is the opposite of growing a decision tree. Instead of starting from a root node and recursively adding nodes, we start with a fully grown tree and iteratively remove nodes.

While the pruning process can be done in multiple ways, it commonly will drop nodes that do not significantly increase model error. This is an alternative way to avoid overfitting in lieu of hyperparameter tuning to limit tree growth [3].

Maximum Likelihood – We can train a decision tree using the maximum likelihood framework [4]. While this approach is less well-known, it sits on a strong theoretical framework. It allows us to use information criteria such as AIC and BIC to objectively optimize the number of parameters in the tree and its performance, which helps side-step the need for extensive hyperparameter tuning.

Example code: Sepsis Survival Prediction Using a Decision Tree

Now, with a basic understanding of decision trees and how we can develop one from data, let’s dive into a concrete example using Python. Here we will use a dataset from the UCI machine learning repository to train a decision tree to predict whether a patient will survive based on their age, sex, and number of sepsis episodes they’ve experienced [5,6].

For the decision tree training, we will use the sklearn Python library [7]. The code for this example is freely available in the GitHub repository.

We start by importing some helpful libraries.

# import modules
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import precision_score, recall_score, f1_score
from imblearn.over_sampling import SMOTE

Next, we load our data from a .csv file and do some data preparation.

# read data from csv
df = pd.read_csv('raw/s41598-020-73558-3_sepsis_survival_primary_cohort.csv')

# look at data distributions
plt.rcParams.update({'font.size': 16})

# plot histograms
df.hist(figsize=(12,8))
👁 Histograms for each variable in dataset. Image by author.
Histograms for each variable in dataset. Image by author.

Notice in the bottom-right histogram we have many more alive records than dead. This is called an imbalanced dataset. For a simple decision tree classifier, learning from imbalanced data can lead to the decision tree over-predicting the majority class.

To handle this situation, we can over-sample the minority class to make our data more balanced. One way to do this is using a technique called Synthetic Minority Over-sampling Technique (SMOTE). While I will leave further details of SMOTE for a future article, for now, it will suffice to say this helps us balance our data and improve our decision tree model.

# Balance data using SMOTE

# define predictor and target variable names
X_var_names = df.columns[:3]
y_var_name = df.columns[3]

# create predictor and target arrays
X = df[X_var_names]
y = df[y_var_name]

# oversample minority class using smote
X_resampled, y_resampled = SMOTE().fit_resample(X, y)

# plot resulting outcome histogram
y_resampled.hist(figsize=(6,4))
plt.title('hospital_outcome_1alive_0dead n (balanced)')
👁 Outcome histogram after SMOTE. Image by author.
Outcome histogram after SMOTE. Image by author.

The final step for our data preparation is to split our resampled data into training and testing datasets. The training data will be used to grow the decision tree, and testing data will be used to evaluate its performance. Here we use an 80–20 train-test split.

# create train and test datasets
X_train, X_test, y_train, y_test = 
 train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=0)

Now with our training data, we can create our decision tree. Sklearn makes this super easy, with just two lines of code, we have a decision tree.

# Training
clf = tree.DecisionTreeClassifier(random_state=0)
clf = clf.fit(X_train, y_train)

Let’s take a look at the result.

# Display decision tree
plt.figure(figsize=(24,16))

tree.plot_tree(clf)
plt.savefig('visuals/fully_grown_decision_tree.png',facecolor='white',bbox_inches="tight")
plt.show()
👁 Fully grown decision tree. Image by author.
Fully grown decision tree. Image by author.

Needless to say, this is a very big decision tree, which can make interpreting the results difficult. However, let’s put that point aside for now and evaluate the model’s performance.

For evaluating performance, we use a confusion matrix, which displays the number of true positives (TP), true negatives (TN), false positives (FP), and false negatives (FN).

I won’t get into a discussion of confusion matrices here, but for now what we want is for the on-diagonal numbers to be big and the off-diagonal terms to be small.

👁 Confusion matrices for fully-grown decision tree. (Left) Training data set. (Right) Testing dataset. Image by author.
Confusion matrices for fully-grown decision tree. (Left) Training data set. (Right) Testing dataset. Image by author.

We can take the numbers from our confusion matrices and compute three different performance metrics: precision, recall, and f1-score. Briefly, precision = TP / (TP + FP), recall = TP / (TP + FN), and the f1-score is the harmonic mean of precision and recall.

👁 Three performance metrics for fully-grown decision tree. Image by author.
Three performance metrics for fully-grown decision tree. Image by author.

The code to generate these results is given below.

# Function to plot confusion matrix and print precision, recall, and f1-score
def evaluateModel(clf, X, y):

 # confusion matrix
 y_pred = clf.predict(X)
 cm = confusion_matrix(y, y_pred)
 cm_disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['dead', 'alive'])
 cm_disp.plot()

 # print metrics
 print("Precision = " + str(np.round(precision_score(y, y_pred),3)))
 print("Recall = " + str(np.round(recall_score(y, y_pred),3)))
 print("F1 = " + str(np.round(f1_score(y, y_pred),3)))

Hyperparameter Tuning

While the decision tree has decent performance on the data used here, there is still the lingering issue of interpretability. Looking at the decision tree displayed previously, it would be challenging for a clinician to extract any meaningful insights from the decision tree’s logic.

This is where hyperparameter tuning can help. To do this with sklearn, we can simply add input arguments to our decision tree training step.

Here we will try setting the max_depth = 3.

# train model with max_depth set to 3
clf_tuned = tree.DecisionTreeClassifier(random_state=0, max_depth=3)
clf_tuned = clf_tuned.fit(X_train, y_train)

Now, let’s take a look at the resulting decision tree.

👁 Tuned decision tree so that max_depth=3. Image by author.
Tuned decision tree so that max_depth=3. Image by author.

Since we constrained the max depth of the tree, we can plainly see what splits are happening here.

We again evaluate the model’s performance using confusion matrices and the same three performance metrics as before.

👁 Confusion matrices for hyperparameter tuned decision tree. (Left) Training data set. (Right) Testing dataset. Image by author
Confusion matrices for hyperparameter tuned decision tree. (Left) Training data set. (Right) Testing dataset. Image by author
👁 3 performance metrics for hyperparameter tuned decision tree. Image by author.
3 performance metrics for hyperparameter tuned decision tree. Image by author.

Although it may seem the fully-grown tree is preferable to the hyperparameter-tuned one, this goes back to the discussion on overfitting. Yes, the fully-grown tree performance is better on the current data, but I would not expect this to be the case for other data.

Put another way, although the simpler decision tree has worse performance here, it will likely generalize better than the fully-grown tree.

This hypothesis can be tested by applying each model to the other two datasets available in the GitHub repo.

YouTube-Blog/decision-tree/decision_tree at main · ShawhinT/YouTube-Blog

Decision Tree Ensembles

While hyperparameter tuning can improve the generalizability of a decision tree, it still leaves something to be desired in regard to performance. In our example above, after hyperparameter tuning, the decision tree still mislabelled the training data 35% of the time, which is a big deal when talking about life and death (like with the example here).

A popular solution to this problem is to use an ensemble of trees rather than a single decision tree to make predictions. These are called decision tree ensembles and will be the topic of the next article in this series.

10 Decision Trees are Better Than 1


Resources

Connect: My website | Book a call

Socials: YouTube 🎥 | LinkedIn | Twitter

Support: Buy me a coffee ☕️

Get FREE access to every new story I write


[1] Classification and Regression Trees by Breiman et al.

[2] Decision trees: a recent overview by Kotsiantis, S. B.

[3] A comparative analysis of methods for pruning decision trees by Esposito et al.

[4] Maximum likelihood regression trees by Su et al.

[5] Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science. (CC BY 4.0)

[6] Survival prediction of patients with sepsis from age, sex, and septic episode number alone by Chicco & Jurman

[7] Scikit-learn: Machine Learning in Python, Pedregosa et al., JMLR 12, pp. 2825–2830, 2011.


Written By

Shaw Talebi

Towards Data Science is a community publication. Submit your insights to reach our global audience and earn through the TDS Author Payment Program.

Write for TDS

Related Articles