Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 35a905f

Browse filesBrowse files
authored
Merge pull request #13078 from eric-wieser/simplify-from-roots
MAINT: deduplicate fromroots in np.polynomial
2 parents d34d1f3 + 1bb279a commit 35a905f
Copy full SHA for 35a905f

File tree

7 files changed

+36
-90
lines changed
Filter options

7 files changed

+36
-90
lines changed

‎numpy/polynomial/chebyshev.py

Copy file name to clipboardExpand all lines: numpy/polynomial/chebyshev.py
+1-15Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -541,21 +541,7 @@ def chebfromroots(roots):
541541
array([1.5+0.j, 0. +0.j, 0.5+0.j])
542542
543543
"""
544-
if len(roots) == 0:
545-
return np.ones(1)
546-
else:
547-
[roots] = pu.as_series([roots], trim=False)
548-
roots.sort()
549-
p = [chebline(-r, 1) for r in roots]
550-
n = len(p)
551-
while n > 1:
552-
m, r = divmod(n, 2)
553-
tmp = [chebmul(p[i], p[i+m]) for i in range(m)]
554-
if r:
555-
tmp[0] = chebmul(tmp[0], p[-1])
556-
p = tmp
557-
n = m
558-
return p[0]
544+
return pu._fromroots(chebline, chebmul, roots)
559545

560546

561547
def chebadd(c1, c2):

‎numpy/polynomial/hermite.py

Copy file name to clipboardExpand all lines: numpy/polynomial/hermite.py
+1-15Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -286,21 +286,7 @@ def hermfromroots(roots):
286286
array([0.+0.j, 0.+0.j])
287287
288288
"""
289-
if len(roots) == 0:
290-
return np.ones(1)
291-
else:
292-
[roots] = pu.as_series([roots], trim=False)
293-
roots.sort()
294-
p = [hermline(-r, 1) for r in roots]
295-
n = len(p)
296-
while n > 1:
297-
m, r = divmod(n, 2)
298-
tmp = [hermmul(p[i], p[i+m]) for i in range(m)]
299-
if r:
300-
tmp[0] = hermmul(tmp[0], p[-1])
301-
p = tmp
302-
n = m
303-
return p[0]
289+
return pu._fromroots(hermline, hermmul, roots)
304290

305291

306292
def hermadd(c1, c2):

‎numpy/polynomial/hermite_e.py

Copy file name to clipboardExpand all lines: numpy/polynomial/hermite_e.py
+1-15Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -287,21 +287,7 @@ def hermefromroots(roots):
287287
array([0.+0.j, 0.+0.j])
288288
289289
"""
290-
if len(roots) == 0:
291-
return np.ones(1)
292-
else:
293-
[roots] = pu.as_series([roots], trim=False)
294-
roots.sort()
295-
p = [hermeline(-r, 1) for r in roots]
296-
n = len(p)
297-
while n > 1:
298-
m, r = divmod(n, 2)
299-
tmp = [hermemul(p[i], p[i+m]) for i in range(m)]
300-
if r:
301-
tmp[0] = hermemul(tmp[0], p[-1])
302-
p = tmp
303-
n = m
304-
return p[0]
290+
return pu._fromroots(hermeline, hermemul, roots)
305291

306292

307293
def hermeadd(c1, c2):

‎numpy/polynomial/laguerre.py

Copy file name to clipboardExpand all lines: numpy/polynomial/laguerre.py
+1-15Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -283,21 +283,7 @@ def lagfromroots(roots):
283283
array([0.+0.j, 0.+0.j])
284284
285285
"""
286-
if len(roots) == 0:
287-
return np.ones(1)
288-
else:
289-
[roots] = pu.as_series([roots], trim=False)
290-
roots.sort()
291-
p = [lagline(-r, 1) for r in roots]
292-
n = len(p)
293-
while n > 1:
294-
m, r = divmod(n, 2)
295-
tmp = [lagmul(p[i], p[i+m]) for i in range(m)]
296-
if r:
297-
tmp[0] = lagmul(tmp[0], p[-1])
298-
p = tmp
299-
n = m
300-
return p[0]
286+
return pu._fromroots(lagline, lagmul, roots)
301287

302288

303289
def lagadd(c1, c2):

‎numpy/polynomial/legendre.py

Copy file name to clipboardExpand all lines: numpy/polynomial/legendre.py
+1-15Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -314,21 +314,7 @@ def legfromroots(roots):
314314
array([ 1.33333333+0.j, 0.00000000+0.j, 0.66666667+0.j]) # may vary
315315
316316
"""
317-
if len(roots) == 0:
318-
return np.ones(1)
319-
else:
320-
[roots] = pu.as_series([roots], trim=False)
321-
roots.sort()
322-
p = [legline(-r, 1) for r in roots]
323-
n = len(p)
324-
while n > 1:
325-
m, r = divmod(n, 2)
326-
tmp = [legmul(p[i], p[i+m]) for i in range(m)]
327-
if r:
328-
tmp[0] = legmul(tmp[0], p[-1])
329-
p = tmp
330-
n = m
331-
return p[0]
317+
return pu._fromroots(legline, legmul, roots)
332318

333319

334320
def legadd(c1, c2):

‎numpy/polynomial/polynomial.py

Copy file name to clipboardExpand all lines: numpy/polynomial/polynomial.py
+1-15Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -188,21 +188,7 @@ def polyfromroots(roots):
188188
array([1.+0.j, 0.+0.j, 1.+0.j])
189189
190190
"""
191-
if len(roots) == 0:
192-
return np.ones(1)
193-
else:
194-
[roots] = pu.as_series([roots], trim=False)
195-
roots.sort()
196-
p = [polyline(-r, 1) for r in roots]
197-
n = len(p)
198-
while n > 1:
199-
m, r = divmod(n, 2)
200-
tmp = [polymul(p[i], p[i+m]) for i in range(m)]
201-
if r:
202-
tmp[0] = polymul(tmp[0], p[-1])
203-
p = tmp
204-
n = m
205-
return p[0]
191+
return pu._fromroots(polyline, polymul, roots)
206192

207193

208194
def polyadd(c1, c2):

‎numpy/polynomial/polyutils.py

Copy file name to clipboardExpand all lines: numpy/polynomial/polyutils.py
+30Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,33 @@ def _vander3d(vander_f, x, y, z, deg):
459459
vz = vander_f(z, degz)
460460
v = vx[..., None, None]*vy[..., None,:, None]*vz[..., None, None,:]
461461
return v.reshape(v.shape[:-3] + (-1,))
462+
463+
464+
def _fromroots(line_f, mul_f, roots):
465+
"""
466+
Helper function used to implement the ``<type>fromroots`` functions.
467+
468+
Parameters
469+
----------
470+
line_f : function(float, float) -> ndarray
471+
The ``<type>line`` function, such as ``polyline``
472+
mul_f : function(array_like, array_like) -> ndarray
473+
The ``<type>mul`` function, such as ``polymul``
474+
roots :
475+
See the ``<type>fromroots`` functions for more detail
476+
"""
477+
if len(roots) == 0:
478+
return np.ones(1)
479+
else:
480+
[roots] = as_series([roots], trim=False)
481+
roots.sort()
482+
p = [line_f(-r, 1) for r in roots]
483+
n = len(p)
484+
while n > 1:
485+
m, r = divmod(n, 2)
486+
tmp = [mul_f(p[i], p[i+m]) for i in range(m)]
487+
if r:
488+
tmp[0] = mul_f(tmp[0], p[-1])
489+
p = tmp
490+
n = m
491+
return p[0]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.