Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Gaussian Mixture - weighted implementation #17130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
Loading
from
139 changes: 109 additions & 30 deletions 139 sklearn/mixture/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..base import DensityMixin
from ..exceptions import ConvergenceWarning
from ..utils import check_array, check_random_state
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, _check_sample_weight


def _check_shape(param, param_shape, name):
Expand Down Expand Up @@ -62,6 +62,23 @@ def _check_X(X, n_components=None, n_features=None, ensure_min_samples=1):
return X


def _check_normalize_sample_weight(sample_weight, X):
"""Set sample_weight if None, and check for correct dtype"""
if sample_weight is None:
sample_weight = np.ones(X.shape[0])

sample_weight_was_none = sample_weight is None

sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
if not sample_weight_was_none:
# normalize the weights to sum up to n_samples
# an array of 1 (i.e. samples_weight is None) is already normalized
n_samples = len(sample_weight)
scale = n_samples / sample_weight.sum()
sample_weight *= scale
return sample_weight


class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
"""Base class for mixture models.

Expand Down Expand Up @@ -129,7 +146,7 @@ def _check_parameters(self, X):
"""
pass

def _initialize_parameters(self, X, random_state):
def _initialize_parameters(self, X, sample_weight, random_state):
"""Initialize the model parameters.

Parameters
Expand All @@ -139,13 +156,20 @@ def _initialize_parameters(self, X, random_state):
random_state : RandomState
A random number generator instance that controls the random seed
used for the method chosen to initialize the parameters.

sample_weight : array-like, shape (n_samples,), optional
The weights for each observation in X. If None, all observations
are assigned equal weight (default: None).
"""
n_samples, _ = X.shape
sample_weight_copy = sample_weight.copy()

if self.init_params == 'kmeans':
resp = np.zeros((n_samples, self.n_components))
label = cluster.KMeans(n_clusters=self.n_components, n_init=1,
random_state=random_state).fit(X).labels_
label = cluster.KMeans(
n_clusters=self.n_components, n_init=1,
random_state=random_state).fit(
X, sample_weight=sample_weight_copy).labels_
resp[np.arange(n_samples), label] = 1
elif self.init_params == 'random':
resp = random_state.rand(n_samples, self.n_components)
Expand All @@ -154,21 +178,23 @@ def _initialize_parameters(self, X, random_state):
raise ValueError("Unimplemented initialization method '%s'"
% self.init_params)

self._initialize(X, resp)
self._initialize(X, sample_weight, resp)

@abstractmethod
def _initialize(self, X, resp):
def _initialize(self, X, sample_weight, resp):
"""Initialize the model parameters of the derived class.

Parameters
----------
X : array-like, shape (n_samples, n_features)

resp : array-like, shape (n_samples, n_components)

sample_weight : array-like, shape (n_samples,), optional.
"""
pass

def fit(self, X, y=None):
def fit(self, X, y=None, sample_weight=None):
"""Estimate model parameters with the EM algorithm.

The method fits the model ``n_init`` times and sets the parameters with
Expand All @@ -186,14 +212,18 @@ def fit(self, X, y=None):
List of n_features-dimensional data points. Each row
corresponds to a single data point.

sample_weight : array-like, shape (n_samples,), optional
The weights for each observation in X. If None, all observations
are assigned equal weight (default: None).

Returns
-------
self
"""
self.fit_predict(X, y)
self.fit_predict(X, y, sample_weight)
return self

def fit_predict(self, X, y=None):
def fit_predict(self, X, y=None, sample_weight=None):
"""Estimate model parameters using X and predict the labels for X.

The method fits the model n_init times and sets the parameters with
Expand All @@ -212,6 +242,13 @@ def fit_predict(self, X, y=None):
List of n_features-dimensional data points. Each row
corresponds to a single data point.

y : Ignored
Not used, present here for API consistency by convention.

sample_weight : array-like, shape (n_samples,), optional
The weights for each observation in X. If None, all observations
are assigned equal weight (default: None).

Returns
-------
labels : array, shape (n_samples,)
Expand All @@ -221,6 +258,8 @@ def fit_predict(self, X, y=None):
self._check_n_features(X, reset=True)
self._check_initial_parameters(X)

sample_weight = _check_normalize_sample_weight(sample_weight, X)

# if we enable warm_start, we will have a unique initialisation
do_init = not(self.warm_start and hasattr(self, 'converged_'))
n_init = self.n_init if do_init else 1
Expand All @@ -235,15 +274,15 @@ def fit_predict(self, X, y=None):
self._print_verbose_msg_init_beg(init)

if do_init:
self._initialize_parameters(X, random_state)
self._initialize_parameters(X, sample_weight, random_state)

lower_bound = (-np.infty if do_init else self.lower_bound_)

for n_iter in range(1, self.max_iter + 1):
prev_lower_bound = lower_bound

log_prob_norm, log_resp = self._e_step(X)
self._m_step(X, log_resp)
log_prob_norm, log_resp = self._e_step(X, sample_weight)
self._m_step(X, sample_weight, log_resp)
lower_bound = self._compute_lower_bound(
log_resp, log_prob_norm)

Expand Down Expand Up @@ -275,17 +314,20 @@ def fit_predict(self, X, y=None):
# Always do a final e-step to guarantee that the labels returned by
# fit_predict(X) are always consistent with fit(X).predict(X)
# for any value of max_iter and tol (and any random_state).
_, log_resp = self._e_step(X)
_, log_resp = self._e_step(X, sample_weight)

return log_resp.argmax(axis=1)

def _e_step(self, X):
def _e_step(self, X, sample_weight):
"""E step.

Parameters
----------
X : array-like, shape (n_samples, n_features)

sample_weight : array-like, shape (n_samples,)
The weights for each observation in X.

Returns
-------
log_prob_norm : float
Expand All @@ -295,11 +337,12 @@ def _e_step(self, X):
Logarithm of the posterior probabilities (or responsibilities) of
the point of each sample in X.
"""
log_prob_norm, log_resp = self._estimate_log_prob_resp(X)
log_prob_norm, log_resp = self._estimate_log_prob_resp(X,
sample_weight)
return np.mean(log_prob_norm), log_resp

@abstractmethod
def _m_step(self, X, log_resp):
def _m_step(self, X, sample_weight, log_resp):
"""M step.

Parameters
Expand All @@ -309,6 +352,9 @@ def _m_step(self, X, log_resp):
log_resp : array-like, shape (n_samples, n_components)
Logarithm of the posterior probabilities (or responsibilities) of
the point of each sample in X.

sample_weight : array-like, shape (n_samples,)
The weights for each observation in X.
"""
pass

Expand All @@ -320,7 +366,7 @@ def _get_parameters(self):
def _set_parameters(self, params):
pass

def score_samples(self, X):
def score_samples(self, X, sample_weight=None):
"""Compute the weighted log probabilities for each sample.

Parameters
Expand All @@ -329,6 +375,10 @@ def score_samples(self, X):
List of n_features-dimensional data points. Each row
corresponds to a single data point.

sample_weight : array-like, shape (n_samples,), optional
The weights for each observation in X. If None, all observations
are assigned equal weight (default: None).

Returns
-------
log_prob : array, shape (n_samples,)
Expand All @@ -337,9 +387,12 @@ def score_samples(self, X):
check_is_fitted(self)
X = _check_X(X, None, self.means_.shape[1])

return logsumexp(self._estimate_weighted_log_prob(X), axis=1)
sample_weight = _check_normalize_sample_weight(sample_weight, X)

return sample_weight * logsumexp(
self._estimate_weighted_log_prob(X, sample_weight), axis=1)

def score(self, X, y=None):
def score(self, X, y=None, sample_weight=None):
"""Compute the per-sample average log-likelihood of the given data X.

Parameters
Expand All @@ -348,14 +401,18 @@ def score(self, X, y=None):
List of n_features-dimensional data points. Each row
corresponds to a single data point.

sample_weight : array-like, shape (n_samples,), optional
The weights for each observation in X. If None, all observations
are assigned equal weight (default: None).

Returns
-------
log_likelihood : float
Log likelihood of the Gaussian mixture given X.
"""
return self.score_samples(X).mean()
return self.score_samples(X, sample_weight).mean()

def predict(self, X):
def predict(self, X, sample_weight=None):
"""Predict the labels for the data samples in X using trained model.

Parameters
Expand All @@ -364,16 +421,24 @@ def predict(self, X):
List of n_features-dimensional data points. Each row
corresponds to a single data point.

sample_weight : array-like, shape (n_samples,), optional
The weights for each observation in X. If None, all observations
are assigned equal weight (default: None).

Returns
-------
labels : array, shape (n_samples,)
Component labels.
"""
check_is_fitted(self)
X = _check_X(X, None, self.means_.shape[1])
return self._estimate_weighted_log_prob(X).argmax(axis=1)

def predict_proba(self, X):
sample_weight = _check_normalize_sample_weight(sample_weight, X)

return self._estimate_weighted_log_prob(X,
sample_weight).argmax(axis=1)

def predict_proba(self, X, sample_weight=None):
"""Predict posterior probability of each component given the data.

Parameters
Expand All @@ -382,6 +447,10 @@ def predict_proba(self, X):
List of n_features-dimensional data points. Each row
corresponds to a single data point.

sample_weight : array-like, shape (n_samples,), optional
The weights for each observation in X. If None, all observations
are assigned equal weight (default: None).

Returns
-------
resp : array, shape (n_samples, n_components)
Expand All @@ -390,7 +459,10 @@ def predict_proba(self, X):
"""
check_is_fitted(self)
X = _check_X(X, None, self.means_.shape[1])
_, log_resp = self._estimate_log_prob_resp(X)

sample_weight = _check_normalize_sample_weight(sample_weight, X)

_, log_resp = self._estimate_log_prob_resp(X, sample_weight)
return np.exp(log_resp)

def sample(self, n_samples=1):
Expand Down Expand Up @@ -438,22 +510,25 @@ def sample(self, n_samples=1):
self.means_, self.covariances_, n_samples_comp)])

y = np.concatenate([np.full(sample, j, dtype=int)
for j, sample in enumerate(n_samples_comp)])
for j, sample in enumerate(n_samples_comp)])

return (X, y)

def _estimate_weighted_log_prob(self, X):
def _estimate_weighted_log_prob(self, X, sample_weight):
"""Estimate the weighted log-probabilities, log P(X | Z) + log weights.

Parameters
----------
X : array-like, shape (n_samples, n_features)

sample_weight : array-like, shape (n_samples,)

Returns
-------
weighted_log_prob : array, shape (n_samples, n_component)
"""
return self._estimate_log_prob(X) + self._estimate_log_weights()
return (self._estimate_log_prob(X, sample_weight)
+ self._estimate_log_weights())

@abstractmethod
def _estimate_log_weights(self):
Expand All @@ -466,7 +541,7 @@ def _estimate_log_weights(self):
pass

@abstractmethod
def _estimate_log_prob(self, X):
def _estimate_log_prob(self, X, sample_weight):
"""Estimate the log-probabilities log P(X | Z).

Compute the log-probabilities per each component for each sample.
Expand All @@ -475,13 +550,15 @@ def _estimate_log_prob(self, X):
----------
X : array-like, shape (n_samples, n_features)

sample_weight : array-like, shape (n_samples,)

Returns
-------
log_prob : array, shape (n_samples, n_component)
"""
pass

def _estimate_log_prob_resp(self, X):
def _estimate_log_prob_resp(self, X, sample_weight):
"""Estimate log probabilities and responsibilities for each sample.

Compute the log probabilities, weighted log probabilities per
Expand All @@ -492,6 +569,8 @@ def _estimate_log_prob_resp(self, X):
----------
X : array-like, shape (n_samples, n_features)

sample_weight : array-like, shape (n_samples,)

Returns
-------
log_prob_norm : array, shape (n_samples,)
Expand All @@ -500,7 +579,7 @@ def _estimate_log_prob_resp(self, X):
log_responsibilities : array, shape (n_samples, n_components)
logarithm of the responsibilities
"""
weighted_log_prob = self._estimate_weighted_log_prob(X)
weighted_log_prob = self._estimate_weighted_log_prob(X, sample_weight)
log_prob_norm = logsumexp(weighted_log_prob, axis=1)
with np.errstate(under='ignore'):
# ignore underflow
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.