Open
Description
Description
When using the QDA classifier with data that is not of full rank, matching does not work and and a divide by zero error is raised. An example dataset can be found in the code below.
Steps/Code to Reproduce
import numpy as np
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
np.seterr(all="warn")
def dataset_fixed_cov():
n, dim = 300, 2
np.random.seed(0)
C = np.array([[0.0, -0.23], [0.83, 0.23]])
X = np.r_[
np.dot(np.concatenate((np.random.randn(n, 1), np.zeros((n, 1))), axis=1), C),
np.dot(np.random.randn(n, dim), C) + np.array([1, 1]),
]
y = np.hstack((np.zeros(n), np.ones(n)))
return X, y
X, y = dataset_fixed_cov()
qda = QuadraticDiscriminantAnalysis(store_covariance=True)
y_pred = qda.fit(X, y).predict(X)
print("Correct: {}/{}".format(np.sum(y==y_pred), 600))
Expected Results
No divide by zero errors occur and the prediction works correctly.
Actual Results
Prediction fails for all the samples in the class and following errors are raised:
/usr/local/lib/python3.7/site-packages/sklearn/discriminant_analysis.py:693: UserWarning: Variables are collinear
warnings.warn("Variables are collinear")
/usr/local/lib/python3.7/site-packages/sklearn/discriminant_analysis.py:717: RuntimeWarning: divide by zero encountered in power
X2 = np.dot(Xm, R * (S ** (-0.5)))
/usr/local/lib/python3.7/site-packages/sklearn/discriminant_analysis.py:717: RuntimeWarning: invalid value encountered in multiply
X2 = np.dot(Xm, R * (S ** (-0.5)))
/usr/local/lib/python3.7/site-packages/sklearn/discriminant_analysis.py:720: RuntimeWarning: divide by zero encountered in log
u = np.asarray([np.sum(np.log(s)) for s in self.scalings_])
Versions
System:
python: 3.7.4 (default, Sep 12 2019, 15:40:15) [GCC 8.3.0]
executable: /usr/local/bin/python
machine: Linux-5.2.14-arch2-1-ARCH-x86_64-with-debian-10.1
Python deps:
pip: 19.2.3
setuptools: 41.2.0
sklearn: 0.21.3
numpy: 1.17.2
scipy: 1.3.1
Cython: None
pandas: None