Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 96a96f1

Browse filesBrowse files
authored
ENH Adds n_features_in_ checking in cross_decomposition (#18741)
1 parent 71f7085 commit 96a96f1
Copy full SHA for 96a96f1

File tree

3 files changed

+11
-7
lines changed
Filter options

3 files changed

+11
-7
lines changed

‎sklearn/cross_decomposition/_pls.py

Copy file name to clipboardExpand all lines: sklearn/cross_decomposition/_pls.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def transform(self, X, Y=None, copy=True):
317317
`x_scores` if `Y` is not given, `(x_scores, y_scores)` otherwise.
318318
"""
319319
check_is_fitted(self)
320-
X = check_array(X, copy=copy, dtype=FLOAT_DTYPES)
320+
X = self._validate_data(X, copy=copy, dtype=FLOAT_DTYPES, reset=False)
321321
# Normalize
322322
X -= self._x_mean
323323
X /= self._x_std
@@ -379,7 +379,7 @@ def predict(self, X, copy=True):
379379
space.
380380
"""
381381
check_is_fitted(self)
382-
X = check_array(X, copy=copy, dtype=FLOAT_DTYPES)
382+
X = self._validate_data(X, copy=copy, dtype=FLOAT_DTYPES, reset=False)
383383
# Normalize
384384
X -= self._x_mean
385385
X /= self._x_std
@@ -984,7 +984,7 @@ def transform(self, X, Y=None):
984984
`(X_transformed, Y_transformed)` otherwise.
985985
"""
986986
check_is_fitted(self)
987-
X = check_array(X, dtype=np.float64)
987+
X = self._validate_data(X, dtype=np.float64, reset=False)
988988
Xr = (X - self._x_mean) / self._x_std
989989
x_scores = np.dot(Xr, self.x_weights_)
990990
if Y is not None:

‎sklearn/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_common.py
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def test_search_cv(estimator, check, request):
267267
'calibration',
268268
'compose',
269269
'covariance',
270-
'cross_decomposition',
271270
'discriminant_analysis',
272271
'ensemble',
273272
'feature_extraction',

‎sklearn/utils/estimator_checks.py

Copy file name to clipboardExpand all lines: sklearn/utils/estimator_checks.py
+8-3Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
load_iris,
6464
make_blobs,
6565
make_multilabel_classification,
66-
make_regression,
66+
make_regression
6767
)
6868

6969
REGRESSION_DATASET = None
@@ -646,6 +646,9 @@ def _set_checking_parameters(estimator):
646646
if name == 'OneHotEncoder':
647647
estimator.set_params(handle_unknown='ignore')
648648

649+
if name in CROSS_DECOMPOSITION:
650+
estimator.set_params(n_components=1)
651+
649652

650653
class _NotAnArray:
651654
"""An object that is convertible to an array.
@@ -3122,9 +3125,11 @@ def check_n_features_in_after_fitting(name, estimator_orig):
31223125
if 'warm_start' in estimator.get_params():
31233126
estimator.set_params(warm_start=False)
31243127

3125-
n_samples = 100
3126-
X = rng.normal(loc=100, size=(n_samples, 2))
3128+
n_samples = 150
3129+
X = rng.normal(size=(n_samples, 8))
3130+
X = _enforce_estimator_tags_x(estimator, X)
31273131
X = _pairwise_estimator_convert_X(X, estimator)
3132+
31283133
if is_regressor(estimator):
31293134
y = rng.normal(size=n_samples)
31303135
else:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.