In the year 2024 (when I am writing this post), two classes of algorithms rule the supervised learning world. You have probably guessed that one is neural networks, the driving force behind powerful tools such as ChatGPT or other generative models such as Stable Diffusion. Neural networks are intricate, mysterious, often enormous in size, and require a bunch of data and computational resources. Moreover, we really can't tell why a neural network prefers one representation or the other; it's a black box.
On the top of machine learning hall of fame, the other entry is the family of decision trees. Opposed to neural nets, they are
simple as 1-2-3,
fast as lightning,
require only a small amount of data,
and we can precisely tell how they make decisions.
Let's make them a permanent tool under our belt.
This post is the next chapter of my upcoming Mathematics of Machine Learning book, available in early access.
New chapters are available for the premium subscribers of The Palindrome as well, but there are 60+ chapters (~650 pages) available exclusively for members of the early access.
Petals, sepals, and directed acyclic graphs
As usual, let's start with the Iris dataset, one of my favorite toy problems for classification.
import numpy as np
from sklearn.datasets import load_iris
data = load_iris()
X, y = data["data"], data["target"]
X_with_y = np.hstack([X, y.reshape(-1, 1)])
Recall that Iris has four features: petal width, petal length, sepal width, and sepal length, which are used to classify three species of the iris genus. Here are its features and classes, plotted against each other.
We've looked at several classification algorithms before, but check this out. Following these simple rules, we can solve this problem with 96% accuracy.
Is the petal length less than 2.45 cm? If yes, then it is of the setosa family.
If not, then is the petal width less than 1.75 cm? If yes, then it is a versicolor.
If not, then it is virginica.
In pseudocode, this is what we're talking about:
if petal_length < 2.45:
return "setosa"
elif petal_width < 1.75:
return "versicolor"
else:
return "virginica"
Here's the actual code to back up the claims.
from mlfz.classical.tree.cart import ClassificationTree
from mlfz.scores import accuracy
tree = ClassificationTree(max_depth=2)
y_pred = tree.fit_predict(X, y)
accuracy(y_pred, y)
(Run it on your local machine to see that the accuracy is 96%. By the way, mlfz
is a neural network library built maintained by me, aimed to teach machine learning from zero.)
We can visualize the graph represented by tree
via the tree.digraph
attribute.
Trust me when I say this: decision trees are godlike. Variants like XGBoost are state-of-the-art for a wide array of problems, frequently whooping entire Kaggle leaderboards.
It's time to pop that hood and take a look inside.
Growing a decision tree
Let's cut back one dimension and generate a toy classification dataset with two features.
from sklearn.datasets import make_classification
X, y = make_classification(
n_samples=500,
n_features=2,
n_informative=2,
n_redundant=0,
n_classes=2,
n_clusters_per_class=2,
flip_y=0.1,
random_state=42
)
Here’s how it looks.
We are already familiar with the scikit-learn interface, so we can comfortably use the previously introduced ClassificationTree
class to illustrate the learning process. (If not, check out an earlier post about this.)
Each decision tree has a crucial parameter that governs how deep it can reach. In our implementation, this is named `max_depth`. First, we'll train a stump, that is, a tree of max_depth=1
.
tree = ClassificationTree(max_depth=1)
y_pred = tree.fit_predict(X, y)
If the dataset is simple enough (like in our case), the accuracy can be quite decent.
In [1]: accuracy(y_pred, y)
Out[1]: np.float64(0.852)
Here's the stump.
An inner node in a decision tree (such as the root node above) represents a split, a division of the feature space. Check it out in the figure below.
During the training, we find the optimal combination of a feature and a threshold to split the feature space. In this example, the splitting feature is X[0]
and the threshold is 0.05
.
(Pseudo-)algorithmically speaking, this is how the best split is found.
X # the training data
y # the training labels
split_feature = 0 # the feature along which we'll split the data
split_threshold = 0 # the threshold to use for the split
best_score = np.inf
for i in features:
for c in thresholds:
left_split = X[i] < c
right_split = ~left_split
score = split_score([y[left_split], y[right_split]])
if score < best_score:
best_score = score
split_feature = i
split_threshold = c
No vectorization yet for simplicity. Hold your horses. We also haven't specified how we plan to rank the splits; that is, how the mysterious split_score
function is defined.
Let's go deeper by performing an additional split; that is, train a model of max_depth=2
. Here we go.
tree = ClassificationTree(max_depth=2)
tree.fit(X, y)
Here's the tree representation.
What happened? We kept the first split but replaced the leaf nodes with stumps. This has a definite recursion vibe! Think of each non-leaf node as a filter, passing a subset of the data to the children.
Let's plot the decision boundary to check what happened.
Keep going and add one more level. Will we keep the existing splits and replace the leaves with stumps?
I won't leave you hanging for long: this is exactly what happens. Check it out.
tree = ClassificationTree(max_depth=3)
tree.fit(X, y)
And here’s the actual tree.
This confirms our belief that decision trees are built recursively:
1. first, we train a stump,
2. then train a stump on each leaf,
3. and then repeat this process until the maximum depth is reached. (Or some other stopping condition is met. We'll see all of them later.)
Here's the pseudocode.
tree # is a ClassificationTree(max_depth)
for node in tree.leaves:
if node.should_stop:
# if stopping condition is met, we find the most likely
# class for the data this node represents
node.predicted_class = find_prediction(y)
else:
# if we don't stop, we keep building the tree
node.left_child = ClassificationTree(
max_depth=tree.max_depth - 1
).fit(X, y)
node.right_child = ClassificationTree(
max_depth=tree.max_depth - 1
).fit(X, y)
Now, we'll let it rip and crunch out an enormous tree of eight levels.
Just kidding about the enormous part; it's not that big. (Decision trees shouldn't be big anyway. Keeping them small helps with overfitting. We'll see other techniques that help increase the expressivity of trees.)
However, it's big enough to show you a few things though. Check this out.
tree = ClassificationTree(max_depth=8)
tree.fit(X, y)
Notice that this binary tree is not complete. Leaves can be found on various levels, so the recursion doesn't always reach max depth. Let's see the decision boundaries first; then we'll talk about the stopping conditions.
When should we stop growing the tree along a branch? There are three conditions: if
the split contains only a single class,
the split contains less than a given number of data points,
or the maximum depth is reached.
Now that we understand how decision trees work in broad strokes, it's time to build one finally. Let's make those pseudocodes work!
See you at the next episode.
It's great to have you back on Substack, Tivadar.
Awesome,Tivadar. Will your book be available on Amazon? When do you think it will go to print? Will there be a paper version?