Skip to content

Commit

Permalink
Add vstack implementation (Bears-R-Us#3305)
Browse files Browse the repository at this point in the history
* add vstack implementation

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix bad numpy import

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix mypy type error

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* move vstack test to PROTO_tests. Fix bug in vstack's use of np.common_type

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix mypy error

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix undefined function in base_test

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

---------

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>
  • Loading branch information
jeremiah-corrado committed Jun 10, 2024
1 parent afc6cdf commit c7ffa24
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 0 deletions.
24 changes: 24 additions & 0 deletions PROTO_tests/tests/array_manipulation_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import arkouda as ak
import numpy as np

import pytest
import json


def get_server_max_array_dims():
try:
return json.load(open('serverConfig.json', 'r'))['max_array_dims']
except (ValueError, FileNotFoundError, TypeError, KeyError):
return 1


class TestManipulationFunctions:
@pytest.mark.skipif(get_server_max_array_dims() < 2, reason="vstack requires server with 'max_array_dims' >= 2")
def test_vstack(self):
a = [ak.random.randint(0, 10, 25) for _ in range(4)]
n = [x.to_ndarray() for x in a]

n_vstack = np.vstack(n)
a_vstack = ak.vstack(a)

assert n_vstack.tolist() == a_vstack.to_list()
1 change: 1 addition & 0 deletions ServerModules.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ IndexingMsg
JoinEqWithDTMsg
KExtremeMsg
LogMsg
# ManipulationMsg
OperatorMsg
ParquetMsg
RandMsg
Expand Down
1 change: 1 addition & 0 deletions arkouda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from arkouda.sorting import *
from arkouda.pdarraysetops import *
from arkouda.pdarraycreation import *
from arkouda.pdarraymanipulation import *
from arkouda.numeric import *
from arkouda.groupbyclass import *
from arkouda.strings import *
Expand Down
73 changes: 73 additions & 0 deletions arkouda/pdarraymanipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Tuple, List, Literal, Union, Optional
from typeguard import typechecked

from arkouda.client import generic_msg
from arkouda.pdarrayclass import pdarray, create_pdarray
from arkouda.dtypes import dtype as akdtype

import numpy as np

__all__ = ["vstack"]


@typechecked
def vstack(
tup: Union[Tuple[pdarray], List[pdarray]],
*,
dtype: Optional[Union[type, str]] = None,
casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind",
) -> pdarray:
"""
Stack a sequence of arrays vertically (row-wise).
This is equivalent to concatenation along the first axis after 1-D arrays of
shape `(N,)` have been reshaped to `(1,N)`.
Parameters
----------
tup : Tuple[pdarray]
The arrays to be stacked
dtype : Optional[Union[type, str]], optional
The data-type of the output array. If not provided, the output
array will be determined using `np.common_type` on the
input arrays Defaults to None
casting : {"no", "equiv", "safe", "same_kind", "unsafe"], optional
Controls what kind of data casting may occur - currently unused
Returns
-------
pdarray
The stacked array
"""

if casting != "same_kind":
# TODO: wasn't clear from the docs what each of the casting options does
raise NotImplementedError(f"casting={casting} is not yet supported")

# ensure all arrays have the same number of dimensions
ndim = tup[0].ndim
for a in tup:
if a.ndim != ndim:
raise ValueError("all input arrays must have the same number of dimensions")

# establish the dtype of the output array
if dtype is None:
dtype_ = np.common_type(*[np.empty(0, dtype=a.dtype) for a in tup])
else:
dtype_ = akdtype(dtype)

# cast the input arrays to the output dtype if necessary
arrays = [a.astype(dtype_) if a.dtype != dtype_ else a for a in tup]

# stack the arrays along the first axis
return create_pdarray(
generic_msg(
cmd=f"stack{ndim}D",
args={
"names": list(arrays),
"n": len(arrays),
"axis": 0,
},
)
)
1 change: 1 addition & 0 deletions tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
stop_arkouda_server,
)


"""
ArkoudaTest defines the base Arkouda test logic for starting up the arkouda_server at the
launch of a unittest TestCase and shutting down the arkouda_server at the completion of
Expand Down

0 comments on commit c7ffa24

Please sign in to comment.