Expand source code
from operator import le, gt
from typing import Callable, List, Mapping, Optional, Tuple

import numpy as np

from ...errors import NoSplittableVariableException, NoPrunableNodeException
from ...mutation import TreeMutation, GrowMutation, PruneMutation
from ...node import LeafNode, DecisionNode, split_node
from ...samplers.scalar import DiscreteSampler
from ...samplers.treemutation import TreeMutationProposer
from ...split import SplitCondition
from ...tree import Tree


def grow_mutations(tree: Tree) -> List[TreeMutation]:
    return [GrowMutation(x, sample_split_node(x)) for x in tree.leaf_nodes]


def prune_mutations(tree: Tree) -> List[TreeMutation]:
    return [PruneMutation(x, LeafNode(x.split, depth=x.depth)) for x in tree.prunable_decision_nodes]


class UniformMutationProposer(TreeMutationProposer):

    def __init__(self,
                 p_grow: float=0.5,
                 p_prune: float=0.5):
        self.method_sampler = DiscreteSampler([grow_mutations, prune_mutations],
                                              [p_grow, p_prune],
                                              cache_size=1000)

    def propose(self, tree: Tree) -> TreeMutation:
        method = self.method_sampler.sample()
        try:
            return method(tree)
        except NoSplittableVariableException:
            return self.propose(tree)
        except NoPrunableNodeException:
            return self.propose(tree)


def sample_split_condition(node: LeafNode) -> Optional[Tuple[SplitCondition, SplitCondition]]:
    """
    Randomly sample a splitting rule for a particular leaf node
    Works based on two random draws

      - draw a node to split on based on multinomial distribution
      - draw an observation within that variable to split on

    Returns None if there isn't a possible non-degenerate split
    """
    split_variable = node.data.X.random_splittable_variable()
    split_value = node.data.X.random_splittable_value(split_variable)
    if split_value is None:
        return None
    return SplitCondition(split_variable, split_value, le), SplitCondition(split_variable, split_value, gt)


def sample_split_node(node: LeafNode) -> DecisionNode:
    """
    Split a leaf node into a decision node with two leaf children
    The variable and value to split on is determined by sampling from their respective distributions
    """
    if node.is_splittable():
        conditions = sample_split_condition(node)
        return split_node(node, conditions)
    else:
        return DecisionNode(node.split,
                            LeafNode(node.split, depth=node.depth + 1),
                            LeafNode(node.split, depth=node.depth + 1),
                            depth=node.depth)

Functions

def grow_mutations(tree: Tree) ‑> List[TreeMutation]
Expand source code
def grow_mutations(tree: Tree) -> List[TreeMutation]:
    return [GrowMutation(x, sample_split_node(x)) for x in tree.leaf_nodes]
def prune_mutations(tree: Tree) ‑> List[TreeMutation]
Expand source code
def prune_mutations(tree: Tree) -> List[TreeMutation]:
    return [PruneMutation(x, LeafNode(x.split, depth=x.depth)) for x in tree.prunable_decision_nodes]
def sample_split_condition(node: LeafNode) ‑> Optional[Tuple[SplitConditionSplitCondition]]

Randomly sample a splitting rule for a particular leaf node Works based on two random draws

  • draw a node to split on based on multinomial distribution
  • draw an observation within that variable to split on

Returns None if there isn't a possible non-degenerate split

Expand source code
def sample_split_condition(node: LeafNode) -> Optional[Tuple[SplitCondition, SplitCondition]]:
    """
    Randomly sample a splitting rule for a particular leaf node
    Works based on two random draws

      - draw a node to split on based on multinomial distribution
      - draw an observation within that variable to split on

    Returns None if there isn't a possible non-degenerate split
    """
    split_variable = node.data.X.random_splittable_variable()
    split_value = node.data.X.random_splittable_value(split_variable)
    if split_value is None:
        return None
    return SplitCondition(split_variable, split_value, le), SplitCondition(split_variable, split_value, gt)
def sample_split_node(node: LeafNode) ‑> DecisionNode

Split a leaf node into a decision node with two leaf children The variable and value to split on is determined by sampling from their respective distributions

Expand source code
def sample_split_node(node: LeafNode) -> DecisionNode:
    """
    Split a leaf node into a decision node with two leaf children
    The variable and value to split on is determined by sampling from their respective distributions
    """
    if node.is_splittable():
        conditions = sample_split_condition(node)
        return split_node(node, conditions)
    else:
        return DecisionNode(node.split,
                            LeafNode(node.split, depth=node.depth + 1),
                            LeafNode(node.split, depth=node.depth + 1),
                            depth=node.depth)

Classes

class UniformMutationProposer (p_grow: float = 0.5, p_prune: float = 0.5)

A TreeMutationProposer is responsible for generating samples from tree space It is capable of generating proposed TreeMutations

Expand source code
class UniformMutationProposer(TreeMutationProposer):

    def __init__(self,
                 p_grow: float=0.5,
                 p_prune: float=0.5):
        self.method_sampler = DiscreteSampler([grow_mutations, prune_mutations],
                                              [p_grow, p_prune],
                                              cache_size=1000)

    def propose(self, tree: Tree) -> TreeMutation:
        method = self.method_sampler.sample()
        try:
            return method(tree)
        except NoSplittableVariableException:
            return self.propose(tree)
        except NoPrunableNodeException:
            return self.propose(tree)

Ancestors

Inherited members