Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fefd804

Browse filesBrowse files
iwhalvicglemaitre
authored andcommitted
FIX Use cho_solve when return_std=True for GaussianProcessRegressor (#19939)
1 parent eda2153 commit fefd804
Copy full SHA for fefd804

File tree

Expand file treeCollapse file tree

3 files changed

+50
-41
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+50
-41
lines changed

‎doc/whats_new/v0.24.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v0.24.rst
+8-1Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ Changelog
4545
:mod:`sklearn.gaussian_process`
4646
...............................
4747

48+
- |Fix| Avoid explicitly forming inverse covariance matrix in
49+
:class:`gaussian_process.GaussianProcessRegressor` when set to output
50+
standard deviation. With certain covariance matrices this inverse is unstable
51+
to compute explicitly. Calling Cholesky solver mitigates this issue in
52+
computation.
53+
:pr:`19939` by :user:`Ian Halvic <iwhalvic>`.
54+
4855
- |Fix| Avoid division by zero when scaling constant target in
4956
:class:`gaussian_process.GaussianProcessRegressor`. It was due to a std. dev.
5057
equal to 0. Now, such case is detected and the std. dev. is affected to 1
@@ -59,7 +66,7 @@ Changelog
5966
- |Fix|: Fixed a bug in :class:`linear_model.LogisticRegression`: the
6067
sample_weight object is not modified anymore. :pr:`19182` by
6168
:user:`Yosuke KOBAYASHI <m7142yosuke>`.
62-
69+
6370
:mod:`sklearn.metrics`
6471
......................
6572

‎sklearn/gaussian_process/_gpr.py

Copy file name to clipboardExpand all lines: sklearn/gaussian_process/_gpr.py
+9-15Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from operator import itemgetter
99

1010
import numpy as np
11-
from scipy.linalg import cholesky, cho_solve, solve_triangular
11+
from scipy.linalg import cholesky, cho_solve
1212
import scipy.optimize
1313

1414
from ..base import BaseEstimator, RegressorMixin, clone
@@ -271,8 +271,6 @@ def obj_func(theta, eval_gradient=True):
271271
K[np.diag_indices_from(K)] += self.alpha
272272
try:
273273
self.L_ = cholesky(K, lower=True) # Line 2
274-
# self.L_ changed, self._K_inv needs to be recomputed
275-
self._K_inv = None
276274
except np.linalg.LinAlgError as exc:
277275
exc.args = ("The kernel, %s, is not returning a "
278276
"positive definite matrix. Try gradually "
@@ -345,31 +343,27 @@ def predict(self, X, return_std=False, return_cov=False):
345343
else: # Predict based on GP posterior
346344
K_trans = self.kernel_(X, self.X_train_)
347345
y_mean = K_trans.dot(self.alpha_) # Line 4 (y_mean = f_star)
348-
349346
# undo normalisation
350347
y_mean = self._y_train_std * y_mean + self._y_train_mean
351348

352349
if return_cov:
353-
v = cho_solve((self.L_, True), K_trans.T) # Line 5
354-
y_cov = self.kernel_(X) - K_trans.dot(v) # Line 6
350+
# Solve K @ V = K_trans.T
351+
V = cho_solve((self.L_, True), K_trans.T) # Line 5
352+
y_cov = self.kernel_(X) - K_trans.dot(V) # Line 6
355353

356354
# undo normalisation
357355
y_cov = y_cov * self._y_train_std**2
358356

359357
return y_mean, y_cov
360358
elif return_std:
361-
# cache result of K_inv computation
362-
if self._K_inv is None:
363-
# compute inverse K_inv of K based on its Cholesky
364-
# decomposition L and its inverse L_inv
365-
L_inv = solve_triangular(self.L_.T,
366-
np.eye(self.L_.shape[0]))
367-
self._K_inv = L_inv.dot(L_inv.T)
359+
# Solve K @ V = K_trans.T
360+
V = cho_solve((self.L_, True), K_trans.T) # Line 5
368361

369362
# Compute variance of predictive distribution
363+
# Use einsum to avoid explicitly forming the large matrix
364+
# K_trans @ V just to extract its diagonal afterward.
370365
y_var = self.kernel_.diag(X)
371-
y_var -= np.einsum("ij,ij->i",
372-
np.dot(K_trans, self._K_inv), K_trans)
366+
y_var -= np.einsum("ij,ji->i", K_trans, V)
373367

374368
# Check if any of the variances is negative because of
375369
# numerical issues. If yes: set the variance to 0.

‎sklearn/gaussian_process/tests/test_gpr.py

Copy file name to clipboardExpand all lines: sklearn/gaussian_process/tests/test_gpr.py
+33-25Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel
2020
from sklearn.exceptions import ConvergenceWarning
2121

22-
from sklearn.utils._testing \
23-
import (assert_array_less,
24-
assert_almost_equal, assert_raise_message,
25-
assert_array_almost_equal, assert_array_equal,
26-
assert_allclose, assert_warns_message)
22+
from sklearn.utils._testing import (
23+
assert_array_less,
24+
assert_almost_equal,
25+
assert_array_almost_equal,
26+
assert_allclose
27+
)
2728

2829

2930
def f(x):
@@ -185,7 +186,8 @@ def test_no_optimizer():
185186

186187

187188
@pytest.mark.parametrize('kernel', kernels)
188-
def test_predict_cov_vs_std(kernel):
189+
@pytest.mark.parametrize("target", [y, np.ones(X.shape[0], dtype=np.float64)])
190+
def test_predict_cov_vs_std(kernel, target):
189191
if sys.maxsize <= 2 ** 32 and sys.version_info[:2] == (3, 6):
190192
pytest.xfail("This test may fail on 32bit Py3.6")
191193

@@ -452,25 +454,6 @@ def test_no_fit_default_predict():
452454
assert_array_almost_equal(y_cov1, y_cov2)
453455

454456

455-
@pytest.mark.parametrize('kernel', kernels)
456-
def test_K_inv_reset(kernel):
457-
y2 = f(X2).ravel()
458-
459-
# Test that self._K_inv is reset after a new fit
460-
gpr = GaussianProcessRegressor(kernel=kernel).fit(X, y)
461-
assert hasattr(gpr, '_K_inv')
462-
assert gpr._K_inv is None
463-
gpr.predict(X, return_std=True)
464-
assert gpr._K_inv is not None
465-
gpr.fit(X2, y2)
466-
assert gpr._K_inv is None
467-
gpr.predict(X2, return_std=True)
468-
gpr2 = GaussianProcessRegressor(kernel=kernel).fit(X2, y2)
469-
gpr2.predict(X2, return_std=True)
470-
# the value of K_inv should be independent of the first fit
471-
assert_array_equal(gpr._K_inv, gpr2._K_inv)
472-
473-
474457
def test_warning_bounds():
475458
kernel = RBF(length_scale_bounds=[1e-5, 1e-3])
476459
gpr = GaussianProcessRegressor(kernel=kernel)
@@ -566,3 +549,28 @@ def test_constant_target(kernel):
566549
assert_allclose(y_pred, y_constant)
567550
# set atol because we compare to zero
568551
assert_allclose(np.diag(y_cov), 0., atol=1e-9)
552+
553+
554+
def test_gpr_consistency_std_cov_non_invertible_kernel():
555+
"""Check the consistency between the returned std. dev. and the covariance.
556+
Non-regression test for:
557+
https://github.com/scikit-learn/scikit-learn/issues/19936
558+
Inconsistencies were observed when the kernel cannot be inverted (or
559+
numerically stable).
560+
"""
561+
kernel = (C(8.98576054e+05, (1e-12, 1e12)) *
562+
RBF([5.91326520e+02, 1.32584051e+03], (1e-12, 1e12)) +
563+
WhiteKernel(noise_level=1e-5))
564+
gpr = GaussianProcessRegressor(kernel=kernel, alpha=0, optimizer=None)
565+
X_train = np.array([[0., 0.], [1.54919334, -0.77459667], [-1.54919334, 0.],
566+
[0., -1.54919334], [0.77459667, 0.77459667],
567+
[-0.77459667, 1.54919334]])
568+
y_train = np.array([[-2.14882017e-10], [-4.66975823e+00], [4.01823986e+00],
569+
[-1.30303674e+00], [-1.35760156e+00],
570+
[3.31215668e+00]])
571+
gpr.fit(X_train, y_train)
572+
X_test = np.array([[-1.93649167, -1.93649167], [1.93649167, -1.93649167],
573+
[-1.93649167, 1.93649167], [1.93649167, 1.93649167]])
574+
pred1, std = gpr.predict(X_test, return_std=True)
575+
pred2, cov = gpr.predict(X_test, return_cov=True)
576+
assert_allclose(std, np.sqrt(np.diagonal(cov)), rtol=1e-5)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.