Skip to content

Commit

Permalink
Modify zarr chunking as suggested in pydata#4496 (pydata#4646)
Browse files Browse the repository at this point in the history
* modify get_chunks to align zarr chunking as described in issue pydata#4496

* fix: maintain old open_zarr chunking interface

* add and fix tests

* black

* bugfix

* add few documentation on open_dataset chunking

* in test: re-add xafils for negative steps without dask

* Specify in reason that only zarr is expected to fail

* unify backend test negative_step with dask and without dask

* Add comment on has_dask usage

Co-authored-by: Alessandro Amici <[email protected]>
  • Loading branch information
aurghs and alexamici authored Dec 9, 2020
1 parent ff6b1f5 commit 9802411
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 48 deletions.
12 changes: 7 additions & 5 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,12 @@ def open_dataset(
"netcdf4".
chunks : int or dict, optional
If chunks is provided, it is used to load the new dataset into dask
arrays. ``chunks={}`` loads the dataset with dask using a single
chunk for all arrays. When using ``engine="zarr"``, setting
``chunks='auto'`` will create dask chunks based on the variable's zarr
chunks.
arrays. ``chunks=-1`` loads the dataset with dask using a single
chunk for all arrays. `chunks={}`` loads the dataset with dask using
engine preferred chunks if exposed by the backend, otherwise with
a single chunk for all arrays.
``chunks='auto'`` will use dask ``auto`` chunking taking into account the
engine preferred chunks. See dask chunking for more details.
lock : False or lock-like, optional
Resource lock to use when reading data from disk. Only relevant when
using dask or another form of parallelism. By default, appropriate
Expand Down Expand Up @@ -536,7 +538,7 @@ def maybe_decode_store(store, chunks):
k: _maybe_chunk(
k,
v,
_get_chunk(k, v, chunks),
_get_chunk(v, chunks),
overwrite_encoded_chunks=overwrite_encoded_chunks,
)
for k, v in ds.variables.items()
Expand Down
12 changes: 7 additions & 5 deletions xarray/backends/apiv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _chunk_ds(

variables = {}
for k, v in backend_ds.variables.items():
var_chunks = _get_chunk(k, v, chunks)
var_chunks = _get_chunk(v, chunks)
variables[k] = _maybe_chunk(
k,
v,
Expand Down Expand Up @@ -146,10 +146,12 @@ def open_dataset(
"pynio", "cfgrib", "pseudonetcdf", "zarr"}.
chunks : int or dict, optional
If chunks is provided, it is used to load the new dataset into dask
arrays. ``chunks={}`` loads the dataset with dask using a single
chunk for all arrays. When using ``engine="zarr"``, setting
``chunks='auto'`` will create dask chunks based on the variable's zarr
chunks.
arrays. ``chunks=-1`` loads the dataset with dask using a single
chunk for all arrays. `chunks={}`` loads the dataset with dask using
engine preferred chunks if exposed by the backend, otherwise with
a single chunk for all arrays.
``chunks='auto'`` will use dask ``auto`` chunking taking into account the
engine preferred chunks. See dask chunking for more details.
cache : bool, optional
If True, cache data is loaded from the underlying datastore in memory as
NumPy arrays when accessed to avoid reading from the underlying data-
Expand Down
8 changes: 8 additions & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,14 @@ def open_zarr(
"""
from .api import open_dataset

if chunks == "auto":
try:
import dask.array # noqa

chunks = {}
except ImportError:
chunks = None

if kwargs:
raise TypeError(
"open_zarr() got unexpected keyword arguments " + ",".join(kwargs.keys())
Expand Down
69 changes: 46 additions & 23 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,32 +359,55 @@ def _assert_empty(args: tuple, msg: str = "%s") -> None:
raise ValueError(msg % args)


def _get_chunk(name, var, chunks):
chunk_spec = dict(zip(var.dims, var.encoding.get("chunks")))
def _check_chunks_compatibility(var, chunks, chunk_spec):
for dim in var.dims:
if dim not in chunks or (dim not in chunk_spec):
continue

chunk_spec_dim = chunk_spec.get(dim)
chunks_dim = chunks.get(dim)

if isinstance(chunks_dim, int):
chunks_dim = (chunks_dim,)
if any(s % chunk_spec_dim for s in chunks_dim):
warnings.warn(
f"Specified Dask chunks {chunks[dim]} would separate "
f"on disks chunk shape {chunk_spec[dim]} for dimension {dim}. "
"This could degrade performance. "
"Consider rechunking after loading instead.",
stacklevel=2,
)

# Coordinate labels aren't chunked
if var.ndim == 1 and var.dims[0] == name:
return chunk_spec

if chunks == "auto":
return chunk_spec
def _get_chunk(var, chunks):
# chunks need to be explicity computed to take correctly into accout
# backend preferred chunking
import dask.array as da

for dim in var.dims:
if dim in chunks:
spec = chunks[dim]
if isinstance(spec, int):
spec = (spec,)
if isinstance(spec, (tuple, list)) and chunk_spec[dim]:
if any(s % chunk_spec[dim] for s in spec):
warnings.warn(
f"Specified Dask chunks {chunks[dim]} would separate "
f"on disks chunk shape {chunk_spec[dim]} for dimension {dim}. "
"This could degrade performance. "
"Consider rechunking after loading instead.",
stacklevel=2,
)
chunk_spec[dim] = chunks[dim]
return chunk_spec
if isinstance(var, IndexVariable):
return {}

if isinstance(chunks, int) or (chunks == "auto"):
chunks = dict.fromkeys(var.dims, chunks)

preferred_chunks_list = var.encoding.get("chunks", {})
preferred_chunks = dict(zip(var.dims, var.encoding.get("chunks", {})))

chunks_list = [
chunks.get(dim, None) or preferred_chunks.get(dim, None) for dim in var.dims
]

output_chunks_list = da.core.normalize_chunks(
chunks_list,
shape=var.shape,
dtype=var.dtype,
previous_chunks=preferred_chunks_list,
)

output_chunks = dict(zip(var.dims, output_chunks_list))
_check_chunks_compatibility(var, output_chunks, preferred_chunks)

return output_chunks


def _maybe_chunk(
Expand Down
103 changes: 88 additions & 15 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,10 +608,6 @@ def test_orthogonal_indexing(self):
actual = on_disk.isel(**indexers)
assert_identical(expected, actual)

@pytest.mark.xfail(
not has_dask,
reason="the code for indexing without dask handles negative steps in slices incorrectly",
)
def test_vectorized_indexing(self):
in_memory = create_test_data()
with self.roundtrip(in_memory) as on_disk:
Expand Down Expand Up @@ -676,6 +672,29 @@ def multiple_indexing(indexers):
]
multiple_indexing(indexers)

@pytest.mark.xfail(
reason="zarr without dask handles negative steps in slices incorrectly",
)
def test_vectorized_indexing_negative_step(self):
# use dask explicitly when present
if has_dask:
open_kwargs = {"chunks": {}}
else:
open_kwargs = None
in_memory = create_test_data()

def multiple_indexing(indexers):
# make sure a sequence of lazy indexings certainly works.
with self.roundtrip(in_memory, open_kwargs=open_kwargs) as on_disk:
actual = on_disk["var3"]
expected = in_memory["var3"]
for ind in indexers:
actual = actual.isel(**ind)
expected = expected.isel(**ind)
# make sure the array is not yet loaded into memory
assert not actual.variable._in_memory
assert_identical(expected, actual.load())

# with negative step slice.
indexers = [
{
Expand Down Expand Up @@ -1567,7 +1586,7 @@ def roundtrip(
if save_kwargs is None:
save_kwargs = {}
if open_kwargs is None:
open_kwargs = {"chunks": "auto"}
open_kwargs = {}
with self.create_zarr_target() as store_target:
self.save(data, store_target, **save_kwargs)
with self.open(store_target, **open_kwargs) as ds:
Expand Down Expand Up @@ -1604,7 +1623,7 @@ def test_auto_chunk(self):
# there should be no chunks
assert v.chunks is None

with self.roundtrip(original, open_kwargs={"chunks": "auto"}) as actual:
with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual:
for k, v in actual.variables.items():
# only index variables should be in memory
assert v._in_memory == (k in actual.dims)
Expand Down Expand Up @@ -1701,7 +1720,7 @@ def test_deprecate_auto_chunk(self):
def test_write_uneven_dask_chunks(self):
# regression for GH#2225
original = create_test_data().chunk({"dim1": 3, "dim2": 4, "dim3": 3})
with self.roundtrip(original, open_kwargs={"chunks": "auto"}) as actual:
with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual:
for k, v in actual.data_vars.items():
print(k)
assert v.chunks == actual[k].chunks
Expand Down Expand Up @@ -1850,9 +1869,7 @@ def test_write_persistence_modes(self, group):
ds.to_zarr(store_target, mode="w", group=group)
ds_to_append.to_zarr(store_target, append_dim="time", group=group)
original = xr.concat([ds, ds_to_append], dim="time")
actual = xr.open_dataset(
store_target, group=group, chunks="auto", engine="zarr"
)
actual = xr.open_dataset(store_target, group=group, engine="zarr")
assert_identical(original, actual)

def test_compressor_encoding(self):
Expand Down Expand Up @@ -1941,11 +1958,11 @@ def test_check_encoding_is_consistent_after_append(self):
encoding = {"da": {"compressor": compressor}}
ds.to_zarr(store_target, mode="w", encoding=encoding)
ds_to_append.to_zarr(store_target, append_dim="time")
actual_ds = xr.open_dataset(store_target, chunks="auto", engine="zarr")
actual_ds = xr.open_dataset(store_target, engine="zarr")
actual_encoding = actual_ds["da"].encoding["compressor"]
assert actual_encoding.get_config() == compressor.get_config()
assert_identical(
xr.open_dataset(store_target, chunks="auto", engine="zarr").compute(),
xr.open_dataset(store_target, engine="zarr").compute(),
xr.concat([ds, ds_to_append], dim="time"),
)

Expand All @@ -1960,9 +1977,7 @@ def test_append_with_new_variable(self):
ds_with_new_var.to_zarr(store_target, mode="a")
combined = xr.concat([ds, ds_to_append], dim="time")
combined["new_var"] = ds_with_new_var["new_var"]
assert_identical(
combined, xr.open_dataset(store_target, chunks="auto", engine="zarr")
)
assert_identical(combined, xr.open_dataset(store_target, engine="zarr"))

@requires_dask
def test_to_zarr_compute_false_roundtrip(self):
Expand Down Expand Up @@ -4803,3 +4818,61 @@ def test_load_single_value_h5netcdf(tmp_path):
ds.to_netcdf(tmp_path / "test.nc")
with xr.open_dataset(tmp_path / "test.nc", engine="h5netcdf") as ds2:
ds2["test"][0].load()


@requires_zarr
@requires_dask
@pytest.mark.parametrize(
"chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}]
)
def test_open_dataset_chunking_zarr(chunks, tmp_path):
encoded_chunks = 100
dask_arr = da.from_array(
np.ones((500, 500), dtype="float64"), chunks=encoded_chunks
)
ds = xr.Dataset(
{
"test": xr.DataArray(
dask_arr,
dims=("x", "y"),
)
}
)
ds["test"].encoding["chunks"] = encoded_chunks
ds.to_zarr(tmp_path / "test.zarr")

with dask.config.set({"array.chunk-size": "1MiB"}):
expected = ds.chunk(chunks)
actual = xr.open_dataset(tmp_path / "test.zarr", engine="zarr", chunks=chunks)
xr.testing.assert_chunks_equal(actual, expected)


@requires_zarr
@requires_dask
@pytest.mark.parametrize(
"chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}]
)
def test_chunking_consintency(chunks, tmp_path):
encoded_chunks = {}
dask_arr = da.from_array(
np.ones((500, 500), dtype="float64"), chunks=encoded_chunks
)
ds = xr.Dataset(
{
"test": xr.DataArray(
dask_arr,
dims=("x", "y"),
)
}
)
ds["test"].encoding["chunks"] = encoded_chunks
ds.to_zarr(tmp_path / "test.zarr")
ds.to_netcdf(tmp_path / "test.nc")

with dask.config.set({"array.chunk-size": "1MiB"}):
expected = ds.chunk(chunks)
actual = xr.open_dataset(tmp_path / "test.zarr", engine="zarr", chunks=chunks)
xr.testing.assert_chunks_equal(actual, expected)

actual = xr.open_dataset(tmp_path / "test.nc", chunks=chunks)
xr.testing.assert_chunks_equal(actual, expected)

0 comments on commit 9802411

Please sign in to comment.