Expand source code
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from imodels import (
    RuleFitClassifier,
    TreeGAMClassifier,
    FIGSClassifier,
    HSTreeClassifier,
    RuleFitRegressor,
    TreeGAMRegressor,
    FIGSRegressor,
    HSTreeRegressor,
)
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.linear_model import LogisticRegression, ElasticNet, Ridge
import imodels
from sklearn.model_selection import GridSearchCV, train_test_split
import numpy as np
from sklearn.pipeline import Pipeline


class AutoInterpretableModel(BaseEstimator):
    """Automatically fit and select a classifier that is interpretable.
    Note that all preprocessing should be done beforehand.
    This is basically a wrapper around GridSearchCV, with some preselected models.
    """

    def __init__(self, param_grid=None, refit=True):
        if param_grid is None:
            if isinstance(self, ClassifierMixin):
                self.param_grid = self.PARAM_GRID_DEFAULT_CLASSIFICATION
            elif isinstance(self, RegressorMixin):
                self.param_grid = self.PARAM_GRID_DEFAULT_REGRESSION
        else:
            self.param_grid = param_grid
        self.refit = refit

    def fit(self, X, y, cv=5):
        self.pipe_ = Pipeline([("est", BaseEstimator())]
                              )  # Placeholder Estimator
        if isinstance(self, ClassifierMixin):
            scoring = "roc_auc"
        elif isinstance(self, RegressorMixin):
            scoring = "r2"
        self.est_ = GridSearchCV(
            self.pipe_, self.param_grid, scoring=scoring, cv=cv, refit=self.refit)
        self.est_.fit(X, y)
        return self

    def predict(self, X):
        return self.est_.predict(X)

    def predict_proba(self, X):
        return self.est_.predict_proba(X)

    def score(self, X, y):
        return self.est_.score(X, y)

    PARAM_GRID_LINEAR_CLASSIFICATION = [
        {
            "est": [
                LogisticRegression(
                    solver="saga", penalty="elasticnet", max_iter=100, random_state=42)
            ],
            "est__C": [0.1, 1, 10],
            "est__l1_ratio": [0, 0.5, 1],
        },
    ]

    PARAM_GRID_DEFAULT_CLASSIFICATION = [
        {
            "est": [DecisionTreeClassifier(random_state=42)],
            "est__max_leaf_nodes": [2, 5, 10],
        },
        {
            "est": [RuleFitClassifier(random_state=42)],
            "est__max_rules": [10, 100],
            "est__n_estimators": [20],
        },
        {
            "est": [TreeGAMClassifier(random_state=42)],
            "est__n_boosting_rounds": [10, 100],
        },
        {
            "est": [HSTreeClassifier(random_state=42)],
            "est__max_leaf_nodes": [5, 10],
        },
        {
            "est": [FIGSClassifier(random_state=42)],
            "est__max_rules": [5, 10],
        },
    ] + PARAM_GRID_LINEAR_CLASSIFICATION

    PARAM_GRID_LINEAR_REGRESSION = [
        {
            "est": [
                ElasticNet(max_iter=100, random_state=42)
            ],
            "est__alpha": [0.1, 1, 10],
            "est__l1_ratio": [0.5, 1],
        },
        {
            "est": [
                Ridge(max_iter=100, random_state=42)
            ],
            "est__alpha": [0, 0.1, 1, 10],
        },
    ]

    PARAM_GRID_DEFAULT_REGRESSION = [
        {
            "est": [DecisionTreeRegressor()],
            "est__max_leaf_nodes": [2, 5, 10],
        },
        {
            "est": [HSTreeRegressor()],
            "est__max_leaf_nodes": [5, 10],
        },

        {
            "est": [RuleFitRegressor()],
            "est__max_rules": [10, 100],
            "est__n_estimators": [20],
        },
        {
            "est": [TreeGAMRegressor()],
            "est__n_boosting_rounds": [10, 100],
        },
        {
            "est": [FIGSRegressor()],
            "est__max_rules": [5, 10],
        },
    ] + PARAM_GRID_LINEAR_REGRESSION


class AutoInterpretableClassifier(AutoInterpretableModel, ClassifierMixin):
    ...


class AutoInterpretableRegressor(AutoInterpretableModel, RegressorMixin):
    ...


if __name__ == "__main__":
    X, y, feature_names = imodels.get_clean_dataset("heart")

    print("shapes", X.shape, y.shape, "nunique", np.unique(y).size)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, random_state=42, test_size=0.2
    )

    m = AutoInterpretableClassifier()
    # m = AutoInterpretableRegressor()
    m.fit(X_train, y_train)

    print("best params", m.est_.best_params_)
    print("best score", m.est_.best_score_)
    print("best estimator", m.est_.best_estimator_)
    print("best estimator params", m.est_.best_estimator_.get_params())

Classes

class AutoInterpretableClassifier (param_grid=None, refit=True)

Automatically fit and select a classifier that is interpretable. Note that all preprocessing should be done beforehand. This is basically a wrapper around GridSearchCV, with some preselected models.

Expand source code
class AutoInterpretableClassifier(AutoInterpretableModel, ClassifierMixin):
    ...

Ancestors

  • AutoInterpretableModel
  • sklearn.base.BaseEstimator
  • sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
  • sklearn.utils._metadata_requests._MetadataRequester
  • sklearn.base.ClassifierMixin

Inherited members

class AutoInterpretableModel (param_grid=None, refit=True)

Automatically fit and select a classifier that is interpretable. Note that all preprocessing should be done beforehand. This is basically a wrapper around GridSearchCV, with some preselected models.

Expand source code
class AutoInterpretableModel(BaseEstimator):
    """Automatically fit and select a classifier that is interpretable.
    Note that all preprocessing should be done beforehand.
    This is basically a wrapper around GridSearchCV, with some preselected models.
    """

    def __init__(self, param_grid=None, refit=True):
        if param_grid is None:
            if isinstance(self, ClassifierMixin):
                self.param_grid = self.PARAM_GRID_DEFAULT_CLASSIFICATION
            elif isinstance(self, RegressorMixin):
                self.param_grid = self.PARAM_GRID_DEFAULT_REGRESSION
        else:
            self.param_grid = param_grid
        self.refit = refit

    def fit(self, X, y, cv=5):
        self.pipe_ = Pipeline([("est", BaseEstimator())]
                              )  # Placeholder Estimator
        if isinstance(self, ClassifierMixin):
            scoring = "roc_auc"
        elif isinstance(self, RegressorMixin):
            scoring = "r2"
        self.est_ = GridSearchCV(
            self.pipe_, self.param_grid, scoring=scoring, cv=cv, refit=self.refit)
        self.est_.fit(X, y)
        return self

    def predict(self, X):
        return self.est_.predict(X)

    def predict_proba(self, X):
        return self.est_.predict_proba(X)

    def score(self, X, y):
        return self.est_.score(X, y)

    PARAM_GRID_LINEAR_CLASSIFICATION = [
        {
            "est": [
                LogisticRegression(
                    solver="saga", penalty="elasticnet", max_iter=100, random_state=42)
            ],
            "est__C": [0.1, 1, 10],
            "est__l1_ratio": [0, 0.5, 1],
        },
    ]

    PARAM_GRID_DEFAULT_CLASSIFICATION = [
        {
            "est": [DecisionTreeClassifier(random_state=42)],
            "est__max_leaf_nodes": [2, 5, 10],
        },
        {
            "est": [RuleFitClassifier(random_state=42)],
            "est__max_rules": [10, 100],
            "est__n_estimators": [20],
        },
        {
            "est": [TreeGAMClassifier(random_state=42)],
            "est__n_boosting_rounds": [10, 100],
        },
        {
            "est": [HSTreeClassifier(random_state=42)],
            "est__max_leaf_nodes": [5, 10],
        },
        {
            "est": [FIGSClassifier(random_state=42)],
            "est__max_rules": [5, 10],
        },
    ] + PARAM_GRID_LINEAR_CLASSIFICATION

    PARAM_GRID_LINEAR_REGRESSION = [
        {
            "est": [
                ElasticNet(max_iter=100, random_state=42)
            ],
            "est__alpha": [0.1, 1, 10],
            "est__l1_ratio": [0.5, 1],
        },
        {
            "est": [
                Ridge(max_iter=100, random_state=42)
            ],
            "est__alpha": [0, 0.1, 1, 10],
        },
    ]

    PARAM_GRID_DEFAULT_REGRESSION = [
        {
            "est": [DecisionTreeRegressor()],
            "est__max_leaf_nodes": [2, 5, 10],
        },
        {
            "est": [HSTreeRegressor()],
            "est__max_leaf_nodes": [5, 10],
        },

        {
            "est": [RuleFitRegressor()],
            "est__max_rules": [10, 100],
            "est__n_estimators": [20],
        },
        {
            "est": [TreeGAMRegressor()],
            "est__n_boosting_rounds": [10, 100],
        },
        {
            "est": [FIGSRegressor()],
            "est__max_rules": [5, 10],
        },
    ] + PARAM_GRID_LINEAR_REGRESSION

Ancestors

  • sklearn.base.BaseEstimator
  • sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
  • sklearn.utils._metadata_requests._MetadataRequester

Subclasses

Class variables

var PARAM_GRID_DEFAULT_CLASSIFICATION
var PARAM_GRID_DEFAULT_REGRESSION
var PARAM_GRID_LINEAR_CLASSIFICATION
var PARAM_GRID_LINEAR_REGRESSION

Methods

def fit(self, X, y, cv=5)
Expand source code
def fit(self, X, y, cv=5):
    self.pipe_ = Pipeline([("est", BaseEstimator())]
                          )  # Placeholder Estimator
    if isinstance(self, ClassifierMixin):
        scoring = "roc_auc"
    elif isinstance(self, RegressorMixin):
        scoring = "r2"
    self.est_ = GridSearchCV(
        self.pipe_, self.param_grid, scoring=scoring, cv=cv, refit=self.refit)
    self.est_.fit(X, y)
    return self
def predict(self, X)
Expand source code
def predict(self, X):
    return self.est_.predict(X)
def predict_proba(self, X)
Expand source code
def predict_proba(self, X):
    return self.est_.predict_proba(X)
def score(self, X, y)
Expand source code
def score(self, X, y):
    return self.est_.score(X, y)
def set_fit_request(self: AutoInterpretableModel, *, cv: Union[bool, ForwardRef(None), str] = '$UNCHANGED$') ‑> AutoInterpretableModel

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see :func:sklearn.set_config). Please see :ref:User Guide <metadata_routing> on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version: 1.3

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:~sklearn.pipeline.Pipeline. Otherwise it has no effect.

Parameters

cv : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for cv parameter in fit.

Returns

self : object
The updated object.
Expand source code
def func(**kw):
    """Updates the request for provided parameters

    This docstring is overwritten below.
    See REQUESTER_DOC for expected functionality
    """
    if not _routing_enabled():
        raise RuntimeError(
            "This method is only available when metadata routing is enabled."
            " You can enable it using"
            " sklearn.set_config(enable_metadata_routing=True)."
        )

    if self.validate_keys and (set(kw) - set(self.keys)):
        raise TypeError(
            f"Unexpected args: {set(kw) - set(self.keys)}. Accepted arguments"
            f" are: {set(self.keys)}"
        )

    requests = instance._get_metadata_request()
    method_metadata_request = getattr(requests, self.name)

    for prop, alias in kw.items():
        if alias is not UNCHANGED:
            method_metadata_request.add_request(param=prop, alias=alias)
    instance._metadata_request = requests

    return instance
class AutoInterpretableRegressor (param_grid=None, refit=True)

Automatically fit and select a classifier that is interpretable. Note that all preprocessing should be done beforehand. This is basically a wrapper around GridSearchCV, with some preselected models.

Expand source code
class AutoInterpretableRegressor(AutoInterpretableModel, RegressorMixin):
    ...

Ancestors

  • AutoInterpretableModel
  • sklearn.base.BaseEstimator
  • sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
  • sklearn.utils._metadata_requests._MetadataRequester
  • sklearn.base.RegressorMixin

Inherited members