Skip to content

Commit

Permalink
Use a more complex chunks header (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 authored Feb 2, 2023
1 parent 411fe3a commit 38507bd
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 53 deletions.
7 changes: 0 additions & 7 deletions setup.cfg

This file was deleted.

65 changes: 64 additions & 1 deletion tests/test_logic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,41 @@
from zarr_proxy.logic import chunk_id_to_slice
import pytest

from zarr_proxy.logic import (
chunk_id_to_slice,
parse_chunks_header,
validate_chunks_info,
)


@pytest.mark.parametrize(
"shape, chunks, chunk_index, chunks_block_shape",
[
((5, 5), (2, 2), (0, 0), (2, 2)),
((5, 5), (2, 2), (0, 1), (2, 2)),
((5, 5), (2, 2), (1, 0), (2, 2)),
((5, 5), (2, 2), (1, 1), (2, 2)),
],
)
def test_validate_chunks_info_valid(shape, chunks, chunk_index, chunks_block_shape):
validate_chunks_info(
shape=shape, chunks=chunks, chunk_index=chunk_index, chunks_block_shape=chunks_block_shape
)


@pytest.mark.parametrize(
"shape, chunks, chunk_index, chunks_block_shape",
[
((5, 5), (2, 2), (0, 5), (2, 2)),
((5, 5), (2, 2), (5, 0), (2, 2)),
((5, 5), (2, 2), (5, 5), (2, 2)),
((5, 5), (3, 3), (2, 2), (2, 2)),
],
)
def test_validate_chunks_info_invalid(shape, chunks, chunk_index, chunks_block_shape):
with pytest.raises(IndexError):
validate_chunks_info(
shape=shape, chunks=chunks, chunk_index=chunk_index, chunks_block_shape=chunks_block_shape
)


def test_chunk_id_to_slice():
Expand All @@ -9,3 +46,29 @@ def test_chunk_id_to_slice():
assert chunk_id_to_slice("0.0", shape=shape, chunks=chunks) == (slice(0, 1), slice(0, 2))
assert chunk_id_to_slice("1.0", shape=shape, chunks=chunks) == (slice(1, 2), slice(0, 2))
assert chunk_id_to_slice("1.1", shape=shape, chunks=chunks) == (slice(1, 2), slice(2, 4))


@pytest.mark.parametrize(
"chunks, expected_output",
[
("bed=10,10,prec=20,20,lat=5", {"bed": (10, 10), "prec": (20, 20), "lat": (5,)}),
("foo=1,1,bar=2,2", {"foo": (1, 1), "bar": (2, 2)}),
("Foo=1,1,BAR=2,2", {"Foo": (1, 1), "BAR": (2, 2)}),
("a=1,b=2,c=3,d=4", {"a": (1,), "b": (2,), "c": (3,), "d": (4,)}),
("", {}),
("bed=10,10,prec=20,20,lat=five", {'bed': (10, 10), 'prec': (20, 20)}),
(
'time=6443, analysed_sst=30,100,100, analysis_error=30,100,100, mask=30,100,100, sea_ice_fraction=30,100,100',
{
'time': (6443,),
'sst': (30, 100, 100),
'error': (30, 100, 100),
'mask': (30, 100, 100),
'fraction': (30, 100, 100),
},
),
],
)
def test_parse_chunks_header(chunks, expected_output):
output = parse_chunks_header(chunks)
assert output == expected_output
47 changes: 36 additions & 11 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
[
(
'storage.googleapis.com/carbonplan-maps/ncview/demo/single_timestep/air_temperature.zarr/air',
'10,10',
'lat=10,air=10,10',
),
(
'storage.googleapis.com/carbonplan-maps/ncview/demo/single_timestep/air_temperature.zarr/lat',
'5',
'lat=5,air=10,10',
),
(
'storage.googleapis.com/carbonplan-maps/ncview/demo/single_timestep/air_temperature.zarr/lon',
'10',
'lat=10,air=10,10',
),
],
)
Expand Down Expand Up @@ -49,14 +49,29 @@ def test_store_zattrs(test_app, store):


@pytest.mark.parametrize(
'store',
['storage.googleapis.com/carbonplan-maps/ncview/demo/single_timestep/air_temperature.zarr'],
'store,chunks,expected_air_chunks',
[
(
'storage.googleapis.com/carbonplan-maps/ncview/demo/single_timestep/air_temperature.zarr',
'',
[25, 53],
),
(
'storage.googleapis.com/carbonplan-maps/ncview/demo/single_timestep/air_temperature.zarr',
'lat=10,air=10,10',
[10, 10],
),
],
)
def test_store_zmetadata(test_app, store):
response = test_app.get(f"/{store}/.zmetadata")
def test_store_zmetadata(test_app, store, chunks, expected_air_chunks):
response = test_app.get(f"/{store}/.zmetadata", headers={'chunks': chunks})
assert response.status_code == 200

assert {'.zattrs', '.zgroup'}.issubset(response.json()['metadata'].keys())
zmetadata = response.json()['metadata']

assert {'.zattrs', '.zgroup'}.issubset(zmetadata.keys())

assert zmetadata['air/.zarray']['chunks'] == expected_air_chunks


@pytest.mark.parametrize(
Expand All @@ -65,17 +80,17 @@ def test_store_zmetadata(test_app, store):
(
'storage.googleapis.com/carbonplan-maps/ncview/demo/single_timestep/air_temperature.zarr/air',
'0.0',
'10,10',
'lat=10,air=10,10',
),
(
'storage.googleapis.com/carbonplan-maps/ncview/demo/single_timestep/air_temperature.zarr/lat',
'0',
'5',
'lat=10,air=10,10',
),
(
'storage.googleapis.com/carbonplan-maps/ncview/demo/single_timestep/air_temperature.zarr/lon',
'1',
'10',
'lat=10,air=10,10,lon=10',
),
],
)
Expand All @@ -87,3 +102,13 @@ def test_store_array_chunk(test_app, store, chunk_key, chunks):
# get the bytes from the response
data = response.content
assert isinstance(data, bytes)


def test_store_array_chunk_out_of_bounds(test_app):
array = 'storage.googleapis.com/carbonplan-maps/ncview/demo/single_timestep/air_temperature.zarr/lon'
chunk_key = 1
chunks = 'lat=10,air=10,10'
response = test_app.get(f"/{array}/{chunk_key}", headers={'chunks': chunks})

assert response.status_code == 400
assert 'The chunk_index: (1,) must be less than the chunks block shape: (1,)' in response.json()['detail']
126 changes: 114 additions & 12 deletions zarr_proxy/logic.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,120 @@
"""Logic for the zarr proxy"""

import functools
import math
import operator
import re
import typing


def parse_chunks_header(chunks: str) -> dict[str, tuple[int, ...]]:

"""Parse the chunks header into a dictionary of chunk keys and chunk sizes.
This turns a string like "bed=10,10,prec=20,20,lat=5" into a dictionary like
{"bed": (10, 10), "prec": (20, 20), "lat": (5,)}.
Parameters
----------
chunks: str
e.g. "bed=10,10,prec=20,20,lat=5"
Returns
-------
dict[str, tuple[int, ...]]
e.g. {"bed": (10, 10), "prec": (20, 20), "lat": (5,)}
"""

# ([a-zA-Z]+=[\d,]+) regular expression looks for one or more letters (both uppercase and lowercase)
# followed by an equal sign and one or more digits and commas. The letters are grouped
# together using square brackets and the "+" symbol specifies that there needs to be one
# or more of them. The "A-Z" and "a-z" inside the square brackets specify that it should
# match uppercase and lowercase letters. The equal sign and digits and commas are also matched
# with the "+" symbol to match one or more of them. The entire match is wrapped in a capturing group,
# indicated by the parentheses, which can be used to access the match in the code.

parsed_list = re.findall(r'([a-zA-Z]+=[\d,]+)', chunks)
parsed_dict = {}
for item in parsed_list:
key, value = item.split('=')
parsed_dict[key] = tuple(map(int, value.strip(',').split(',')))
return parsed_dict


def virtual_array_info(
*, shape: tuple[int, ...], chunks: tuple[int, ...]
) -> dict[str, typing.Union[tuple[int, ...], int]]:
"""
Return a dictionary of information about a virtual array
Parameters
----------
shape: tuple[int]
The shape of the array
chunks: tuple[int]
The desired chunking
Returns
-------
dict[str, tuple[int] | int]
def chunks_from_string(chunks: str) -> tuple[int]:
"""
Given a string of comma-separated integers, return a tuple of ints
chunks_block_shape = tuple(math.ceil(s / c) for s, c in zip(shape, chunks)) if shape else 1
nchunks = functools.reduce(operator.mul, chunks_block_shape, 1)

return {"chunks_block_shape": chunks_block_shape, "num_chunks": nchunks}


def validate_chunks_info(
*,
shape=tuple[int, ...],
chunks=tuple[int, ...],
chunk_index=tuple[int, ...],
chunks_block_shape=tuple[int, ...],
):
"""Validate the chunks info. Raise an error if the chunk_index is out of bounds
Parameters
----------
chunks: e.g. "1,2,3"
shape: tuple[int]
The shape of the array
chunks: tuple[int]
The desired chunking
chunk_index: tuple[int]
The chunk index
chunks_block_shape: tuple[int]
The chunks block shape
Returns
-------
None
Raises
------
IndexError
If the chunk_index is out of bounds
"""
return tuple(int(c) for c in chunks.split(","))

if len(chunk_index) != len(chunks):
raise IndexError(f"The length of chunk_index: {chunk_index} and chunks: {chunks} must be the same.")

if len(shape) != len(chunks):
raise IndexError(f"The length of shape: {shape} and chunks: {chunks} must be the same.")

for index, dim_size in enumerate(chunks_block_shape):
if chunk_index[index] < 0 or chunk_index[index] >= dim_size:
raise IndexError(
f"The chunk_index: {chunk_index} must be less than the chunks block shape: {chunks_block_shape}"
)


def chunk_id_to_slice(
chunk_key: str, *, chunks: tuple[int], shape: tuple[int], delimiter: str = "."
chunk_key: str, *, chunks: tuple[int, ...], shape: tuple[int, ...], delimiter: str = "."
) -> tuple[slice]:
"""
Given a Zarr chunk key, return the slice of an array to extract that data
Expand All @@ -25,15 +126,16 @@ def chunk_id_to_slice(
shape: the total array shape
delimiter: chunk separator character
"""
chunk_index = tuple(int(c) for c in chunk_key.split("."))

if len(chunk_index) != len(chunks):
raise ValueError(f"The length of chunk_index: {chunk_index} and chunks: {chunks} must be the same.")

if len(shape) != len(chunks):
raise ValueError(f"The length of shape: {shape} and chunks: {chunks} must be the same.")
chunk_index = tuple(int(c) for c in chunk_key.split(delimiter))

# TODO: more error handling maybe?
array_info = virtual_array_info(shape=shape, chunks=chunks)
validate_chunks_info(
shape=shape,
chunks=chunks,
chunk_index=chunk_index,
chunks_block_shape=array_info["chunks_block_shape"],
)
slices = tuple(
(slice(min(c * ci, s), min(c * (ci + 1), s)) for c, s, ci in zip(chunks, shape, chunk_index))
)
Expand Down
Loading

0 comments on commit 38507bd

Please sign in to comment.