4
4
5
5
from .. import confusion_matrix
6
6
from ...utils import check_matplotlib_support
7
+ from ...utils import deprecated
7
8
from ...utils .multiclass import unique_labels
8
9
from ...utils .validation import _deprecate_positional_args
9
10
from ...base import is_classifier
12
13
class ConfusionMatrixDisplay :
13
14
"""Confusion Matrix visualization.
14
15
15
- It is recommend to use :func:`~sklearn.metrics.plot_confusion_matrix` to
16
+ It is recommend to use
17
+ :func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator` or
18
+ :func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` to
16
19
create a :class:`ConfusionMatrixDisplay`. All parameters are stored as
17
20
attributes.
18
21
@@ -161,7 +164,274 @@ def plot(self, *, include_values=True, cmap='viridis',
161
164
self .ax_ = ax
162
165
return self
163
166
167
+ @classmethod
168
+ def from_estimator (
169
+ cls ,
170
+ estimator ,
171
+ X ,
172
+ y ,
173
+ * ,
174
+ labels = None ,
175
+ sample_weight = None ,
176
+ normalize = None ,
177
+ display_labels = None ,
178
+ include_values = True ,
179
+ xticks_rotation = "horizontal" ,
180
+ values_format = None ,
181
+ cmap = "viridis" ,
182
+ ax = None ,
183
+ colorbar = True ,
184
+ ):
185
+ """Plot Confusion Matrix given an estimator and some data.
186
+
187
+ Read more in the :ref:`User Guide <confusion_matrix>`.
188
+
189
+ .. versionadded:: 1.0
164
190
191
+ Parameters
192
+ ----------
193
+ estimator : estimator instance
194
+ Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
195
+ in which the last estimator is a classifier.
196
+
197
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
198
+ Input values.
199
+
200
+ y : array-like of shape (n_samples,)
201
+ Target values.
202
+
203
+ labels : array-like of shape (n_classes,), default=None
204
+ List of labels to index the confusion matrix. This may be used to
205
+ reorder or select a subset of labels. If `None` is given, those
206
+ that appear at least once in `y_true` or `y_pred` are used in
207
+ sorted order.
208
+
209
+ sample_weight : array-like of shape (n_samples,), default=None
210
+ Sample weights.
211
+
212
+ normalize : {'true', 'pred', 'all'}, default=None
213
+ Either to normalize the counts display in the matrix:
214
+
215
+ - if `'true'`, the confusion matrix is normalized over the true
216
+ conditions (e.g. rows);
217
+ - if `'pred'`, the confusion matrix is normalized over the
218
+ predicted conditions (e.g. columns);
219
+ - if `'all'`, the confusion matrix is normalized by the total
220
+ number of samples;
221
+ - if `None` (default), the confusion matrix will not be normalized.
222
+
223
+ display_labels : array-like of shape (n_classes,), default=None
224
+ Target names used for plotting. By default, `labels` will be used
225
+ if it is defined, otherwise the unique labels of `y_true` and
226
+ `y_pred` will be used.
227
+
228
+ include_values : bool, default=True
229
+ Includes values in confusion matrix.
230
+
231
+ xticks_rotation : {'vertical', 'horizontal'} or float, \
232
+ default='horizontal'
233
+ Rotation of xtick labels.
234
+
235
+ values_format : str, default=None
236
+ Format specification for values in confusion matrix. If `None`, the
237
+ format specification is 'd' or '.2g' whichever is shorter.
238
+
239
+ cmap : str or matplotlib Colormap, default='viridis'
240
+ Colormap recognized by matplotlib.
241
+
242
+ ax : matplotlib Axes, default=None
243
+ Axes object to plot on. If `None`, a new figure and axes is
244
+ created.
245
+
246
+ colorbar : bool, default=True
247
+ Whether or not to add a colorbar to the plot.
248
+
249
+ Returns
250
+ -------
251
+ display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
252
+
253
+ See Also
254
+ --------
255
+ ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
256
+ given the true and predicted labels.
257
+
258
+ Examples
259
+ --------
260
+ >>> import matplotlib.pyplot as plt # doctest: +SKIP
261
+ >>> from sklearn.datasets import make_classification
262
+ >>> from sklearn.metrics import ConfusionMatrixDisplay
263
+ >>> from sklearn.model_selection import train_test_split
264
+ >>> from sklearn.svm import SVC
265
+ >>> X, y = make_classification(random_state=0)
266
+ >>> X_train, X_test, y_train, y_test = train_test_split(
267
+ ... X, y, random_state=0)
268
+ >>> clf = SVC(random_state=0)
269
+ >>> clf.fit(X_train, y_train)
270
+ SVC(random_state=0)
271
+ >>> ConfusionMatrixDisplay.from_estimator(
272
+ ... clf, X_test, y_test) # doctest: +SKIP
273
+ >>> plt.show() # doctest: +SKIP
274
+ """
275
+ method_name = f"{ cls .__name__ } .from_estimator"
276
+ check_matplotlib_support (method_name )
277
+ if not is_classifier (estimator ):
278
+ raise ValueError (f"{ method_name } only supports classifiers" )
279
+ y_pred = estimator .predict (X )
280
+
281
+ return cls .from_predictions (
282
+ y ,
283
+ y_pred ,
284
+ sample_weight = sample_weight ,
285
+ labels = labels ,
286
+ normalize = normalize ,
287
+ display_labels = display_labels ,
288
+ include_values = include_values ,
289
+ cmap = cmap ,
290
+ ax = ax ,
291
+ xticks_rotation = xticks_rotation ,
292
+ values_format = values_format ,
293
+ colorbar = colorbar ,
294
+ )
295
+
296
+ @classmethod
297
+ def from_predictions (
298
+ cls ,
299
+ y_true ,
300
+ y_pred ,
301
+ * ,
302
+ labels = None ,
303
+ sample_weight = None ,
304
+ normalize = None ,
305
+ display_labels = None ,
306
+ include_values = True ,
307
+ xticks_rotation = "horizontal" ,
308
+ values_format = None ,
309
+ cmap = "viridis" ,
310
+ ax = None ,
311
+ colorbar = True ,
312
+ ):
313
+ """Plot Confusion Matrix given true and predicted labels.
314
+
315
+ Read more in the :ref:`User Guide <confusion_matrix>`.
316
+
317
+ .. versionadded:: 0.24
318
+
319
+ Parameters
320
+ ----------
321
+ y_true : array-like of shape (n_samples,)
322
+ True labels.
323
+
324
+ y_pred : array-like of shape (n_samples,)
325
+ The predicted labels given by the method `predict` of an
326
+ classifier.
327
+
328
+ labels : array-like of shape (n_classes,), default=None
329
+ List of labels to index the confusion matrix. This may be used to
330
+ reorder or select a subset of labels. If `None` is given, those
331
+ that appear at least once in `y_true` or `y_pred` are used in
332
+ sorted order.
333
+
334
+ sample_weight : array-like of shape (n_samples,), default=None
335
+ Sample weights.
336
+
337
+ normalize : {'true', 'pred', 'all'}, default=None
338
+ Either to normalize the counts display in the matrix:
339
+
340
+ - if `'true'`, the confusion matrix is normalized over the true
341
+ conditions (e.g. rows);
342
+ - if `'pred'`, the confusion matrix is normalized over the
343
+ predicted conditions (e.g. columns);
344
+ - if `'all'`, the confusion matrix is normalized by the total
345
+ number of samples;
346
+ - if `None` (default), the confusion matrix will not be normalized.
347
+
348
+ display_labels : array-like of shape (n_classes,), default=None
349
+ Target names used for plotting. By default, `labels` will be used
350
+ if it is defined, otherwise the unique labels of `y_true` and
351
+ `y_pred` will be used.
352
+
353
+ include_values : bool, default=True
354
+ Includes values in confusion matrix.
355
+
356
+ xticks_rotation : {'vertical', 'horizontal'} or float, \
357
+ default='horizontal'
358
+ Rotation of xtick labels.
359
+
360
+ values_format : str, default=None
361
+ Format specification for values in confusion matrix. If `None`, the
362
+ format specification is 'd' or '.2g' whichever is shorter.
363
+
364
+ cmap : str or matplotlib Colormap, default='viridis'
365
+ Colormap recognized by matplotlib.
366
+
367
+ ax : matplotlib Axes, default=None
368
+ Axes object to plot on. If `None`, a new figure and axes is
369
+ created.
370
+
371
+ colorbar : bool, default=True
372
+ Whether or not to add a colorbar to the plot.
373
+
374
+ Returns
375
+ -------
376
+ display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
377
+
378
+ See Also
379
+ --------
380
+ ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
381
+ given an estimator, the data, and the label.
382
+
383
+ Examples
384
+ --------
385
+ >>> import matplotlib.pyplot as plt # doctest: +SKIP
386
+ >>> from sklearn.datasets import make_classification
387
+ >>> from sklearn.metrics import ConfusionMatrixDisplay
388
+ >>> from sklearn.model_selection import train_test_split
389
+ >>> from sklearn.svm import SVC
390
+ >>> X, y = make_classification(random_state=0)
391
+ >>> X_train, X_test, y_train, y_test = train_test_split(
392
+ ... X, y, random_state=0)
393
+ >>> clf = SVC(random_state=0)
394
+ >>> clf.fit(X_train, y_train)
395
+ SVC(random_state=0)
396
+ >>> y_pred = clf.predict(X_test)
397
+ >>> ConfusionMatrixDisplay.from_predictions(
398
+ ... y_test, y_pred) # doctest: +SKIP
399
+ >>> plt.show() # doctest: +SKIP
400
+ """
401
+ check_matplotlib_support (f"{ cls .__name__ } .from_predictions" )
402
+
403
+ if display_labels is None :
404
+ if labels is None :
405
+ display_labels = unique_labels (y_true , y_pred )
406
+ else :
407
+ display_labels = labels
408
+
409
+ cm = confusion_matrix (
410
+ y_true ,
411
+ y_pred ,
412
+ sample_weight = sample_weight ,
413
+ labels = labels ,
414
+ normalize = normalize ,
415
+ )
416
+
417
+ disp = cls (confusion_matrix = cm , display_labels = display_labels )
418
+
419
+ return disp .plot (
420
+ include_values = include_values ,
421
+ cmap = cmap ,
422
+ ax = ax ,
423
+ xticks_rotation = xticks_rotation ,
424
+ values_format = values_format ,
425
+ colorbar = colorbar ,
426
+ )
427
+
428
+
429
+ @deprecated (
430
+ "Function plot_confusion_matrix is deprecated in 1.0 and will be "
431
+ "removed in 1.2. Use one of the class methods: "
432
+ "ConfusionMatrixDisplay.from_predictions or "
433
+ "ConfusionMatrixDisplay.from_estimator."
434
+ )
165
435
@_deprecate_positional_args
166
436
def plot_confusion_matrix (estimator , X , y_true , * , labels = None ,
167
437
sample_weight = None , normalize = None ,
@@ -173,6 +443,12 @@ def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
173
443
174
444
Read more in the :ref:`User Guide <confusion_matrix>`.
175
445
446
+ .. deprecated:: 1.0
447
+ `plot_confusion_matrix` is deprecated in 1.0 and will be removed in
448
+ 1.2. Use one of the following class methods:
449
+ :func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` or
450
+ :func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator`.
451
+
176
452
Parameters
177
453
----------
178
454
estimator : estimator instance
@@ -194,9 +470,15 @@ def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
194
470
Sample weights.
195
471
196
472
normalize : {'true', 'pred', 'all'}, default=None
197
- Normalizes confusion matrix over the true (rows), predicted (columns)
198
- conditions or all the population. If None, confusion matrix will not be
199
- normalized.
473
+ Either to normalize the counts display in the matrix:
474
+
475
+ - if `'true'`, the confusion matrix is normalized over the true
476
+ conditions (e.g. rows);
477
+ - if `'pred'`, the confusion matrix is normalized over the
478
+ predicted conditions (e.g. columns);
479
+ - if `'all'`, the confusion matrix is normalized by the total
480
+ number of samples;
481
+ - if `None` (default), the confusion matrix will not be normalized.
200
482
201
483
display_labels : array-like of shape (n_classes,), default=None
202
484
Target names used for plotting. By default, `labels` will be used if
0 commit comments