Expand source code
import json
import warnings
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.utils import validation
from imodels import GreedyTreeClassifier
from imodels.tree.gosdt.pygosdt_helper import TreeClassifier
from imodels.util import rule
try:
import gosdt
gosdt_supported = True
except ImportError:
gosdt_supported = False
class OptimalTreeClassifier(GreedyTreeClassifier if not gosdt_supported else BaseEstimator):
def __init__(self,
balance=False,
cancellation=True,
look_ahead=True,
similar_support=True,
feature_exchange=True,
continuous_feature_exchange=True,
rule_list=False,
diagnostics=False,
verbose=False,
regularization=0.05,
uncertainty_tolerance=0.0,
upperbound=0.0,
model_limit=1,
precision_limit=0,
stack_limit=0,
tile_limit=0,
time_limit=0,
worker_limit=1,
random_state=None,
costs="",
model="",
profile="",
timing="",
trace="",
tree=""):
super().__init__()
self.balance = balance
self.cancellation = cancellation
self.look_ahead = look_ahead
self.similar_support = similar_support
self.feature_exchange = feature_exchange
self.continuous_feature_exchange = continuous_feature_exchange
self.rule_list = rule_list
self.diagnostics = diagnostics
self.verbose = verbose
self.regularization = regularization
self.uncertainty_tolerance = uncertainty_tolerance
self.upperbound = upperbound
self.model_limit = model_limit
self.precision_limit = precision_limit
self.stack_limit = stack_limit
self.tile_limit = tile_limit
self.time_limit = time_limit
self.worker_limit = worker_limit
self.costs = costs
self.model = model
self.profile = profile
self.timing = timing
self.trace = trace
self.tree = tree
self.tree_type = 'gosdt'
self.random_state = random_state
if random_state is not None:
np.random.seed(random_state)
def load(self, path):
"""
Parameters
---
path : string
path to a JSON file representing a model
"""
with open(path, 'r') as model_source:
result = model_source.read()
result = json.loads(result)
self.tree_ = TreeClassifier(result[0])
def fit(self, X, y, feature_names=None):
"""
Parameters
---
X : matrix-like, shape = [n_samples, m_features]
matrix containing the training samples and features
y : array-like, shape = [n_samples, 1]
column containing the correct label for each sample in X
Modifies
---
trains the model so that this model instance is ready for prediction
"""
try:
import gosdt
if not isinstance(X, pd.DataFrame):
self.feature_names_ = list(rule.get_feature_dict(X.shape[1], feature_names).keys())
X = pd.DataFrame(X, columns=self.feature_names_)
else:
self.feature_names_ = X.columns
if not isinstance(y, pd.DataFrame):
y = pd.DataFrame(y, columns=['target'])
# gosdt extension expects serialized CSV, which we generate via pandas
dataset_with_target = pd.concat((X, y), axis=1)
# Perform C++ extension calls to train the model
configuration = self._get_configuration()
gosdt.configure(json.dumps(configuration, separators=(',', ':')))
result = gosdt.fit(dataset_with_target.to_csv(index=False))
result = json.loads(result)
self.tree_ = TreeClassifier(result[0])
# Record the training time, number of iterations, and graph size required
self.time_ = gosdt.time()
self.iterations_ = gosdt.iterations()
self.size_ = gosdt.size()
except ImportError:
warnings.warn(
"Should install gosdt extension. On x86_64 linux or macOS: "
"'pip install gosdt-deprecated'. On other platforms, see "
"https://github.com/keyan3/GeneralizedOptimalSparseDecisionTrees. "
"Defaulting to Non-optimal DecisionTreeClassifier."
)
# dtree = DecisionTreeClassifierWithComplexity()
# dtree.fit(X, y)
# self.tree_ = dtree
super().fit(X, y, feature_names=feature_names)
self.tree_type = 'dt'
return self
def predict(self, X):
"""
Parameters
---
X : matrix-like, shape = [n_samples, m_features]
a matrix where each row is a sample to be predicted and each column is a feature to
be used for prediction
Returns
---
array-like, shape = [n_samples, 1] : a column where each element is the prediction
associated with each row
"""
validation.check_is_fitted(self)
if self.tree_type == 'gosdt':
if type(self.tree_) is TreeClassifier and not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X, columns=self.feature_names_)
return self.tree_.predict(X)
else:
return super().predict(X)
def predict_proba(self, X):
validation.check_is_fitted(self)
if self.tree_type == 'gosdt':
if type(self.tree_) is TreeClassifier and not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X, columns=self.feature_names_)
probs = np.expand_dims(self.tree_.confidence(X), axis=1)
return np.hstack((1 - probs, probs))
else:
return super().predict_proba(X)
def score(self, X, y, weight=None):
"""
Parameters
---
X : matrix-like, shape = [n_samples, m_features]
an n-by-m matrix of sample and their features
y : array-like, shape = [n_samples,]
an n-by-1 column of labels associated with each sample
weight : shape = [n_samples,]
an n-by-1 column of weights to apply to each sample's misclassification
Returns
---
real number : the accuracy produced by applying this model over the given dataset, with
optionals for weighted accuracy
"""
validation.check_is_fitted(self)
if type(self.tree_) is TreeClassifier:
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X, columns=self.feature_names_)
return self.tree_.score(X, y, weight=weight)
else:
return self.tree_.score(X, y, sample_weight=weight)
def __len__(self):
"""
Returns
---
natural number : The number of terminal nodes present in this tree
"""
validation.check_is_fitted(self)
if type(self.tree_) is TreeClassifier:
return len(self.tree_)
else:
warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. "
"DecisionTreeClassifier does not have this method.")
return None
def leaves(self):
"""
Returns
---
natural number : The number of terminal nodes present in this tree
"""
validation.check_is_fitted(self)
if type(self.tree_) is TreeClassifier:
return self.tree_.leaves()
else:
return self.tree_.get_n_leaves()
def nodes(self):
"""
Returns
---
natural number : The number of nodes present in this tree
"""
validation.check_is_fitted(self)
if type(self.tree_) is TreeClassifier:
return self.tree_.nodes()
else:
warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. "
"DecisionTreeClassifier does not have this method.")
return None
def max_depth(self):
"""
Returns
---
natural number : the length of the longest decision path in this tree. A single-node tree
will return 1.
"""
validation.check_is_fitted(self)
if type(self.tree_) is TreeClassifier:
return self.tree_.maximum_depth()
else:
return self.tree_.get_depth()
def latex(self):
"""
Note
---
This method doesn't work well for label headers that contain underscores due to underscore
being a reserved character in LaTeX
Returns
---
string : A LaTeX string representing the model
"""
validation.check_is_fitted(self)
if type(self.tree_) is TreeClassifier:
return self.tree_.latex()
else:
warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. "
"DecisionTreeClassifier does not have this method.")
return None
def json(self):
"""
Returns
---
string : A JSON string representing the model
"""
validation.check_is_fitted(self)
if type(self.tree_) is TreeClassifier:
return self.tree_.json()
else:
warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. "
"DecisionTreeClassifier does not have this method.")
return None
def _get_configuration(self):
return {
"balance": self.balance,
"cancellation": self.cancellation,
"look_ahead": self.look_ahead,
"similar_support": self.similar_support,
"feature_exchange": self.feature_exchange,
"continuous_feature_exchange": self.continuous_feature_exchange,
"rule_list": self.rule_list,
"diagnostics": self.diagnostics,
"verbose": self.verbose,
"regularization": self.regularization,
"uncertainty_tolerance": self.uncertainty_tolerance,
"upperbound": self.upperbound,
"model_limit": self.model_limit,
"precision_limit": self.precision_limit,
"stack_limit": self.stack_limit,
"tile_limit": self.tile_limit,
"time_limit": self.time_limit,
"worker_limit": self.worker_limit,
"costs": self.costs,
"model": self.model,
"profile": self.profile,
"timing": self.timing,
"trace": self.trace,
"tree": self.tree
}
Classes
class OptimalTreeClassifier (balance=False, cancellation=True, look_ahead=True, similar_support=True, feature_exchange=True, continuous_feature_exchange=True, rule_list=False, diagnostics=False, verbose=False, regularization=0.05, uncertainty_tolerance=0.0, upperbound=0.0, model_limit=1, precision_limit=0, stack_limit=0, tile_limit=0, time_limit=0, worker_limit=1, random_state=None, costs='', model='', profile='', timing='', trace='', tree='')
-
Wrapper around sklearn greedy tree classifier
Expand source code
class OptimalTreeClassifier(GreedyTreeClassifier if not gosdt_supported else BaseEstimator): def __init__(self, balance=False, cancellation=True, look_ahead=True, similar_support=True, feature_exchange=True, continuous_feature_exchange=True, rule_list=False, diagnostics=False, verbose=False, regularization=0.05, uncertainty_tolerance=0.0, upperbound=0.0, model_limit=1, precision_limit=0, stack_limit=0, tile_limit=0, time_limit=0, worker_limit=1, random_state=None, costs="", model="", profile="", timing="", trace="", tree=""): super().__init__() self.balance = balance self.cancellation = cancellation self.look_ahead = look_ahead self.similar_support = similar_support self.feature_exchange = feature_exchange self.continuous_feature_exchange = continuous_feature_exchange self.rule_list = rule_list self.diagnostics = diagnostics self.verbose = verbose self.regularization = regularization self.uncertainty_tolerance = uncertainty_tolerance self.upperbound = upperbound self.model_limit = model_limit self.precision_limit = precision_limit self.stack_limit = stack_limit self.tile_limit = tile_limit self.time_limit = time_limit self.worker_limit = worker_limit self.costs = costs self.model = model self.profile = profile self.timing = timing self.trace = trace self.tree = tree self.tree_type = 'gosdt' self.random_state = random_state if random_state is not None: np.random.seed(random_state) def load(self, path): """ Parameters --- path : string path to a JSON file representing a model """ with open(path, 'r') as model_source: result = model_source.read() result = json.loads(result) self.tree_ = TreeClassifier(result[0]) def fit(self, X, y, feature_names=None): """ Parameters --- X : matrix-like, shape = [n_samples, m_features] matrix containing the training samples and features y : array-like, shape = [n_samples, 1] column containing the correct label for each sample in X Modifies --- trains the model so that this model instance is ready for prediction """ try: import gosdt if not isinstance(X, pd.DataFrame): self.feature_names_ = list(rule.get_feature_dict(X.shape[1], feature_names).keys()) X = pd.DataFrame(X, columns=self.feature_names_) else: self.feature_names_ = X.columns if not isinstance(y, pd.DataFrame): y = pd.DataFrame(y, columns=['target']) # gosdt extension expects serialized CSV, which we generate via pandas dataset_with_target = pd.concat((X, y), axis=1) # Perform C++ extension calls to train the model configuration = self._get_configuration() gosdt.configure(json.dumps(configuration, separators=(',', ':'))) result = gosdt.fit(dataset_with_target.to_csv(index=False)) result = json.loads(result) self.tree_ = TreeClassifier(result[0]) # Record the training time, number of iterations, and graph size required self.time_ = gosdt.time() self.iterations_ = gosdt.iterations() self.size_ = gosdt.size() except ImportError: warnings.warn( "Should install gosdt extension. On x86_64 linux or macOS: " "'pip install gosdt-deprecated'. On other platforms, see " "https://github.com/keyan3/GeneralizedOptimalSparseDecisionTrees. " "Defaulting to Non-optimal DecisionTreeClassifier." ) # dtree = DecisionTreeClassifierWithComplexity() # dtree.fit(X, y) # self.tree_ = dtree super().fit(X, y, feature_names=feature_names) self.tree_type = 'dt' return self def predict(self, X): """ Parameters --- X : matrix-like, shape = [n_samples, m_features] a matrix where each row is a sample to be predicted and each column is a feature to be used for prediction Returns --- array-like, shape = [n_samples, 1] : a column where each element is the prediction associated with each row """ validation.check_is_fitted(self) if self.tree_type == 'gosdt': if type(self.tree_) is TreeClassifier and not isinstance(X, pd.DataFrame): X = pd.DataFrame(X, columns=self.feature_names_) return self.tree_.predict(X) else: return super().predict(X) def predict_proba(self, X): validation.check_is_fitted(self) if self.tree_type == 'gosdt': if type(self.tree_) is TreeClassifier and not isinstance(X, pd.DataFrame): X = pd.DataFrame(X, columns=self.feature_names_) probs = np.expand_dims(self.tree_.confidence(X), axis=1) return np.hstack((1 - probs, probs)) else: return super().predict_proba(X) def score(self, X, y, weight=None): """ Parameters --- X : matrix-like, shape = [n_samples, m_features] an n-by-m matrix of sample and their features y : array-like, shape = [n_samples,] an n-by-1 column of labels associated with each sample weight : shape = [n_samples,] an n-by-1 column of weights to apply to each sample's misclassification Returns --- real number : the accuracy produced by applying this model over the given dataset, with optionals for weighted accuracy """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X, columns=self.feature_names_) return self.tree_.score(X, y, weight=weight) else: return self.tree_.score(X, y, sample_weight=weight) def __len__(self): """ Returns --- natural number : The number of terminal nodes present in this tree """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return len(self.tree_) else: warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. " "DecisionTreeClassifier does not have this method.") return None def leaves(self): """ Returns --- natural number : The number of terminal nodes present in this tree """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return self.tree_.leaves() else: return self.tree_.get_n_leaves() def nodes(self): """ Returns --- natural number : The number of nodes present in this tree """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return self.tree_.nodes() else: warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. " "DecisionTreeClassifier does not have this method.") return None def max_depth(self): """ Returns --- natural number : the length of the longest decision path in this tree. A single-node tree will return 1. """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return self.tree_.maximum_depth() else: return self.tree_.get_depth() def latex(self): """ Note --- This method doesn't work well for label headers that contain underscores due to underscore being a reserved character in LaTeX Returns --- string : A LaTeX string representing the model """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return self.tree_.latex() else: warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. " "DecisionTreeClassifier does not have this method.") return None def json(self): """ Returns --- string : A JSON string representing the model """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return self.tree_.json() else: warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. " "DecisionTreeClassifier does not have this method.") return None def _get_configuration(self): return { "balance": self.balance, "cancellation": self.cancellation, "look_ahead": self.look_ahead, "similar_support": self.similar_support, "feature_exchange": self.feature_exchange, "continuous_feature_exchange": self.continuous_feature_exchange, "rule_list": self.rule_list, "diagnostics": self.diagnostics, "verbose": self.verbose, "regularization": self.regularization, "uncertainty_tolerance": self.uncertainty_tolerance, "upperbound": self.upperbound, "model_limit": self.model_limit, "precision_limit": self.precision_limit, "stack_limit": self.stack_limit, "tile_limit": self.tile_limit, "time_limit": self.time_limit, "worker_limit": self.worker_limit, "costs": self.costs, "model": self.model, "profile": self.profile, "timing": self.timing, "trace": self.trace, "tree": self.tree }
Ancestors
- GreedyTreeClassifier
- sklearn.tree._classes.DecisionTreeClassifier
- sklearn.base.ClassifierMixin
- sklearn.tree._classes.BaseDecisionTree
- sklearn.base.MultiOutputMixin
- sklearn.base.BaseEstimator
- sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
- sklearn.utils._metadata_requests._MetadataRequester
Methods
def fit(self, X, y, feature_names=None)
-
Parameters
X
:matrix-like, shape = [n_samples, m_features]
- matrix containing the training samples and features
y
:array-like, shape = [n_samples, 1]
- column containing the correct label for each sample in X
Modifies
trains the model so that this model instance is ready for prediction
Expand source code
def fit(self, X, y, feature_names=None): """ Parameters --- X : matrix-like, shape = [n_samples, m_features] matrix containing the training samples and features y : array-like, shape = [n_samples, 1] column containing the correct label for each sample in X Modifies --- trains the model so that this model instance is ready for prediction """ try: import gosdt if not isinstance(X, pd.DataFrame): self.feature_names_ = list(rule.get_feature_dict(X.shape[1], feature_names).keys()) X = pd.DataFrame(X, columns=self.feature_names_) else: self.feature_names_ = X.columns if not isinstance(y, pd.DataFrame): y = pd.DataFrame(y, columns=['target']) # gosdt extension expects serialized CSV, which we generate via pandas dataset_with_target = pd.concat((X, y), axis=1) # Perform C++ extension calls to train the model configuration = self._get_configuration() gosdt.configure(json.dumps(configuration, separators=(',', ':'))) result = gosdt.fit(dataset_with_target.to_csv(index=False)) result = json.loads(result) self.tree_ = TreeClassifier(result[0]) # Record the training time, number of iterations, and graph size required self.time_ = gosdt.time() self.iterations_ = gosdt.iterations() self.size_ = gosdt.size() except ImportError: warnings.warn( "Should install gosdt extension. On x86_64 linux or macOS: " "'pip install gosdt-deprecated'. On other platforms, see " "https://github.com/keyan3/GeneralizedOptimalSparseDecisionTrees. " "Defaulting to Non-optimal DecisionTreeClassifier." ) # dtree = DecisionTreeClassifierWithComplexity() # dtree.fit(X, y) # self.tree_ = dtree super().fit(X, y, feature_names=feature_names) self.tree_type = 'dt' return self
def json(self)
-
Returns
string
:A JSON string representing the model
Expand source code
def json(self): """ Returns --- string : A JSON string representing the model """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return self.tree_.json() else: warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. " "DecisionTreeClassifier does not have this method.") return None
def latex(self)
-
Note
This method doesn't work well for label headers that contain underscores due to underscore being a reserved character in LaTeX
Returns
string
:A LaTeX string representing the model
Expand source code
def latex(self): """ Note --- This method doesn't work well for label headers that contain underscores due to underscore being a reserved character in LaTeX Returns --- string : A LaTeX string representing the model """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return self.tree_.latex() else: warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. " "DecisionTreeClassifier does not have this method.") return None
def leaves(self)
-
Returns
natural number : The number
ofterminal nodes present in this tree
Expand source code
def leaves(self): """ Returns --- natural number : The number of terminal nodes present in this tree """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return self.tree_.leaves() else: return self.tree_.get_n_leaves()
def load(self, path)
-
Parameters
path
:string
- path to a JSON file representing a model
Expand source code
def load(self, path): """ Parameters --- path : string path to a JSON file representing a model """ with open(path, 'r') as model_source: result = model_source.read() result = json.loads(result) self.tree_ = TreeClassifier(result[0])
def max_depth(self)
-
Returns
natural number : the length
ofthe longest decision path in this tree. A single-node tree
- will return 1.
Expand source code
def max_depth(self): """ Returns --- natural number : the length of the longest decision path in this tree. A single-node tree will return 1. """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return self.tree_.maximum_depth() else: return self.tree_.get_depth()
def nodes(self)
-
Returns
natural number : The number
ofnodes present in this tree
Expand source code
def nodes(self): """ Returns --- natural number : The number of nodes present in this tree """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: return self.tree_.nodes() else: warnings.warn("Using DecisionTreeClassifier due to absence of gosdt package. " "DecisionTreeClassifier does not have this method.") return None
def predict(self, X)
-
Parameters
X
:matrix-like, shape = [n_samples, m_features]
- a matrix where each row is a sample to be predicted and each column is a feature to be used for prediction
Returns
array-like, shape = [n_samples, 1] : a column where each element is the prediction
- associated with each row
Expand source code
def predict(self, X): """ Parameters --- X : matrix-like, shape = [n_samples, m_features] a matrix where each row is a sample to be predicted and each column is a feature to be used for prediction Returns --- array-like, shape = [n_samples, 1] : a column where each element is the prediction associated with each row """ validation.check_is_fitted(self) if self.tree_type == 'gosdt': if type(self.tree_) is TreeClassifier and not isinstance(X, pd.DataFrame): X = pd.DataFrame(X, columns=self.feature_names_) return self.tree_.predict(X) else: return super().predict(X)
def predict_proba(self, X)
-
Predict class probabilities of the input samples X.
The predicted class probability is the fraction of samples of the same class in a leaf.
Parameters
X
:{array-like, sparse matrix}
ofshape (n_samples, n_features)
- The input samples. Internally, it will be converted to
dtype=np.float32
and if a sparse matrix is provided to a sparsecsr_matrix
. check_input
:bool
, default=True
- Allow to bypass several input checking. Don't use this parameter unless you know what you're doing.
Returns
proba
:ndarray
ofshape (n_samples, n_classes)
orlist
ofn_outputs such arrays if n_outputs > 1
- The class probabilities of the input samples. The order of the
classes corresponds to that in the attribute :term:
classes_
.
Expand source code
def predict_proba(self, X): validation.check_is_fitted(self) if self.tree_type == 'gosdt': if type(self.tree_) is TreeClassifier and not isinstance(X, pd.DataFrame): X = pd.DataFrame(X, columns=self.feature_names_) probs = np.expand_dims(self.tree_.confidence(X), axis=1) return np.hstack((1 - probs, probs)) else: return super().predict_proba(X)
def score(self, X, y, weight=None)
-
Parameters
X
:matrix-like, shape = [n_samples, m_features]
- an n-by-m matrix of sample and their features
y
:array-like, shape = [n_samples,]
- an n-by-1 column of labels associated with each sample
weight
:shape = [n_samples,]
- an n-by-1 column of weights to apply to each sample's misclassification
Returns
real number : the accuracy produced by applying this model over the given dataset, with
- optionals for weighted accuracy
Expand source code
def score(self, X, y, weight=None): """ Parameters --- X : matrix-like, shape = [n_samples, m_features] an n-by-m matrix of sample and their features y : array-like, shape = [n_samples,] an n-by-1 column of labels associated with each sample weight : shape = [n_samples,] an n-by-1 column of weights to apply to each sample's misclassification Returns --- real number : the accuracy produced by applying this model over the given dataset, with optionals for weighted accuracy """ validation.check_is_fitted(self) if type(self.tree_) is TreeClassifier: if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X, columns=self.feature_names_) return self.tree_.score(X, y, weight=weight) else: return self.tree_.score(X, y, sample_weight=weight)
def set_fit_request(self: OptimalTreeClassifier, *, feature_names: Union[bool, ForwardRef(None), str] = '$UNCHANGED$') ‑> OptimalTreeClassifier
-
Request metadata passed to the
fit
method.Note that this method is only relevant if
enable_metadata_routing=True
(see :func:sklearn.set_config
). Please see :ref:User Guide <metadata_routing>
on how the routing mechanism works.The options for each parameter are:
-
True
: metadata is requested, and passed tofit
if provided. The request is ignored if metadata is not provided. -
False
: metadata is not requested and the meta-estimator will not pass it tofit
. -
None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it. -
str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version: 1.3
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:
~sklearn.pipeline.Pipeline
. Otherwise it has no effect.Parameters
feature_names
:str, True, False,
orNone
, default=sklearn.utils.metadata_routing.UNCHANGED
- Metadata routing for
feature_names
parameter infit
.
Returns
self
:object
- The updated object.
Expand source code
def func(*args, **kw): """Updates the request for provided parameters This docstring is overwritten below. See REQUESTER_DOC for expected functionality """ if not _routing_enabled(): raise RuntimeError( "This method is only available when metadata routing is enabled." " You can enable it using" " sklearn.set_config(enable_metadata_routing=True)." ) if self.validate_keys and (set(kw) - set(self.keys)): raise TypeError( f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. " f"Accepted arguments are: {set(self.keys)}" ) # This makes it possible to use the decorated method as an unbound method, # for instance when monkeypatching. # https://github.com/scikit-learn/scikit-learn/issues/28632 if instance is None: _instance = args[0] args = args[1:] else: _instance = instance # Replicating python's behavior when positional args are given other than # `self`, and `self` is only allowed if this method is unbound. if args: raise TypeError( f"set_{self.name}_request() takes 0 positional argument but" f" {len(args)} were given" ) requests = _instance._get_metadata_request() method_metadata_request = getattr(requests, self.name) for prop, alias in kw.items(): if alias is not UNCHANGED: method_metadata_request.add_request(param=prop, alias=alias) _instance._metadata_request = requests return _instance
-
def set_score_request(self: OptimalTreeClassifier, *, weight: Union[bool, ForwardRef(None), str] = '$UNCHANGED$') ‑> OptimalTreeClassifier
-
Request metadata passed to the
score
method.Note that this method is only relevant if
enable_metadata_routing=True
(see :func:sklearn.set_config
). Please see :ref:User Guide <metadata_routing>
on how the routing mechanism works.The options for each parameter are:
-
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided. -
False
: metadata is not requested and the meta-estimator will not pass it toscore
. -
None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it. -
str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version: 1.3
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:
~sklearn.pipeline.Pipeline
. Otherwise it has no effect.Parameters
weight
:str, True, False,
orNone
, default=sklearn.utils.metadata_routing.UNCHANGED
- Metadata routing for
weight
parameter inscore
.
Returns
self
:object
- The updated object.
Expand source code
def func(*args, **kw): """Updates the request for provided parameters This docstring is overwritten below. See REQUESTER_DOC for expected functionality """ if not _routing_enabled(): raise RuntimeError( "This method is only available when metadata routing is enabled." " You can enable it using" " sklearn.set_config(enable_metadata_routing=True)." ) if self.validate_keys and (set(kw) - set(self.keys)): raise TypeError( f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. " f"Accepted arguments are: {set(self.keys)}" ) # This makes it possible to use the decorated method as an unbound method, # for instance when monkeypatching. # https://github.com/scikit-learn/scikit-learn/issues/28632 if instance is None: _instance = args[0] args = args[1:] else: _instance = instance # Replicating python's behavior when positional args are given other than # `self`, and `self` is only allowed if this method is unbound. if args: raise TypeError( f"set_{self.name}_request() takes 0 positional argument but" f" {len(args)} were given" ) requests = _instance._get_metadata_request() method_metadata_request = getattr(requests, self.name) for prop, alias in kw.items(): if alias is not UNCHANGED: method_metadata_request.add_request(param=prop, alias=alias) _instance._metadata_request = requests return _instance
-
Inherited members