Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5f15950

Browse filesBrowse files
committed
Add a scorer for model evaluation
# More detailed explanatory text, if necessary. Wrap it to about 72 # characters or so. In some contexts, the first line is treated as the # subject of the commit and the rest of the text as the body. The # blank line separating the summary from the body is critical (unless # you omit the body entirely); various tools like `log`, `shortlog` # and `rebase` can get confused if you run the two together. # Explain the problem that this commit is solving. Focus on why you # are making this change as opposed to how (the code explains that). # Are there side effects or other unintuitive consequences of this # change? Here's the place to explain them. # Further paragraphs come after blank lines. # - Bullet points are okay, too # - Typically a hyphen or asterisk is used for the bullet, preceded # by a single space, with blank lines in between, but conventions # vary here # If you use an issue tracker, put references to them at the bottom, # like this: # Resolves: scikit-learn#123 # See also: scikit-learn#456, scikit-learn#789
1 parent 51f8c15 commit 5f15950
Copy full SHA for 5f15950

File tree

Expand file treeCollapse file tree

3 files changed

+14
-5
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+14
-5
lines changed

‎doc/modules/model_evaluation.rst

Copy file name to clipboardExpand all lines: doc/modules/model_evaluation.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Scoring Function
6060
**Classification**
6161
'accuracy' :func:`metrics.accuracy_score`
6262
'balanced_accuracy' :func:`metrics.balanced_accuracy_score`
63+
'top_k_accuracy' :func:`metrics.top_k_accuracy_score`
6364
'average_precision' :func:`metrics.average_precision_score`
6465
'neg_brier_score' :func:`metrics.brier_score_loss`
6566
'f1' :func:`metrics.f1_score` for binary targets

‎sklearn/metrics/_scorer.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_scorer.py
+9-4Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
from . import (r2_score, median_absolute_error, max_error, mean_absolute_error,
2828
mean_squared_error, mean_squared_log_error,
2929
mean_poisson_deviance, mean_gamma_deviance, accuracy_score,
30-
f1_score, roc_auc_score, average_precision_score,
31-
precision_score, recall_score, log_loss,
32-
balanced_accuracy_score, explained_variance_score,
30+
top_k_accuracy_score, f1_score, roc_auc_score,
31+
average_precision_score, precision_score, recall_score,
32+
log_loss, balanced_accuracy_score, explained_variance_score,
3333
brier_score_loss, jaccard_score, mean_absolute_percentage_error)
3434

3535
from .cluster import adjusted_rand_score
@@ -610,6 +610,9 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
610610
balanced_accuracy_scorer = make_scorer(balanced_accuracy_score)
611611

612612
# Score functions that need decision values
613+
top_k_accuracy_scorer = make_scorer(top_k_accuracy_score,
614+
greater_is_better=True,
615+
needs_threshold=True)
613616
roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True,
614617
needs_threshold=True)
615618
average_precision_scorer = make_scorer(average_precision_score,
@@ -658,7 +661,9 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
658661
neg_root_mean_squared_error=neg_root_mean_squared_error_scorer,
659662
neg_mean_poisson_deviance=neg_mean_poisson_deviance_scorer,
660663
neg_mean_gamma_deviance=neg_mean_gamma_deviance_scorer,
661-
accuracy=accuracy_scorer, roc_auc=roc_auc_scorer,
664+
accuracy=accuracy_scorer,
665+
top_k_accuracy=top_k_accuracy_scorer,
666+
roc_auc=roc_auc_scorer,
662667
roc_auc_ovr=roc_auc_ovr_scorer,
663668
roc_auc_ovo=roc_auc_ovo_scorer,
664669
roc_auc_ovr_weighted=roc_auc_ovr_weighted_scorer,

‎sklearn/metrics/tests/test_score_objects.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_score_objects.py
+4-1Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
'max_error', 'neg_mean_poisson_deviance',
5454
'neg_mean_gamma_deviance']
5555

56-
CLF_SCORERS = ['accuracy', 'balanced_accuracy',
56+
CLF_SCORERS = ['accuracy', 'balanced_accuracy', 'top_k_accuracy',
5757
'f1', 'f1_weighted', 'f1_macro', 'f1_micro',
5858
'roc_auc', 'average_precision', 'precision',
5959
'precision_weighted', 'precision_macro', 'precision_micro',
@@ -496,6 +496,9 @@ def test_classification_scorer_sample_weight():
496496
if name in REGRESSION_SCORERS:
497497
# skip the regression scores
498498
continue
499+
if name == 'top_k_accuracy':
500+
# in the binary case k > 1 will always lead to a perfect score
501+
scorer._kwargs = {'k': 1}
499502
if name in MULTILABEL_ONLY_SCORERS:
500503
target = y_ml_test
501504
else:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.