Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b55aba5

Browse filesBrowse files
ENH Exposes latent mean and variance for GPCs (scikit-learn#22227)
Co-authored-by: antoinebaker <antoinebaker@users.noreply.github.com>
1 parent f29c100 commit b55aba5
Copy full SHA for b55aba5

File tree

4 files changed

+119
-11
lines changed
Filter options

4 files changed

+119
-11
lines changed

‎doc/modules/gaussian_process.rst

Copy file name to clipboardExpand all lines: doc/modules/gaussian_process.rst
+7-2
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ The :class:`GaussianProcessClassifier` implements Gaussian processes (GP) for
106106
classification purposes, more specifically for probabilistic classification,
107107
where test predictions take the form of class probabilities.
108108
GaussianProcessClassifier places a GP prior on a latent function :math:`f`,
109-
which is then squashed through a link function to obtain the probabilistic
109+
which is then squashed through a link function :math:`\pi` to obtain the probabilistic
110110
classification. The latent function :math:`f` is a so-called nuisance function,
111111
whose values are not observed and are not relevant by themselves.
112112
Its purpose is to allow a convenient formulation of the model, and :math:`f`
113-
is removed (integrated out) during prediction. GaussianProcessClassifier
113+
is removed (integrated out) during prediction. :class:`GaussianProcessClassifier`
114114
implements the logistic link function, for which the integral cannot be
115115
computed analytically but is easily approximated in the binary case.
116116

@@ -134,6 +134,11 @@ that have been chosen randomly from the range of allowed values.
134134
If the initial hyperparameters should be kept fixed, `None` can be passed as
135135
optimizer.
136136

137+
In some scenarios, information about the latent function :math:`f` is desired
138+
(i.e. the mean :math:`\bar{f_*}` and the variance :math:`\text{Var}[f_*]` described
139+
in Eqs. (3.21) and (3.24) of [RW2006]_). The :class:`GaussianProcessClassifier`
140+
provides access to these quantities via the `latent_mean_and_variance` method.
141+
137142
:class:`GaussianProcessClassifier` supports multi-class classification
138143
by performing either one-versus-rest or one-versus-one based training and
139144
prediction. In one-versus-rest, one binary Gaussian process classifier is
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- :class:`gaussian_process.GaussianProcessClassifier` now includes a `latent_mean_and_variance` method that exposes the mean and the variance of the latent function, :math:`f`, used in the Laplace approximation. By :user:`Miguel González Duque <miguelgondu>`

‎sklearn/gaussian_process/_gpc.py

Copy file name to clipboardExpand all lines: sklearn/gaussian_process/_gpc.py
+76-9
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,9 @@ def predict_proba(self, X):
306306
"""
307307
check_is_fitted(self)
308308

309-
# Based on Algorithm 3.2 of GPML
310-
K_star = self.kernel_(self.X_train_, X) # K_star =k(x_star)
311-
f_star = K_star.T.dot(self.y_train_ - self.pi_) # Line 4
312-
v = solve(self.L_, self.W_sr_[:, np.newaxis] * K_star) # Line 5
313-
# Line 6 (compute np.diag(v.T.dot(v)) via einsum)
314-
var_f_star = self.kernel_.diag(X) - np.einsum("ij,ij->j", v, v)
309+
# Compute the mean and variance of the latent function
310+
# (Lines 4-6 of Algorithm 3.2 of GPML)
311+
latent_mean, latent_var = self.latent_mean_and_variance(X)
315312

316313
# Line 7:
317314
# Approximate \int log(z) * N(z | f_star, var_f_star)
@@ -320,12 +317,12 @@ def predict_proba(self, X):
320317
# sigmoid by a linear combination of 5 error functions.
321318
# For information on how this integral can be computed see
322319
# blitiri.blogspot.de/2012/11/gaussian-integral-of-error-function.html
323-
alpha = 1 / (2 * var_f_star)
324-
gamma = LAMBDAS * f_star
320+
alpha = 1 / (2 * latent_var)
321+
gamma = LAMBDAS * latent_mean
325322
integrals = (
326323
np.sqrt(np.pi / alpha)
327324
* erf(gamma * np.sqrt(alpha / (alpha + LAMBDAS**2)))
328-
/ (2 * np.sqrt(var_f_star * 2 * np.pi))
325+
/ (2 * np.sqrt(latent_var * 2 * np.pi))
329326
)
330327
pi_star = (COEFS * integrals).sum(axis=0) + 0.5 * COEFS.sum()
331328

@@ -410,6 +407,39 @@ def log_marginal_likelihood(
410407

411408
return Z, d_Z
412409

410+
def latent_mean_and_variance(self, X):
411+
"""Compute the mean and variance of the latent function values.
412+
413+
Based on algorithm 3.2 of [RW2006]_, this function returns the latent
414+
mean (Line 4) and variance (Line 6) of the Gaussian process
415+
classification model.
416+
417+
Note that this function is only supported for binary classification.
418+
419+
Parameters
420+
----------
421+
X : array-like of shape (n_samples, n_features) or list of object
422+
Query points where the GP is evaluated for classification.
423+
424+
Returns
425+
-------
426+
latent_mean : array-like of shape (n_samples,)
427+
Mean of the latent function values at the query points.
428+
429+
latent_var : array-like of shape (n_samples,)
430+
Variance of the latent function values at the query points.
431+
"""
432+
check_is_fitted(self)
433+
434+
# Based on Algorithm 3.2 of GPML
435+
K_star = self.kernel_(self.X_train_, X) # K_star =k(x_star)
436+
latent_mean = K_star.T.dot(self.y_train_ - self.pi_) # Line 4
437+
v = solve(self.L_, self.W_sr_[:, np.newaxis] * K_star) # Line 5
438+
# Line 6 (compute np.diag(v.T.dot(v)) via einsum)
439+
latent_var = self.kernel_.diag(X) - np.einsum("ij,ij->j", v, v)
440+
441+
return latent_mean, latent_var
442+
413443
def _posterior_mode(self, K, return_temporaries=False):
414444
"""Mode-finding for binary Laplace GPC and fixed kernel.
415445
@@ -902,3 +932,40 @@ def log_marginal_likelihood(
902932
"Obtained theta with shape %d."
903933
% (n_dims, n_dims * self.classes_.shape[0], theta.shape[0])
904934
)
935+
936+
def latent_mean_and_variance(self, X):
937+
"""Compute the mean and variance of the latent function.
938+
939+
Based on algorithm 3.2 of [RW2006]_, this function returns the latent
940+
mean (Line 4) and variance (Line 6) of the Gaussian process
941+
classification model.
942+
943+
Note that this function is only supported for binary classification.
944+
945+
Parameters
946+
----------
947+
X : array-like of shape (n_samples, n_features) or list of object
948+
Query points where the GP is evaluated for classification.
949+
950+
Returns
951+
-------
952+
latent_mean : array-like of shape (n_samples,)
953+
Mean of the latent function values at the query points.
954+
955+
latent_var : array-like of shape (n_samples,)
956+
Variance of the latent function values at the query points.
957+
"""
958+
if self.n_classes_ > 2:
959+
raise ValueError(
960+
"Returning the mean and variance of the latent function f "
961+
"is only supported for binary classification, received "
962+
f"{self.n_classes_} classes."
963+
)
964+
check_is_fitted(self)
965+
966+
if self.kernel is None or self.kernel.requires_vector_input:
967+
X = validate_data(self, X, ensure_2d=True, dtype="numeric", reset=False)
968+
else:
969+
X = validate_data(self, X, ensure_2d=False, dtype=None, reset=False)
970+
971+
return self.base_estimator_.latent_mean_and_variance(X)

‎sklearn/gaussian_process/tests/test_gpc.py

Copy file name to clipboardExpand all lines: sklearn/gaussian_process/tests/test_gpc.py
+35
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,38 @@ def test_gpc_fit_error(params, error_type, err_msg):
283283
gpc = GaussianProcessClassifier(**params)
284284
with pytest.raises(error_type, match=err_msg):
285285
gpc.fit(X, y)
286+
287+
288+
@pytest.mark.parametrize("kernel", kernels)
289+
def test_gpc_latent_mean_and_variance_shape(kernel):
290+
"""Checks that the latent mean and variance have the right shape."""
291+
gpc = GaussianProcessClassifier(kernel=kernel)
292+
gpc.fit(X, y)
293+
294+
# Check that the latent mean and variance have the right shape
295+
latent_mean, latent_variance = gpc.latent_mean_and_variance(X)
296+
assert latent_mean.shape == (X.shape[0],)
297+
assert latent_variance.shape == (X.shape[0],)
298+
299+
300+
def test_gpc_latent_mean_and_variance_complain_on_more_than_2_classes():
301+
"""Checks that the latent mean and variance have the right shape."""
302+
gpc = GaussianProcessClassifier(kernel=RBF())
303+
gpc.fit(X, y_mc)
304+
305+
# Check that the latent mean and variance have the right shape
306+
with pytest.raises(
307+
ValueError,
308+
match="Returning the mean and variance of the latent function f "
309+
"is only supported for binary classification",
310+
):
311+
gpc.latent_mean_and_variance(X)
312+
313+
314+
def test_latent_mean_and_variance_works_on_structured_kernels():
315+
X = ["A", "AB", "B"]
316+
y = np.array([True, False, True])
317+
kernel = MiniSeqKernel(baseline_similarity_bounds="fixed")
318+
gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y)
319+
320+
gpc.latent_mean_and_variance(X)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.