Skip to content

Commit

Permalink
Fixing sklearn error when using RandomizedSearchCV
Browse files Browse the repository at this point in the history
Signed-off-by: Luis França <[email protected]>
  • Loading branch information
luisffranca committed Jan 14, 2022
1 parent 2a76307 commit b7a99af
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 2 deletions.
19 changes: 17 additions & 2 deletions python/interpret-core/interpret/glassbox/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from abc import abstractmethod
from sklearn.base import is_classifier
import numpy as np
from sklearn.base import ClassifierMixin, RegressorMixin
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.linear_model import LogisticRegression as SKLogistic
from sklearn.linear_model import Lasso as SKLinear


class BaseLinear:
class BaseLinear(BaseEstimator):
""" Base linear model.
Currently wrapper around linear models in scikit-learn.
Expand Down Expand Up @@ -43,11 +43,26 @@ def __init__(
self.linear_class = linear_class
self.kwargs = kwargs

for key, value in self.kwargs.items():
setattr(self, key, value)

@abstractmethod
def _model(self):
# This method should be overridden.
return None

# get_params and set_params are usually inherited from BaseEstimator, but they will
# fail here due to the **kwargs in the __init__. Therefore, we implement them.
def get_params(self, deep = True):
return {param: getattr(self, param)
for param in self.kwargs}

def set_params(self, **parameters):
for parameter, value in parameters.items():
setattr(self, parameter, value)

return self

def fit(self, X, y):
""" Fits model to provided instances.
Expand Down
59 changes: 59 additions & 0 deletions python/interpret-core/interpret/glassbox/test/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sklearn.datasets import load_breast_cancer, load_boston
from sklearn.linear_model import LogisticRegression as SKLogistic
from sklearn.linear_model import Lasso as SKLinear
from sklearn.model_selection import RandomizedSearchCV
import numpy as np


Expand Down Expand Up @@ -38,6 +39,35 @@ def test_linear_regression():
assert global_viz is not None


def test_linear_regression_sklearn_compatibility():
boston = load_boston()
X, y = boston.data, boston.target

distributions = {
'max_iter': [250, 500],
'alpha': [0.1 , 0.25, 0.5, 1]
}

sk_lr = SKLinear()
our_lr = LinearRegression()

search_sk = RandomizedSearchCV(estimator = sk_lr,
param_distributions = distributions,
random_state = 2022)

search_our = RandomizedSearchCV(estimator = our_lr,
param_distributions = distributions,
random_state = 2022)

search_sk.fit(X, y)
search_our.fit(X, y)

sk_pred = search_sk.predict(X)
our_pred = search_our.predict(X)

assert np.allclose(sk_pred, our_pred)


def test_logistic_regression():
cancer = load_breast_cancer()
X, y = cancer.data, cancer.target
Expand Down Expand Up @@ -72,6 +102,35 @@ def test_logistic_regression():
assert global_viz is not None


def test_logistic_regression_sklearn_compatibility():
cancer = load_breast_cancer()
X, y = cancer.data, cancer.target

distributions = {
'penalty': ['l1', 'l2'],
'C': [1 , 0.5, 0.1, 0.05, 0.01]
}

sk_lr = SKLogistic()
our_lr = LogisticRegression()

search_sk = RandomizedSearchCV(estimator = sk_lr,
param_distributions = distributions,
random_state = 2022)

search_our = RandomizedSearchCV(estimator = our_lr,
param_distributions = distributions,
random_state = 2022)

search_sk.fit(X, y)
search_our.fit(X, y)

sk_pred = search_sk.predict_proba(X)
our_pred = search_our.predict_proba(X)

assert np.allclose(sk_pred, our_pred)


def test_sorting():
cancer = load_breast_cancer()
X, y = cancer.data, cancer.target
Expand Down

0 comments on commit b7a99af

Please sign in to comment.