From 624f14f2003ff26dcabee01db73e0a4e78304c54 Mon Sep 17 00:00:00 2001 From: Anton <100830759+antonwolfy@users.noreply.github.com> Date: Mon, 7 Jul 2025 17:35:03 +0200 Subject: [PATCH 1/7] Fix compilation warnings (#2517) This PR resolves compilation warnings: > warning: class template argument deduction for alias templates is a C++20 extension [-Wc++20-extensions] As for now DPNP is building with C++17 standard, and Class Template Argument Deduction is available where only for a class template, but not for a type alias. What support was added in C++20 standard. Since there are currently no plans to move to the C++20 standard, the code has been updated to conform to the C++17 standard. --- CHANGELOG.md | 1 + dpnp/backend/extensions/statistics/histogram.cpp | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b73e53543e3f..f1677aa78f7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Updated existing GitHub workflows to add testing with Python 3.13 [#2510](https://github.com/IntelPython/dpnp/pull/2510) * Aligned the license expression with `PEP-639` [#2511](https://github.com/IntelPython/dpnp/pull/2511) * Bumped oneMKL version up to `v0.8` [#2514](https://github.com/IntelPython/dpnp/pull/2514) +* Removed the use of class template argument deduction for alias template to conform to the C++17 standard [#2517](https://github.com/IntelPython/dpnp/pull/2517) ### Deprecated diff --git a/dpnp/backend/extensions/statistics/histogram.cpp b/dpnp/backend/extensions/statistics/histogram.cpp index 5e05a44858f2..be752566d514 100644 --- a/dpnp/backend/extensions/statistics/histogram.cpp +++ b/dpnp/backend/extensions/statistics/histogram.cpp @@ -139,12 +139,14 @@ struct HistogramF auto dispatch_edges = [&](uint32_t local_mem, const auto &weights, auto &hist) { if (device.is_gpu() && (local_mem >= bins_count + 1)) { - auto edges = CachedEdges(bins_edges, bins_count + 1, cgh); + auto edges = + CachedEdges(bins_edges, bins_count + 1, cgh); submit_histogram(in, size, dims, WorkPI, hist, edges, weights, nd_range, cgh); } else { - auto edges = UncachedEdges(bins_edges, bins_count + 1, cgh); + auto edges = + UncachedEdges(bins_edges, bins_count + 1, cgh); submit_histogram(in, size, dims, WorkPI, hist, edges, weights, nd_range, cgh); } @@ -165,7 +167,8 @@ struct HistogramF } else { auto hist = HistGlobalMemory(out); - auto edges = UncachedEdges(bins_edges, bins_count + 1, cgh); + auto edges = + UncachedEdges(bins_edges, bins_count + 1, cgh); submit_histogram(in, size, dims, WorkPI, hist, edges, weights, nd_range, cgh); } From d15d395fe584b9d820fd48a077ab251c03751946 Mon Sep 17 00:00:00 2001 From: Anton <100830759+antonwolfy@users.noreply.github.com> Date: Tue, 8 Jul 2025 11:29:18 +0200 Subject: [PATCH 2/7] Temporary mute tests failing in public CI (#2518) The PR temporary mutes tests which are sporadically failing in GitHub workflow on Windows. --- dpnp/tests/test_product.py | 8 ++++++++ dpnp/tests/test_umath.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index 5fa74edfe2d8..983fd2161d2d 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -12,7 +12,9 @@ assert_dtype_allclose, generate_random_numpy_array, get_all_dtypes, + is_gpu_device, is_ptl, + is_win_platform, numpy_version, ) from .third_party.cupy import testing @@ -1441,6 +1443,9 @@ class TestMatvec: def setup_method(self): numpy.random.seed(42) + @pytest.mark.skipif( + is_win_platform() and not is_gpu_device(), reason="SAT-8073" + ) @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize( "shape1, shape2", @@ -2167,6 +2172,9 @@ class TestVecmat: def setup_method(self): numpy.random.seed(42) + @pytest.mark.skipif( + is_win_platform() and not is_gpu_device(), reason="SAT-8073" + ) @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize( "shape1, shape2", diff --git a/dpnp/tests/test_umath.py b/dpnp/tests/test_umath.py index 56f55de2f1c7..0039d74789f7 100644 --- a/dpnp/tests/test_umath.py +++ b/dpnp/tests/test_umath.py @@ -23,6 +23,7 @@ has_support_aspect64, is_cuda_device, is_gpu_device, + is_win_platform, ) # full list of umaths @@ -121,6 +122,9 @@ def test_umaths(test_cases): pytest.skip("dpnp.modf is not supported with dpnp.float16") elif is_cuda_device(): pytest.skip("dpnp.modf is not supported on CUDA device") + elif umath in ["vecmat", "matvec"]: + if is_win_platform() and not is_gpu_device(): + pytest.skip("SAT-8073") expected = getattr(numpy, umath)(*args) result = getattr(dpnp, umath)(*iargs) From 9b73305babb28ceed6fd951675d8c2f8fd29416d Mon Sep 17 00:00:00 2001 From: Anton <100830759+antonwolfy@users.noreply.github.com> Date: Thu, 10 Jul 2025 12:56:45 +0200 Subject: [PATCH 3/7] Add implementation of `dpnp.ndarray.view` method (#2520) This PR adds implementation of `dpnp.ndarray.view` method. All places in the code with connected TODO comments were updated properly. --- CHANGELOG.md | 1 + dpnp/dpnp_array.py | 124 +++++++++++++++++- dpnp/dpnp_utils/dpnp_utils_einsum.py | 4 +- dpnp/linalg/dpnp_utils_linalg.py | 11 +- dpnp/tests/test_ndarray.py | 30 +++++ .../core_tests/test_ndarray_copy_and_view.py | 23 ++-- dpnp/tests/third_party/cupy/testing/_array.py | 5 +- 7 files changed, 171 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1677aa78f7c..a8d4adca875a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `--target-cuda[=ARCH]` option to replace the deprecated `--target=cuda`, allowing users to build for CUDA devices with optional architecture selection using [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/) [#2478](https://github.com/IntelPython/dpnp/pull/2478) * Added several new `pre-commit` rules, including protection against direct commits to master/maintenance branches [#2500](https://github.com/IntelPython/dpnp/pull/2500) +* Added implementation of `dpnp.ndarray.view` method [#2520](https://github.com/IntelPython/dpnp/pull/2520) ### Changed diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 48df4acf3b81..f47383619dc2 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -25,6 +25,7 @@ # ***************************************************************************** import dpctl.tensor as dpt +import dpctl.tensor._type_utils as dtu from dpctl.tensor._numpy_helper import AxisError import dpnp @@ -1979,5 +1980,126 @@ def var( correction=correction, ) + def view(self, dtype=None, *, type=None): + """ + New view of array with the same data. + + For full documentation refer to :obj:`numpy.ndarray.view`. + + Parameters + ---------- + dtype : {None, str, dtype object}, optional + The desired data type of the returned view, e.g. :obj:`dpnp.float32` + or :obj:`dpnp.int16`. By default, it results in the view having the + same data type. + + Default: ``None``. + + Notes + ----- + Passing ``None`` for `dtype` is the same as omitting the parameter, + opposite to NumPy where they have different meaning. + + ``view(some_dtype)`` or ``view(dtype=some_dtype)`` constructs a view of + the array's memory with a different data type. This can cause a + reinterpretation of the bytes of memory. + + Only the last axis has to be contiguous. -# 'view' + Limitations + ----------- + Parameter `type` is supported only with default value ``None``. + Otherwise, the function raises ``NotImplementedError`` exception. + + Examples + -------- + >>> import dpnp as np + >>> x = np.ones((4,), dtype=np.float32) + >>> xv = x.view(dtype=np.int32) + >>> xv[:] = 0 + >>> xv + array([0, 0, 0, 0], dtype=int32) + + However, views that change dtype are totally fine for arrays with a + contiguous last axis, even if the rest of the axes are not C-contiguous: + + >>> x = np.arange(2 * 3 * 4, dtype=np.int8).reshape(2, 3, 4) + >>> x.transpose(1, 0, 2).view(np.int16) + array([[[ 256, 770], + [3340, 3854]], + + [[1284, 1798], + [4368, 4882]], + + [[2312, 2826], + [5396, 5910]]], dtype=int16) + + """ + + if type is not None: + raise NotImplementedError( + "Keyword argument `type` is supported only with " + f"default value ``None``, but got {type}." + ) + + old_sh = self.shape + old_strides = self.strides + + if dtype is None: + return dpnp_array(old_sh, buffer=self, strides=old_strides) + + new_dt = dpnp.dtype(dtype) + new_dt = dtu._to_device_supported_dtype(new_dt, self.sycl_device) + + new_itemsz = new_dt.itemsize + old_itemsz = self.dtype.itemsize + if new_itemsz == old_itemsz: + return dpnp_array( + old_sh, dtype=new_dt, buffer=self, strides=old_strides + ) + + ndim = self.ndim + if ndim == 0: + raise ValueError( + "Changing the dtype of a 0d array is only supported " + "if the itemsize is unchanged" + ) + + # resize on last axis only + axis = ndim - 1 + if old_sh[axis] != 1 and self.size != 0 and old_strides[axis] != 1: + raise ValueError( + "To change to a dtype of a different size, " + "the last axis must be contiguous" + ) + + # normalize strides whenever itemsize changes + if old_itemsz > new_itemsz: + new_strides = list( + el * (old_itemsz // new_itemsz) for el in old_strides + ) + else: + new_strides = list( + el // (new_itemsz // old_itemsz) for el in old_strides + ) + new_strides[axis] = 1 + new_strides = tuple(new_strides) + + new_dim = old_sh[axis] * old_itemsz + if new_dim % new_itemsz != 0: + raise ValueError( + "When changing to a larger dtype, its size must be a divisor " + "of the total size in bytes of the last axis of the array" + ) + + # normalize shape whenever itemsize changes + new_sh = list(old_sh) + new_sh[axis] = new_dim // new_itemsz + new_sh = tuple(new_sh) + + return dpnp_array( + new_sh, + dtype=new_dt, + buffer=self, + strides=new_strides, + ) diff --git a/dpnp/dpnp_utils/dpnp_utils_einsum.py b/dpnp/dpnp_utils/dpnp_utils_einsum.py index 322d7dd2c148..12baacac3dc1 100644 --- a/dpnp/dpnp_utils/dpnp_utils_einsum.py +++ b/dpnp/dpnp_utils/dpnp_utils_einsum.py @@ -945,7 +945,6 @@ def _transpose_ex(a, axeses): stride = sum(a.strides[axis] for axis in axes) strides.append(stride) - # TODO: replace with a.view() when it is implemented in dpnp return dpnp_array( shape, dtype=a.dtype, @@ -1151,8 +1150,7 @@ def dpnp_einsum( operands[idx] = operands[idx].sum(axis=sum_axes, dtype=result_dtype) if returns_view: - # TODO: replace with a.view() when it is implemented in dpnp - operands = [a for a in operands] + operands = [a.view() for a in operands] else: operands = [ dpnp.astype(a, result_dtype, copy=False, casting=casting) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index b694b730c97c..51cebb2815bb 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -1290,15 +1290,8 @@ def _nrm2_last_axis(x): """ real_dtype = _real_type(x.dtype) - # TODO: use dpnp.sum(dpnp.square(dpnp.view(x)), axis=-1, dtype=real_dtype) - # w/a since dpnp.view() in not implemented yet - # Сalculate and sum the squares of both real and imaginary parts for - # compelex array. - if dpnp.issubdtype(x.dtype, dpnp.complexfloating): - y = dpnp.abs(x) ** 2 - else: - y = dpnp.square(x) - return dpnp.sum(y, axis=-1, dtype=real_dtype) + x = dpnp.ascontiguousarray(x) + return dpnp.sum(dpnp.square(x.view(real_dtype)), axis=-1) def _real_type(dtype, device=None): diff --git a/dpnp/tests/test_ndarray.py b/dpnp/tests/test_ndarray.py index 0a4fea422fc9..eaccf689a795 100644 --- a/dpnp/tests/test_ndarray.py +++ b/dpnp/tests/test_ndarray.py @@ -74,6 +74,36 @@ def test_attributes(self): assert_equal(self.two.itemsize, self.two.dtype.itemsize) +class TestView: + def test_none_dtype(self): + a = numpy.ones((1, 2, 4), dtype=numpy.int32) + ia = dpnp.array(a) + + expected = a.view() + result = ia.view() + assert_allclose(result, expected) + + expected = a.view() # numpy returns dtype(None) otherwise + result = ia.view(None) + assert_allclose(result, expected) + + @pytest.mark.parametrize("dt", [bool, int, float, complex]) + def test_python_types(self, dt): + a = numpy.ones((8, 4), dtype=numpy.complex64) + ia = dpnp.array(a) + + result = ia.view(dt) + if not has_support_aspect64() and dt in [float, complex]: + dt = result.dtype + expected = a.view(dt) + assert_allclose(result, expected) + + def test_type_error(self): + x = dpnp.ones(4, dtype="i4") + with pytest.raises(NotImplementedError): + x.view("i2", type=dpnp.ndarray) + + @pytest.mark.parametrize( "arr", [ diff --git a/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py b/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py index eaf01d1b345c..25d30b69607c 100644 --- a/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py +++ b/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py @@ -25,7 +25,6 @@ def get_strides(xp, a): return a.strides -@pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet") class TestView: @testing.numpy_cupy_array_equal() @@ -98,9 +97,9 @@ def test_view_relaxed_contiguous(self, xp, dtype): ) @testing.numpy_cupy_equal() def test_view_flags_smaller(self, xp, order, shape): - a = xp.zeros(shape, numpy.int32, order) + a = xp.zeros(shape, dtype=numpy.int32, order=order) b = a.view(numpy.int16) - return b.flags.c_contiguous, b.flags.f_contiguous, b.flags.owndata + return b.flags.c_contiguous, b.flags.f_contiguous # , b.flags.owndata @pytest.mark.parametrize( ("order", "shape"), @@ -112,7 +111,7 @@ def test_view_flags_smaller(self, xp, order, shape): @testing.with_requires("numpy>=1.23") def test_view_flags_smaller_invalid(self, order, shape): for xp in (numpy, cupy): - a = xp.zeros(shape, numpy.int32, order) + a = xp.zeros(shape, dtype=numpy.int32, order=order) with pytest.raises(ValueError): a.view(numpy.int16) @@ -121,7 +120,7 @@ def test_view_flags_smaller_invalid(self, order, shape): [ ("C", (6,)), ("C", (3, 10)), - ("C", (0,)), + # ("C", (0,)), # dpctl-2119 ("C", (1, 6)), ("C", (3, 2)), ], @@ -129,9 +128,9 @@ def test_view_flags_smaller_invalid(self, order, shape): ) @testing.numpy_cupy_equal() def test_view_flags_larger(self, xp, order, shape): - a = xp.zeros(shape, numpy.int16, order) + a = xp.zeros(shape, dtype=numpy.int16, order=order) b = a.view(numpy.int32) - return b.flags.c_contiguous, b.flags.f_contiguous, b.flags.owndata + return b.flags.c_contiguous, b.flags.f_contiguous # , b.flags.owndata @pytest.mark.parametrize( ("order", "shape"), @@ -144,7 +143,7 @@ def test_view_flags_larger(self, xp, order, shape): @testing.with_requires("numpy>=1.23") def test_view_flags_larger_invalid(self, order, shape): for xp in (numpy, cupy): - a = xp.zeros(shape, numpy.int16, order) + a = xp.zeros(shape, dtype=numpy.int16, order=order) with pytest.raises(ValueError): a.view(numpy.int32) @@ -161,7 +160,7 @@ def test_view_smaller_dtype_multiple(self, xp): @testing.numpy_cupy_array_equal() def test_view_smaller_dtype_multiple2(self, xp): # x is non-contiguous, and stride[-1] != 0 - x = xp.ones((3, 4), xp.int32)[:, :1:2] + x = xp.ones((3, 4), dtype=xp.int32)[:, :1:2] return x.view(xp.int16) @testing.with_requires("numpy>=1.23") @@ -184,7 +183,7 @@ def test_view_non_c_contiguous(self, xp): @testing.numpy_cupy_array_equal() def test_view_larger_dtype_zero_sized(self, xp): - x = xp.ones((3, 20), xp.int16)[:0, ::2] + x = xp.ones((3, 20), dtype=xp.int16)[:0, ::2] return x.view(xp.int32) @@ -387,7 +386,7 @@ def test_astype_strides_broadcast(self, xp, src_dtype, dst_dtype): dst = astype_without_warning(src, dst_dtype, order="K") return get_strides(xp, dst) - @pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet") + @pytest.mark.skip("dpctl-2121") @testing.numpy_cupy_array_equal() def test_astype_boolean_view(self, xp): # See #4354 @@ -454,7 +453,7 @@ def __array_finalize__(self, obj): self.info = getattr(obj, "info", None) -@pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet") +@pytest.mark.skip("subclass array is not supported") class TestSubclassArrayView: def test_view_casting(self): diff --git a/dpnp/tests/third_party/cupy/testing/_array.py b/dpnp/tests/third_party/cupy/testing/_array.py index beecaac16e58..f2f8d455dd8e 100644 --- a/dpnp/tests/third_party/cupy/testing/_array.py +++ b/dpnp/tests/third_party/cupy/testing/_array.py @@ -171,13 +171,14 @@ def assert_array_equal( ) if strides_check: - if actual.strides != desired.strides: + strides = tuple(el // desired.itemsize for el in desired.strides) + if actual.strides != strides: msg = ["Strides are not equal:"] if err_msg: msg = [msg[0] + " " + err_msg] if verbose: msg.append(" x: {}".format(actual.strides)) - msg.append(" y: {}".format(desired.strides)) + msg.append(" y: {}".format(strides)) raise AssertionError("\n".join(msg)) From 425ec0c6865e153a08701ad48c425055b5ca8599 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad <120411540+vtavana@users.noreply.github.com> Date: Fri, 11 Jul 2025 08:36:44 -0500 Subject: [PATCH 4/7] Unmute FFT tests that are fixed (#2523) Unmute FFT tests that are fixed. --- dpnp/tests/test_fft.py | 29 ++------- .../third_party/cupy/fft_tests/test_fft.py | 64 +++++++++---------- 2 files changed, 37 insertions(+), 56 deletions(-) diff --git a/dpnp/tests/test_fft.py b/dpnp/tests/test_fft.py index 4c360db476ef..6ab1a0b253a7 100644 --- a/dpnp/tests/test_fft.py +++ b/dpnp/tests/test_fft.py @@ -551,10 +551,7 @@ def test_basic(self, func, dtype, axes): class TestHfft: - # TODO: include boolean dtype when mkl_fft-gh-180 is merged - @pytest.mark.parametrize( - "dtype", get_all_dtypes(no_none=True, no_bool=True) - ) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize("n", [None, 5, 18]) @pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"]) def test_basic(self, dtype, n, norm): @@ -563,10 +560,7 @@ def test_basic(self, dtype, n, norm): result = dpnp.fft.hfft(ia, n=n, norm=norm) expected = numpy.fft.hfft(a, n=n, norm=norm) - # TODO: change to the commented line when mkl_fft-2.0.0 is released - # and being used with Intel NumPy >= 2.0.0 - flag = True - # flag = True if numpy_version() < "2.0.0" else False + flag = True if numpy_version() < "2.0.0" else False assert_dtype_allclose( result, expected, factor=24, check_only_type_kind=flag ) @@ -609,10 +603,7 @@ def test_basic(self, dtype, n, norm): result = dpnp.fft.irfft(ia, n=n, norm=norm) expected = numpy.fft.irfft(a, n=n, norm=norm) - # TODO: change to the commented line when mkl_fft-2.0.0 is released - # and being used with Intel NumPy >= 2.0.0 - flag = True - # flag = True if numpy_version() < "2.0.0" else False + flag = True if numpy_version() < "2.0.0" else False assert_dtype_allclose( result, expected, factor=24, check_only_type_kind=flag ) @@ -779,8 +770,7 @@ def test_float16(self): expected = numpy.fft.rfft(a) result = dpnp.fft.rfft(ia) - # TODO: change to the commented line when mkl_fft-2.0.0 is released - # and being used with Intel NumPy >= 2.0.0 + # TODO: change to the commented line when mkl_fft-gh-204 is resolved flag = True # flag = True if numpy_version() < "2.0.0" else False assert_dtype_allclose(result, expected, check_only_type_kind=flag) @@ -800,11 +790,10 @@ def test_validate_out(self): class TestRfft2: - # TODO: add other axes when mkl_fft gh-119 is addressed @pytest.mark.parametrize( "dtype", get_all_dtypes(no_none=True, no_complex=True) ) - @pytest.mark.parametrize("axes", [(0, 1)]) # (1, 2),(0, 2),(2, 1),(2, 0) + @pytest.mark.parametrize("axes", [(0, 1), (1, 2), (0, 2), (2, 1), (2, 0)]) @pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"]) @pytest.mark.parametrize("order", ["C", "F"]) def test_basic(self, dtype, axes, norm, order): @@ -859,12 +848,11 @@ def test_error(self, xp): class TestRfftn: - # TODO: add additional axes when mkl_fft gh-119 is addressed @pytest.mark.parametrize( "dtype", get_all_dtypes(no_none=True, no_complex=True) ) @pytest.mark.parametrize( - "axes", [(0, 1, 2), (-2, -4, -1, -3)] # (-1, -4, -2) + "axes", [(0, 1, 2), (-2, -4, -1, -3), (-1, -4, -2)] ) @pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"]) @pytest.mark.parametrize("order", ["C", "F"]) @@ -965,8 +953,5 @@ def test_1d_array(self): result = dpnp.fft.irfftn(ia) expected = numpy.fft.irfftn(a) - # TODO: change to the commented line when mkl_fft-2.0.0 is released - # and being used with Intel NumPy >= 2.0.0 - flag = True - # flag = True if numpy_version() < "2.0.0" else False + flag = True if numpy_version() < "2.0.0" else False assert_dtype_allclose(result, expected, check_only_type_kind=flag) diff --git a/dpnp/tests/third_party/cupy/fft_tests/test_fft.py b/dpnp/tests/third_party/cupy/fft_tests/test_fft.py index 5c0368278f97..534b474363f1 100644 --- a/dpnp/tests/third_party/cupy/fft_tests/test_fft.py +++ b/dpnp/tests/third_party/cupy/fft_tests/test_fft.py @@ -378,21 +378,21 @@ def test_fft_allocate(self): {"shape": (3, 4), "s": (1, 5), "axes": (-2, -1)}, {"shape": (3, 4), "s": None, "axes": (-2, -1)}, {"shape": (3, 4), "s": None, "axes": (-1, -2)}, - # {"shape": (3, 4), "s": None, "axes": (0,)}, # mkl_fft gh-109 + {"shape": (3, 4), "s": None, "axes": (0,)}, {"shape": (3, 4), "s": None, "axes": None}, - # {"shape": (3, 4), "s": None, "axes": ()}, # mkl_fft gh-108 + {"shape": (3, 4), "s": None, "axes": ()}, {"shape": (2, 3, 4), "s": None, "axes": None}, {"shape": (2, 3, 4), "s": (1, 4, 10), "axes": (-2, -1)}, {"shape": (2, 3, 4), "s": None, "axes": (-3, -2, -1)}, {"shape": (2, 3, 4), "s": None, "axes": (-1, -2, -3)}, - # {"shape": (2, 3, 4), "s": None, "axes": (0, 1)}, # mkl_fft gh-109 + {"shape": (2, 3, 4), "s": None, "axes": (0, 1)}, {"shape": (2, 3, 4), "s": None, "axes": None}, - # {"shape": (2, 3, 4), "s": None, "axes": ()}, # mkl_fft gh-108 - # {"shape": (2, 3, 4), "s": (2, 3), "axes": (0, 1, 2)}, # mkl_fft gh-109 + {"shape": (2, 3, 4), "s": None, "axes": ()}, + {"shape": (2, 3, 4), "s": (2, 3), "axes": (0, 1, 2)}, {"shape": (2, 3, 4, 5), "s": None, "axes": None}, - # {"shape": (0, 5), "s": None, "axes": None}, # mkl_fft gh-110 - # {"shape": (2, 0, 5), "s": None, "axes": None}, # mkl_fft gh-110 - # {"shape": (0, 0, 5), "s": None, "axes": None}, # mkl_fft gh-110 + {"shape": (0, 5), "s": None, "axes": None}, + {"shape": (2, 0, 5), "s": None, "axes": None}, + {"shape": (0, 0, 5), "s": None, "axes": None}, {"shape": (3, 4), "s": (0, 5), "axes": (-2, -1)}, {"shape": (3, 4), "s": (1, 0), "axes": (-2, -1)}, ], @@ -468,23 +468,23 @@ def test_ifft2(self, xp, dtype, order, enable_nd): {"shape": (3, 4), "s": None, "axes": (-2, -1)}, {"shape": (3, 4), "s": None, "axes": (-1, -2)}, {"shape": (3, 4), "s": None, "axes": [-1, -2]}, - # {"shape": (3, 4), "s": None, "axes": (0,)}, # mkl_fft gh-109 - # {"shape": (3, 4), "s": None, "axes": ()}, # mkl_fft gh-108 + {"shape": (3, 4), "s": None, "axes": (0,)}, + {"shape": (3, 4), "s": None, "axes": ()}, {"shape": (3, 4), "s": None, "axes": None}, {"shape": (2, 3, 4), "s": None, "axes": None}, {"shape": (2, 3, 4), "s": (1, 4, 10), "axes": (-3, -2, -1)}, {"shape": (2, 3, 4), "s": None, "axes": (-3, -2, -1)}, {"shape": (2, 3, 4), "s": None, "axes": (-1, -2, -3)}, - # {"shape": (2, 3, 4), "s": None, "axes": (-1, -3)}, # mkl_fft gh-109 - # {"shape": (2, 3, 4), "s": None, "axes": (0, 1)}, # mkl_fft gh-109 + {"shape": (2, 3, 4), "s": None, "axes": (-1, -3)}, + {"shape": (2, 3, 4), "s": None, "axes": (0, 1)}, {"shape": (2, 3, 4), "s": None, "axes": None}, - # {"shape": (2, 3, 4), "s": None, "axes": ()}, # mkl_fft gh-108 - # {"shape": (2, 3, 4), "s": (2, 3), "axes": (0, 1, 2)}, # mkl_fft gh-109 + {"shape": (2, 3, 4), "s": None, "axes": ()}, + {"shape": (2, 3, 4), "s": (2, 3), "axes": (0, 1, 2)}, {"shape": (2, 3, 4), "s": (4, 3, 2), "axes": (2, 0, 1)}, {"shape": (2, 3, 4, 5), "s": None, "axes": None}, - # {"shape": (0, 5), "s": None, "axes": None}, # mkl_fft gh-110 - # {"shape": (2, 0, 5), "s": None, "axes": None}, # mkl_fft gh-110 - # {"shape": (0, 0, 5), "s": None, "axes": None}, # mkl_fft gh-110 + {"shape": (0, 5), "s": None, "axes": None}, + {"shape": (2, 0, 5), "s": None, "axes": None}, + {"shape": (0, 0, 5), "s": None, "axes": None}, ], testing.product({"norm": [None, "backward", "ortho", "forward"]}), ) @@ -912,8 +912,7 @@ def test_rfft(self, xp, dtype): atol=2e-6, accept_error=ValueError, contiguous_check=False, - # TODO: replace with has_support_aspect64() when mkl_fft-gh-180 is merged - type_check=False, + type_check=has_support_aspect64(), ) def test_irfft(self, xp, dtype): a = testing.shaped_random(self.shape, xp, dtype) @@ -1002,14 +1001,14 @@ def test_rfft_error_on_wrong_plan(self, dtype): {"shape": (3, 4), "s": None, "axes": (-1, -2)}, {"shape": (3, 4), "s": None, "axes": (0,)}, {"shape": (3, 4), "s": None, "axes": None}, - # {"shape": (2, 3, 4), "s": None, "axes": None}, # mkl_fft gh-116 - # {"shape": (2, 3, 4), "s": (1, 4, 10), "axes": (-3, -2, -1)}, # mkl_fft gh-115 - # {"shape": (2, 3, 4), "s": None, "axes": (-3, -2, -1)}, # mkl_fft gh-116 - # {"shape": (2, 3, 4), "s": None, "axes": (-1, -2, -3)}, # mkl_fft gh-116 + {"shape": (2, 3, 4), "s": None, "axes": None}, + {"shape": (2, 3, 4), "s": (1, 4, 10), "axes": (-3, -2, -1)}, + {"shape": (2, 3, 4), "s": None, "axes": (-3, -2, -1)}, + {"shape": (2, 3, 4), "s": None, "axes": (-1, -2, -3)}, {"shape": (2, 3, 4), "s": None, "axes": (0, 1)}, {"shape": (2, 3, 4), "s": None, "axes": None}, {"shape": (2, 3, 4), "s": (2, 3), "axes": (0, 1, 2)}, - # {"shape": (2, 3, 4, 5), "s": None, "axes": None}, # mkl_fft gh-109 and gh-116 + {"shape": (2, 3, 4, 5), "s": None, "axes": None}, ], testing.product( {"norm": [None, "backward", "ortho", "forward", ""]} @@ -1044,8 +1043,7 @@ def test_rfft2(self, xp, dtype, order, enable_nd): atol=1e-7, accept_error=ValueError, contiguous_check=False, - # TODO: replace with has_support_aspect64() when mkl_fft-gh-180 is merged - type_check=False, + type_check=has_support_aspect64(), ) def test_irfft2(self, xp, dtype, order, enable_nd): # assert config.enable_nd_planning == enable_nd @@ -1090,13 +1088,13 @@ def test_irfft2(self, dtype): {"shape": (3, 4), "s": None, "axes": (0,)}, {"shape": (3, 4), "s": None, "axes": None}, {"shape": (2, 3, 4), "s": None, "axes": None}, - # {"shape": (2, 3, 4), "s": (1, 4, 10), "axes": (-3, -2, -1)}, # mkl_fft gh-115 - # {"shape": (2, 3, 4), "s": None, "axes": (-3, -2, -1)}, # mkl_fft gh-116 - # {"shape": (2, 3, 4), "s": None, "axes": (-1, -2, -3)}, # mkl_fft gh-116 + {"shape": (2, 3, 4), "s": (1, 4, 10), "axes": (-3, -2, -1)}, + {"shape": (2, 3, 4), "s": None, "axes": (-3, -2, -1)}, + {"shape": (2, 3, 4), "s": None, "axes": (-1, -2, -3)}, {"shape": (2, 3, 4), "s": None, "axes": (0, 1)}, {"shape": (2, 3, 4), "s": None, "axes": None}, {"shape": (2, 3, 4), "s": (2, 3), "axes": (0, 1, 2)}, - # {"shape": (2, 3, 4, 5), "s": None, "axes": None}, # mkl_fft gh-109 and gh-116 + {"shape": (2, 3, 4, 5), "s": None, "axes": None}, ], testing.product( {"norm": [None, "backward", "ortho", "forward", ""]} @@ -1131,8 +1129,7 @@ def test_rfftn(self, xp, dtype, order, enable_nd): atol=1e-7, accept_error=ValueError, contiguous_check=False, - # TODO: replace with has_support_aspect64() when mkl_fft-gh-180 is merged - type_check=False, + type_check=has_support_aspect64(), ) def test_irfftn(self, xp, dtype, order, enable_nd): # assert config.enable_nd_planning == enable_nd @@ -1326,8 +1323,7 @@ class TestHfft: atol=2e-6, accept_error=ValueError, contiguous_check=False, - # TODO: replace with has_support_aspect64() when mkl_fft-gh-180 is merged - type_check=False, + type_check=has_support_aspect64(), ) def test_hfft(self, xp, dtype): a = testing.shaped_random(self.shape, xp, dtype) From 71c2e8ef4b5f4179a21cf1b865a9cce5fbf46ab1 Mon Sep 17 00:00:00 2001 From: Anton <100830759+antonwolfy@users.noreply.github.com> Date: Sat, 12 Jul 2025 00:47:08 +0200 Subject: [PATCH 5/7] Remove dummy setting of `compat` module in the tests (#2525) `numpy.comat` was deprecated since 1.26.0 and remove in 2.3.0. The PR removes dummy setting of `dpnp.compat` in external tests. --- tests_external/numpy/runtests.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests_external/numpy/runtests.py b/tests_external/numpy/runtests.py index f40e86cfa1f9..82e21395e2cd 100644 --- a/tests_external/numpy/runtests.py +++ b/tests_external/numpy/runtests.py @@ -291,9 +291,6 @@ def wrapper(*args, **kwargs): dpnp.unicode = str dpnp.unicode_ = dpnp.str_ -dpnp.compat = dummymodule -dpnp.compat.unicode = dpnp.unicode - dpnp.core = dpnp.core.umath = dpnp dpnp.core._exceptions = dummymodule From 1a7ce2207739d1d387fbd300ad8fee31397c3c92 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad <120411540+vtavana@users.noreply.github.com> Date: Sat, 12 Jul 2025 07:42:32 -0500 Subject: [PATCH 6/7] simplify blas backend when using with `USE_ONEMATH_CUBLAS` (#2522) simplify blas backend when using with `USE_ONEMATH_CUBLAS`. --- dpnp/backend/extensions/blas/gemm.cpp | 18 +----------------- dpnp/backend/extensions/blas/gemm_batch.cpp | 19 +------------------ dpnp/backend/extensions/blas/gemv.cpp | 17 +---------------- 3 files changed, 3 insertions(+), 51 deletions(-) diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index af18ab3002fb..c343c232b7af 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -55,9 +55,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &, const std::int64_t, char *, const std::int64_t, -#if !defined(USE_ONEMATH_CUBLAS) const bool, -#endif // !USE_ONEMATH_CUBLAS const std::vector &); static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types] @@ -76,9 +74,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q, const std::int64_t ldb, char *resultC, const std::int64_t ldc, -#if !defined(USE_ONEMATH_CUBLAS) const bool is_row_major, -#endif // !USE_ONEMATH_CUBLAS const std::vector &depends) { type_utils::validate_type_for_device(exec_q); @@ -100,11 +96,6 @@ static sycl::event gemm_impl(sycl::queue &exec_q, const Tab *a, const std::int64_t lda, const Tab *b, const std::int64_t ldb, Tab beta, Tc *c, const std::int64_t ldc, const std::vector &deps) -> sycl::event { -#if defined(USE_ONEMATH_CUBLAS) - return mkl_blas::column_major::gemm(q, transA, transB, m, n, k, - alpha, a, lda, b, ldb, beta, c, - ldc, deps); -#else if (is_row_major) { return mkl_blas::row_major::gemm(q, transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, @@ -115,7 +106,6 @@ static sycl::event gemm_impl(sycl::queue &exec_q, alpha, a, lda, b, ldb, beta, c, ldc, deps); } -#endif // USE_ONEMATH_CUBLAS }; gemm_event = gemm_func( exec_q, @@ -242,7 +232,7 @@ std::tuple // cuBLAS supports only column-major storage #if defined(USE_ONEMATH_CUBLAS) - const bool is_row_major = false; + constexpr bool is_row_major = false; transA = is_matrixA_c_contig ? oneapi::mkl::transpose::T : oneapi::mkl::transpose::N; @@ -320,15 +310,9 @@ std::tuple const char *b_typeless_ptr = matrixB.get_data(); char *r_typeless_ptr = resultC.get_data(); -#if defined(USE_ONEMATH_CUBLAS) - sycl::event gemm_ev = - gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda, - b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends); -#else sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda, b_typeless_ptr, ldb, r_typeless_ptr, ldc, is_row_major, depends); -#endif // USE_ONEMATH_CUBLAS sycl::event args_ev = dpctl::utils::keep_args_alive( exec_q, {matrixA, matrixB, resultC}, {gemm_ev}); diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 1e210aede9fa..95f5a1aaf328 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -60,9 +60,7 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( const char *, const char *, char *, -#if !defined(USE_ONEMATH_CUBLAS) const bool, -#endif // !USE_ONEMATH_CUBLAS const std::vector &); static gemm_batch_impl_fn_ptr_t @@ -85,9 +83,7 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, const char *matrixA, const char *matrixB, char *resultC, -#if !defined(USE_ONEMATH_CUBLAS) const bool is_row_major, -#endif // !USE_ONEMATH_CUBLAS const std::vector &depends) { type_utils::validate_type_for_device(exec_q); @@ -112,11 +108,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, Tc *c, const std::int64_t ldc, const std::int64_t stridec, const std::int64_t batch_size, const std::vector &deps) -> sycl::event { -#if defined(USE_ONEMATH_CUBLAS) - return mkl_blas::column_major::gemm_batch( - q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, - strideb, beta, c, ldc, stridec, batch_size, deps); -#else if (is_row_major) { return mkl_blas::row_major::gemm_batch( q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, @@ -127,7 +118,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, batch_size, deps); } -#endif // USE_ONEMATH_CUBLAS }; gemm_batch_event = gemm_batch_func( exec_q, @@ -317,7 +307,7 @@ std::tuple // cuBLAS supports only column-major storage #if defined(USE_ONEMATH_CUBLAS) - const bool is_row_major = false; + constexpr bool is_row_major = false; transA = A_base_is_c_contig ? oneapi::mkl::transpose::T : oneapi::mkl::transpose::N; @@ -396,17 +386,10 @@ std::tuple const char *b_typeless_ptr = matrixB.get_data(); char *r_typeless_ptr = resultC.get_data(); -#if defined(USE_ONEMATH_CUBLAS) - sycl::event gemm_batch_ev = - gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea, - strideb, stridec, transA, transB, a_typeless_ptr, - b_typeless_ptr, r_typeless_ptr, depends); -#else sycl::event gemm_batch_ev = gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea, strideb, stridec, transA, transB, a_typeless_ptr, b_typeless_ptr, r_typeless_ptr, is_row_major, depends); -#endif // USE_ONEMATH_CUBLAS sycl::event args_ev = dpctl::utils::keep_args_alive( exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev}); diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 91057893aa5f..29bce7a10990 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -53,9 +53,7 @@ typedef sycl::event (*gemv_impl_fn_ptr_t)(sycl::queue &, const std::int64_t, char *, const std::int64_t, -#if !defined(USE_ONEMATH_CUBLAS) const bool, -#endif // !USE_ONEMATH_CUBLAS const std::vector &); static gemv_impl_fn_ptr_t gemv_dispatch_vector[dpctl_td_ns::num_types]; @@ -71,9 +69,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q, const std::int64_t incx, char *vectorY, const std::int64_t incy, -#if !defined(USE_ONEMATH_CUBLAS) const bool is_row_major, -#endif // !USE_ONEMATH_CUBLAS const std::vector &depends) { type_utils::validate_type_for_device(exec_q); @@ -93,10 +89,6 @@ static sycl::event gemv_impl(sycl::queue &exec_q, const std::int64_t lda, const T *x, const std::int64_t incx, T beta, T *y, const std::int64_t incy, const std::vector &deps) -> sycl::event { -#if defined(USE_ONEMATH_CUBLAS) - return mkl_blas::column_major::gemv(q, transA, m, n, alpha, a, lda, - x, incx, beta, y, incy, deps); -#else if (is_row_major) { return mkl_blas::row_major::gemv(q, transA, m, n, alpha, a, lda, x, incx, beta, y, incy, deps); @@ -106,7 +98,6 @@ static sycl::event gemv_impl(sycl::queue &exec_q, lda, x, incx, beta, y, incy, deps); } -#endif // USE_ONEMATH_CUBLAS }; gemv_event = gemv_func( exec_q, @@ -196,7 +187,7 @@ std::pair // cuBLAS supports only column-major storage #if defined(USE_ONEMATH_CUBLAS) - const bool is_row_major = false; + constexpr bool is_row_major = false; std::int64_t m; std::int64_t n; @@ -304,15 +295,9 @@ std::pair y_typeless_ptr -= (y_shape[0] - 1) * std::abs(incy) * y_elemsize; } -#if defined(USE_ONEMATH_CUBLAS) - sycl::event gemv_ev = - gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx, - y_typeless_ptr, incy, depends); -#else sycl::event gemv_ev = gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx, y_typeless_ptr, incy, is_row_major, depends); -#endif // USE_ONEMATH_CUBLAS sycl::event args_ev = dpctl::utils::keep_args_alive( exec_q, {matrixA, vectorX, vectorY}, {gemv_ev}); From afd5c6d65ae751265f670e5f6fd2c2ac8583381d Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad <120411540+vtavana@users.noreply.github.com> Date: Sat, 12 Jul 2025 12:26:52 -0500 Subject: [PATCH 7/7] using `syrk` for performing special cases of matrix multiplication (#2509) In this PR, the `syrk` routines from oneMKL is used to perform a rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix. --- CHANGELOG.md | 1 + dpnp/backend/extensions/blas/CMakeLists.txt | 2 + dpnp/backend/extensions/blas/blas_py.cpp | 22 +- dpnp/backend/extensions/blas/dot_common.hpp | 3 +- dpnp/backend/extensions/blas/gemm.cpp | 11 +- dpnp/backend/extensions/blas/gemm_batch.cpp | 3 +- dpnp/backend/extensions/blas/gemv.cpp | 49 ++- dpnp/backend/extensions/blas/gemv.hpp | 1 - dpnp/backend/extensions/blas/syrk.cpp | 357 ++++++++++++++++++ dpnp/backend/extensions/blas/syrk.hpp | 42 +++ dpnp/backend/extensions/blas/types_matrix.hpp | 25 ++ .../extensions/lapack/evd_batch_common.hpp | 3 +- dpnp/backend/extensions/lapack/evd_common.hpp | 3 +- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 75 +++- dpnp/tests/test_product.py | 113 ++++-- dpnp/tests/test_sycl_queue.py | 79 ++-- dpnp/tests/test_usm_type.py | 79 ++-- 17 files changed, 703 insertions(+), 165 deletions(-) create mode 100644 dpnp/backend/extensions/blas/syrk.cpp create mode 100644 dpnp/backend/extensions/blas/syrk.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index a8d4adca875a..613338e8d3bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `--target-cuda[=ARCH]` option to replace the deprecated `--target=cuda`, allowing users to build for CUDA devices with optional architecture selection using [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/) [#2478](https://github.com/IntelPython/dpnp/pull/2478) * Added several new `pre-commit` rules, including protection against direct commits to master/maintenance branches [#2500](https://github.com/IntelPython/dpnp/pull/2500) * Added implementation of `dpnp.ndarray.view` method [#2520](https://github.com/IntelPython/dpnp/pull/2520) +* Added a new backend routine `syrk` from oneMKL to perform symmetric rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix [2509](https://github.com/IntelPython/dpnp/pull/2509) ### Changed diff --git a/dpnp/backend/extensions/blas/CMakeLists.txt b/dpnp/backend/extensions/blas/CMakeLists.txt index 24b8457ffebc..3f81253169d1 100644 --- a/dpnp/backend/extensions/blas/CMakeLists.txt +++ b/dpnp/backend/extensions/blas/CMakeLists.txt @@ -30,6 +30,7 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/syrk.cpp ) pybind11_add_module(${python_module_name} MODULE ${_module_src}) @@ -61,6 +62,7 @@ set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDEN target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include) target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src) +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common) target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS}) target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR}) diff --git a/dpnp/backend/extensions/blas/blas_py.cpp b/dpnp/backend/extensions/blas/blas_py.cpp index 3393315ffe19..850c1e784c7b 100644 --- a/dpnp/backend/extensions/blas/blas_py.cpp +++ b/dpnp/backend/extensions/blas/blas_py.cpp @@ -36,6 +36,7 @@ #include "dotu.hpp" #include "gemm.hpp" #include "gemv.hpp" +#include "syrk.hpp" namespace blas_ns = dpnp::extensions::blas; namespace py = pybind11; @@ -48,6 +49,7 @@ void init_dispatch_vectors_tables(void) blas_ns::init_gemm_batch_dispatch_table(); blas_ns::init_gemm_dispatch_table(); blas_ns::init_gemv_dispatch_vector(); + blas_ns::init_syrk_dispatch_vector(); } static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types]; @@ -73,7 +75,7 @@ PYBIND11_MODULE(_blas_impl, m) }; m.def("_dot", dot_pyapi, - "Call `dot` from OneMKL BLAS library to compute " + "Call `dot` from oneMKL BLAS library to compute " "the dot product of two real-valued vectors.", py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"), py::arg("result"), py::arg("depends") = py::list()); @@ -91,7 +93,7 @@ PYBIND11_MODULE(_blas_impl, m) }; m.def("_dotc", dotc_pyapi, - "Call `dotc` from OneMKL BLAS library to compute " + "Call `dotc` from oneMKL BLAS library to compute " "the dot product of two complex vectors, " "conjugating the first vector.", py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"), @@ -110,7 +112,7 @@ PYBIND11_MODULE(_blas_impl, m) }; m.def("_dotu", dotu_pyapi, - "Call `dotu` from OneMKL BLAS library to compute " + "Call `dotu` from oneMKL BLAS library to compute " "the dot product of two complex vectors.", py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"), py::arg("result"), py::arg("depends") = py::list()); @@ -118,7 +120,7 @@ PYBIND11_MODULE(_blas_impl, m) { m.def("_gemm", &blas_ns::gemm, - "Call `gemm` from OneMKL BLAS library to compute " + "Call `gemm` from oneMKL BLAS library to compute " "the matrix-matrix product with 2-D matrices.", py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"), py::arg("resultC"), py::arg("depends") = py::list()); @@ -126,7 +128,7 @@ PYBIND11_MODULE(_blas_impl, m) { m.def("_gemm_batch", &blas_ns::gemm_batch, - "Call `gemm_batch` from OneMKL BLAS library to compute " + "Call `gemm_batch` from oneMKL BLAS library to compute " "the matrix-matrix product for a batch of 2-D matrices.", py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"), py::arg("resultC"), py::arg("depends") = py::list()); @@ -134,13 +136,21 @@ PYBIND11_MODULE(_blas_impl, m) { m.def("_gemv", &blas_ns::gemv, - "Call `gemv` from OneMKL BLAS library to compute " + "Call `gemv` from oneMKL BLAS library to compute " "the matrix-vector product with a general matrix.", py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"), py::arg("vectorY"), py::arg("transpose"), py::arg("depends") = py::list()); } + { + m.def("_syrk", &blas_ns::syrk, + "Call `syrk` from oneMKL BLAS library to compute " + "the matrix-vector product with a general matrix.", + py::arg("sycl_queue"), py::arg("matrixA"), py::arg("resultC"), + py::arg("depends") = py::list()); + } + { m.def( "_using_onemath", diff --git a/dpnp/backend/extensions/blas/dot_common.hpp b/dpnp/backend/extensions/blas/dot_common.hpp index fb9a1f078c53..169421a2464c 100644 --- a/dpnp/backend/extensions/blas/dot_common.hpp +++ b/dpnp/backend/extensions/blas/dot_common.hpp @@ -128,7 +128,8 @@ std::pair dot_impl_fn_ptr_t dot_fn = dot_dispatch_vector[type_id]; if (dot_fn == nullptr) { throw py::value_error( - "Types of input vectors and result array are mismatched."); + "No dot implementation is available for the specified data type " + "of the input and output arrays."); } char *x_typeless_ptr = vectorX.get_data(); diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index c343c232b7af..5a411d94ba1f 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -119,8 +119,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q, Tab(1), // Scaling factor for the product of matrices A and B. a, // Pointer to matrix A. lda, // Leading dimension of matrix A, which is the - // stride between successive rows (for row major - // layout). + // stride between successive rows (for row major layout). b, // Pointer to matrix B. ldb, // Leading dimension of matrix B, similar to lda. Tab(0), // Scaling factor for matrix C. @@ -158,7 +157,8 @@ std::tuple const int resultC_nd = resultC.get_ndim(); if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) { - throw py::value_error("Input matrices must be two-dimensional."); + throw py::value_error( + "Input and output matrices must be two-dimensional."); } auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); @@ -276,6 +276,8 @@ std::tuple } } else { + // both A and B are f_contig so using column-major gemm and + // no transpose is needed transA = oneapi::mkl::transpose::N; transB = oneapi::mkl::transpose::N; lda = m; @@ -303,7 +305,8 @@ std::tuple gemm_dispatch_table[matrixAB_type_id][resultC_type_id]; if (gemm_fn == nullptr) { throw py::value_error( - "Types of input matrices and result matrix are mismatched."); + "No gemm implementation is available for the specified data type " + "of the input and output arrays."); } const char *a_typeless_ptr = matrixA.get_data(); diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 95f5a1aaf328..f329f8f4be58 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -379,7 +379,8 @@ std::tuple gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id]; if (gemm_batch_fn == nullptr) { throw py::value_error( - "Types of input matrices and result matrix are mismatched."); + "No gemm_batch implementation is available for the specified data " + "type of the input and output arrays."); } const char *a_typeless_ptr = matrixA.get_data(); diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 29bce7a10990..37da858c8e69 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -109,8 +109,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q, T(1), // Scaling factor for the matrix-vector product. a, // Pointer to the input matrix A. lda, // Leading dimension of matrix A, which is the - // stride between successive rows (for row major - // layout). + // stride between successive rows (for row major layout). x, // Pointer to the input vector x. incx, // The stride of vector x. T(0), // Scaling factor for vector y. @@ -181,6 +180,26 @@ std::pair const py::ssize_t *a_shape = matrixA.get_shape_raw(); const py::ssize_t *x_shape = vectorX.get_shape_raw(); const py::ssize_t *y_shape = vectorY.get_shape_raw(); + if (transpose) { + if (a_shape[0] != x_shape[0]) { + throw py::value_error("The number of rows in A must be equal to " + "the number of elements in X."); + } + if (a_shape[1] != y_shape[0]) { + throw py::value_error("The number of columns in A must be equal to " + "the number of elements in Y."); + } + } + else { + if (a_shape[1] != x_shape[0]) { + throw py::value_error("The number of columns in A must be equal to " + "the number of elements in X."); + } + if (a_shape[0] != y_shape[0]) { + throw py::value_error("The number of rows in A must be equal to " + "the number of elements in Y."); + } + } oneapi::mkl::transpose transA; std::size_t src_nelems; @@ -234,27 +253,6 @@ std::pair } #endif // USE_ONEMATH_CUBLAS - if (transpose) { - if (a_shape[0] != x_shape[0]) { - throw py::value_error("The number of rows in A must be equal to " - "the number of elements in X."); - } - if (a_shape[1] != y_shape[0]) { - throw py::value_error("The number of columns in A must be equal to " - "the number of elements in Y."); - } - } - else { - if (a_shape[1] != x_shape[0]) { - throw py::value_error("The number of columns in A must be equal to " - "the number of elements in X."); - } - if (a_shape[0] != y_shape[0]) { - throw py::value_error("The number of rows in A must be equal to " - "the number of elements in Y."); - } - } - const std::int64_t lda = is_row_major ? n : m; dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vectorY); dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vectorY, @@ -275,10 +273,11 @@ std::pair gemv_impl_fn_ptr_t gemv_fn = gemv_dispatch_vector[type_id]; if (gemv_fn == nullptr) { throw py::value_error( - "Types of input arrays and result array are mismatched."); + "No gemv implementation is available for the specified data type " + "of the input and output arrays."); } - char *a_typeless_ptr = matrixA.get_data(); + const char *a_typeless_ptr = matrixA.get_data(); char *x_typeless_ptr = vectorX.get_data(); char *y_typeless_ptr = vectorY.get_data(); diff --git a/dpnp/backend/extensions/blas/gemv.hpp b/dpnp/backend/extensions/blas/gemv.hpp index 88e9f9c5c6f0..094cdafdc483 100644 --- a/dpnp/backend/extensions/blas/gemv.hpp +++ b/dpnp/backend/extensions/blas/gemv.hpp @@ -41,5 +41,4 @@ extern std::pair const std::vector &depends); extern void init_gemv_dispatch_vector(void); -extern void init_gemv_batch_dispatch_vector(void); } // namespace dpnp::extensions::blas diff --git a/dpnp/backend/extensions/blas/syrk.cpp b/dpnp/backend/extensions/blas/syrk.cpp new file mode 100644 index 000000000000..d432a810ce6b --- /dev/null +++ b/dpnp/backend/extensions/blas/syrk.cpp @@ -0,0 +1,357 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include +#include + +#include + +#include "ext/common.hpp" + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_utils.hpp" + +#include "syrk.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +using ext::common::Align; + +namespace dpnp::extensions::blas +{ +namespace mkl_blas = oneapi::mkl::blas; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*syrk_impl_fn_ptr_t)(sycl::queue &, + const oneapi::mkl::transpose, + const std::int64_t, + const std::int64_t, + const char *, + const std::int64_t, + char *, + const std::int64_t, + const bool, + const std::vector &); + +static syrk_impl_fn_ptr_t syrk_dispatch_vector[dpctl_td_ns::num_types]; + +template +constexpr void copy_to_lower_triangle(T *res, + const std::size_t i, + const std::size_t j, + const std::int64_t ldc, + const std::size_t n, + const bool is_row_major) +{ + if (i < n && j < n && i > j) { + // result form row_major::syrk is row major and result form + // column_major::syrk is column major, so copying upper + // triangle to lower triangle is different for each case + if (is_row_major) { + res[i * ldc + j] = res[j * ldc + i]; + } + else { + res[j * ldc + i] = res[i * ldc + j]; + } + } +} + +template +class copy_kernel; + +template +void submit_copy_kernel(T *res, + const std::int64_t ldc, + const std::size_t n, + const bool is_row_major, + sycl::handler &cgh) +{ + using KernelName = copy_kernel; + + if constexpr (use_wg) { + static constexpr std::size_t tile_sz = 8; + sycl::range<2> global_range(Align(n, tile_sz), Align(n, tile_sz)); + sycl::range<2> local_range(tile_sz, tile_sz); + + cgh.parallel_for( + sycl::nd_range<2>(global_range, local_range), + [=](sycl::nd_item<2> item) { + std::size_t i = item.get_global_id(0); + std::size_t j = item.get_global_id(1); + + copy_to_lower_triangle(res, i, j, n, ldc, is_row_major); + }); + } + else { + cgh.parallel_for( + sycl::range<2>{n, n}, [=](sycl::id<2> idx) { + std::size_t i = idx[0]; + std::size_t j = idx[1]; + + copy_to_lower_triangle(res, i, j, n, ldc, is_row_major); + }); + } +} + +// kernel to copy upper triangle to lower triangle +template +sycl::event run_copy(sycl::queue &exec_q, + T *res, + const std::int64_t ldc, + const std::int64_t n, + const bool is_row_major, + sycl::event &depends) +{ + const sycl::device &dev = exec_q.get_device(); + + return exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + // two separate kernels are used to have better performance compared + // to gemm on both CPU and GPU + if (dev.is_gpu()) { + submit_copy_kernel(res, ldc, n, is_row_major, cgh); + } + else { + assert(dev.is_cpu()); + submit_copy_kernel(res, ldc, n, is_row_major, cgh); + } + }); +} + +template +static sycl::event syrk_impl(sycl::queue &exec_q, + const oneapi::mkl::transpose transA, + const std::int64_t n, + const std::int64_t k, + const char *matrixA, + const std::int64_t lda, + char *resultC, + const std::int64_t ldc, + const bool is_row_major, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + const T *a = reinterpret_cast(matrixA); + T *res = reinterpret_cast(resultC); + + std::stringstream error_msg; + bool is_exception_caught = false; + + sycl::event syrk_event; + try { + auto syrk_func = + [&](sycl::queue &q, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose transA, const std::int64_t n, + const std::int64_t k, T alpha, const T *a, + const std::int64_t lda, T beta, T *c, const std::int64_t ldc, + const std::vector &deps) -> sycl::event { + if (is_row_major) { + return mkl_blas::row_major::syrk(q, upper_lower, transA, n, k, + alpha, a, lda, beta, c, ldc, + deps); + } + else { + return mkl_blas::column_major::syrk(q, upper_lower, transA, n, + k, alpha, a, lda, beta, c, + ldc, deps); + } + }; + + // we pass beta = 0, so passing upper or lower does not matter + static constexpr auto uplo = oneapi::mkl::uplo::upper; + syrk_event = syrk_func( + exec_q, + uplo, // Specifies whether C’s data is stored in its upper + // or lower triangle + transA, // Defines the transpose operation for matrix A: + // 'N' indicates no transpose, 'T' for transpose, + // or 'C' for a conjugate transpose. + n, // Number of rows in op(A). + // Number of rows and columns in C. + k, // Number of columns in op(A). + T(1), // Scaling factor for the rank-k update. + a, // Pointer to the input matrix A. + lda, // Leading dimension of matrix A, which is the + // stride between successive rows (for row major layout). + T(0), // Scaling factor for matrix C. + res, // Pointer to output matrix c, where the result is stored. + ldc, // Leading dimension of matrix C. + depends); + } catch (oneapi::mkl::exception const &e) { + error_msg << "Unexpected MKL exception caught during syrk() " + "call:\nreason: " + << e.what(); + is_exception_caught = true; + } catch (sycl::exception const &e) { + error_msg << "Unexpected SYCL exception caught during syrk() call:\n" + << e.what(); + is_exception_caught = true; + } + + if (is_exception_caught) // an unexpected error occurs + { + throw std::runtime_error(error_msg.str()); + } + + return run_copy(exec_q, res, ldc, n, is_row_major, syrk_event); +} + +std::pair + syrk(sycl::queue &exec_q, + const dpctl::tensor::usm_ndarray &matrixA, + const dpctl::tensor::usm_ndarray &resultC, + const std::vector &depends) +{ + const int matrixA_nd = matrixA.get_ndim(); + const int resultC_nd = resultC.get_ndim(); + + if ((matrixA_nd != 2) || (resultC_nd != 2)) { + throw py::value_error("The given arrays have incorrect dimensions."); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(matrixA, resultC)) { + throw py::value_error("Input and output matrices are overlapping " + "segments of memory"); + } + + if (!dpctl::utils::queues_are_compatible( + exec_q, {matrixA.get_queue(), resultC.get_queue()})) + { + throw py::value_error( + "USM allocations are not compatible with the execution queue."); + } + + const py::ssize_t *a_shape = matrixA.get_shape_raw(); + const py::ssize_t *c_shape = resultC.get_shape_raw(); + if (c_shape[0] != c_shape[1]) { + throw py::value_error("The output matrix should be square."); + } + if (a_shape[0] != c_shape[0]) { + throw py::value_error("The number of rows in A must be equal to " + "the number of rows in result array."); + } + + const bool is_matrixA_f_contig = matrixA.is_f_contiguous(); + const bool is_matrixA_c_contig = matrixA.is_c_contiguous(); + if (!is_matrixA_f_contig && !is_matrixA_c_contig) { + throw py::value_error( + "Input matrix is not c-contiguous nor f-contiguous."); + } + + oneapi::mkl::transpose transA; + std::size_t src_nelems; + +// cuBLAS supports only column-major storage +#if defined(USE_ONEMATH_CUBLAS) + constexpr bool is_row_major = false; + std::int64_t n; + std::int64_t k; + + if (is_matrixA_f_contig) { + transA = oneapi::mkl::transpose::N; + n = a_shape[0]; + k = a_shape[1]; + src_nelems = n * n; + } + else { + transA = oneapi::mkl::transpose::T; + k = a_shape[0]; + n = a_shape[1]; + src_nelems = k * k; + } +#else + bool is_row_major = true; + if (is_matrixA_f_contig) { + is_row_major = false; + } + + transA = oneapi::mkl::transpose::N; + const std::int64_t n = a_shape[0]; + const std::int64_t k = a_shape[1]; + src_nelems = n * n; +#endif // USE_ONEMATH_CUBLAS + + const std::int64_t lda = is_row_major ? k : n; + const std::int64_t ldc = n; + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(resultC); + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(resultC, + src_nelems); + + const int matrixA_typenum = matrixA.get_typenum(); + const int resultC_typenum = resultC.get_typenum(); + if (matrixA_typenum != resultC_typenum) { + throw py::value_error("Given arrays must be of the same type."); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + const int type_id = array_types.typenum_to_lookup_id(matrixA_typenum); + syrk_impl_fn_ptr_t syrk_fn = syrk_dispatch_vector[type_id]; + if (syrk_fn == nullptr) { + throw py::value_error("No syrk implementation is available for the " + "specified data type " + "of the input and output arrays."); + } + + const char *a_typeless_ptr = matrixA.get_data(); + char *r_typeless_ptr = resultC.get_data(); + + sycl::event syrk_ev = syrk_fn(exec_q, transA, n, k, a_typeless_ptr, lda, + r_typeless_ptr, ldc, is_row_major, depends); + + sycl::event args_ev = + dpctl::utils::keep_args_alive(exec_q, {matrixA, resultC}, {syrk_ev}); + + return std::make_pair(args_ev, syrk_ev); +} + +template +struct SyrkContigFactory +{ + fnT get() + { + if constexpr (types::SyrkTypePairSupportFactory::is_defined) { + return syrk_impl; + } + else { + return nullptr; + } + } // namespace dpnp::extensions::blas +}; + +void init_syrk_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(syrk_dispatch_vector); +} +} // namespace dpnp::extensions::blas diff --git a/dpnp/backend/extensions/blas/syrk.hpp b/dpnp/backend/extensions/blas/syrk.hpp new file mode 100644 index 000000000000..7fd38a9abdb7 --- /dev/null +++ b/dpnp/backend/extensions/blas/syrk.hpp @@ -0,0 +1,42 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp::extensions::blas +{ +extern std::pair + syrk(sycl::queue &exec_q, + const dpctl::tensor::usm_ndarray &matrixA, + const dpctl::tensor::usm_ndarray &resultC, + const std::vector &depends); + +extern void init_syrk_dispatch_vector(void); +} // namespace dpnp::extensions::blas diff --git a/dpnp/backend/extensions/blas/types_matrix.hpp b/dpnp/backend/extensions/blas/types_matrix.hpp index 7590364737bf..3d70255be313 100644 --- a/dpnp/backend/extensions/blas/types_matrix.hpp +++ b/dpnp/backend/extensions/blas/types_matrix.hpp @@ -186,4 +186,29 @@ struct GemvTypePairSupportFactory // fall-through dpctl_td_ns::NotDefinedEntry>::is_defined; }; + +/** + * @brief A factory to define pairs of supported types for which + * MKL BLAS library provides support in oneapi::mkl::blas::syrk + * function. + * + * @tparam T Type of input and output arrays. + */ +template +struct SyrkTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; } // namespace dpnp::extensions::blas::types diff --git a/dpnp/backend/extensions/lapack/evd_batch_common.hpp b/dpnp/backend/extensions/lapack/evd_batch_common.hpp index 9610d6aa568a..3545db01458c 100644 --- a/dpnp/backend/extensions/lapack/evd_batch_common.hpp +++ b/dpnp/backend/extensions/lapack/evd_batch_common.hpp @@ -97,7 +97,8 @@ std::pair evd_batch_dispatch_table[eig_vecs_type_id][eig_vals_type_id]; if (evd_batch_fn == nullptr) { throw py::value_error( - "Types of input vectors and result array are mismatched."); + "No evd_batch implementation is available for the specified data " + "type of the input and output arrays."); } char *eig_vecs_data = eig_vecs.get_data(); diff --git a/dpnp/backend/extensions/lapack/evd_common.hpp b/dpnp/backend/extensions/lapack/evd_common.hpp index 5503d8f82052..3964943c5305 100644 --- a/dpnp/backend/extensions/lapack/evd_common.hpp +++ b/dpnp/backend/extensions/lapack/evd_common.hpp @@ -91,7 +91,8 @@ std::pair evd_dispatch_table[eig_vecs_type_id][eig_vals_type_id]; if (evd_fn == nullptr) { throw py::value_error( - "Types of input vectors and result array are mismatched."); + "No evd implementation is available for the specified data type " + "of the input and output arrays."); } char *eig_vecs_data = eig_vecs.get_data(); diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index c80332ea8ebd..8e21e1ca4dac 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -50,7 +50,7 @@ ] -def _compute_res_dtype(*arrays, sycl_queue, dtype=None, out=None, casting="no"): +def _compute_res_dtype(*arrays, dtype=None, out=None, casting="no"): """ Determines the output array data type. If `dtype` and `out` are ``None``, the output array data type of the @@ -70,8 +70,6 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, out=None, casting="no"): If not ``None``, data type of the output array. casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional Controls what kind of data casting may occur. - sycl_queue : {SyclQueue} - A SYCL queue to use for determining default floating point datat type. Returns ------- @@ -334,7 +332,7 @@ def _gemm_matmul(exec_q, x1, x2, res): def _gemm_special_case(x1, x2, res_dtype, call_flag): """ `gemm` and `gemm_batch` support these special cases of data types - while `gemv` does not. + while `gemv` or `syrk` do not. """ @@ -520,6 +518,29 @@ def _get_signature(func): return signature, distinct_core +def _is_syrk_compatible(x1, x2): + """ + Check to see if `syrk` can be called instead of `gemm`. + Input arrays have already been validated to be 2-dimensional. + + """ + # Must share data (same base buffer) + if dpnp.get_usm_ndarray(x1)._pointer != dpnp.get_usm_ndarray(x2)._pointer: + return False + + # Result must be square + if x1.shape[0] != x2.shape[1]: + return False + + # Strides must match transpose pattern + x1_strides = x1.strides + x2_strides = x2.strides + if x1_strides[0] != x2_strides[1] or x1_strides[1] != x2_strides[0]: + return False + + return True + + def _shape_error(shape1, shape2, func, err_msg): """Validate the shapes of input and output arrays.""" @@ -765,9 +786,7 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False): _validate_out_array(out, exec_q) # Determine the appropriate data types - res_dtype = _compute_res_dtype( - a, b, out=out, casting=casting, sycl_queue=exec_q - ) + res_dtype = _compute_res_dtype(a, b, out=out, casting=casting) result = _create_result_array( a, b, out, (), res_dtype, res_usm_type, exec_q @@ -918,7 +937,7 @@ def dpnp_multiplication( # Determine the appropriate data types res_dtype = _compute_res_dtype( - x1, x2, dtype=dtype, out=out, casting=casting, sycl_queue=exec_q + x1, x2, dtype=dtype, out=out, casting=casting ) call_flag = None @@ -928,8 +947,6 @@ def dpnp_multiplication( x1_is_2D, x1_is_1D, x1_base_is_1D = _define_dim_flags(x1, axis=-1) x2_is_2D, x2_is_1D, x2_base_is_1D = _define_dim_flags(x2, axis=-2) - # TODO: investigate usage of syrk function from BLAS in - # case of a.T @ a and a @ a.T to gain performance. if numpy.prod(result_shape) == 0: res_shape = result_shape elif x1_shape[-1] == 1: @@ -966,6 +983,14 @@ def dpnp_multiplication( x1 = dpnp.reshape(x1, x1_shape[-2:]) x2 = dpnp.reshape(x2, x2_shape[-2:]) res_shape = (x1_shape[-2], x2_shape[-1]) + if _is_syrk_compatible(x1, x2): + call_flag = "syrk" + res_dtype_orig = res_dtype + # for exact dtypes, use syrk implementation unlike general approach + # where dpctl implementation is used for exact dtypes for better + # performance + if not dpnp.issubdtype(res_dtype, dpnp.inexact): + res_dtype = dpnp.default_float_type(x1.device) elif x1_base_is_1D: # TODO: implement gemv_batch to use it here with transpose call_flag = "gemm_batch" @@ -1046,12 +1071,13 @@ def dpnp_multiplication( dtype=res_dtype, order=res_order, ) - x2 = _copy_array( - x2, - copy_flag=not x2_contig_flag, - dtype=res_dtype, - order=res_order, - ) + if call_flag != "syrk": + x2 = _copy_array( + x2, + copy_flag=not x2_contig_flag, + dtype=res_dtype, + order=res_order, + ) if call_flag == "gemv": if transpose: @@ -1062,7 +1088,6 @@ def dpnp_multiplication( x_usm = dpnp.get_usm_ndarray(x2) _manager = dpu.SequentialOrderManager[exec_q] - ht_ev, gemv_ev = bi._gemv( exec_q, a_usm, @@ -1072,6 +1097,15 @@ def dpnp_multiplication( depends=_manager.submitted_events, ) _manager.add_event_pair(ht_ev, gemv_ev) + elif call_flag == "syrk": + _manager = dpu.SequentialOrderManager[exec_q] + ht_ev, gemv_ev = bi._syrk( + exec_q, + dpnp.get_usm_ndarray(x1), + dpnp.get_usm_ndarray(result), + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, gemv_ev) elif call_flag == "gemm": result = _gemm_matmul(exec_q, x1, x2, result) else: @@ -1101,6 +1135,9 @@ def dpnp_multiplication( elif res_shape != result_shape: result = dpnp.reshape(result, result_shape) + if call_flag == "syrk" and res_dtype_orig != res_dtype: + result = result.astype(res_dtype_orig) + if out is None: if axes is not None: # Move the data back to the appropriate axes of the result array @@ -1217,7 +1254,7 @@ def dpnp_vecdot( if axis is not None: raise TypeError("cannot specify both `axis` and `axes`.") - axes_x1, axes_x2, axes_res = _validate_axes(x1, x2, axes, "vecdot") + axes_x1, axes_x2, _ = _validate_axes(x1, x2, axes, "vecdot") # Move the axes that are going to be used in dot product, # to the end of "x1" and "x2" @@ -1241,7 +1278,7 @@ def dpnp_vecdot( # Determine the appropriate data types res_dtype = _compute_res_dtype( - x1, x2, dtype=dtype, out=out, casting=casting, sycl_queue=exec_q + x1, x2, dtype=dtype, out=out, casting=casting ) _, x1_is_1D, _ = _define_dim_flags(x1, axis=-1) diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index 983fd2161d2d..9963c85f737c 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -25,9 +25,6 @@ class TestCross: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("axis", [None, 0]) @pytest.mark.parametrize("axisc", [-1, 0]) @pytest.mark.parametrize("axisb", [-1, 0]) @@ -182,9 +179,6 @@ def test_linalg_error(self): class TestDot: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_ones(self, dtype): n = 10**5 @@ -433,9 +427,6 @@ def test_out_error(self, shape1, shape2, out_shape): class TestInner: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_scalar(self, dtype): a = 2 @@ -599,9 +590,6 @@ def test_order(self, order): class TestMatmul: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", _selected_dtypes) @pytest.mark.parametrize( "order1, order2", [("C", "C"), ("C", "F"), ("F", "C"), ("F", "F")] @@ -1062,17 +1050,19 @@ def test_strided_vec_mat(self, dtype, func, incx, incy, transpose): @pytest.mark.parametrize("dtype", _selected_dtypes) def test_out_order1(self, order1, order2, out_order, dtype): # test gemm with out keyword - a = generate_random_numpy_array((5, 4), dtype, low=-5, high=5) - b = generate_random_numpy_array((4, 7), dtype, low=-5, high=5) - a = numpy.array(a, order=order1) - b = numpy.array(b, order=order2) + a = generate_random_numpy_array( + (5, 4), dtype, order=order1, low=-5, high=5 + ) + b = generate_random_numpy_array( + (4, 7), dtype, order=order2, low=-5, high=5 + ) ia, ib = dpnp.array(a), dpnp.array(b) - iout = dpnp.empty((5, 7), dtype=dtype, order=out_order) + out = numpy.empty((5, 7), dtype=dtype, order=out_order) + iout = dpnp.array(out) result = dpnp.matmul(ia, ib, out=iout) assert result is iout - out = numpy.empty((5, 7), dtype=dtype, order=out_order) expected = numpy.matmul(a, b, out=out) assert result.flags.c_contiguous == expected.flags.c_contiguous assert result.flags.f_contiguous == expected.flags.f_contiguous @@ -1185,6 +1175,75 @@ def test_special_case(self, dt_out, shape1, shape2): result = dpnp.matmul(ia, ib, out=iout) assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("dt", get_all_dtypes(no_none=True)) + def test_syrk(self, dt): + a = generate_random_numpy_array((6, 9), dtype=dt, low=-5, high=5) + ia = dpnp.array(a) + + result = dpnp.matmul(ia, ia.mT) + expected = numpy.matmul(a, a.T) + assert_dtype_allclose(result, expected) + + iout = dpnp.empty(result.shape, dtype=dt) + result = dpnp.matmul(ia, ia.mT, out=iout) + assert result is iout + assert_dtype_allclose(result, expected) + + result = ia.mT @ ia + expected = a.T @ a + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dt", [dpnp.int32, dpnp.float32]) + def test_syrk_strided(self, dt): + a = generate_random_numpy_array((20, 30), dtype=dt) + ia = dpnp.array(a) + a = a[::2, ::2] + ia = ia[::2, ::2] + + result = dpnp.matmul(ia, ia.mT) + expected = numpy.matmul(a, a.T) + assert_dtype_allclose(result, expected) + + result = ia.mT @ ia + expected = a.T @ a + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "order, out_order", + [("C", "C"), ("C", "F"), ("F", "C"), ("F", "F")], + ) + def test_syrk_out_order(self, order, out_order): + a = generate_random_numpy_array((5, 4), order=order, low=-5, high=5) + out = numpy.empty((5, 5), dtype=a.dtype, order=out_order) + ia, iout = dpnp.array(a), dpnp.array(out) + + expected = numpy.matmul(a, a.T, out=out) + result = dpnp.matmul(ia, ia.mT, out=iout) + assert result is iout + assert result.flags.c_contiguous == expected.flags.c_contiguous + assert result.flags.f_contiguous == expected.flags.f_contiguous + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("order", ["F", "C"]) + def test_syrk_order(self, order): + a = generate_random_numpy_array((4, 6), order=order, low=-5, high=5) + ia = dpnp.array(a) + expected = numpy.matmul(a, a.T) + result = dpnp.matmul(ia, ia.mT) + assert_dtype_allclose(result, expected) + + # added for coverage + def test_not_syrk(self): + a = generate_random_numpy_array((20, 20), low=-5, high=5) + ia = dpnp.array(a) + + # Result must be square + b = a.T[:, ::2] + ib = ia.mT[:, ::2] + expected = numpy.matmul(a, b) + result = dpnp.matmul(ia, ib) + assert_dtype_allclose(result, expected) + def test_bool(self): a = generate_random_numpy_array((3, 4), dtype=dpnp.bool) b = generate_random_numpy_array((4, 5), dtype=dpnp.bool) @@ -1440,9 +1499,6 @@ def test_invalid_axes(self, xp): @testing.with_requires("numpy>=2.2") class TestMatvec: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.skipif( is_win_platform() and not is_gpu_device(), reason="SAT-8073" ) @@ -1505,9 +1561,6 @@ def test_error(self, xp): class TestMultiDot: - def setup_method(self): - numpy.random.seed(70) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize( "shapes", @@ -1662,9 +1715,6 @@ def test_error(self): class TestTensordot: - def setup_method(self): - numpy.random.seed(87) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_scalar(self, dtype): a = 2 @@ -1796,9 +1846,6 @@ def test_error(self): class TestVdot: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_scalar(self, dtype): a = numpy.array([3.5], dtype=dtype) @@ -1889,9 +1936,6 @@ def test_error(self): @testing.with_requires("numpy>=2.0") class TestVecdot: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize( "shape1, shape2", @@ -2169,9 +2213,6 @@ def test_error(self, xp): @testing.with_requires("numpy>=2.2") class TestVecmat: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.skipif( is_win_platform() and not is_gpu_device(), reason="SAT-8073" ) diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index c501bcb169e6..0316c8a7510f 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -415,9 +415,6 @@ def test_1in_1out(func, data, device): pytest.param("ldexp", [5, 5, 5, 5, 5], [0, 1, 2, 3, 4]), pytest.param("logaddexp", [-1, 2, 5, 9], [4, -3, 2, -8]), pytest.param("logaddexp2", [-1, 2, 5, 9], [4, -3, 2, -8]), - pytest.param( - "matmul", [[1.0, 0.0], [0.0, 1.0]], [[4.0, 1.0], [1.0, 2.0]] - ), pytest.param("maximum", [2.0, 3.0, 4.0], [1.0, 5.0, 2.0]), pytest.param("minimum", [2.0, 3.0, 4.0], [1.0, 5.0, 2.0]), pytest.param( @@ -632,40 +629,50 @@ def test_bitwise_op_2in(op, device): assert_sycl_queue_equal(zy.sycl_queue, y.sycl_queue) -@pytest.mark.parametrize("device", valid_dev, ids=dev_ids) -@pytest.mark.parametrize( - "shape1, shape2", - [ - ((2, 4), (4,)), - ((4,), (4, 3)), - ((2, 4), (4, 3)), - ((2, 0), (0, 3)), - ((2, 4), (4, 0)), - ((4, 2, 3), (4, 3, 5)), - ((4, 2, 3), (4, 3, 1)), - ((4, 1, 3), (4, 3, 5)), - ((6, 7, 4, 3), (6, 7, 3, 5)), - ], - ids=[ - "((2, 4), (4,))", - "((4,), (4, 3))", - "((2, 4), (4, 3))", - "((2, 0), (0, 3))", - "((2, 4), (4, 0))", - "((4, 2, 3), (4, 3, 5))", - "((4, 2, 3), (4, 3, 1))", - "((4, 1, 3), (4, 3, 5))", - "((6, 7, 4, 3), (6, 7, 3, 5))", - ], -) -def test_matmul(device, shape1, shape2): - a = dpnp.arange(numpy.prod(shape1), device=device).reshape(shape1) - b = dpnp.arange(numpy.prod(shape2), device=device).reshape(shape2) - result = dpnp.matmul(a, b) +class TestMatmul: + @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) + @pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32]) + @pytest.mark.parametrize( + "shape1, shape2", + [ + ((2, 4), (4,)), + ((4,), (4, 3)), + ((2, 4), (4, 3)), + ((2, 0), (0, 3)), + ((2, 4), (4, 0)), + ((4, 2, 3), (4, 3, 5)), + ((4, 2, 3), (4, 3, 1)), + ((4, 1, 3), (4, 3, 5)), + ((6, 7, 4, 3), (6, 7, 3, 5)), + ], + ids=[ + "((2, 4), (4,))", + "((4,), (4, 3))", + "((2, 4), (4, 3))", + "((2, 0), (0, 3))", + "((2, 4), (4, 0))", + "((4, 2, 3), (4, 3, 5))", + "((4, 2, 3), (4, 3, 1))", + "((4, 1, 3), (4, 3, 5))", + "((6, 7, 4, 3), (6, 7, 3, 5))", + ], + ) + def test_matmul(self, device, dtype, shape1, shape2): + # int32 checks dpctl implementation and float32 checks oneMKL + a = dpnp.arange(numpy.prod(shape1), dtype=dtype, device=device) + b = dpnp.arange(numpy.prod(shape2), dtype=dtype, device=device) + a, b = a.reshape(shape1), b.reshape(shape2) + result = dpnp.matmul(a, b) - result_queue = result.sycl_queue - assert_sycl_queue_equal(result_queue, a.sycl_queue) - assert_sycl_queue_equal(result_queue, b.sycl_queue) + result_queue = result.sycl_queue + assert_sycl_queue_equal(result_queue, a.sycl_queue) + assert_sycl_queue_equal(result_queue, b.sycl_queue) + + @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) + def test_matmul_syrk(self, device): + a = dpnp.arange(20, dtype=dpnp.float32, device=device).reshape(4, 5) + result = dpnp.matmul(a, a.mT) + assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue) @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index df88071e39e5..aed316eca533 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -403,41 +403,52 @@ def test_bitwise_op_2in(op, usm_type_x, usm_type_y): assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) -@pytest.mark.parametrize("usm_type_x", list_of_usm_types) -@pytest.mark.parametrize("usm_type_y", list_of_usm_types) -@pytest.mark.parametrize( - "shape1, shape2", - [ - ((2, 4), (4,)), - ((4,), (4, 3)), - ((2, 4), (4, 3)), - ((2, 0), (0, 3)), - ((2, 4), (4, 0)), - ((4, 2, 3), (4, 3, 5)), - ((4, 2, 3), (4, 3, 1)), - ((4, 1, 3), (4, 3, 5)), - ((6, 7, 4, 3), (6, 7, 3, 5)), - ], - ids=[ - "((2, 4), (4,))", - "((4,), (4, 3))", - "((2, 4), (4, 3))", - "((2, 0), (0, 3))", - "((2, 4), (4, 0))", - "((4, 2, 3), (4, 3, 5))", - "((4, 2, 3), (4, 3, 1))", - "((4, 1, 3), (4, 3, 5))", - "((6, 7, 4, 3), (6, 7, 3, 5))", - ], -) -def test_matmul(usm_type_x, usm_type_y, shape1, shape2): - x = dpnp.arange(numpy.prod(shape1), usm_type=usm_type_x).reshape(shape1) - y = dpnp.arange(numpy.prod(shape2), usm_type=usm_type_y).reshape(shape2) - z = dpnp.matmul(x, y) +class TestMatmul: + @pytest.mark.parametrize("usm_type_x", list_of_usm_types) + @pytest.mark.parametrize("usm_type_y", list_of_usm_types) + @pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32]) + @pytest.mark.parametrize( + "shape1, shape2", + [ + ((2, 4), (4,)), + ((4,), (4, 3)), + ((2, 4), (4, 3)), + ((2, 0), (0, 3)), + ((2, 4), (4, 0)), + ((4, 2, 3), (4, 3, 5)), + ((4, 2, 3), (4, 3, 1)), + ((4, 1, 3), (4, 3, 5)), + ((6, 7, 4, 3), (6, 7, 3, 5)), + ], + ids=[ + "((2, 4), (4,))", + "((4,), (4, 3))", + "((2, 4), (4, 3))", + "((2, 0), (0, 3))", + "((2, 4), (4, 0))", + "((4, 2, 3), (4, 3, 5))", + "((4, 2, 3), (4, 3, 1))", + "((4, 1, 3), (4, 3, 5))", + "((6, 7, 4, 3), (6, 7, 3, 5))", + ], + ) + def test_basic(self, usm_type_x, usm_type_y, dtype, shape1, shape2): + # int32 checks dpctl implementation and float32 checks oneMKL + x = dpnp.arange(numpy.prod(shape1), dtype=dtype, usm_type=usm_type_x) + y = dpnp.arange(numpy.prod(shape2), dtype=dtype, usm_type=usm_type_y) + x, y = x.reshape(shape1), y.reshape(shape2) + z = dpnp.matmul(x, y) - assert x.usm_type == usm_type_x - assert y.usm_type == usm_type_y - assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) + assert x.usm_type == usm_type_x + assert y.usm_type == usm_type_y + assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) + + @pytest.mark.parametrize("usm_type", list_of_usm_types) + def test_syrk(self, usm_type): + x = dpnp.arange(20, dtype=dpnp.float32, usm_type=usm_type).reshape(4, 5) + y = dpnp.matmul(x, x.mT) + + assert y.usm_type == usm_type @pytest.mark.parametrize("usm_type_x", list_of_usm_types)