diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index a51ee60e47e04..9e1efdd44428e 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -91,7 +91,7 @@ Estimators - :class:`decomposition.PCA` (with `svd_solver="full"`, `svd_solver="randomized"` and `power_iteration_normalizer="QR"`) -- :class:`linear_model.Ridge` (with `solver="svd"`) +- :class:`linear_model.Ridge` (with `solver="cholesky"` or `solver="svd"`) - :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`) - :class:`preprocessing.KernelCenterer` - :class:`preprocessing.MaxAbsScaler` diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index e1e7a1f01f2f8..a3f354649d633 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -55,6 +55,8 @@ See :ref:`array_api` for more details. :class:`model_selection.HalvingRandomSearchCV` now support Array API compatible inputs when their base estimators do. :pr:`27096` by :user:`Tim Head ` and :user:`Olivier Grisel `. +- :class:`linear_model.Ridge` with `solver="cholesky"` now supports Array API + compatible inputs. :pr:`29318` by :user:`Olivier Grisel `. Metadata Routing ---------------- diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index c9143389739af..4931862474e75 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -198,57 +198,93 @@ def _solve_lsqr( return coefs, n_iter -def _solve_cholesky(X, y, alpha): +def _solve_cholesky(X, y, alpha, xp=None): # w = inv(X^t X + alpha*Id) * X.T y + if xp is None: + xp, _ = get_namespace(X, y) n_features = X.shape[1] n_targets = y.shape[1] A = safe_sparse_dot(X.T, X, dense_output=True) Xy = safe_sparse_dot(X.T, y, dense_output=True) - one_alpha = np.array_equal(alpha, len(alpha) * [alpha[0]]) + one_alpha = bool(xp.all(alpha == alpha[0])) + if _is_numpy_namespace(xp): + # A.flat is guaranteed to be a view even when A is Fortran-ordered + # which typically happens when X is a CSR datastructure. + A_flat = A.flat + linalg_solve = partial(linalg.solve, assume_a="pos", overwrite_a=one_alpha) + else: + # XXX: ideally one would like to pass copy=False explicitly to + # xp.reshape, but this is not supported by PyTorch at the time of + # writing. + A_flat = xp.reshape(A, (-1,)) + linalg_solve = xp.linalg.solve if one_alpha: - A.flat[:: n_features + 1] += alpha[0] - return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T + A_flat[:: n_features + 1] += alpha[0] + return linalg_solve(A, Xy).T else: - coefs = np.empty([n_targets, n_features], dtype=X.dtype) - for coef, target, current_alpha in zip(coefs, Xy.T, alpha): - A.flat[:: n_features + 1] += current_alpha - coef[:] = linalg.solve(A, target, assume_a="pos", overwrite_a=False).ravel() - A.flat[:: n_features + 1] -= current_alpha + coefs = xp.empty([n_targets, n_features], dtype=X.dtype, device=device(X)) + for target_idx, current_alpha in enumerate(alpha): + coef = coefs[target_idx, :] + target = Xy[:, target_idx] + A_flat[:: n_features + 1] += current_alpha + coef[:] = _ravel(linalg_solve(A, target)) + A_flat[:: n_features + 1] -= current_alpha return coefs -def _solve_cholesky_kernel(K, y, alpha, sample_weight=None, copy=False): +def _solve_cholesky_kernel(K, y, alpha, sample_weight=None, copy=False, xp=None): # dual_coef = inv(X X^t + alpha*Id) y + if xp is None: + xp, _ = get_namespace(K, y, sample_weight) n_samples = K.shape[0] n_targets = y.shape[1] if copy: K = K.copy() - alpha = np.atleast_1d(alpha) - one_alpha = (alpha == alpha[0]).all() - has_sw = isinstance(sample_weight, np.ndarray) or sample_weight not in [1.0, None] + one_alpha = bool(xp.all(alpha == alpha[0])) + has_sw = sample_weight is not None if has_sw: # Unlike other solvers, we need to support sample_weight directly # because K might be a pre-computed kernel. - sw = np.sqrt(np.atleast_1d(sample_weight)) - y = y * sw[:, np.newaxis] - K *= np.outer(sw, sw) + sw = xp.sqrt(sample_weight) + y = y * sw[:, None] + K *= sw[:, None] @ sw[None, :] # outer product + + if _is_numpy_namespace(xp): + # K.flat is guaranteed to be a view even when K is Fortran-ordered + # which typically happens for the linear kernel X @ X.T with X being a + # CSR datastructure. + K_flat = K.flat + + # Note: we must use overwrite_a=False in order to be able to use the + # fall-back solution below in case a LinAlgError is raised. + linalg_solve = partial(linalg.solve, assume_a="pos", overwrite_a=False) + else: + # XXX: ideally one would like to pass copy=False explicitly to + # xp.reshape, but this is not supported by PyTorch at the time of + # writing. + K_flat = xp.reshape(K, (-1,)) + linalg_solve = xp.linalg.solve if one_alpha: # Only one penalty, we can solve multi-target problems in one time. - K.flat[:: n_samples + 1] += alpha[0] + K_flat[:: n_samples + 1] += alpha[0] try: - # Note: we must use overwrite_a=False in order to be able to - # use the fall-back solution below in case a LinAlgError - # is raised - dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False) + dual_coef = linalg_solve(K, y) except np.linalg.LinAlgError: + # XXX: this exception is numpy specific. If another + # xp.linalg.LinAlgError is raised instead and if the caller is + # _ridge_regression, the caller should catch it and fall back to + # the SVD solution instead. + # + # TODO: find out which call-back we want in case the caller is a + # non-linear kernel ridge instead. warnings.warn( "Singular matrix in solving dual problem. Using " "least-squares solution instead." @@ -257,27 +293,24 @@ def _solve_cholesky_kernel(K, y, alpha, sample_weight=None, copy=False): # K is expensive to compute and store in memory so change it back in # case it was user-given. - K.flat[:: n_samples + 1] -= alpha[0] + K_flat[:: n_samples + 1] -= alpha[0] if has_sw: - dual_coef *= sw[:, np.newaxis] + dual_coef *= sw[:, None] return dual_coef else: # One penalty per target. We need to solve each target separately. - dual_coefs = np.empty([n_targets, n_samples], K.dtype) - - for dual_coef, target, current_alpha in zip(dual_coefs, y.T, alpha): - K.flat[:: n_samples + 1] += current_alpha - - dual_coef[:] = linalg.solve( - K, target, assume_a="pos", overwrite_a=False - ).ravel() - - K.flat[:: n_samples + 1] -= current_alpha + dual_coefs = xp.empty([n_targets, n_samples], dtype=K.dtype, device=device(K)) + for target_idx, current_alpha in enumerate(alpha): + dual_coef = dual_coefs[target_idx, :] + target = y[:, target_idx] + K_flat[:: n_samples + 1] += current_alpha + dual_coef[:] = _ravel(linalg_solve(K, target)) + K_flat[:: n_samples + 1] -= current_alpha if has_sw: - dual_coefs *= sw[np.newaxis, :] + dual_coefs *= sw[None, :] return dual_coefs.T @@ -625,10 +658,10 @@ def _ridge_regression( if is_numpy_namespace and not X_is_sparse: X = np.asarray(X) - if not is_numpy_namespace and solver != "svd": + if not is_numpy_namespace and solver not in ("svd", "cholesky"): raise ValueError( f"Array API dispatch to namespace {xp.__name__} only supports " - f"solver 'svd'. Got '{solver}'." + f"solver 'svd' and 'cholesky'. Got '{solver}'." ) if positive and solver != "lbfgs": @@ -684,16 +717,8 @@ def _ridge_regression( # we implement sample_weight via a simple rescaling. X, y, sample_weight_sqrt = _rescale_data(X, y, sample_weight) - # Some callers of this method might pass alpha as single - # element array which already has been validated. - if alpha is not None and not isinstance(alpha, type(xp.asarray([0.0]))): - alpha = check_scalar( - alpha, - "alpha", - target_type=numbers.Real, - min_val=0.0, - include_boundaries="left", - ) + if alpha is not None: + alpha = xp.asarray(alpha, dtype=X.dtype, device=device_) # There should be either 1 or n_targets penalties alpha = _ravel(xp.asarray(alpha, device=device_, dtype=X.dtype), xp=xp) @@ -739,15 +764,14 @@ def _ridge_regression( if n_features > n_samples: K = safe_sparse_dot(X, X.T, dense_output=True) try: - dual_coef = _solve_cholesky_kernel(K, y, alpha) - + dual_coef = _solve_cholesky_kernel(K, y, alpha, xp=xp) coef = safe_sparse_dot(X.T, dual_coef, dense_output=True).T except linalg.LinAlgError: # use SVD solver if matrix is singular solver = "svd" else: try: - coef = _solve_cholesky(X, y, alpha) + coef = _solve_cholesky(X, y, alpha, xp=xp) except linalg.LinAlgError: # use SVD solver if matrix is singular solver = "svd" @@ -810,8 +834,6 @@ def _ridge_regression( if ravel: coef = _ravel(coef) - coef = xp.asarray(coef) - if return_n_iter and return_intercept: res = coef, n_iter, intercept elif return_intercept: @@ -837,23 +859,25 @@ def resolve_solver(solver, positive, return_intercept, is_sparse, xp): if positive: raise ValueError( "The solvers that support positive fitting do not support " - f"Array API dispatch to namespace {xp.__name__}. Please " - "either disable Array API dispatch, or use a numpy-like " + f"array API dispatch to namespace {xp.__name__}. Please " + "either disable array API dispatch, or use a numpy-like " "namespace, or set `positive=False`." ) - # At the moment, Array API dispatch only supports the "svd" solver. - solver = "svd" - if solver != auto_solver_np: - warnings.warn( - f"Using Array API dispatch to namespace {xp.__name__} with " - f"`solver='auto'` will result in using the solver '{solver}'. " - "The results may differ from those when using a Numpy array, " - f"because in that case the preferred solver would be {auto_solver_np}. " - f"Set `solver='{solver}'` to suppress this warning." + resolved_solver = auto_solver_np + if auto_solver_np != "cholesky": + # The only way to end-up here is if the solver is 'auto' and the + # namespace is not numpy, and ridge_regression was called with + # return_intercept=True. + assert return_intercept + raise ValueError( + "The solvers that support fitting fit intercept without preprocessing " + f"do not support array API dispatch to namespace {xp.__name__}. Please " + "either disable array API dispatch, or use Ridge().fit_transform(X, y) " + "instead of ridge_regression(X, y)." ) - return solver + return resolved_solver def resolve_solver_for_numpy(positive, return_intercept, is_sparse): @@ -872,7 +896,7 @@ def resolve_solver_for_numpy(positive, return_intercept, is_sparse): class _BaseRidge(LinearModel, metaclass=ABCMeta): _parameter_constraints: dict = { - "alpha": [Interval(Real, 0, None, closed="left"), np.ndarray], + "alpha": [Interval(Real, 0, None, closed="left"), "array-like"], "fit_intercept": ["boolean"], "copy_X": ["boolean"], "max_iter": [Interval(Integral, 1, None, closed="left"), None], diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py index 167ce0bac4cba..dc9f0ce008d74 100644 --- a/sklearn/linear_model/tests/test_ridge.py +++ b/sklearn/linear_model/tests/test_ridge.py @@ -1,4 +1,5 @@ import warnings +from functools import partial from itertools import product import numpy as np @@ -1216,24 +1217,65 @@ def _test_tolerance(sparse_container): assert score >= score2 -def check_array_api_attributes(name, estimator, array_namespace, device, dtype_name): +def check_array_api_attributes( + name, + estimator, + array_namespace, + device, + dtype_name, + data_shape="tall", + multi_output=False, + use_sample_weight=False, + rank_deficient=False, +): + rng = np.random.RandomState(0) xp = _array_api_for_tests(array_namespace, device) - X_iris_np = X_iris.astype(dtype_name) - y_iris_np = y_iris.astype(dtype_name) + if data_shape == "tall": + X_np = X_iris.astype(dtype_name) + if rank_deficient: + # Introduce redundant features to make the covariance matrix rank + # deficient + X_np = np.hstack([X_np] * 2) + y_np = y_iris.astype(dtype_name) + else: + n_samples, n_features = 10, 100 + X_np = rng.randn(n_samples, n_features).astype(dtype_name) + w = rng.randn(100).astype(dtype_name) + y_np = X_np @ w + 0.01 * rng.randn(n_samples).astype(dtype_name) + if rank_deficient: + # Duplicated some rows to make the kernel matrix rank deficient + X_np = np.vstack([X_np] * 2) + y_np = np.hstack([y_np] * 2) + + if multi_output: + y_np = np.column_stack([y_np, y_np]) + + # Set different alphas for each target to increase test coverage. + estimator = clone(estimator) + estimator.set_params(alpha=[1e-6, 1e6]) + + if use_sample_weight: + sample_weight_np = rng.rand(X_np.shape[0]).astype(dtype_name) + sample_weight_xp = xp.asarray(sample_weight_np, device=device) + else: + sample_weight_np = sample_weight_xp = None - X_iris_xp = xp.asarray(X_iris_np, device=device) - y_iris_xp = xp.asarray(y_iris_np, device=device) + X_xp = xp.asarray(X_np, device=device) + y_xp = xp.asarray(y_np, device=device) - estimator.fit(X_iris_np, y_iris_np) + estimator.fit(X_np, y_np, sample_weight_np) coef_np = estimator.coef_ intercept_np = estimator.intercept_ with config_context(array_api_dispatch=True): - estimator_xp = clone(estimator).fit(X_iris_xp, y_iris_xp) + estimator_xp = clone(estimator).fit(X_xp, y_xp, sample_weight=sample_weight_xp) coef_xp = estimator_xp.coef_ - assert coef_xp.shape == (4,) - assert coef_xp.dtype == X_iris_xp.dtype + if multi_output: + assert coef_xp.shape == (2, X_xp.shape[1]) + else: + assert coef_xp.shape == (X_xp.shape[1],) + assert coef_xp.dtype == X_xp.dtype assert_allclose( _convert_to_numpy(coef_xp, xp=xp), @@ -1241,8 +1283,11 @@ def check_array_api_attributes(name, estimator, array_namespace, device, dtype_n atol=_atol_for_type(dtype_name), ) intercept_xp = estimator_xp.intercept_ - assert intercept_xp.shape == () - assert intercept_xp.dtype == X_iris_xp.dtype + if multi_output: + assert intercept_xp.shape == (2,) + else: + assert intercept_xp.shape == () + assert intercept_xp.dtype == X_xp.dtype assert_allclose( _convert_to_numpy(intercept_xp, xp=xp), @@ -1256,12 +1301,29 @@ def check_array_api_attributes(name, estimator, array_namespace, device, dtype_n ) @pytest.mark.parametrize( "check", - [check_array_api_input_and_values, check_array_api_attributes], + [ + check_array_api_input_and_values, + partial(check_array_api_attributes, data_shape="tall", use_sample_weight=True), + partial( + check_array_api_attributes, + data_shape="tall", + multi_output=True, + ), + partial(check_array_api_attributes, data_shape="tall", rank_deficient=True), + partial(check_array_api_attributes, data_shape="wide"), + partial( + check_array_api_attributes, + data_shape="wide", + multi_output=True, + use_sample_weight=True, + ), + partial(check_array_api_attributes, data_shape="wide", rank_deficient=True), + ], ids=_get_check_estimator_ids, ) @pytest.mark.parametrize( "estimator", - [Ridge(solver="svd")], + [Ridge(solver="svd"), Ridge(solver="cholesky"), Ridge(solver="cholesky", alpha=0)], ids=_get_check_estimator_ids, ) def test_ridge_array_api_compliance( @@ -1281,11 +1343,11 @@ def test_array_api_error_and_warnings_for_solver_parameter(array_namespace): y_iris_xp = xp.asarray(y_iris[:5]) available_solvers = Ridge._parameter_constraints["solver"][0].options - for solver in available_solvers - {"auto", "svd"}: + for solver in available_solvers - {"auto", "svd", "cholesky"}: ridge = Ridge(solver=solver, positive=solver == "lbfgs") expected_msg = ( f"Array API dispatch to namespace {xp.__name__} only supports " - f"solver 'svd'. Got '{solver}'." + f"solver 'svd' and 'cholesky'. Got '{solver}'." ) with pytest.raises(ValueError, match=expected_msg): @@ -1295,8 +1357,8 @@ def test_array_api_error_and_warnings_for_solver_parameter(array_namespace): ridge = Ridge(solver="auto", positive=True) expected_msg = ( "The solvers that support positive fitting do not support " - f"Array API dispatch to namespace {xp.__name__}. Please " - "either disable Array API dispatch, or use a numpy-like " + f"array API dispatch to namespace {xp.__name__}. Please " + "either disable array API dispatch, or use a numpy-like " "namespace, or set `positive=False`." ) @@ -1304,16 +1366,18 @@ def test_array_api_error_and_warnings_for_solver_parameter(array_namespace): with config_context(array_api_dispatch=True): ridge.fit(X_iris_xp, y_iris_xp) - ridge = Ridge() - expected_msg = ( - f"Using Array API dispatch to namespace {xp.__name__} with `solver='auto'` " - "will result in using the solver 'svd'. The results may differ from those " - "when using a Numpy array, because in that case the preferred solver would " - "be cholesky. Set `solver='svd'` to suppress this warning." - ) - with pytest.warns(UserWarning, match=expected_msg): - with config_context(array_api_dispatch=True): - ridge.fit(X_iris_xp, y_iris_xp) + with config_context(array_api_dispatch=True): + ridge_regression( + X_iris_xp, y_iris_xp, alpha=1.0, return_intercept=False + ) # no error + + expected_msg = ( + "The solvers that support fitting fit intercept without preprocessing " + f"do not support array API dispatch to namespace {xp.__name__}." + ) + + with pytest.raises(ValueError, match=expected_msg): + ridge_regression(X_iris_xp, y_iris_xp, alpha=1.0, return_intercept=True) @pytest.mark.parametrize("array_namespace", sorted(_NUMPY_NAMESPACE_NAMES)) @@ -1323,18 +1387,12 @@ def test_array_api_numpy_namespace_no_warning(array_namespace): X_iris_xp = xp.asarray(X_iris[:5]) y_iris_xp = xp.asarray(y_iris[:5]) - ridge = Ridge() - expected_msg = ( - "Results might be different than when Array API dispatch is " - "disabled, or when a numpy-like namespace is used" - ) - with warnings.catch_warnings(): - warnings.filterwarnings("error", message=expected_msg, category=UserWarning) + warnings.filterwarnings("error", category=UserWarning) with config_context(array_api_dispatch=True): - ridge.fit(X_iris_xp, y_iris_xp) + Ridge().fit(X_iris_xp, y_iris_xp) - # All numpy namespaces are compatible with all solver, in particular + # All NumPy namespaces are compatible with all solver, in particular # solvers that support `positive=True` (like 'lbfgs') should work. with config_context(array_api_dispatch=True): Ridge(solver="auto", positive=True).fit(X_iris_xp, y_iris_xp) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index c222e26fcc82c..dfab8aacb3ba4 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -427,7 +427,16 @@ def reshape(self, x, shape, *, copy=None): if copy is True: x = x.copy() - return numpy.reshape(x, shape) + + output = numpy.reshape(x, shape) + if copy is False and not numpy.shares_memory(x, output): + # See the following ref in the spec for the meaning of copy=False: + # https://data-apis.org/array-api/latest/API_specification/generated/array_api.reshape.html + raise ValueError( + f"reshape with copy=False is not compatible with shape {shape} " + "for the memory layout of the input array." + ) + return output def isdtype(self, dtype, kind): return isdtype(dtype, kind, xp=self) @@ -560,20 +569,50 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): return namespace, is_array_api_compliant -def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,)): - """Combination into one single function of `get_namespace` and `device`.""" - array_list = _remove_non_arrays( - *array_list, remove_none=remove_none, remove_types=remove_types +def get_namespace_and_device(*arrays, remove_none=True, remove_types=(str,), xp=None): + """Combination into one single function of `get_namespace` and `device`. + + Parameters + ---------- + *arrays : array objects + Array objects. + + remove_none : bool, default=True + Whether to ignore None objects passed in arrays. + + remove_types : tuple or list, default=(str,) + Types to ignore in the arrays. + + xp : module, default=None + Precomputed array namespace module. When passed, typically from a caller + that has already performed inspection of its own inputs, skips array + namespace inspection. + + Returns + ------- + namespace : module + Namespace shared by array objects. If any of the `arrays` are not arrays, + the namespace defaults to NumPy. + + is_array_api_compliant : bool + True if the arrays are containers that implement the Array API spec. + Always False when array_api_dispatch=False. + + device : device + `device` object (see the "Device Support" section of the array API spec). + """ + arrays = _remove_non_arrays( + *arrays, remove_none=remove_none, remove_types=remove_types ) skip_remove_kwargs = dict(remove_none=False, remove_types=[]) - xp, is_array_api = get_namespace(*array_list, **skip_remove_kwargs) + xp, is_array_api = get_namespace(*arrays, xp=xp, **skip_remove_kwargs) if is_array_api: return ( xp, is_array_api, - device(*array_list, **skip_remove_kwargs), + device(*arrays, **skip_remove_kwargs), ) else: return xp, False, None @@ -824,6 +863,8 @@ def _estimator_with_converted_arrays(estimator, converter): def _atol_for_type(dtype): """Return the absolute tolerance for a given numpy dtype.""" + if dtype is None: + dtype = numpy.float64 return numpy.finfo(dtype).eps * 100 diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 25913e7f54846..0ce998be7fc64 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -454,6 +454,13 @@ def test_reshape_behavior(): with pytest.raises(TypeError, match="shape must be a tuple"): xp.reshape(X, -1) + X_fortran = numpy.asfortranarray(X) + with pytest.raises(ValueError, match="reshape with copy=False is not compatible"): + xp.reshape(X_fortran, (-1,), copy=False) + + X_copy = xp.reshape(X_fortran, (-1,)) + assert X_copy.base is not X_fortran.base + @pytest.mark.parametrize("wrapper", [_ArrayAPIWrapper, _NumPyAPIWrapper]) def test_get_namespace_array_api_isdtype(wrapper): diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 228fbe76a25e1..6d1c90bcc0dec 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -17,7 +17,12 @@ from .. import get_config as _get_config from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning -from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace +from ..utils._array_api import ( + _asarray_with_order, + _is_numpy_namespace, + get_namespace, + get_namespace_and_device, +) from ..utils.fixes import ComplexWarning, _preserve_dia_indices_dtype from ._isfinite import FiniteStatus, cy_isfinite from .fixes import _object_dtype_isnan @@ -2022,18 +2027,19 @@ def _check_sample_weight( sample_weight : ndarray of shape (n_samples,) Validated sample weight. It is guaranteed to be "C" contiguous. """ + xp, _, device_ = get_namespace_and_device(X) n_samples = _num_samples(X) - if dtype is not None and dtype not in [np.float32, np.float64]: - dtype = np.float64 + if dtype is not None and dtype not in [xp.float32, xp.float64]: + dtype = xp.float64 if sample_weight is None: - sample_weight = np.ones(n_samples, dtype=dtype) + sample_weight = xp.ones(n_samples, dtype=dtype, device=device_) elif isinstance(sample_weight, numbers.Number): - sample_weight = np.full(n_samples, sample_weight, dtype=dtype) + sample_weight = xp.full(n_samples, sample_weight, dtype=dtype, device=device_) else: if dtype is None: - dtype = [np.float64, np.float32] + dtype = [xp.float64, xp.float32] sample_weight = check_array( sample_weight, accept_sparse=False,