Expand source code
from inspect import isclass
def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
"""Perform is_fitted validation for estimator.
Checks if the estimator is fitted by verifying the presence of
fitted attributes (ending with a trailing underscore) and otherwise
raises a NotFittedError with the given message.
If an estimator does not set any attributes with a trailing underscore, it
can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the
estimator is fitted or not.
Parameters
----------
estimator : estimator instance
estimator instance for which the check is performed.
attributes : str, list or tuple of str, default=None
Attribute name(s) given as string or a list/tuple of strings
Eg.: ``["coef_", "estimator_", ...], "coef_"``
If `None`, `estimator` is considered fitted if there exist an
attribute that ends with a underscore and does not start with double
underscore.
msg : str, default=None
The default error message is, "This %(name)s instance is not fitted
yet. Call 'fit' with appropriate arguments before using this
estimator."
For custom messages if "%(name)s" is present in the message string,
it is substituted for the estimator name.
Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
all_or_any : callable, {all, any}, default=all
Specify whether all or any of the given attributes must exist.
Returns
-------
fitted: bool
"""
if isclass(estimator):
raise TypeError("{} is a class, not an instance.".format(estimator))
if not hasattr(estimator, "fit"):
raise TypeError("%s is not an estimator instance." % (estimator))
if attributes is not None:
if not isinstance(attributes, (list, tuple)):
attributes = [attributes]
return all_or_any([hasattr(estimator, attr) for attr in attributes])
elif hasattr(estimator, "__sklearn_is_fitted__"):
return estimator.__sklearn_is_fitted__()
else:
return len([
v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
]) > 0
Functions
def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=<built-in function all>)
-
Perform is_fitted validation for estimator. Checks if the estimator is fitted by verifying the presence of fitted attributes (ending with a trailing underscore) and otherwise raises a NotFittedError with the given message. If an estimator does not set any attributes with a trailing underscore, it can define a
__sklearn_is_fitted__
method returning a boolean to specify if the estimator is fitted or not. Parameters
estimator
:estimator instance
- estimator instance for which the check is performed.
attributes
:str, list
ortuple
ofstr
, default=None
- Attribute name(s) given as string or a list/tuple of strings
Eg.:
["coef_", "estimator_", ...], "coef_"
IfNone
,estimator
is considered fitted if there exist an attribute that ends with a underscore and does not start with double underscore. msg
:str
, default=None
- The default error message is, "This %(name)s instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator." For custom messages if "%(name)s" is present in the message string, it is substituted for the estimator name. Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
all_or_any
:callable, {all, any}
, default=all
- Specify whether all or any of the given attributes must exist.
Returns
fitted
:bool
Expand source code
def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all): """Perform is_fitted validation for estimator. Checks if the estimator is fitted by verifying the presence of fitted attributes (ending with a trailing underscore) and otherwise raises a NotFittedError with the given message. If an estimator does not set any attributes with a trailing underscore, it can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the estimator is fitted or not. Parameters ---------- estimator : estimator instance estimator instance for which the check is performed. attributes : str, list or tuple of str, default=None Attribute name(s) given as string or a list/tuple of strings Eg.: ``["coef_", "estimator_", ...], "coef_"`` If `None`, `estimator` is considered fitted if there exist an attribute that ends with a underscore and does not start with double underscore. msg : str, default=None The default error message is, "This %(name)s instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator." For custom messages if "%(name)s" is present in the message string, it is substituted for the estimator name. Eg. : "Estimator, %(name)s, must be fitted before sparsifying". all_or_any : callable, {all, any}, default=all Specify whether all or any of the given attributes must exist. Returns ------- fitted: bool """ if isclass(estimator): raise TypeError("{} is a class, not an instance.".format(estimator)) if not hasattr(estimator, "fit"): raise TypeError("%s is not an estimator instance." % (estimator)) if attributes is not None: if not isinstance(attributes, (list, tuple)): attributes = [attributes] return all_or_any([hasattr(estimator, attr) for attr in attributes]) elif hasattr(estimator, "__sklearn_is_fitted__"): return estimator.__sklearn_is_fitted__() else: return len([ v for v in vars(estimator) if v.endswith("_") and not v.startswith("__") ]) > 0