Skip to content

Commit

Permalink
ENH Adds n_features_in_ checks to linear and svm modules (scikit-lear…
Browse files Browse the repository at this point in the history
…n#18578)

Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
  • Loading branch information
3 people authored Jan 2, 2021
1 parent 8f72c2a commit 5946f8b
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 32 deletions.
11 changes: 3 additions & 8 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def fit(self, X, y):
def _decision_function(self, X):
check_is_fitted(self)

X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
X = self._validate_data(X, accept_sparse=['csr', 'csc', 'coo'],
reset=False)
return safe_sparse_dot(X, self.coef_.T,
dense_output=True) + self.intercept_

Expand Down Expand Up @@ -281,13 +282,7 @@ class would be predicted.
"""
check_is_fitted(self)

X = check_array(X, accept_sparse='csr')

n_features = self.coef_.shape[1]
if X.shape[1] != n_features:
raise ValueError("X has %d features per sample; expecting %d"
% (X.shape[1], n_features))

X = self._validate_data(X, accept_sparse='csr', reset=False)
scores = safe_sparse_dot(X, self.coef_.T,
dense_output=True) + self.intercept_
return scores.ravel() if scores.shape[1] == 1 else scores
Expand Down
13 changes: 6 additions & 7 deletions sklearn/linear_model/_glm/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import scipy.optimize

from ...base import BaseEstimator, RegressorMixin
from ...utils import check_array, check_X_y
from ...utils.optimize import _check_optimize_result
from ...utils.validation import check_is_fitted, _check_sample_weight
from ..._loss.glm_distribution import (
Expand Down Expand Up @@ -221,9 +220,9 @@ def fit(self, X, y, sample_weight=None):
family = self._family_instance
link = self._link_instance

X, y = check_X_y(X, y, accept_sparse=['csc', 'csr'],
dtype=[np.float64, np.float32],
y_numeric=True, multi_output=False)
X, y = self._validate_data(X, y, accept_sparse=['csc', 'csr'],
dtype=[np.float64, np.float32],
y_numeric=True, multi_output=False)

weights = _check_sample_weight(sample_weight, X)

Expand Down Expand Up @@ -311,9 +310,9 @@ def _linear_predictor(self, X):
Returns predicted values of linear predictor.
"""
check_is_fitted(self)
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'],
dtype=[np.float64, np.float32], ensure_2d=True,
allow_nd=False)
X = self._validate_data(X, accept_sparse=['csr', 'csc', 'coo'],
dtype=[np.float64, np.float32], ensure_2d=True,
allow_nd=False, reset=False)
return X @ self.coef_ + self.intercept_

def predict(self, X):
Expand Down
18 changes: 9 additions & 9 deletions sklearn/linear_model/_stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ._base import LinearClassifierMixin, SparseCoefMixin
from ._base import make_dataset
from ..base import BaseEstimator, RegressorMixin
from ..utils import check_array, check_random_state, check_X_y
from ..utils import check_random_state
from ..utils.extmath import safe_sparse_dot
from ..utils.multiclass import _check_partial_fit_first_call
from ..utils.validation import check_is_fitted, _check_sample_weight
Expand Down Expand Up @@ -491,8 +491,10 @@ def _partial_fit(self, X, y, alpha, C,
loss, learning_rate, max_iter,
classes, sample_weight,
coef_init, intercept_init):
X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64,
order="C", accept_large_sparse=False)
first_call = not hasattr(self, "classes_")
X, y = self._validate_data(X, y, accept_sparse='csr', dtype=np.float64,
order="C", accept_large_sparse=False,
reset=first_call)

n_samples, n_features = X.shape

Expand Down Expand Up @@ -1138,22 +1140,20 @@ def __init__(self, loss="squared_loss", *, penalty="l2", alpha=0.0001,

def _partial_fit(self, X, y, alpha, C, loss, learning_rate,
max_iter, sample_weight, coef_init, intercept_init):
first_call = getattr(self, "coef_", None) is None
X, y = self._validate_data(X, y, accept_sparse="csr", copy=False,
order='C', dtype=np.float64,
accept_large_sparse=False)
accept_large_sparse=False, reset=first_call)
y = y.astype(np.float64, copy=False)

n_samples, n_features = X.shape

sample_weight = _check_sample_weight(sample_weight, X)

# Allocate datastructures from input arguments
if getattr(self, "coef_", None) is None:
if first_call:
self._allocate_parameter_mem(1, n_features, coef_init,
intercept_init)
elif n_features != self.coef_.shape[-1]:
raise ValueError("Number of features %d does not match previous "
"data %d." % (n_features, self.coef_.shape[-1]))
if self.average > 0 and getattr(self, "_average_coef", None) is None:
self._average_coef = np.zeros(n_features,
dtype=np.float64,
Expand Down Expand Up @@ -1269,7 +1269,7 @@ def _decision_function(self, X):
"""
check_is_fitted(self)

X = check_array(X, accept_sparse='csr')
X = self._validate_data(X, accept_sparse='csr', reset=False)

scores = safe_sparse_dot(X, self.coef_.T,
dense_output=True) + self.intercept_
Expand Down
9 changes: 3 additions & 6 deletions sklearn/svm/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,9 @@ def _validate_for_predict(self, X):
check_is_fitted(self)

if not callable(self.kernel):
X = check_array(X, accept_sparse='csr', dtype=np.float64,
order="C", accept_large_sparse=False)
X = self._validate_data(X, accept_sparse='csr', dtype=np.float64,
order="C", accept_large_sparse=False,
reset=False)

if self._sparse and not sp.isspmatrix(X):
X = sp.csr_matrix(X)
Expand All @@ -489,10 +490,6 @@ def _validate_for_predict(self, X):
raise ValueError("X.shape[1] = %d should be equal to %d, "
"the number of samples at training time" %
(X.shape[1], self.shape_fit_[0]))
elif not callable(self.kernel) and X.shape[1] != self.shape_fit_[1]:
raise ValueError("X.shape[1] = %d should be equal to %d, "
"the number of features at training time" %
(X.shape[1], self.shape_fit_[1]))
return X

@property
Expand Down
2 changes: 0 additions & 2 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def test_search_cv(estimator, check, request):
'feature_extraction',
'feature_selection',
'isotonic',
'linear_model',
'manifold',
'mixture',
'model_selection',
Expand All @@ -284,7 +283,6 @@ def test_search_cv(estimator, check, request):
'pipeline',
'random_projection',
'semi_supervised',
'svm',
}

N_FEATURES_IN_AFTER_FIT_ESTIMATORS = [
Expand Down

0 comments on commit 5946f8b

Please sign in to comment.