Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

QDA Classifier: divide by zero if data not of full rank  #14997

Copy link
Copy link
Open
@thielepaul

Description

@thielepaul
Issue body actions

Description

When using the QDA classifier with data that is not of full rank, matching does not work and and a divide by zero error is raised. An example dataset can be found in the code below.

Steps/Code to Reproduce

import numpy as np
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

np.seterr(all="warn")


def dataset_fixed_cov():
    n, dim = 300, 2
    np.random.seed(0)
    C = np.array([[0.0, -0.23], [0.83, 0.23]])
    X = np.r_[
        np.dot(np.concatenate((np.random.randn(n, 1), np.zeros((n, 1))), axis=1), C),
        np.dot(np.random.randn(n, dim), C) + np.array([1, 1]),
    ]
    y = np.hstack((np.zeros(n), np.ones(n)))
    return X, y


X, y = dataset_fixed_cov()
qda = QuadraticDiscriminantAnalysis(store_covariance=True)
y_pred = qda.fit(X, y).predict(X)

print("Correct: {}/{}".format(np.sum(y==y_pred), 600))

Expected Results

No divide by zero errors occur and the prediction works correctly.

Actual Results

Prediction fails for all the samples in the class and following errors are raised:

/usr/local/lib/python3.7/site-packages/sklearn/discriminant_analysis.py:693: UserWarning: Variables are collinear
  warnings.warn("Variables are collinear")
/usr/local/lib/python3.7/site-packages/sklearn/discriminant_analysis.py:717: RuntimeWarning: divide by zero encountered in power
  X2 = np.dot(Xm, R * (S ** (-0.5)))
/usr/local/lib/python3.7/site-packages/sklearn/discriminant_analysis.py:717: RuntimeWarning: invalid value encountered in multiply
  X2 = np.dot(Xm, R * (S ** (-0.5)))
/usr/local/lib/python3.7/site-packages/sklearn/discriminant_analysis.py:720: RuntimeWarning: divide by zero encountered in log
  u = np.asarray([np.sum(np.log(s)) for s in self.scalings_])

Versions

System:
    python: 3.7.4 (default, Sep 12 2019, 15:40:15)  [GCC 8.3.0]
executable: /usr/local/bin/python
   machine: Linux-5.2.14-arch2-1-ARCH-x86_64-with-debian-10.1

Python deps:
       pip: 19.2.3
setuptools: 41.2.0
   sklearn: 0.21.3
     numpy: 1.17.2
     scipy: 1.3.1
    Cython: None
    pandas: None

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Morty Proxy This is a proxified and sanitized view of the page, visit original site.