TensorFlow Decision Forests – Train your favorite tree-based models using Keras
Yes, you read that right – the same API for both Neural Networks and tree-based models!
In this article, I will briefly describe what decision forests are and how to train tree-based models (such as Random Forest or Gradient Boosted Trees) using the same Keras API as you would normally use for Neural Networks. Let’s dive into it!
What is a Decision Forest?
I will get straight to the point, it is not another fancy algorithm like XGBoost, LightGBM, or CatBoost. Decision forests are simply a family of machine learning algorithms built from many decision trees. That includes many of your favorites like Random Forest and various flavors of gradient-boosted trees.
TensorFlow Decision Forest (TF-DF)
Until now there was a clear split between machine and deep learning libraries. For traditional machine learning, use scikit-learn. For deep learning, use TensorFlow/PyTorch. And if you want to be really nit-picky, there is a neural_network module in scikit-learn and it contains Multi-layer Perceptrons. However, that is more of a fun fact and not something that is used extensively in practice.
Ending the small digression, TensorFlow Decision Forests takes the first step to bridge the gap between the frameworks. The following paragraph from the library’s release note does a great job at concisely describing the main idea:
TF-DF is a collection of production-ready state-of-the-art algorithms for training, serving and interpreting decision forest models (including random forests and gradient boosted trees). You can now use these models for classification, regression and ranking tasks – with the flexibility and composability of the TensorFlow and Keras.
That sounds really nice, but why bother with having another library to train the same models? Good question. Here are some advantages of having different classes of models available in the same API:
- if you already had an architecture built for some project using neural networks in TensorFlow, now you can easily experiment with an entirely different class of models. The tree-based models are often favored over NNs (both in terms of performance and speed), especially when working with structured tabular data.
- you can deploy various models using the same set of tools, for example, TensorFlow Serving.
- you can use a variety of interpretability tools and techniques available for the tree-based models.
- using the library, it is easy to combine neural networks and decision forests, for example, a tree-based model can consume the output of the Neural Network.
- under the hood, TF-DF is a wrapper around Yggdrasil Decision Forests, a C++ framework containing many of the decision forest algorithms.
- the TF-DF implementation of the models can not only do classification and regression, but can also solve ranking problems.
Decision Forests in practice
Setup
Unfortunately, at the moment of writing this article, the setup is not as easy as installing and importing the library. As of August 2021, it is not possible to install the library on macOS and Windows. The authors of the library are currently working on making it possible.
In the meantime, the easiest solution to play around with TF-DF is to use Google Colab. We need start off by installing two libraries: tensorflow_decision_forests and wurlitzer.
Note: If you are having troubles installing the library on Colab, downgrade TensorFlow to tensorflow==2.5.1.
After doing so, we can proceed as usual and import the required libraries.
Data
We will once again use the Pokémon data set. It contains all the Pokémon from eight generations, including their types, battle stats, and some more meta information. Previously, we have explored the correlations between the features of this data set using the new correlation metric – 𝜙k. We can load the data as follows.
We only keep some columns that might be relevant for building a machine learning model. As for the goal itself, there are a few options here. We could try to build a classifier to recognize if a given Pokémon is legendary, or maybe solve a multi-class problem to recognize the generation from which given Pokémon comes. However, based on the correlation analysis it might also be interesting to try a regression problem and predict the Pokémon’s attack based on the rest of its stats and information.
As the first step, we split the data into training and test sets, while stratifying by generation.
Training Decision Forests models
A very nice and handy feature of TD-DF is that it does not require any preprocessing of the data. It automatically handles numerical and categorical features, as well as missing values. Our data set has missing values in the secondary_type feature, as not all Pokémon have two types.
To work with Decision Forests’ models, we just need to convert pandas DataFrames into TensorFlow Datasets. While doing so, we indicate which column contains the target and what kind of task we are working on. The default is classification, but for our use case, we need to switch to regression.
Then, we instantiate the Decision Forest model. For this example, we use a Random Forest model (currently, we could also use Gradient Boosted Trees or CART). We complete the optional compilation step in which we add some additional metrics of interest – MSE and MAPE. Then, we fit the model to the training data. Unless specified otherwise, the training will use all the features available in the data set.
In the logs, we can see some useful information about the features used for training, including summary statistics, % of missing values, etc. Further on in the logs (not included for brevity), we also see the hyperparameters used for training the model (default ones), how the RMSE (the default metric) changes after fitting X trees, and the out-of-bag score.
Evaluating the results
The next step is to evaluate the performance of the fitted model on the test set. To do so, we use the evaluate method.
What returns the following scores:
{'loss': 0.0, 'mse': 431.4039611816406, 'mape': 24.62486457824707}
MSE: 431.40
RMSE: 20.77
MAPE: 24.62
Please note that the results would be the same if we created predictions using the predict method and then manually calculated the scores, for example, using the functions from scikit-learn.
Note: We will not spend much time on analyzing the model’s performance, as the goal of this article is just to showcase a new library and not to achieve the best possible performance for the task at hand.
Interpreting the models
TF-DF provides some nice functionalities to interpret the models. There are many reasons why interpretability should be near top of the things we should analyze after we have fitted a model. Some of them include the ability to understand the model’s decisions and potentially explain them to stakeholders, or debug the model when we see "weird" predictions. I have covered the topic of Random Forest’s feature importance extensively in another article.
We start by plotting a single decision tree from the Random Forest. We can do so using the following snippet:
We have plotted only the first 3 levels, as the default settings of RandomForestModel allow for a maximum depth of 16. What is definitely interesting to see is that categories are grouped together. For example, the primary_type at the top of the image shows a split based on 4 possible groups, not like in the scikit-learn implementation which would show a single, one-hot encoded primary_type.
Such kind of plot might be especially interesting for a CART model or for inspecting the first tree of a GBT model.
We can dive even deeper into feature importances and analyzing the model’s structure using the summary method. However, the output can be a bit overwhelming, so we will access the very same elements one step at a time.
We start by inspecting the features used in the model.
# inspect the features used in the model
model_rf.make_inspector().features()
Then, we move forward to feature importances. TF-DF provides a few different ways of calculating the feature importance which depend on the type of the model. The first one available for the Random Forest is based on the MEAN_MIN_DEPTH. The minimal depth for a feature in a tree corresponds to the depth of the node which splits the observations on that feature and is the closest to the root of the tree. A low value indicates that a lot of observations are categorized based on this feature.
The smaller the mean minimal depth, the more important the variable. From the image above we can see that HP (health points) seems to be the most important feature, at least when using this criterion.
The next available metric is NUM_AS_ROOT, which simply indicates how many times the given feature was the root of the tree. In this case, the higher the better, and HP was again the most important one. A quick sanity check shows that there was a total of 300 trees in the RF model, which is indeed the default hyperparameter value.
Lastly, there is the NUM_NODES metric. It shows how many times a given feature was used as a node for splitting the observations in a tree. Naturally, a feature can be used multiple times in a tree, so those do not sum up to the total number of trees. Using this metric, primary_type was the most important feature.
We can also have a look at the out-of-bag RMSE. To get the value, we can use the following snippet:
# get the out-of-bag score
model_rf.make_inspector().evaluation()
What returns the following:
Evaluation(num_examples=718, accuracy=None, loss=None, rmse=20.664861230679822, ndcg=None, aucs=None)
Lastly, we can also dive a bit deeper and see how the out-of-bag RMSE evolved with the number of trees trained. To do so, we need to access the training logs, also using the make_inspector method.
From the image above we see that the score stabilized after around 100 trees.
In general, the training logs show the quality of the model as the model keeps on growing the trees. We can use this information to evaluate the balance between model size and model quality.
There is much more information available in the output of the summary method and I highly recommend giving it a try.
TensorFlow vs. scikit-learn
Will TensorFlow Decision Forests replace the good, old scikit-learn? Probably not. The main reasons in favor of that statement are:
- it is not as straightforward as
scikit-learn, - less out-of-the-box functionality for tuning the models (think Grid Search),
- the documentation is nowhere near as comprehensive,
- so far, there are not that many models available (there are more available in Yggdrasil Decision Forests, which will probably be added soon).
I would say that the Decision Forests will be useful for those who already have a project built in TensorFlow and would like to easily compare their current neural network solutions to an entirely different class of models. Then, they can easily do so using the new library, while keeping their entire architecture pretty much the same.
Takeaways
- Decision Forests are a family of algorithms built from many decision trees,
- TensorFlow Decision Forests allow us to train Random Forest or Gradient Boosted Trees using the familiar TensorFlow API,
- While a lot of functionality is provided in the library, it is probably not enough to ditch
scikit-learnin favor of the new library. It’s more of an additional opportunity for those who already have a TensorFlow architecture set up for some project.
You can find the code used for this article on my GitHub. Also, any constructive feedback is welcome. You can reach out to me on Twitter or in the comments.
Liked the article? Become a Medium member to continue learning by reading without limits. If you use this link to become a member, you will support me at no extra cost to you. Thanks in advance and see you around!
You might also be interested in one of the following:
9 Useful Pandas Methods You Might Have Not Heard About
Chefboost – an alternative Python library for tree-based models
References
Share This Article
Towards Data Science is a community publication. Submit your insights to reach our global audience and earn through the TDS Author Payment Program.
Write for TDS