diff --git a/arkouda/dataframe.py b/arkouda/dataframe.py index bdeed1b7a..a6deade21 100644 --- a/arkouda/dataframe.py +++ b/arkouda/dataframe.py @@ -4,6 +4,7 @@ import os import random from collections import UserDict +from functools import reduce from typing import Callable, Dict, List, Optional, Tuple, Union, cast from warnings import warn @@ -24,10 +25,11 @@ from arkouda.join import inner_join from arkouda.numpy import cast as akcast from arkouda.numpy import cumsum, where -from arkouda.numpy.dtypes import bigint +from arkouda.numpy.dtypes import _is_dtype_in_union, bigint from arkouda.numpy.dtypes import bool_ as akbool from arkouda.numpy.dtypes import float64 as akfloat64 from arkouda.numpy.dtypes import int64 as akint64 +from arkouda.numpy.dtypes import numeric_scalars from arkouda.numpy.dtypes import uint64 as akuint64 from arkouda.pdarrayclass import RegistrationError, pdarray from arkouda.pdarraycreation import arange, array, create_pdarray, full, zeros @@ -105,6 +107,7 @@ class DataFrameGroupBy: """ def __init__(self, gb, df, gb_key_names=None, as_index=True): + self.gb = gb self.df = df self.gb_key_names = gb_key_names @@ -112,6 +115,39 @@ def __init__(self, gb, df, gb_key_names=None, as_index=True): for attr in ["nkeys", "permutation", "unique_keys", "segments"]: setattr(self, attr, getattr(gb, attr)) + self.dropna = self.gb.dropna + self.where_not_nan = None + self.all_non_nan = False + + if self.dropna: + from arkouda import all as ak_all + from arkouda import isnan + + # calculate ~isnan on each key then & them all together + # keep up with if they're all_non_nan, so we can skip indexing later + key_cols = ( + [df[k] for k in gb_key_names] if isinstance(gb_key_names, List) else [df[gb_key_names]] + ) + where_key_not_nan = [ + ~isnan(col) + for col in key_cols + if isinstance(col, pdarray) and _is_dtype_in_union(col.dtype, numeric_scalars) + ] + + if len(where_key_not_nan) == 0: + # if empty then none of the keys are pdarray, so non are nan + self.all_non_nan = True + else: + self.where_not_nan = reduce(lambda x, y: x & y, where_key_not_nan) + self.all_non_nan = ak_all(self.where_not_nan) + + def _get_df_col(self, c): + # helper function to mask out the values where the keys are nan when dropna is True + if not self.dropna or self.all_non_nan: + return self.df.data[c] + else: + return self.df.data[c][self.where_not_nan] + @classmethod def _make_aggop(cls, opname): numerical_dtypes = [akfloat64, akint64, akuint64] @@ -148,18 +184,18 @@ def aggop(self, colnames=None): if isinstance(colnames, List): if isinstance(self.gb_key_names, str): return DataFrame( - {c: self.gb.aggregate(self.df.data[c], opname)[1] for c in colnames}, + {c: self.gb.aggregate(self._get_df_col(c), opname)[1] for c in colnames}, index=Index(self.gb.unique_keys, name=self.gb_key_names), ) elif isinstance(self.gb_key_names, list) and len(self.gb_key_names) == 1: return DataFrame( - {c: self.gb.aggregate(self.df.data[c], opname)[1] for c in colnames}, + {c: self.gb.aggregate(self._get_df_col(c), opname)[1] for c in colnames}, index=Index(self.gb.unique_keys, name=self.gb_key_names[0]), ) elif isinstance(self.gb_key_names, list): column_dict = dict(zip(self.gb_key_names, self.unique_keys)) for c in colnames: - column_dict[c] = self.gb.aggregate(self.df.data[c], opname)[1] + column_dict[c] = self.gb.aggregate(self._get_df_col(c), opname)[1] return DataFrame(column_dict) else: return None diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py index 693d91f01..fd6c1f6b9 100644 --- a/tests/dataframe_test.py +++ b/tests/dataframe_test.py @@ -1,7 +1,5 @@ -import glob import itertools import os -import tempfile import numpy as np import pandas as pd @@ -652,16 +650,47 @@ def test_gb_aggregations_example_numeric_types(self, agg): pd_result = getattr(pd_df.groupby(group_on), agg)() assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result) + @pytest.mark.parametrize("dropna", [True, False]) @pytest.mark.parametrize("agg", ["count", "max", "mean", "median", "min", "std", "sum", "var"]) - def test_gb_aggregations_with_nans(self, agg): + def test_gb_aggregations_with_nans(self, agg, dropna): df = self.build_ak_df_with_nans() # @TODO handle bool columns correctly df.drop("bools", axis=1, inplace=True) pd_df = df.to_pandas() group_on = ["key1", "key2"] - ak_result = getattr(df.groupby(group_on), agg)() - pd_result = getattr(pd_df.groupby(group_on, as_index=False), agg)() + ak_result = getattr(df.groupby(group_on, dropna=dropna), agg)() + pd_result = getattr(pd_df.groupby(group_on, as_index=False, dropna=dropna), agg)() + assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result) + + # TODO aggregations of string columns not currently supported (even for count) + df.drop("key1", axis=1, inplace=True) + df.drop("key2", axis=1, inplace=True) + pd_df = df.to_pandas() + + group_on = ["nums1", "nums2"] + ak_result = getattr(df.groupby(group_on, dropna=dropna), agg)() + pd_result = getattr(pd_df.groupby(group_on, as_index=False, dropna=dropna), agg)() + assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result) + + # TODO aggregation mishandling NaN see issue #3765 + df.drop("nums2", axis=1, inplace=True) + pd_df = df.to_pandas() + group_on = "nums1" + ak_result = getattr(df.groupby(group_on, dropna=dropna), agg)() + pd_result = getattr(pd_df.groupby(group_on, dropna=dropna), agg)() + assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result) + + @pytest.mark.parametrize("dropna", [True, False]) + def test_count_nan_bug(self, dropna): + # verify reproducer for #3762 is fixed + df = ak.DataFrame({"A": [1, 2, 2, np.nan], "B": [3, 4, 5, 6], "C": [1, np.nan, 2, 3]}) + ak_result = df.groupby("A", dropna=dropna).count() + pd_result = df.to_pandas().groupby("A", dropna=dropna).count() + assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result) + + ak_result = df.groupby(["A", "C"], as_index=False, dropna=dropna).count() + pd_result = df.to_pandas().groupby(["A", "C"], as_index=False, dropna=dropna).count() assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result) def test_gb_aggregations_return_dataframe(self):