Skip to content

Commit

Permalink
ENH remove _xfail_checks, pass directly to check runners, return stru…
Browse files Browse the repository at this point in the history
…ctured output from check_estimator (scikit-learn#30149)
  • Loading branch information
adrinjalali authored Nov 8, 2024
1 parent a71860a commit 9012b78
Show file tree
Hide file tree
Showing 43 changed files with 1,025 additions and 571 deletions.
2 changes: 2 additions & 0 deletions doc/api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def _get_submodule(module_name, submodule_name):
"InconsistentVersionWarning",
"NotFittedError",
"UndefinedMetricWarning",
"EstimatorCheckFailedWarning",
],
},
],
Expand Down Expand Up @@ -1298,6 +1299,7 @@ def _get_submodule(module_name, submodule_name):
"autosummary": [
"estimator_checks.check_estimator",
"estimator_checks.parametrize_with_checks",
"estimator_checks.estimator_checks_generator",
],
},
{
Expand Down
3 changes: 2 additions & 1 deletion doc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ General Concepts
This refers to the tests run on almost every estimator class in
Scikit-learn to check they comply with basic API conventions. They are
available for external use through
:func:`utils.estimator_checks.check_estimator`, with most of the
:func:`utils.estimator_checks.check_estimator` or
:func:`utils.estimator_checks.parametrize_with_checks`, with most of the
implementation in ``sklearn/utils/estimator_checks.py``.

Note: Some exceptions to the common testing regime are currently
Expand Down
23 changes: 23 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.utils/30149.enhancement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
- Changes to :func:`~utils.estimator_checks.check_estimator` and
:func:`~utils.estimator_checks.parametrize_with_checks`.

- :func:`~utils.estimator_checks.check_estimator` introduces new arguments:
``on_skip``, ``on_fail``, and ``callback`` to control the behavior of the check
runner. Refer to the API documentation for more details.

- ``generate_only=True`` is deprecated in
:func:`~utils.estimator_checks.check_estimator`. Use
:func:`~utils.estimator_checks.estimator_checks_generator` instead.

- The ``_xfail_checks`` estimator tag is now removed, and now in order to indicate
which tests are expected to fail, you can pass a dictionary to the
:func:`~utils.estimator_checks.check_estimator` as the ``expected_failed_checks``
parameter. Similarly, the ``expected_failed_checks`` parameter in
:func:`~utils.estimator_checks.parametrize_with_checks` can be used, which is a
callable returning a dictionary of the form::

{
"check_name": "reason to mark this check as xfail",
}

By `Adrin Jalali`_
37 changes: 37 additions & 0 deletions maint_tools/check_xfailed_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This script checks that the common tests marked with xfail are actually
# failing.
# Note that in some cases, a test might be marked with xfail because it is
# failing on certain machines, and might not be triggered by this script.

import contextlib
import io

from sklearn.utils._test_common.instance_generator import (
_get_expected_failed_checks,
_tested_estimators,
)
from sklearn.utils.estimator_checks import check_estimator

for estimator in _tested_estimators():
# calling check_estimator w/o passing expected_failed_checks will find
# all the failing tests in your environment.
# suppress stdout/stderr while running checks
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
):
check_results = check_estimator(estimator, on_skip=None, on_fail=None)
failed_tests = [e for e in check_results if e["status"] == "failed"]
failed_test_names = set(e["check_name"] for e in failed_tests)
expected_failed_tests = set(_get_expected_failed_checks(estimator).keys())
unexpected_failures = failed_test_names - expected_failed_tests
if unexpected_failures:
print(f"{estimator.__class__.__name__} failed with unexpected failures:")
for failure in unexpected_failures:
print(f" {failure}")

expected_but_not_raised = expected_failed_tests - failed_test_names
if expected_but_not_raised:
print(f"{estimator.__class__.__name__} did not fail expected failures:")
for failure in expected_but_not_raised:
print(f" {failure}")
25 changes: 0 additions & 25 deletions sklearn/cluster/_bicluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,20 +193,6 @@ def _k_means(self, data, n_clusters):
labels = model.labels_
return centroid, labels

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks = {
"check_estimators_dtypes": "raises nan error",
"check_fit2d_1sample": "_scale_normalize fails",
"check_fit2d_1feature": "raises apply_along_axis error",
"check_estimator_sparse_matrix": "does not fail gracefully",
"check_estimator_sparse_array": "does not fail gracefully",
"check_methods_subset_invariance": "empty array passed inside",
"check_dont_overwrite_parameters": "empty array passed inside",
"check_fit2d_predict1d": "empty array passed inside",
}
return tags


class SpectralCoclustering(BaseSpectral):
"""Spectral Co-Clustering algorithm (Dhillon, 2001).
Expand Down Expand Up @@ -362,17 +348,6 @@ def _fit(self, X):
[self.column_labels_ == c for c in range(self.n_clusters)]
)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks.update(
{
# ValueError: Found array with 0 feature(s) (shape=(23, 0))
# while a minimum of 1 is required.
"check_dict_unchanged": "FIXME",
}
)
return tags


class SpectralBiclustering(BaseSpectral):
"""Spectral biclustering (Kluger, 2003).
Expand Down
10 changes: 0 additions & 10 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,16 +1177,6 @@ def score(self, X, y=None, sample_weight=None):
)
return -scores

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


class KMeans(_BaseKMeans):
"""K-Means clustering.
Expand Down
14 changes: 0 additions & 14 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,20 +1315,6 @@ def get_metadata_routing(self):

return router

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks = {
"check_estimators_empty_data_messages": "FIXME",
"check_estimators_nan_inf": "FIXME",
"check_estimator_sparse_array": "FIXME",
"check_estimator_sparse_matrix": "FIXME",
"check_fit1d": "FIXME",
"check_fit2d_predict1d": "FIXME",
"check_complex_data": "FIXME",
"check_fit2d_1feature": "FIXME",
}
return tags


def _check_X(X):
"""Use check_array only when necessary, e.g. on lists and other non-array-likes."""
Expand Down
4 changes: 0 additions & 4 deletions sklearn/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,6 @@ def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.classifier_tags.poor_score = True
tags.no_validation = True
tags._xfail_checks = {
"check_methods_subset_invariance": "fails for the predict method",
"check_methods_sample_order_invariance": "fails for the predict method",
}
return tags

def score(self, X, y, sample_weight=None):
Expand Down
6 changes: 0 additions & 6 deletions sklearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,12 +628,6 @@ def _get_estimator(self):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = get_tags(self._get_estimator()).input_tags.allow_nan
# TODO: replace by a statistical test, see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


Expand Down
30 changes: 0 additions & 30 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,16 +1557,6 @@ def __init__(
self.monotonic_cst = monotonic_cst
self.ccp_alpha = ccp_alpha

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


class RandomForestRegressor(ForestRegressor):
"""
Expand Down Expand Up @@ -1928,16 +1918,6 @@ def __init__(
self.ccp_alpha = ccp_alpha
self.monotonic_cst = monotonic_cst

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


class ExtraTreesClassifier(ForestClassifier):
"""
Expand Down Expand Up @@ -3012,13 +2992,3 @@ def transform(self, X):
"""
check_is_fitted(self)
return self.one_hot_encoder_.transform(self.apply(X))

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
20 changes: 0 additions & 20 deletions sklearn/ensemble/_gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,16 +1725,6 @@ def staged_predict_proba(self, X):
"loss=%r does not support predict_proba" % self.loss
) from e

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: investigate failure see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting):
"""Gradient Boosting for regression.
Expand Down Expand Up @@ -2191,13 +2181,3 @@ def apply(self, X):
leaves = super().apply(X)
leaves = leaves.reshape(X.shape[0], self.estimators_.shape[0])
return leaves

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: investigate failure see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
6 changes: 0 additions & 6 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,12 +1389,6 @@ def _compute_partial_dependence_recursion(self, grid, target_features):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = True
# TODO: replace by a statistical test, see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags

@abstractmethod
Expand Down
6 changes: 0 additions & 6 deletions sklearn/ensemble/_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,12 +633,6 @@ def _compute_score_samples(self, X, subsample_features):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
tags.input_tags.allow_nan = True
return tags

Expand Down
20 changes: 0 additions & 20 deletions sklearn/ensemble/_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,16 +858,6 @@ def predict_log_proba(self, X):
"""
return np.log(self.predict_proba(X))

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


class AdaBoostRegressor(_RoutingNotSupportedMixin, RegressorMixin, BaseWeightBoosting):
"""An AdaBoost regressor.
Expand Down Expand Up @@ -1176,13 +1166,3 @@ def staged_predict(self, X):

for i, _ in enumerate(self.estimators_, 1):
yield self._get_median_predict(X, limit=i)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #16298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
58 changes: 58 additions & 0 deletions sklearn/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"UndefinedMetricWarning",
"PositiveSpectrumWarning",
"UnsetMetadataPassedError",
"EstimatorCheckFailedWarning",
]


Expand Down Expand Up @@ -189,3 +190,60 @@ def __str__(self):
"https://scikit-learn.org/stable/model_persistence.html"
"#security-maintainability-limitations"
)


class EstimatorCheckFailedWarning(UserWarning):
"""Warning raised when an estimator check from the common tests fails.
Parameters
----------
estimator : estimator object
Estimator instance for which the test failed.
check_name : str
Name of the check that failed.
exception : Exception
Exception raised by the failed check.
status : str
Status of the check.
expected_to_fail : bool
Whether the check was expected to fail.
expected_to_fail_reason : str
Reason for the expected failure.
"""

def __init__(
self,
*,
estimator,
check_name: str,
exception: Exception,
status: str,
expected_to_fail: bool,
expected_to_fail_reason: str,
):
self.estimator = estimator
self.check_name = check_name
self.exception = exception
self.status = status
self.expected_to_fail = expected_to_fail
self.expected_to_fail_reason = expected_to_fail_reason

def __repr__(self):
expected_to_fail_str = (
f"Expected to fail: {self.expected_to_fail_reason}"
if self.expected_to_fail
else "Not expected to fail"
)
return (
f"Test {self.check_name} failed for estimator {self.estimator!r}.\n"
f"Expected to fail reason: {expected_to_fail_str}\n"
f"Exception: {self.exception}"
)

def __str__(self):
return self.__repr__()
Loading

0 comments on commit 9012b78

Please sign in to comment.