Skip to content

Commit

Permalink
test to sklearn perm imp with cv and sample weight
Browse files Browse the repository at this point in the history
  • Loading branch information
rg2410 committed Jan 20, 2020
1 parent 64bee30 commit f587bfa
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion tests/test_sklearn_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from sklearn.base import is_classifier, is_regressor
from sklearn.svm import SVR, SVC
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.pipeline import make_pipeline
from sklearn.feature_selection import SelectFromModel
Expand Down Expand Up @@ -165,6 +165,7 @@ def test_explain_weights(iris_train):
for _expl in res:
assert "petal width (cm)" in _expl


def test_pandas_xgboost_support(iris_train):
xgboost = pytest.importorskip('xgboost')
pd = pytest.importorskip('pandas')
Expand All @@ -175,3 +176,17 @@ def test_pandas_xgboost_support(iris_train):
est.fit(X, y)
# we expect no exception to be raised here when using xgboost with pd.DataFrame
perm = PermutationImportance(est).fit(X, y)


def test_cv_sample_weight(iris_train):
X, y, feature_names, target_names = iris_train
weights_ones = np.ones(len(y))
model = RandomForestClassifier(random_state=42)

# we expect no exception to be raised when passing weights with a CV
perm_weights = PermutationImportance(model, cv=5, random_state=42).\
fit(X, y, sample_weight=weights_ones)
perm = PermutationImportance(model, cv=5, random_state=42).fit(X, y)

# passing a vector of weights filled with one should be the same as passing no weights
assert (perm.feature_importances_ == perm_weights.feature_importances_).all()

0 comments on commit f587bfa

Please sign in to comment.