Expand source code
from typing import Union, Tuple
from .data import Data
from .split import Split, SplitCondition
class TreeNode(object):
"""
A representation of a node in the Tree
Contains two main types of information:
- Data relevant for the node
- Links to children nodes
"""
def __init__(self, split: Split, depth: int, left_child: 'TreeNode' = None, right_child: 'TreeNode' = None):
self.depth = depth
self._split = split
self._left_child = left_child
self._right_child = right_child
@property
def data(self) -> Data:
return self._split.data
@property
def left_child(self) -> 'TreeNode':
return self._left_child
@property
def right_child(self) -> 'TreeNode':
return self._right_child
@property
def split(self):
return self._split
def update_y(self, y):
self.data.update_y(y)
if self.left_child is not None:
self.left_child.update_y(y)
self.right_child.update_y(y)
class LeafNode(TreeNode):
"""
A representation of a leaf node in the tree
In addition to the normal work of a `Node`, a `LeafNode` is responsible for:
- Interacting with `Data`
- Making predictions
"""
def __init__(self, split: Split, depth=0, value=0.0):
self._value = value
super().__init__(split, depth, None, None)
def set_value(self, value: float) -> None:
self._value = value
def set_mean_response(self, value: float) -> None:
self._mean_response = value
@property
def current_value(self):
return self._value
def predict(self) -> float:
return self.current_value
def is_splittable(self) -> bool:
return self.data.X.is_at_least_one_splittable_variable()
@property
def n_obs(self):
return self.data.X.n_obsv
@property
def mean_response(self):
return self._mean_response
class DecisionNode(TreeNode):
"""
A `DecisionNode` encapsulates internal node in the tree
Unlike a `LeafNode`, it contains very little actual logic beyond tying the tree together
"""
def __init__(self, split: Split, left_child_node: TreeNode, right_child_node: TreeNode, depth=0):
super().__init__(split, depth, left_child_node, right_child_node)
def is_prunable(self) -> bool:
return type(self.left_child) == LeafNode and type(self.right_child) == LeafNode
def most_recent_split_condition(self) -> SplitCondition:
return self.left_child.split.most_recent_split_condition()
@property
def n_obs(self):
n_l = self.left_child.n_obs
n_r = self.right_child.n_obs
return n_l + n_r
@property
def current_value(self):
n_l = self.left_child.n_obs
n_r = self.right_child.n_obs
l_val = self.left_child.current_value if type(self.left_child) == DecisionNode else self.left_child.mean_response
r_val = self.right_child.current_value if type(self.right_child) == DecisionNode else self.right_child.mean_response
l_sum = l_val * n_l
r_sum = r_val * n_r
return (l_sum + r_sum) / self.n_obs
def split_node(node: LeafNode, split_conditions: Tuple[SplitCondition, SplitCondition]) -> DecisionNode:
"""
Converts a `LeafNode` into an internal `DecisionNode` by applying the split condition
The left node contains all values for the splitting variable less than the splitting value
"""
left_split = node.split + split_conditions[0]
split_conditions[1].carry_n_obsv = node.data.X.n_obsv - left_split.data.X.n_obsv
split_conditions[1].carry_y_sum = node.data.y.summed_y() - left_split.data.y.summed_y()
right_split = node.split + split_conditions[1]
return DecisionNode(node.split,
LeafNode(left_split, depth=node.depth + 1),
LeafNode(right_split, depth=node.depth + 1),
depth=node.depth)
def deep_copy_node(node: TreeNode):
if type(node) == LeafNode:
node: LeafNode = node
return LeafNode(node.split.out_of_sample_conditioner(), value=node.current_value, depth=node.depth)
elif type(node) == DecisionNode:
node: DecisionNode = node
return DecisionNode(node.split.out_of_sample_conditioner(), node.left_child, node.right_child, depth=node.depth)
else:
raise TypeError("Unsupported node type")
Functions
def deep_copy_node(node: TreeNode)
-
Expand source code
def deep_copy_node(node: TreeNode): if type(node) == LeafNode: node: LeafNode = node return LeafNode(node.split.out_of_sample_conditioner(), value=node.current_value, depth=node.depth) elif type(node) == DecisionNode: node: DecisionNode = node return DecisionNode(node.split.out_of_sample_conditioner(), node.left_child, node.right_child, depth=node.depth) else: raise TypeError("Unsupported node type")
def split_node(node: LeafNode, split_conditions: Tuple[SplitCondition, SplitCondition]) ‑> DecisionNode
-
Converts a
LeafNode
into an internalDecisionNode
by applying the split condition The left node contains all values for the splitting variable less than the splitting valueExpand source code
def split_node(node: LeafNode, split_conditions: Tuple[SplitCondition, SplitCondition]) -> DecisionNode: """ Converts a `LeafNode` into an internal `DecisionNode` by applying the split condition The left node contains all values for the splitting variable less than the splitting value """ left_split = node.split + split_conditions[0] split_conditions[1].carry_n_obsv = node.data.X.n_obsv - left_split.data.X.n_obsv split_conditions[1].carry_y_sum = node.data.y.summed_y() - left_split.data.y.summed_y() right_split = node.split + split_conditions[1] return DecisionNode(node.split, LeafNode(left_split, depth=node.depth + 1), LeafNode(right_split, depth=node.depth + 1), depth=node.depth)
Classes
class DecisionNode (split: Split, left_child_node: TreeNode, right_child_node: TreeNode, depth=0)
-
A
DecisionNode
encapsulates internal node in the tree Unlike aLeafNode
, it contains very little actual logic beyond tying the tree togetherExpand source code
class DecisionNode(TreeNode): """ A `DecisionNode` encapsulates internal node in the tree Unlike a `LeafNode`, it contains very little actual logic beyond tying the tree together """ def __init__(self, split: Split, left_child_node: TreeNode, right_child_node: TreeNode, depth=0): super().__init__(split, depth, left_child_node, right_child_node) def is_prunable(self) -> bool: return type(self.left_child) == LeafNode and type(self.right_child) == LeafNode def most_recent_split_condition(self) -> SplitCondition: return self.left_child.split.most_recent_split_condition() @property def n_obs(self): n_l = self.left_child.n_obs n_r = self.right_child.n_obs return n_l + n_r @property def current_value(self): n_l = self.left_child.n_obs n_r = self.right_child.n_obs l_val = self.left_child.current_value if type(self.left_child) == DecisionNode else self.left_child.mean_response r_val = self.right_child.current_value if type(self.right_child) == DecisionNode else self.right_child.mean_response l_sum = l_val * n_l r_sum = r_val * n_r return (l_sum + r_sum) / self.n_obs
Ancestors
Instance variables
var current_value
-
Expand source code
@property def current_value(self): n_l = self.left_child.n_obs n_r = self.right_child.n_obs l_val = self.left_child.current_value if type(self.left_child) == DecisionNode else self.left_child.mean_response r_val = self.right_child.current_value if type(self.right_child) == DecisionNode else self.right_child.mean_response l_sum = l_val * n_l r_sum = r_val * n_r return (l_sum + r_sum) / self.n_obs
var n_obs
-
Expand source code
@property def n_obs(self): n_l = self.left_child.n_obs n_r = self.right_child.n_obs return n_l + n_r
Methods
def is_prunable(self) ‑> bool
-
Expand source code
def is_prunable(self) -> bool: return type(self.left_child) == LeafNode and type(self.right_child) == LeafNode
def most_recent_split_condition(self) ‑> SplitCondition
-
Expand source code
def most_recent_split_condition(self) -> SplitCondition: return self.left_child.split.most_recent_split_condition()
class LeafNode (split: Split, depth=0, value=0.0)
-
A representation of a leaf node in the tree In addition to the normal work of a
Node
, aLeafNode
is responsible for: - Interacting withData
- Making predictionsExpand source code
class LeafNode(TreeNode): """ A representation of a leaf node in the tree In addition to the normal work of a `Node`, a `LeafNode` is responsible for: - Interacting with `Data` - Making predictions """ def __init__(self, split: Split, depth=0, value=0.0): self._value = value super().__init__(split, depth, None, None) def set_value(self, value: float) -> None: self._value = value def set_mean_response(self, value: float) -> None: self._mean_response = value @property def current_value(self): return self._value def predict(self) -> float: return self.current_value def is_splittable(self) -> bool: return self.data.X.is_at_least_one_splittable_variable() @property def n_obs(self): return self.data.X.n_obsv @property def mean_response(self): return self._mean_response
Ancestors
Instance variables
var current_value
-
Expand source code
@property def current_value(self): return self._value
var mean_response
-
Expand source code
@property def mean_response(self): return self._mean_response
var n_obs
-
Expand source code
@property def n_obs(self): return self.data.X.n_obsv
Methods
def is_splittable(self) ‑> bool
-
Expand source code
def is_splittable(self) -> bool: return self.data.X.is_at_least_one_splittable_variable()
def predict(self) ‑> float
-
Expand source code
def predict(self) -> float: return self.current_value
def set_mean_response(self, value: float) ‑> None
-
Expand source code
def set_mean_response(self, value: float) -> None: self._mean_response = value
def set_value(self, value: float) ‑> None
-
Expand source code
def set_value(self, value: float) -> None: self._value = value
class TreeNode (split: Split, depth: int, left_child: TreeNode = None, right_child: TreeNode = None)
-
A representation of a node in the Tree Contains two main types of information: - Data relevant for the node - Links to children nodes
Expand source code
class TreeNode(object): """ A representation of a node in the Tree Contains two main types of information: - Data relevant for the node - Links to children nodes """ def __init__(self, split: Split, depth: int, left_child: 'TreeNode' = None, right_child: 'TreeNode' = None): self.depth = depth self._split = split self._left_child = left_child self._right_child = right_child @property def data(self) -> Data: return self._split.data @property def left_child(self) -> 'TreeNode': return self._left_child @property def right_child(self) -> 'TreeNode': return self._right_child @property def split(self): return self._split def update_y(self, y): self.data.update_y(y) if self.left_child is not None: self.left_child.update_y(y) self.right_child.update_y(y)
Subclasses
Instance variables
var data : Data
-
Expand source code
@property def data(self) -> Data: return self._split.data
var left_child : TreeNode
-
Expand source code
@property def left_child(self) -> 'TreeNode': return self._left_child
var right_child : TreeNode
-
Expand source code
@property def right_child(self) -> 'TreeNode': return self._right_child
var split
-
Expand source code
@property def split(self): return self._split
Methods
def update_y(self, y)
-
Expand source code
def update_y(self, y): self.data.update_y(y) if self.left_child is not None: self.left_child.update_y(y) self.right_child.update_y(y)