Expand source code
import copy
from typing import List

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.model_selection import cross_val_score

from imodels.tree.gosdt.pygosdt import OptimalTreeClassifier


def shrink_node(node, reg_param, parent_val, parent_num, cum_sum, scheme, constant):
    """Shrink the tree
    """

    left = node.get("false", None)
    right = node.get("true", None)
    is_leaf = "prediction" in node
    # if self.prediction_task == 'regression':
    val = node["probs"]
    is_root = parent_val is None and parent_num is None
    n_samples = node['n_obs'] if (scheme != "leaf_based" or is_root) else parent_num

    if is_root:
        val_new = val

    else:
        reg_term = reg_param if scheme == "constant" else reg_param / parent_num

        val_new = (val - parent_val) / (1 + reg_term)

    cum_sum += val_new

    if is_leaf:
        if scheme == "leaf_based":
            v = constant + (val - constant) / (1 + reg_param / node.n_obs)
            node["probs"] = v
        else:
            # print(f"Changing {val} to {cum_sum}")
            node["probs"] = cum_sum

    else:
        shrink_node(left, reg_param, val, parent_num=n_samples, cum_sum=cum_sum, scheme=scheme, constant=constant)
        shrink_node(right, reg_param, val, parent_num=n_samples, cum_sum=cum_sum, scheme=scheme, constant=constant)

    return node


def _add_label(node, val):
    if "labels" in node:
        node['labels'].append(val)
        return
    node['labels'] = [val]


class HSOptimalTreeClassifier(BaseEstimator):
    def __init__(self, estimator_: OptimalTreeClassifier, reg_param: float = 1, shrinkage_scheme_: str = 'node_based'):
        """
        Params
        ------
        reg_param: float
            Higher is more regularization (can be arbitrarily large, should not be < 0)
        shrinkage_scheme: str
            Experimental: Used to experiment with different forms of shrinkage. options are:
                (i) node_based shrinks based on number of samples in parent node
                (ii) leaf_based only shrinks leaf nodes based on number of leaf samples
                (iii) constant shrinks every node by a constant lambda
        """
        super().__init__()
        self.reg_param = reg_param
        # print('est', estimator_)
        self.estimator_ = estimator_
        # self.tree_ = estimator_.tree_
        self.shrinkage_scheme_ = shrinkage_scheme_

    def _calc_probs(self, node):
        lbls = np.array([float(l) for l in node["labels"]]) if "labels" in node else np.array(
            [float(node['prediction'])])
        node['probs'] = np.mean(lbls == 1)
        node['n_obs'] = len(node.get('labels', []))
        if "prediction" in node:
            node['prediction'] = np.round(node['probs'])
            return
        self._calc_probs(node['true'])
        self._calc_probs(node['false'])

    def impute_nodes(self, X, y):
        """
        Returns
        ---
        the leaf by which this sample would be classified
        """
        source_node = self.estimator_.tree_.source
        for i in range(len(y)):
            sample, label = X[i, ...], y[i]
            _add_label(source_node, label)
            nodes = [source_node]
            while len(nodes) > 0:
                node = nodes.pop()
                if "prediction" in node:
                    continue
                else:
                    value = sample[node["feature"]]
                    reference = node["reference"]
                    relation = node["relation"]
                    if relation == "==":
                        is_true = value == reference
                    elif relation == ">=":
                        is_true = value >= reference
                    elif relation == "<=":
                        is_true = value <= reference
                    elif relation == "<":
                        is_true = value < reference
                    elif relation == ">":
                        is_true = value > reference
                    else:
                        raise "Unsupported relational operator {}".format(node["relation"])

                    next_node = node['true'] if is_true else node['false']
                    _add_label(next_node, label)
                    nodes.append(next_node)

        self._calc_probs(source_node)
        self.estimator_.tree_.source = source_node

    # def fit(self, *args, **kwargs):
    #     X = kwargs['X'] if "X" in kwargs else args[0]
    #     y = kwargs['y'] if "y" in kwargs else args[1]

    def shrink_tree(self):
        root = self.estimator_.tree_.source
        shrink_node(root, self.reg_param, None, None, 0, self.shrinkage_scheme_, 0)

    def predict_proba(self, X):
        probs = []
        for i in range(X.shape[0]):
            sample = X[i, ...]
            node = self.estimator_.tree_.__find_leaf__(sample)
            probs.append([1 - node["probs"], node["probs"]])
        return np.array(probs)

    def fit(self, *args, **kwargs):
        X = kwargs['X'] if "X" in kwargs else args[0]
        y = kwargs['y'] if "y" in kwargs else args[1]
        if not hasattr(self.estimator_, "tree_"):
            self.estimator_.fit(X, y)
        self.impute_nodes(X, y)
        self.shrink_tree()

    def predict(self, X):
        return self.estimator_.predict(X)

    def score(self, X, y, weight=None):
        self.estimator_.score(X, y, weight)

    @property
    def complexity_(self):
        return self.estimator_.complexity_


class HSOptimalTreeClassifierCV(HSOptimalTreeClassifier):
    def __init__(self, estimator_: OptimalTreeClassifier,
                 reg_param_list: List[float] = [0.1, 1, 10, 50, 100, 500], shrinkage_scheme_: str = 'node_based',
                 cv: int = 3, scoring="accuracy", *args, **kwargs):
        """Note: args, kwargs are not used but left so that imodels-experiments can still pass redundant args
        """
        super().__init__(estimator_, reg_param=None)
        self.reg_param_list = np.array(reg_param_list)
        self.cv = cv
        self.scoring = scoring
        self.shrinkage_scheme_ = shrinkage_scheme_
        # print('estimator', self.estimator_,
        #       'checks.check_is_fitted(estimator)', checks.check_is_fitted(self.estimator_))
        # if checks.check_is_fitted(self.estimator_):
        #     raise Warning('Passed an already fitted estimator,'
        #                   'but shrinking not applied until fit method is called.')

    def fit(self, X, y, *args, **kwargs):
        self.scores_ = []
        opt = copy.deepcopy(self.estimator_)
        for reg_param in self.reg_param_list:
            est = HSOptimalTreeClassifier(opt, reg_param)
            cv_scores = cross_val_score(est, X, y, cv=self.cv, scoring=self.scoring)
            self.scores_.append(np.mean(cv_scores))
        self.reg_param = self.reg_param_list[np.argmax(self.scores_)]
        super().fit(X=X, y=y)

Functions

def shrink_node(node, reg_param, parent_val, parent_num, cum_sum, scheme, constant)

Shrink the tree

Expand source code
def shrink_node(node, reg_param, parent_val, parent_num, cum_sum, scheme, constant):
    """Shrink the tree
    """

    left = node.get("false", None)
    right = node.get("true", None)
    is_leaf = "prediction" in node
    # if self.prediction_task == 'regression':
    val = node["probs"]
    is_root = parent_val is None and parent_num is None
    n_samples = node['n_obs'] if (scheme != "leaf_based" or is_root) else parent_num

    if is_root:
        val_new = val

    else:
        reg_term = reg_param if scheme == "constant" else reg_param / parent_num

        val_new = (val - parent_val) / (1 + reg_term)

    cum_sum += val_new

    if is_leaf:
        if scheme == "leaf_based":
            v = constant + (val - constant) / (1 + reg_param / node.n_obs)
            node["probs"] = v
        else:
            # print(f"Changing {val} to {cum_sum}")
            node["probs"] = cum_sum

    else:
        shrink_node(left, reg_param, val, parent_num=n_samples, cum_sum=cum_sum, scheme=scheme, constant=constant)
        shrink_node(right, reg_param, val, parent_num=n_samples, cum_sum=cum_sum, scheme=scheme, constant=constant)

    return node

Classes

class HSOptimalTreeClassifier (estimator_: OptimalTreeClassifier, reg_param: float = 1, shrinkage_scheme_: str = 'node_based')

Base class for all estimators in scikit-learn.

Inheriting from this class provides default implementations of:

  • setting and getting parameters used by GridSearchCV and friends;
  • textual and HTML representation displayed in terminals and IDEs;
  • estimator serialization;
  • parameters validation;
  • data validation;
  • feature names validation.

Read more in the :ref:User Guide <rolling_your_own_estimator>.

Notes

All estimators should specify all the parameters that can be set at the class level in their __init__ as explicit keyword arguments (no *args or **kwargs).

Examples

>>> import numpy as np
>>> from sklearn.base import BaseEstimator
>>> class MyEstimator(BaseEstimator):
...     def __init__(self, *, param=1):
...         self.param = param
...     def fit(self, X, y=None):
...         self.is_fitted_ = True
...         return self
...     def predict(self, X):
...         return np.full(shape=X.shape[0], fill_value=self.param)
>>> estimator = MyEstimator(param=2)
>>> estimator.get_params()
{'param': 2}
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
>>> y = np.array([1, 0, 1])
>>> estimator.fit(X, y).predict(X)
array([2, 2, 2])
>>> estimator.set_params(param=3).fit(X, y).predict(X)
array([3, 3, 3])

Params

reg_param: float Higher is more regularization (can be arbitrarily large, should not be < 0) shrinkage_scheme: str Experimental: Used to experiment with different forms of shrinkage. options are: (i) node_based shrinks based on number of samples in parent node (ii) leaf_based only shrinks leaf nodes based on number of leaf samples (iii) constant shrinks every node by a constant lambda

Expand source code
class HSOptimalTreeClassifier(BaseEstimator):
    def __init__(self, estimator_: OptimalTreeClassifier, reg_param: float = 1, shrinkage_scheme_: str = 'node_based'):
        """
        Params
        ------
        reg_param: float
            Higher is more regularization (can be arbitrarily large, should not be < 0)
        shrinkage_scheme: str
            Experimental: Used to experiment with different forms of shrinkage. options are:
                (i) node_based shrinks based on number of samples in parent node
                (ii) leaf_based only shrinks leaf nodes based on number of leaf samples
                (iii) constant shrinks every node by a constant lambda
        """
        super().__init__()
        self.reg_param = reg_param
        # print('est', estimator_)
        self.estimator_ = estimator_
        # self.tree_ = estimator_.tree_
        self.shrinkage_scheme_ = shrinkage_scheme_

    def _calc_probs(self, node):
        lbls = np.array([float(l) for l in node["labels"]]) if "labels" in node else np.array(
            [float(node['prediction'])])
        node['probs'] = np.mean(lbls == 1)
        node['n_obs'] = len(node.get('labels', []))
        if "prediction" in node:
            node['prediction'] = np.round(node['probs'])
            return
        self._calc_probs(node['true'])
        self._calc_probs(node['false'])

    def impute_nodes(self, X, y):
        """
        Returns
        ---
        the leaf by which this sample would be classified
        """
        source_node = self.estimator_.tree_.source
        for i in range(len(y)):
            sample, label = X[i, ...], y[i]
            _add_label(source_node, label)
            nodes = [source_node]
            while len(nodes) > 0:
                node = nodes.pop()
                if "prediction" in node:
                    continue
                else:
                    value = sample[node["feature"]]
                    reference = node["reference"]
                    relation = node["relation"]
                    if relation == "==":
                        is_true = value == reference
                    elif relation == ">=":
                        is_true = value >= reference
                    elif relation == "<=":
                        is_true = value <= reference
                    elif relation == "<":
                        is_true = value < reference
                    elif relation == ">":
                        is_true = value > reference
                    else:
                        raise "Unsupported relational operator {}".format(node["relation"])

                    next_node = node['true'] if is_true else node['false']
                    _add_label(next_node, label)
                    nodes.append(next_node)

        self._calc_probs(source_node)
        self.estimator_.tree_.source = source_node

    # def fit(self, *args, **kwargs):
    #     X = kwargs['X'] if "X" in kwargs else args[0]
    #     y = kwargs['y'] if "y" in kwargs else args[1]

    def shrink_tree(self):
        root = self.estimator_.tree_.source
        shrink_node(root, self.reg_param, None, None, 0, self.shrinkage_scheme_, 0)

    def predict_proba(self, X):
        probs = []
        for i in range(X.shape[0]):
            sample = X[i, ...]
            node = self.estimator_.tree_.__find_leaf__(sample)
            probs.append([1 - node["probs"], node["probs"]])
        return np.array(probs)

    def fit(self, *args, **kwargs):
        X = kwargs['X'] if "X" in kwargs else args[0]
        y = kwargs['y'] if "y" in kwargs else args[1]
        if not hasattr(self.estimator_, "tree_"):
            self.estimator_.fit(X, y)
        self.impute_nodes(X, y)
        self.shrink_tree()

    def predict(self, X):
        return self.estimator_.predict(X)

    def score(self, X, y, weight=None):
        self.estimator_.score(X, y, weight)

    @property
    def complexity_(self):
        return self.estimator_.complexity_

Ancestors

  • sklearn.base.BaseEstimator
  • sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
  • sklearn.utils._metadata_requests._MetadataRequester

Subclasses

Instance variables

var complexity_
Expand source code
@property
def complexity_(self):
    return self.estimator_.complexity_

Methods

def fit(self, *args, **kwargs)
Expand source code
def fit(self, *args, **kwargs):
    X = kwargs['X'] if "X" in kwargs else args[0]
    y = kwargs['y'] if "y" in kwargs else args[1]
    if not hasattr(self.estimator_, "tree_"):
        self.estimator_.fit(X, y)
    self.impute_nodes(X, y)
    self.shrink_tree()
def impute_nodes(self, X, y)

Returns

the leaf by which this sample would be classified
 
Expand source code
def impute_nodes(self, X, y):
    """
    Returns
    ---
    the leaf by which this sample would be classified
    """
    source_node = self.estimator_.tree_.source
    for i in range(len(y)):
        sample, label = X[i, ...], y[i]
        _add_label(source_node, label)
        nodes = [source_node]
        while len(nodes) > 0:
            node = nodes.pop()
            if "prediction" in node:
                continue
            else:
                value = sample[node["feature"]]
                reference = node["reference"]
                relation = node["relation"]
                if relation == "==":
                    is_true = value == reference
                elif relation == ">=":
                    is_true = value >= reference
                elif relation == "<=":
                    is_true = value <= reference
                elif relation == "<":
                    is_true = value < reference
                elif relation == ">":
                    is_true = value > reference
                else:
                    raise "Unsupported relational operator {}".format(node["relation"])

                next_node = node['true'] if is_true else node['false']
                _add_label(next_node, label)
                nodes.append(next_node)

    self._calc_probs(source_node)
    self.estimator_.tree_.source = source_node
def predict(self, X)
Expand source code
def predict(self, X):
    return self.estimator_.predict(X)
def predict_proba(self, X)
Expand source code
def predict_proba(self, X):
    probs = []
    for i in range(X.shape[0]):
        sample = X[i, ...]
        node = self.estimator_.tree_.__find_leaf__(sample)
        probs.append([1 - node["probs"], node["probs"]])
    return np.array(probs)
def score(self, X, y, weight=None)
Expand source code
def score(self, X, y, weight=None):
    self.estimator_.score(X, y, weight)
def set_score_request(self: HSOptimalTreeClassifier, *, weight: Union[bool, ForwardRef(None), str] = '$UNCHANGED$') ‑> HSOptimalTreeClassifier

Request metadata passed to the score method.

Note that this method is only relevant if enable_metadata_routing=True (see :func:sklearn.set_config). Please see :ref:User Guide <metadata_routing> on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to score.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version: 1.3

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:~sklearn.pipeline.Pipeline. Otherwise it has no effect.

Parameters

weight : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for weight parameter in score.

Returns

self : object
The updated object.
Expand source code
def func(*args, **kw):
    """Updates the request for provided parameters

    This docstring is overwritten below.
    See REQUESTER_DOC for expected functionality
    """
    if not _routing_enabled():
        raise RuntimeError(
            "This method is only available when metadata routing is enabled."
            " You can enable it using"
            " sklearn.set_config(enable_metadata_routing=True)."
        )

    if self.validate_keys and (set(kw) - set(self.keys)):
        raise TypeError(
            f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. "
            f"Accepted arguments are: {set(self.keys)}"
        )

    # This makes it possible to use the decorated method as an unbound method,
    # for instance when monkeypatching.
    # https://github.com/scikit-learn/scikit-learn/issues/28632
    if instance is None:
        _instance = args[0]
        args = args[1:]
    else:
        _instance = instance

    # Replicating python's behavior when positional args are given other than
    # `self`, and `self` is only allowed if this method is unbound.
    if args:
        raise TypeError(
            f"set_{self.name}_request() takes 0 positional argument but"
            f" {len(args)} were given"
        )

    requests = _instance._get_metadata_request()
    method_metadata_request = getattr(requests, self.name)

    for prop, alias in kw.items():
        if alias is not UNCHANGED:
            method_metadata_request.add_request(param=prop, alias=alias)
    _instance._metadata_request = requests

    return _instance
def shrink_tree(self)
Expand source code
def shrink_tree(self):
    root = self.estimator_.tree_.source
    shrink_node(root, self.reg_param, None, None, 0, self.shrinkage_scheme_, 0)
class HSOptimalTreeClassifierCV (estimator_: OptimalTreeClassifier, reg_param_list: List[float] = [0.1, 1, 10, 50, 100, 500], shrinkage_scheme_: str = 'node_based', cv: int = 3, scoring='accuracy', *args, **kwargs)

Base class for all estimators in scikit-learn.

Inheriting from this class provides default implementations of:

  • setting and getting parameters used by GridSearchCV and friends;
  • textual and HTML representation displayed in terminals and IDEs;
  • estimator serialization;
  • parameters validation;
  • data validation;
  • feature names validation.

Read more in the :ref:User Guide <rolling_your_own_estimator>.

Notes

All estimators should specify all the parameters that can be set at the class level in their __init__ as explicit keyword arguments (no *args or **kwargs).

Examples

>>> import numpy as np
>>> from sklearn.base import BaseEstimator
>>> class MyEstimator(BaseEstimator):
...     def __init__(self, *, param=1):
...         self.param = param
...     def fit(self, X, y=None):
...         self.is_fitted_ = True
...         return self
...     def predict(self, X):
...         return np.full(shape=X.shape[0], fill_value=self.param)
>>> estimator = MyEstimator(param=2)
>>> estimator.get_params()
{'param': 2}
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
>>> y = np.array([1, 0, 1])
>>> estimator.fit(X, y).predict(X)
array([2, 2, 2])
>>> estimator.set_params(param=3).fit(X, y).predict(X)
array([3, 3, 3])

Note: args, kwargs are not used but left so that imodels-experiments can still pass redundant args

Expand source code
class HSOptimalTreeClassifierCV(HSOptimalTreeClassifier):
    def __init__(self, estimator_: OptimalTreeClassifier,
                 reg_param_list: List[float] = [0.1, 1, 10, 50, 100, 500], shrinkage_scheme_: str = 'node_based',
                 cv: int = 3, scoring="accuracy", *args, **kwargs):
        """Note: args, kwargs are not used but left so that imodels-experiments can still pass redundant args
        """
        super().__init__(estimator_, reg_param=None)
        self.reg_param_list = np.array(reg_param_list)
        self.cv = cv
        self.scoring = scoring
        self.shrinkage_scheme_ = shrinkage_scheme_
        # print('estimator', self.estimator_,
        #       'checks.check_is_fitted(estimator)', checks.check_is_fitted(self.estimator_))
        # if checks.check_is_fitted(self.estimator_):
        #     raise Warning('Passed an already fitted estimator,'
        #                   'but shrinking not applied until fit method is called.')

    def fit(self, X, y, *args, **kwargs):
        self.scores_ = []
        opt = copy.deepcopy(self.estimator_)
        for reg_param in self.reg_param_list:
            est = HSOptimalTreeClassifier(opt, reg_param)
            cv_scores = cross_val_score(est, X, y, cv=self.cv, scoring=self.scoring)
            self.scores_.append(np.mean(cv_scores))
        self.reg_param = self.reg_param_list[np.argmax(self.scores_)]
        super().fit(X=X, y=y)

Ancestors

  • HSOptimalTreeClassifier
  • sklearn.base.BaseEstimator
  • sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
  • sklearn.utils._metadata_requests._MetadataRequester

Methods

def fit(self, X, y, *args, **kwargs)
Expand source code
def fit(self, X, y, *args, **kwargs):
    self.scores_ = []
    opt = copy.deepcopy(self.estimator_)
    for reg_param in self.reg_param_list:
        est = HSOptimalTreeClassifier(opt, reg_param)
        cv_scores = cross_val_score(est, X, y, cv=self.cv, scoring=self.scoring)
        self.scores_.append(np.mean(cv_scores))
    self.reg_param = self.reg_param_list[np.argmax(self.scores_)]
    super().fit(X=X, y=y)

Inherited members