Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit dc6c01c

Browse filesBrowse files
authored
array API support for mean_absolute_percentage_error (#29300)
1 parent 1813b4a commit dc6c01c
Copy full SHA for dc6c01c

File tree

4 files changed

+21
-6
lines changed
Filter options

4 files changed

+21
-6
lines changed

‎doc/modules/array_api.rst

Copy file name to clipboardExpand all lines: doc/modules/array_api.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ Metrics
117117
- :func:`sklearn.metrics.d2_tweedie_score`
118118
- :func:`sklearn.metrics.max_error`
119119
- :func:`sklearn.metrics.mean_absolute_error`
120+
- :func:`sklearn.metrics.mean_absolute_percentage_error`
120121
- :func:`sklearn.metrics.mean_gamma_deviance`
121122
- :func:`sklearn.metrics.mean_squared_error`
122123
- :func:`sklearn.metrics.mean_tweedie_deviance`

‎doc/whats_new/v1.6.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.6.rst
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ See :ref:`array_api` for more details.
3737
- :func:`sklearn.metrics.max_error` :pr:`29212` by :user:`Edoardo Abati <EdAbati>`;
3838
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`
3939
and :pr:`29143` by :user:`Tialo <Tialo>` and :user:`Loïc Estève <lesteve>`;
40-
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :usser:`Emily Chen <EmilyXinyi>`;
40+
- :func:`sklearn.metrics.mean_absolute_percentage_error` :pr:`29300` by :user:`Emily Chen <EmilyXinyi>`;
41+
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :user:`Emily Chen <EmilyXinyi>`;
4142
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
4243
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
4344
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;

‎sklearn/metrics/_regression.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_regression.py
+14-5Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,21 +395,30 @@ def mean_absolute_percentage_error(
395395
>>> mean_absolute_percentage_error(y_true, y_pred)
396396
112589990684262.48
397397
"""
398+
input_arrays = [y_true, y_pred, sample_weight, multioutput]
399+
xp, _ = get_namespace(*input_arrays)
400+
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
401+
398402
y_type, y_true, y_pred, multioutput = _check_reg_targets(
399403
y_true, y_pred, multioutput
400404
)
401405
check_consistent_length(y_true, y_pred, sample_weight)
402-
epsilon = np.finfo(np.float64).eps
403-
mape = np.abs(y_pred - y_true) / np.maximum(np.abs(y_true), epsilon)
404-
output_errors = np.average(mape, weights=sample_weight, axis=0)
406+
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=dtype)
407+
y_true_abs = xp.asarray(xp.abs(y_true), dtype=dtype)
408+
mape = xp.asarray(xp.abs(y_pred - y_true), dtype=dtype) / xp.maximum(
409+
y_true_abs, epsilon
410+
)
411+
output_errors = _average(mape, weights=sample_weight, axis=0)
405412
if isinstance(multioutput, str):
406413
if multioutput == "raw_values":
407414
return output_errors
408415
elif multioutput == "uniform_average":
409-
# pass None as weights to np.average: uniform mean
416+
# pass None as weights to _average: uniform mean
410417
multioutput = None
411418

412-
return np.average(output_errors, weights=multioutput)
419+
mean_absolute_percentage_error = _average(output_errors, weights=multioutput)
420+
assert mean_absolute_percentage_error.shape == ()
421+
return float(mean_absolute_percentage_error)
413422

414423

415424
@validate_params(

‎sklearn/metrics/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_common.py
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,6 +2016,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
20162016
additive_chi2_kernel: [check_array_api_metric_pairwise],
20172017
mean_gamma_deviance: [check_array_api_regression_metric],
20182018
max_error: [check_array_api_regression_metric],
2019+
mean_absolute_percentage_error: [
2020+
check_array_api_regression_metric,
2021+
check_array_api_regression_metric_multioutput,
2022+
],
20192023
chi2_kernel: [check_array_api_metric_pairwise],
20202024
cosine_distances: [check_array_api_metric_pairwise],
20212025
euclidean_distances: [check_array_api_metric_pairwise],

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.