Expand source code
from sklearn.base import RegressorMixin, BaseEstimator, is_regressor
class DistilledRegressor(BaseEstimator, RegressorMixin):
"""
Class to implement distillation. Currently only supports regression.
Params
------
teacher: initial model to be trained
must be a regressor or a binary classifier
student: model to be distilled from teacher's predictions
must be a regressor
"""
def __init__(self, teacher: BaseEstimator, student: BaseEstimator,
n_iters_teacher: int=1):
self.teacher = teacher
self.student = student
self.n_iters_teacher = n_iters_teacher
self._validate_student()
self._check_teacher_type()
def _validate_student(self):
if is_regressor(self.student):
pass
else:
if not hasattr(self.student, "prediction_task"):
raise ValueError("Student must be either a scikit-learn or imodels regressor")
elif self.student.prediction_task == "classification":
raise ValueError("Student must be a regressor")
def _check_teacher_type(self):
if hasattr(self.teacher, "prediction_task"):
self.teacher_type = self.teacher.prediction_task
elif hasattr(self.teacher, "_estimator_type"):
if is_regressor(self.teacher):
self.teacher_type = "regression"
else:
self.teacher_type = "classification"
def set_teacher_params(self, **params):
self.teacher.set_params(**params)
def set_student_params(self, **params):
self.student.set_params(**params)
def fit(self, X, y, **kwargs):
# fit teacher
for iter_teacher in range(self.n_iters_teacher):
self.teacher.fit(X, y, **kwargs)
if self.teacher_type == "regression":
y = self.teacher.predict(X)
else:
y = self.teacher.predict_proba(X)[:, 1] # assumes binary classifier
# fit student
self.student.fit(X, y)
def predict(self, X):
return self.student.predict(X)
Classes
class DistilledRegressor (teacher: sklearn.base.BaseEstimator, student: sklearn.base.BaseEstimator, n_iters_teacher: int = 1)
-
Class to implement distillation. Currently only supports regression. Params
teacher: initial model to be trained must be a regressor or a binary classifier student: model to be distilled from teacher's predictions must be a regressor
Expand source code
class DistilledRegressor(BaseEstimator, RegressorMixin): """ Class to implement distillation. Currently only supports regression. Params ------ teacher: initial model to be trained must be a regressor or a binary classifier student: model to be distilled from teacher's predictions must be a regressor """ def __init__(self, teacher: BaseEstimator, student: BaseEstimator, n_iters_teacher: int=1): self.teacher = teacher self.student = student self.n_iters_teacher = n_iters_teacher self._validate_student() self._check_teacher_type() def _validate_student(self): if is_regressor(self.student): pass else: if not hasattr(self.student, "prediction_task"): raise ValueError("Student must be either a scikit-learn or imodels regressor") elif self.student.prediction_task == "classification": raise ValueError("Student must be a regressor") def _check_teacher_type(self): if hasattr(self.teacher, "prediction_task"): self.teacher_type = self.teacher.prediction_task elif hasattr(self.teacher, "_estimator_type"): if is_regressor(self.teacher): self.teacher_type = "regression" else: self.teacher_type = "classification" def set_teacher_params(self, **params): self.teacher.set_params(**params) def set_student_params(self, **params): self.student.set_params(**params) def fit(self, X, y, **kwargs): # fit teacher for iter_teacher in range(self.n_iters_teacher): self.teacher.fit(X, y, **kwargs) if self.teacher_type == "regression": y = self.teacher.predict(X) else: y = self.teacher.predict_proba(X)[:, 1] # assumes binary classifier # fit student self.student.fit(X, y) def predict(self, X): return self.student.predict(X)
Ancestors
- sklearn.base.BaseEstimator
- sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
- sklearn.utils._metadata_requests._MetadataRequester
- sklearn.base.RegressorMixin
Methods
def fit(self, X, y, **kwargs)
-
Expand source code
def fit(self, X, y, **kwargs): # fit teacher for iter_teacher in range(self.n_iters_teacher): self.teacher.fit(X, y, **kwargs) if self.teacher_type == "regression": y = self.teacher.predict(X) else: y = self.teacher.predict_proba(X)[:, 1] # assumes binary classifier # fit student self.student.fit(X, y)
def predict(self, X)
-
Expand source code
def predict(self, X): return self.student.predict(X)
def set_score_request(self: DistilledRegressor, *, sample_weight: Union[bool, ForwardRef(None), str] = '$UNCHANGED$') ‑> DistilledRegressor
-
Request metadata passed to the
score
method.Note that this method is only relevant if
enable_metadata_routing=True
(see :func:sklearn.set_config
). Please see :ref:User Guide <metadata_routing>
on how the routing mechanism works.The options for each parameter are:
-
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided. -
False
: metadata is not requested and the meta-estimator will not pass it toscore
. -
None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it. -
str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version: 1.3
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:
~sklearn.pipeline.Pipeline
. Otherwise it has no effect.Parameters
sample_weight
:str, True, False,
orNone
, default=sklearn.utils.metadata_routing.UNCHANGED
- Metadata routing for
sample_weight
parameter inscore
.
Returns
self
:object
- The updated object.
Expand source code
def func(*args, **kw): """Updates the request for provided parameters This docstring is overwritten below. See REQUESTER_DOC for expected functionality """ if not _routing_enabled(): raise RuntimeError( "This method is only available when metadata routing is enabled." " You can enable it using" " sklearn.set_config(enable_metadata_routing=True)." ) if self.validate_keys and (set(kw) - set(self.keys)): raise TypeError( f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. " f"Accepted arguments are: {set(self.keys)}" ) # This makes it possible to use the decorated method as an unbound method, # for instance when monkeypatching. # https://github.com/scikit-learn/scikit-learn/issues/28632 if instance is None: _instance = args[0] args = args[1:] else: _instance = instance # Replicating python's behavior when positional args are given other than # `self`, and `self` is only allowed if this method is unbound. if args: raise TypeError( f"set_{self.name}_request() takes 0 positional argument but" f" {len(args)} were given" ) requests = _instance._get_metadata_request() method_metadata_request = getattr(requests, self.name) for prop, alias in kw.items(): if alias is not UNCHANGED: method_metadata_request.add_request(param=prop, alias=alias) _instance._metadata_request = requests return _instance
-
def set_student_params(self, **params)
-
Expand source code
def set_student_params(self, **params): self.student.set_params(**params)
def set_teacher_params(self, **params)
-
Expand source code
def set_teacher_params(self, **params): self.teacher.set_params(**params)