diff --git a/dpctl/apis/include/dpctl4pybind11.hpp b/dpctl/apis/include/dpctl4pybind11.hpp index 3181291b8d..0bea3538e6 100644 --- a/dpctl/apis/include/dpctl4pybind11.hpp +++ b/dpctl/apis/include/dpctl4pybind11.hpp @@ -36,6 +36,9 @@ #include #include +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace py = pybind11; namespace dpctl @@ -722,6 +725,43 @@ template <> struct type_caster PYBIND11_TYPE_CASTER(sycl::half, _("float")); }; +/* This type caster associates + * ``sycl:ext::oneapi::complex`` C++ classes with Python :class:`complex` for + * the purposes of generation of Python bindings by pybind11. + */ +template +struct type_caster> +{ +public: + bool load(handle src, bool convert) + { + if (!src) { + return false; + } + if (!convert && !PyComplex_Check(src.ptr())) { + return false; + } + Py_complex result = PyComplex_AsCComplex(src.ptr()); + if (result.real == -1.0 && PyErr_Occurred()) { + PyErr_Clear(); + return false; + } + value = sycl::ext::oneapi::experimental::complex( + static_cast(result.real), static_cast(result.imag)); + return true; + } + + static handle cast(const sycl::ext::oneapi::experimental::complex &src, + return_value_policy /* policy */, + handle /* parent */) + { + return PyComplex_FromDoubles((double)src.real(), (double)src.imag()); + } + + PYBIND11_TYPE_CASTER(sycl::ext::oneapi::experimental::complex, + const_name("sycl_complex")); +}; + } // namespace detail } // namespace pybind11 diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index 48c8e3e4dd..34f0b24d55 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -32,6 +32,7 @@ #include #include "cabs_impl.hpp" +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -118,8 +119,8 @@ template struct AbsOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, float>, - td_ns::TypeMapResultEntry, double>, + td_ns::TypeMapResultEntry, float>, + td_ns::TypeMapResultEntry, double>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp index 7dbfb6618c..4801e1a1d7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp @@ -72,9 +72,10 @@ template struct AcosFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isnan(x)) { /* acos(NaN + I*+-Inf) = NaN + I*-+Inf */ @@ -106,12 +107,10 @@ template struct AcosFunctor constexpr realT r_eps = realT(1) / std::numeric_limits::epsilon(); if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = exprm_ns::complex; - sycl_complexT log_in = - exprm_ns::log(exprm_ns::complex(in)); + sycl_complexT log_z = exprm_ns::log(z); - const realT wx = log_in.real(); - const realT wy = log_in.imag(); + const realT wx = log_z.real(); + const realT wy = log_z.imag(); const realT rx = sycl::fabs(wy); realT ry = wx + sycl::log(realT(2)); @@ -119,7 +118,7 @@ template struct AcosFunctor } /* ordinary cases */ - return exprm_ns::acos(exprm_ns::complex(in)); // acos(in); + return exprm_ns::acos(z); // acos(z); } else { static_assert(std::is_floating_point_v || @@ -152,8 +151,8 @@ template struct AcosOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp index a81ff3da99..e85d20b545 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp @@ -77,17 +77,19 @@ template struct AcoshFunctor * where the sign is chosen so Re(acosh(in)) >= 0. * So, we first calculate acos(in) and then acosh(in). */ - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); - resT acos_in; + sycl_complexT acos_z; if (std::isnan(x)) { /* acos(NaN + I*+-Inf) = NaN + I*-+Inf */ if (std::isinf(y)) { - acos_in = resT{q_nan, -y}; + acos_z = resT{q_nan, -y}; } else { - acos_in = resT{q_nan, q_nan}; + acos_z = resT{q_nan, q_nan}; } } else if (std::isnan(y)) { @@ -95,15 +97,15 @@ template struct AcoshFunctor constexpr realT inf = std::numeric_limits::infinity(); if (std::isinf(x)) { - acos_in = resT{q_nan, -inf}; + acos_z = resT{q_nan, -inf}; } /* acos(0 + I*NaN) = Pi/2 + I*NaN with inexact */ else if (x == realT(0)) { const realT pi_half = sycl::atan(realT(1)) * 2; - acos_in = resT{pi_half, q_nan}; + acos_z = resT{pi_half, q_nan}; } else { - acos_in = resT{q_nan, q_nan}; + acos_z = resT{q_nan, q_nan}; } } @@ -113,23 +115,21 @@ template struct AcoshFunctor * For large x or y including acos(+-Inf + I*+-Inf) */ if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = typename exprm_ns::complex; - const sycl_complexT log_in = exprm_ns::log(sycl_complexT(in)); - const realT wx = log_in.real(); - const realT wy = log_in.imag(); + const sycl_complexT log_z = exprm_ns::log(z); + const realT wx = log_z.real(); + const realT wy = log_z.imag(); const realT rx = sycl::fabs(wy); realT ry = wx + sycl::log(realT(2)); - acos_in = resT{rx, (sycl::signbit(y)) ? ry : -ry}; + acos_z = resT{rx, (sycl::signbit(y)) ? ry : -ry}; } else { /* ordinary cases */ - acos_in = - exprm_ns::acos(exprm_ns::complex(in)); // acos(in); + acos_z = exprm_ns::acos(z); // acos(z); } /* Now we calculate acosh(z) */ - const realT rx = std::real(acos_in); - const realT ry = std::imag(acos_in); + const realT rx = exprm_ns::real(acos_z); + const realT ry = exprm_ns::imag(acos_z); /* acosh(NaN + I*NaN) = NaN + I*NaN */ if (std::isnan(rx) && std::isnan(ry)) { @@ -145,7 +145,7 @@ template struct AcoshFunctor return resT{ry, ry}; } /* ordinary cases */ - const realT res_im = sycl::copysign(rx, std::imag(in)); + const realT res_im = sycl::copysign(rx, exprm_ns::imag(z)); return resT{sycl::fabs(ry), res_im}; } else { @@ -179,8 +179,8 @@ template struct AcoshOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index 476e7b52b9..d4c67e6b38 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -185,15 +185,15 @@ template struct AddOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; @@ -515,13 +515,13 @@ template struct AddInplaceTypePairSupport td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, + exprm_ns::complex, resTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, resTy, - std::complex>, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp index 726f90ba81..d0921376b6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp @@ -95,8 +95,8 @@ using AngleStridedFunctor = elementwise_common:: template struct AngleOutputType { using value_type = typename std::disjunction< - td_ns::TypeMapResultEntry, float>, - td_ns::TypeMapResultEntry, double>, + td_ns::TypeMapResultEntry, float>, + td_ns::TypeMapResultEntry, double>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp index 70b48895b4..815f8fa06a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp @@ -80,8 +80,10 @@ template struct AsinFunctor * y = imag(I * conj(in)) = real(in) * and then return {imag(w), real(w)} which is asin(in) */ - const realT x = std::imag(in); - const realT y = std::real(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::imag(z); + const realT y = exprm_ns::real(z); if (std::isnan(x)) { /* asinh(NaN + I*+-Inf) = opt(+-)Inf + I*NaN */ @@ -120,26 +122,24 @@ template struct AsinFunctor constexpr realT r_eps = realT(1) / std::numeric_limits::epsilon(); if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = exprm_ns::complex; - const sycl_complexT z{x, y}; + const sycl_complexT z1{x, y}; realT wx, wy; if (!sycl::signbit(x)) { - const auto log_z = exprm_ns::log(z); - wx = log_z.real() + sycl::log(realT(2)); - wy = log_z.imag(); + const auto log_z1 = exprm_ns::log(z1); + wx = log_z1.real() + sycl::log(realT(2)); + wy = log_z1.imag(); } else { - const auto log_mz = exprm_ns::log(-z); - wx = log_mz.real() + sycl::log(realT(2)); - wy = log_mz.imag(); + const auto log_mz1 = exprm_ns::log(-z1); + wx = log_mz1.real() + sycl::log(realT(2)); + wy = log_mz1.imag(); } const realT asinh_re = sycl::copysign(wx, x); const realT asinh_im = sycl::copysign(wy, y); return resT{asinh_im, asinh_re}; } /* ordinary cases */ - return exprm_ns::asin( - exprm_ns::complex(in)); // sycl::asin(in); + return exprm_ns::asin(z); // sycl::asin(z); } else { static_assert(std::is_floating_point_v || @@ -172,8 +172,8 @@ template struct AsinOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp index 420ba3246c..52273012f6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp @@ -72,9 +72,10 @@ template struct AsinhFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isnan(x)) { /* asinh(NaN + I*+-Inf) = opt(+-)Inf + I*NaN */ @@ -109,12 +110,10 @@ template struct AsinhFunctor realT(1) / std::numeric_limits::epsilon(); if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = exprm_ns::complex; - sycl_complexT log_in = (sycl::signbit(x)) - ? exprm_ns::log(sycl_complexT(-in)) - : exprm_ns::log(sycl_complexT(in)); - realT wx = log_in.real() + sycl::log(realT(2)); - realT wy = log_in.imag(); + sycl_complexT log_in = + (sycl::signbit(x)) ? exprm_ns::log(-z) : exprm_ns::log(z); + realT wx = exprm_ns::real(log_in) + sycl::log(realT(2)); + realT wy = exprm_ns::imag(log_in); const realT res_re = sycl::copysign(wx, x); const realT res_im = sycl::copysign(wy, y); @@ -122,7 +121,7 @@ template struct AsinhFunctor } /* ordinary cases */ - return exprm_ns::asinh(exprm_ns::complex(in)); // asinh(in); + return exprm_ns::asinh(z); // asinh(z); } else { static_assert(std::is_floating_point_v || @@ -155,8 +154,8 @@ template struct AsinhOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp index 29c4941d76..55b889b245 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp @@ -83,8 +83,11 @@ template struct AtanFunctor * y = imag(I * conj(in)) = real(in) * and then return {imag(w), real(w)} which is atan(in) */ - const realT x = std::imag(in); - const realT y = std::real(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::imag(z); + const realT y = exprm_ns::real(z); + if (std::isnan(x)) { /* atanh(NaN + I*+-Inf) = sign(NaN)*0 + I*+-Pi/2 */ if (std::isinf(y)) { @@ -132,7 +135,7 @@ template struct AtanFunctor return resT{atanh_im, atanh_re}; } /* ordinary cases */ - return exprm_ns::atan(exprm_ns::complex(in)); // atan(in); + return exprm_ns::atan(z); // atan(z); } else { static_assert(std::is_floating_point_v || @@ -165,8 +168,8 @@ template struct AtanOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp index 39f11e0f90..a737b70489 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp @@ -73,8 +73,10 @@ template struct AtanhFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isnan(x)) { /* atanh(NaN + I*+-Inf) = sign(NaN)0 + I*+-PI/2 */ @@ -123,7 +125,7 @@ template struct AtanhFunctor return resT{res_re, res_im}; } /* ordinary cases */ - return exprm_ns::atanh(exprm_ns::complex(in)); // atanh(in); + return exprm_ns::atanh(z); // atanh(z); } else { static_assert(std::is_floating_point_v || @@ -156,8 +158,8 @@ template struct AtanhOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp index afa83a64cb..1fddf84400 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp @@ -24,7 +24,6 @@ #pragma once #include -#include #include #include "sycl_complex.hpp" @@ -38,7 +37,7 @@ namespace kernels namespace detail { -template realT cabs(std::complex const &z) +template realT cabs(exprm_ns::complex const &z) { // Special values for cabs( x + y * 1j): // * If x is either +infinity or -infinity and y is any value @@ -51,8 +50,10 @@ template realT cabs(std::complex const &z) // * If x is a finite number and y is NaN, the result is NaN. // * If x is NaN and y is NaN, the result is NaN. - const realT x = std::real(z); - const realT y = std::imag(z); + using sycl_complexT = exprm_ns::complex; + sycl_complexT _z = exprm_ns::complex(z); + const realT x = exprm_ns::real(_z); + const realT y = exprm_ns::imag(_z); constexpr realT q_nan = std::numeric_limits::quiet_NaN(); constexpr realT p_inf = std::numeric_limits::infinity(); @@ -60,11 +61,8 @@ template realT cabs(std::complex const &z) const realT res = std::isinf(x) ? p_inf - : ((std::isinf(y) - ? p_inf - : ((std::isnan(x) - ? q_nan - : exprm_ns::abs(exprm_ns::complex(z)))))); + : ((std::isinf(y) ? p_inf + : ((std::isnan(x) ? q_nan : exprm_ns::abs(_z))))); return res; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp index 19a95df5a1..a2403fec92 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp @@ -115,8 +115,8 @@ template struct ConjOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index 5940315c62..fbfcaaf0a4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -72,30 +72,31 @@ template struct CosFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT z_re = exprm_ns::real(z); + const realT z_im = exprm_ns::imag(z); - realT const &in_re = std::real(in); - realT const &in_im = std::imag(in); - - const bool in_re_finite = std::isfinite(in_re); - const bool in_im_finite = std::isfinite(in_im); + const bool z_re_finite = std::isfinite(z_re); + const bool z_im_finite = std::isfinite(z_im); /* * Handle the nearly-non-exceptional cases where * real and imaginary parts of input are finite. */ - if (in_re_finite && in_im_finite) { - return exprm_ns::cos(exprm_ns::complex(in)); // cos(in); + if (z_re_finite && z_im_finite) { + return exprm_ns::cos(z); // cos(z); } /* - * since cos(in) = cosh(I * in), for special cases, - * we return cosh(I * in). + * since cos(z) = cosh(I * z), for special cases, + * we return cosh(I * z). */ - const realT x = -in_im; - const realT y = in_re; + const realT x = -z_im; + const realT y = z_re; - const bool xfinite = in_im_finite; - const bool yfinite = in_re_finite; + const bool xfinite = z_im_finite; + const bool yfinite = z_re_finite; /* * cosh(+-0 +- I Inf) = dNaN + I sign(d(+-0, dNaN))0. * The sign of 0 in the result is unspecified. Choice = normally @@ -187,9 +188,12 @@ template struct CosOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, std::complex>, - td_ns:: - TypeMapResultEntry, std::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp index 59468428d1..affa8b3a21 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp @@ -73,8 +73,10 @@ template struct CoshFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); const bool xfinite = std::isfinite(x); const bool yfinite = std::isfinite(y); @@ -84,8 +86,7 @@ template struct CoshFunctor * real and imaginary parts of input are finite. */ if (xfinite && yfinite) { - return exprm_ns::cosh( - exprm_ns::complex(in)); // cosh(in); + return exprm_ns::cosh(z); // cosh(z); } /* @@ -177,8 +178,8 @@ template struct CoshOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index a53f6412de..6f191d0b40 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -178,14 +178,14 @@ template struct EqualOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::DefaultResultEntry>::result_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp index 00f8213251..cfcec5e660 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp @@ -72,12 +72,13 @@ template struct ExpFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isfinite(x)) { if (std::isfinite(y)) { - return exprm_ns::exp( - exprm_ns::complex(in)); // exp(in); + return exprm_ns::exp(z); // exp(z); } else { return resT{q_nan, q_nan}; @@ -86,7 +87,7 @@ template struct ExpFunctor else if (std::isnan(x)) { /* x is nan */ if (y == realT(0)) { - return resT{in}; + return resT{z}; } else { return resT{x, q_nan}; @@ -146,8 +147,8 @@ template struct ExpOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp index 22291101ca..e975758e3e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp @@ -71,15 +71,18 @@ template struct Exp2Functor if constexpr (is_complex::value) { using realT = typename argT::value_type; - const argT tmp = in * sycl::log(realT(2)); - constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(tmp); - const realT y = std::imag(tmp); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); + + const sycl_complexT tmp = z * sycl::log(realT(2)); + if (std::isfinite(x)) { if (std::isfinite(y)) { - return exprm_ns::exp(exprm_ns::complex(tmp)); + return exprm_ns::exp(tmp); } else { return resT{q_nan, q_nan}; @@ -148,8 +151,8 @@ template struct Exp2OutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index d1d64f4904..1d06fe0e81 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -31,6 +31,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -73,8 +74,10 @@ template struct Expm1Functor using realT = typename argT::value_type; // expm1(x + I*y) = expm1(x)*cos(y) - 2*sin(y / 2)^2 + // I*exp(x)*sin(y) - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); // special cases if (std::isinf(x)) { @@ -160,9 +163,12 @@ template struct Expm1OutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, std::complex>, - td_ns:: - TypeMapResultEntry, std::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp index 4d0b7fb94f..d9d01b8c97 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/math_utils.hpp" @@ -179,14 +180,14 @@ template struct GreaterOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::DefaultResultEntry>::result_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp index b149158ee0..2b4bbab9fe 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/math_utils.hpp" @@ -180,14 +181,14 @@ template struct GreaterEqualOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::DefaultResultEntry>::result_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index 89adabff41..bd41ba4e8a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -25,12 +25,12 @@ #pragma once #include -#include #include #include #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -53,14 +53,16 @@ using dpctl::tensor::ssize_t; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::is_complex_v; template struct ImagFunctor { // is function constant for given argT - using is_constant = typename std::false_type; + using is_constant = + typename std::is_same, std::false_type>; // constant value, if constant - // constexpr resT constant_value = resT{}; + static constexpr resT constant_value = resT{0}; // is function defined for sycl::vec using supports_vec = typename std::false_type; // do both argTy and resTy support sugroup store/load operation @@ -69,12 +71,14 @@ template struct ImagFunctor resT operator()(const argT &in) const { - if constexpr (is_complex::value) { - return std::imag(in); + if constexpr (is_complex_v) { + using realT = typename argT::value_type; + using sycl_complexT = typename exprm_ns::complex; + return exprm_ns::real(sycl_complexT(in)); } else { static_assert(std::is_same_v); - return resT{0}; + return constant_value; } } }; @@ -111,8 +115,8 @@ template struct ImagOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, float>, - td_ns::TypeMapResultEntry, double>, + td_ns::TypeMapResultEntry, float>, + td_ns::TypeMapResultEntry, double>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; @@ -173,7 +177,7 @@ template struct ImagContigFactory template struct ImagTypeMapFactory { - /*! @brief get typeid for output type of std::imag(T x) */ + /*! @brief get typeid for output type of imag(T x) */ std::enable_if_t::value, int> get() { using rT = typename ImagOutputType::value_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index b0651a4d8b..32f6addf2f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -69,8 +70,11 @@ template struct IsFiniteFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - const bool real_isfinite = std::isfinite(std::real(in)); - const bool imag_isfinite = std::isfinite(std::imag(in)); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const bool real_isfinite = std::isfinite(exprm_ns::real(z)); + const bool imag_isfinite = std::isfinite(exprm_ns::imag(z)); return (real_isfinite && imag_isfinite); } else if constexpr (std::is_same::value || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index ec78746143..87a215bd9f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -69,8 +70,11 @@ template struct IsInfFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - const bool real_isinf = std::isinf(std::real(in)); - const bool imag_isinf = std::isinf(std::imag(in)); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const bool real_isinf = std::isinf(exprm_ns::real(z)); + const bool imag_isinf = std::isinf(exprm_ns::imag(z)); return (real_isinf || imag_isinf); } else if constexpr (std::is_same::value || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index fbf6ef9383..abac44be84 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -29,6 +29,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -70,8 +71,11 @@ template struct IsNanFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - const bool real_isnan = sycl::isnan(std::real(in)); - const bool imag_isnan = sycl::isnan(std::imag(in)); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const bool real_isnan = sycl::isnan(exprm_ns::real(z)); + const bool imag_isnan = sycl::isnan(exprm_ns::imag(z)); return (real_isnan || imag_isnan); } else if constexpr (std::is_same::value || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp index 523410a161..05dd3994c0 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp @@ -29,6 +29,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/math_utils.hpp" @@ -177,14 +178,14 @@ template struct LessOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::DefaultResultEntry>::result_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp index 5827d350a3..e7a30de4ce 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/math_utils.hpp" @@ -178,14 +179,14 @@ template struct LessEqualOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::DefaultResultEntry>::result_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp index 84471a5ef4..f0a64a9b3b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp @@ -24,7 +24,6 @@ #pragma once #include -#include #include #include #include @@ -102,9 +101,12 @@ template struct LogOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, std::complex>, - td_ns:: - TypeMapResultEntry, std::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp index d308c85ac9..202627543d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp @@ -25,7 +25,6 @@ #pragma once #include -#include #include #include #include @@ -121,9 +120,12 @@ template struct Log10OutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, std::complex>, - td_ns:: - TypeMapResultEntry, std::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp index b8d993dd94..941326c148 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -78,8 +79,11 @@ template struct Log1pFunctor // = log1p(x^2 + 2x + y^2) / 2 // + I * atan2(y, x + 1) using realT = typename argT::value_type; - const realT x = std::real(in); - const realT y = std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); // imaginary part of result const realT res_im = sycl::atan2(y, x + 1); @@ -126,9 +130,12 @@ template struct Log1pOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, std::complex>, - td_ns:: - TypeMapResultEntry, std::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp index 42c837cfa3..a90ea08c7b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp @@ -25,7 +25,6 @@ #pragma once #include -#include #include #include #include @@ -122,9 +121,12 @@ template struct Log2OutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, std::complex>, - td_ns:: - TypeMapResultEntry, std::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp index 0aa1f61b90..7c4da12e5c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -148,14 +149,14 @@ template struct LogicalAndOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::DefaultResultEntry>::result_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp index 1fdcd84f60..85c3c3d749 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -147,14 +148,14 @@ template struct LogicalOrOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::DefaultResultEntry>::result_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp index 0ef3b17dff..a3a4337430 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -149,14 +150,14 @@ template struct LogicalXorOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::DefaultResultEntry>::result_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp index 799cbb1d8c..78b507c934 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp @@ -29,6 +29,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/math_utils.hpp" @@ -181,15 +182,15 @@ template struct MaximumOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp index 9a672e539f..cb842331d3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp @@ -29,6 +29,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/math_utils.hpp" @@ -181,15 +182,15 @@ template struct MinimumOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp index ca24383b44..49fa4c306a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp @@ -24,7 +24,6 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include @@ -173,15 +172,15 @@ template struct MultiplyOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; @@ -475,13 +474,13 @@ template struct MultiplyInplaceTypePairSupport td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, + exprm_ns::complex, resTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, resTy, - std::complex>, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp index 47707a5f04..35c49cdbb2 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -93,8 +94,8 @@ template struct NegativeOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp index f21bc678fd..99be19a8dd 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp @@ -29,6 +29,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -162,14 +163,14 @@ template struct NotEqualOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, + exprm_ns::complex, bool>, td_ns::DefaultResultEntry>::result_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp index df6c04021f..f9fe225b5e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -108,8 +109,8 @@ template struct PositiveOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp index d7b0ed909e..0301d15bff 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp @@ -225,15 +225,15 @@ template struct PowOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; @@ -485,13 +485,13 @@ template struct PowInplaceTypePairSupport td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, + exprm_ns::complex, resTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, resTy, - std::complex>, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp index df5edface1..d12ffadcc6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp @@ -25,13 +25,13 @@ #pragma once #include -#include #include #include #include #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -70,8 +70,11 @@ template struct ProjFunctor resT operator()(const argT &in) const { using realT = typename argT::value_type; - const realT x = std::real(in); - const realT y = std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isinf(x)) { return value_at_infinity(y); @@ -85,10 +88,11 @@ template struct ProjFunctor } private: - template std::complex value_at_infinity(const T &y) const + template + exprm_ns::complex value_at_infinity(const T &y) const { const T res_im = sycl::copysign(T(0), y); - return std::complex{std::numeric_limits::infinity(), res_im}; + return exprm_ns::complex{std::numeric_limits::infinity(), res_im}; } }; @@ -112,8 +116,8 @@ using ProjStridedFunctor = elementwise_common:: template struct ProjOutputType { using value_type = typename std::disjunction< - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; @@ -166,7 +170,7 @@ template struct ProjContigFactory return fn; } else { - if constexpr (std::is_same_v>) { + if constexpr (std::is_same_v>) { fnT fn = proj_contig_impl; return fn; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index bb22352907..7258a29b56 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -31,6 +31,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -53,6 +54,7 @@ using dpctl::tensor::ssize_t; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::is_complex_v; template struct RealFunctor { @@ -69,8 +71,10 @@ template struct RealFunctor resT operator()(const argT &in) const { - if constexpr (is_complex::value) { - return std::real(in); + if constexpr (is_complex_v) { + using realT = typename argT::value_type; + using sycl_complexT = typename exprm_ns::complex; + return exprm_ns::real(sycl_complexT(in)); } else { static_assert(std::is_same_v); @@ -111,8 +115,8 @@ template struct RealOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, float>, - td_ns::TypeMapResultEntry, double>, + td_ns::TypeMapResultEntry, float>, + td_ns::TypeMapResultEntry, double>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; @@ -173,7 +177,7 @@ template struct RealContigFactory template struct RealTypeMapFactory { - /*! @brief get typeid for output type of std::real(T x) */ + /*! @brief get typeid for output type of real(T x) */ std::enable_if_t::value, int> get() { using rT = typename RealOutputType::value_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp index 0e46acba39..bbae8abbce 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp @@ -26,7 +26,6 @@ #pragma once #include -#include #include #include #include @@ -108,8 +107,8 @@ template struct ReciprocalOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index 7fbb20ae32..7a5857a6b1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -29,6 +29,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -67,14 +68,15 @@ template struct RoundFunctor resT operator()(const argT &in) const { - if constexpr (std::is_integral_v) { return in; } else if constexpr (is_complex::value) { using realT = typename argT::value_type; - return resT{round_func(std::real(in)), - round_func(std::imag(in))}; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + return resT{round_func(exprm_ns::real(z)), + round_func(exprm_ns::imag(z))}; } else { return round_func(in); @@ -119,8 +121,8 @@ template struct RoundOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp index baa224942f..4c1be9b568 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp @@ -31,6 +31,7 @@ #include #include "cabs_impl.hpp" +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -131,8 +132,8 @@ template struct SignOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index e075a90a88..7ada68fb12 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -72,8 +72,11 @@ template struct SinFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - realT const &in_re = std::real(in); - realT const &in_im = std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + realT const &in_re = exprm_ns::real(z); + realT const &in_im = exprm_ns::imag(z); const bool in_re_finite = std::isfinite(in_re); const bool in_im_finite = std::isfinite(in_im); @@ -82,8 +85,7 @@ template struct SinFunctor * real and imaginary parts of input are finite. */ if (in_re_finite && in_im_finite) { - resT res = - exprm_ns::sin(exprm_ns::complex(in)); // sin(in); + resT res = exprm_ns::sin(z); // sin(z); if (in_re == realT(0)) { res.real(sycl::copysign(realT(0), in_re)); } @@ -91,9 +93,9 @@ template struct SinFunctor } /* - * since sin(in) = -I * sinh(I * in), for special cases, - * we calculate real and imaginary parts of z = sinh(I * in) and - * then return { imag(z) , -real(z) } which is sin(in). + * since sin(z) = -I * sinh(I * z), for special cases, + * we calculate real and imaginary parts of z = sinh(I * z) and + * then return { imag(z) , -real(z) } which is sin(z). */ const realT x = -in_im; const realT y = in_re; @@ -210,8 +212,8 @@ template struct SinOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp index 23b3588a3b..cfe0f0c11e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp @@ -70,9 +70,10 @@ template struct SinhFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); const bool xfinite = std::isfinite(x); const bool yfinite = std::isfinite(y); @@ -82,7 +83,7 @@ template struct SinhFunctor * real and imaginary parts of input are finite. */ if (xfinite && yfinite) { - return exprm_ns::sinh(exprm_ns::complex(in)); + return exprm_ns::sinh(z); } /* * sinh(+-0 +- I Inf) = sign(d(+-0, dNaN))0 + I dNaN. @@ -179,8 +180,8 @@ template struct SinhOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index b83ff72495..726432cac9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -25,7 +25,6 @@ #pragma once #include -#include #include #include #include @@ -104,9 +103,12 @@ template struct SqrtOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry, std::complex>, - td_ns:: - TypeMapResultEntry, std::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, + td_ns::TypeMapResultEntry, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp index f9d9d848c0..8b09b81884 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp @@ -130,8 +130,8 @@ template struct SquareOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp index 51a3955142..0e7cba08d5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp @@ -29,6 +29,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -159,15 +160,15 @@ template struct SubtractOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; @@ -473,13 +474,13 @@ template struct SubtractInplaceTypePairSupport td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, + exprm_ns::complex, resTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, resTy, - std::complex>, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp index 770518f918..a215eb0f8d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp @@ -24,7 +24,6 @@ #pragma once #include -#include #include #include #include @@ -75,12 +74,14 @@ template struct TanFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); /* - * since tan(in) = -I * tanh(I * in), for special cases, - * we calculate real and imaginary parts of z = tanh(I * in) and - * return { imag(z) , -real(z) } which is tan(in). + * since tan(z) = -I * tanh(I * z), for special cases, + * we calculate real and imaginary parts of z = tanh(I * z) and + * return { imag(z) , -real(z) } which is tan(z). */ - const realT x = -std::imag(in); - const realT y = std::real(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = -exprm_ns::imag(z); + const realT y = exprm_ns::real(z); /* * tanh(NaN + i 0) = NaN + i 0 * @@ -121,7 +122,7 @@ template struct TanFunctor return resT{q_nan, q_nan}; } /* ordinary cases */ - return exprm_ns::tan(exprm_ns::complex(in)); // tan(in); + return exprm_ns::tan(z); // tan(z); } else { static_assert(std::is_floating_point_v || @@ -154,8 +155,8 @@ template struct TanOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp index 1d06fd3c4f..2e57dc15f5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp @@ -25,7 +25,6 @@ #pragma once #include -#include #include #include #include @@ -75,8 +74,10 @@ template struct TanhFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); /* * tanh(NaN + i 0) = NaN + i 0 * @@ -115,7 +116,7 @@ template struct TanhFunctor return resT{q_nan, q_nan}; } /* ordinary cases */ - return exprm_ns::tanh(exprm_ns::complex(in)); // tanh(in); + return exprm_ns::tanh(z); // tanh(z); } else { static_assert(std::is_floating_point_v || @@ -148,8 +149,8 @@ template struct TanhOutputType td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, - td_ns::TypeMapResultEntry>, - td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index de6c9a8723..b3f9657ddb 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -146,35 +146,35 @@ template struct TrueDivideOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, float, - std::complex>, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, double, - std::complex>, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; @@ -470,16 +470,20 @@ struct TrueDivideInplaceTypePairSupport td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, + td_ns:: + TypePairDefinedEntry>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, resTy, - std::complex>, - td_ns::TypePairDefinedEntry>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + double, resTy, - std::complex>, + exprm_ns::complex>, + td_ns::TypePairDefinedEntry, + resTy, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; }; diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp index 71e2c15b6b..a430ddbd6d 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -1014,7 +1014,7 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q, } else { constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; // more than one work-groups is needed, requires a temporary std::size_t reduction_groups = @@ -1261,7 +1261,7 @@ dot_product_contig_tree_impl(sycl::queue &exec_q, } else { constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; // more than one work-groups is needed, requires a temporary std::size_t reduction_groups = diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 4ad4eb142a..6213bbf977 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -24,6 +24,9 @@ #pragma once +#define SYCL_EXT_ONEAPI_COMPLEX +#include + #include #include #include @@ -47,6 +50,8 @@ namespace tensor namespace kernels { +namespace exprm_ns = sycl::ext::oneapi::experimental; + using dpctl::tensor::ssize_t; namespace gemm_detail @@ -1187,14 +1192,14 @@ struct GemmBatchFunctorThreadNM_vecm_HyperParametersSelector } else if constexpr (sizeof(resT) == 8) { // 8 * 2 * 1 * 4 == 64 - if constexpr (std::is_same_v>) { + if constexpr (std::is_same_v>) { return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 4, 1); } else { return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 1, 4); } } - else if constexpr (std::is_same_v>) { + else if constexpr (std::is_same_v>) { // 16 * 2 * 2 * 1 == 64 return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 2, 1); } @@ -2316,7 +2321,7 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, sycl::logical_or, sycl::plus>::type; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = @@ -2612,7 +2617,7 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, sycl::logical_or, sycl::plus>::type; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; @@ -2985,7 +2990,7 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, sycl::logical_or, sycl::plus>::type; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = @@ -3173,7 +3178,7 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, sycl::logical_or, sycl::plus>::type; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; @@ -3544,7 +3549,7 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, sycl::logical_or, sycl::plus>::type; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = @@ -3698,7 +3703,7 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, sycl::logical_or, sycl::plus>::type; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; @@ -3934,7 +3939,7 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, sycl::logical_or, sycl::plus>::type; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = @@ -4073,7 +4078,7 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, sycl::logical_or, sycl::plus>::type; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index f056d246c9..fdbcabd721 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -40,6 +40,9 @@ #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace dpctl { namespace tensor @@ -50,6 +53,7 @@ namespace kernels using dpctl::tensor::ssize_t; namespace td_ns = dpctl::tensor::type_dispatch; namespace su_ns = dpctl::tensor::sycl_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace reduction_detail { @@ -1914,8 +1918,8 @@ struct SequentialSearchReduction using dpctl::tensor::math_utils::less_complex; // less_complex always returns false for NaNs, so check if (less_complex(val, red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(exprm_ns::real(val)) || + std::isnan(exprm_ns::imag(val))) { red_val = val; idx_val = static_cast(m); @@ -1941,8 +1945,8 @@ struct SequentialSearchReduction if constexpr (is_complex::value) { using dpctl::tensor::math_utils::greater_complex; if (greater_complex(val, red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(exprm_ns::real(val)) || + std::isnan(exprm_ns::imag(val))) { red_val = val; idx_val = static_cast(m); @@ -2230,8 +2234,8 @@ struct CustomSearchReduction // less_complex always returns false for NaNs, so // check if (less_complex(val, local_red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(exprm_ns::real(val)) || + std::isnan(exprm_ns::imag(val))) { local_red_val = val; if constexpr (!First) { @@ -2277,8 +2281,8 @@ struct CustomSearchReduction if constexpr (is_complex::value) { using dpctl::tensor::math_utils::greater_complex; if (greater_complex(val, local_red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(exprm_ns::real(val)) || + std::isnan(exprm_ns::imag(val))) { local_red_val = val; if constexpr (!First) { @@ -2330,8 +2334,8 @@ struct CustomSearchReduction if constexpr (is_complex::value) { // equality does not hold for NaNs, so check here local_idx = (red_val_over_wg == local_red_val || - std::isnan(std::real(local_red_val)) || - std::isnan(std::imag(local_red_val))) + std::isnan(exprm_ns::real(local_red_val)) || + std::isnan(exprm_ns::imag(local_red_val))) ? local_idx : idx_identity_; } diff --git a/dpctl/tensor/libtensor/include/utils/math_utils.hpp b/dpctl/tensor/libtensor/include/utils/math_utils.hpp index a49b56b6ba..ecd9c1fe18 100644 --- a/dpctl/tensor/libtensor/include/utils/math_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/math_utils.hpp @@ -24,7 +24,8 @@ #pragma once #include -#include +#define SYCL_EXT_ONEAPI_COMPLEX +#include #include namespace dpctl @@ -34,13 +35,18 @@ namespace tensor namespace math_utils { +namespace exprm_ns = sycl::ext::oneapi::experimental; + template bool less_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 < imag2) @@ -50,10 +56,13 @@ template bool less_complex(const T &x1, const T &x2) template bool greater_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 > imag2) @@ -63,10 +72,13 @@ template bool greater_complex(const T &x1, const T &x2) template bool less_equal_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 <= imag2) @@ -76,10 +88,13 @@ template bool less_equal_complex(const T &x1, const T &x2) template bool greater_equal_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 >= imag2) @@ -89,10 +104,13 @@ template bool greater_equal_complex(const T &x1, const T &x2) template T max_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); bool isnan_imag1 = std::isnan(imag1); bool gt = (real1 == real2) @@ -104,10 +122,13 @@ template T max_complex(const T &x1, const T &x2) template T min_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); bool isnan_imag1 = std::isnan(imag1); bool lt = (real1 == real2) diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index ece8852643..78837b51c6 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -32,12 +32,18 @@ #include "math_utils.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace dpctl { namespace tensor { namespace sycl_utils { + +namespace exprm_ns = sycl::ext::oneapi::experimental; + namespace detail { @@ -79,7 +85,7 @@ template struct IsContained : std::false_type template struct IsComplex : std::false_type { }; -template struct IsComplex> : std::true_type +template struct IsComplex> : std::true_type { }; @@ -380,11 +386,12 @@ struct GetIdentity::value>> template struct GetIdentity, - std::enable_if_t, Op>::value>> + exprm_ns::complex, + std::enable_if_t, Op>::value>> { - static constexpr std::complex value{-std::numeric_limits::infinity(), - -std::numeric_limits::infinity()}; + static constexpr exprm_ns::complex value{ + -std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; }; // Minimum @@ -413,11 +420,11 @@ struct GetIdentity::value>> template struct GetIdentity, - std::enable_if_t, Op>::value>> + exprm_ns::complex, + std::enable_if_t, Op>::value>> { - static constexpr std::complex value{std::numeric_limits::infinity(), - std::numeric_limits::infinity()}; + static constexpr exprm_ns::complex value{ + std::numeric_limits::infinity(), std::numeric_limits::infinity()}; }; // Plus @@ -429,6 +436,15 @@ using IsPlus = std::bool_constant> || template using IsSyclPlus = std::bool_constant>>; +template +struct GetIdentity, + std::enable_if_t, Op>::value>> +{ + static constexpr exprm_ns::complex value{static_cast(0), + static_cast(0)}; +}; + // Multiplies template diff --git a/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp b/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp index 4394d9a4b1..a0db2f4dc5 100644 --- a/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp @@ -29,6 +29,8 @@ #include #include +#define SYCL_EXT_ONEAPI_COMPLEX +#include #include namespace dpctl @@ -39,6 +41,8 @@ namespace tensor namespace type_dispatch { +namespace exprm_ns = sycl::ext::oneapi::experimental; + enum class typenum_t : int { BOOL = 0, @@ -81,8 +85,8 @@ class DispatchTableBuilder factory{}.get(), factory{}.get(), factory{}.get(), - factory>{}.get(), - factory>{}.get()}; + factory>{}.get(), + factory>{}.get()}; assert(per_dstTy.size() == _num_types); return per_dstTy; } @@ -93,20 +97,21 @@ class DispatchTableBuilder void populate_dispatch_table(funcPtrT table[][_num_types]) const { - const auto map_by_dst_type = {row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type>(), - row_per_dst_type>()}; + const auto map_by_dst_type = { + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type>(), + row_per_dst_type>()}; assert(map_by_dst_type.size() == _num_types); int dst_id = 0; for (const auto &row : map_by_dst_type) { @@ -139,20 +144,21 @@ class DispatchVectorBuilder void populate_dispatch_vector(funcPtrT vector[]) const { - const auto fn_map_by_type = {func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type>(), - func_per_type>()}; + const auto fn_map_by_type = { + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type>(), + func_per_type>()}; assert(fn_map_by_type.size() == _num_types); int ty_id = 0; for (const auto &fn : fn_map_by_type) { @@ -229,10 +235,10 @@ template struct GetTypeid else if constexpr (std::is_same_v) { return static_cast(typenum_t::DOUBLE); } - else if constexpr (std::is_same_v>) { + else if constexpr (std::is_same_v>) { return static_cast(typenum_t::CFLOAT); } - else if constexpr (std::is_same_v>) { + else if constexpr (std::is_same_v>) { return static_cast(typenum_t::CDOUBLE); } else if constexpr (std::is_same_v) { // special token diff --git a/dpctl/tensor/libtensor/include/utils/type_utils.hpp b/dpctl/tensor/libtensor/include/utils/type_utils.hpp index e9493aebfe..24c04c3c7e 100644 --- a/dpctl/tensor/libtensor/include/utils/type_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_utils.hpp @@ -29,6 +29,8 @@ #include #include #include +#define SYCL_EXT_ONEAPI_COMPLEX +#include namespace dpctl { @@ -37,6 +39,8 @@ namespace tensor namespace type_utils { +namespace exprm_ns = sycl::ext::oneapi::experimental; + template struct is_complex : public std::false_type { @@ -45,8 +49,9 @@ struct is_complex : public std::false_type template struct is_complex< T, - std::enable_if_t, std::complex> || - std::is_same_v, std::complex>>> + std::enable_if_t< + std::is_same_v, exprm_ns::complex> || + std::is_same_v, exprm_ns::complex>>> : public std::true_type { }; @@ -91,7 +96,7 @@ template void validate_type_for_device(const sycl::device &d) " does not support type 'float64'"); } } - else if constexpr (std::is_same_v>) { + else if constexpr (std::is_same_v>) { if (!d.has(sycl::aspect::fp64)) { throw std::runtime_error("Device " + d.get_info() + diff --git a/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp b/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp index 045b1b330e..d2fb618a29 100644 --- a/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp +++ b/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp @@ -35,6 +35,9 @@ #include "utils/sycl_utils.hpp" #include "utils/type_dispatch_building.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace py = pybind11; namespace dpctl @@ -46,6 +49,7 @@ namespace py_internal namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace impl { @@ -117,16 +121,16 @@ struct TypePairSupportDataForProdAccumulation // input double td_ns::TypePairDefinedEntry, - // input std::complex + // input exprm_ns::complex td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; diff --git a/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp b/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp index e44678e15f..9227b6f238 100644 --- a/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp +++ b/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp @@ -35,6 +35,9 @@ #include "utils/sycl_utils.hpp" #include "utils/type_dispatch_building.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace py = pybind11; namespace dpctl @@ -46,6 +49,7 @@ namespace py_internal namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace impl { @@ -117,16 +121,16 @@ struct TypePairSupportDataForSumAccumulation // input double td_ns::TypePairDefinedEntry, - // input std::complex + // input exprm_ns::complex td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp index 0e3fb38015..256fbbef8e 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp @@ -48,6 +48,9 @@ #include "kernels/elementwise_functions/common_inplace.hpp" #include "kernels/elementwise_functions/true_divide.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace py = pybind11; namespace dpctl @@ -58,6 +61,7 @@ namespace py_internal { namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; using ew_cmn_ns::binary_contig_impl_fn_ptr_t; @@ -332,13 +336,13 @@ py_divide_by_scalar(const dpctl::tensor::usm_ndarray &src, } case complex64_typeid: { - fn = divide_by_scalar, float>; + fn = divide_by_scalar, float>; std::ignore = new (scalar_alloc) float(scalar); break; } case complex128_typeid: { - fn = divide_by_scalar, double>; + fn = divide_by_scalar, double>; std::ignore = new (scalar_alloc) double(scalar); break; } diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp index 2437ed40bb..3047b452ae 100644 --- a/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp @@ -32,6 +32,9 @@ #include "kernels/linalg_functions/gemm.hpp" #include "utils/type_dispatch_building.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace dpctl { namespace tensor @@ -40,6 +43,7 @@ namespace py_internal { namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct DotAtomicOutputType { @@ -146,20 +150,20 @@ template struct DotNoAtomicOutputType td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::BinaryTypeMapResultEntry, + exprm_ns::complex, T2, - std::complex, - std::complex>, + exprm_ns::complex, + exprm_ns::complex>, td_ns::DefaultResultEntry>::result_type; static constexpr bool is_defined = !std::is_same_v; diff --git a/dpctl/tensor/libtensor/source/reductions/argmax.cpp b/dpctl/tensor/libtensor/source/reductions/argmax.cpp index 2e6bcfddd3..fb5f86e305 100644 --- a/dpctl/tensor/libtensor/source/reductions/argmax.cpp +++ b/dpctl/tensor/libtensor/source/reductions/argmax.cpp @@ -35,6 +35,9 @@ #include "utils/sycl_utils.hpp" #include "utils/type_dispatch_building.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace py = pybind11; namespace dpctl @@ -46,6 +49,7 @@ namespace py_internal namespace td_ns = dpctl::tensor::type_dispatch; namespace su_ns = dpctl::tensor::sycl_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace impl { @@ -102,14 +106,14 @@ struct TypePairSupportForArgmaxReductionTemps // input double td_ns::TypePairDefinedEntry, - // input std::complex + // input exprm_ns::complex td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, std::int64_t>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, std::int64_t>, diff --git a/dpctl/tensor/libtensor/source/reductions/argmin.cpp b/dpctl/tensor/libtensor/source/reductions/argmin.cpp index 883ec1d397..c16202748a 100644 --- a/dpctl/tensor/libtensor/source/reductions/argmin.cpp +++ b/dpctl/tensor/libtensor/source/reductions/argmin.cpp @@ -36,6 +36,9 @@ #include "utils/sycl_utils.hpp" #include "utils/type_dispatch_building.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace py = pybind11; namespace dpctl @@ -47,6 +50,7 @@ namespace py_internal namespace td_ns = dpctl::tensor::type_dispatch; namespace su_ns = dpctl::tensor::sycl_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace impl { @@ -103,14 +107,14 @@ struct TypePairSupportForArgminReductionTemps // input double td_ns::TypePairDefinedEntry, - // input std::complex + // input exprm_ns::complex td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, std::int64_t>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, std::int64_t>, diff --git a/dpctl/tensor/libtensor/source/reductions/max.cpp b/dpctl/tensor/libtensor/source/reductions/max.cpp index 55fff60f9b..035edbfcd7 100644 --- a/dpctl/tensor/libtensor/source/reductions/max.cpp +++ b/dpctl/tensor/libtensor/source/reductions/max.cpp @@ -37,6 +37,9 @@ #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace py = pybind11; namespace dpctl @@ -48,6 +51,7 @@ namespace py_internal namespace td_ns = dpctl::tensor::type_dispatch; namespace su_ns = dpctl::tensor::sycl_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace impl { @@ -133,16 +137,16 @@ struct TypePairSupportDataForMaxReductionTemps // input double td_ns::TypePairDefinedEntry, - // input std::complex + // input exprm_ns::complex td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; diff --git a/dpctl/tensor/libtensor/source/reductions/min.cpp b/dpctl/tensor/libtensor/source/reductions/min.cpp index 1ff5331bf0..7846d41f6e 100644 --- a/dpctl/tensor/libtensor/source/reductions/min.cpp +++ b/dpctl/tensor/libtensor/source/reductions/min.cpp @@ -37,6 +37,9 @@ #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace py = pybind11; namespace dpctl @@ -48,6 +51,7 @@ namespace py_internal namespace td_ns = dpctl::tensor::type_dispatch; namespace su_ns = dpctl::tensor::sycl_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace impl { @@ -133,16 +137,16 @@ struct TypePairSupportDataForMinReductionTemps // input double td_ns::TypePairDefinedEntry, - // input std::complex + // input exprm_ns::complex td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; diff --git a/dpctl/tensor/libtensor/source/reductions/prod.cpp b/dpctl/tensor/libtensor/source/reductions/prod.cpp index 7c768ce179..db136eafd8 100644 --- a/dpctl/tensor/libtensor/source/reductions/prod.cpp +++ b/dpctl/tensor/libtensor/source/reductions/prod.cpp @@ -36,6 +36,9 @@ #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace py = pybind11; namespace dpctl @@ -46,6 +49,7 @@ namespace py_internal { namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace impl { @@ -193,37 +197,46 @@ struct TypePairSupportDataForProductReductionTemps td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, - td_ns:: - TypePairDefinedEntry>, td_ns::TypePairDefinedEntry>, + exprm_ns::complex>, + td_ns::TypePairDefinedEntry>, // input float td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, - td_ns::TypePairDefinedEntry>, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, // input double td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, - // input std::complex + // input exprm_ns::complex td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; diff --git a/dpctl/tensor/libtensor/source/reductions/sum.cpp b/dpctl/tensor/libtensor/source/reductions/sum.cpp index f449a6cde3..2fd5780167 100644 --- a/dpctl/tensor/libtensor/source/reductions/sum.cpp +++ b/dpctl/tensor/libtensor/source/reductions/sum.cpp @@ -36,6 +36,9 @@ #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace py = pybind11; namespace dpctl @@ -46,6 +49,7 @@ namespace py_internal { namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace impl { @@ -193,37 +197,46 @@ struct TypePairSupportDataForSumReductionTemps td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, - td_ns:: - TypePairDefinedEntry>, td_ns::TypePairDefinedEntry>, + exprm_ns::complex>, + td_ns::TypePairDefinedEntry>, // input float td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, - td_ns::TypePairDefinedEntry>, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, // input double td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, - // input std::complex + // input exprm_ns::complex td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, td_ns::TypePairDefinedEntry, + exprm_ns::complex, outTy, - std::complex>, + exprm_ns::complex>, // fall-through td_ns::NotDefinedEntry>::is_defined; diff --git a/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp b/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp index 2aaa1cfafa..817578811e 100644 --- a/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp +++ b/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp @@ -24,7 +24,9 @@ #pragma once +#define SYCL_EXT_ONEAPI_COMPLEX #include "sycl/sycl.hpp" +#include #include namespace dpctl @@ -53,6 +55,8 @@ template struct ExtendedRealFPGreater } }; +namespace exprm_ns = sycl::ext::oneapi::experimental; + template struct ExtendedComplexFPLess { /* [(R, R), (R, nan), (nan, R), (nan, nan)] */ @@ -60,15 +64,17 @@ template struct ExtendedComplexFPLess bool operator()(const cT &v1, const cT &v2) const { using realT = typename cT::value_type; - - const realT real1 = std::real(v1); - const realT real2 = std::real(v2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(v1); + sycl_complexT z2 = sycl_complexT(v2); + const realT real1 = exprm_ns::real(z1); + const realT real2 = exprm_ns::real(z2); const bool r1_nan = std::isnan(real1); const bool r2_nan = std::isnan(real2); - const realT imag1 = std::imag(v1); - const realT imag2 = std::imag(v2); + const realT imag1 = exprm_ns::imag(z1); + const realT imag2 = exprm_ns::imag(z2); const bool i1_nan = std::isnan(imag1); const bool i2_nan = std::isnan(imag2); @@ -112,9 +118,9 @@ template struct AscendingSorter std::less>; }; -template struct AscendingSorter> +template struct AscendingSorter> { - using type = ExtendedComplexFPLess>; + using type = ExtendedComplexFPLess>; }; template struct DescendingSorter @@ -124,9 +130,9 @@ template struct DescendingSorter std::greater>; }; -template struct DescendingSorter> +template struct DescendingSorter> { - using type = ExtendedComplexFPGreater>; + using type = ExtendedComplexFPGreater>; }; } // end of namespace py_internal diff --git a/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp b/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp index 2dddd415ce..67717c8745 100644 --- a/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp +++ b/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp @@ -34,9 +34,12 @@ #include #include #include "dpctl4pybind11.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include // clang-format on namespace py = pybind11; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::utils::keep_args_alive; @@ -119,7 +122,7 @@ py_gemv(sycl::queue &q, res_ev = gemv_ev; } else if (v_typenum == api.UAR_CDOUBLE_) { - using T = std::complex; + using T = exprm_ns::complex; sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv( q, oneapi::mkl::transpose::nontrans, n, m, T(1), reinterpret_cast(mat_typeless_ptr), m, @@ -128,7 +131,7 @@ py_gemv(sycl::queue &q, res_ev = gemv_ev; } else if (v_typenum == api.UAR_CFLOAT_) { - using T = std::complex; + using T = exprm_ns::complex; sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv( q, oneapi::mkl::transpose::nontrans, n, m, T(1), reinterpret_cast(mat_typeless_ptr), m, @@ -228,12 +231,12 @@ py_sub(sycl::queue q, out_r_typeless_ptr, depends); } else if (out_r_typenum == api.UAR_CDOUBLE_) { - using T = std::complex; + using T = exprm_ns::complex; res_ev = sub_impl(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr, out_r_typeless_ptr, depends); } else if (out_r_typenum == api.UAR_CFLOAT_) { - using T = std::complex; + using T = exprm_ns::complex; res_ev = sub_impl(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr, out_r_typeless_ptr, depends); } @@ -329,12 +332,12 @@ py_axpby_inplace(sycl::queue q, y_typeless_ptr, depends); } else if (x_typenum == api.UAR_CDOUBLE_) { - using T = std::complex; + using T = exprm_ns::complex; res_ev = axpby_inplace_impl(q, n, a, x_typeless_ptr, b, y_typeless_ptr, depends); } else if (x_typenum == api.UAR_CFLOAT_) { - using T = std::complex; + using T = exprm_ns::complex; res_ev = axpby_inplace_impl(q, n, a, x_typeless_ptr, b, y_typeless_ptr, depends); } @@ -367,8 +370,8 @@ T complex_norm_squared_blocking_impl( const char *r_typeless, const std::vector &depends = {}) { - const std::complex *r = - reinterpret_cast *>(r_typeless); + const exprm_ns::complex *r = + reinterpret_cast *>(r_typeless); return cg_solver::detail::complex_norm_squared_blocking(q, nelems, r, depends); @@ -419,13 +422,13 @@ py::object py_norm_squared_blocking(sycl::queue q, res = py::float_(n_sq); } else if (r_typenum == api.UAR_CDOUBLE_) { - using T = std::complex; + using T = exprm_ns::complex; double n_sq = complex_norm_squared_blocking_impl( q, n, r_typeless_ptr, depends); res = py::float_(n_sq); } else if (r_typenum == api.UAR_CFLOAT_) { - using T = std::complex; + using T = exprm_ns::complex; float n_sq = complex_norm_squared_blocking_impl( q, n, r_typeless_ptr, depends); res = py::float_(n_sq); @@ -506,7 +509,7 @@ py::object py_dot_blocking(sycl::queue q, res = py::float_(res_v); } else if (v1_typenum == api.UAR_CDOUBLE_) { - using T = std::complex; + using T = exprm_ns::complex; T *res_usm = sycl::malloc_device(1, q); sycl::event dotc_ev = oneapi::mkl::blas::row_major::dotc( q, n, reinterpret_cast(v1_typeless_ptr), 1, @@ -517,7 +520,7 @@ py::object py_dot_blocking(sycl::queue q, res = py::cast(res_v); } else if (v1_typenum == api.UAR_CFLOAT_) { - using T = std::complex; + using T = exprm_ns::complex; T *res_usm = sycl::malloc_device(1, q); sycl::event dotc_ev = oneapi::mkl::blas::row_major::dotc( q, n, reinterpret_cast(v1_typeless_ptr), 1, diff --git a/examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp b/examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp index 6de14aa9ec..eb551fd9ee 100644 --- a/examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp +++ b/examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp @@ -29,12 +29,16 @@ #include #include +#define SYCL_EXT_ONEAPI_COMPLEX +#include namespace cg_solver { namespace detail { +namespace exprm_ns = sycl::ext::oneapi::experimental; + template class sub_kern; template @@ -122,7 +126,7 @@ template class complex_norm_squared_blocking_kern; template T complex_norm_squared_blocking(sycl::queue &q, size_t nelems, - const std::complex *r, + const exprm_ns::complex *r, const std::vector &depends = {}) { sycl::buffer sum_sq_buf(sycl::range<1>{1});