Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8c6a045

Browse filesBrowse files
glemaitrethomasjpfanadrinjalali
authored
ENH/DEP add class method and deprecate plot function for confusion matrix (#18543)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent 7404a82 commit 8c6a045
Copy full SHA for 8c6a045

File tree

7 files changed

+719
-11
lines changed
Filter options

7 files changed

+719
-11
lines changed

‎doc/modules/model_evaluation.rst

Copy file name to clipboardExpand all lines: doc/modules/model_evaluation.rst
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ predicted to be in group :math:`j`. Here is an example::
613613
[0, 0, 1],
614614
[1, 0, 2]])
615615

616-
:func:`plot_confusion_matrix` can be used to visually represent a confusion
616+
:class:`ConfusionMatrixDisplay` can be used to visually represent a confusion
617617
matrix as shown in the
618618
:ref:`sphx_glr_auto_examples_model_selection_plot_confusion_matrix.py`
619619
example, which creates the following figure:

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ Changelog
9191
:pr:`17743` by :user:`Maria Telenczuk <maikia>` and
9292
:user:`Alexandre Gramfort <agramfort>`.
9393

94+
:mod:`sklearn.metrics`
95+
......................
96+
97+
- |API| :class:`metrics.ConfusionMatrixDisplay` exposes two class methods
98+
:func:`~metrics.ConfusionMatrixDisplay.from_estimator` and
99+
:func:`~metrics.ConfusionMatrixDisplay.from_predictions` allowing to create
100+
a confusion matrix plot using an estimator or the predictions.
101+
:func:`metrics.plot_confusion_matrix` is deprecated in favor of these two
102+
class methods and will be removed in 1.2.
103+
:pr:`18543` by `Guillaume Lemaitre`_.
104+
94105
:mod:`sklearn.naive_bayes`
95106
..........................
96107

‎examples/classification/plot_digits_classification.py

Copy file name to clipboardExpand all lines: examples/classification/plot_digits_classification.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
# We can also plot a :ref:`confusion matrix <confusion_matrix>` of the
9696
# true digit values and the predicted digit values.
9797

98-
disp = metrics.plot_confusion_matrix(clf, X_test, y_test)
98+
disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)
9999
disp.figure_.suptitle("Confusion Matrix")
100100
print(f"Confusion matrix:\n{disp.confusion_matrix}")
101101

‎examples/model_selection/plot_confusion_matrix.py

Copy file name to clipboardExpand all lines: examples/model_selection/plot_confusion_matrix.py
+5-5Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from sklearn import svm, datasets
3333
from sklearn.model_selection import train_test_split
34-
from sklearn.metrics import plot_confusion_matrix
34+
from sklearn.metrics import ConfusionMatrixDisplay
3535

3636
# import some data to play with
3737
iris = datasets.load_iris()
@@ -52,10 +52,10 @@
5252
titles_options = [("Confusion matrix, without normalization", None),
5353
("Normalized confusion matrix", 'true')]
5454
for title, normalize in titles_options:
55-
disp = plot_confusion_matrix(classifier, X_test, y_test,
56-
display_labels=class_names,
57-
cmap=plt.cm.Blues,
58-
normalize=normalize)
55+
disp = ConfusionMatrixDisplay.from_estimator(
56+
classifier, X_test, y_test, display_labels=class_names,
57+
cmap=plt.cm.Blues, normalize=normalize
58+
)
5959
disp.ax_.set_title(title)
6060

6161
print(title)

‎sklearn/metrics/_plot/confusion_matrix.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_plot/confusion_matrix.py
+286-4Lines changed: 286 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .. import confusion_matrix
66
from ...utils import check_matplotlib_support
7+
from ...utils import deprecated
78
from ...utils.multiclass import unique_labels
89
from ...utils.validation import _deprecate_positional_args
910
from ...base import is_classifier
@@ -12,7 +13,9 @@
1213
class ConfusionMatrixDisplay:
1314
"""Confusion Matrix visualization.
1415
15-
It is recommend to use :func:`~sklearn.metrics.plot_confusion_matrix` to
16+
It is recommend to use
17+
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator` or
18+
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` to
1619
create a :class:`ConfusionMatrixDisplay`. All parameters are stored as
1720
attributes.
1821
@@ -161,7 +164,274 @@ def plot(self, *, include_values=True, cmap='viridis',
161164
self.ax_ = ax
162165
return self
163166

167+
@classmethod
168+
def from_estimator(
169+
cls,
170+
estimator,
171+
X,
172+
y,
173+
*,
174+
labels=None,
175+
sample_weight=None,
176+
normalize=None,
177+
display_labels=None,
178+
include_values=True,
179+
xticks_rotation="horizontal",
180+
values_format=None,
181+
cmap="viridis",
182+
ax=None,
183+
colorbar=True,
184+
):
185+
"""Plot Confusion Matrix given an estimator and some data.
186+
187+
Read more in the :ref:`User Guide <confusion_matrix>`.
188+
189+
.. versionadded:: 1.0
164190
191+
Parameters
192+
----------
193+
estimator : estimator instance
194+
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
195+
in which the last estimator is a classifier.
196+
197+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
198+
Input values.
199+
200+
y : array-like of shape (n_samples,)
201+
Target values.
202+
203+
labels : array-like of shape (n_classes,), default=None
204+
List of labels to index the confusion matrix. This may be used to
205+
reorder or select a subset of labels. If `None` is given, those
206+
that appear at least once in `y_true` or `y_pred` are used in
207+
sorted order.
208+
209+
sample_weight : array-like of shape (n_samples,), default=None
210+
Sample weights.
211+
212+
normalize : {'true', 'pred', 'all'}, default=None
213+
Either to normalize the counts display in the matrix:
214+
215+
- if `'true'`, the confusion matrix is normalized over the true
216+
conditions (e.g. rows);
217+
- if `'pred'`, the confusion matrix is normalized over the
218+
predicted conditions (e.g. columns);
219+
- if `'all'`, the confusion matrix is normalized by the total
220+
number of samples;
221+
- if `None` (default), the confusion matrix will not be normalized.
222+
223+
display_labels : array-like of shape (n_classes,), default=None
224+
Target names used for plotting. By default, `labels` will be used
225+
if it is defined, otherwise the unique labels of `y_true` and
226+
`y_pred` will be used.
227+
228+
include_values : bool, default=True
229+
Includes values in confusion matrix.
230+
231+
xticks_rotation : {'vertical', 'horizontal'} or float, \
232+
default='horizontal'
233+
Rotation of xtick labels.
234+
235+
values_format : str, default=None
236+
Format specification for values in confusion matrix. If `None`, the
237+
format specification is 'd' or '.2g' whichever is shorter.
238+
239+
cmap : str or matplotlib Colormap, default='viridis'
240+
Colormap recognized by matplotlib.
241+
242+
ax : matplotlib Axes, default=None
243+
Axes object to plot on. If `None`, a new figure and axes is
244+
created.
245+
246+
colorbar : bool, default=True
247+
Whether or not to add a colorbar to the plot.
248+
249+
Returns
250+
-------
251+
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
252+
253+
See Also
254+
--------
255+
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
256+
given the true and predicted labels.
257+
258+
Examples
259+
--------
260+
>>> import matplotlib.pyplot as plt # doctest: +SKIP
261+
>>> from sklearn.datasets import make_classification
262+
>>> from sklearn.metrics import ConfusionMatrixDisplay
263+
>>> from sklearn.model_selection import train_test_split
264+
>>> from sklearn.svm import SVC
265+
>>> X, y = make_classification(random_state=0)
266+
>>> X_train, X_test, y_train, y_test = train_test_split(
267+
... X, y, random_state=0)
268+
>>> clf = SVC(random_state=0)
269+
>>> clf.fit(X_train, y_train)
270+
SVC(random_state=0)
271+
>>> ConfusionMatrixDisplay.from_estimator(
272+
... clf, X_test, y_test) # doctest: +SKIP
273+
>>> plt.show() # doctest: +SKIP
274+
"""
275+
method_name = f"{cls.__name__}.from_estimator"
276+
check_matplotlib_support(method_name)
277+
if not is_classifier(estimator):
278+
raise ValueError(f"{method_name} only supports classifiers")
279+
y_pred = estimator.predict(X)
280+
281+
return cls.from_predictions(
282+
y,
283+
y_pred,
284+
sample_weight=sample_weight,
285+
labels=labels,
286+
normalize=normalize,
287+
display_labels=display_labels,
288+
include_values=include_values,
289+
cmap=cmap,
290+
ax=ax,
291+
xticks_rotation=xticks_rotation,
292+
values_format=values_format,
293+
colorbar=colorbar,
294+
)
295+
296+
@classmethod
297+
def from_predictions(
298+
cls,
299+
y_true,
300+
y_pred,
301+
*,
302+
labels=None,
303+
sample_weight=None,
304+
normalize=None,
305+
display_labels=None,
306+
include_values=True,
307+
xticks_rotation="horizontal",
308+
values_format=None,
309+
cmap="viridis",
310+
ax=None,
311+
colorbar=True,
312+
):
313+
"""Plot Confusion Matrix given true and predicted labels.
314+
315+
Read more in the :ref:`User Guide <confusion_matrix>`.
316+
317+
.. versionadded:: 0.24
318+
319+
Parameters
320+
----------
321+
y_true : array-like of shape (n_samples,)
322+
True labels.
323+
324+
y_pred : array-like of shape (n_samples,)
325+
The predicted labels given by the method `predict` of an
326+
classifier.
327+
328+
labels : array-like of shape (n_classes,), default=None
329+
List of labels to index the confusion matrix. This may be used to
330+
reorder or select a subset of labels. If `None` is given, those
331+
that appear at least once in `y_true` or `y_pred` are used in
332+
sorted order.
333+
334+
sample_weight : array-like of shape (n_samples,), default=None
335+
Sample weights.
336+
337+
normalize : {'true', 'pred', 'all'}, default=None
338+
Either to normalize the counts display in the matrix:
339+
340+
- if `'true'`, the confusion matrix is normalized over the true
341+
conditions (e.g. rows);
342+
- if `'pred'`, the confusion matrix is normalized over the
343+
predicted conditions (e.g. columns);
344+
- if `'all'`, the confusion matrix is normalized by the total
345+
number of samples;
346+
- if `None` (default), the confusion matrix will not be normalized.
347+
348+
display_labels : array-like of shape (n_classes,), default=None
349+
Target names used for plotting. By default, `labels` will be used
350+
if it is defined, otherwise the unique labels of `y_true` and
351+
`y_pred` will be used.
352+
353+
include_values : bool, default=True
354+
Includes values in confusion matrix.
355+
356+
xticks_rotation : {'vertical', 'horizontal'} or float, \
357+
default='horizontal'
358+
Rotation of xtick labels.
359+
360+
values_format : str, default=None
361+
Format specification for values in confusion matrix. If `None`, the
362+
format specification is 'd' or '.2g' whichever is shorter.
363+
364+
cmap : str or matplotlib Colormap, default='viridis'
365+
Colormap recognized by matplotlib.
366+
367+
ax : matplotlib Axes, default=None
368+
Axes object to plot on. If `None`, a new figure and axes is
369+
created.
370+
371+
colorbar : bool, default=True
372+
Whether or not to add a colorbar to the plot.
373+
374+
Returns
375+
-------
376+
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
377+
378+
See Also
379+
--------
380+
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
381+
given an estimator, the data, and the label.
382+
383+
Examples
384+
--------
385+
>>> import matplotlib.pyplot as plt # doctest: +SKIP
386+
>>> from sklearn.datasets import make_classification
387+
>>> from sklearn.metrics import ConfusionMatrixDisplay
388+
>>> from sklearn.model_selection import train_test_split
389+
>>> from sklearn.svm import SVC
390+
>>> X, y = make_classification(random_state=0)
391+
>>> X_train, X_test, y_train, y_test = train_test_split(
392+
... X, y, random_state=0)
393+
>>> clf = SVC(random_state=0)
394+
>>> clf.fit(X_train, y_train)
395+
SVC(random_state=0)
396+
>>> y_pred = clf.predict(X_test)
397+
>>> ConfusionMatrixDisplay.from_predictions(
398+
... y_test, y_pred) # doctest: +SKIP
399+
>>> plt.show() # doctest: +SKIP
400+
"""
401+
check_matplotlib_support(f"{cls.__name__}.from_predictions")
402+
403+
if display_labels is None:
404+
if labels is None:
405+
display_labels = unique_labels(y_true, y_pred)
406+
else:
407+
display_labels = labels
408+
409+
cm = confusion_matrix(
410+
y_true,
411+
y_pred,
412+
sample_weight=sample_weight,
413+
labels=labels,
414+
normalize=normalize,
415+
)
416+
417+
disp = cls(confusion_matrix=cm, display_labels=display_labels)
418+
419+
return disp.plot(
420+
include_values=include_values,
421+
cmap=cmap,
422+
ax=ax,
423+
xticks_rotation=xticks_rotation,
424+
values_format=values_format,
425+
colorbar=colorbar,
426+
)
427+
428+
429+
@deprecated(
430+
"Function plot_confusion_matrix is deprecated in 1.0 and will be "
431+
"removed in 1.2. Use one of the class methods: "
432+
"ConfusionMatrixDisplay.from_predictions or "
433+
"ConfusionMatrixDisplay.from_estimator."
434+
)
165435
@_deprecate_positional_args
166436
def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
167437
sample_weight=None, normalize=None,
@@ -173,6 +443,12 @@ def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
173443
174444
Read more in the :ref:`User Guide <confusion_matrix>`.
175445
446+
.. deprecated:: 1.0
447+
`plot_confusion_matrix` is deprecated in 1.0 and will be removed in
448+
1.2. Use one of the following class methods:
449+
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` or
450+
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator`.
451+
176452
Parameters
177453
----------
178454
estimator : estimator instance
@@ -194,9 +470,15 @@ def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
194470
Sample weights.
195471
196472
normalize : {'true', 'pred', 'all'}, default=None
197-
Normalizes confusion matrix over the true (rows), predicted (columns)
198-
conditions or all the population. If None, confusion matrix will not be
199-
normalized.
473+
Either to normalize the counts display in the matrix:
474+
475+
- if `'true'`, the confusion matrix is normalized over the true
476+
conditions (e.g. rows);
477+
- if `'pred'`, the confusion matrix is normalized over the
478+
predicted conditions (e.g. columns);
479+
- if `'all'`, the confusion matrix is normalized by the total
480+
number of samples;
481+
- if `None` (default), the confusion matrix will not be normalized.
200482
201483
display_labels : array-like of shape (n_classes,), default=None
202484
Target names used for plotting. By default, `labels` will be used if

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.