Expand source code
from sklearn.base import RegressorMixin, BaseEstimator, is_regressor


class DistilledRegressor(BaseEstimator, RegressorMixin):
    """
    Class to implement distillation. Currently only supports regression.
    Params
    ------
    teacher: initial model to be trained
        must be a regressor or a binary classifier
    student: model to be distilled from teacher's predictions
        must be a regressor
    """

    def __init__(self, teacher: BaseEstimator, student: BaseEstimator, 
                 n_iters_teacher: int=1):
        self.teacher = teacher
        self.student = student
        self.n_iters_teacher = n_iters_teacher
        self._validate_student()
        self._check_teacher_type()

    def _validate_student(self):
        if is_regressor(self.student):
            pass
        else:
            if not hasattr(self.student, "prediction_task"):
                raise ValueError("Student must be either a scikit-learn or imodels regressor")
            elif self.student.prediction_task == "classification":
                raise ValueError("Student must be a regressor")

    def _check_teacher_type(self):
        if hasattr(self.teacher, "prediction_task"):
            self.teacher_type = self.teacher.prediction_task
        elif hasattr(self.teacher, "_estimator_type"):
            if is_regressor(self.teacher):
                self.teacher_type = "regression"
            else:
                self.teacher_type = "classification"

    def set_teacher_params(self, **params):
        self.teacher.set_params(**params)

    def set_student_params(self, **params):
        self.student.set_params(**params)

    def fit(self, X, y, **kwargs):
        # fit teacher
        for iter_teacher in range(self.n_iters_teacher):
            self.teacher.fit(X, y, **kwargs)
            if self.teacher_type == "regression":
                y = self.teacher.predict(X)
            else:
                y = self.teacher.predict_proba(X)[:, 1] # assumes binary classifier
                
        # fit student
        self.student.fit(X, y)

    def predict(self, X):
        return self.student.predict(X)

Classes

class DistilledRegressor (teacher: sklearn.base.BaseEstimator, student: sklearn.base.BaseEstimator, n_iters_teacher: int = 1)

Class to implement distillation. Currently only supports regression. Params


teacher: initial model to be trained must be a regressor or a binary classifier student: model to be distilled from teacher's predictions must be a regressor

Expand source code
class DistilledRegressor(BaseEstimator, RegressorMixin):
    """
    Class to implement distillation. Currently only supports regression.
    Params
    ------
    teacher: initial model to be trained
        must be a regressor or a binary classifier
    student: model to be distilled from teacher's predictions
        must be a regressor
    """

    def __init__(self, teacher: BaseEstimator, student: BaseEstimator, 
                 n_iters_teacher: int=1):
        self.teacher = teacher
        self.student = student
        self.n_iters_teacher = n_iters_teacher
        self._validate_student()
        self._check_teacher_type()

    def _validate_student(self):
        if is_regressor(self.student):
            pass
        else:
            if not hasattr(self.student, "prediction_task"):
                raise ValueError("Student must be either a scikit-learn or imodels regressor")
            elif self.student.prediction_task == "classification":
                raise ValueError("Student must be a regressor")

    def _check_teacher_type(self):
        if hasattr(self.teacher, "prediction_task"):
            self.teacher_type = self.teacher.prediction_task
        elif hasattr(self.teacher, "_estimator_type"):
            if is_regressor(self.teacher):
                self.teacher_type = "regression"
            else:
                self.teacher_type = "classification"

    def set_teacher_params(self, **params):
        self.teacher.set_params(**params)

    def set_student_params(self, **params):
        self.student.set_params(**params)

    def fit(self, X, y, **kwargs):
        # fit teacher
        for iter_teacher in range(self.n_iters_teacher):
            self.teacher.fit(X, y, **kwargs)
            if self.teacher_type == "regression":
                y = self.teacher.predict(X)
            else:
                y = self.teacher.predict_proba(X)[:, 1] # assumes binary classifier
                
        # fit student
        self.student.fit(X, y)

    def predict(self, X):
        return self.student.predict(X)

Ancestors

  • sklearn.base.BaseEstimator
  • sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
  • sklearn.utils._metadata_requests._MetadataRequester
  • sklearn.base.RegressorMixin

Methods

def fit(self, X, y, **kwargs)
Expand source code
def fit(self, X, y, **kwargs):
    # fit teacher
    for iter_teacher in range(self.n_iters_teacher):
        self.teacher.fit(X, y, **kwargs)
        if self.teacher_type == "regression":
            y = self.teacher.predict(X)
        else:
            y = self.teacher.predict_proba(X)[:, 1] # assumes binary classifier
            
    # fit student
    self.student.fit(X, y)
def predict(self, X)
Expand source code
def predict(self, X):
    return self.student.predict(X)
def set_score_request(self: DistilledRegressor, *, sample_weight: Union[bool, ForwardRef(None), str] = '$UNCHANGED$') ‑> DistilledRegressor

Request metadata passed to the score method.

Note that this method is only relevant if enable_metadata_routing=True (see :func:sklearn.set_config). Please see :ref:User Guide <metadata_routing> on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to score.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version: 1.3

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:~sklearn.pipeline.Pipeline. Otherwise it has no effect.

Parameters

sample_weight : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for sample_weight parameter in score.

Returns

self : object
The updated object.
Expand source code
def func(*args, **kw):
    """Updates the request for provided parameters

    This docstring is overwritten below.
    See REQUESTER_DOC for expected functionality
    """
    if not _routing_enabled():
        raise RuntimeError(
            "This method is only available when metadata routing is enabled."
            " You can enable it using"
            " sklearn.set_config(enable_metadata_routing=True)."
        )

    if self.validate_keys and (set(kw) - set(self.keys)):
        raise TypeError(
            f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. "
            f"Accepted arguments are: {set(self.keys)}"
        )

    # This makes it possible to use the decorated method as an unbound method,
    # for instance when monkeypatching.
    # https://github.com/scikit-learn/scikit-learn/issues/28632
    if instance is None:
        _instance = args[0]
        args = args[1:]
    else:
        _instance = instance

    # Replicating python's behavior when positional args are given other than
    # `self`, and `self` is only allowed if this method is unbound.
    if args:
        raise TypeError(
            f"set_{self.name}_request() takes 0 positional argument but"
            f" {len(args)} were given"
        )

    requests = _instance._get_metadata_request()
    method_metadata_request = getattr(requests, self.name)

    for prop, alias in kw.items():
        if alias is not UNCHANGED:
            method_metadata_request.add_request(param=prop, alias=alias)
    _instance._metadata_request = requests

    return _instance
def set_student_params(self, **params)
Expand source code
def set_student_params(self, **params):
    self.student.set_params(**params)
def set_teacher_params(self, **params)
Expand source code
def set_teacher_params(self, **params):
    self.teacher.set_params(**params)