diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..c29b638 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,2 @@ +[run] +omit = src/akimbo/cudf.py \ No newline at end of file diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 0a1296c..914f094 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -34,4 +34,4 @@ jobs: pip list - name: test run: | - python -m pytest -v --cov akimbo + python -m pytest -v --cov-config=.coveragerc --cov akimbo diff --git a/.gitignore b/.gitignore index ad2c9bf..b7f5bbc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # emacs .dir-locals.el +.idea # setuptools_scm src/akimbo/version.py diff --git a/example/cuda_env.yaml b/example/cuda_env.yaml new file mode 100644 index 0000000..c6b12b2 --- /dev/null +++ b/example/cuda_env.yaml @@ -0,0 +1,22 @@ +name: cuda +channels: + - conda-forge +dependencies: + - python=3.10 + - cuda-cudart + - cuda-version=12.2 + - pycuda + - cupy + - numba + - awkward + - rapidsai::cudf + - ipython + - numba + - pyarrow + - pandas + - polars + - pytest + - distributed + - dask-awkward + - pytest + - rox diff --git a/example/cudf-ak.ipynb b/example/cudf-ak.ipynb new file mode 100644 index 0000000..db5008d --- /dev/null +++ b/example/cudf-ak.ipynb @@ -0,0 +1,626 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "58d18a3a-45b1-425a-b822-e8be0a6c0bc0", + "metadata": {}, + "source": [ + "\n", + "```python\n", + "import awkward as ak\n", + "\n", + "def make_data(fn):\n", + " part = [[[1, 2, 3], [], [4, 5]],\n", + " [[6, 7]]] * 1000000\n", + " arr = ak.Array({\"a\": part})\n", + " ak.to_parquet(arr, fn, extensionarray=False)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "cefd8e53-a56f-4b0c-88d2-d662d59849a7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('2.6.9', '2024.8.1.dev29+g9b9f27f.d20240927')" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import awkward as ak\n", + "import cupy as cp\n", + "import cudf\n", + "import numpy as np\n", + "import akimbo.cudf\n", + "import subprocess\n", + "\n", + "def gpu_mem():\n", + " print(subprocess.check_output(\"nvidia-smi | grep py\", shell=True).split()[-2].decode())\n", + "\n", + "ak.__version__, akimbo.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0490043a-564a-4c11-bb0d-a54fb4c6fb10", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "160MiB\n" + ] + } + ], + "source": [ + "df = cudf.read_parquet(\"/floppy/code/awkward/s.parquet\")\n", + "gpu_mem()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e29ff9a4-60e4-4260-9a44-c135ad6d7d6b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "a list\n", + "dtype: object" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.dtypes" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "58d16a80-041e-4260-8c56-9de932dde557", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "a [[1, 2, 3], [], [4, 5]]\n", + "Name: 0, dtype: list" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.iloc[0] # each element is list-of-lists" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c7b65320-e1fa-44b2-a232-6ffb97ba1d18", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['Mask',\n", + " 'all',\n", + " 'almost_equal',\n", + " 'angle',\n", + " 'annotations',\n", + " 'any',\n", + " 'apply',\n", + " 'argcartesian',\n", + " 'argcombinations',\n", + " 'argmax',\n", + " 'argmin',\n", + " 'argsort',\n", + " 'array',\n", + " 'array_equal',\n", + " 'attrs',\n", + " 'awkward',\n", + " 'backend',\n", + " 'behavior',\n", + " 'behaviors',\n", + " 'broadcast_arrays',\n", + " 'broadcast_fields',\n", + " 'builder',\n", + " 'cartesian',\n", + " 'categories',\n", + " 'combinations',\n", + " 'concatenate',\n", + " 'contents',\n", + " 'copy',\n", + " 'corr',\n", + " 'count',\n", + " 'count_nonzero',\n", + " 'covar',\n", + " 'cpp_type',\n", + " 'cppyy',\n", + " 'drop_none',\n", + " 'dt',\n", + " 'enforce_type',\n", + " 'errors',\n", + " 'explode',\n", + " 'fields',\n", + " 'fill_none',\n", + " 'firsts',\n", + " 'flatten',\n", + " 'forms',\n", + " 'forth',\n", + " 'from_arrow',\n", + " 'from_arrow_schema',\n", + " 'from_avro_file',\n", + " 'from_buffers',\n", + " 'from_categorical',\n", + " 'from_cupy',\n", + " 'from_dlpack',\n", + " 'from_feather',\n", + " 'from_iter',\n", + " 'from_jax',\n", + " 'from_json',\n", + " 'from_numpy',\n", + " 'from_parquet',\n", + " 'from_raggedtensor',\n", + " 'from_rdataframe',\n", + " 'from_regular',\n", + " 'from_torch',\n", + " 'full_like',\n", + " 'highlevel',\n", + " 'imag',\n", + " 'index',\n", + " 'is_categorical',\n", + " 'is_none',\n", + " 'is_tuple',\n", + " 'is_valid',\n", + " 'isclose',\n", + " 'jax',\n", + " 'layout',\n", + " 'linear_fit',\n", + " 'local_index',\n", + " 'mask',\n", + " 'max',\n", + " 'mean',\n", + " 'merge_option_of_records',\n", + " 'merge_union_of_records',\n", + " 'metadata_from_parquet',\n", + " 'min',\n", + " 'mixin_class',\n", + " 'mixin_class_method',\n", + " 'moment',\n", + " 'nan_to_none',\n", + " 'nan_to_num',\n", + " 'nanargmax',\n", + " 'nanargmin',\n", + " 'nanmax',\n", + " 'nanmean',\n", + " 'nanmin',\n", + " 'nanprod',\n", + " 'nanstd',\n", + " 'nansum',\n", + " 'nanvar',\n", + " 'nbytes',\n", + " 'ndim',\n", + " 'num',\n", + " 'numba',\n", + " 'numba_type',\n", + " 'ones_like',\n", + " 'operations',\n", + " 'pad_none',\n", + " 'parameters',\n", + " 'prettyprint',\n", + " 'prod',\n", + " 'ptp',\n", + " 'ravel',\n", + " 'real',\n", + " 'record',\n", + " 'round',\n", + " 'run_lengths',\n", + " 'show',\n", + " 'singletons',\n", + " 'softmax',\n", + " 'sort',\n", + " 'std',\n", + " 'str',\n", + " 'strings_astype',\n", + " 'sum',\n", + " 'to_arrow',\n", + " 'to_arrow_table',\n", + " 'to_backend',\n", + " 'to_buffers',\n", + " 'to_cudf',\n", + " 'to_cupy',\n", + " 'to_dataframe',\n", + " 'to_feather',\n", + " 'to_jax',\n", + " 'to_json',\n", + " 'to_layout',\n", + " 'to_list',\n", + " 'to_numpy',\n", + " 'to_packed',\n", + " 'to_parquet',\n", + " 'to_parquet_dataset',\n", + " 'to_parquet_row_groups',\n", + " 'to_raggedtensor',\n", + " 'to_rdataframe',\n", + " 'to_regular',\n", + " 'to_torch',\n", + " 'tolist',\n", + " 'transform',\n", + " 'type',\n", + " 'types',\n", + " 'typestr',\n", + " 'typetracer',\n", + " 'unflatten',\n", + " 'unmerge',\n", + " 'unzip',\n", + " 'validity_error',\n", + " 'values_astype',\n", + " 'var',\n", + " 'where',\n", + " 'with_field',\n", + " 'with_name',\n", + " 'with_parameter',\n", + " 'without_field',\n", + " 'without_parameters',\n", + " 'zeros_like',\n", + " 'zip']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# allows all ak.* namespace, many identical to numpy equivalents\n", + "dir(df.a.ak)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8ff11e13-8503-4d79-a64c-993028709ca4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(28000000)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.a.ak.sum(axis=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2dd99fe5-0523-46c9-87ec-1392070f5139", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "cupy.ndarray" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# if output was array-like, it stays on the GPU\n", + "type(_)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9d8e55cf-8cf1-40a0-8733-24b7719f431d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.83 ms ± 16 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "# fast reduction across three levels of nesting\n", + "%timeit df.a.ak.sum(axis=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fae94aea-d9cf-4228-bcab-f843c7cc9c98", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 [[-1, -2, -3], [], [-4, -5]]\n", + "1 [[-6, -7]]\n", + "2 [[-1, -2, -3], [], [-4, -5]]\n", + "3 [[-6, -7]]\n", + "4 [[-1, -2, -3], [], [-4, -5]]\n", + " ... \n", + "1999995 [[-6, -7]]\n", + "1999996 [[-1, -2, -3], [], [-4, -5]]\n", + "1999997 [[-6, -7]]\n", + "1999998 [[-1, -2, -3], [], [-4, -5]]\n", + "1999999 [[-6, -7]]\n", + "Length: 2000000, dtype: list" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ufunc maintains structure\n", + "np.negative(df.a.ak)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1b83da2c-5e15-42f6-b594-f2ebaece5ac8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "256MiB\n" + ] + } + ], + "source": [ + "gpu_mem() # created new arrays on GPU, made new cuDF series" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "558ca2c3-d6c7-4404-bcab-557b9b03f795", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 [[2, 3, 4], [], [5, 6]]\n", + "1 [[7, 8]]\n", + "2 [[2, 3, 4], [], [5, 6]]\n", + "3 [[7, 8]]\n", + "4 [[2, 3, 4], [], [5, 6]]\n", + "dtype: list\n" + ] + } + ], + "source": [ + "# operator overload\n", + "print((df.a.ak + 1).head())" + ] + }, + { + "cell_type": "markdown", + "id": "bb51c8c3-42cf-4999-b688-67703f7311d2", + "metadata": {}, + "source": [ + "#### numba" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d240ea54-87b4-4b99-b67f-b2f885a4bf5e", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([15, 13, 15, ..., 13, 15, 13], dtype=int32)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numba.cuda\n", + "ak.numba.register_and_check()\n", + "\n", + "@numba.cuda.jit(extensions=[ak.numba.cuda])\n", + "def inner_sum(array, out):\n", + " tid = numba.cuda.grid(1)\n", + " if tid < len(array):\n", + " out[tid] = 0\n", + " for x in array[tid]:\n", + " for y in x:\n", + " out[tid] += y\n", + "\n", + "out = cp.empty(len(df.a), dtype=\"int32\")\n", + "blocksize = 256\n", + "numblocks = (len(df.a) + blocksize - 1) // blocksize\n", + "\n", + "df.a.ak.apply(lambda x: inner_sum[numblocks, blocksize](ak.drop_none(x, axis=0), out))\n", + "out\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "73a35144-292f-4b1d-bbc0-4ebba2a84b0d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6.17 ms ± 118 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%timeit df.a.ak.apply(lambda x: inner_sum[numblocks, blocksize](ak.drop_none(x, axis=0), out))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "bb781ca6-bdbd-4659-9885-8c634f490fca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "264MiB\n" + ] + } + ], + "source": [ + "gpu_mem() " + ] + }, + { + "cell_type": "markdown", + "id": "6d1ffd1a-b53b-4657-bab6-9c9223c28808", + "metadata": {}, + "source": [ + "**slice**" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "d039a508-e77c-4e23-a583-ec7997a88bb1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 [[1], [], [4]]\n", + "1 [[6]]\n", + "2 [[1], [], [4]]\n", + "3 [[6]]\n", + "4 [[1], [], [4]]\n", + " ... \n", + "1999995 [[6]]\n", + "1999996 [[1], [], [4]]\n", + "1999997 [[6]]\n", + "1999998 [[1], [], [4]]\n", + "1999999 [[6]]\n", + "Length: 2000000, dtype: list" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# pick the first number of the innermost lists, if there is one\n", + "df.a.ak[:, :, :1]" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f149dfaf-c01e-4d0a-8e01-2d20623d216f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 [1, 2, 3]\n", + "1 [6, 7]\n", + "2 [1, 2, 3]\n", + "3 [6, 7]\n", + "4 [1, 2, 3]\n", + " ... \n", + "1999995 [6, 7]\n", + "1999996 [1, 2, 3]\n", + "1999997 [6, 7]\n", + "1999998 [1, 2, 3]\n", + "1999999 [6, 7]\n", + "Length: 2000000, dtype: list" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# pick the first inner list of each row\n", + "df.a.ak[:, 0, :]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5aaf1903-6a6a-456f-89a7-3dedb01520ad", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:cuda] *", + "language": "python", + "name": "conda-env-cuda-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 7d9bda7..eba4bf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ docs = [ "pandas", "polars", "dask", "pyarrow", "pandoc", "nbsphinx" ] test = [ - "pandas", "polars", "dask", "pyarrow", "pytest", "pytest-cov", "numba", "dask-awkward" + "pandas", "polars", "dask", "pyarrow", "pytest", "pytest-cov", "numba", "dask-awkward", "distributed" ] [project.urls] diff --git a/src/akimbo/ak_from_cudf.py b/src/akimbo/ak_from_cudf.py new file mode 100644 index 0000000..a139a6c --- /dev/null +++ b/src/akimbo/ak_from_cudf.py @@ -0,0 +1,681 @@ +import cudf +import pyarrow +import cupy +import numpy + +import awkward as ak +from awkward._backends.numpy import NumpyBackend +from awkward._backends.cupy import CupyBackend + + +# COPIED from awkward/studies/cudf-to-awkward.py + +######################### stripped-down copy of src/awkward/_connect/pyarrow.py + + +_string_like = ( + pyarrow.string(), + pyarrow.large_string(), + pyarrow.binary(), + pyarrow.large_binary(), +) + +_pyarrow_to_numpy_dtype = { + pyarrow.date32(): (True, numpy.dtype("M8[D]")), + pyarrow.date64(): (False, numpy.dtype("M8[ms]")), + pyarrow.time32("s"): (True, numpy.dtype("M8[s]")), + pyarrow.time32("ms"): (True, numpy.dtype("M8[ms]")), + pyarrow.time64("us"): (False, numpy.dtype("M8[us]")), + pyarrow.time64("ns"): (False, numpy.dtype("M8[ns]")), + pyarrow.timestamp("s"): (False, numpy.dtype("M8[s]")), + pyarrow.timestamp("ms"): (False, numpy.dtype("M8[ms]")), + pyarrow.timestamp("us"): (False, numpy.dtype("M8[us]")), + pyarrow.timestamp("ns"): (False, numpy.dtype("M8[ns]")), + pyarrow.duration("s"): (False, numpy.dtype("m8[s]")), + pyarrow.duration("ms"): (False, numpy.dtype("m8[ms]")), + pyarrow.duration("us"): (False, numpy.dtype("m8[us]")), + pyarrow.duration("ns"): (False, numpy.dtype("m8[ns]")), +} + + +def revertable(modified, original): + modified.__pyarrow_original = original + return modified + + +def remove_optiontype(akarray): + return akarray.__pyarrow_original + + +def popbuffers_finalize(out, array, validbits, generate_bitmasks, fix_offsets=True): + # Every buffer from Arrow must be offsets-corrected. + if fix_offsets and (array.offset != 0 or len(array) != len(out)): + out = out[array.offset : array.offset + len(array)] + + # Everything must leave popbuffers as option-type; the mask_node will be + # removed by the next level up in popbuffers recursion if appropriate. + + if validbits is None and generate_bitmasks: + # ceildiv(len(out), 8) = -(len(out) // -8) + validbits = numpy.full(-(len(out) // -8), numpy.uint8(0xFF), dtype=numpy.uint8) + + if validbits is None: + return revertable(ak.contents.UnmaskedArray.simplified(out), out) + else: + return revertable( + ak.contents.BitMaskedArray.simplified( + ak.index.IndexU8(numpy.frombuffer(validbits, dtype=numpy.uint8)), + out, + valid_when=True, + length=len(out), + lsb_order=True, + ), + out, + ) + + +def popbuffers(paarray, arrow_type, buffers, generate_bitmasks): + ### Beginning of the big if-elif-elif chain! + + if isinstance(arrow_type, pyarrow.lib.DictionaryType): + masked_index = popbuffers( + paarray.indices, + arrow_type.index_type, + buffers, + generate_bitmasks, + ) + index = masked_index.content.data + + if not isinstance(masked_index, ak.contents.UnmaskedArray): + mask = masked_index.mask_as_bool(valid_when=False) + if mask.any(): + index = numpy.asarray(index, copy=True) + index[mask] = -1 + + content = handle_arrow(paarray.dictionary, generate_bitmasks) + + parameters = {"__array__": "categorical"} + + return revertable( + ak.contents.IndexedOptionArray.simplified( + ak.index.Index(index), + content, + parameters=parameters, + ), + ak.contents.IndexedArray( + ak.index.Index(index), + remove_optiontype(content) if content.is_option else content, + parameters=parameters, + ), + ) + + elif isinstance(arrow_type, pyarrow.lib.FixedSizeListType): + assert arrow_type.num_buffers == 1 + validbits = buffers.pop(0) + + akcontent = popbuffers( + paarray.values, arrow_type.value_type, buffers, generate_bitmasks + ) + + if not arrow_type.value_field.nullable: + # strip the dummy option-type node + akcontent = remove_optiontype(akcontent) + + out = ak.contents.RegularArray(akcontent, arrow_type.list_size, parameters=None) + return popbuffers_finalize(out, paarray, validbits, generate_bitmasks) + + elif isinstance(arrow_type, (pyarrow.lib.LargeListType, pyarrow.lib.ListType)): + assert arrow_type.num_buffers == 2 + validbits = buffers.pop(0) + paoffsets = buffers.pop(0) + + if isinstance(arrow_type, pyarrow.lib.LargeListType): + akoffsets = ak.index.Index64(numpy.frombuffer(paoffsets, dtype=numpy.int64)) + else: + akoffsets = ak.index.Index32(numpy.frombuffer(paoffsets, dtype=numpy.int32)) + + akcontent = popbuffers( + paarray.values, arrow_type.value_type, buffers, generate_bitmasks + ) + + if not arrow_type.value_field.nullable: + # strip the dummy option-type node + akcontent = remove_optiontype(akcontent) + + out = ak.contents.ListOffsetArray(akoffsets, akcontent, parameters=None) + return popbuffers_finalize(out, paarray, validbits, generate_bitmasks) + + elif isinstance(arrow_type, pyarrow.lib.MapType): + # FIXME: make a ListOffsetArray of 2-tuples with __array__ == "sorted_map". + # (Make sure the keys are sorted). + raise NotImplementedError + + elif isinstance( + arrow_type, (pyarrow.lib.Decimal128Type, pyarrow.lib.Decimal256Type) + ): + # Note: Decimal128Type and Decimal256Type are subtypes of FixedSizeBinaryType. + # NumPy doesn't support decimal: https://github.com/numpy/numpy/issues/9789 + raise ValueError( + "Arrow arrays containing pyarrow.decimal128 or pyarrow.decimal256 types can't be converted into Awkward Arrays" + ) + + elif isinstance(arrow_type, pyarrow.lib.FixedSizeBinaryType): + assert arrow_type.num_buffers == 2 + validbits = buffers.pop(0) + pacontent = buffers.pop(0) + + parameters = {"__array__": "bytestring"} + sub_parameters = {"__array__": "byte"} + + out = ak.contents.RegularArray( + ak.contents.NumpyArray( + numpy.frombuffer(pacontent, dtype=numpy.uint8), + parameters=sub_parameters, + backend=NumpyBackend.instance(), + ), + arrow_type.byte_width, + parameters=parameters, + ) + return popbuffers_finalize(out, paarray, validbits, generate_bitmasks) + + elif arrow_type in _string_like: + assert arrow_type.num_buffers == 3 + validbits = buffers.pop(0) + paoffsets = buffers.pop(0) + pacontent = buffers.pop(0) + + if arrow_type in _string_like[::2]: + akoffsets = ak.index.Index32(numpy.frombuffer(paoffsets, dtype=numpy.int32)) + else: + akoffsets = ak.index.Index64(numpy.frombuffer(paoffsets, dtype=numpy.int64)) + + if arrow_type in _string_like[:2]: + parameters = {"__array__": "string"} + sub_parameters = {"__array__": "char"} + else: + parameters = {"__array__": "bytestring"} + sub_parameters = {"__array__": "byte"} + + out = ak.contents.ListOffsetArray( + akoffsets, + ak.contents.NumpyArray( + numpy.frombuffer(pacontent, dtype=numpy.uint8), + parameters=sub_parameters, + backend=NumpyBackend.instance(), + ), + parameters=parameters, + ) + return popbuffers_finalize(out, paarray, validbits, generate_bitmasks) + + elif isinstance(arrow_type, pyarrow.lib.StructType): + assert arrow_type.num_buffers == 1 + validbits = buffers.pop(0) + + keys = [] + contents = [] + for i in range(arrow_type.num_fields): + field = arrow_type[i] + field_name = field.name + keys.append(field_name) + + akcontent = popbuffers( + paarray.field(field_name), field.type, buffers, generate_bitmasks + ) + if not field.nullable: + # strip the dummy option-type node + akcontent = remove_optiontype(akcontent) + contents.append(akcontent) + + out = ak.contents.RecordArray( + contents, keys, length=len(paarray), parameters=None + ) + return popbuffers_finalize( + out, paarray, validbits, generate_bitmasks, fix_offsets=False + ) + + elif isinstance(arrow_type, pyarrow.lib.UnionType): + if isinstance(arrow_type, pyarrow.lib.SparseUnionType): + assert arrow_type.num_buffers == 2 + validbits = buffers.pop(0) + nptags = numpy.frombuffer(buffers.pop(0), dtype=numpy.int8) + npindex = numpy.arange(len(nptags), dtype=numpy.int32) + else: + assert arrow_type.num_buffers == 3 + validbits = buffers.pop(0) + nptags = numpy.frombuffer(buffers.pop(0), dtype=numpy.int8) + npindex = numpy.frombuffer(buffers.pop(0), dtype=numpy.int32) + + akcontents = [] + for i in range(arrow_type.num_fields): + field = arrow_type[i] + akcontent = popbuffers( + paarray.field(i), field.type, buffers, generate_bitmasks + ) + + if not field.nullable: + # strip the dummy option-type node + akcontent = remove_optiontype(akcontent) + akcontents.append(akcontent) + + out = ak.contents.UnionArray.simplified( + ak.index.Index8(nptags), + ak.index.Index32(npindex), + akcontents, + parameters=None, + ) + return popbuffers_finalize(out, paarray, None, generate_bitmasks) + + elif arrow_type == pyarrow.null(): + validbits = buffers.pop(0) + assert arrow_type.num_fields == 0 + + # This is already an option-type and offsets-corrected, so no popbuffers_finalize. + return ak.contents.IndexedOptionArray( + ak.index.Index64(numpy.full(len(paarray), -1, dtype=numpy.int64)), + ak.contents.EmptyArray(parameters=None), + parameters=None, + ) + + elif arrow_type == pyarrow.bool_(): + assert arrow_type.num_buffers == 2 + validbits = buffers.pop(0) + bitdata = buffers.pop(0) + + bytedata = numpy.unpackbits( + numpy.frombuffer(bitdata, dtype=numpy.uint8), bitorder="little" + ) + + out = ak.contents.NumpyArray( + bytedata.view(numpy.bool_), + parameters=None, + backend=NumpyBackend.instance(), + ) + return popbuffers_finalize(out, paarray, validbits, generate_bitmasks) + + elif isinstance(arrow_type, pyarrow.lib.DataType): + assert arrow_type.num_buffers == 2 + validbits = buffers.pop(0) + data = buffers.pop(0) + + to64, dt = _pyarrow_to_numpy_dtype.get(str(arrow_type), (False, None)) + if to64: + data = numpy.astype( + numpy.frombuffer(data, dtype=numpy.int32), dtype=numpy.int64 + ) + if dt is None: + dt = arrow_type.to_pandas_dtype() + + out = ak.contents.NumpyArray( + numpy.frombuffer(data, dtype=dt), + parameters=None, + backend=NumpyBackend.instance(), + ) + return popbuffers_finalize(out, paarray, validbits, generate_bitmasks) + + else: + raise TypeError(f"unrecognized Arrow array type: {arrow_type!r}") + + +def handle_arrow(obj, generate_bitmasks): + buffers = obj.buffers() + out = popbuffers(obj, obj.type, buffers, generate_bitmasks) + assert len(buffers) == 0 + return out + + +def pyarrow_to_awkward( + pyarrow_array: pyarrow.lib.Array, + generate_bitmasks=False, + highlevel=True, + behavior=None, + attrs=None, +): + ctx = ak._layout.HighLevelContext(behavior=behavior, attrs=attrs).finalize() + + out = handle_arrow(pyarrow_array, generate_bitmasks) + if isinstance(out, ak.contents.UnmaskedArray): + out = remove_optiontype(out) + + def remove_revertable(layout, **kwargs): + if hasattr(layout, "__pyarrow_original"): + del layout.__pyarrow_original + + ak._do.recursively_apply(out, remove_revertable) + + return ctx.wrap(out, highlevel=highlevel) + + +######################### equivalent for CuDF + + +def recurse_finalize( + out: ak.contents.Content, + column: cudf.core.column.column.ColumnBase, + validbits: None | cudf.core.buffer.buffer.Buffer, + generate_bitmasks: bool, + fix_offsets: bool = True, +): + # Every buffer from Arrow must be offsets-corrected. + if fix_offsets and (column.offset != 0 or len(column) != len(out)): + out = out[column.offset : column.offset + len(column)] + + if validbits is None: + return revertable(ak.contents.UnmaskedArray.simplified(out), out) + else: + return revertable( + ak.contents.BitMaskedArray.simplified( + ak.index.IndexU8(cupy.asarray(validbits)), + out, + valid_when=True, + length=len(out), + lsb_order=True, + ), + out, + ) + + +def recurse( + column: cudf.core.column.column.ColumnBase, + arrow_type: pyarrow.lib.DataType, + generate_bitmasks: bool, +): + if isinstance(column, cudf.core.column.CategoricalColumn): + validbits = column.base_mask + + paindex = column.base_children[-1] + masked_index = recurse(paindex, arrow_type_of(paindex), generate_bitmasks) + index = masked_index.content.data + + if not isinstance(masked_index, ak.contents.UnmaskedArray): + mask = masked_index.mask_as_bool(valid_when=False) + if mask.any(): + index = cupy.asarray(index, copy=True) + index[mask] = -1 + + pacats = column.categories + content = recurse(pacats, arrow_type_of(pacats), generate_bitmasks) + + if index.dtype == cupy.dtype(cupy.int64): + akindex1 = ak.index.Index64(index) + akindex2 = akindex1 + elif index.dtype == cupy.dtype(cupy.uint32): + akindex1 = ak.index.Index64(index.astype(cupy.int64)) + akindex2 = ak.index.IndexU32(index) + elif index.dtype == cupy.dtype(cupy.int32): + akindex1 = ak.index.Index32(index) + akindex2 = akindex1 + else: + akindex1 = ak.index.Index64(index.astype(cupy.int64)) + akindex2 = akindex1 + + return revertable( + ak.contents.IndexedOptionArray.simplified( + akindex1, + content, + parameters={"__array__": "categorical"}, + ), + ak.contents.IndexedArray( + akindex2, + remove_optiontype(content) if content.is_option else content, + parameters={"__array__": "categorical"}, + ), + ) + + elif isinstance(arrow_type, pyarrow.lib.FixedSizeListType): + validbits = column.base_mask + + akcontent = recurse( + column.base_children[-1], arrow_type.value_type, generate_bitmasks + ) + + if not arrow_type.value_field.nullable: + # strip the dummy option-type node + akcontent = remove_optiontype(akcontent) + + out = ak.contents.RegularArray(akcontent, arrow_type.list_size, parameters=None) + return recurse_finalize(out, column, validbits, generate_bitmasks) + + elif isinstance(arrow_type, (pyarrow.lib.LargeListType, pyarrow.lib.ListType)): + validbits = column.base_mask + paoffsets = column.offsets.base_data + + if isinstance(arrow_type, pyarrow.lib.LargeListType): + akoffsets = ak.index.Index64(cupy.asarray(paoffsets).view(cupy.int64)) + else: + akoffsets = ak.index.Index32(cupy.asarray(paoffsets).view(cupy.int32)) + + akcontent = recurse( + column.base_children[-1], arrow_type.value_type, generate_bitmasks + ) + + if not arrow_type.value_field.nullable: + # strip the dummy option-type node + akcontent = remove_optiontype(akcontent) + + out = ak.contents.ListOffsetArray(akoffsets, akcontent, parameters=None) + return recurse_finalize(out, column, validbits, generate_bitmasks) + + elif isinstance(arrow_type, pyarrow.lib.MapType): + # FIXME: make a ListOffsetArray of 2-tuples with __array__ == "sorted_map". + # (Make sure the keys are sorted). + raise NotImplementedError + + elif isinstance( + arrow_type, (pyarrow.lib.Decimal128Type, pyarrow.lib.Decimal256Type) + ): + # Note: Decimal128Type and Decimal256Type are subtypes of FixedSizeBinaryType. + # NumPy doesn't support decimal: https://github.com/numpy/numpy/issues/9789 + raise ValueError( + "Arrow arrays containing pyarrow.decimal128 or pyarrow.decimal256 types can't be converted into Awkward Arrays" + ) + + elif isinstance(arrow_type, pyarrow.lib.FixedSizeBinaryType): + validbits = column.base_mask + pacontent = column.base_data + + parameters = {"__array__": "bytestring"} + sub_parameters = {"__array__": "byte"} + + out = ak.contents.RegularArray( + ak.contents.NumpyArray( + cupy.asarray(pacontent), + parameters=sub_parameters, + backend=CupyBackend.instance(), + ), + arrow_type.byte_width, + parameters=parameters, + ) + return recurse_finalize(out, column, validbits, generate_bitmasks) + + elif arrow_type in _string_like: + validbits = column.base_mask + + paoffsets = column.base_children[-1] + pacontent = column.base_data + + if arrow_type in _string_like[::2]: + akoffsets = ak.index.Index32(cupy.asarray(paoffsets).view(cupy.int32)) + else: + akoffsets = ak.index.Index64(cupy.asarray(paoffsets).view(cupy.int64)) + + if arrow_type in _string_like[:2]: + parameters = {"__array__": "string"} + sub_parameters = {"__array__": "char"} + else: + parameters = {"__array__": "bytestring"} + sub_parameters = {"__array__": "byte"} + + out = ak.contents.ListOffsetArray( + akoffsets, + ak.contents.NumpyArray( + cupy.asarray(pacontent), + parameters=sub_parameters, + backend=CupyBackend.instance(), + ), + parameters=parameters, + ) + return recurse_finalize(out, column, validbits, generate_bitmasks) + + elif isinstance(arrow_type, pyarrow.lib.StructType): + validbits = column.base_mask + + keys = [] + contents = [] + for i in range(arrow_type.num_fields): + field = arrow_type[i] + field_name = field.name + keys.append(field_name) + + akcontent = recurse(column.base_children[i], field.type, generate_bitmasks) + if not field.nullable: + # strip the dummy option-type node + akcontent = remove_optiontype(akcontent) + contents.append(akcontent) + + out = ak.contents.RecordArray( + contents, keys, length=len(column), parameters=None + ) + return recurse_finalize(out, column, validbits, generate_bitmasks) + + elif isinstance(arrow_type, pyarrow.lib.UnionType): + raise NotImplementedError + + elif arrow_type == pyarrow.null(): + validbits = column.base_mask + + # This is already an option-type and offsets-corrected, so no popbuffers_finalize. + return ak.contents.IndexedOptionArray( + ak.index.Index64(cupy.full(len(column), -1, dtype=cupy.int64)), + ak.contents.EmptyArray(parameters=None), + parameters=None, + ) + + elif arrow_type == pyarrow.bool_(): + validbits = column.base_mask + + ## boolean data from CuDF differs from Arrow: it's represented as bytes, not bits! + # bitdata = column.base_data + # bytedata = cupy.unpackbits(cupy.asarray(bitdata), bitorder="little") + bytedata = cupy.asarray(column.base_data) + + out = ak.contents.NumpyArray( + cupy.asarray(bytedata).view(cupy.bool_), + parameters=None, + backend=CupyBackend.instance(), + ) + return recurse_finalize(out, column, validbits, generate_bitmasks) + + elif isinstance(arrow_type, pyarrow.lib.DataType): + validbits = column.base_mask + + to64, dt = _pyarrow_to_numpy_dtype.get(str(arrow_type), (False, None)) + if to64: + data = cupy.asarray(data).view(cupy.int32).astype(cupy.int64) + if dt is None: + dt = arrow_type.to_pandas_dtype() + + out = ak.contents.NumpyArray( + cupy.asarray(column.base_data).view(dt), + parameters=None, + backend=CupyBackend.instance(), + ) + return recurse_finalize(out, column, validbits, generate_bitmasks) + + else: + raise TypeError(f"unrecognized Arrow array type: {arrow_type!r}") + + +def arrow_type_of(column): + dtype = column.dtype + + if isinstance(column, cudf.core.column.StringColumn): + return pyarrow.string() + + elif isinstance(column, cudf.core.column.CategoricalColumn): + return None # deal with it in `recurse` for nesting-generality + + elif isinstance(dtype, numpy.dtype): + if dtype == numpy.dtype(object): + raise TypeError("Python object type encountered in CuDF Series") + else: + return pyarrow.from_numpy_dtype(dtype) + + else: + return dtype.to_arrow() + + +def handle_cudf(cudf_series: cudf.core.series.Series, generate_bitmasks): + column = cudf_series._data[cudf_series.name] + return recurse(column, arrow_type_of(column), generate_bitmasks) + + +def cudf_to_awkward( + cudf_series: cudf.core.series.Series, + generate_bitmasks=False, + highlevel=True, + behavior=None, + attrs=None, +): + ctx = ak._layout.HighLevelContext(behavior=behavior, attrs=attrs).finalize() + + out = handle_cudf(cudf_series, generate_bitmasks) + if isinstance(out, ak.contents.UnmaskedArray): + out = remove_optiontype(out) + + def remove_revertable(layout, **kwargs): + if hasattr(layout, "__pyarrow_original"): + del layout.__pyarrow_original + + ak._do.recursively_apply(out, remove_revertable) + + return ctx.wrap(out, highlevel=highlevel) + + +######################### testing + + +if __name__ == "__main__": + # tests numerics, lists, records, and option-type, but not union-type + examples = [ + [False, True, True], # booleans are special (1-bit) + [1.1, 2.2, 3.3], + [[False, True, True], [], [True, False]], + [[1, 2, 3], [], [4, 5]], + [[[1, 2], [3]], [], [[]], [[4], [], [5, 6, 7]], [[8, 9]]], + [{"x": 1}, {"x": 2}, {"x": 3}], + [{"x": 1.1, "y": []}, {"x": 2.2, "y": [1]}, {"x": 3.3, "y": [1, 2]}], + [[{"x": 1}, {"x": 2}, {"x": 3}], [], [{"x": 4}, {"x": 5}]], + ["This", "is", "a", "string", "array", ".", ""], + [["This", "is", "a"], ["nested"], ["string", "array", ".", ""]], + [None, None, None, None, None], + [False, True, None, True], + [1.1, 2.2, None, 3.3], + [[False, True, None, True], [], [True, False]], + [[False, True, True], None, [], [True, False]], + [[1, 2, None, 3], [], [4, 5]], + [[1, 2, 3], None, [], [4, 5]], + [[[1, 2, None], [3]], [], [[]], [[4], [], [5, 6, 7]], [[8, 9]]], + [[[1, 2], None, [3]], [], [[]], [[4], [], [5, 6, 7]], [[8, 9]]], + [[[1, 2], [3]], None, [], [[]], [[4], [], [5, 6, 7]], [[8, 9]]], + [{"x": 1}, {"x": None}, {"x": 3}], + [{"x": 1}, {"x": 2}, None, {"x": 3}], + [{"x": 1.1, "y": []}, {"x": None, "y": [1]}, {"x": 3.3, "y": [1, 2]}], + [{"x": 1.1, "y": []}, {"x": 2.2, "y": [1, None]}, {"x": 3.3, "y": [1, 2]}], + [{"x": 1.1, "y": []}, {"x": 2.2, "y": [1]}, None, {"x": 3.3, "y": [1, 2]}], + [[{"x": 1}, {"x": None}, {"x": 3}], [], [{"x": 4}, {"x": 5}]], + [[{"x": 1}, {"x": 2}, None, {"x": 3}], [], [{"x": 4}, {"x": 5}]], + [[{"x": 1}, {"x": 2}, {"x": 3}], None, [], [{"x": 4}, {"x": 5}]], + ["This", "is", "a", None, "string", "array", ".", ""], + [["This", "is", "a", None], ["nested"], ["string", "array", ".", ""]], + [["This", "is", "a"], None, ["nested"], ["string", "array", ".", ""]], + numpy.array(["2024-01-01", "2024-01-02"], dtype="datetime64[s]"), + numpy.array([1, 2, 3], dtype="timedelta64[s]"), + ] + + for example in examples: + print(f"---- {example}") + df = cudf.DataFrame({"column": example}) + + awkward_array = cudf_to_awkward(df["column"]) + assert ak.backend(awkward_array) == "cuda" + assert awkward_array.tolist() == list(example), awkward_array.show(type=True) diff --git a/src/akimbo/cudf.py b/src/akimbo/cudf.py new file mode 100644 index 0000000..abcbeef --- /dev/null +++ b/src/akimbo/cudf.py @@ -0,0 +1,130 @@ +import functools +from typing import Callable + +import awkward as ak +import cudf +from cudf import DataFrame, Series, _lib as libcudf +from cudf.core.column.string import StringMethods +from cudf.core.column.datetime import DatetimeColumn + +from akimbo.ak_from_cudf import cudf_to_awkward as from_cudf +from akimbo.mixin import Accessor +from akimbo.datetimes import DatetimeAccessor, match as match_t +from akimbo.strings import StringAccessor +from akimbo.apply_tree import dec, leaf + + +def match_string(arr): + return arr.parameters.get("__array__", "") == "string" + + +class CudfStringAccessor(StringAccessor): + """String operations on nested/var-length data""" + + def decode(self, encoding: str = "utf-8"): + raise NotImplementedError("cudf does not support bytearray type, so we can't automatically identify them") + + def encode(self, encoding: str = "utf-8"): + raise NotImplementedError("cudf does not support bytearray type") + + +def dec_cu(op, match=match_string): + + @functools.wraps(op) + def f(lay, **kwargs): + # op(column, ...)->column + col = op(lay._to_cudf(cudf, None, len(lay)), **kwargs) + return from_cudf(cudf.Series(col)).layout + + return dec(func=f, match=match, inmode="ak") + + +for meth in dir(StringMethods): + if meth.startswith("_"): + continue + + @functools.wraps(getattr(StringMethods, meth)) + def f(lay, method=meth, **kwargs): + # this is different from dec_cu, because we need to instantiate StringMethods + # before getting the method from it + col = getattr(StringMethods(cudf.Series(lay._to_cudf(cudf, None, len(lay)))), method)(**kwargs) + return from_cudf(col).layout + + setattr(CudfStringAccessor, meth, dec(func=f, match=match_string, inmode="ak")) + + +class CudfDatetimeAccessor(DatetimeAccessor): + + ... + + +for meth in dir(DatetimeColumn): + if meth.startswith("_") or meth == "strptime": + # strptime belongs in .str, not here! + continue + + @functools.wraps(getattr(DatetimeColumn, meth)) + def f(lay, method=meth, **kwargs): + # this is different from dec_cu, because we need to instantiate StringMethods + # before getting the method from it + m = getattr(lay._to_cudf(cudf, None, len(lay)), method) + if callable(m): + col = m(**kwargs) + else: + # attributes giving components + col = m + return from_cudf(cudf.Series(col)).layout + + if isinstance(getattr(DatetimeColumn, meth), property): + setattr(CudfDatetimeAccessor, meth, property(dec(func=f, match=match_t, inmode="ak"))) + else: + setattr(CudfDatetimeAccessor, meth, dec(func=f, match=match_t, inmode="ak")) + + +class CudfAwkwardAccessor(Accessor): + series_type = Series + dataframe_type = DataFrame + + @classmethod + def _to_output(cls, arr): + if isinstance(arr, ak.Array): + return ak.to_cudf(arr) + elif isinstance(arr, ak.contents.Content): + return arr._to_cudf(cudf, None, len(arr)) + return arr + + @classmethod + def to_array(cls, data) -> ak.Array: + return from_cudf(data) + + @property + def array(self) -> ak.Array: + return self.to_array(self._obj) + + @property + def str(self): + """Nested string operations""" + # need to find string ops within cudf + return CudfStringAccessor(self) + + cast = dec_cu(libcudf.unary.cast, match=leaf) + + @property + def dt(self): + """Nested datetime operations""" + # need to find datetime ops within cudf + return CudfDatetimeAccessor(self) + + def apply(self, fn: Callable, *args, **kwargs): + if "CPUDispatcher" in str(fn): + # auto wrap original function for GPU + raise NotImplementedError + super().apply(fn, *args, **kwargs) + + +@property # type:ignore +def ak_property(self): + return CudfAwkwardAccessor(self) + + +Series.ak = ak_property # no official register function? diff --git a/src/akimbo/io.py b/src/akimbo/io.py index 24528d4..9c75230 100644 --- a/src/akimbo/io.py +++ b/src/akimbo/io.py @@ -19,9 +19,13 @@ def ak_to_series(ds, backend="pandas", extract=True): # TODO: actually don't use this, use dask-awkward, or dask.dataframe s = akimbo.polars.PolarsAwkwardAccessor._to_output(ds) + elif backend == "cudf": + import akimbo.cudf + + s = akimbo.cudf.CudfAwkwardAccessor._to_output(ds) else: raise ValueError("Backend must be in {'pandas', 'polars', 'dask'}") - if extract: + if extract and ds.fields: return s.ak.unmerge() return s diff --git a/src/akimbo/strings.py b/src/akimbo/strings.py index b9eb73b..6e7fcb2 100644 --- a/src/akimbo/strings.py +++ b/src/akimbo/strings.py @@ -59,6 +59,8 @@ def _decode(layout): class StringAccessor: + """String operations on nested/var-length data""" + def __init__(self, accessor): self.accessor = accessor diff --git a/tests/test_cudf.py b/tests/test_cudf.py new file mode 100644 index 0000000..e573df4 --- /dev/null +++ b/tests/test_cudf.py @@ -0,0 +1,85 @@ +import datetime + +import pytest + +import pyarrow as pa +import awkward as ak + +pytest.importorskip("akimbo.cudf") + +import akimbo.io +import cudf + + +def test_operator_overload(): + s = pa.array([[1, 2, 3], [], [4, 5]], type=pa.list_(pa.int32())) + series = cudf.Series(s) + assert ak.backend(series.ak.array) == "cuda" + s2 = series.ak + 1 + assert ak.backend(s2.ak.array) == "cuda" + assert isinstance(s2, cudf.Series) + assert s2.ak.to_list() == [[2, 3, 4], [], [5, 6]] + + +def test_inner_slicing(): + s = pa.array([[1, 2, 3], [0], [4, 5]], type=pa.list_(pa.int32())) + series = cudf.Series(s) + assert ak.backend(series.ak.array) == "cuda" + s2 = series.ak[:, 0] + assert ak.backend(s2.ak.array) == "cuda" + assert isinstance(s2, cudf.Series) + assert s2.ak.to_list() == [1, 0, 4] + s2 = series.ak[:, :2] + assert ak.backend(s2.ak.array) == "cuda" + assert isinstance(s2, cudf.Series) + assert s2.ak.to_list() == [[1, 2], [0], [4, 5]] + s2 = series.ak[:, ::2] + assert ak.backend(s2.ak.array) == "cuda" + assert isinstance(s2, cudf.Series) + assert s2.ak.to_list() == [[1, 3], [0], [4]] + + +def test_string_methods(): + s = pa.array([{"s": ["hey", "Ho"], "i": [0]}, {"s": ["Gar", "go"], "i": [2]}], + type=pa.struct([("s", pa.list_(pa.string())), ("i", pa.list_(pa.int32()))])) + series = cudf.Series(s) + s2 = series.ak.str.upper() + assert s2.ak.to_list() == [{"s": ["HEY", "HO"], "i": [0]}, {"s": ["GAR", "GO"], "i": [2]}] + + assert series.ak.str.upper.__doc__ + # kwargs + s2 = series.ak.str.replace(pat="h", repl="B") + assert s2.ak.to_list() == [{"s": ["Bey", "Ho"], "i": [0]}, {"s": ["Gar", "go"], "i": [2]}] + + # positional args + s2 = series.ak.str.replace("h", "B") + assert s2.ak.to_list() == [{"s": ["Bey", "Ho"], "i": [0]}, {"s": ["Gar", "go"], "i": [2]}] + + # non-str output + s2 = series.ak.str.len() + assert s2.ak.to_list() == [{"s": [3, 2], "i": [0]}, {"s": [3, 2], "i": [2]}] + + +def test_cast(): + s = cudf.Series([0, 1, 2]) + # shows that cast to timestamp needs to be two-step in cudf + s2 = s.ak.cast('m8[s]').ak.cast('M8[s]') + out = s2.ak.to_list() + assert out == [ + datetime.datetime(1970, 1, 1, 0, 0), + datetime.datetime(1970, 1, 1, 0, 0, 1), + datetime.datetime(1970, 1, 1, 0, 0, 2) + ] + + +def test_times(): + data = [ + datetime.datetime(1970, 1, 1, 0, 0), + datetime.datetime(1970, 1, 1, 0, 0, 1), + None, + datetime.datetime(1970, 1, 1, 0, 0, 2) + ] + arr = ak.Array([[data], [], [data]]) + s = akimbo.io.ak_to_series(arr, "cudf") + s2 = s.ak.dt.second + assert s2.ak.to_list() == [[[0, 1, None, 2]], [], [[0, 1, None, 2]]]