From acf7e79c4beaf006617f97e9d6f0db86aa2826a7 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Thu, 15 Dec 2022 17:58:33 +0100 Subject: [PATCH 1/6] improve error handling for _compute_mi_cd check if instances are left after masking of unique labels --- doc/whats_new/v1.3.rst | 3 +++ sklearn/feature_selection/_mutual_info.py | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 68a569acb14e5..6198fc3a82c05 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -52,6 +52,9 @@ Changelog :pr:`24935` by :user:`Seladus `, :user:`Guillaume Lemaitre `, and :user:`Dea María Léon `. +- |Enhancement| Imporve error handling in :math:`feature_selection.mutual_info_classif` now + checks if instances are left after masking of unique labels. :pr:`` by :user:`makoeppel`. + Code and Documentation Contributors ----------------------------------- diff --git a/sklearn/feature_selection/_mutual_info.py b/sklearn/feature_selection/_mutual_info.py index 2a03eb7dfd2fe..f7095c8e48550 100644 --- a/sklearn/feature_selection/_mutual_info.py +++ b/sklearn/feature_selection/_mutual_info.py @@ -9,7 +9,7 @@ from ..neighbors import NearestNeighbors, KDTree from ..preprocessing import scale from ..utils import check_random_state -from ..utils.validation import check_array, check_X_y +from ..utils.validation import check_array, check_X_y, _num_samples from ..utils.multiclass import check_classification_targets @@ -135,6 +135,14 @@ def _compute_mi_cd(c, d, n_neighbors): c = c[mask] radius = radius[mask] + # check if after masking instances are left + if _num_samples(c) == 0: + raise ValueError( + "Found array with 0 sample(s) after masking" + " points with unique labels. Ensure to have at least" + " two instances with the same label." + ) + kd = KDTree(c) m_all = kd.query_radius(c, radius, count_only=True, return_distance=False) m_all = np.array(m_all) From fcb2627503f0159cb7b2cc87196fe9bbb8c6d398 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Thu, 15 Dec 2022 18:04:20 +0100 Subject: [PATCH 2/6] update pull request ID in doc file --- doc/whats_new/v1.3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 6198fc3a82c05..d440d6714a4f7 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -53,7 +53,7 @@ Changelog :user:`Dea María Léon `. - |Enhancement| Imporve error handling in :math:`feature_selection.mutual_info_classif` now - checks if instances are left after masking of unique labels. :pr:`` by :user:`makoeppel`. + checks if instances are left after masking of unique labels. :pr:`25192` by :user:`makoeppel`. Code and Documentation Contributors ----------------------------------- From 0a63d5ade73bc3bd47ce8b009c270d6692fa5905 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Mon, 19 Dec 2022 11:35:21 +0100 Subject: [PATCH 3/6] add corrections from @Micky774 --- doc/whats_new/v1.3.rst | 2 +- sklearn/feature_selection/_mutual_info.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index d440d6714a4f7..85220ffc6f73a 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -52,7 +52,7 @@ Changelog :pr:`24935` by :user:`Seladus `, :user:`Guillaume Lemaitre `, and :user:`Dea María Léon `. -- |Enhancement| Imporve error handling in :math:`feature_selection.mutual_info_classif` now +- |Enhancement| Imporve error handling in :func:`feature_selection.mutual_info_classif` now checks if instances are left after masking of unique labels. :pr:`25192` by :user:`makoeppel`. Code and Documentation Contributors diff --git a/sklearn/feature_selection/_mutual_info.py b/sklearn/feature_selection/_mutual_info.py index f7095c8e48550..d0c2ce1ec99a2 100644 --- a/sklearn/feature_selection/_mutual_info.py +++ b/sklearn/feature_selection/_mutual_info.py @@ -138,9 +138,9 @@ def _compute_mi_cd(c, d, n_neighbors): # check if after masking instances are left if _num_samples(c) == 0: raise ValueError( - "Found array with 0 sample(s) after masking" - " points with unique labels. Ensure to have at least" - " two instances with the same label." + "Found array with 0 samples after masking" + " points with unique labels. Ensure that at least" + " two instances share the same label." ) kd = KDTree(c) From 6e25778a93d62b4efa94f828fbbac5ba4a356746 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Mon, 19 Dec 2022 11:53:54 +0100 Subject: [PATCH 4/6] add test for ValueError --- .../tests/test_mutual_info.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sklearn/feature_selection/tests/test_mutual_info.py b/sklearn/feature_selection/tests/test_mutual_info.py index f39e4a5738b21..1ef4360fc3d0f 100644 --- a/sklearn/feature_selection/tests/test_mutual_info.py +++ b/sklearn/feature_selection/tests/test_mutual_info.py @@ -236,3 +236,25 @@ def test_mutual_information_symmetry_classif_regression(correlated, global_rando ) assert mi_classif == pytest.approx(mi_regression) + + +def test_mutual_info_error_handling_for_unique_labels(): + """Check that the correct ValueError is raised when calling `mutual_info_classif` + with only unique labels. + """ + + with pytest.raises(ValueError) as exc_info: + a = [[1, 0, 1], [0, 1, 1]] + b = [0, 1] + + mutual_info_classif(a, b) + + exception_raised = exc_info.value + exception_expected = ( + "Found array with 0 samples after masking" + " points with unique labels. Ensure that at least" + " two instances share the same label." + ) + + # check if the exception is found + assert exception_raised != exception_expected From ff1b8555cc6bc9bbbb7d4e03c3c285e892b3102c Mon Sep 17 00:00:00 2001 From: makoeppel Date: Mon, 19 Dec 2022 12:09:12 +0100 Subject: [PATCH 5/6] fix assert --- sklearn/feature_selection/tests/test_mutual_info.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/feature_selection/tests/test_mutual_info.py b/sklearn/feature_selection/tests/test_mutual_info.py index 1ef4360fc3d0f..e288dd1011c82 100644 --- a/sklearn/feature_selection/tests/test_mutual_info.py +++ b/sklearn/feature_selection/tests/test_mutual_info.py @@ -256,5 +256,5 @@ def test_mutual_info_error_handling_for_unique_labels(): " two instances share the same label." ) - # check if the exception is found - assert exception_raised != exception_expected + # check if the exception error has the correct value is found + assert exception_raised.args[0] == exception_expected From b9bbfc7f9bc2ee5d3c42a7a6a204b3c098352bd8 Mon Sep 17 00:00:00 2001 From: makoeppel Date: Tue, 3 Jan 2023 18:44:48 +0100 Subject: [PATCH 6/6] add requested changes --- doc/whats_new/v1.3.rst | 5 +++-- sklearn/feature_selection/_mutual_info.py | 1 - .../feature_selection/tests/test_mutual_info.py | 16 +++++----------- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 715604b4233d3..a9eaf05e852a2 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -95,8 +95,9 @@ Changelog :pr:`24935` by :user:`Seladus `, :user:`Guillaume Lemaitre `, and :user:`Dea María Léon `. -- |Enhancement| Imporve error handling in :func:`feature_selection.mutual_info_classif` now - checks if instances are left after masking of unique labels. :pr:`25192` by :user:`makoeppel`. +- |Enhancement| Improve error handling in :func:`feature_selection.mutual_info_classif` + that now checks if instances are left after masking of unique labels. + :pr:`25192` by :user:`makoeppel`. Code and Documentation Contributors ----------------------------------- diff --git a/sklearn/feature_selection/_mutual_info.py b/sklearn/feature_selection/_mutual_info.py index d0c2ce1ec99a2..c2d3fce5f518b 100644 --- a/sklearn/feature_selection/_mutual_info.py +++ b/sklearn/feature_selection/_mutual_info.py @@ -135,7 +135,6 @@ def _compute_mi_cd(c, d, n_neighbors): c = c[mask] radius = radius[mask] - # check if after masking instances are left if _num_samples(c) == 0: raise ValueError( "Found array with 0 samples after masking" diff --git a/sklearn/feature_selection/tests/test_mutual_info.py b/sklearn/feature_selection/tests/test_mutual_info.py index e288dd1011c82..ae68ba0971b75 100644 --- a/sklearn/feature_selection/tests/test_mutual_info.py +++ b/sklearn/feature_selection/tests/test_mutual_info.py @@ -243,18 +243,12 @@ def test_mutual_info_error_handling_for_unique_labels(): with only unique labels. """ - with pytest.raises(ValueError) as exc_info: - a = [[1, 0, 1], [0, 1, 1]] - b = [0, 1] - - mutual_info_classif(a, b) - - exception_raised = exc_info.value - exception_expected = ( + a = [[1, 0, 1], [0, 1, 1]] + b = [0, 1] + err_msg = ( "Found array with 0 samples after masking" " points with unique labels. Ensure that at least" " two instances share the same label." ) - - # check if the exception error has the correct value is found - assert exception_raised.args[0] == exception_expected + with pytest.raises(ValueError, match=err_msg): + mutual_info_classif(a, b)