Skip to content

Commit

Permalink
Address review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey-Kamenev committed Jun 29, 2023
1 parent eb257d4 commit ff14cc0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
13 changes: 11 additions & 2 deletions python/kvikio/nvcomp_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,21 @@ def encode_batch(self, bufs: List[Any]) -> List[Any]:
temp_buf = cp.empty(temp_size, dtype=cp.uint8)

# Includes header with the original buffer size,
# same as in numcodecs codec.
# same as in numcodecs codec. This enables data compatibility between
# numcodecs default codecs and this nvCOMP batch codec.
# TODO(akamenev): probably should use contiguous buffer which stores all chunks?
comp_chunks_header = [
cp.empty(self.HEADER_SIZE_BYTES + comp_chunk_size, dtype=cp.uint8)
for _ in range(num_chunks)
]
# comp_chunks is used as a container that stores pointers to actual chunks.
# nvCOMP requires this container to be in GPU memory.
comp_chunks = cp.array(
[c.data.ptr + self.HEADER_SIZE_BYTES for c in comp_chunks_header],
dtype=cp.uint64,
)
# Similar to comp_chunks, comp_chunk_sizes is an array that contains
# chunk sizes and is required by nvCOMP to be in GPU memory.
comp_chunk_sizes = cp.empty(num_chunks, dtype=cp.uint64)

self._algo.compress(
Expand Down Expand Up @@ -213,12 +218,16 @@ def decode_batch(
temp_buf = cp.empty(temp_size, dtype=cp.uint8)

# Prepare uncompressed chunks buffers.
# First, allocate chunks of appropriate sizes and then
# copy the pointers to a pointer array in GPU memory as required by nvCOMP.
# TODO(akamenev): probably can allocate single contiguous buffer.
uncomp_chunks = [cp.empty(size, dtype=cp.uint8) for size in uncomp_chunk_sizes]
uncomp_chunk_ptrs = cp.array(
[c.data.ptr for c in uncomp_chunks], dtype=cp.uint64
)

# Sizes array must be in GPU memory.
uncomp_chunk_sizes = cp.array(uncomp_chunk_sizes, dtype=cp.uint64)

# TODO(akamenev): currently we provide the following 2 buffers to decompress()
# but do not check/use them afterwards since some of the algos
# (e.g. LZ4 and Gdeflate) do not require it and run faster
Expand Down
33 changes: 24 additions & 9 deletions python/tests/test_nvcomp_codec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

import itertools as it
import json

import numcodecs
Expand All @@ -26,15 +27,31 @@ def _get_codec(algo: str):
return numcodecs.registry.get_codec(codec_args)


@pytest.fixture(params=[(16,), (8, 16), (16, 16)])
def shape(request):
return request.param


# Separate fixture for combinations of shapes and chunks, since
# chunks array must have the same rank as data array.
@pytest.fixture(
params=it.chain(
it.product([(32,)], [(16,), (32,), (40,)]),
it.product([(16, 8), (16, 16)], [(8, 16), (16, 16), (40, 12)]),
)
)
def shape_chunks(request):
return request.param


@pytest.mark.parametrize("algo", SUPPORTED_CODECS)
def test_codec_registry(algo: str):
codec = _get_codec(algo)
assert isinstance(codec, numcodecs.abc.Codec)


@pytest.mark.parametrize("algo", SUPPORTED_CODECS)
def test_basic(algo: str):
shape = (16, 16)
def test_basic(algo: str, shape):
codec = NvCompBatchCodec(algo)

# Create data.
Expand All @@ -49,9 +66,8 @@ def test_basic(algo: str):


@pytest.mark.parametrize("algo", SUPPORTED_CODECS)
def test_basic_zarr(algo: str):
shape = (16, 16)
chunks = (8, 8)
def test_basic_zarr(algo: str, shape_chunks):
shape, chunks = shape_chunks

codec = NvCompBatchCodec(algo)

Expand Down Expand Up @@ -86,14 +102,13 @@ def test_batch_comp_decomp(algo: str):


@pytest.mark.parametrize("algo", SUPPORTED_CODECS)
def test_comp_decomp(algo: str):
def test_comp_decomp(algo: str, shape_chunks):
shape, chunks = shape_chunks

codec = _get_codec(algo)

np.random.seed(1)

shape = (16, 16)
chunks = (8, 8)

data = np.random.randn(*shape).astype(np.float32)

z1 = zarr.array(data, chunks=chunks, compressor=codec)
Expand Down

0 comments on commit ff14cc0

Please sign in to comment.