Expand source code
from typing import List
import numpy as np
from .mutation import TreeMutation
from .node import TreeNode, LeafNode, DecisionNode, deep_copy_node
class Tree:
"""
An encapsulation of the structure of a single decision tree
Contains no logic, but keeps track of 4 different kinds of nodes within the tree:
- leaf nodes
- decision nodes
- splittable leaf nodes
- prunable decision nodes
Parameters
----------
nodes: List[Node]
All nodes contained in the tree, i.e. decision and leaf nodes
"""
def __init__(self, nodes: List[TreeNode]):
self._nodes = nodes
self.cache_up_to_date = False
self._prediction = None
@property
def nodes(self) -> List[TreeNode]:
"""
List of all nodes contained in the tree
"""
return self._nodes
@property
def leaf_nodes(self) -> List[LeafNode]:
"""
List of all of the leaf nodes in the tree
"""
return [x for x in self._nodes if type(x) == LeafNode]
@property
def splittable_leaf_nodes(self) -> List[LeafNode]:
"""
List of all leaf nodes in the tree which can be split in a non-degenerate way
i.e. not all rows of the covariate matrix are duplicates
"""
return [x for x in self.leaf_nodes if x.is_splittable()]
@property
def decision_nodes(self) -> List[DecisionNode]:
"""
List of decision nodes in the tree.
Decision nodes are internal split nodes, i.e. not leaf nodes
"""
return [x for x in self._nodes if type(x) == DecisionNode]
@property
def prunable_decision_nodes(self) -> List[DecisionNode]:
"""
List of decision nodes in the tree that are suitable for pruning
In particular, decision nodes that have two leaf node children
"""
return [x for x in self.decision_nodes if x.is_prunable()]
def update_y(self, y: np.ndarray) -> None:
"""
Update the cached value of the target array in all nodes
Used to pass in the residuals from the sum of all of the other trees
"""
self.cache_up_to_date = False
for node in self.nodes:
node.update_y(y)
def predict(self, X: np.ndarray=None) -> np.ndarray:
"""
Generate a set of predictions with the same dimensionality as the target array
Note that the prediction is from one tree, so represents only (1 / number_of_trees) of the target
"""
if X is not None:
return self._out_of_sample_predict(X)
if self.cache_up_to_date:
return self._prediction
for leaf in self.leaf_nodes:
if self._prediction is None:
self._prediction = np.zeros(self.nodes[0].data.X.n_obsv)
self._prediction[leaf.split.condition()] = leaf.predict()
self.cache_up_to_date = True
return self._prediction
def _out_of_sample_predict(self, X) -> np.ndarray:
"""
Prediction for a covariate matrix not used for training
Note that this is quite slow
Parameters
----------
X: pd.DataFrame
Covariates to predict for
Returns
-------
np.ndarray
"""
prediction = np.array([0.] * len(X))
for leaf in self.leaf_nodes:
prediction[leaf.split.condition(X)] = leaf.predict()
return prediction
def remove_node(self, node: TreeNode) -> None:
"""
Remove a single node from the tree
Note that this is non-recursive, only drops the node and not any children
"""
self._nodes.remove(node)
def add_node(self, node: TreeNode) -> None:
"""
Add a node to the tree
Note that this is non-recursive, only adds the node and not any children
"""
self._nodes.append(node)
def mutate(tree: Tree, mutation: TreeMutation) -> None:
"""
Apply a change to the structure of the tree
Modifies not only the tree, but also the links between the TreeNodes
Parameters
----------
tree: Tree
The tree to mutate
mutation: TreeMutation
The mutation to apply to the tree
"""
tree.cache_up_to_date = False
if mutation.kind == "prune":
tree.remove_node(mutation.existing_node)
tree.remove_node(mutation.existing_node.left_child)
tree.remove_node(mutation.existing_node.right_child)
tree.add_node(mutation.updated_node)
if mutation.kind == "grow":
tree.remove_node(mutation.existing_node)
tree.add_node(mutation.updated_node.left_child)
tree.add_node(mutation.updated_node.right_child)
tree.add_node(mutation.updated_node)
for node in tree.nodes:
if node.right_child == mutation.existing_node:
node._right_child = mutation.updated_node
if node.left_child == mutation.existing_node:
node._left_child = mutation.updated_node
def deep_copy_tree(tree: Tree):
"""
Efficiently create a copy of the tree for storage
Creates a memory-light version of the tree with access to important information
Parameters
----------
tree: Tree
Tree to copy
Returns
-------
Tree
Version of the tree optimized to be low memory
"""
return Tree([deep_copy_node(x) for x in tree.nodes])
Functions
def deep_copy_tree(tree: Tree)
-
Efficiently create a copy of the tree for storage Creates a memory-light version of the tree with access to important information Parameters
tree
:Tree
- Tree to copy
Returns
Tree
- Version of the tree optimized to be low memory
Expand source code
def deep_copy_tree(tree: Tree): """ Efficiently create a copy of the tree for storage Creates a memory-light version of the tree with access to important information Parameters ---------- tree: Tree Tree to copy Returns ------- Tree Version of the tree optimized to be low memory """ return Tree([deep_copy_node(x) for x in tree.nodes])
def mutate(tree: Tree, mutation: TreeMutation) ‑> None
-
Apply a change to the structure of the tree Modifies not only the tree, but also the links between the TreeNodes
Parameters
tree
:Tree
- The tree to mutate
mutation
:TreeMutation
- The mutation to apply to the tree
Expand source code
def mutate(tree: Tree, mutation: TreeMutation) -> None: """ Apply a change to the structure of the tree Modifies not only the tree, but also the links between the TreeNodes Parameters ---------- tree: Tree The tree to mutate mutation: TreeMutation The mutation to apply to the tree """ tree.cache_up_to_date = False if mutation.kind == "prune": tree.remove_node(mutation.existing_node) tree.remove_node(mutation.existing_node.left_child) tree.remove_node(mutation.existing_node.right_child) tree.add_node(mutation.updated_node) if mutation.kind == "grow": tree.remove_node(mutation.existing_node) tree.add_node(mutation.updated_node.left_child) tree.add_node(mutation.updated_node.right_child) tree.add_node(mutation.updated_node) for node in tree.nodes: if node.right_child == mutation.existing_node: node._right_child = mutation.updated_node if node.left_child == mutation.existing_node: node._left_child = mutation.updated_node
Classes
class Tree (nodes: List[TreeNode])
-
An encapsulation of the structure of a single decision tree Contains no logic, but keeps track of 4 different kinds of nodes within the tree: - leaf nodes - decision nodes - splittable leaf nodes - prunable decision nodes
Parameters
nodes
:List[Node]
- All nodes contained in the tree, i.e. decision and leaf nodes
Expand source code
class Tree: """ An encapsulation of the structure of a single decision tree Contains no logic, but keeps track of 4 different kinds of nodes within the tree: - leaf nodes - decision nodes - splittable leaf nodes - prunable decision nodes Parameters ---------- nodes: List[Node] All nodes contained in the tree, i.e. decision and leaf nodes """ def __init__(self, nodes: List[TreeNode]): self._nodes = nodes self.cache_up_to_date = False self._prediction = None @property def nodes(self) -> List[TreeNode]: """ List of all nodes contained in the tree """ return self._nodes @property def leaf_nodes(self) -> List[LeafNode]: """ List of all of the leaf nodes in the tree """ return [x for x in self._nodes if type(x) == LeafNode] @property def splittable_leaf_nodes(self) -> List[LeafNode]: """ List of all leaf nodes in the tree which can be split in a non-degenerate way i.e. not all rows of the covariate matrix are duplicates """ return [x for x in self.leaf_nodes if x.is_splittable()] @property def decision_nodes(self) -> List[DecisionNode]: """ List of decision nodes in the tree. Decision nodes are internal split nodes, i.e. not leaf nodes """ return [x for x in self._nodes if type(x) == DecisionNode] @property def prunable_decision_nodes(self) -> List[DecisionNode]: """ List of decision nodes in the tree that are suitable for pruning In particular, decision nodes that have two leaf node children """ return [x for x in self.decision_nodes if x.is_prunable()] def update_y(self, y: np.ndarray) -> None: """ Update the cached value of the target array in all nodes Used to pass in the residuals from the sum of all of the other trees """ self.cache_up_to_date = False for node in self.nodes: node.update_y(y) def predict(self, X: np.ndarray=None) -> np.ndarray: """ Generate a set of predictions with the same dimensionality as the target array Note that the prediction is from one tree, so represents only (1 / number_of_trees) of the target """ if X is not None: return self._out_of_sample_predict(X) if self.cache_up_to_date: return self._prediction for leaf in self.leaf_nodes: if self._prediction is None: self._prediction = np.zeros(self.nodes[0].data.X.n_obsv) self._prediction[leaf.split.condition()] = leaf.predict() self.cache_up_to_date = True return self._prediction def _out_of_sample_predict(self, X) -> np.ndarray: """ Prediction for a covariate matrix not used for training Note that this is quite slow Parameters ---------- X: pd.DataFrame Covariates to predict for Returns ------- np.ndarray """ prediction = np.array([0.] * len(X)) for leaf in self.leaf_nodes: prediction[leaf.split.condition(X)] = leaf.predict() return prediction def remove_node(self, node: TreeNode) -> None: """ Remove a single node from the tree Note that this is non-recursive, only drops the node and not any children """ self._nodes.remove(node) def add_node(self, node: TreeNode) -> None: """ Add a node to the tree Note that this is non-recursive, only adds the node and not any children """ self._nodes.append(node)
Instance variables
var decision_nodes : List[DecisionNode]
-
List of decision nodes in the tree. Decision nodes are internal split nodes, i.e. not leaf nodes
Expand source code
@property def decision_nodes(self) -> List[DecisionNode]: """ List of decision nodes in the tree. Decision nodes are internal split nodes, i.e. not leaf nodes """ return [x for x in self._nodes if type(x) == DecisionNode]
var leaf_nodes : List[LeafNode]
-
List of all of the leaf nodes in the tree
Expand source code
@property def leaf_nodes(self) -> List[LeafNode]: """ List of all of the leaf nodes in the tree """ return [x for x in self._nodes if type(x) == LeafNode]
var nodes : List[TreeNode]
-
List of all nodes contained in the tree
Expand source code
@property def nodes(self) -> List[TreeNode]: """ List of all nodes contained in the tree """ return self._nodes
var prunable_decision_nodes : List[DecisionNode]
-
List of decision nodes in the tree that are suitable for pruning In particular, decision nodes that have two leaf node children
Expand source code
@property def prunable_decision_nodes(self) -> List[DecisionNode]: """ List of decision nodes in the tree that are suitable for pruning In particular, decision nodes that have two leaf node children """ return [x for x in self.decision_nodes if x.is_prunable()]
var splittable_leaf_nodes : List[LeafNode]
-
List of all leaf nodes in the tree which can be split in a non-degenerate way i.e. not all rows of the covariate matrix are duplicates
Expand source code
@property def splittable_leaf_nodes(self) -> List[LeafNode]: """ List of all leaf nodes in the tree which can be split in a non-degenerate way i.e. not all rows of the covariate matrix are duplicates """ return [x for x in self.leaf_nodes if x.is_splittable()]
Methods
def add_node(self, node: TreeNode) ‑> None
-
Add a node to the tree Note that this is non-recursive, only adds the node and not any children
Expand source code
def add_node(self, node: TreeNode) -> None: """ Add a node to the tree Note that this is non-recursive, only adds the node and not any children """ self._nodes.append(node)
def predict(self, X: numpy.ndarray = None) ‑> numpy.ndarray
-
Generate a set of predictions with the same dimensionality as the target array Note that the prediction is from one tree, so represents only (1 / number_of_trees) of the target
Expand source code
def predict(self, X: np.ndarray=None) -> np.ndarray: """ Generate a set of predictions with the same dimensionality as the target array Note that the prediction is from one tree, so represents only (1 / number_of_trees) of the target """ if X is not None: return self._out_of_sample_predict(X) if self.cache_up_to_date: return self._prediction for leaf in self.leaf_nodes: if self._prediction is None: self._prediction = np.zeros(self.nodes[0].data.X.n_obsv) self._prediction[leaf.split.condition()] = leaf.predict() self.cache_up_to_date = True return self._prediction
def remove_node(self, node: TreeNode) ‑> None
-
Remove a single node from the tree Note that this is non-recursive, only drops the node and not any children
Expand source code
def remove_node(self, node: TreeNode) -> None: """ Remove a single node from the tree Note that this is non-recursive, only drops the node and not any children """ self._nodes.remove(node)
def update_y(self, y: numpy.ndarray) ‑> None
-
Update the cached value of the target array in all nodes Used to pass in the residuals from the sum of all of the other trees
Expand source code
def update_y(self, y: np.ndarray) -> None: """ Update the cached value of the target array in all nodes Used to pass in the residuals from the sum of all of the other trees """ self.cache_up_to_date = False for node in self.nodes: node.update_y(y)