223 lines
8.7 KiB
Python
223 lines
8.7 KiB
Python
|
"""
|
||
|
=========================================
|
||
|
Understanding the decision tree structure
|
||
|
=========================================
|
||
|
|
||
|
The decision tree structure can be analysed to gain further insight on the
|
||
|
relation between the features and the target to predict. In this example, we
|
||
|
show how to retrieve:
|
||
|
|
||
|
- the binary tree structure;
|
||
|
- the depth of each node and whether or not it's a leaf;
|
||
|
- the nodes that were reached by a sample using the ``decision_path`` method;
|
||
|
- the leaf that was reached by a sample using the apply method;
|
||
|
- the rules that were used to predict a sample;
|
||
|
- the decision path shared by a group of samples.
|
||
|
|
||
|
"""
|
||
|
|
||
|
import numpy as np
|
||
|
from matplotlib import pyplot as plt
|
||
|
|
||
|
from sklearn import tree
|
||
|
from sklearn.datasets import load_iris
|
||
|
from sklearn.model_selection import train_test_split
|
||
|
from sklearn.tree import DecisionTreeClassifier
|
||
|
|
||
|
##############################################################################
|
||
|
# Train tree classifier
|
||
|
# ---------------------
|
||
|
# First, we fit a :class:`~sklearn.tree.DecisionTreeClassifier` using the
|
||
|
# :func:`~sklearn.datasets.load_iris` dataset.
|
||
|
|
||
|
iris = load_iris()
|
||
|
X = iris.data
|
||
|
y = iris.target
|
||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
||
|
|
||
|
clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
|
||
|
clf.fit(X_train, y_train)
|
||
|
|
||
|
##############################################################################
|
||
|
# Tree structure
|
||
|
# --------------
|
||
|
#
|
||
|
# The decision classifier has an attribute called ``tree_`` which allows access
|
||
|
# to low level attributes such as ``node_count``, the total number of nodes,
|
||
|
# and ``max_depth``, the maximal depth of the tree. The
|
||
|
# ``tree_.compute_node_depths()`` method computes the depth of each node in the
|
||
|
# tree. `tree_` also stores the entire binary tree structure, represented as a
|
||
|
# number of parallel arrays. The i-th element of each array holds information
|
||
|
# about the node ``i``. Node 0 is the tree's root. Some of the arrays only
|
||
|
# apply to either leaves or split nodes. In this case the values of the nodes
|
||
|
# of the other type is arbitrary. For example, the arrays ``feature`` and
|
||
|
# ``threshold`` only apply to split nodes. The values for leaf nodes in these
|
||
|
# arrays are therefore arbitrary.
|
||
|
#
|
||
|
# Among these arrays, we have:
|
||
|
#
|
||
|
# - ``children_left[i]``: id of the left child of node ``i`` or -1 if leaf
|
||
|
# node
|
||
|
# - ``children_right[i]``: id of the right child of node ``i`` or -1 if leaf
|
||
|
# node
|
||
|
# - ``feature[i]``: feature used for splitting node ``i``
|
||
|
# - ``threshold[i]``: threshold value at node ``i``
|
||
|
# - ``n_node_samples[i]``: the number of training samples reaching node
|
||
|
# ``i``
|
||
|
# - ``impurity[i]``: the impurity at node ``i``
|
||
|
# - ``weighted_n_node_samples[i]``: the weighted number of training samples
|
||
|
# reaching node ``i``
|
||
|
# - ``value[i, j, k]``: the summary of the training samples that reached node i for
|
||
|
# output j and class k (for regression tree, class is set to 1).
|
||
|
#
|
||
|
# Using the arrays, we can traverse the tree structure to compute various
|
||
|
# properties. Below, we will compute the depth of each node and whether or not
|
||
|
# it is a leaf.
|
||
|
|
||
|
n_nodes = clf.tree_.node_count
|
||
|
children_left = clf.tree_.children_left
|
||
|
children_right = clf.tree_.children_right
|
||
|
feature = clf.tree_.feature
|
||
|
threshold = clf.tree_.threshold
|
||
|
values = clf.tree_.value
|
||
|
|
||
|
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
|
||
|
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
|
||
|
stack = [(0, 0)] # start with the root node id (0) and its depth (0)
|
||
|
while len(stack) > 0:
|
||
|
# `pop` ensures each node is only visited once
|
||
|
node_id, depth = stack.pop()
|
||
|
node_depth[node_id] = depth
|
||
|
|
||
|
# If the left and right child of a node is not the same we have a split
|
||
|
# node
|
||
|
is_split_node = children_left[node_id] != children_right[node_id]
|
||
|
# If a split node, append left and right children and depth to `stack`
|
||
|
# so we can loop through them
|
||
|
if is_split_node:
|
||
|
stack.append((children_left[node_id], depth + 1))
|
||
|
stack.append((children_right[node_id], depth + 1))
|
||
|
else:
|
||
|
is_leaves[node_id] = True
|
||
|
|
||
|
print(
|
||
|
"The binary tree structure has {n} nodes and has "
|
||
|
"the following tree structure:\n".format(n=n_nodes)
|
||
|
)
|
||
|
for i in range(n_nodes):
|
||
|
if is_leaves[i]:
|
||
|
print(
|
||
|
"{space}node={node} is a leaf node with value={value}.".format(
|
||
|
space=node_depth[i] * "\t", node=i, value=values[i]
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
print(
|
||
|
"{space}node={node} is a split node with value={value}: "
|
||
|
"go to node {left} if X[:, {feature}] <= {threshold} "
|
||
|
"else to node {right}.".format(
|
||
|
space=node_depth[i] * "\t",
|
||
|
node=i,
|
||
|
left=children_left[i],
|
||
|
feature=feature[i],
|
||
|
threshold=threshold[i],
|
||
|
right=children_right[i],
|
||
|
value=values[i],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# %%
|
||
|
# What is the values array used here?
|
||
|
# -----------------------------------
|
||
|
# The `tree_.value` array is a 3D array of shape
|
||
|
# [``n_nodes``, ``n_classes``, ``n_outputs``] which provides the count of samples
|
||
|
# reaching a node for each class and for each output. Each node has a ``value``
|
||
|
# array which is the number of weighted samples reaching this
|
||
|
# node for each output and class.
|
||
|
#
|
||
|
# For example, in the above tree built on the iris dataset, the root node has
|
||
|
# ``value = [37, 34, 41]``, indicating there are 37 samples
|
||
|
# of class 0, 34 samples of class 1, and 41 samples of class 2 at the root node.
|
||
|
# Traversing the tree, the samples are split and as a result, the ``value`` array
|
||
|
# reaching each node changes. The left child of the root node has ``value = [37, 0, 0]``
|
||
|
# because all 37 samples in the left child node are from class 0.
|
||
|
#
|
||
|
# Note: In this example, `n_outputs=1`, but the tree classifier can also handle
|
||
|
# multi-output problems. The `value` array at each node would just be a 2D
|
||
|
# array instead.
|
||
|
|
||
|
##############################################################################
|
||
|
# We can compare the above output to the plot of the decision tree.
|
||
|
|
||
|
tree.plot_tree(clf)
|
||
|
plt.show()
|
||
|
|
||
|
##############################################################################
|
||
|
# Decision path
|
||
|
# -------------
|
||
|
#
|
||
|
# We can also retrieve the decision path of samples of interest. The
|
||
|
# ``decision_path`` method outputs an indicator matrix that allows us to
|
||
|
# retrieve the nodes the samples of interest traverse through. A non zero
|
||
|
# element in the indicator matrix at position ``(i, j)`` indicates that
|
||
|
# the sample ``i`` goes through the node ``j``. Or, for one sample ``i``, the
|
||
|
# positions of the non zero elements in row ``i`` of the indicator matrix
|
||
|
# designate the ids of the nodes that sample goes through.
|
||
|
#
|
||
|
# The leaf ids reached by samples of interest can be obtained with the
|
||
|
# ``apply`` method. This returns an array of the node ids of the leaves
|
||
|
# reached by each sample of interest. Using the leaf ids and the
|
||
|
# ``decision_path`` we can obtain the splitting conditions that were used to
|
||
|
# predict a sample or a group of samples. First, let's do it for one sample.
|
||
|
# Note that ``node_index`` is a sparse matrix.
|
||
|
|
||
|
node_indicator = clf.decision_path(X_test)
|
||
|
leaf_id = clf.apply(X_test)
|
||
|
|
||
|
sample_id = 0
|
||
|
# obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id`
|
||
|
node_index = node_indicator.indices[
|
||
|
node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1]
|
||
|
]
|
||
|
|
||
|
print("Rules used to predict sample {id}:\n".format(id=sample_id))
|
||
|
for node_id in node_index:
|
||
|
# continue to the next node if it is a leaf node
|
||
|
if leaf_id[sample_id] == node_id:
|
||
|
continue
|
||
|
|
||
|
# check if value of the split feature for sample 0 is below threshold
|
||
|
if X_test[sample_id, feature[node_id]] <= threshold[node_id]:
|
||
|
threshold_sign = "<="
|
||
|
else:
|
||
|
threshold_sign = ">"
|
||
|
|
||
|
print(
|
||
|
"decision node {node} : (X_test[{sample}, {feature}] = {value}) "
|
||
|
"{inequality} {threshold})".format(
|
||
|
node=node_id,
|
||
|
sample=sample_id,
|
||
|
feature=feature[node_id],
|
||
|
value=X_test[sample_id, feature[node_id]],
|
||
|
inequality=threshold_sign,
|
||
|
threshold=threshold[node_id],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
##############################################################################
|
||
|
# For a group of samples, we can determine the common nodes the samples go
|
||
|
# through.
|
||
|
|
||
|
sample_ids = [0, 1]
|
||
|
# boolean array indicating the nodes both samples go through
|
||
|
common_nodes = node_indicator.toarray()[sample_ids].sum(axis=0) == len(sample_ids)
|
||
|
# obtain node ids using position in array
|
||
|
common_node_id = np.arange(n_nodes)[common_nodes]
|
||
|
|
||
|
print(
|
||
|
"\nThe following samples {samples} share the node(s) {nodes} in the tree.".format(
|
||
|
samples=sample_ids, nodes=common_node_id
|
||
|
)
|
||
|
)
|
||
|
print("This is {prop}% of all nodes.".format(prop=100 * len(common_node_id) / n_nodes))
|