How to implement a decision tree
Beating neural networks one if-else condition at a time, part 2
Last time, we stopped the decision tree train right when we finally understood how to train one.
Now, we continue straight where we left off. In today’s post, we’ll have our very own implementation of decision trees ready!
Note: this post is a direct continuation of the How to grow a decision tree post, so read that first to understand this one!
Scoring a split
Before diving into actually growing a decision tree, let's take a step back and examine a split in detail. Each split is determined by a feature and a threshold, for example:
left_leaf = X[:, 0] < 0
right_leaf = ~left_leaf # equivalent to right_split = X[0] >= 0
Each split creates two leaves in our decision tree. In our case, these are represented by the boolean arrays left_leaf
and right_leaf
.
These boolean arrays tell us whether the first feature (represented by X[:, 0]
) is larger or smaller than the threshold (0 here). Let's plot this.
How do you score the quality of such a split? Let's focus on the leaves. In general, if pᵢ denotes the relative frequency of the i-th label in the leaf, then the so-called Gini impurity defined by the formula
which is a good measure of how "pure" the label distribution of the leaf is. Think about it: if the split is perfect, the relative frequencies will be zero for all but one class. The relative frequency of the only class is one, making the Gini impurity zero. On the other hand, if the label distribution is even, Gini impurity will be closer to one.
How do we implement this? It's easy in NumPy. Check this out.
def gini_impurity(p):
return 1 - (p**2).sum()
Let's visualize this by plotting it for a binary classification problem. As you can see, the closer p is to zero or one, the smaller gini_impurity(p)
is.
To compute the Gini impurity for the class labels passed to a leaf, we have to turn the labels into relative frequencies. This can be done with NumPy's `unique` function.
def leaf_gini_impurity(y):
labels, counts = np.unique(y, return_counts=True)
freq = counts / len(y)
return gini_impurity(freq)
In [1]: leaf_gini_impurity(y[left_leaf])
Out[1]: np.float64(0.26680547293277823)
(Recall that X
and y
represents our training data and labels.)
Finally, we score the split by weighing the purities according to the size of the leaves.
def score_split(ys):
n_samples = sum([len(y) for y in ys])
return sum([ (len(y) / n_samples) * leaf_gini_impurity(y) for y in ys])
In [2]: score_split([y[left_leaf], y[right_leaf]])
Out[2]: np.float64(0.26331553677741504)
There is an abundance of leaf-scoring methods, such as the Shannon entropy. Now, with the Gini impurity in our toolkit, we are ready to grow the tree!
Implementing a decision tree
What kind of attributes does our ClassificationTree
class need? In our previous reverse-engineering study, we found that a decision tree has
a maximum depth,
a minimal number of samples split per node,
a feature to split along,
a threshold for the split,
and perhaps a left and right child, and if not, a class prediction.
(Recall that if there are no left and right children, or equivalently, there's a predicted class, the node is considered a leaf.) The first two are given upon initialization, while the others are determined during training.
So, the object initialization is ready for implementation. While we're at it, we might as well add a nice string representation.
from mlfz.classical.base import Model
class ClassificationTree(Model):
def __init__(
self,
max_depth=None,
min_samples_split=2,
):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.split_feature_idx = 0
self.threshold = 0
self.left_child = None
self.right_child = None
self.predicted_class = None
def __repr__(self):
return f"{self.__class__.__name__}(max_depth={self.max_depth}, min_samples_split={self.min_samples_split})"
@property
def is_leaf(self):
return self.predicted_class != None
def fit(self, X, Y):
pass
def predict(self, X):
pass
What is the training process? Let's recap.
Upon calling ClassificationTree.fit
, the root node receives
a training dataset
X
; that is, a NumPy array of shape(n, m)
,and the training labels
Y
; that is, a NumPy array of shape(n, )
.
First, we decide to either grow the tree by splitting the data or settle on a prediction and call it a day leaf. The training stops when either
Y
only contains a single class label; that is,np.unique(Y)
returns an array of a single element,Y
contains fewer labels than the minimum number of samples to split; that is,len(Y) < min_samples_split
,or the max depth has been reached; that is,
max_depth == 0
.
We encode all this logic into the ClassificationTree._should_stop
method.
class ClassificationTree(ClassificationTree):
def _should_stop(self, Y: np.ndarray):
return (
(len(np.unique(Y)) == 1)
or (len(Y) < self.min_samples_split)
or (self.max_depth == 0)
)
If we are satisfied with the split and keep the node a leaf — that is, if _should_stop
returns True
— we determine the prediction by taking a majority vote. This will be done using the most_frequent_label
function below.
def most_frequent_label(Y):
Y_unique, counts = np.unique(Y, return_counts=True)
return Y_unique[np.argmax(counts)]
Now, if we the algorithm decides to grow the tree, we have to define the search space for the (split_feature_idx, threshold)
tuple that determines the split. The simplest idea is to try all possible combinations. As the threshold
is represented by a float for ordinal data (such as ours), we can take the arithmetic means of the ordered features.
It's easier to show this in code, so check it out. First, we sort all feature columns.
In [3]: X_sorted = np.sort(X, axis=0)
In [4]: X_sorted[:10]
Out[4]: array([[-3.60167075, -4.88758428],
[-3.31469894, -4.50345197],
[-3.13798361, -3.9132429 ],
[-3.12049845, -3.37056089],
[-3.03368792, -3.2190333 ],
[-2.74955241, -3.04883149],
[-2.7385946 , -3.00132329],
[-2.6883164 , -2.90979919],
[-2.65545839, -2.87220712],
[-2.60013506, -2.8599603 ]])
Now, we take the arithmetic means of the consecutive elements.
In [5]: thresholds = (X_sorted[1:] + X_sorted[:-1]) / 2
In [6]: thresholds[:10]
Out[6]: array([[-3.45818485, -4.69551812],
[-3.22634128, -4.20834744],
[-3.12924103, -3.6419019 ],
[-3.07709318, -3.29479709],
[-2.89162017, -3.13393239],
[-2.7440735 , -3.02507739],
[-2.7134555 , -2.95556124],
[-2.6718874 , -2.89100315],
[-2.62779673, -2.86608371],
[-2.55690005, -2.84546943]])
We are ready to take all possible combinations of features and thresholds, scoring the split one by one. Each element of the thresholds
array represents a split. Think about it: the column index encodes the feature to split along, while the value encodes the threshold itself.
i, feature_idx = 249, 0
left_idx = X[:, feature_idx] < thresholds[i, feature_idx]
right_idx = ~left_idx
In [7]: left_idx[:10]
Out[7]: array([ True, True, True, True, True, True, False, True, False, False])
With that in hand, the fit
method is pretty straightforward, albeit contains more characters than the previous snippets. Fear not: I'll mark each step with a comment to help you navigate. So, let's put all of the above together!
class ClassificationTree(ClassificationTree):
def fit(self, X: np.ndarray, Y: np.ndarray):
# stopping the training if needed
if self._should_stop(Y):
self.predicted_class = most_frequent_label(Y)
return self
# building the search space for splits
X_sorted = np.sort(X, axis=0)
thresholds = (X_sorted[1:] + X_sorted[:-1]) / 2
# this is where we'll store the scores
scores = np.zeros_like(thresholds)
# scoring the splits one by one
for (i, feature_idx), c in np.ndenumerate(thresholds):
left_idx = X[:, feature_idx] < c
right_idx = ~left_idx
split = [Y[left_idx], Y[right_idx]]
scores[i, feature_idx] = score_split(split)
# identifying the split with the lowest score
row_idx, self.split_feature_idx = np.unravel_index(
np.argmin(scores), scores.shape
)
self.threshold = thresholds[row_idx, self.split_feature_idx]
# recursively training a decision tree
left_idx = X[:, self.split_feature_idx] < self.threshold
right_idx = ~left_idx
# adding the left child
self.left_child = ClassificationTree(
max_depth=self.max_depth - 1,
min_samples_split=self.min_samples_split,
).fit(X[left_idx], Y[left_idx])
# adding the right child
self.right_child = ClassificationTree(
max_depth=self.max_depth - 1,
min_samples_split=self.min_samples_split,
).fit(X[right_idx], Y[right_idx])
# returning self to chain instantiation and training
return self
Should we test this? Currently, there's no predict
method to check the accuracy nor a digraph
attribute for visualization. So, there's work to be done. Let's start with the prediction.
As the training is recursive, so is the prediction. If the ClassificationTree
object is a leaf, we return a prediction. If not, we pass the data to the children along the split.
class ClassificationTree(ClassificationTree):
def predict(self, X: np.ndarray):
# returning the prediction if the node is a leaf
if self.is_leaf:
return np.full(X.shape[0], self.predicted_class)
# splitting the input and passing the splits to the
# proper child for prediction
predictions = np.zeros(X.shape[0], dtype=np.int32)
left_idx = X[:, self.split_feature_idx] < self.threshold
right_idx = ~left_idx
predictions[left_idx] = (
self.left_child.predict(X[left_idx]) if self.left_child else None
)
predictions[right_idx] = (
self.right_child.predict(X[right_idx]) if self.right_child else None
)
return predictions
Time to test!
tree = ClassificationTree(max_depth=16)
tree.fit(X, y)
y_pred = tree.predict(X)
In [8]: accuracy(y, y_pred)
Out[8]: np.float64(1.0)
It works perfectly, but you might wonder about testing the accuracy on the training data. It's against all sane machine learning practices, but trust me, it has its place. Overfitting a model on the training data is an excellent sanity check — if this is not possible, there must be a bug in the training code. Our `ClassificationTree` passes this test.
(Of course, this doesn't mean our code is bug-free. It's just an indication that the code produces a meaningful model.)
Now, depending on the version of `mlfz`, this is what you see when you open up the source.
class ClassificationTree(DecisionTree):
def __init__(self, max_depth=None, min_samples_split=2, **kwargs):
super().__init__(
most_frequent_label,
gini_impurity_leaf,
max_depth,
min_samples_split
)
These four lines cover the entire implementation of `ClassificationTree`. What's going on? This is what we'll answer next.
Generalizing classification trees
To generalize classification trees, let's see another special case: regression trees. We'll compare how they work, spot the differences, and abstract the rest.
Let's generate some toy data first! We'll use the function
and add some noise.
X = 10 * np.random.rand(200)
y = 0.5 * (X ** 2) + 10 * np.sin(X * 2.5) + 0.5 * np.random.normal(0, 5, size=X.shape[0])
y
describes a modulated sine wave with a slight polynomial drift.
The task of regression is to predict the target from the data; that is, y
from X
. We can easily adapt classification trees for that purpose: instead of returning a class label prediction, leaves return a float. Check it out.
Let's plot this.
That's a pretty rudimentary model, so let's move one level deeper. Here's a regression tree of depth 2 this time.
tree = RegressionTree(max_depth=2)
tree.fit(X.reshape(-1, 1), y)
Regression trees are also built recursively: the first split is exactly the same as before; we've just added two more splits.
This is a better fit, but it's still not there. However, decision trees rapidly improve with the increase of depth. Adding one more level means potentially doubling the number of splits. Check it out for max_depth=1
, 2
, 4
, and 8
.
So, what makes a regression tree different from the classification tree above?
Two things:
the way we score splits,
and the way we select a leaf's prediction.
Let's generalize all of the above in the new class DecisionTree
. Unsurprisingly, these are stored in the new attributes leaf_vote
and leaf_score
.
class DecisionTree(Model):
def __init__(
self,
leaf_vote, # new attribute
leaf_score, # new attribute
max_depth=None,
min_samples_split=2,
):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.leaf_vote = leaf_vote # new attribute
self.leaf_score = leaf_score # new attribute
self.split_feature_idx = 0
self.threshold = 0
self.left_child = None
self.right_child = None
self.predicted_class = None
def __repr__(self):
return f"{self.__class__.__name__}(max_depth={self.max_depth}, min_samples_split={self.min_samples_split})"
@property
def is_leaf(self):
return self.predicted_class != None
Besides the stopping logic, let's implement a couple of helper methods: one for making the split and the other for creating a child.
class DecisionTree(DecisionTree):
def _should_stop(self, Y: np.ndarray):
return (
(len(np.unique(Y)) == 1)
or (len(Y) < self.min_samples_split)
or (self.max_depth == 0)
)
def _split_idx(self, X: np.ndarray):
left_idx = X[:, self.split_feature_idx] < self.threshold
right_idx = ~left_idx
return left_idx, right_idx
def _create_child(self):
return self.__class__(
max_depth=self.max_depth - 1,
min_samples_split=self.min_samples_split,
leaf_score=self.leaf_score,
leaf_vote=self.leaf_vote,
)
Now, the fit
and predict
methods. It's almost the same as before.
class DecisionTree(DecisionTree):
def fit(self, X: np.ndarray, Y: np.ndarray):
if self._should_stop(Y):
self.predicted_class = self.leaf_vote(Y) # change
return self
X_sorted = np.sort(X, axis=0)
thresholds = (X_sorted[1:] + X_sorted[:-1]) / 2
scores = np.zeros_like(thresholds)
for (i, feature_idx), c in np.ndenumerate(thresholds):
left_idx = X[:, feature_idx] < c
right_idx = ~left_idx
split = [Y[left_idx], Y[right_idx]]
scores[i, feature_idx] = sum([(len(y) / len(Y)) * self.leaf_score(y) for y in split])
row_idx, self.split_feature_idx = np.unravel_index(
np.argmin(scores), scores.shape
)
self.threshold = thresholds[row_idx, self.split_feature_idx]
left_idx, right_idx = self._split_idx(X)
self.left_child = self._create_child().fit(
X[left_idx], Y[left_idx]
)
self.right_child = self._create_child().fit(
X[right_idx], Y[right_idx]
)
return self
def predict(self, X: np.ndarray):
if self.is_leaf:
return np.full(X.shape[0], self.predicted_class)
predictions = np.zeros(X.shape[0], dtype=np.int32)
left_idx, right_idx = self._split_idx(X)
predictions[left_idx] = (
self.left_child.predict(X[left_idx]) if self.left_child else None
)
predictions[right_idx] = (
self.right_child.predict(X[right_idx]) if self.right_child else None
)
return predictions
With DecisionTree
, defining classification and regression trees is as simple as passing the proper callables for scoring and voting. We can even further customize the stopping logic by adding an early stopping condition.
def average_label(Y):
return np.mean(Y)
def mean_squared_error(Y):
m = np.mean(Y)
return np.mean((Y - m) ** 2)
class RegressionTree(DecisionTree):
def __init__(
self,
max_depth=None,
min_samples_split=2,
min_score=1,
**kwargs
):
super().__init__(
max_depth=max_depth,
min_samples_split=min_samples_split,
leaf_vote=average_label,
leaf_score=mean_squared_error,
)
self.min_score = min_score
def _should_stop(self, Y: np.ndarray):
return (
(len(np.unique(Y)) == 1)
or (self.leaf_score(Y) < self.min_score)
or (len(Y) < self.min_samples_split)
or (self.max_depth == 0)
)
Let's try this out.
tree = RegressionTree(max_depth=4)
tree.fit(X.reshape(-1, 1), y)
With this, ClassificationTree
takes the following form.
class ClassificationTree(DecisionTree):
def __init__(self, max_depth=None, min_samples_split=2, **kwargs):
super().__init__(
max_depth=max_depth,
min_samples_split=min_samples_split,
leaf_vote=most_frequent_label,
leaf_score=leaf_gini_impurity,
)
And with that, we're done! Of course, these vanilla classification and regression trees are just the tip of the regression tree iceberg. There's much more: bagging, boosting, and so on. We will leave it at that for now and return to these topics later.
Brilliant exposition and witty too!