Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 864ea19

Browse filesBrowse files
committed
MNT use np.linspace backport in SplineTransformer
1 parent 831cba6 commit 864ea19
Copy full SHA for 864ea19

File tree

Expand file treeCollapse file tree

1 file changed

+9
-26
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+9
-26
lines changed

‎sklearn/preprocessing/_data.py

Copy file name to clipboardExpand all lines: sklearn/preprocessing/_data.py
+9-26Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..utils import assert_all_finite, check_array
2626
from ..utils.extmath import row_norms
2727
from ..utils.extmath import _incremental_mean_and_var
28-
from ..utils.fixes import np_version, parse_version
28+
from ..utils.fixes import linspace
2929
from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1,
3030
inplace_csr_row_normalize_l2)
3131
from ..utils.sparsefuncs import (inplace_column_scale,
@@ -1916,19 +1916,8 @@ def get_base_knot_positions(X, n_knots=10, knots='uniform'):
19161916
# knots == 'uniform':
19171917
x_min = np.amin(X, axis=0)
19181918
x_max = np.amax(X, axis=0)
1919-
1920-
# FIXME: to be removed if min version becomes numpy 1.16
1921-
# start and stop arrays for linspace logspace and geomspace
1922-
# https://github.com/numpy/numpy/pull/12388
1923-
if np_version < parse_version('1.16'):
1924-
n_features = X.shape[1]
1925-
knots = np.empty((n_knots, n_features))
1926-
for j in range(n_features):
1927-
knots[:, j] = np.linspace(start=x_min[j], stop=x_max[j],
1928-
num=n_knots, endpoint=True)
1929-
else:
1930-
knots = np.linspace(start=x_min, stop=x_max, num=n_knots,
1931-
endpoint=True)
1919+
knots = linspace(start=x_min, stop=x_max, num=n_knots,
1920+
endpoint=True)
19321921

19331922
return knots
19341923

@@ -2023,18 +2012,12 @@ def fit(self, X, y=None):
20232012
# Instead, we reuse the distance of the 2 fist/last knots.
20242013
dist_min = base_knots[1] - base_knots[0]
20252014
dist_max = base_knots[-1] - base_knots[-2]
2026-
# FIXME: to be removed if min version becomes numpy 1.16
2027-
# start and stop arrays for linspace logspace and geomspace
2028-
# https://github.com/numpy/numpy/pull/12388
2029-
if np_version < parse_version('1.16'):
2030-
2031-
else:
2032-
knots = np.r_[np.linspace(base_knots[0] - degree * dist_min,
2033-
base_knots[0] - dist_min, num=degree),
2034-
base_knots,
2035-
np.linspace(base_knots[-1] + dist_max,
2036-
base_knots[-1] + degree * dist_max,
2037-
num=degree)]
2015+
knots = np.r_[linspace(base_knots[0] - degree * dist_min,
2016+
base_knots[0] - dist_min, num=degree),
2017+
base_knots,
2018+
np.linspace(base_knots[-1] + dist_max,
2019+
base_knots[-1] + degree * dist_max,
2020+
num=degree)]
20382021

20392022
# With a diagonal coefficient matrix, we get back the spline basis
20402023
# elements, i.e. the design matrix of the spline.

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.