Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9f44f1f

Browse filesBrowse files
authored
ENH Add Array API compatibility to mean_absolute_error (#27736)
1 parent 28c9f50 commit 9f44f1f
Copy full SHA for 9f44f1f

File tree

4 files changed

+34
-5
lines changed
Filter options

4 files changed

+34
-5
lines changed

‎doc/modules/array_api.rst

Copy file name to clipboardExpand all lines: doc/modules/array_api.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Metrics
106106
-------
107107

108108
- :func:`sklearn.metrics.accuracy_score`
109+
- :func:`sklearn.metrics.mean_absolute_error`
109110
- :func:`sklearn.metrics.mean_tweedie_deviance`
110111
- :func:`sklearn.metrics.r2_score`
111112
- :func:`sklearn.metrics.zero_one_loss`

‎doc/whats_new/v1.6.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.6.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ See :ref:`array_api` for more details.
3535
- :func:`sklearn.metrics.mean_tweedie_deviance` now supports Array API compatible
3636
inputs.
3737
:pr:`28106` by :user:`Thomas Li <lithomas1>`
38+
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`.
3839

3940
**Classes:**
4041

‎sklearn/metrics/_regression.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_regression.py
+21-5Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def mean_absolute_error(
189189
190190
Returns
191191
-------
192-
loss : float or ndarray of floats
192+
loss : float or array of floats
193193
If multioutput is 'raw_values', then mean absolute error is returned
194194
for each output separately.
195195
If multioutput is 'uniform_average' or an ndarray of weights, then the
@@ -213,19 +213,35 @@ def mean_absolute_error(
213213
>>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
214214
0.85...
215215
"""
216-
y_type, y_true, y_pred, multioutput = _check_reg_targets(
217-
y_true, y_pred, multioutput
216+
input_arrays = [y_true, y_pred, sample_weight, multioutput]
217+
xp, _ = get_namespace(*input_arrays)
218+
219+
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
220+
221+
_, y_true, y_pred, multioutput = _check_reg_targets(
222+
y_true, y_pred, multioutput, dtype=dtype, xp=xp
218223
)
219224
check_consistent_length(y_true, y_pred, sample_weight)
220-
output_errors = np.average(np.abs(y_pred - y_true), weights=sample_weight, axis=0)
225+
226+
output_errors = _average(
227+
xp.abs(y_pred - y_true), weights=sample_weight, axis=0, xp=xp
228+
)
221229
if isinstance(multioutput, str):
222230
if multioutput == "raw_values":
223231
return output_errors
224232
elif multioutput == "uniform_average":
225233
# pass None as weights to np.average: uniform mean
226234
multioutput = None
227235

228-
return np.average(output_errors, weights=multioutput)
236+
# Average across the outputs (if needed).
237+
mean_absolute_error = _average(output_errors, weights=multioutput)
238+
239+
# Since `y_pred.ndim <= 2` and `y_true.ndim <= 2`, the second call to _average
240+
# should always return a scalar array that we convert to a Python float to
241+
# consistently return the same eager evaluated value, irrespective of the
242+
# Array API implementation.
243+
assert mean_absolute_error.shape == ()
244+
return float(mean_absolute_error)
229245

230246

231247
@validate_params(

‎sklearn/metrics/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_common.py
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,6 +1879,13 @@ def check_array_api_regression_metric_multioutput(
18791879
)
18801880

18811881

1882+
def check_array_api_multioutput_regression_metric(
1883+
metric, array_namespace, device, dtype_name
1884+
):
1885+
metric = partial(metric, multioutput="raw_values")
1886+
check_array_api_regression_metric(metric, array_namespace, device, dtype_name)
1887+
1888+
18821889
array_api_metric_checkers = {
18831890
accuracy_score: [
18841891
check_array_api_binary_classification_metric,
@@ -1893,6 +1900,10 @@ def check_array_api_regression_metric_multioutput(
18931900
check_array_api_regression_metric,
18941901
check_array_api_regression_metric_multioutput,
18951902
],
1903+
mean_absolute_error: [
1904+
check_array_api_regression_metric,
1905+
check_array_api_multioutput_regression_metric,
1906+
],
18961907
}
18971908

18981909

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.