diff --git a/tests/checks.py b/tests/checks.py deleted file mode 100644 index 851509e25..000000000 --- a/tests/checks.py +++ /dev/null @@ -1,166 +0,0 @@ -# tests.checks -# Performs checking that visualizers adhere to Yellowbrick conventions. -# -# Author: Benjamin Bengfort -# Created: Mon May 22 11:18:06 2017 -0700 -# -# Copyright (C) 2017 District Data Labs -# For license information, see LICENSE.txt -# -# ID: checks.py [4131cb1] benjamin@bengfort.com $ - -""" -Performs checking that visualizers adhere to Yellowbrick conventions. -""" - -########################################################################## -## Imports -########################################################################## - -import sys -sys.path.append("..") - -import numpy as np -import matplotlib.pyplot as plt - -from yellowbrick.base import ModelVisualizer, ScoreVisualizer -from yellowbrick.classifier.base import ClassificationScoreVisualizer -from yellowbrick.cluster.base import ClusteringScoreVisualizer -from yellowbrick.features.base import FeatureVisualizer, DataVisualizer -from yellowbrick.regressor.base import RegressionScoreVisualizer -from yellowbrick.text.base import TextVisualizer - - -########################################################################## -## Checking runable -########################################################################## - -def check_visualizer(Visualizer): - """ - Check if visualizer adheres to Yellowbrick conventions. - - This function runs an extensive test-suite for input validation, return - values, exception handling, and more. Additional tests for scoring or - tuning visualizers will be run if the Visualizer clss inherits from the - corresponding object. - """ - name = Visualizer.__name__ - for check in _yield_all_checks(name, Visualizer): - check(name, Visualizer) - - -########################################################################## -## Generate the specific per-visualizer checking -########################################################################## - -def _yield_all_checks(name, Visualizer): - """ - Composes the checks required for the specific visualizer. - """ - - # Global Checks - yield check_instantiation - yield check_estimator_api - - # Visualizer Type Checks - if issubclass(Visualizer, RegressionScoreVisualizer): - for check in _yield_regressor_checks(name, Visualizer): - yield check - - if issubclass(Visualizer, ClassificationScoreVisualizer): - for check in _yield_classifier_checks(name, Visualizer): - yield check - - if issubclass(Visualizer, ClusteringScoreVisualizer): - for check in _yield_clustering_checks(name, Visualizer): - yield check - - if issubclass(Visualizer, FeatureVisualizer): - for check in _yield_feature_checks(name, Visualizer): - yield check - - if issubclass(Visualizer, TextVisualizer): - for check in _yield_text_checks(name, Visualizer): - yield check - - # Other checks - - -def _yield_regressor_checks(name, Visualizer): - """ - Checks for regressor visualizers - """ - pass - - -def _yield_classifier_checks(name, Visualizer): - """ - Checks for classifier visualizers - """ - pass - - -def _yield_clustering_checks(name, Visualizer): - """ - Checks for clustering visualizers - """ - pass - - -def _yield_feature_checks(name, Visualizer): - """ - Checks for feature visualizers - """ - pass - - -def _yield_text_checks(name, Visualizer): - """ - Checks for text visualizers - """ - pass - - -########################################################################## -## Checking Functions -########################################################################## - -def check_instantiation(name, Visualizer, args, kwargs): - # assert that visualizers can be passed an axes object. - ax = plt.gca() - - viz = Visualizer(*args, **kwargs) - assert viz.ax == ax - - -def check_estimator_api(name, Visualizer): - X = np.random.rand((5, 10)) - y = np.random.randint(0,2, 10) - - # Ensure fit returns self. - viz = Visualizer() - self = viz.fit(X, y) - assert viz == self - - -if __name__ == '__main__': - import sys - sys.path.append("..") - - from yellowbrick.classifier import * - from yellowbrick.cluster import * - from yellowbrick.features import * - from yellowbrick.regressor import * - from yellowbrick.text import * - - visualizers = [ - ClassBalance, ClassificationReport, ConfusionMatrix, ROCAUC, - KElbowVisualizer, SilhouetteVisualizer, - ScatterVisualizer, JointPlotVisualizer, Rank2D, RadViz, ParallelCoordinates, - AlphaSelection, ManualAlphaSelection, - PredictionError, ResidualsPlot, - TSNEVisualizer, FreqDistVisualizer, PosTagVisualizer - ] - - for visualizer in visualizers: - check_visualizer(visualizer) diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 000000000..81e2ae669 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,337 @@ +# tests.test_api +# Ensures that standard visualizers adhere to the Yellowbrick API. +# +# Author: Benjamin Bengfort +# Created: Mon May 22 11:18:06 2017 -0700 +# +# Copyright (C) 2017 The scikit-yb developers +# For license information, see LICENSE.txt +# +# ID: checks.py [4131cb1] benjamin@bengfort.com $ + +""" +Ensures that standard visualizers adhere to the Yellowbrick API. + +This module runs a full suite of checks against all of our documented Visualizers to +ensure they conform to our API. Visualizers that are considered "complete" should be +added to this test suite to ensure that they meet the requirements of the checks. +""" + +########################################################################## +## Imports +########################################################################## + +import pytest +import unittest.mock as mock +import matplotlib.pyplot as plt + +from yellowbrick.base import * +from yellowbrick.pipeline import * +from yellowbrick.classifier import * +from yellowbrick.cluster import * +from yellowbrick.features import * +from yellowbrick.gridsearch import * +from yellowbrick.regressor import * +from yellowbrick.text import * +from yellowbrick.target import * +from yellowbrick.model_selection import * + +from tests.fixtures import Dataset, Split +from yellowbrick.datasets import load_hobbies + +from sklearn.svm import LinearSVC +from sklearn.naive_bayes import GaussianNB +from sklearn.cluster import MiniBatchKMeans +from sklearn.linear_model import Lasso, LassoCV +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.model_selection import train_test_split as tts +from sklearn.datasets import make_blobs, make_classification, make_regression + + +BASES = [ + Visualizer, # green + ModelVisualizer, # green + ScoreVisualizer, # yellow + ClassificationScoreVisualizer, # red -> green + RegressionScoreVisualizer, # red -> yellow (needs tests) + ClusteringScoreVisualizer, # green + TargetVisualizer, # yellow (needs tests) + FeatureVisualizer, # green + MultiFeatureVisualizer, # green + DataVisualizer, # green + RankDBase, # green + ProjectionVisualizer, # green + TextVisualizer, # green + GridSearchVisualizer, # red (prototype, no tests) +] + +OTHER = [ + Wrapper, # green + VisualizerGrid, # red (prototype) + VisualPipeline, # red (prototype) +] + +CLASSIFICATION_VISUALZERS = [ + ClassPredictionError, # yellow + ClassificationReport, # green + ConfusionMatrix, # green + PrecisionRecallCurve, # green + ROCAUC, # green + DiscriminationThreshold, # green +] + +CLUSTERING_VISUALIZERS = [ + KElbowVisualizer, # yellow (problems with kneed) + InterclusterDistance, # yellow + SilhouetteVisualizer, # yellow --> green (quick method) +] + +FEATURE_VISUALIZERS = [ + ExplainedVariance, # red (undocumented, no tests, still a prototype) + FeatureImportances, # yellow --> green (getting moved) + Rank1D, # green + Rank2D, # green + RFECV, # yellow --> green (getting moved) + JointPlot, # red + Manifold, # green + PCA, # yellow (needs much better documentation) + ParallelCoordinates, # green + RadialVisualizer, # green +] + +MODEL_SELECTION_VISUALIZERS = [ + CVScores, # green (style) + LearningCurve, # green + ValidationCurve, # green + GridSearchColorPlot, # red (prototype, untested, undocumented) +] + +REGRESSOR_VISUALIZERS = [ + AlphaSelection, # red -> yellow (quick method) + ManualAlphaSelection, # red + CooksDistance, # green + PredictionError, # green + ResidualsPlot, # green +] + +TARGET_VISUALIZERS = [ + BalancedBinningReference, # yellow (style) + ClassBalance, # green + FeatureCorrelation, # green (its fine) +] + +TEXT_VISUALIZERS = [ + DispersionPlot, # green + FreqDistVisualizer, # yellow (needs better test coverage) + PosTagVisualizer, # green + TSNEVisualizer, # green + UMAPVisualizer, # green +] + +VISUALIZERS = ( + CLASSIFICATION_VISUALZERS + + CLUSTERING_VISUALIZERS + + FEATURE_VISUALIZERS + + MODEL_SELECTION_VISUALIZERS + + REGRESSOR_VISUALIZERS + + TARGET_VISUALIZERS + + TEXT_VISUALIZERS +) + +QUICK_METHODS = { + ClassPredictionError: class_prediction_error, + ClassificationReport: classification_report, + ConfusionMatrix: confusion_matrix, + PrecisionRecallCurve: precision_recall_curve, + ROCAUC: roc_auc, + DiscriminationThreshold: discrimination_threshold, + KElbowVisualizer: kelbow_visualizer, + InterclusterDistance: intercluster_distance, + SilhouetteVisualizer: silhouette_visualizer, + ExplainedVariance: explained_variance_visualizer, + FeatureImportances: feature_importances, + Rank1D: rank1d, + Rank2D: rank2d, + RFECV: rfecv, + JointPlot: joint_plot, # raises not implemented error + Manifold: manifold_embedding, + PCA: pca_decomposition, + ParallelCoordinates: parallel_coordinates, + RadialVisualizer: radviz, + CVScores: cv_scores, + LearningCurve: learning_curve, + ValidationCurve: validation_curve, + AlphaSelection: alphas, + CooksDistance: cooks_distance, + PredictionError: prediction_error, + ResidualsPlot: residuals_plot, + BalancedBinningReference: balanced_binning_reference, + ClassBalance: class_balance, + FeatureCorrelation: feature_correlation, + DispersionPlot: dispersion, + FreqDistVisualizer: freqdist, + PosTagVisualizer: postag, + TSNEVisualizer: tsne, + UMAPVisualizer: umap, + GridSearchVisualizer: gridsearch_color_plot, +} + +ALIASES = [ + PRCurve, + KElbow, + ICDM, + JointPlotVisualizer, + PCADecomposition, + RadViz, + FrequencyVisualizer, +] + +SKIPS = { + GridSearchColorPlot: "prototype only, no tests", + ManualAlphaSelection: "prototype only, cannot be instantiated", + ValidationCurve: "has required positional args: 'param_name', 'param_range'", + DispersionPlot: "has required positional arg: 'target_words'", + FrequencyVisualizer: "has required positional arg: 'features'", +} + + +########################################################################## +## Fixtures +########################################################################## + +@pytest.fixture(scope='function') +def figure(): + fig, ax = plt.subplots() + yield fig, ax + plt.close(fig) + + +def get_model_for_visualizer(Viz): + """ + Helper function to return the appropriate model for the visualizer class + """ + if not issubclass(Viz, ModelVisualizer): + return None + + if issubclass(Viz, ClassificationScoreVisualizer) or Viz is DiscriminationThreshold: + return GaussianNB + + if issubclass(Viz, RegressionScoreVisualizer): + if Viz is AlphaSelection: + return LassoCV + return Lasso + + if issubclass(Viz, ClusteringScoreVisualizer): + return MiniBatchKMeans + + if Viz in MODEL_SELECTION_VISUALIZERS: + return LinearSVC + + if Viz in {RFECV, FeatureImportances}: + return GaussianNB + + raise TypeError("unknown model for type {}".format(Viz.__name__)) + + +def get_dataset_for_visualizer(Viz): + """ + Helper function to return the appropriate dataset for the visualizer class + """ + X, y = None, None + + if issubclass(Viz, RegressionScoreVisualizer): + X, y = make_regression(random_state=842) + + if issubclass(Viz, ClusteringScoreVisualizer): + X, y = make_blobs(random_state=112) + + if issubclass(Viz, ClassificationScoreVisualizer): + X, y = make_classification(random_state=49) + + if Viz in TEXT_VISUALIZERS: + corpus = load_hobbies() + X, y = CountVectorizer().fit_transform(corpus.data), corpus.target + + if X is None or y is None: + # Current default dataset is binary classification + X, y = make_classification(random_state=23) + + X_train, X_test, y_train, y_test = tts(X, y, test_size=0.2, random_state=982) + return Dataset(Split(X_train, X_test), Split(y_train, y_test)) + + +########################################################################## +## Base API Tests +########################################################################## + +@pytest.mark.parametrize("Viz", VISUALIZERS) +def test_instantiation(Viz, figure): + """ + Ensure all visualizers are instantiated correctly + """ + if Viz in SKIPS: + pytest.skip(SKIPS[Viz]) + + fig, ax = figure + kwargs = { + "ax": ax, + "fig": fig, + "size": (9, 6), + "title": "foo title", + } + + model = get_model_for_visualizer(Viz) + if model is not None: + oz = Viz(model(), **kwargs) + else: + oz = Viz(**kwargs) + + assert oz.ax is ax + assert oz.fig is fig + assert oz.size == (9, 6) + assert oz.title == "foo title" + + +@pytest.mark.skip("too many edge cases, is tested in most visualizer-specific tests") +@pytest.mark.parametrize("Viz", VISUALIZERS) +def test_fit(Viz): + """ + Ensure that fit returns self and sets up the visualization + """ + if Viz in SKIPS: + pytest.skip(SKIPS[Viz]) + + kwargs = { + "ax": mock.MagicMock(), + "fig": mock.MagicMock(), + } + + model = get_model_for_visualizer(Viz) + data = get_dataset_for_visualizer(Viz) + oz = Viz(model(), **kwargs) if model is not None else Viz(**kwargs) + + assert oz.fit(data.X.train, data.y.train) is oz + + +@pytest.mark.xfail(reason="quick methods aren't primetime yet") +@pytest.mark.parametrize("Viz, method", list(QUICK_METHODS.items())) +def test_quickmethod(Viz, method): + """ + Ensures the quick method accepts standard arguments and returns the visualizer + """ + if Viz in SKIPS: + pytest.skip(SKIPS[Viz]) + + kwargs = { + "ax": mock.MagicMock(), + "fig": mock.MagicMock(), + } + + model = get_model_for_visualizer(Viz) + pargs = [] if model is None else [model] + + data = get_dataset_for_visualizer(Viz) + pargs.extend([data.X.train, data.y.train]) + + oz = method(*pargs, **kwargs) + assert isinstance(oz, Viz) diff --git a/tests/test_base.py b/tests/test_base.py index d27c712f8..df285d473 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -34,6 +34,8 @@ from tests.base import IS_WINDOWS_OR_CONDA, VisualTestCase from sklearn.svm import LinearSVC +from sklearn.decomposition import PCA +from sklearn.pipeline import Pipeline from sklearn.datasets import make_classification ########################################################################## @@ -161,6 +163,60 @@ class CustomVisualizer(Visualizer): viz = CustomVisualizer() assert viz.poof() is not None + +########################################################################## +## ModelVisualizer Cases +########################################################################## + + +class TestModelVisualizer(VisualTestCase): + """ + Tests for the ModelVisualizer + """ + + def test_final_estimator(self): + """ + Ensure the final estimator can be returned from a Pipeline or directly + """ + svc = LinearSVC() + assert ModelVisualizer(svc)._final_estimator() is svc + + short = Pipeline([("clf", svc)]) + assert ModelVisualizer(short)._final_estimator() is svc + + pipe = Pipeline([("pca", PCA()), ("clf", svc)]) + assert ModelVisualizer(pipe)._final_estimator() is svc + + @pytest.mark.parametrize( + "model", [LinearSVC(), Pipeline([("pca", PCA()), ("clf", LinearSVC())])] + ) + def test_get_learned_attr(self, model): + """ + Ensure a learned attribute can be feteched from a model + """ + X, y = load_occupancy(return_dataset=True).to_numpy() + assert not hasattr(model, "classes_") + + viz = ModelVisualizer(model) + assert viz.fit(X, y) is viz + assert hasattr(model, "classes_") + + assert viz._get_learned_attr("classes_") is model.classes_ + + @pytest.mark.parametrize( + "model", [LinearSVC(), Pipeline([("pca", PCA()), ("clf", LinearSVC())])] + ) + def test_get_learned_attr_not_fitted(self, model): + """ + Assert NotFitted is raised on _get_learned_attr when not exists + """ + assert not hasattr(model, "classes_") + viz = ModelVisualizer(model) + + with pytest.raises(NotFitted, match="object has no attribute 'classes_'"): + viz._get_learned_attr("classes_") + + ########################################################################## ## ScoreVisualizer Cases ########################################################################## @@ -170,6 +226,7 @@ class MockVisualizer(ScoreVisualizer): """ Mock for a downstream score visualizer """ + def fit(self, X, y): super(MockVisualizer, self).fit(X, y) @@ -178,6 +235,7 @@ class TestScoreVisualizer(VisualTestCase): """ Tests for the ScoreVisualizer """ + def test_with_fitted(self): """ Test that visualizer properly handles an already-fitted model diff --git a/tests/test_utils/test_types.py b/tests/test_utils/test_types.py index 17d5ff033..1fbe5e26c 100644 --- a/tests/test_utils/test_types.py +++ b/tests/test_utils/test_types.py @@ -25,7 +25,7 @@ try: import pandas as pd -except: +except ImportError: pd = None # Yellowbrick Utilities @@ -41,8 +41,15 @@ from sklearn.linear_model import Ridge, RidgeCV, Lasso, LassoCV REGRESSORS = [ - SVR, DecisionTreeRegressor, MLPRegressor, LinearRegression, - RandomForestRegressor, Ridge, RidgeCV, Lasso, LassoCV, + SVR, + DecisionTreeRegressor, + MLPRegressor, + LinearRegression, + RandomForestRegressor, + Ridge, + RidgeCV, + Lasso, + LassoCV, ] # Import Classifiers @@ -55,8 +62,13 @@ from sklearn.naive_bayes import MultinomialNB, GaussianNB CLASSIFIERS = [ - SVC, DecisionTreeClassifier, MLPClassifier, LogisticRegression, - RandomForestClassifier, GradientBoostingClassifier, MultinomialNB, + SVC, + DecisionTreeClassifier, + MLPClassifier, + LogisticRegression, + RandomForestClassifier, + GradientBoostingClassifier, + MultinomialNB, GaussianNB, ] @@ -64,17 +76,13 @@ from sklearn.cluster import KMeans, MiniBatchKMeans from sklearn.cluster import AffinityPropagation, Birch -CLUSTERERS = [ - KMeans, MiniBatchKMeans, AffinityPropagation, Birch, -] +CLUSTERERS = [KMeans, MiniBatchKMeans, AffinityPropagation, Birch] # Import Decompositions from sklearn.decomposition import PCA from sklearn.decomposition import TruncatedSVD -DECOMPOSITIONS = [ - PCA, TruncatedSVD -] +DECOMPOSITIONS = [PCA, TruncatedSVD] # Import Transformers from sklearn.feature_extraction.text import TfidfVectorizer @@ -84,7 +92,10 @@ from sklearn.impute import SimpleImputer TRANSFORMERS = [ - DictVectorizer, QuantileTransformer, StandardScaler, SimpleImputer, + DictVectorizer, + QuantileTransformer, + StandardScaler, + SimpleImputer, TfidfVectorizer, ] @@ -92,16 +103,12 @@ from sklearn.pipeline import Pipeline, FeatureUnion -PIPELINES = [ - Pipeline, FeatureUnion, -] +PIPELINES = [Pipeline, FeatureUnion] # Import GridSearch Utilities from sklearn.model_selection import GridSearchCV, RandomizedSearchCV -SEARCH = [ - GridSearchCV, RandomizedSearchCV, -] +SEARCH = [GridSearchCV, RandomizedSearchCV] # Other Groups @@ -120,6 +127,7 @@ def obj_name(obj): ## Model type checking test cases ########################################################################## + class TestModelTypeChecking(object): """ Test model type checking utilities @@ -146,9 +154,9 @@ def test_is_estimator(self, model): obj = model() assert is_estimator(obj) - @pytest.mark.parametrize("cls", [ - list, dict, tuple, set, str, bool, int, float - ], ids=obj_name) + @pytest.mark.parametrize( + "cls", [list, dict, tuple, set, str, bool, int, float], ids=obj_name + ) def test_not_is_estimator(self, cls): """ Assert Python objects are not estimators @@ -166,10 +174,7 @@ def test_is_estimator_pipeline(self): assert is_estimator(Pipeline) assert is_estimator(FeatureUnion) - model = Pipeline([ - ('reduce_dim', PCA()), - ('linreg', LinearRegression()) - ]) + model = Pipeline([("reduce_dim", PCA()), ("linreg", LinearRegression())]) assert is_estimator(model) @@ -180,14 +185,18 @@ def test_is_estimator_search(self): assert is_estimator(GridSearchCV) assert is_estimator(RandomizedSearchCV) - model = GridSearchCV(SVR(), {'kernel': ['linear', 'rbf']}) + model = GridSearchCV(SVR(), {"kernel": ["linear", "rbf"]}) assert is_estimator(model) - @pytest.mark.parametrize("viz,params", [ - (Visualizer, {}), - (ScoreVisualizer, {'model': LinearRegression()}), - (ModelVisualizer, {'model': LogisticRegression()}) - ], ids=["Visualizer", "ScoreVisualizer", "ModelVisualizer"]) + @pytest.mark.parametrize( + "viz,params", + [ + (Visualizer, {}), + (ScoreVisualizer, {"model": LinearRegression()}), + (ModelVisualizer, {"model": LogisticRegression()}), + ], + ids=["Visualizer", "ScoreVisualizer", "ModelVisualizer"], + ) def test_is_estimator_visualizer(self, viz, params): """ Test that is_estimator works for Visualizers @@ -198,6 +207,32 @@ def test_is_estimator_visualizer(self, viz, params): obj = viz(**params) assert is_estimator(obj) + ##//////////////////////////////////////////////////////////////////// + ## is_pipeline testing + ##//////////////////////////////////////////////////////////////////// + + def test_is_pipeline_alias(self): + """ + Assert ispipeline aliases is_pipeline + """ + assert ispipeline is is_pipeline + + def test_is_pipeline(self): + assert is_pipeline(Pipeline) + + obj = Pipeline([("pca", PCA()), ("clf", LogisticRegression())]) + assert is_pipeline(obj) + + @pytest.mark.parametrize( + "model", CLASSIFIERS + CLUSTERERS + TRANSFORMERS + DECOMPOSITIONS, ids=obj_name + ) + def test_not_is_pipeline(self, model): + assert inspect.isclass(model) + assert not is_pipeline(model) + + obj = model() + assert not is_pipeline(obj) + ##//////////////////////////////////////////////////////////////////// ## is_regressor testing ##//////////////////////////////////////////////////////////////////// @@ -219,9 +254,9 @@ def test_is_regressor(self, model): obj = model() assert is_regressor(obj) - @pytest.mark.parametrize("model", - CLASSIFIERS+CLUSTERERS+TRANSFORMERS+DECOMPOSITIONS, - ids=obj_name) + @pytest.mark.parametrize( + "model", CLASSIFIERS + CLUSTERERS + TRANSFORMERS + DECOMPOSITIONS, ids=obj_name + ) def test_not_is_regressor(self, model): """ Test that is_regressor does not match non-regressor estimators @@ -239,10 +274,7 @@ def test_is_regressor_pipeline(self): assert not is_regressor(Pipeline) assert not is_regressor(FeatureUnion) - model = Pipeline([ - ('reduce_dim', PCA()), - ('linreg', LinearRegression()) - ]) + model = Pipeline([("reduce_dim", PCA()), ("linreg", LinearRegression())]) assert is_regressor(model) @@ -254,13 +286,17 @@ def test_is_regressor_search(self): assert is_regressor(GridSearchCV) assert is_regressor(RandomizedSearchCV) - model = GridSearchCV(SVR(), {'kernel': ['linear', 'rbf']}) + model = GridSearchCV(SVR(), {"kernel": ["linear", "rbf"]}) assert is_regressor(model) - @pytest.mark.parametrize("viz,params", [ - (ScoreVisualizer, {'model': LinearRegression()}), - (ModelVisualizer, {'model': Ridge()}) - ], ids=["ScoreVisualizer", "ModelVisualizer"]) + @pytest.mark.parametrize( + "viz,params", + [ + (ScoreVisualizer, {"model": LinearRegression()}), + (ModelVisualizer, {"model": Ridge()}), + ], + ids=["ScoreVisualizer", "ModelVisualizer"], + ) def test_is_regressor_visualizer(self, viz, params): """ Test that is_regressor works on visualizers @@ -292,9 +328,9 @@ def test_is_classifier(self, model): obj = model() assert is_classifier(obj) - @pytest.mark.parametrize("model", - REGRESSORS+CLUSTERERS+TRANSFORMERS+DECOMPOSITIONS, - ids=obj_name) + @pytest.mark.parametrize( + "model", REGRESSORS + CLUSTERERS + TRANSFORMERS + DECOMPOSITIONS, ids=obj_name + ) def test_not_is_classifier(self, model): """ Test that is_classifier does not match non-classifier estimators @@ -312,10 +348,7 @@ def test_classifier_pipeline(self): assert not is_classifier(Pipeline) assert not is_classifier(FeatureUnion) - model = Pipeline([ - ('reduce_dim', PCA()), - ('linreg', LogisticRegression()) - ]) + model = Pipeline([("reduce_dim", PCA()), ("linreg", LogisticRegression())]) assert is_classifier(model) @@ -327,13 +360,17 @@ def test_is_classifier_search(self): assert is_classifier(GridSearchCV) assert is_classifier(RandomizedSearchCV) - model = GridSearchCV(SVC(), {'kernel': ['linear', 'rbf']}) + model = GridSearchCV(SVC(), {"kernel": ["linear", "rbf"]}) assert is_classifier(model) - @pytest.mark.parametrize("viz,params", [ - (ScoreVisualizer, {'model': MultinomialNB()}), - (ModelVisualizer, {'model': MLPClassifier()}) - ], ids=["ScoreVisualizer", "ModelVisualizer"]) + @pytest.mark.parametrize( + "viz,params", + [ + (ScoreVisualizer, {"model": MultinomialNB()}), + (ModelVisualizer, {"model": MLPClassifier()}), + ], + ids=["ScoreVisualizer", "ModelVisualizer"], + ) def test_is_classifier_visualizer(self, viz, params): """ Test that is_classifier works on visualizers @@ -365,9 +402,9 @@ def test_is_clusterer(self, model): obj = model() assert is_clusterer(obj) - @pytest.mark.parametrize("model", - REGRESSORS+CLASSIFIERS+TRANSFORMERS+DECOMPOSITIONS, - ids=obj_name) + @pytest.mark.parametrize( + "model", REGRESSORS + CLASSIFIERS + TRANSFORMERS + DECOMPOSITIONS, ids=obj_name + ) def test_not_is_clusterer(self, model): """ Test that is_clusterer does not match non-clusterer estimators @@ -385,16 +422,13 @@ def test_clusterer_pipeline(self): assert not is_clusterer(Pipeline) assert not is_clusterer(FeatureUnion) - model = Pipeline([ - ('reduce_dim', PCA()), - ('kmeans', KMeans()) - ]) + model = Pipeline([("reduce_dim", PCA()), ("kmeans", KMeans())]) assert is_clusterer(model) - @pytest.mark.parametrize("viz,params", [ - (ModelVisualizer, {'model': KMeans()}) - ], ids=["ModelVisualizer"]) + @pytest.mark.parametrize( + "viz,params", [(ModelVisualizer, {"model": KMeans()})], ids=["ModelVisualizer"] + ) def test_is_clusterer_visualizer(self, viz, params): """ Test that is_clusterer works on visualizers @@ -426,8 +460,9 @@ def test_is_gridsearch(self, model): obj = model(SVC, {"C": [0.5, 1, 10]}) assert is_gridsearch(obj) - @pytest.mark.parametrize("model", - [MLPRegressor, MLPClassifier, SimpleImputer], ids=obj_name) + @pytest.mark.parametrize( + "model", [MLPRegressor, MLPClassifier, SimpleImputer], ids=obj_name + ) def test_not_is_gridsearch(self, model): """ Test that is_gridsearch does not match non grid searches @@ -448,10 +483,19 @@ def test_probabilistic_alias(self): """ assert isprobabilistic is is_probabilistic - @pytest.mark.parametrize("model", [ - MultinomialNB, GaussianNB, LogisticRegression, SVC, - RandomForestClassifier, GradientBoostingClassifier, MLPClassifier, - ], ids=obj_name) + @pytest.mark.parametrize( + "model", + [ + MultinomialNB, + GaussianNB, + LogisticRegression, + SVC, + RandomForestClassifier, + GradientBoostingClassifier, + MLPClassifier, + ], + ids=obj_name, + ) def test_is_probabilistic(self, model): """ Test that is_probabilistic works correctly @@ -462,10 +506,11 @@ def test_is_probabilistic(self, model): obj = model() assert is_probabilistic(obj) - @pytest.mark.parametrize("model", [ - MLPRegressor, SimpleImputer, StandardScaler, KMeans, - RandomForestRegressor, - ], ids=obj_name) + @pytest.mark.parametrize( + "model", + [MLPRegressor, SimpleImputer, StandardScaler, KMeans, RandomForestRegressor], + ids=obj_name, + ) def test_not_is_probabilistic(self, model): """ Test that is_probabilistic does not match non probablistic estimators @@ -481,6 +526,7 @@ def test_not_is_probabilistic(self, model): ## Data type checking test cases ########################################################################## + class TestDataTypeChecking(object): """ Test data type checking utilities @@ -501,22 +547,24 @@ def test_is_dataframe(self): """ Test that is_dataframe works correctly """ - df = pd.DataFrame([ - {'a': 1, 'b': 2.3, 'c': 'Hello'}, - {'a': 2, 'b': 3.14, 'c': 'World'}, - ]) + df = pd.DataFrame( + [{"a": 1, "b": 2.3, "c": "Hello"}, {"a": 2, "b": 3.14, "c": "World"}] + ) assert is_dataframe(df) - @pytest.mark.parametrize("obj", [ - np.array([ - (1,2.,'Hello'), (2,3.,"World")], - dtype=[('foo', 'i4'),('bar', 'f4'), ('baz', 'S10')] - ), - np.array([[1,2,3], [1,2,3]]), - [[1,2,3], [1,2,3]], - ], - ids=["structured array", "array", "list"]) + @pytest.mark.parametrize( + "obj", + [ + np.array( + [(1, 2.0, "Hello"), (2, 3.0, "World")], + dtype=[("foo", "i4"), ("bar", "f4"), ("baz", "S10")], + ), + np.array([[1, 2, 3], [1, 2, 3]]), + [[1, 2, 3], [1, 2, 3]], + ], + ids=["structured array", "array", "list"], + ) def test_not_is_dataframe(self, obj): """ Test that is_dataframe does not match non-dataframes @@ -542,15 +590,18 @@ def test_is_series(self): assert is_series(df) - @pytest.mark.parametrize("obj", [ - np.array([ - (1,2.,'Hello'), (2,3.,"World")], - dtype=[('foo', 'i4'),('bar', 'f4'), ('baz', 'S10')] - ), - np.array([1,2,3]), - [1, 2, 3], - ], - ids=["structured array", "array", "list"]) + @pytest.mark.parametrize( + "obj", + [ + np.array( + [(1, 2.0, "Hello"), (2, 3.0, "World")], + dtype=[("foo", "i4"), ("bar", "f4"), ("baz", "S10")], + ), + np.array([1, 2, 3]), + [1, 2, 3], + ], + ids=["structured array", "array", "list"], + ) def test_not_is_series(self, obj): """ Test that is_series does not match non-dataframes @@ -571,18 +622,16 @@ def test_is_structured_array(self): """ Test that is_structured_array works correctly """ - x = np.array([ - (1,2.,'Hello'), (2,3.,"World")], - dtype=[('foo', 'i4'),('bar', 'f4'), ('baz', 'S10')] + x = np.array( + [(1, 2.0, "Hello"), (2, 3.0, "World")], + dtype=[("foo", "i4"), ("bar", "f4"), ("baz", "S10")], ) assert is_structured_array(x) - @pytest.mark.parametrize("obj", [ - np.array([[1,2,3], [1,2,3]]), - [[1,2,3], [1,2,3]], - ], - ids=obj_name) + @pytest.mark.parametrize( + "obj", [np.array([[1, 2, 3], [1, 2, 3]]), [[1, 2, 3], [1, 2, 3]]], ids=obj_name + ) def test_not_is_structured_array(self, obj): """ Test that is_structured_array does not match non-structured-arrays diff --git a/yellowbrick/base.py b/yellowbrick/base.py index c329b86db..1a4fd8ced 100644 --- a/yellowbrick/base.py +++ b/yellowbrick/base.py @@ -22,8 +22,9 @@ from yellowbrick.utils import get_model_name from yellowbrick.utils.wrapper import Wrapper +from yellowbrick.utils.types import is_pipeline from yellowbrick.utils.helpers import check_fitted -from yellowbrick.exceptions import YellowbrickWarning +from yellowbrick.exceptions import YellowbrickWarning, NotFitted from yellowbrick.exceptions import YellowbrickValueError, YellowbrickTypeError @@ -350,6 +351,27 @@ def fit(self, X, y=None, **kwargs): self.estimator.fit(X, y, **kwargs) return self + def _final_estimator(self): + """ + If the wrapped estimator is a Pipeline, return the final estimator in + the Pipeline, otherwise return the wrapped estimator. + """ + if is_pipeline(self.estimator): + return self.estimator.steps[-1][1] + return self.estimator + + def _get_learned_attr(self, attr): + """ + Get a learned attribute (e.g. an attribute that is not available until + after fit() is called) from the underlying estimator. If the attribute + doesn't exist a NotFittedError is raised. If the estimator is a Pipeline, + the attribute is fetched from the final estimator. + """ + try: + return getattr(self._final_estimator(), attr) + except AttributeError as e: + raise NotFitted(str(e)) + ########################################################################## ## Score Visualizers diff --git a/yellowbrick/features/__init__.py b/yellowbrick/features/__init__.py index 7b7fc0fc2..dfca97cf5 100644 --- a/yellowbrick/features/__init__.py +++ b/yellowbrick/features/__init__.py @@ -17,13 +17,18 @@ ## Imports ########################################################################## +## Hoist base classes into the features namespace +from .base import FeatureVisualizer, MultiFeatureVisualizer, DataVisualizer +from .projection import ProjectionVisualizer + ## Hoist visualizers into the features namespace from .pcoords import ParallelCoordinates, parallel_coordinates from .radviz import RadialVisualizer, RadViz, radviz -from .rankd import Rank1D, rank1d, Rank2D, rank2d +from .rankd import RankDBase, Rank1D, rank1d, Rank2D, rank2d from .jointplot import JointPlot, JointPlotVisualizer, joint_plot from .pca import PCA, PCADecomposition, pca_decomposition from .manifold import Manifold, manifold_embedding +from .decomposition import ExplainedVariance, explained_variance_visualizer # Alias the TargetType defined in yellowbrick.utils.target from yellowbrick.utils.target import TargetType @@ -31,4 +36,4 @@ # RFECV and Feature Importances moved to model selection module as of YB v1.0 from yellowbrick.model_selection.rfecv import RFECV, rfecv from yellowbrick.model_selection.importances import FeatureImportances -from yellowbrick.model_selection.importances import feature_importances +from yellowbrick.model_selection.importances import feature_importances \ No newline at end of file diff --git a/yellowbrick/gridsearch/__init__.py b/yellowbrick/gridsearch/__init__.py index 80e14a7d3..9683a46cc 100644 --- a/yellowbrick/gridsearch/__init__.py +++ b/yellowbrick/gridsearch/__init__.py @@ -18,4 +18,5 @@ ########################################################################## ## Hoist visualizers into the gridsearch namespace +from .base import * from .pcolor import * diff --git a/yellowbrick/target/__init__.py b/yellowbrick/target/__init__.py index ca5959795..b38699509 100644 --- a/yellowbrick/target/__init__.py +++ b/yellowbrick/target/__init__.py @@ -18,6 +18,9 @@ ## Imports ########################################################################## +# Hoist base classes into top level +from .base import TargetVisualizer + # Hoist visualizers into the top level of the target package from .class_balance import ClassBalance, class_balance from .binning import BalancedBinningReference, balanced_binning_reference diff --git a/yellowbrick/text/__init__.py b/yellowbrick/text/__init__.py index 3a69f3f40..6bcf1e4dc 100644 --- a/yellowbrick/text/__init__.py +++ b/yellowbrick/text/__init__.py @@ -17,8 +17,9 @@ ## Imports ########################################################################## +from .base import TextVisualizer from .tsne import TSNEVisualizer, tsne from .umap_vis import UMAPVisualizer, umap -from .freqdist import FreqDistVisualizer, freqdist +from .freqdist import FrequencyVisualizer, freqdist, FreqDistVisualizer from .postag import PosTagVisualizer from .dispersion import DispersionPlot, dispersion diff --git a/yellowbrick/utils/types.py b/yellowbrick/utils/types.py index 04bb67cec..c6807419f 100644 --- a/yellowbrick/utils/types.py +++ b/yellowbrick/utils/types.py @@ -20,6 +20,7 @@ import inspect import numpy as np +from sklearn.pipeline import Pipeline from sklearn.base import BaseEstimator @@ -28,26 +29,44 @@ ########################################################################## -def is_estimator(model): +def is_estimator(obj): """ Determines if a model is an estimator using issubclass and isinstance. Parameters ---------- - estimator : class or instance - The object to test if it is a Scikit-Learn clusterer, especially a - Scikit-Learn estimator or Yellowbrick visualizer + obj : class or instance + The object to test if it is a scikit-learn estimator. """ - if inspect.isclass(model): - return issubclass(model, BaseEstimator) + if inspect.isclass(obj): + return issubclass(obj, BaseEstimator) - return isinstance(model, BaseEstimator) + return isinstance(obj, BaseEstimator) # Alias for closer name to isinstance and issubclass isestimator = is_estimator +def is_pipeline(obj): + """ + Determines if a model is a pipeline using issubclass and isinstance. + + Parameters + ---------- + obj : class or instance + The object to test if it is a scikit-learn Pipeline. + """ + if inspect.isclass(obj): + return issubclass(obj, Pipeline) + + return isinstance(obj, Pipeline) + + +# Alias for closer name to isinstance and issubclass +ispipeline = is_pipeline + + def is_classifier(estimator): """ Returns True if the given estimator is (probably) a classifier. @@ -55,8 +74,7 @@ def is_classifier(estimator): Parameters ---------- estimator : class or instance - The object to test if it is a Scikit-Learn clusterer, especially a - Scikit-Learn estimator or Yellowbrick visualizer + The object to test if it is a scikit-learn classifier. See also -------- @@ -79,8 +97,7 @@ def is_regressor(estimator): Parameters ---------- estimator : class or instance - The object to test if it is a Scikit-Learn clusterer, especially a - Scikit-Learn estimator or Yellowbrick visualizer + The object to test if it is a scikit-learn regressor. See also -------- @@ -103,8 +120,7 @@ def is_clusterer(estimator): Parameters ---------- estimator : class or instance - The object to test if it is a Scikit-Learn clusterer, especially a - Scikit-Learn estimator or Yellowbrick visualizer + The object to test if it is a scikit-learn clusterer. """ # Test the _estimator_type property @@ -122,8 +138,7 @@ def is_gridsearch(estimator): Parameters ---------- estimator : class or instance - The object to test if it is a Scikit-Learn clusterer, especially a - Scikit-Learn estimator or Yellowbrick visualizer + The object to test if it is a scikit-learn grid search. """ from sklearn.model_selection import GridSearchCV, RandomizedSearchCV