forked from Bears-R-Us/arkouda
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add vstack implementation (Bears-R-Us#3305)
* add vstack implementation Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com> * fix bad numpy import Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com> * fix mypy type error Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com> * move vstack test to PROTO_tests. Fix bug in vstack's use of np.common_type Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com> * fix mypy error Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com> * fix undefined function in base_test Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com> --------- Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>
- Loading branch information
1 parent
afc6cdf
commit c7ffa24
Showing
5 changed files
with
100 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import arkouda as ak | ||
import numpy as np | ||
|
||
import pytest | ||
import json | ||
|
||
|
||
def get_server_max_array_dims(): | ||
try: | ||
return json.load(open('serverConfig.json', 'r'))['max_array_dims'] | ||
except (ValueError, FileNotFoundError, TypeError, KeyError): | ||
return 1 | ||
|
||
|
||
class TestManipulationFunctions: | ||
@pytest.mark.skipif(get_server_max_array_dims() < 2, reason="vstack requires server with 'max_array_dims' >= 2") | ||
def test_vstack(self): | ||
a = [ak.random.randint(0, 10, 25) for _ in range(4)] | ||
n = [x.to_ndarray() for x in a] | ||
|
||
n_vstack = np.vstack(n) | ||
a_vstack = ak.vstack(a) | ||
|
||
assert n_vstack.tolist() == a_vstack.to_list() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ IndexingMsg | |
JoinEqWithDTMsg | ||
KExtremeMsg | ||
LogMsg | ||
# ManipulationMsg | ||
OperatorMsg | ||
ParquetMsg | ||
RandMsg | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Tuple, List, Literal, Union, Optional | ||
from typeguard import typechecked | ||
|
||
from arkouda.client import generic_msg | ||
from arkouda.pdarrayclass import pdarray, create_pdarray | ||
from arkouda.dtypes import dtype as akdtype | ||
|
||
import numpy as np | ||
|
||
__all__ = ["vstack"] | ||
|
||
|
||
@typechecked | ||
def vstack( | ||
tup: Union[Tuple[pdarray], List[pdarray]], | ||
*, | ||
dtype: Optional[Union[type, str]] = None, | ||
casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", | ||
) -> pdarray: | ||
""" | ||
Stack a sequence of arrays vertically (row-wise). | ||
This is equivalent to concatenation along the first axis after 1-D arrays of | ||
shape `(N,)` have been reshaped to `(1,N)`. | ||
Parameters | ||
---------- | ||
tup : Tuple[pdarray] | ||
The arrays to be stacked | ||
dtype : Optional[Union[type, str]], optional | ||
The data-type of the output array. If not provided, the output | ||
array will be determined using `np.common_type` on the | ||
input arrays Defaults to None | ||
casting : {"no", "equiv", "safe", "same_kind", "unsafe"], optional | ||
Controls what kind of data casting may occur - currently unused | ||
Returns | ||
------- | ||
pdarray | ||
The stacked array | ||
""" | ||
|
||
if casting != "same_kind": | ||
# TODO: wasn't clear from the docs what each of the casting options does | ||
raise NotImplementedError(f"casting={casting} is not yet supported") | ||
|
||
# ensure all arrays have the same number of dimensions | ||
ndim = tup[0].ndim | ||
for a in tup: | ||
if a.ndim != ndim: | ||
raise ValueError("all input arrays must have the same number of dimensions") | ||
|
||
# establish the dtype of the output array | ||
if dtype is None: | ||
dtype_ = np.common_type(*[np.empty(0, dtype=a.dtype) for a in tup]) | ||
else: | ||
dtype_ = akdtype(dtype) | ||
|
||
# cast the input arrays to the output dtype if necessary | ||
arrays = [a.astype(dtype_) if a.dtype != dtype_ else a for a in tup] | ||
|
||
# stack the arrays along the first axis | ||
return create_pdarray( | ||
generic_msg( | ||
cmd=f"stack{ndim}D", | ||
args={ | ||
"names": list(arrays), | ||
"n": len(arrays), | ||
"axis": 0, | ||
}, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters