Skip to content

Commit

Permalink
1077 update dtypes to better support scalars (#1492)
Browse files Browse the repository at this point in the history
* added updated dtypes logic, updated unit tests #1077

* fixed mypy error

* updates for mypy

* mypy updates

* mypy updates

* added tests for #1077)

* fixed flake8 errors #1077

* updated formatting per PR review feedback #1077

* updated per PR feedback #1077

* fixed flake8 error #1077
  • Loading branch information
hokiegeek2 committed Jun 13, 2022
1 parent ed79d1a commit f9a7fa2
Show file tree
Hide file tree
Showing 8 changed files with 410 additions and 165 deletions.
68 changes: 59 additions & 9 deletions arkouda/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,34 @@
# Union aliases used for static and runtime type checking
bool_scalars = Union[builtins.bool, np.bool_]
float_scalars = Union[float, np.float64]
int_scalars = Union[int, np.int64, np.uint64]
numeric_scalars = Union[float, np.float64, int, np.int64, np.uint8, np.uint64]
int_scalars = Union[
int,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
]
numeric_scalars = Union[float_scalars, int_scalars]
numeric_and_bool_scalars = Union[bool_scalars, numeric_scalars]
numpy_scalars = Union[np.float64, np.int64, np.bool_, np.uint8, np.str_, np.uint64]
numpy_scalars = Union[
np.float64,
np.int8,
np.int16,
np.int32,
np.int64,
np.bool_,
np.str_,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
]
str_scalars = Union[str, np.str_]
all_scalars = Union[float, np.float64, int, np.int64, np.uint64, builtins.bool, np.bool_, str, np.str_]
all_scalars = Union[bool_scalars, numeric_scalars, numpy_scalars, str_scalars]

"""
The DType enum defines the supported Arkouda data types in string form.
Expand Down Expand Up @@ -94,10 +116,34 @@ def __repr__(self) -> str: # type: ignore
return self.value


ARKOUDA_SUPPORTED_INTS = (int, np.int64, np.uint64)
ARKOUDA_SUPPORTED_INTS = (
int,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
)
ARKOUDA_SUPPORTED_FLOATS = (float, np.float64)
ARKOUDA_SUPPORTED_NUMBERS = (int, np.int64, float, np.float64, np.uint64)
ARKOUDA_SUPPORTED_DTYPES = frozenset([member.value for _, member in DType.__members__.items()])
ARKOUDA_SUPPORTED_NUMBERS = (
int,
np.int8,
np.int16,
np.int32,
np.int64,
float,
np.float64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
)
ARKOUDA_SUPPORTED_DTYPES = frozenset(
[member.value for _, member in DType.__members__.items()]
)

DTypes = frozenset([member.value for _, member in DType.__members__.items()])
DTypeObjects = frozenset([bool, float, float64, int, int64, str, str_, uint8, uint64])
Expand Down Expand Up @@ -182,13 +228,17 @@ def resolve_scalar_dtype(val: object) -> str: # type: ignore
):
return "bool"
# Python int or np.int* or np.uint*
elif isinstance(val, int) or (hasattr(val, "dtype") and cast(np.uint, val).dtype.kind in "ui"):
elif isinstance(val, int) or (
hasattr(val, "dtype") and cast(np.uint, val).dtype.kind in "ui"
):
if isinstance(val, np.uint64):
return "uint64"
else:
return "int64"
# Python float or np.float*
elif isinstance(val, float) or (hasattr(val, "dtype") and cast(np.float_, val).dtype.kind == "f"):
elif isinstance(val, float) or (
hasattr(val, "dtype") and cast(np.float_, val).dtype.kind == "f"
):
return "float64"
elif isinstance(val, builtins.str) or isinstance(val, np.str_):
return "str"
Expand Down
100 changes: 76 additions & 24 deletions arkouda/groupbyclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def unique(
if hasattr(pda, "_get_grouping_keys"):
# Single groupable array
nkeys = 1
grouping_keys = cast(list, cast(groupable_element_type, pda)._get_grouping_keys())
grouping_keys = cast(
list, cast(groupable_element_type, pda)._get_grouping_keys()
)
else:
# Sequence of groupable arrays
nkeys = len(pda)
Expand All @@ -108,7 +110,11 @@ def unique(
repMsg = generic_msg(
cmd="unique",
args="{} {} {:n} {} {}".format(
return_groups, assume_sorted, effectiveKeys, " ".join(keynames), " ".join(keytypes)
return_groups,
assume_sorted,
effectiveKeys,
" ".join(keynames),
" ".join(keytypes),
),
)
if return_groups:
Expand Down Expand Up @@ -217,7 +223,11 @@ class GroupBy:
Reductions = GROUPBY_REDUCTION_TYPES

def __init__(
self, keys: Optional[groupable], assume_sorted: bool = False, hash_strings: bool = True, **kwargs
self,
keys: Optional[groupable],
assume_sorted: bool = False,
hash_strings: bool = True,
**kwargs,
) -> None:
# Type Checks required because @typechecked was removed for causing other issues
# This prevents non-bool values that can be evaluated to true (ie non-empty arrays)
Expand Down Expand Up @@ -340,27 +350,36 @@ def aggregate(

operator = operator.lower()
if operator not in self.Reductions:
raise ValueError(f"Unsupported reduction: {operator}\nMust be one of {self.Reductions}")
raise ValueError(
f"Unsupported reduction: {operator}\nMust be one of {self.Reductions}"
)

# TO DO: remove once logic is ported over to Chapel
if operator == "nunique":
return self.nunique(values)

# All other aggregations operate on pdarray
if cast(pdarray, values).size != self.size:
raise ValueError("Attempt to group array using key array of different length")
raise ValueError(
"Attempt to group array using key array of different length"
)

if self.assume_sorted:
permuted_values = cast(pdarray, values)
else:
permuted_values = cast(pdarray, values)[cast(pdarray, self.permutation)]

cmd = "segmentedReduction"
args = "{} {} {} {}".format(permuted_values.name, self.segments.name, operator, skipna)
args = "{} {} {} {}".format(
permuted_values.name, self.segments.name, operator, skipna
)
repMsg = generic_msg(cmd=cmd, args=args)
self.logger.debug(repMsg)
if operator.startswith("arg"):
return (self.unique_keys, cast(pdarray, self.permutation[create_pdarray(repMsg)]))
return (
self.unique_keys,
cast(pdarray, self.permutation[create_pdarray(repMsg)]),
)
else:
return self.unique_keys, create_pdarray(repMsg)

Expand Down Expand Up @@ -547,7 +566,9 @@ def min(self, values: pdarray, skipna: bool = True) -> Tuple[groupable, pdarray]
(array([2, 3, 4]), array([1, 1, 3]))
"""
if values.dtype == bool:
raise TypeError("min is only supported for pdarrays of dtype float64, uint64, and int64")
raise TypeError(
"min is only supported for pdarrays of dtype float64, uint64, and int64"
)
return self.aggregate(values, "min", skipna)

def max(self, values: pdarray, skipna: bool = True) -> Tuple[groupable, pdarray]:
Expand Down Expand Up @@ -594,7 +615,9 @@ def max(self, values: pdarray, skipna: bool = True) -> Tuple[groupable, pdarray]
(array([2, 3, 4]), array([4, 4, 3]))
"""
if values.dtype == bool:
raise TypeError("max is only supported for pdarrays of dtype float64, uint64, and int64")
raise TypeError(
"max is only supported for pdarrays of dtype float64, uint64, and int64"
)
return self.aggregate(values, "max", skipna)

def argmin(self, values: pdarray) -> Tuple[groupable, pdarray]:
Expand Down Expand Up @@ -751,7 +774,10 @@ def nunique(self, values: groupable) -> Tuple[groupable, pdarray]:
# or Categorical (the last two have a .group() method).
# Can't directly test Categorical due to circular import.
if isinstance(values, pdarray):
if cast(pdarray, values).dtype != akint64 and cast(pdarray, values).dtype != akuint64:
if (
cast(pdarray, values).dtype != akint64
and cast(pdarray, values).dtype != akuint64
):
raise TypeError("nunique unsupported for this dtype")
togroup = [ukidx, values]
elif hasattr(values, "group"):
Expand All @@ -774,7 +800,9 @@ def nunique(self, values: groupable) -> Tuple[groupable, pdarray]:
# Re-join unique counts with original keys (sorting guarantees same order)
return self.unique_keys, nuniq

def any(self, values: pdarray) -> Tuple[Union[pdarray, List[Union[pdarray, Strings]]], pdarray]:
def any(
self, values: pdarray
) -> Tuple[Union[pdarray, List[Union[pdarray, Strings]]], pdarray]:
"""
Using the permutation stored in the GroupBy instance, group another
array of values and perform an "or" reduction on each group.
Expand Down Expand Up @@ -804,7 +832,9 @@ def any(self, values: pdarray) -> Tuple[Union[pdarray, List[Union[pdarray, Strin
raise TypeError("any is only supported for pdarrays of dtype bool")
return self.aggregate(values, "any") # type: ignore

def all(self, values: pdarray) -> Tuple[Union[pdarray, List[Union[pdarray, Strings]]], pdarray]:
def all(
self, values: pdarray
) -> Tuple[Union[pdarray, List[Union[pdarray, Strings]]], pdarray]:
"""
Using the permutation stored in the GroupBy instance, group
another array of values and perform an "and" reduction on
Expand Down Expand Up @@ -838,7 +868,9 @@ def all(self, values: pdarray) -> Tuple[Union[pdarray, List[Union[pdarray, Strin

return self.aggregate(values, "all") # type: ignore

def OR(self, values: pdarray) -> Tuple[Union[pdarray, List[Union[pdarray, Strings]]], pdarray]:
def OR(
self, values: pdarray
) -> Tuple[Union[pdarray, List[Union[pdarray, Strings]]], pdarray]:
"""
Bitwise OR of values in each segment.
Expand Down Expand Up @@ -870,11 +902,15 @@ def OR(self, values: pdarray) -> Tuple[Union[pdarray, List[Union[pdarray, String
Raised if all is not supported for the values dtype
"""
if values.dtype != akint64 and values.dtype != akuint64:
raise TypeError("OR is only supported for pdarrays of dtype int64 or uint64")
raise TypeError(
"OR is only supported for pdarrays of dtype int64 or uint64"
)

return self.aggregate(values, "or") # type: ignore

def AND(self, values: pdarray) -> Tuple[Union[pdarray, List[Union[pdarray, Strings]]], pdarray]:
def AND(
self, values: pdarray
) -> Tuple[Union[pdarray, List[Union[pdarray, Strings]]], pdarray]:
"""
Bitwise AND of values in each segment.
Expand Down Expand Up @@ -906,11 +942,15 @@ def AND(self, values: pdarray) -> Tuple[Union[pdarray, List[Union[pdarray, Strin
Raised if all is not supported for the values dtype
"""
if values.dtype != akint64 and values.dtype != akuint64:
raise TypeError("AND is only supported for pdarrays of dtype int64 or uint64")
raise TypeError(
"AND is only supported for pdarrays of dtype int64 or uint64"
)

return self.aggregate(values, "and") # type: ignore

def XOR(self, values: pdarray) -> Tuple[Union[pdarray, List[Union[pdarray, Strings]]], pdarray]:
def XOR(
self, values: pdarray
) -> Tuple[Union[pdarray, List[Union[pdarray, Strings]]], pdarray]:
"""
Bitwise XOR of values in each segment.
Expand Down Expand Up @@ -942,7 +982,9 @@ def XOR(self, values: pdarray) -> Tuple[Union[pdarray, List[Union[pdarray, Strin
Raised if all is not supported for the values dtype
"""
if values.dtype != akint64 and values.dtype != akuint64:
raise TypeError("XOR is only supported for pdarrays of dtype int64 or uint64")
raise TypeError(
"XOR is only supported for pdarrays of dtype int64 or uint64"
)

return self.aggregate(values, "xor") # type: ignore

Expand Down Expand Up @@ -1053,7 +1095,9 @@ def build_from_components(user_defined_name: str = None, **kwargs) -> GroupBy:
if "segments" not in kwargs:
missingKeys.append("segments")

raise ValueError(f"Can't build GroupBy. kwargs is missing required keys: {missingKeys}.")
raise ValueError(
f"Can't build GroupBy. kwargs is missing required keys: {missingKeys}."
)

def _get_groupby_required_pieces(self) -> Dict:
"""
Expand Down Expand Up @@ -1114,13 +1158,17 @@ def register(self, user_defined_name: str) -> GroupBy:

if isinstance(self.keys, (Strings, pdarray, Categorical)):
self.keys.register(f"{user_defined_name}_{self.keys.objtype}.keys")
self.unique_keys.register(f"{user_defined_name}_{self.keys.objtype}.unique_keys")
self.unique_keys.register(
f"{user_defined_name}_{self.keys.objtype}.unique_keys"
)
elif isinstance(self.keys, Sequence):
for x in range(len(self.keys)):
# Possible for multiple types in a sequence, so we have to check each key's
# type individually
if isinstance(self.keys[x], (Strings, pdarray, Categorical)):
self.keys[x].register(f"{x}_{user_defined_name}_{self.keys[x].objtype}.keys")
self.keys[x].register(
f"{x}_{user_defined_name}_{self.keys[x].objtype}.keys"
)
self.unique_keys[x].register(
f"{x}_{user_defined_name}_{self.keys[x].objtype}.unique_keys"
)
Expand Down Expand Up @@ -1211,7 +1259,9 @@ def is_registered(self) -> bool:
f"^\\d+_{self.name}_.+\\.keys$|^\\d+_{self.name}_.+\\.unique_keys$|"
f"^\\d+_{self.name}_.+\\.unique_keys(?=\\.categories$)"
)
cat_regEx = compile(f"^\\d+_{self.name}_{Categorical.objtype}\\.keys(?=\\.codes$)")
cat_regEx = compile(
f"^\\d+_{self.name}_{Categorical.objtype}\\.keys(?=\\.codes$)"
)

simple_registered = list(filter(regEx.match, registry))
cat_registered = list(filter(cat_regEx.match, registry))
Expand Down Expand Up @@ -1296,7 +1346,9 @@ def attach(user_defined_name: str) -> GroupBy:
matches.sort()

if len(matches) == 0:
raise RegistrationError(f"No registered elements with name '{user_defined_name}'")
raise RegistrationError(
f"No registered elements with name '{user_defined_name}'"
)

for name in matches:
# Parse the name for the dtype and use the proper create method to create the element
Expand Down Expand Up @@ -1513,7 +1565,7 @@ def broadcast(
else:
pname = permutation.name
permute = True
size = permutation.size
size = cast(Union[int, np.int64, np.uint64], permutation.size)
if size < 1:
raise ValueError("result size must be greater than zero")
cmd = "broadcast"
Expand Down
Loading

0 comments on commit f9a7fa2

Please sign in to comment.