-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
MNT n_features_in_ consistency in decomposition #18557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MNT n_features_in_ consistency in decomposition #18557
Conversation
@@ -324,4 +323,5 @@ def test_strict_mode_parametrize_with_checks(estimator, check): | ||
@pytest.mark.parametrize("estimator", N_FEATURES_IN_AFTER_FIT_ESTIMATORS, | ||
ids=_get_check_estimator_ids) | ||
def test_check_n_features_in_after_fitting(estimator): | ||
_set_checking_parameters(estimator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@thomasjpfan if you work on other module you might need this as long as this check is not part of the list of standard checks.
Thanks for fixing the title of this PR :) |
108c7a4
to
2fb837d
Compare
@@ -1347,6 +1347,9 @@ def transform(self, X): | ||
Transformed data. | ||
""" | ||
check_is_fitted(self) | ||
X = self._validate_data(X, accept_sparse=('csr', 'csc'), | ||
dtype=[np.float64, np.float32], | ||
reset=False) | ||
|
||
W, _, n_iter_ = non_negative_factorization( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non_negative_factorization
also does a call to check_array
internally so calling _validate_data
here causes some performance overhead. For simplicity's sake I don't want to do optimize this as part of this PR but we might want to add a kwarg to non_negative_factorization
to skip input validation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe write that as a TODO comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is most likely going to appear in other places. Another solution would be to have a private _non_negative_factorization
that does not call check_array
, but still check for non-negative values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ogrisel, a few questions but looks good
@@ -124,7 +123,7 @@ def transform(self, X): | ||
""" | ||
check_is_fitted(self) | ||
|
||
X = check_array(X) | ||
X = self._validate_data(X, dtype=[np.float64, np.float32], reset=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to pass the dtype here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see no reason why this could would not work properly in float32 so this is a slight performance improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the fit of PCA accepts float32 without upcasting so this change makes transform consistent with fit.
|
||
Returns | ||
------- | ||
X_new : ndarray of shape (n_samples, n_components) | ||
""" | ||
check_is_fitted(self) | ||
|
||
X = check_array(X, copy=copy, dtype=FLOAT_DTYPES) | ||
X = self._validate_data(X, copy=(copy and self.whiten), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this a change of behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a small memory efficiency optim :)
@@ -1347,6 +1347,9 @@ def transform(self, X): | ||
Transformed data. | ||
""" | ||
check_is_fitted(self) | ||
X = self._validate_data(X, accept_sparse=('csr', 'csc'), | ||
dtype=[np.float64, np.float32], | ||
reset=False) | ||
|
||
W, _, n_iter_ = non_negative_factorization( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe write that as a TODO comment?
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @ogrisel, LGTM when green!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Early PR to let @thomasjpfan (and other know) that I started working on this module.
Builds on #18514.