Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[MRG+1] TST Move roc_auc_score from METRIC_UNDEFINED_BINARY to METRIC_UNDEFINED_MULTICLASS #9786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Sep 27, 2017
Merged

Conversation

qinhanmin2014
Copy link
Member

Reference Issue

Proposed in #9567 by @jnothman

What does this implement/fix? Explain your changes.

METRIC_UNDEFINED_BINARY are metrics don't support binary inputs, METRIC_UNDEFINED_MULTICLASS are metrics don't support multiclass inputs, so seems that roc_auc_score belongs to METRIC_UNDEFINED_MULTICLASS .
In order to pass the tests in test_common.py, I have to:
(1)add the check to ensure that the shape of sample_weight is [n_samples] (regression test already in test_common.py)
(2)Carefully choose the scaling value in the test to reduce minor errors introduced by python when doing floating operations(e.g., 4.2 + 2.1 == 6.3 is False).

Any other comments?

cc @jnothman

@jnothman
Copy link
Member

What kinds of discrepancy between expected and actual do we get if we leave the scaling as it was?

@jnothman
Copy link
Member

We may just need a higher tolerance in assert_almost_equal because we're now dealing with a threshold-based metric rather than discrete labels.

@qinhanmin2014
Copy link
Member Author

@jnothman Thanks.
If we use the original scaling value, we have to set decimal=2 to pass the test.
Here is the code to reproduce the test:

n_samples = 50
random_state = check_random_state(0)
y_true = random_state.randint(0, 2, size=(n_samples, ))
y_pred = random_state.randint(0, 2, size=(n_samples, ))
y_score = random_state.random_sample(size=(n_samples,))
rng = np.random.RandomState(0)
sample_weight = rng.randint(1, 10, size=len(y_true))
roc_auc_score(y_true, y_score, sample_weight=sample_weight)
# 0.38036523593708349
roc_auc_score(y_true, y_score, sample_weight=sample_weight*2)
# 0.38036523593708349
roc_auc_score(y_true, y_score, sample_weight=sample_weight*0.3)
# 0.38733004532124787

I checked the implementation and it seems right. The difference seems to be introduced by python when doing floating operations.
So should we remain the original scaling value and set decimal=2?

@jnothman
Copy link
Member

As long as we don't have any estimators that require scores in (0, 1), we could use integer scores to encourage more stability...?

@jnothman
Copy link
Member

Integer scores, where scores are likely to be equal, would also be a more challenging test to pass

@qinhanmin2014
Copy link
Member Author

@jnothman Sorry but I don't quite understand. What do you mean by 'integer scores'? Could you please provide more details? Thanks :)
From my perspective, I can only come up with two solutions:
(1)reduce the precision requirement
(2)keep default precision requirement and choose 'good' scaling value(e.g., 0.5) for sample_weight

@jnothman
Copy link
Member

Currently we have y_score = random_state.random_sample(size=(n_samples,)) generating floats in (0,1). Instead we could use random_state.randint(10, size=(n_samples,)) to generate ints in [0,9].

@qinhanmin2014
Copy link
Member Author

@jnothman
Do you mean something like this?

y_true = random_state.randint(0, 2, size=(n_samples, ))
y_score = random_state.randint(0, 10, size=(n_samples, ))
sample_weight = rng.randint(1, 10, size=len(y_true))
roc_auc_score(y_true, y_score, sample_weight=sample_weight)

Such method work for roc_auc_score, but brier_score_loss require scores in (0, 1), so the common test cannot pass in this way.

@jnothman
Copy link
Member

jnothman commented Sep 18, 2017 via email

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jnothman jnothman changed the title [MRG] TST Move roc_auc_score from METRIC_UNDEFINED_BINARY to METRIC_UNDEFINED_MULTICLASS [MRG+1] TST Move roc_auc_score from METRIC_UNDEFINED_BINARY to METRIC_UNDEFINED_MULTICLASS Sep 18, 2017
@qinhanmin2014
Copy link
Member Author

qinhanmin2014 commented Sep 18, 2017

@jnothman
I got strange test failure when decimal=2 (it can pass locally). Could you please help me? Thanks.
(currently I use decimal=1 to pass the test)

AssertionError: 
Arrays are not almost equal to 2 decimals
roc_auc_score sample_weight is not invariant under scaling
 ACTUAL: 0.38036523593708343
 DESIRED: 0.38733004532124793
>>  raise AssertionError('\nArrays are not almost equal to 2 decimals\nroc_auc_score sample_weight is not invariant under scaling\n ACTUAL: 0.38036523593708343\n DESIRED: 0.38733004532124793')

@jnothman
Copy link
Member

jnothman commented Sep 18, 2017 via email

@qinhanmin2014
Copy link
Member Author

yes on the face of it that looks weird. assert_allclose is a variant which gives more explicit control of tolerances if you'd rather.

@jnothman Thanks. assert_allclose works. CIs are green. Is this OK for you?

@jnothman
Copy link
Member

Yes, but we'll wait for another review to be sure I'm not missing something!

@lesteve
Copy link
Member

lesteve commented Sep 18, 2017

atol=1e-2 seems quite high and is only required for roc_auc metrics (all the other metrics pass the tests for the default values i.e. atol=0, rtol=1e-7), do we understand why this is the case?

With scaling=1/17 I get differences that are bigger than 1% (in relative difference) which does not seem like it can be explained by floating point differences ...

E               AssertionError: 
E               Not equal to tolerance rtol=1e-07, atol=0
E               roc_auc_score sample_weight is not invariant under scaling
E               (mismatch 100.0%)
E                x: array(0.3803652359370835)
E                y: array(0.3757664622767263)

@jnothman
Copy link
Member

jnothman commented Sep 18, 2017 via email

@qinhanmin2014
Copy link
Member Author

qinhanmin2014 commented Sep 18, 2017

@lesteve Thanks for your review.
also ping @jnothman
I traced the whole calculation process of roc_auc_score again. The main reason for the differences lies in function auc in ranking.py. When entering this function, all the parameters (i.e., x and y) are almost the same (difference < 1e-7). But since we introduce very small difference for x, we get different order so we get much bigger difference from np.traz. See the following code snippet:

x1 = [0.6198347107438017, 0.6776859504132231, 0.6776859504132231, 
      0.6776859504132231, 0.6776859504132231, 0.6776859504132231,
      0.74380165289256195, 0.74380165289256195, 0.8925619834710744, 
      0.92561983471074383, 1.0, 1.0, 
      1.0, 1.0, 1.0]
x2 = [0.61983471074380181, 0.6776859504132231, 0.67768595041322344, 
      0.67768595041322321, 0.67768595041322321, 0.67768595041322321,
      0.74380165289256217, 0.74380165289256217, 0.89256198347107418, 
      0.92561983471074349, 0.99999999999999956, 0.99999999999999967,
      0.99999999999999989, 0.99999999999999967, 1.0]
y1 = [0.62096774193548387, 0.62096774193548387, 0.62903225806451613, 
      0.66129032258064513, 0.717741935483871, 0.7338709677419355,
      0.7338709677419355, 0.75806451612903225, 0.75806451612903225,
      0.75806451612903225, 0.75806451612903225, 0.80645161290322576, 
      0.81451612903225812, 0.84677419354838712, 0.94354838709677424]
y2 = [0.62096774193548399, 0.62096774193548399, 0.62903225806451624,
      0.66129032258064524, 0.717741935483871, 0.7338709677419355, 
      0.7338709677419355, 0.75806451612903225, 0.75806451612903225,
      0.75806451612903225, 0.75806451612903225, 0.80645161290322587, 
      0.81451612903225812, 0.84677419354838712, 0.94354838709677424]
x1 = np.array(x1)
x2 = np.array(x2)
y1 = np.array(y1)
y2 = np.array(y2)
np.testing.assert_allclose(x1, x2) #pass
np.testing.assert_allclose(y1, y2) #pass
order = np.lexsort((y1, x1))
x1, y1 = x1[order], y1[order]
order = np.lexsort((y2, x2))
x2, y2 = x1[order], y1[order]
print np.trapz(y1, x1)
# 0.27865902426
print np.trapz(y2, x2)
# 0.275193281792

If we want to solve the problem, we will need to write a more robust sorting function (e.g., regard two float as equal if x1-x2<1e-7) instead of simply calling np.lexsort. Also, we may need to modify the following statement to use np.allclose(..., 0).

optimal_idxs = np.where(np.r_[True,
    np.logical_or(np.diff(fps, 2), np.diff(tps, 2)), True])[0]

We also have tricky ways to access our goal. That is, to round x and y before sorting.

x = np.round(x, 10)
y = np.round(y, 10)
order = np.lexsort((y, x))

We get very similar stable result in this way:

roc_auc_score(y_true, y_score, sample_weight=sample_weight)
# before(unstable):0.38036523593708349
# after(stable):0.38036523591966814

@jnothman
Copy link
Member

jnothman commented Sep 18, 2017 via email

@lesteve
Copy link
Member

lesteve commented Sep 19, 2017

@jnothman I trust you on this one. If you think that atol=1e-2 is fine, then let's do that. Maybe have a special case for auc in the test with a FIXME to show that the higher tolerance is only needed for roc_auc scores?

This reverts commit 6b2cf79.
@qinhanmin2014
Copy link
Member Author

@jnothman So seems that the reason is clear for this problem (np.lexsort and np.traz in function auc) and I come back to the original solution proposed by you. Do you think we need a seperate test for roc_auc_score here as proposed by @lesteve?
BTW, @jnothman @lesteve I'm suddenly wondering why we need to sort in auc. Seems that we already ensure that x(fpr) and y(tpr) are not decreasing here? If so, this can be sloved easily by setting parameter reorder to False in auc?

@lesteve
Copy link
Member

lesteve commented Sep 19, 2017

Do you think we need a seperate test for roc_auc_score here as proposed by @lesteve?

Just for clarity I was not advocating a completely separate test but more a special case, i.e. something like this:

# FIXME: roc_auc scores are more unstable than other scores
kwargs = {'atol': 1e-2} if 'roc_auc' in metric else {}

for scaling in [2, 0.3]:
    assert_almost_equal(
        np.testing.assert_allclose(
            weighted_score,
            metric(y1, y2, sample_weight=sample_weight * scaling),
            atol=1e-2,
            err_msg=("%s sample_weight is not invariant "
                     "under scaling" % name),
            **kwargs))

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from nitpicks, LGTM

@@ -17,6 +17,8 @@ random sampling procedures.

- :class:`decomposition.IncrementalPCA` in Python 2 (bug fix)
- :class:`isotonic.IsotonicRegression` (bug fix)
- :class:`metrics.roc_auc_score` (enhancement)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really think this belongs here? Do we believe this will change user results often enough to caution them here? This makes it seem like their ROC scores will have suddenly changed... I think "bug fix" is more appropriate in any case.

@@ -371,6 +371,20 @@ def test_roc_curve_drop_intermediate():
[1.0, 0.9, 0.7, 0.6, 0.])


def test_roc_curve_fpr_tpr_increasing():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like the fact that elements are sorted for one random sample isn't a very strong assurance. There are edge cases that could be further tested (such as having repeated thresholds), too, but I'm not sure what reasonable edge cases for this test are.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically the edge cases are when the definition of fps are not equal because of floating point errors:

tps = stable_cumsum(y_true * weight)[threshold_idxs]
fps = stable_cumsum(weight)[threshold_idxs] - tps
fps = stable_cumsum((1 - y_true) * weight)[threshold_idxs]

It is not obvious to me how to construct simply an example that does not work but maybe with a little bit of thought there is a way to put a simpler one together.

For full details the best is to look at the definition of _binary_clf_curve, especially how the other variables are defined.

Copy link
Member

@lesteve lesteve Sep 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I found a simpler example:

def test_roc_curve_fpr_tpr_increasing():
    # Ensure that fpr and tpr returned by roc_curve are increasing
    # Construct an edge case with float y_score and sample_weight
    # when some adjacent values of fpr and tpr are the same.
    y_true = [0, 0, 1, 1, 1]
    y_score = [0.1, 0.7, 0.3, 0.4, 0.5]
    sample_weight = np.repeat(0.2, 5)
    fpr, tpr, _ = roc_curve(y_true, y_score,
                            sample_weight=sample_weight)
    assert_equal((np.diff(fpr) < 0).sum(), 0)
    assert_equal((np.diff(tpr) < 0).sum(), 0)

Are you happier with this one @jnothman?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Pdb) np.diff(fpr)
array([  5.00000000e-01,   0.00000000e+00,   2.22044605e-16,
        -3.33066907e-16,   5.00000000e-01])

@@ -371,6 +371,20 @@ def test_roc_curve_drop_intermediate():
[1.0, 0.9, 0.7, 0.6, 0.])


def test_roc_curve_fpr_tpr_increasing():
# Ensure that fpr and tpr returned by roc_curve are increasing
# Regression test for issue #9786
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really a regression test

y_score = rng.random_sample(size=(n_samples,))
sample_weight = rng.randint(1, 10, size=(n_samples, ))
fpr, tpr, _ = roc_curve(y_true, y_score,
sample_weight=sample_weight * 0.2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The * 0.2 on sample weight is mysterious and deserves comment.

@qinhanmin2014
Copy link
Member Author

@jnothman Thanks for your precious time :) Comments addressed.

Do you really think this belongs here? Do we believe this will change user results often enough to caution them here? This makes it seem like their ROC scores will have suddenly changed... I think "bug fix" is more appropriate in any case.

After consideration, I remove the statement about roc_curve (the change of the result is really small, most time much less than 1e-7) and move the statement about roc_auc_score to bug fix section.

It's not really a regression test

I remove the comment since I now think that it is unnecessary. But from my perspective, this is a regression test (fail on master). Because in master, we cannot ensure that fpr and tpr are increasing because we use subtraction instead of accumulation. Though the error is really small, this is actually the core reason of the issue (it cause the wrong sort and the wrong roc_auc_score).

I feel like the fact that elements are sorted for one random sample isn't a very strong assurance. There are edge cases that could be further tested (such as having repeated thresholds), too, but I'm not sure what reasonable edge cases for this test are.

The test is constructed based on @lesteve's suggestion. I think it hit the edge case (float y_score and sample_weight when some adjacent values of fpr and tpr are the same). Now for adjacent same value, we no longer get something not increasing( e.g., [a+1e-10, a-1e-10] ). I have added a comment.

WDYT? Thanks :)

@jnothman
Copy link
Member

Will merge when @lesteve approves. Might even be reasonable to include in 0.19.1

@jnothman jnothman added this to the 0.19.1 milestone Sep 26, 2017
@qinhanmin2014
Copy link
Member Author

@lesteve Thanks for the review and the improvement. I have slightly updated the comment of the test.
Seems that the test at least hit some of the edge case (it fails on master). In such cases, y_score and sample_weight are float, some adjacent values of fpr and tpr are actually the same. In python, the small error introduced when doing float calculation seems unavoidable. In this PR, at least we guarantee that fpr and tpr are increasing and avoid significant error for roc_auc_score. WDYT?

@@ -108,6 +107,11 @@ Decomposition, manifold learning and clustering
- Fixed a bug in :func:`datasets.fetch_kddcup99`, where data were not properly
shuffled. :issue:`9731` by `Nicolas Goix`_.

Model evaluation and meta-estimators
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the right subsection, there should be one with "metrics" in it, look in older version whats_new if you can not find it in this file.

@@ -108,6 +107,11 @@ Decomposition, manifold learning and clustering
- Fixed a bug in :func:`datasets.fetch_kddcup99`, where data were not properly
shuffled. :issue:`9731` by `Nicolas Goix`_.

Model evaluation and meta-estimators

- Fixed a bug in :func:`metrics.roc_auc_score`, where float calculations sometimes
Copy link
Member

@lesteve lesteve Sep 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can written a bit better, maybe something like:

Fixed bug due to floating point error in :func:`metrics.roc_auc_score`
with non-integer sample weights. :issue:`9786` by :user:`Hanmin Qin
<qinhanmin2014>`.

@qinhanmin2014
Copy link
Member Author

@lesteve Thanks a lot for your help :) I totally agree with all your suggestions and have learnt a lot.

@lesteve
Copy link
Member

lesteve commented Sep 27, 2017

I think this is good to go, merging, thanks a lot!

Another piece of advice while I am at it: chose better names for your branches, you can do better than test-feature-3 surely.

@lesteve lesteve merged commit 8fb648a into scikit-learn:master Sep 27, 2017
@lesteve
Copy link
Member

lesteve commented Sep 27, 2017

Follow up PRs @qinhanmin2014 if you are up to it (I would rather have two separate PRs in this case):

  • with this change reorder=False in the auc function is not used in our code. I think we should deprecate the reorder=False parameter and potentially (up for debate) have some threshold in case there are some small negative values in np.diff(tpr) or np.diff(fpr)
  • in this PR we spotted a place where check_consistent_lengths(X, y) was used where check_consistent_lengths(X, y, sample_weight) should have called it would be good to double-check that this error is not present in some other places in our codebase.

@jnothman
Copy link
Member

jnothman commented Sep 27, 2017 via email

@qinhanmin2014
Copy link
Member Author

@lesteve @jnothman Thanks a lot :) I'll try taking these issues.

@qinhanmin2014
Copy link
Member Author

@lesteve
I have opened #9870 to address your second concern since it might need more discussions. Also, it seems that part of it depends on #9828. The PR already have a +1 so if you have time, please have a look at it. Thanks a lot :)

maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
* ensure fpr and tpr are increasing in roc_curve with non integer sample weights
* add tests and move roc_auc_score from METRIC_UNDEFINED_BINARY to METRIC_UNDEFINED_MULTICLASS
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
* ensure fpr and tpr are increasing in roc_curve with non integer sample weights
* add tests and move roc_auc_score from METRIC_UNDEFINED_BINARY to METRIC_UNDEFINED_MULTICLASS
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.