From 5a1f737c809d17f0d1a1cbf7630aed4aaf33832d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Feb 2025 18:34:14 +0100 Subject: [PATCH 01/10] Switch to lazy exports --- pyproject.toml | 1 + src/fast_array_utils/types.py | 92 +++++++++++++++++++++++++---------- 2 files changed, 66 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0f6af6a..ce960e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ lint.ignore = [ ] lint.per-file-ignores."docs/**/*.py" = [ "INP001" ] # No __init__.py in docs lint.per-file-ignores."src/**/stats/*.py" = [ "A001", "A004" ] # Shadows builtins like `sum` +lint.per-file-ignores."src/fast_array_utils/types.py" = [ "N806" ] # We have variables that are classes here lint.per-file-ignores."stubs/**/*.pyi" = [ "F403", "F405", "N801" ] # Stubs don’t follow name conventions lint.per-file-ignores."tests/**/test_*.py" = [ "D100", # tests need no module docstrings diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index 8026123..413c261 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -3,10 +3,16 @@ from __future__ import annotations +from functools import cache from importlib.util import find_spec from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable +if TYPE_CHECKING: + from collections.abc import Callable + from types import UnionType + + __all__ = [ "CSBase", "CupyArray", @@ -20,6 +26,39 @@ T_co = TypeVar("T_co", covariant=True) +# registry for lazy exports: + + +_REGISTRY: dict[str, str | Callable[[], UnionType]] = {} + + +def _register(name: str) -> Callable[[Callable[[], UnionType]], Callable[[], UnionType]]: + def _decorator(fn: Callable[[], UnionType]) -> Callable[[], UnionType]: + _REGISTRY[name] = fn + return fn + + return _decorator + + +@cache +def __getattr__(name: str) -> type | UnionType: + if (source := _REGISTRY.get(name)) is None: + # A name we don’t know about + raise AttributeError(name) from None + + if callable(source): + return source() + + try: + mod, name = source.rsplit(".", 1) + return getattr(__import__(mod, fromlist=[name]), name) # type: ignore[no-any-return] + except ImportError: # A name we can’t import + return type(name, (), {}) + + +# lazy exports: + + # scipy sparse if TYPE_CHECKING: from scipy.sparse import csc_array, csc_matrix, csr_array, csr_matrix @@ -28,53 +67,52 @@ CSMatrix = csr_matrix | csc_matrix CSBase = CSMatrix | CSArray else: - try: # cs?_array isn’t available in older scipy versions - from scipy.sparse import csc_array, csr_array - CSArray = csr_array | csc_array - except ImportError: # pragma: no cover - CSArray = type("CSArray", (), {}) - - try: # cs?_matrix is available when scipy is installed + @_register("CSMatrix") + def _get_cs_matrix() -> UnionType: + # cs?_matrix is available when scipy is installed from scipy.sparse import csc_matrix, csr_matrix - CSMatrix = csr_matrix | csc_matrix - except ImportError: # pragma: no cover - CSMatrix = type("CSMatrix", (), {}) + return csr_matrix | csc_matrix - CSBase = CSMatrix | CSArray + @_register("CSArray") + def _get_cs_array() -> UnionType: + from scipy.sparse import csc_array, csr_array + return csr_array | csc_array -if TYPE_CHECKING or find_spec("cupy"): - from cupy import ndarray as CupyArray -else: # pragma: no cover - CupyArray = type("ndarray", (), {}) + @_register("CSBase") + def _get_cs_base() -> UnionType: + return __getattr__("CSMatrix") | __getattr__("CSArray") -if TYPE_CHECKING or find_spec("cupyx"): +if TYPE_CHECKING: + from cupy import ndarray as CupyArray from cupyx.scipy.sparse import spmatrix as CupySparseMatrix -else: # pragma: no cover - CupySparseMatrix = type("spmatrix", (), {}) +else: + _REGISTRY["CupyArray"] = "cupy.ndarray" + _REGISTRY["CupySparseMatrix"] = "cupyx.scipy.sparse.spmatrix" if TYPE_CHECKING: # https://github.com/dask/dask/issues/8853 from dask.array.core import Array as DaskArray -elif find_spec("dask"): - from dask.array import Array as DaskArray -else: # pragma: no cover - DaskArray = type("array", (), {}) +else: + _REGISTRY["DaskArray"] = "dask.array.Array" -if TYPE_CHECKING or find_spec("h5py"): +if TYPE_CHECKING: from h5py import Dataset as H5Dataset -else: # pragma: no cover - H5Dataset = type("Dataset", (), {}) +else: + _REGISTRY["H5Dataset"] = "h5py.Dataset" if TYPE_CHECKING or find_spec("zarr"): from zarr import Array as ZarrArray -else: # pragma: no cover - ZarrArray = type("Array", (), {}) +else: + _REGISTRY["ZarrArray"] = "zarr.Array" + + +# protocols: @runtime_checkable From e66a8c22e1f55e0c3fff3f0acef83e42118cb088 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Feb 2025 18:46:02 +0100 Subject: [PATCH 02/10] fix min job --- src/fast_array_utils/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index 413c261..4f53e08 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -46,10 +46,10 @@ def __getattr__(name: str) -> type | UnionType: # A name we don’t know about raise AttributeError(name) from None - if callable(source): - return source() - try: + if callable(source): + return source() + mod, name = source.rsplit(".", 1) return getattr(__import__(mod, fromlist=[name]), name) # type: ignore[no-any-return] except ImportError: # A name we can’t import From ae53c2faac482bbb9136447e747b92d57850f289 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Feb 2025 18:46:38 +0100 Subject: [PATCH 03/10] simplify --- src/fast_array_utils/types.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index 4f53e08..8a4b9da 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -59,7 +59,6 @@ def __getattr__(name: str) -> type | UnionType: # lazy exports: -# scipy sparse if TYPE_CHECKING: from scipy.sparse import csc_array, csc_matrix, csr_array, csr_matrix From fa08b031374f6bd4687797bffbb230570397844e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Feb 2025 18:47:46 +0100 Subject: [PATCH 04/10] better comment --- src/fast_array_utils/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index 8a4b9da..a854547 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -66,10 +66,11 @@ def __getattr__(name: str) -> type | UnionType: CSMatrix = csr_matrix | csc_matrix CSBase = CSMatrix | CSArray else: + # cs?_array isn’t available in older scipy versions, + # so we import them separately @_register("CSMatrix") def _get_cs_matrix() -> UnionType: - # cs?_matrix is available when scipy is installed from scipy.sparse import csc_matrix, csr_matrix return csr_matrix | csc_matrix From d277dc45dbd17f63b41c27193c40ee9b5450c041 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 20 Feb 2025 10:34:09 +0100 Subject: [PATCH 05/10] better import --- src/fast_array_utils/_import.py | 32 ++++++++++++++++++++++++++++++++ src/fast_array_utils/types.py | 7 ++++--- 2 files changed, 36 insertions(+), 3 deletions(-) create mode 100644 src/fast_array_utils/_import.py diff --git a/src/fast_array_utils/_import.py b/src/fast_array_utils/_import.py new file mode 100644 index 0000000..2627a35 --- /dev/null +++ b/src/fast_array_utils/_import.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + + +__all__ = ["_import_by_qualname"] + + +def _import_by_qualname(qualname: str) -> object: + from importlib import import_module + + parts = qualname.split(".") + + # import the module + obj = import_module(parts[0]) + for i, name in enumerate(parts[1:]): # noqa: B007 + try: + obj = import_module(f"{obj.__name__}.{name}") + except ModuleNotFoundError: + break + else: + i = len(parts) + + # get object if applicable + for name in parts[i + 1 :]: + try: + obj = getattr(obj, name) + except AttributeError: + msg = f"Could not import {name!r} from {'.'.join(parts[:i])} " + if i + 1 < len(parts): + msg += f"(trying to get {'.'.join(parts[i + 1 :])!r})" + raise ImportError(msg) from None + return obj diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index a854547..f796776 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -5,7 +5,9 @@ from functools import cache from importlib.util import find_spec -from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast, runtime_checkable + +from ._import import _import_by_qualname if TYPE_CHECKING: @@ -50,8 +52,7 @@ def __getattr__(name: str) -> type | UnionType: if callable(source): return source() - mod, name = source.rsplit(".", 1) - return getattr(__import__(mod, fromlist=[name]), name) # type: ignore[no-any-return] + return cast(type, _import_by_qualname(source)) except ImportError: # A name we can’t import return type(name, (), {}) From 580957e4e18944afaede08d81e20a38d9a3249cc Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 20 Feb 2025 12:03:30 +0100 Subject: [PATCH 06/10] Actually lazy import --- src/fast_array_utils/_import.py | 88 +++++++++++++++++++++------ src/fast_array_utils/conv/_asarray.py | 44 ++++++++------ src/fast_array_utils/stats/_sum.py | 37 +++++------ src/fast_array_utils/types.py | 14 ++--- tests/test_asarray.py | 6 +- 5 files changed, 120 insertions(+), 69 deletions(-) diff --git a/src/fast_array_utils/_import.py b/src/fast_array_utils/_import.py index 2627a35..b3ed535 100644 --- a/src/fast_array_utils/_import.py +++ b/src/fast_array_utils/_import.py @@ -1,32 +1,80 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations +from dataclasses import dataclass, field +from types import UnionType +from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar, cast, overload -__all__ = ["_import_by_qualname"] +if TYPE_CHECKING: + from collections.abc import Callable -def _import_by_qualname(qualname: str) -> object: +P = ParamSpec("P") +R = TypeVar("R") + + +__all__ = ["import_by_qualname", "lazy_singledispatch"] + + +def import_by_qualname(qualname: str) -> object: from importlib import import_module - parts = qualname.split(".") + mod_path, obj_path = qualname.split(":") - # import the module - obj = import_module(parts[0]) - for i, name in enumerate(parts[1:]): # noqa: B007 - try: - obj = import_module(f"{obj.__name__}.{name}") - except ModuleNotFoundError: - break - else: - i = len(parts) - - # get object if applicable - for name in parts[i + 1 :]: + mod = import_module(mod_path) + + # get object + obj = mod + for name in obj_path.split("."): try: obj = getattr(obj, name) - except AttributeError: - msg = f"Could not import {name!r} from {'.'.join(parts[:i])} " - if i + 1 < len(parts): - msg += f"(trying to get {'.'.join(parts[i + 1 :])!r})" - raise ImportError(msg) from None + except AttributeError as e: + msg = f"Could not import {'.'.join(obj_path)} from {'.'.join(mod_path)} " + raise ImportError(msg) from e return obj + + +@dataclass +class lazy_singledispatch(Generic[P, R]): # noqa: N801 + fallback: Callable[P, R] + + _lazy: dict[tuple[str, str], Callable[..., R]] = field(init=False, default_factory=dict) + _eager: dict[type | UnionType, Callable[..., R]] = field(init=False, default_factory=dict) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + for typ_, fn in self._eager.items(): + if isinstance(args[0], typ_): + return fn(*args, **kwargs) + for (import_qualname, host_mod_name), fn in self._lazy.items(): + for cls in type(args[0]).mro(): + if cls.__module__.startswith(host_mod_name): # can be deeper + cls_reg = cast(type, import_by_qualname(import_qualname)) + if isinstance(args[0], cls_reg): + return fn(*args, **kwargs) + return self.fallback(*args, **kwargs) + + @overload + def register( + self, qualname_or_type: str, /, host_mod_name: str | None = None + ) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: ... + @overload + def register( + self, qualname_or_type: type | UnionType, /, host_mod_name: None = None + ) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: ... + + def register( + self, qualname_or_type: str | type | UnionType, /, host_mod_name: str | None = None + ) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: + def decorator(fn: Callable[..., R]) -> lazy_singledispatch[P, R]: + match qualname_or_type, host_mod_name: + case str(), _: + hmn = qualname_or_type.split(":")[0] if host_mod_name is None else host_mod_name + self._lazy[(qualname_or_type, hmn)] = fn + case type() | UnionType(), None: + self._eager[qualname_or_type] = fn + case _: + msg = f"name_or_type {qualname_or_type!r} must be a str, type, or UnionType" + raise TypeError(msg) + return self + + return decorator diff --git a/src/fast_array_utils/conv/_asarray.py b/src/fast_array_utils/conv/_asarray.py index 76eddcd..a4b59f1 100644 --- a/src/fast_array_utils/conv/_asarray.py +++ b/src/fast_array_utils/conv/_asarray.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from functools import singledispatch from typing import TYPE_CHECKING import numpy as np -from ..types import CSBase, CupyArray, CupySparseMatrix, DaskArray, H5Dataset, OutOfCoreDataset +from .._import import lazy_singledispatch +from ..types import OutOfCoreDataset if TYPE_CHECKING: @@ -14,13 +14,24 @@ from numpy.typing import ArrayLike, NDArray + from .. import types -__all__ = ["OutOfCoreDataset", "asarray"] + +__all__ = ["asarray"] # fallback’s arg0 type has to include types of registered functions -@singledispatch -def asarray(x: ArrayLike | CSBase | OutOfCoreDataset[Any]) -> NDArray[Any]: +@lazy_singledispatch +def asarray( + x: ArrayLike + | types.CSBase + | types.DaskArray + | types.OutOfCoreDataset[Any] + | types.H5Dataset + | types.ZarrArray + | types.CupyArray + | types.CupySparseMatrix, +) -> NDArray[Any]: """Convert x to a numpy array. Parameters @@ -36,33 +47,28 @@ def asarray(x: ArrayLike | CSBase | OutOfCoreDataset[Any]) -> NDArray[Any]: return np.asarray(x) -@asarray.register(CSBase) # type: ignore[call-overload,misc] -def _(x: CSBase) -> NDArray[Any]: +@asarray.register("fast_array_utils.types:CSBase", "scipy.sparse") +def _(x: types.CSBase) -> NDArray[Any]: from .scipy import to_dense return to_dense(x) -@asarray.register(DaskArray) -def _(x: DaskArray) -> NDArray[Any]: +@asarray.register("dask.array:Array") +def _(x: types.DaskArray) -> NDArray[Any]: return asarray(x.compute()) # type: ignore[no-untyped-call] @asarray.register(OutOfCoreDataset) -def _(x: OutOfCoreDataset[CSBase | NDArray[Any]]) -> NDArray[Any]: +def _(x: types.OutOfCoreDataset[types.CSBase | NDArray[Any]]) -> NDArray[Any]: return asarray(x.to_memory()) -@asarray.register(H5Dataset) -def _(x: H5Dataset) -> NDArray[Any]: - return x[...] # type: ignore[no-any-return] - - -@asarray.register(CupyArray) -def _(x: CupyArray) -> NDArray[Any]: +@asarray.register("cupy:ndarray") +def _(x: types.CupyArray) -> NDArray[Any]: return x.get() # type: ignore[no-any-return] -@asarray.register(CupySparseMatrix) -def _(x: CupySparseMatrix) -> NDArray[Any]: +@asarray.register("cupyx.scipy.sparse:spmatrix") +def _(x: types.CupySparseMatrix) -> NDArray[Any]: return x.toarray().get() # type: ignore[no-any-return] diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 22da491..1092530 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from functools import partial, singledispatch +from functools import partial from typing import TYPE_CHECKING, overload import numpy as np -from ..types import CSBase, CSMatrix, DaskArray +from .._import import lazy_singledispatch if TYPE_CHECKING: @@ -14,6 +14,8 @@ from numpy.typing import ArrayLike, DTypeLike, NDArray + from .. import types + @overload def sum(x: ArrayLike, *, axis: None = None, dtype: DTypeLike | None = None) -> np.number[Any]: ... @@ -21,13 +23,13 @@ def sum(x: ArrayLike, *, axis: None = None, dtype: DTypeLike | None = None) -> n def sum(x: ArrayLike, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> NDArray[Any]: ... @overload def sum( - x: DaskArray, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None -) -> DaskArray: ... + x: types.DaskArray, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None +) -> types.DaskArray: ... def sum( x: ArrayLike, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None -) -> NDArray[Any] | np.number[Any] | DaskArray: +) -> NDArray[Any] | np.number[Any] | types.DaskArray: """Sum over both or one axis. Returns @@ -43,32 +45,30 @@ def sum( return _sum(x, axis=axis, dtype=dtype) -@singledispatch +@lazy_singledispatch def _sum( - x: ArrayLike | CSBase | DaskArray, - *, - axis: Literal[0, 1, None] = None, - dtype: DTypeLike | None = None, -) -> NDArray[Any] | np.number[Any] | DaskArray: - assert not isinstance(x, CSBase | DaskArray) + x: ArrayLike, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None +) -> NDArray[Any] | np.number[Any] | types.DaskArray: return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return] -@_sum.register(CSBase) # type: ignore[call-overload,misc] +@_sum.register("fast_array_utils.types:CSBase", "scipy.sparse") def _( - x: CSBase, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None + x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None ) -> NDArray[Any] | np.number[Any]: import scipy.sparse as sp + from ..types import CSMatrix + if isinstance(x, CSMatrix): x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x) return np.sum(x, axis=axis, dtype=dtype) # type: ignore[call-overload,no-any-return] -@_sum.register(DaskArray) +@_sum.register("dask.array:Array") def _( - x: DaskArray, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None -) -> DaskArray: + x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None +) -> types.DaskArray: if TYPE_CHECKING: from dask.array.reductions import reduction else: @@ -79,7 +79,8 @@ def _( raise TypeError(msg) def sum_drop_keepdims( - a: NDArray[Any] | CSBase, + a: NDArray[Any] | types.CSBase, + /, *, axis: tuple[Literal[0], Literal[1]] | Literal[0, 1] | None = None, dtype: DTypeLike | None = None, diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index f796776..14faac2 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -7,7 +7,7 @@ from importlib.util import find_spec from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast, runtime_checkable -from ._import import _import_by_qualname +from ._import import import_by_qualname if TYPE_CHECKING: @@ -52,7 +52,7 @@ def __getattr__(name: str) -> type | UnionType: if callable(source): return source() - return cast(type, _import_by_qualname(source)) + return cast(type, import_by_qualname(source)) except ImportError: # A name we can’t import return type(name, (), {}) @@ -91,26 +91,26 @@ def _get_cs_base() -> UnionType: from cupy import ndarray as CupyArray from cupyx.scipy.sparse import spmatrix as CupySparseMatrix else: - _REGISTRY["CupyArray"] = "cupy.ndarray" - _REGISTRY["CupySparseMatrix"] = "cupyx.scipy.sparse.spmatrix" + _REGISTRY["CupyArray"] = "cupy:ndarray" + _REGISTRY["CupySparseMatrix"] = "cupyx.scipy.sparse:spmatrix" if TYPE_CHECKING: # https://github.com/dask/dask/issues/8853 from dask.array.core import Array as DaskArray else: - _REGISTRY["DaskArray"] = "dask.array.Array" + _REGISTRY["DaskArray"] = "dask.array:Array" if TYPE_CHECKING: from h5py import Dataset as H5Dataset else: - _REGISTRY["H5Dataset"] = "h5py.Dataset" + _REGISTRY["H5Dataset"] = "h5py:Dataset" if TYPE_CHECKING or find_spec("zarr"): from zarr import Array as ZarrArray else: - _REGISTRY["ZarrArray"] = "zarr.Array" + _REGISTRY["ZarrArray"] = "zarr:Array" # protocols: diff --git a/tests/test_asarray.py b/tests/test_asarray.py index e11ee2d..c7caa57 100644 --- a/tests/test_asarray.py +++ b/tests/test_asarray.py @@ -9,15 +9,11 @@ if TYPE_CHECKING: - from typing import Any - - from numpy.typing import NDArray - from testing.fast_array_utils import ToArray def test_asarray(to_array: ToArray) -> None: x = to_array([[1, 2, 3], [4, 5, 6]]) - arr: NDArray[Any] = asarray(x) # type: ignore[arg-type] + arr = asarray(x) assert isinstance(arr, np.ndarray) assert arr.shape == (2, 3) From c242f96f064d058560508107fe65672bb2844181 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 20 Feb 2025 12:05:05 +0100 Subject: [PATCH 07/10] sum type --- src/fast_array_utils/stats/_sum.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 1092530..9a3cf45 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -18,17 +18,25 @@ @overload -def sum(x: ArrayLike, *, axis: None = None, dtype: DTypeLike | None = None) -> np.number[Any]: ... +def sum( + x: ArrayLike, /, *, axis: None = None, dtype: DTypeLike | None = None +) -> np.number[Any]: ... @overload -def sum(x: ArrayLike, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> NDArray[Any]: ... +def sum( + x: ArrayLike, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None +) -> NDArray[Any]: ... @overload def sum( - x: types.DaskArray, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None + x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None ) -> types.DaskArray: ... def sum( - x: ArrayLike, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None + x: ArrayLike | types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + dtype: DTypeLike | None = None, ) -> NDArray[Any] | np.number[Any] | types.DaskArray: """Sum over both or one axis. From 87fc20d3b86d563b7bc4f864f7d850a9bce94460 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 20 Feb 2025 12:17:04 +0100 Subject: [PATCH 08/10] cache --- src/fast_array_utils/_import.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/fast_array_utils/_import.py b/src/fast_array_utils/_import.py index b3ed535..e14e8be 100644 --- a/src/fast_array_utils/_import.py +++ b/src/fast_array_utils/_import.py @@ -2,6 +2,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from functools import cache from types import UnionType from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar, cast, overload @@ -42,16 +43,24 @@ class lazy_singledispatch(Generic[P, R]): # noqa: N801 _eager: dict[type | UnionType, Callable[..., R]] = field(init=False, default_factory=dict) def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - for typ_, fn in self._eager.items(): - if isinstance(args[0], typ_): - return fn(*args, **kwargs) + fn = self.dispatch(type(args[0])) # type: ignore[arg-type] # https://github.com/python/mypy/issues/11470 + return fn(*args, **kwargs) + + def __hash__(self) -> int: + return hash(self.fallback) + + @cache # noqa: B019 + def dispatch(self, typ: type) -> Callable[P, R]: + for cls_reg, fn in self._eager.items(): + if issubclass(typ, cls_reg): + return fn for (import_qualname, host_mod_name), fn in self._lazy.items(): - for cls in type(args[0]).mro(): + for cls in typ.mro(): if cls.__module__.startswith(host_mod_name): # can be deeper cls_reg = cast(type, import_by_qualname(import_qualname)) - if isinstance(args[0], cls_reg): - return fn(*args, **kwargs) - return self.fallback(*args, **kwargs) + if issubclass(typ, cls_reg): + return fn + return self.fallback @overload def register( From 57349fb4cc79ad6dfa73ff0eb206473b02b809c9 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 20 Feb 2025 12:25:23 +0100 Subject: [PATCH 09/10] full lazy --- src/fast_array_utils/__init__.py | 4 +--- src/fast_array_utils/_import.py | 5 +++++ src/fast_array_utils/types.py | 3 +-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/fast_array_utils/__init__.py b/src/fast_array_utils/__init__.py index 135e10e..9556bd1 100644 --- a/src/fast_array_utils/__init__.py +++ b/src/fast_array_utils/__init__.py @@ -3,9 +3,7 @@ from __future__ import annotations -from . import _patches, conv, stats, types +from . import conv, stats, types __all__ = ["conv", "stats", "types"] - -_patches.patch_dask() diff --git a/src/fast_array_utils/_import.py b/src/fast_array_utils/_import.py index e14e8be..329d2ac 100644 --- a/src/fast_array_utils/_import.py +++ b/src/fast_array_utils/_import.py @@ -24,6 +24,11 @@ def import_by_qualname(qualname: str) -> object: mod = import_module(mod_path) + if mod_path == "dask" or mod_path.startswith("dask."): + from ._patches import patch_dask + + patch_dask() + # get object obj = mod for name in obj_path.split("."): diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index 14faac2..0f7b6e9 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -4,7 +4,6 @@ from __future__ import annotations from functools import cache -from importlib.util import find_spec from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast, runtime_checkable from ._import import import_by_qualname @@ -107,7 +106,7 @@ def _get_cs_base() -> UnionType: _REGISTRY["H5Dataset"] = "h5py:Dataset" -if TYPE_CHECKING or find_spec("zarr"): +if TYPE_CHECKING: from zarr import Array as ZarrArray else: _REGISTRY["ZarrArray"] = "zarr:Array" From 6fb1583e72db7bd9c3279d4d861e12e27e0869e7 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 20 Feb 2025 13:18:34 +0100 Subject: [PATCH 10/10] fix docs --- docs/conf.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index e6ce354..956099f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -57,11 +57,17 @@ "ArrayLike": "numpy.typing.ArrayLike", "DTypeLike": "numpy.typing.DTypeLike", "NDArray": "numpy.typing.NDArray", - "CSBase": "scipy.sparse.spmatrix", - "CupyArray": "cupy.ndarray", - "CupySparseMatrix": "cupyx.scipy.sparse.spmatrix", - "DaskArray": "dask.array.Array", - "H5Dataset": "h5py.Dataset", + **{ + k: v + for k_plain, v in { + "CSBase": "scipy.sparse.spmatrix", + "CupyArray": "cupy.ndarray", + "CupySparseMatrix": "cupyx.scipy.sparse.spmatrix", + "DaskArray": "dask.array.Array", + "H5Dataset": "h5py.Dataset", + }.items() + for k in (k_plain, f"types.{k_plain}") + }, } # If that doesn’t work, ignore them nitpick_ignore = {