Expand source code
import numpy as np
import pandas as pd

from collections import namedtuple

from sklearn import __version__
from sklearn.base import ClassifierMixin, RegressorMixin
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.tree._tree import Tree

TreeData = namedtuple('TreeData', 'left_child right_child feature threshold impurity n_node_samples weighted_n_node_samples missing_go_to_left')

def _extract_arrays_from_figs_tree(figs_tree):
    """Takes in a FIGS tree and recursively converts it to arrays that we can later use to build a sklearn decision tree object
    """
    tree_data = TreeData(
        left_child=[],
        right_child=[],
        feature=[],
        threshold=[],
        impurity=[],
        n_node_samples=[],
        weighted_n_node_samples=[],
        missing_go_to_left=[],
    )

    value_sklearn_array = []

    def _update_node(node):
        if node is None:
            return

        node_id_left = node_id_right = -1
        feature = threshold = -2
        value_sklearn = node.value_sklearn

        has_children = node.left is not None
        if has_children:
            node_id_left = node.left.node_id
            node_id_right = node.right.node_id
            feature = node.feature
            threshold = node.threshold

        tree_data.left_child.append(node_id_left)
        tree_data.right_child.append(node_id_right)
        tree_data.feature.append(feature)
        tree_data.threshold.append(threshold)
        tree_data.impurity.append(node.impurity)
        tree_data.n_node_samples.append(np.sum(value_sklearn))
        tree_data.weighted_n_node_samples.append(np.sum(value_sklearn)) # TODO add sample weights
        tree_data.missing_go_to_left.append(1)
        value_sklearn_array.append(value_sklearn)

        if has_children:
            _update_node(node.left)
            _update_node(node.right)

    _update_node(figs_tree)

    return tree_data, np.array(value_sklearn_array)

def extract_sklearn_tree_from_figs(figs, tree_num, n_classes, with_leaf_predictions=False):
    """Takes in a FIGS model and convert tree tree_num to a sklearn decision tree
    """

    try:
        figs_tree = figs.trees_[tree_num]
    except:
        raise AttributeError(f'Can not load tree_num = {tree_num}!')

    tree_data_namedtuple, value_sklearn_array = _extract_arrays_from_figs_tree(figs_tree)

    # manipulate tree_data_namedtuple into the numpy array of tuples that sklearn expects for use with __setstate__()
    df_tree_data = pd.DataFrame(tree_data_namedtuple._asdict())
    tree_data_list_of_tuples = list(df_tree_data.itertuples(index=False, name=None))
    _dtypes = np.dtype([('left_child', 'i8'), ('right_child', 'i8'), ('feature', 'i8'), ('threshold', 'f8'), ('impurity', 'f8'), ('n_node_samples', 'i8'), ('weighted_n_node_samples', 'f8'), ('missing_go_to_left', 'u1')])

    tree_data_array = np.array(tree_data_list_of_tuples, dtype=_dtypes)

    # reshape value_sklearn_array to match the expected shape of (n_nodes,1,2) for values
    value_sklearns = value_sklearn_array.reshape(value_sklearn_array.shape[0], 1, value_sklearn_array.shape[1])

    if n_classes == 1:
        value_sklearns = np.ascontiguousarray(value_sklearns[:, :, 0:1])

    # get the max_depth
    def get_max_depth(node):
        if node is None:
            return -1
        else:
            return 1 + max(get_max_depth(node.left), get_max_depth(node.right))

    max_depth = get_max_depth(figs_tree)

    # get other variables needed for the sklearn.tree._tree.Tree constructor and __setstate__() calls
    # n_samples = np.sum(figs_tree.value_sklearn)
    node_count = len(tree_data_array)
    features = np.array(tree_data_namedtuple.feature)
    n_features = np.unique(features[np.where( 0 <= features )]).size
    n_classes_array = np.array([n_classes], dtype=int)
    n_outputs = 1

    # make dict to pass to __setstate__()
    _state = {'max_depth': max_depth,
        'node_count': node_count,
        'nodes': tree_data_array,
        'values': value_sklearns,
        'n_features_in_': figs.n_features_in_,
        # WARNING this circumvents
        # UserWarning: Trying to unpickle estimator DecisionTreeClassifier from version pre-0.18 when using version
        # https://github.com/scikit-learn/scikit-learn/blob/53acd0fe52cb5d8c6f5a86a1fc1352809240b68d/sklearn/base.py#L279
        '_sklearn_version': __version__,
    }

    tree = Tree(n_features=n_features, n_classes=n_classes_array, n_outputs=n_outputs)
    # https://github.com/scikit-learn/scikit-learn/blob/3850935ea610b5231720fdf865c837aeff79ab1b/sklearn/tree/_tree.pyx#L677
    tree.__setstate__(_state)

    # add the tree_ for the dt __setstate__()
    # note the trailing underscore also trips the sklearn_is_fitted protections
    _state['tree_'] = tree
    _state['classes_'] = np.arange(n_classes)
    _state['n_outputs_'] = n_outputs

    # construct sklearn object and __setstate__()
    if isinstance(figs, ClassifierMixin):
        dt = DecisionTreeClassifier(max_depth=max_depth)
    elif isinstance(figs, RegressorMixin):
        dt = DecisionTreeRegressor(max_depth=max_depth)

    try:
        dt.__setstate__(_state);
    except:
        raise Exception(f'Did not successfully run __setstate__() when translating to {type(dt)}, did sklearn update?')

    if not with_leaf_predictions:
        return dt
    else:
        leaf_values_dict = {}
        def _read_node(node):
            if node is None:
                return None
            elif node.left is None and node.right is None:
                leaf_values_dict[node.node_id] = node.value[0][0]
            _read_node(node.left)
            _read_node(node.right)
        _read_node(figs_tree)

        return dt, leaf_values_dict

Functions

def extract_sklearn_tree_from_figs(figs, tree_num, n_classes, with_leaf_predictions=False)

Takes in a FIGS model and convert tree tree_num to a sklearn decision tree

Expand source code
def extract_sklearn_tree_from_figs(figs, tree_num, n_classes, with_leaf_predictions=False):
    """Takes in a FIGS model and convert tree tree_num to a sklearn decision tree
    """

    try:
        figs_tree = figs.trees_[tree_num]
    except:
        raise AttributeError(f'Can not load tree_num = {tree_num}!')

    tree_data_namedtuple, value_sklearn_array = _extract_arrays_from_figs_tree(figs_tree)

    # manipulate tree_data_namedtuple into the numpy array of tuples that sklearn expects for use with __setstate__()
    df_tree_data = pd.DataFrame(tree_data_namedtuple._asdict())
    tree_data_list_of_tuples = list(df_tree_data.itertuples(index=False, name=None))
    _dtypes = np.dtype([('left_child', 'i8'), ('right_child', 'i8'), ('feature', 'i8'), ('threshold', 'f8'), ('impurity', 'f8'), ('n_node_samples', 'i8'), ('weighted_n_node_samples', 'f8'), ('missing_go_to_left', 'u1')])

    tree_data_array = np.array(tree_data_list_of_tuples, dtype=_dtypes)

    # reshape value_sklearn_array to match the expected shape of (n_nodes,1,2) for values
    value_sklearns = value_sklearn_array.reshape(value_sklearn_array.shape[0], 1, value_sklearn_array.shape[1])

    if n_classes == 1:
        value_sklearns = np.ascontiguousarray(value_sklearns[:, :, 0:1])

    # get the max_depth
    def get_max_depth(node):
        if node is None:
            return -1
        else:
            return 1 + max(get_max_depth(node.left), get_max_depth(node.right))

    max_depth = get_max_depth(figs_tree)

    # get other variables needed for the sklearn.tree._tree.Tree constructor and __setstate__() calls
    # n_samples = np.sum(figs_tree.value_sklearn)
    node_count = len(tree_data_array)
    features = np.array(tree_data_namedtuple.feature)
    n_features = np.unique(features[np.where( 0 <= features )]).size
    n_classes_array = np.array([n_classes], dtype=int)
    n_outputs = 1

    # make dict to pass to __setstate__()
    _state = {'max_depth': max_depth,
        'node_count': node_count,
        'nodes': tree_data_array,
        'values': value_sklearns,
        'n_features_in_': figs.n_features_in_,
        # WARNING this circumvents
        # UserWarning: Trying to unpickle estimator DecisionTreeClassifier from version pre-0.18 when using version
        # https://github.com/scikit-learn/scikit-learn/blob/53acd0fe52cb5d8c6f5a86a1fc1352809240b68d/sklearn/base.py#L279
        '_sklearn_version': __version__,
    }

    tree = Tree(n_features=n_features, n_classes=n_classes_array, n_outputs=n_outputs)
    # https://github.com/scikit-learn/scikit-learn/blob/3850935ea610b5231720fdf865c837aeff79ab1b/sklearn/tree/_tree.pyx#L677
    tree.__setstate__(_state)

    # add the tree_ for the dt __setstate__()
    # note the trailing underscore also trips the sklearn_is_fitted protections
    _state['tree_'] = tree
    _state['classes_'] = np.arange(n_classes)
    _state['n_outputs_'] = n_outputs

    # construct sklearn object and __setstate__()
    if isinstance(figs, ClassifierMixin):
        dt = DecisionTreeClassifier(max_depth=max_depth)
    elif isinstance(figs, RegressorMixin):
        dt = DecisionTreeRegressor(max_depth=max_depth)

    try:
        dt.__setstate__(_state);
    except:
        raise Exception(f'Did not successfully run __setstate__() when translating to {type(dt)}, did sklearn update?')

    if not with_leaf_predictions:
        return dt
    else:
        leaf_values_dict = {}
        def _read_node(node):
            if node is None:
                return None
            elif node.left is None and node.right is None:
                leaf_values_dict[node.node_id] = node.value[0][0]
            _read_node(node.left)
            _read_node(node.right)
        _read_node(figs_tree)

        return dt, leaf_values_dict

Classes

class TreeData (left_child, right_child, feature, threshold, impurity, n_node_samples, weighted_n_node_samples, missing_go_to_left)

TreeData(left_child, right_child, feature, threshold, impurity, n_node_samples, weighted_n_node_samples, missing_go_to_left)

Ancestors

  • builtins.tuple

Instance variables

var feature

Alias for field number 2

var impurity

Alias for field number 4

var left_child

Alias for field number 0

var missing_go_to_left

Alias for field number 7

var n_node_samples

Alias for field number 5

var right_child

Alias for field number 1

var threshold

Alias for field number 3

var weighted_n_node_samples

Alias for field number 6