diff --git a/src/backend/cuda/CMakeLists.txt b/src/backend/cuda/CMakeLists.txt index 8490c541a0..ece17d962f 100644 --- a/src/backend/cuda/CMakeLists.txt +++ b/src/backend/cuda/CMakeLists.txt @@ -636,6 +636,7 @@ endif() target_compile_options(afcuda PRIVATE + $<$:$<$:-use_fast_math>> $<$:--expt-relaxed-constexpr> $<$:-Xcudafe --diag_suppress=unrecognized_gcc_pragma> $<$: $<$: -Xcompiler=/wd4251 diff --git a/src/backend/cuda/compile_module.cpp b/src/backend/cuda/compile_module.cpp index 3f5bd17d84..de22e8c493 100644 --- a/src/backend/cuda/compile_module.cpp +++ b/src/backend/cuda/compile_module.cpp @@ -261,6 +261,10 @@ Module compileModule(const string &moduleKey, span sources, arch.data(), "--std=c++14", "--device-as-default-execution-space", +#ifdef AF_WITH_FAST_MATH + "--use_fast_math", + "-DAF_WITH_FAST_MATH", +#endif #if !(defined(NDEBUG) || defined(__aarch64__) || defined(__LP64__)) "--device-debug", "--generate-line-info" diff --git a/src/backend/cuda/kernel/jit.cuh b/src/backend/cuda/kernel/jit.cuh index 4681c151ed..cf69146114 100644 --- a/src/backend/cuda/kernel/jit.cuh +++ b/src/backend/cuda/kernel/jit.cuh @@ -59,8 +59,13 @@ typedef cuDoubleComplex cdouble; #define __rem(lhs, rhs) ((lhs) % (rhs)) #define __mod(lhs, rhs) ((lhs) % (rhs)) +#ifdef AF_WITH_FAST_MATH +#define __pow(lhs, rhs) \ + static_cast(pow(static_cast(lhs), static_cast(rhs))); +#else #define __pow(lhs, rhs) \ __float2int_rn(pow(__int2float_rn((int)lhs), __int2float_rn((int)rhs))) +#endif #define __powll(lhs, rhs) \ __double2ll_rn(pow(__ll2double_rn(lhs), __ll2double_rn(rhs))) #define __powul(lhs, rhs) \ diff --git a/src/backend/cuda/math.hpp b/src/backend/cuda/math.hpp index 5987017fa7..23aa1a449b 100644 --- a/src/backend/cuda/math.hpp +++ b/src/backend/cuda/math.hpp @@ -32,6 +32,12 @@ namespace cuda { +#ifdef AF_WITH_FAST_MATH +constexpr bool fast_math = true; +#else +constexpr bool fast_math = false; +#endif + template static inline __DH__ T abs(T val) { return ::abs(val); @@ -138,29 +144,22 @@ __DH__ static To scalar(Ti real, Ti imag) { } #ifndef __CUDA_ARCH__ + template inline T maxval() { - return std::numeric_limits::max(); + if constexpr (std::is_floating_point_v && !fast_math) { + return std::numeric_limits::infinity(); + } else { + return std::numeric_limits::max(); + } } template inline T minval() { - return std::numeric_limits::min(); -} -template<> -inline float maxval() { - return std::numeric_limits::infinity(); -} -template<> -inline double maxval() { - return std::numeric_limits::infinity(); -} -template<> -inline float minval() { - return -std::numeric_limits::infinity(); -} -template<> -inline double minval() { - return -std::numeric_limits::infinity(); + if constexpr (std::is_floating_point_v && !fast_math) { + return -std::numeric_limits::infinity(); + } else { + return std::numeric_limits::lowest(); + } } #else template diff --git a/src/backend/cuda/platform.cpp b/src/backend/cuda/platform.cpp index 7e82f76843..d3b7c2efd9 100644 --- a/src/backend/cuda/platform.cpp +++ b/src/backend/cuda/platform.cpp @@ -94,6 +94,12 @@ unique_handle *cublasManager(const int deviceId) { // call outside of call_once scope. CUBLAS_CHECK( cublasSetStream(handles[deviceId], cuda::getStream(deviceId))); +#ifdef AF_WITH_FAST_MATH + CUBLAS_CHECK( + cublasSetMathMode(handles[deviceId], CUBLAS_TF32_TENSOR_OP_MATH)); + CUBLAS_CHECK( + cublasSetAtomicsMode(handles[deviceId], CUBLAS_ATOMICS_ALLOWED)); +#endif }); return &handles[deviceId]; diff --git a/src/backend/opencl/compile_module.cpp b/src/backend/opencl/compile_module.cpp index 83d66eb740..f931bb554a 100644 --- a/src/backend/opencl/compile_module.cpp +++ b/src/backend/opencl/compile_module.cpp @@ -126,6 +126,10 @@ Program buildProgram(span kernelSources, ostringstream options; for (auto &opt : compileOpts) { options << opt; } +#ifdef AF_WITH_FAST_MATH + options << " -cl-fast-relaxed-math -DAF_WITH_FAST_MATH"; +#endif + retVal.build({device}, (cl_std + defaults + options.str()).c_str()); } catch (Error &err) { if (err.err() == CL_BUILD_PROGRAM_FAILURE) { diff --git a/src/backend/opencl/math.hpp b/src/backend/opencl/math.hpp index e1e9c28f12..e7cf8d1928 100644 --- a/src/backend/opencl/math.hpp +++ b/src/backend/opencl/math.hpp @@ -106,40 +106,27 @@ static To scalar(Ti real, Ti imag) { return cval; } +#ifdef AF_WITH_FAST_MATH +constexpr bool fast_math = true; +#else +constexpr bool fast_math = false; +#endif + template inline T maxval() { - return std::numeric_limits::max(); + if constexpr (std::is_floating_point_v && !fast_math) { + return std::numeric_limits::infinity(); + } else { + return std::numeric_limits::max(); + } } template inline T minval() { - return std::numeric_limits::min(); -} -template<> -inline float maxval() { - return std::numeric_limits::infinity(); -} -template<> -inline double maxval() { - return std::numeric_limits::infinity(); -} - -template<> -inline common::half maxval() { - return std::numeric_limits::infinity(); -} - -template<> -inline float minval() { - return -std::numeric_limits::infinity(); -} - -template<> -inline double minval() { - return -std::numeric_limits::infinity(); -} -template<> -inline common::half minval() { - return -std::numeric_limits::infinity(); + if constexpr (std::is_floating_point_v && !fast_math) { + return -std::numeric_limits::infinity(); + } else { + return std::numeric_limits::lowest(); + } } static inline double real(cdouble in) { return in.s[0]; } diff --git a/test/reduce.cpp b/test/reduce.cpp index 5afdf70648..ef5b33bb1c 100644 --- a/test/reduce.cpp +++ b/test/reduce.cpp @@ -2296,6 +2296,7 @@ TEST(Reduce, Test_Sum_Global_Array_nanval) { } TEST(Reduce, nanval_issue_3255) { + SKIP_IF_FAST_MATH_ENABLED(); char *info_str; af_array ikeys, ivals, okeys, ovals; dim_t dims[1] = {8};