Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 29 additions & 46 deletions sycl/include/sycl/stl_wrappers/__sycl_cmath_wrapper_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,27 @@
#define __SYCL_DEVICE_C \
extern "C" __attribute__((sycl_device_only, always_inline))

// For std::enable_if, std::is_integral, std::is_floating_point, std::is_same,
// and std::conjunction
#include <type_traits>

// Promotion templates: the C++ standard library provides overloads that allow
// arguments of math functions to be promoted. Any floating-point argument is
// allowed to accept any integer type, which should then be promoted to double.
// When multiple floating point arguments are available passing arguments with
// different precision should promote to the larger type. The template helpers
// below provide the machinery to define these promoting overloads.
template <typename T, bool = (std::is_integral<T>::value ||
std::is_floating_point<T>::value)>
template <typename T,
bool = (std::is_integral_v<T> || std::is_floating_point_v<T>)>
struct __sycl_promote {
private:
// Integer types are promoted to double.
template <typename U>
static typename std::enable_if<std::is_integral<U>::value, double>::type
test();
static std::enable_if_t<std::is_integral_v<U>, double> test();

// Floating point types are used as-is.
template <typename U>
static typename std::enable_if<std::is_floating_point<U>::value, U>::type
test();
static std::enable_if_t<std::is_floating_point_v<U>, U> test();

public:
// We rely on dummy templated methods and decltype to select the right type
Expand All @@ -48,29 +50,17 @@ struct __sycl_promote {
// Variant without ::type to allow SFINAE for non-promotable types.
template <typename T> struct __sycl_promote<T, false> {};

// With a single paramter we only need to promote integers.
template <typename T>
using __sycl_promote_1 = std::enable_if<std::is_integral<T>::value, double>;

// With two or three parameters we need to promote integers and possibly
// floating point types. We rely on operator+ with decltype to deduce the
// overall promotion type. This is only needed if at least one of the parameter
// is an integer, or if there's multiple different floating point types.
template <typename T, typename U>
using __sycl_promote_2 =
std::enable_if<!std::is_same<T, U>::value || std::is_integral<T>::value ||
std::is_integral<U>::value,
decltype(typename __sycl_promote<T>::type(0) +
typename __sycl_promote<U>::type(0))>;

template <typename T, typename U, typename V>
using __sycl_promote_3 =
std::enable_if<!(std::is_same<T, U>::value && std::is_same<U, V>::value) ||
std::is_integral<T>::value ||
std::is_integral<U>::value || std::is_integral<V>::value,
decltype(typename __sycl_promote<T>::type(0) +
typename __sycl_promote<U>::type(0) +
typename __sycl_promote<V>::type(0))>;
template <typename T, typename... Ts>
using __sycl_promote_t =
std::enable_if_t<!std::conjunction_v<std::is_same<T, Ts>...> ||
std::is_integral_v<T> ||
(std::is_integral_v<Ts> || ...),
decltype((typename __sycl_promote<Ts>::type(0) + ... +
typename __sycl_promote<T>::type(0)))>;

// For each math built-in we need to define float and double overloads, an
// extern "C" float variant with the 'f' suffix, and a version that promotes
Expand All @@ -85,8 +75,7 @@ using __sycl_promote_3 =
__SYCL_DEVICE_C float NAME##f(float x) { return __spirv_ocl_##NAME(x); } \
__SYCL_DEVICE float NAME(float x) { return __spirv_ocl_##NAME(x); } \
__SYCL_DEVICE double NAME(double x) { return __spirv_ocl_##NAME(x); } \
template <typename T> \
__SYCL_DEVICE typename __sycl_promote_1<T>::type NAME(T x) { \
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> NAME(T x) { \
return __spirv_ocl_##NAME((double)x); \
}

Expand All @@ -101,8 +90,8 @@ using __sycl_promote_3 =
return __spirv_ocl_##NAME(x, y); \
} \
template <typename T, typename U> \
__SYCL_DEVICE __sycl_promote_2<T, U>::type NAME(T x, U y) { \
typedef typename __sycl_promote_2<T, U>::type type; \
__SYCL_DEVICE __sycl_promote_t<T, U> NAME(T x, U y) { \
using type = __sycl_promote_t<T, U>; \
return __spirv_ocl_##NAME((type)x, (type)y); \
}

Expand All @@ -127,8 +116,7 @@ __SYCL_DEVICE double abs(double x) { return x < 0 ? -x : x; }
__SYCL_DEVICE float fabs(float x) { return x < 0 ? -x : x; }
__SYCL_DEVICE_C float fabsf(float x) { return x < 0 ? -x : x; }
__SYCL_DEVICE double fabs(double x) { return x < 0 ? -x : x; }
template <typename T>
__SYCL_DEVICE typename __sycl_promote_1<T>::type fabs(T x) {
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> fabs(T x) {
return x < 0 ? -x : x;
}

Expand All @@ -145,8 +133,8 @@ __SYCL_DEVICE double remquo(double x, double y, int *q) {
return __spirv_ocl_remquo(x, y, q);
}
template <typename T, typename U>
__SYCL_DEVICE typename __sycl_promote_2<T, U>::type remquo(T x, U y, int *q) {
typedef typename __sycl_promote_2<T, U>::type type;
__SYCL_DEVICE __sycl_promote_t<T, U> remquo(T x, U y, int *q) {
using type = __sycl_promote_t<T, U>;
return __spirv_ocl_remquo((type)x, (type)y, q);
}

Expand All @@ -160,8 +148,8 @@ __SYCL_DEVICE double fma(double x, double y, double z) {
return __spirv_ocl_fma(x, y, z);
}
template <typename T, typename U, typename V>
__SYCL_DEVICE typename __sycl_promote_3<T, U, V>::type fma(T x, U y, V z) {
typedef typename __sycl_promote_3<T, U, V>::type type;
__SYCL_DEVICE __sycl_promote_t<T, U, V> fma(T x, U y, V z) {
using type = __sycl_promote_t<T, U, V>;
return __spirv_ocl_fma((type)x, (type)y, (type)z);
}

Expand Down Expand Up @@ -256,8 +244,7 @@ __SYCL_DEVICE float frexp(float x, int *exp) {
__SYCL_DEVICE double frexp(double x, int *exp) {
return __spirv_ocl_frexp(x, exp);
}
template <typename T>
__SYCL_DEVICE typename __sycl_promote_1<T>::type frexp(T x, int *exp) {
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> frexp(T x, int *exp) {
return __spirv_ocl_frexp((double)x, exp);
}

Expand All @@ -270,8 +257,7 @@ __SYCL_DEVICE float ldexp(float x, int exp) {
__SYCL_DEVICE double ldexp(double x, int exp) {
return __spirv_ocl_ldexp(x, exp);
}
template <typename T>
__SYCL_DEVICE typename __sycl_promote_1<T>::type ldexp(T x, int exp) {
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> ldexp(T x, int exp) {
return __spirv_ocl_ldexp((double)x, exp);
}

Expand All @@ -286,7 +272,7 @@ __SYCL_DEVICE double modf(double x, double *intpart) {
}
// modf only supports integer x when the intpart is double.
template <typename T>
__SYCL_DEVICE typename __sycl_promote_1<T>::type modf(T x, double *intpart) {
__SYCL_DEVICE __sycl_promote_t<T> modf(T x, double *intpart) {
return __spirv_ocl_modf((double)x, intpart);
}

Expand All @@ -299,8 +285,7 @@ __SYCL_DEVICE float scalbn(float x, int exp) {
__SYCL_DEVICE double scalbn(double x, int exp) {
return __spirv_ocl_ldexp(x, exp);
}
template <typename T>
__SYCL_DEVICE typename __sycl_promote_1<T>::type scalbn(T x, int exp) {
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> scalbn(T x, int exp) {
return __spirv_ocl_ldexp((double)x, exp);
}

Expand All @@ -313,8 +298,7 @@ __SYCL_DEVICE float scalbln(float x, long exp) {
__SYCL_DEVICE double scalbln(double x, long exp) {
return __spirv_ocl_ldexp(x, (int)exp);
}
template <typename T>
__SYCL_DEVICE typename __sycl_promote_1<T>::type scalbln(T x, long exp) {
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> scalbln(T x, long exp) {
return __spirv_ocl_ldexp((double)x, (int)exp);
}

Expand All @@ -323,8 +307,7 @@ __SYCL_DEVICE int ilogb(float x) { return __spirv_ocl_ilogb(x); }
__SYCL_DEVICE int ilogb(double x) { return __spirv_ocl_ilogb(x); }
// ilogb needs a special template since its signature doesn't include the
// promoted type anywhere, so it needs to be specialized differently.
template <typename T, typename std::enable_if<std::is_integral<T>::value,
bool>::type = true>
template <typename T, std::enable_if_t<std::is_integral_v<T>, bool> = true>
__SYCL_DEVICE int ilogb(T x) {
return __spirv_ocl_ilogb((double)x);
}
Expand Down
Loading