Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Problems with score_samples in BayesianGaussianMixture #10148

Copy link
Copy link
Open
@rdturnermtl

Description

@rdturnermtl
Issue body actions

I have noticed that on many problems the score_samples() function in sklearn.mixture.BayesianGaussianMixture consistently gives lower test set likelihoods than just taking the posterior mean parameters from the mixture and plugging them into a regular mixture of Gaussians likelihood.

I have coded up a demo on a toy problem to illustrate:

import numpy as np
import scipy.stats as ss
from scipy.special import logsumexp
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture


def loglik_mixture(X, w, means, covs):
    N = X.shape[0]

    w = w / np.sum(w)  # Just to be sure normalized

    loglik = np.zeros((N, len(w)))
    for ii in range(len(w)):
        mu = means[ii, :]
        S = covs[ii, :, :]
        gauss_part = ss.multivariate_normal.logpdf(X, mu, S)
        loglik[:, ii] = np.log(w[ii]) + gauss_part
    loglik = logsumexp(loglik, axis=1)
    return loglik


def simple_data():
    x = np.random.randn(1000, 2) + 1.0
    idx = np.random.rand(1000) <= 0.5
    x[idx, :] = -1 * x[idx, :]
    return x

np.random.seed(1234)

delta = []
for _ in range(500):
    x_train = simple_data()
    x_test = simple_data()

    gmm = GaussianMixture(n_components=2, covariance_type='full')
    gmm.fit(x_train)
    loglik0 = gmm.score_samples(x_test)
    w, means, covs = gmm.weights_, gmm.means_, gmm.covariances_
    loglik1 = loglik_mixture(x_test, w, means, covs)
    # Demonstrate that loglik_mixture() is correct
    np.testing.assert_allclose(loglik0, loglik1)
    print('-' * 10)
    print(f'MLE GMM loglik {np.mean(loglik0)}')

    bgmm = BayesianGaussianMixture(n_components=2, covariance_type='full')
    bgmm.fit(x_train)
    loglik_bayes0 = bgmm.score_samples(x_test)
    w, means, covs = bgmm.weights_, bgmm.means_, bgmm.covariances_
    loglik_bayes1 = loglik_mixture(x_test, w, means, covs)
    print(f'VB GMM loglik (built-in) {np.mean(loglik_bayes0)}')
    print(f'improvement {np.mean(loglik_bayes0) - np.mean(loglik0)}')
    print(f'VB GMM loglik {np.mean(loglik_bayes1)}')
    print(f'improvement {np.mean(loglik_bayes1) - np.mean(loglik0)}')

    delta.append(np.mean(loglik_bayes1) - np.mean(loglik_bayes0))
print(np.mean(delta))
print(np.mean(np.array(delta) > 0))

This gives for instance:

MLE GMM loglik -3.362221
VB GMM loglik (built-in) -3.367131
improvement -0.004911
VB GMM loglik -3.361099
improvement 0.001121

The built in score for VBGMM is consistently worse than MLE GMM but often better when the posterior mean is plugged in as a point estimate in loglik_mixture(). On this problem, the loglik_mixture() likelihood is always around 0.006 nats higher than the built in score function.

This makes me wonder if there is some normalization issue in the likelihood in _estimate_log_prob(). I am not sure where the derivation for the code in lines 690-698 is.

The documentation says it implements Blei and Jordan (2006), which has the posterior predictive in eqn (23). But that is only an approximation and not necessarily even normalized.

If score_samples is used for evaluation purposes, it is not fair to compare VB-GMM to other models if its likelihood is not even normalized! Some more investigation is needed here.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Morty Proxy This is a proxified and sanitized view of the page, visit original site.