Expand source code
from inspect import isclass


def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
    """Perform is_fitted validation for estimator.
    Checks if the estimator is fitted by verifying the presence of
    fitted attributes (ending with a trailing underscore) and otherwise
    raises a NotFittedError with the given message.
    If an estimator does not set any attributes with a trailing underscore, it
    can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the
    estimator is fitted or not.
    Parameters
    ----------
    estimator : estimator instance
        estimator instance for which the check is performed.
    attributes : str, list or tuple of str, default=None
        Attribute name(s) given as string or a list/tuple of strings
        Eg.: ``["coef_", "estimator_", ...], "coef_"``
        If `None`, `estimator` is considered fitted if there exist an
        attribute that ends with a underscore and does not start with double
        underscore.
    msg : str, default=None
        The default error message is, "This %(name)s instance is not fitted
        yet. Call 'fit' with appropriate arguments before using this
        estimator."
        For custom messages if "%(name)s" is present in the message string,
        it is substituted for the estimator name.
        Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
    all_or_any : callable, {all, any}, default=all
        Specify whether all or any of the given attributes must exist.
    Returns
    -------
    fitted: bool
    """
    if isclass(estimator):
        raise TypeError("{} is a class, not an instance.".format(estimator))

    if not hasattr(estimator, "fit"):
        raise TypeError("%s is not an estimator instance." % (estimator))

    if attributes is not None:
        if not isinstance(attributes, (list, tuple)):
            attributes = [attributes]
        return all_or_any([hasattr(estimator, attr) for attr in attributes])
    elif hasattr(estimator, "__sklearn_is_fitted__"):
        return estimator.__sklearn_is_fitted__()
    else:
        return len([
            v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
        ]) > 0

Functions

def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=<built-in function all>)

Perform is_fitted validation for estimator. Checks if the estimator is fitted by verifying the presence of fitted attributes (ending with a trailing underscore) and otherwise raises a NotFittedError with the given message. If an estimator does not set any attributes with a trailing underscore, it can define a __sklearn_is_fitted__ method returning a boolean to specify if the estimator is fitted or not. Parameters


estimator : estimator instance
estimator instance for which the check is performed.
attributes : str, list or tuple of str, default=None
Attribute name(s) given as string or a list/tuple of strings Eg.: ["coef_", "estimator_", ...], "coef_" If None, estimator is considered fitted if there exist an attribute that ends with a underscore and does not start with double underscore.
msg : str, default=None
The default error message is, "This %(name)s instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator." For custom messages if "%(name)s" is present in the message string, it is substituted for the estimator name. Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
all_or_any : callable, {all, any}, default=all
Specify whether all or any of the given attributes must exist.

Returns

fitted : bool
 
Expand source code
def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
    """Perform is_fitted validation for estimator.
    Checks if the estimator is fitted by verifying the presence of
    fitted attributes (ending with a trailing underscore) and otherwise
    raises a NotFittedError with the given message.
    If an estimator does not set any attributes with a trailing underscore, it
    can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the
    estimator is fitted or not.
    Parameters
    ----------
    estimator : estimator instance
        estimator instance for which the check is performed.
    attributes : str, list or tuple of str, default=None
        Attribute name(s) given as string or a list/tuple of strings
        Eg.: ``["coef_", "estimator_", ...], "coef_"``
        If `None`, `estimator` is considered fitted if there exist an
        attribute that ends with a underscore and does not start with double
        underscore.
    msg : str, default=None
        The default error message is, "This %(name)s instance is not fitted
        yet. Call 'fit' with appropriate arguments before using this
        estimator."
        For custom messages if "%(name)s" is present in the message string,
        it is substituted for the estimator name.
        Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
    all_or_any : callable, {all, any}, default=all
        Specify whether all or any of the given attributes must exist.
    Returns
    -------
    fitted: bool
    """
    if isclass(estimator):
        raise TypeError("{} is a class, not an instance.".format(estimator))

    if not hasattr(estimator, "fit"):
        raise TypeError("%s is not an estimator instance." % (estimator))

    if attributes is not None:
        if not isinstance(attributes, (list, tuple)):
            attributes = [attributes]
        return all_or_any([hasattr(estimator, attr) for attr in attributes])
    elif hasattr(estimator, "__sklearn_is_fitted__"):
        return estimator.__sklearn_is_fitted__()
    else:
        return len([
            v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
        ]) > 0