Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 33860e6

Browse filesBrowse files
Rewrite complex log1p to improve precision.
Closes gh-22609
1 parent 2f3da0d commit 33860e6
Copy full SHA for 33860e6

File tree

7 files changed

+535
-10
lines changed
Filter options

7 files changed

+535
-10
lines changed

‎numpy/_core/meson.build

Copy file name to clipboardExpand all lines: numpy/_core/meson.build
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,7 @@ src_umath = umath_gen_headers + [
11531153
src_file.process('src/umath/scalarmath.c.src'),
11541154
'src/umath/ufunc_object.c',
11551155
'src/umath/umathmodule.c',
1156+
src_file.process('src/umath/clog1p_wrappers.cpp.src'),
11561157
'src/umath/special_integer_comparisons.cpp',
11571158
'src/umath/string_ufuncs.cpp',
11581159
'src/umath/stringdtype_ufuncs.cpp',
+32Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include <Python.h>
2+
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
3+
#include "numpy/npy_math.h"
4+
#include "numpy/ndarraytypes.h"
5+
#include <cmath>
6+
#include <complex>
7+
#include "log1p_complex.h"
8+
9+
extern "C" {
10+
11+
/**begin repeat
12+
* #c_fp_type = float, double, long double#
13+
* #np_cplx_type = npy_cfloat,npy_cdouble,npy_clongdouble#
14+
* #c = f, , l#
15+
*/
16+
17+
/*
18+
* C wrapper for log1p_complex(z). This function is to be used only
19+
* when the input is close to the unit circle centered at -1+0j.
20+
*/
21+
NPY_NO_EXPORT void
22+
clog1p@c@(@np_cplx_type@ *x, @np_cplx_type@ *r)
23+
{
24+
std::complex<@c_fp_type@> z{npy_creal@c@(*x), npy_cimag@c@(*x)};
25+
auto w = log1p_complex::log1p_complex(z);
26+
npy_csetreal@c@(r, w.real());
27+
npy_csetimag@c@(r, w.imag());
28+
}
29+
30+
/**end repeat**/
31+
32+
} // extern "C"
+19Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef CLOG1P_WRAPPERS_H
2+
#define CLOG1P_WRAPPERS_H
3+
4+
// This header is to be included in umathmodule.c,
5+
// so it will be processed by a C compiler.
6+
//
7+
// This file assumes that the numpy header files have
8+
// already been included.
9+
10+
NPY_NO_EXPORT void
11+
clog1pf(npy_cfloat *z, npy_cfloat *r);
12+
13+
NPY_NO_EXPORT void
14+
clog1p(npy_cdouble *z, npy_cdouble *r);
15+
16+
NPY_NO_EXPORT void
17+
clog1pl(npy_clongdouble *z, npy_clongdouble *r);
18+
19+
#endif // CLOG1P_WRAPPERS_H

‎numpy/_core/src/umath/funcs.inc.src

Copy file name to clipboardExpand all lines: numpy/_core/src/umath/funcs.inc.src
+25-4Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,31 @@ nc_log@c@(@ctype@ *x, @ctype@ *r)
310310
static void
311311
nc_log1p@c@(@ctype@ *x, @ctype@ *r)
312312
{
313-
@ftype@ l = npy_hypot@c@(npy_creal@c@(*x) + 1,npy_cimag@c@(*x));
314-
npy_csetimag@c@(r, npy_atan2@c@(npy_cimag@c@(*x), npy_creal@c@(*x) + 1));
315-
npy_csetreal@c@(r, npy_log@c@(l));
316-
return;
313+
@ftype@ xr = npy_creal@c@(*x);
314+
@ftype@ xi = npy_cimag@c@(*x);
315+
@ctype@ xp1 = npy_cpack@c@(xr + 1, xi);
316+
@ftype@ delta = 0.001;
317+
/*
318+
* Should we use the high precision calculation? First check if
319+
* x is within a bounding box, and then check if x is close to the
320+
* unit circle centered at -1+0j. The bounding box is checked first
321+
* to avoid overflow that might occur in npy_cabs@c@(xp1) when xr or
322+
* xi is very close to the maximum value of @ftype@.
323+
*/
324+
if (-2 - delta < xr && xr < delta && -1 - delta < xi && xi < 1 + delta) {
325+
@ftype@ m = npy_cabs@c@(xp1);
326+
if (1 - delta < m && m < 1 + delta) {
327+
/* Use the high precision calculation in this region. */
328+
clog1p@c@(x, r);
329+
return;
330+
}
331+
}
332+
/*
333+
* Use clog(1 + x).
334+
* This branch of the code will also be used if either
335+
* xr or xi is nan.
336+
*/
337+
nc_log@c@(&xp1, r);
317338
}
318339

319340
static void

‎numpy/_core/src/umath/log1p_complex.h

Copy file name to clipboard
+260Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
#ifndef LOG1P_COMPLEX_H
2+
#define LOG1P_COMPLEX_H
3+
4+
#include <Python.h>
5+
#include "numpy/ndarraytypes.h"
6+
#include "numpy/npy_math.h"
7+
8+
#include <cmath>
9+
#include <complex>
10+
#include <limits>
11+
12+
// For memcpy
13+
#include <cstring>
14+
15+
16+
//
17+
// Trivial C++ wrappers for several npy_* functions.
18+
//
19+
20+
#define CPP_WRAP1(name) \
21+
inline float name(const float x) \
22+
{ \
23+
return ::npy_ ## name ## f(x); \
24+
} \
25+
inline double name(const double x) \
26+
{ \
27+
return ::npy_ ## name(x); \
28+
} \
29+
inline long double name(const long double x) \
30+
{ \
31+
return ::npy_ ## name ## l(x); \
32+
} \
33+
34+
#define CPP_WRAP2(name) \
35+
inline float name(const float x, \
36+
const float y) \
37+
{ \
38+
return ::npy_ ## name ## f(x, y); \
39+
} \
40+
inline double name(const double x, \
41+
const double y) \
42+
{ \
43+
return ::npy_ ## name(x, y); \
44+
} \
45+
inline long double name(const long double x, \
46+
const long double y) \
47+
{ \
48+
return ::npy_ ## name ## l(x, y); \
49+
} \
50+
51+
52+
namespace npy {
53+
54+
CPP_WRAP1(fabs)
55+
CPP_WRAP1(log)
56+
CPP_WRAP1(log1p)
57+
CPP_WRAP2(atan2)
58+
CPP_WRAP2(hypot)
59+
60+
}
61+
62+
namespace log1p_complex
63+
{
64+
65+
template<typename T>
66+
struct doubled_t {
67+
T upper;
68+
T lower;
69+
};
70+
71+
//
72+
// There are three functions below where it is crucial that the
73+
// expressions are not optimized. E.g. `t - (t - x)` must not be
74+
// simplified by the compiler to just `x`. The NO_OPT macro defines
75+
// an attribute that should turn off optimization for the function.
76+
//
77+
// The inclusion of `gnu::target("fpmath=sse")` when __GNUC__ and
78+
// __i386 are defined also turns off the use of the floating-point
79+
// unit '387'. It is important that when the type is, for example,
80+
// `double`, these functions compute their results with 64 bit
81+
// precision, and not with 80 bit extended precision.
82+
//
83+
84+
#if defined(__clang__)
85+
#define NO_OPT [[clang::optnone]]
86+
#elif defined(__GNUC__)
87+
#if defined(__i386)
88+
#define NO_OPT [[gnu::optimize(0),gnu::target("fpmath=sse")]]
89+
#else
90+
#define NO_OPT [[gnu::optimize(0)]]
91+
#endif
92+
#else
93+
#define NO_OPT
94+
#endif
95+
96+
//
97+
// Dekker splitting. See, for example, Theorem 1 of
98+
//
99+
// Seppa Linnainmaa, Software for Double-Precision Floating-Point
100+
// Computations, ACM Transactions on Mathematical Software, Vol 7, No 3,
101+
// September 1981, pages 272-283.
102+
//
103+
// or Theorem 17 of
104+
//
105+
// J. R. Shewchuk, Adaptive Precision Floating-Point Arithmetic and
106+
// Fast Robust Geometric Predicates, CMU-CS-96-140R, from Discrete &
107+
// Computational Geometry 18(3):305-363, October 1997.
108+
//
109+
template<typename T>
110+
NO_OPT inline void
111+
split(T x, doubled_t<T>& out)
112+
{
113+
if (std::numeric_limits<T>::digits == 106) {
114+
// Special case: IBM double-double format. The value is already
115+
// split in memory, so there is no need for any calculations.
116+
std::memcpy(&out, &x, sizeof(out));
117+
}
118+
else {
119+
constexpr int halfprec = (std::numeric_limits<T>::digits + 1)/2;
120+
T t = ((1ull << halfprec) + 1)*x;
121+
// The compiler must not be allowed to simplify this expression:
122+
out.upper = t - (t - x);
123+
out.lower = x - out.upper;
124+
}
125+
}
126+
127+
template<typename T>
128+
NO_OPT inline void
129+
two_sum_quick(T x, T y, doubled_t<T>& out)
130+
{
131+
T r = x + y;
132+
T e = y - (r - x);
133+
out.upper = r;
134+
out.lower = e;
135+
}
136+
137+
template<typename T>
138+
NO_OPT inline void
139+
two_sum(T x, T y, doubled_t<T>& out)
140+
{
141+
T s = x + y;
142+
T v = s - x;
143+
T e = (x - (s - v)) + (y - v);
144+
out.upper = s;
145+
out.lower = e;
146+
}
147+
148+
template<typename T>
149+
inline void
150+
double_sum(const doubled_t<T>& x, const doubled_t<T>& y,
151+
doubled_t<T>& out)
152+
{
153+
two_sum<T>(x.upper, y.upper, out);
154+
out.lower += x.lower + y.lower;
155+
two_sum_quick<T>(out.upper, out.lower, out);
156+
}
157+
158+
template<typename T>
159+
inline void
160+
square(T x, doubled_t<T>& out)
161+
{
162+
doubled_t<T> xsplit;
163+
out.upper = x*x;
164+
split(x, xsplit);
165+
out.lower = xsplit.lower*xsplit.lower
166+
- ((out.upper - xsplit.upper*xsplit.upper)
167+
- 2*xsplit.lower*xsplit.upper);
168+
}
169+
170+
//
171+
// As the name makes clear, this function computes x**2 + 2*x + y**2.
172+
// It uses doubled_t<T> for the intermediate calculations.
173+
// (That is, we give the floating point type T an upgrayedd, spelled with
174+
// two d's for a double dose of precision.)
175+
//
176+
// The function is used in log1p_complex() to avoid the loss of
177+
// precision that can occur in the expression when x**2 + y**2 ≈ -2*x.
178+
//
179+
template<typename T>
180+
inline T
181+
xsquared_plus_2x_plus_ysquared_dd(T x, T y)
182+
{
183+
doubled_t<T> x2, y2, twox, sum1, sum2;
184+
185+
square<T>(x, x2); // x2 = x**2
186+
square<T>(y, y2); // y2 = y**2
187+
twox.upper = 2*x; // twox = 2*x
188+
twox.lower = 0.0;
189+
double_sum<T>(x2, twox, sum1); // sum1 = x**2 + 2*x
190+
double_sum<T>(sum1, y2, sum2); // sum2 = x**2 + 2*x + y**2
191+
return sum2.upper;
192+
}
193+
194+
//
195+
// For the float type, the intermediate calculation is done
196+
// with the double type. We don't need to use doubled_t<float>.
197+
//
198+
inline float
199+
xsquared_plus_2x_plus_ysquared(float x, float y)
200+
{
201+
double xd = x;
202+
double yd = y;
203+
return xd*(2.0 + xd) + yd*yd;
204+
}
205+
206+
//
207+
// For double, we used doubled_t<double> if long double doesn't have
208+
// at least 106 bits of precision.
209+
//
210+
inline double
211+
xsquared_plus_2x_plus_ysquared(double x, double y)
212+
{
213+
if (std::numeric_limits<long double>::digits >= 106) {
214+
// Cast to long double for the calculation.
215+
long double xd = x;
216+
long double yd = y;
217+
return xd*(2.0L + xd) + yd*yd;
218+
}
219+
else {
220+
// Use doubled_t<double> for the calculation.
221+
return xsquared_plus_2x_plus_ysquared_dd<double>(x, y);
222+
}
223+
}
224+
225+
//
226+
// For long double, we always use doubled_t<long double> for the
227+
// calculation.
228+
//
229+
inline long double
230+
xsquared_plus_2x_plus_ysquared(long double x, long double y)
231+
{
232+
return xsquared_plus_2x_plus_ysquared_dd<long double>(x, y);
233+
}
234+
235+
//
236+
// Implement log1p(z) for complex inputs that are near the unit circle
237+
// centered at -1+0j.
238+
//
239+
// The function assumes that neither component of z is nan.
240+
//
241+
template<typename T>
242+
inline std::complex<T>
243+
log1p_complex(std::complex<T> z)
244+
{
245+
T x = z.real();
246+
T y = z.imag();
247+
// The input is close to the unit circle centered at -1+0j.
248+
// Compute x**2 + 2*x + y**2 with higher precision than T.
249+
// The calculation here is equivalent to log(hypot(x+1, y)),
250+
// since
251+
// log(hypot(x+1, y)) = 0.5*log(x**2 + 2*x + 1 + y**2)
252+
// = 0.5*log1p(x**2 + 2*x + y**2)
253+
T t = xsquared_plus_2x_plus_ysquared(x, y);
254+
T lnr = 0.5*npy::log1p(t);
255+
return std::complex<T>(lnr, npy::atan2(y, x + static_cast<T>(1)));
256+
}
257+
258+
} // namespace log1p_complex
259+
260+
#endif

‎numpy/_core/src/umath/umathmodule.c

Copy file name to clipboardExpand all lines: numpy/_core/src/umath/umathmodule.c
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "string_ufuncs.h"
3131
#include "stringdtype_ufuncs.h"
3232
#include "special_integer_comparisons.h"
33+
#include "clog1p_wrappers.h"
3334
#include "extobj.h" /* for _extobject_contextvar exposure */
3435
#include "ufunc_type_resolution.h"
3536

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.