Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b84afe5

Browse filesBrowse files
afonariSasha Fonaricmarmoglemaitre
authored
FIX prevent division by zero with constant target in GPR (#19703)
Co-authored-by: Sasha Fonari <fonari@schrodinger.com> Co-authored-by: Chiara Marmo <cmarmo@users.noreply.github.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent d5ebdca commit b84afe5
Copy full SHA for b84afe5

File tree

Expand file treeCollapse file tree

3 files changed

+38
-1
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+38
-1
lines changed

‎doc/whats_new/v0.24.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v0.24.rst
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ Changelog
3232
- |Fix| Fixed a bug in :class:`decomposition.KernelPCA`'s
3333
``inverse_transform``. :pr:`19732` by :user:`Kei Ishikawa <kstoneriv3>`.
3434

35+
:mod:`sklearn.gaussian_process`
36+
...............................
37+
38+
- |Fix| Avoid division by zero when scaling constant target in
39+
:class:`gaussian_process.GaussianProcessRegressor`. It was due to a std. dev.
40+
equal to 0. Now, such case is detected and the std. dev. is affected to 1
41+
avoiding a division by zero and thus the presence of NaN values in the
42+
normalized target.
43+
:pr:`19703` by :user:`sobkevich`, :user:`Boris Villazón-Terrazas <boricles>`
44+
and :user:`Alexandr Fonari <afonari>`.
45+
3546
:mod:`sklearn.linear_model`
3647
...........................
3748

‎sklearn/gaussian_process/_gpr.py

Copy file name to clipboardExpand all lines: sklearn/gaussian_process/_gpr.py
+4-1Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ..base import BaseEstimator, RegressorMixin, clone
1515
from ..base import MultiOutputMixin
1616
from .kernels import RBF, ConstantKernel as C
17+
from ..preprocessing._data import _handle_zeros_in_scale
1718
from ..utils import check_random_state
1819
from ..utils.optimize import _check_optimize_result
1920
from ..utils.validation import _deprecate_positional_args
@@ -197,7 +198,9 @@ def fit(self, X, y):
197198
# Normalize target value
198199
if self.normalize_y:
199200
self._y_train_mean = np.mean(y, axis=0)
200-
self._y_train_std = np.std(y, axis=0)
201+
self._y_train_std = _handle_zeros_in_scale(
202+
np.std(y, axis=0), copy=False
203+
)
201204

202205
# Remove mean and make unit variance
203206
y = (y - self._y_train_mean) / self._y_train_std

‎sklearn/gaussian_process/tests/test_gpr.py

Copy file name to clipboardExpand all lines: sklearn/gaussian_process/tests/test_gpr.py
+23Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,3 +546,26 @@ def test_bound_check_fixed_hyperparameter():
546546
periodicity_bounds="fixed") # seasonal component
547547
kernel = k1 + k2
548548
GaussianProcessRegressor(kernel=kernel).fit(X, y)
549+
550+
551+
# FIXME: we should test for multitargets as well. However, GPR is broken:
552+
# see: https://github.com/scikit-learn/scikit-learn/pull/19706
553+
@pytest.mark.parametrize('kernel', kernels)
554+
def test_constant_target(kernel):
555+
"""Check that the std. dev. is affected to 1 when normalizing a constant
556+
feature.
557+
Non-regression test for:
558+
https://github.com/scikit-learn/scikit-learn/issues/18318
559+
NaN where affected to the target when scaling due to null std. dev. with
560+
constant target.
561+
"""
562+
y_constant = np.ones(X.shape[0], dtype=np.float64)
563+
564+
gpr = GaussianProcessRegressor(kernel=kernel, normalize_y=True)
565+
gpr.fit(X, y_constant)
566+
assert gpr._y_train_std == pytest.approx(1.0)
567+
568+
y_pred, y_cov = gpr.predict(X, return_cov=True)
569+
assert_allclose(y_pred, y_constant)
570+
# set atol because we compare to zero
571+
assert_allclose(np.diag(y_cov), 0., atol=1e-9)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.