diff --git a/.github/workflows/build-with-clang.yml b/.github/workflows/build-with-clang.yml index 498bede..727bb88 100644 --- a/.github/workflows/build-with-clang.yml +++ b/.github/workflows/build-with-clang.yml @@ -73,5 +73,5 @@ jobs: - name: Run mkl_fft tests run: | source ${{ env.ONEAPI_ROOT }}/setvars.sh - pip install scipy mkl-service pytest + pip install pytest mkl-service scipy dask pytest -s -v --pyargs mkl_fft diff --git a/.github/workflows/conda-package-cf.yml b/.github/workflows/conda-package-cf.yml index cd7282a..2b59b0a 100644 --- a/.github/workflows/conda-package-cf.yml +++ b/.github/workflows/conda-package-cf.yml @@ -132,7 +132,7 @@ jobs: - name: Install mkl_fft run: | CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}" - conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest scipy $CHANNELS + conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest scipy dask $CHANNELS # Test installed packages conda list -n ${{ env.TEST_ENV_NAME }} @@ -298,7 +298,7 @@ jobs: FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO ( SET PACKAGE_VERSION=%%F ) - SET "TEST_DEPENDENCIES=pytest scipy" + SET "TEST_DEPENDENCIES=pytest scipy dask" conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} ${{ matrix.numpy }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }} - name: Report content of test environment diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 48f09ff..7d92a99 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -140,7 +140,7 @@ jobs: - name: Install mkl_fft run: | CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}" - conda create -n ${{ env.TEST_ENV_NAME }} $PACKAGE_NAME=${{ env.PACKAGE_VERSION }} python=${{ matrix.python }} pytest "scipy>=1.10" $CHANNELS + conda create -n ${{ env.TEST_ENV_NAME }} $PACKAGE_NAME=${{ env.PACKAGE_VERSION }} python=${{ matrix.python }} pytest dask "scipy>=1.10" $CHANNELS # Test installed packages conda list -n ${{ env.TEST_ENV_NAME }} @@ -307,7 +307,7 @@ jobs: FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO ( SET PACKAGE_VERSION=%%F ) - SET "TEST_DEPENDENCIES=pytest scipy" + SET "TEST_DEPENDENCIES=pytest scipy dask" conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }} - name: Report content of test environment diff --git a/CHANGELOG.md b/CHANGELOG.md index 106b7c3..26edb1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [2.0.0] (05/DD/2025) +## [dev] (MM/DD/YYYY) + +### Added +* Added a new interface for FFT module of Dask accessible through `mkl_fft.interfaces.dask_fft` [gh-???](https://github.com/IntelPython/mkl_fft/pull/???) + +### Changed + +### Fixed + +## [2.0.0] (06/DD/2025) ### Added * Added Hermitian FFT functions to SciPy interface `mkl_fft.interfaces.scipy_fft`: `hfft`, `ihfft`, `hfftn`, `ihfftn`, `hfft2`, and `ihfft2` [gh-161](https://github.com/IntelPython/mkl_fft/pull/161) diff --git a/conda-recipe-cf/meta.yaml b/conda-recipe-cf/meta.yaml index 20c1a62..bff7b24 100644 --- a/conda-recipe-cf/meta.yaml +++ b/conda-recipe-cf/meta.yaml @@ -33,11 +33,13 @@ test: requires: - pytest - scipy >=1.10 + - dask imports: - mkl_fft - mkl_fft.interfaces - mkl_fft.interfaces.numpy_fft - mkl_fft.interfaces.scipy_fft + - mkl_fft.interfaces.dask_fft about: home: http://github.com/IntelPython/mkl_fft diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 2eeccb2..8b1e5b2 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -34,11 +34,13 @@ test: requires: - pytest - scipy >=1.10 + - dask imports: - mkl_fft - mkl_fft.interfaces - mkl_fft.interfaces.numpy_fft - mkl_fft.interfaces.scipy_fft + - mkl_fft.interfaces.dask_fft about: home: http://github.com/IntelPython/mkl_fft diff --git a/mkl_fft/interfaces/README.md b/mkl_fft/interfaces/README.md index 0e8969f..6806b39 100644 --- a/mkl_fft/interfaces/README.md +++ b/mkl_fft/interfaces/README.md @@ -1,5 +1,5 @@ # Interfaces -The `mkl_fft` package provides interfaces that serve as drop-in replacements for equivalent functions in NumPy and SciPy. +The `mkl_fft` package provides interfaces that serve as drop-in replacements for equivalent functions in NumPy, SciPy, and Dask. --- @@ -125,3 +125,43 @@ with mkl_fft.set_workers(4): y = scipy.signal.fftconvolve(a, a) # Note that Nthr:4 # MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:4,unaligned_output,desc:0x563aefe86180) 187.37us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:4 ``` + +--- + +## Dask interface - `mkl_fft.interfaces.dask_fft` + +This interface is a drop-in replacement for the [`dask.fft`](https://dask.pydata.org/en/latest/array-api.html#fast-fourier-transforms) module and includes **all** the functions available there: + +* complex-to-complex FFTs: `fft`, `ifft`, `fft2`, `ifft2`, `fftn`, `ifftn`. + +* real-to-complex and complex-to-real FFTs: `rfft`, `irfft`, `rfft2`, `irfft2`, `rfftn`, `irfftn`. + +* Hermitian FFTs: `hfft`, `ihfft`. + +* Helper routines: `fft_wrap`, `fftfreq`, `rfftfreq`, `fftshift`, `ifftshift`. These routines serve as a fallback to the Dask implementation and are included for completeness. + +The following example shows how to use this interface for calculating a 2D FFT. + +```python +import numpy, dask +import mkl_fft.interfaces.dask_fft as dask_fft + +a = numpy.random.randn(128, 64) + 1j*numpy.random.randn(128, 64) +x = dask.array.from_array(a, chunks=(64, 64)) +lazy_res = dask_fft.fft(x) +mkl_res = lazy_res.compute() +np_res = numpy.fft.fft(a) +numpy.allclose(mkl_res, np_res) +# True + +# There are two chunks in this example based on the size of input array (128, 64) and chunk size (64, 64) +# to confirm that MKL FFT is called twice, turn on verbosity +import mkl +mkl.verbose(1) +# True + +mkl_res = lazy_res.compute() # MKL_VERBOSE FFT is shown twice below which means MKL FFT is called twice +# MKL_VERBOSE oneMKL 2024.0 Update 2 Patch 2 Product build 20240823 for Intel(R) 64 architecture Intel(R) Advanced Vector Extensions 512 (Intel(R) AVX-512) with support for INT8, BF16, FP16 (limited) instructions, and Intel(R) Advanced Matrix Extensions (Intel(R) AMX) with INT8 and BF16, Lnx 3.80GHz intel_thread +# MKL_VERBOSE FFT(dcfo64*64,input_strides:{0,1},output_strides:{0,1},input_distance:64,output_distance:64,bScale:0.015625,tLim:32,unaligned_input,desc:0x7fd000010e40) 432.84us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112 +# MKL_VERBOSE FFT(dcfo64*64,input_strides:{0,1},output_strides:{0,1},input_distance:64,output_distance:64,bScale:0.015625,tLim:32,unaligned_input,desc:0x7fd480011300) 499.00us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112 +``` diff --git a/mkl_fft/interfaces/__init__.py b/mkl_fft/interfaces/__init__.py index 1988ba8..14d3081 100644 --- a/mkl_fft/interfaces/__init__.py +++ b/mkl_fft/interfaces/__init__.py @@ -31,3 +31,10 @@ pass else: from . import scipy_fft + +try: + import dask.fft +except ImportError: + pass +else: + from . import dask_fft diff --git a/mkl_fft/interfaces/dask_fft.py b/mkl_fft/interfaces/dask_fft.py new file mode 100644 index 0000000..6dad605 --- /dev/null +++ b/mkl_fft/interfaces/dask_fft.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# Copyright (c) 2025, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from dask.array.fft import fft_wrap, fftfreq, fftshift, ifftshift, rfftfreq + +from . import numpy_fft as _numpy_fft + +__all__ = [ + "fft", + "ifft", + "fft2", + "ifft2", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfft2", + "irfft2", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftshift", + "ifftshift", + "fftfreq", + "rfftfreq", + "fft_wrap", +] + + +fft = fft_wrap(_numpy_fft.fft) +ifft = fft_wrap(_numpy_fft.ifft) +fft2 = fft_wrap(_numpy_fft.fft2) +ifft2 = fft_wrap(_numpy_fft.ifft2) +fftn = fft_wrap(_numpy_fft.fftn) +ifftn = fft_wrap(_numpy_fft.ifftn) +rfft = fft_wrap(_numpy_fft.rfft) +irfft = fft_wrap(_numpy_fft.irfft) +rfft2 = fft_wrap(_numpy_fft.rfft2) +irfft2 = fft_wrap(_numpy_fft.irfft2) +rfftn = fft_wrap(_numpy_fft.rfftn) +irfftn = fft_wrap(_numpy_fft.irfftn) +hfft = fft_wrap(_numpy_fft.hfft) +ihfft = fft_wrap(_numpy_fft.ihfft) diff --git a/mkl_fft/tests/test_interfaces.py b/mkl_fft/tests/test_interfaces.py index 226d462..ff0b3e9 100644 --- a/mkl_fft/tests/test_interfaces.py +++ b/mkl_fft/tests/test_interfaces.py @@ -34,13 +34,21 @@ except AttributeError: scipy_fft = None +try: + dask_fft = mfi.dask_fft +except AttributeError: + dask_fft = None + interfaces = [] ids = [] +interfaces.append(mfi.numpy_fft) +ids.append("numpy") if scipy_fft is not None: interfaces.append(scipy_fft) ids.append("scipy") -interfaces.append(mfi.numpy_fft) -ids.append("numpy") +if dask_fft is not None: + interfaces.append(dask_fft) + ids.append("dask") @pytest.mark.parametrize("norm", [None, "forward", "backward", "ortho"]) @@ -189,3 +197,8 @@ def test_axes(interface): ) def test_interface_helper_functions(interface, func): assert hasattr(interface, func) + + +def test_dask_fftwrap(): + pytest.importorskip("dask", reason="requires dask") + assert hasattr(mfi.dask_fft, "fft_wrap") diff --git a/mkl_fft/tests/third_party/dask/test_dask_fft.py b/mkl_fft/tests/third_party/dask/test_dask_fft.py new file mode 100644 index 0000000..43cca27 --- /dev/null +++ b/mkl_fft/tests/third_party/dask/test_dask_fft.py @@ -0,0 +1,166 @@ +# This file includes tests from dask.fft module: +# https://github.com/dask/dask/blob/main/dask/array/tests/test_fft.py + +import contextlib +from itertools import combinations_with_replacement + +import numpy as np +import pytest + +try: + import dask +except ImportError: + pytest.skip("This test file needs dask", allow_module_level=True) +else: + import dask.array as da + from dask.array.numpy_compat import NUMPY_GE_200 + from dask.array.utils import assert_eq, same_keys + + import mkl_fft.interfaces.dask_fft as dask_fft + +requires_dask_2024_8_2 = pytest.mark.skipif( + dask.__version__ < "2024.8.2", + reason="norm kwarg requires Dask >= 2024.8.2", +) + +all_1d_funcnames = ["fft", "ifft", "rfft", "irfft", "hfft", "ihfft"] + +all_nd_funcnames = [ + "fft2", + "ifft2", + "fftn", + "ifftn", + "rfft2", + "irfft2", + "rfftn", + "irfftn", +] + +if not da._array_expr_enabled(): + + nparr = np.arange(100).reshape(10, 10) + darr = da.from_array(nparr, chunks=(1, 10)) + darr2 = da.from_array(nparr, chunks=(10, 1)) + darr3 = da.from_array(nparr, chunks=(10, 10)) + + +@pytest.mark.parametrize("funcname", all_1d_funcnames) +def test_cant_fft_chunked_axis(funcname): + da_fft = getattr(dask_fft, funcname) + + bad_darr = da.from_array(nparr, chunks=(5, 5)) + for i in range(bad_darr.ndim): + with pytest.raises(ValueError): + da_fft(bad_darr, axis=i) + + +@pytest.mark.parametrize("funcname", all_1d_funcnames) +def test_fft(funcname): + da_fft = getattr(dask_fft, funcname) + np_fft = getattr(np.fft, funcname) + + # pylint: disable=possibly-used-before-assignment + assert_eq(da_fft(darr), np_fft(nparr)) + + +@pytest.mark.parametrize("funcname", all_nd_funcnames) +def test_fft2n_shapes(funcname): + da_fft = getattr(dask_fft, funcname) + np_fft = getattr(np.fft, funcname) + + # pylint: disable=possibly-used-before-assignment + assert_eq(da_fft(darr3), np_fft(nparr)) + assert_eq( + da_fft(darr3, (8, 9), axes=(1, 0)), np_fft(nparr, (8, 9), axes=(1, 0)) + ) + assert_eq( + da_fft(darr3, (12, 11), axes=(1, 0)), + np_fft(nparr, (12, 11), axes=(1, 0)), + ) + + if NUMPY_GE_200 and funcname.endswith("fftn"): + ctx = pytest.warns( + DeprecationWarning, + match="`axes` should not be `None` if `s` is not `None`", + ) + else: + ctx = contextlib.nullcontext() + with ctx: + expect = np_fft(nparr, (8, 9)) + with ctx: + actual = da_fft(darr3, (8, 9)) + assert_eq(expect, actual) + + +@requires_dask_2024_8_2 +@pytest.mark.parametrize("funcname", all_1d_funcnames) +def test_fft_n_kwarg(funcname): + da_fft = getattr(dask_fft, funcname) + np_fft = getattr(np.fft, funcname) + + assert_eq(da_fft(darr, 5), np_fft(nparr, 5)) + assert_eq(da_fft(darr, 13), np_fft(nparr, 13)) + assert_eq( + da_fft(darr, 13, norm="backward"), np_fft(nparr, 13, norm="backward") + ) + assert_eq(da_fft(darr, 13, norm="ortho"), np_fft(nparr, 13, norm="ortho")) + assert_eq( + da_fft(darr, 13, norm="forward"), np_fft(nparr, 13, norm="forward") + ) + # pylint: disable=possibly-used-before-assignment + assert_eq(da_fft(darr2, axis=0), np_fft(nparr, axis=0)) + assert_eq(da_fft(darr2, 5, axis=0), np_fft(nparr, 5, axis=0)) + assert_eq( + da_fft(darr2, 13, axis=0, norm="backward"), + np_fft(nparr, 13, axis=0, norm="backward"), + ) + assert_eq( + da_fft(darr2, 12, axis=0, norm="ortho"), + np_fft(nparr, 12, axis=0, norm="ortho"), + ) + assert_eq( + da_fft(darr2, 12, axis=0, norm="forward"), + np_fft(nparr, 12, axis=0, norm="forward"), + ) + + +@pytest.mark.parametrize("funcname", all_1d_funcnames) +def test_fft_consistent_names(funcname): + da_fft = getattr(dask_fft, funcname) + + assert same_keys(da_fft(darr, 5), da_fft(darr, 5)) + assert same_keys(da_fft(darr2, 5, axis=0), da_fft(darr2, 5, axis=0)) + assert not same_keys(da_fft(darr, 5), da_fft(darr, 13)) + + +@pytest.mark.parametrize("funcname", all_nd_funcnames) +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_nd_ffts_axes(funcname, dtype): + np_fft = getattr(np.fft, funcname) + da_fft = getattr(dask_fft, funcname) + + shape = (7, 8, 9) + chunk_size = (3, 3, 3) + a = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + d = da.from_array(a, chunks=chunk_size) + + for num_axes in range(1, d.ndim): + for axes in combinations_with_replacement(range(d.ndim), num_axes): + cs = list(chunk_size) + for i in axes: + cs[i] = shape[i] + d2 = d.rechunk(cs) + if len(set(axes)) < len(axes): + with pytest.raises(ValueError): + da_fft(d2, axes=axes) + else: + r = da_fft(d2, axes=axes) + er = np_fft(a, axes=axes) + if np.lib.NumpyVersion(np.__version__) >= "2.0.0": + check_dtype = True + assert r.dtype == er.dtype + else: + check_dtype = False + assert r.shape == er.shape + + assert_eq(r, er, check_dtype=check_dtype, rtol=1e-6, atol=1e-4) diff --git a/pyproject.toml b/pyproject.toml index 55f31da..61681ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ readme = {file = "README.md", content-type = "text/markdown"} requires-python = ">=3.9,<3.13" [project.optional-dependencies] +dask_interface = ["dask"] scipy_interface = ["scipy>=1.10"] test = ["pytest", "scipy>=1.10"]