Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3655.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a bug in the sharding codec that prevented nested shard reads in certain cases.
13 changes: 10 additions & 3 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
)
from zarr.core.metadata.v3 import parse_codecs
from zarr.registry import get_ndbuffer_class, get_pipeline_class
from zarr.storage._utils import _normalize_byte_range_index

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -86,11 +87,16 @@ class _ShardingByteGetter(ByteGetter):
async def get(
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None:
assert byte_range is None, "byte_range is not supported within shards"
assert prototype == default_buffer_prototype(), (
f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}"
)
return self.shard_dict.get(self.chunk_coords)
value = self.shard_dict.get(self.chunk_coords)
if value is None:
return None
if byte_range is None:
return value
start, stop = _normalize_byte_range_index(value, byte_range)
return value[start:stop]


@dataclass(frozen=True)
Expand Down Expand Up @@ -597,7 +603,8 @@ async def _decode_shard_index(
)
)
)
assert index_array is not None
# This cannot be None because we have the bytes already
index_array = cast(NDBuffer, index_array)
return _ShardIndex(index_array.as_numpy_array())

async def _encode_shard_index(self, index: _ShardIndex) -> Buffer:
Expand Down
47 changes: 22 additions & 25 deletions tests/test_codecs/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
TransposeCodec,
)
from zarr.core.buffer import NDArrayLike, default_buffer_prototype
from zarr.errors import ZarrUserWarning
from zarr.storage import StorePath, ZipStore

from ..conftest import ArrayRequest
Expand Down Expand Up @@ -239,12 +238,14 @@ def test_sharding_partial_overwrite(
assert np.array_equal(data, read_data)


# Zip storage raises a warning about a duplicate name, which we ignore.
@pytest.mark.filterwarnings("ignore:Duplicate name.*:UserWarning")
@pytest.mark.parametrize(
"array_fixture",
[
ArrayRequest(shape=(128,) * 3, dtype="uint16", order="F"),
ArrayRequest(shape=(127, 128, 129), dtype="uint16", order="F"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by making the shape of the array irregular w.r.t to the chunk shape, we hit the partial decode path, which required for evoking the bug reported in #3652

],
indirect=["array_fixture"],
indirect=True,
)
@pytest.mark.parametrize(
"outer_index_location",
Expand All @@ -263,24 +264,23 @@ def test_nested_sharding(
) -> None:
data = array_fixture
spath = StorePath(store)
msg = "Combining a `sharding_indexed` codec disables partial reads and writes, which may lead to inefficient performance."
with pytest.warns(ZarrUserWarning, match=msg):
a = zarr.create_array(
spath,
shape=data.shape,
chunks=(64, 64, 64),
dtype=data.dtype,
fill_value=0,
serializer=ShardingCodec(
chunk_shape=(32, 32, 32),
codecs=[
ShardingCodec(chunk_shape=(16, 16, 16), index_location=inner_index_location)
],
index_location=outer_index_location,
),
)
# compressors=None ensures no BytesBytesCodec is added, which keeps
# supports_partial_decode=True and exercises the partial decode path
a = zarr.create_array(
spath,
data=data,
chunks=(64,) * data.ndim,
compressors=None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting compressors to None here is also required to trigger the partial decode path.

serializer=ShardingCodec(
chunk_shape=(32,) * data.ndim,
codecs=[
ShardingCodec(chunk_shape=(16,) * data.ndim, index_location=inner_index_location)
],
index_location=outer_index_location,
),
)

a[:, :, :] = data
a[:] = data

read_data = a[0 : data.shape[0], 0 : data.shape[1], 0 : data.shape[2]]
assert isinstance(read_data, NDArrayLike)
Expand Down Expand Up @@ -326,13 +326,10 @@ def test_nested_sharding_create_array(
filters=None,
compressors=None,
)
print(a.metadata.to_dict())

a[:, :, :] = data
a[:] = data

read_data = a[0 : data.shape[0], 0 : data.shape[1], 0 : data.shape[2]]
assert isinstance(read_data, NDArrayLike)
assert data.shape == read_data.shape
read_data = a[:]
assert np.array_equal(data, read_data)


Expand Down