Skip to content

Commit

Permalink
update margins in pivot_table
Browse files Browse the repository at this point in the history
  • Loading branch information
mahaalbashir committed Oct 4, 2023
1 parent 42d708c commit 805036a
Show file tree
Hide file tree
Showing 3 changed files with 874 additions and 1,990 deletions.
166 changes: 107 additions & 59 deletions acro/acro_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,18 @@ def crosstab( # pylint: disable=too-many-arguments,too-many-locals
)
else:
table = crosstab_with_totals(
masks,
aggfunc,
index,
columns,
values,
rownames,
colnames,
margins,
margins_name,
dropna,
normalize,
masks=masks,
aggfunc=agg_func,
index=index,
columns=columns,
values=values,
margins=margins,
margins_name=margins_name,
dropna=dropna,
crosstab=True,
rownames=rownames,
colnames=colnames,
normalize=normalize,
)

# record output
Expand Down Expand Up @@ -210,6 +211,7 @@ def pivot_table( # pylint: disable=too-many-arguments,too-many-locals
margins_name: str = "All",
observed: bool = False,
sort: bool = True,
show_suppressed: bool = False,
) -> DataFrame:
"""Create a spreadsheet-style pivot table as a DataFrame.
Expand Down Expand Up @@ -253,6 +255,8 @@ def pivot_table( # pylint: disable=too-many-arguments,too-many-locals
all values for categorical groupers.
sort : bool, default True
Specifies if the result should be sorted.
show_suppressed : bool, default False
how the totals are being calculated when the suppression is true.
Returns
-------
Expand Down Expand Up @@ -330,6 +334,26 @@ def pivot_table( # pylint: disable=too-many-arguments,too-many-locals
safe_table, outcome = apply_suppression(table, masks)
if self.suppress:
table = safe_table
if margins:
logger.info(
"Disclosive cells were deleted from the dataframe "
"before calculating the pivot table"
)
table = crosstab_with_totals(
masks=masks,
aggfunc=aggfunc,
index=index,
columns=columns,
values=values,
margins=margins,
margins_name=margins_name,
dropna=dropna,
crosstab=False,
data=data,
fill_value=fill_value,
observed=observed,
sort=sort,
)
# record output
self.results.add(
status=status,
Expand Down Expand Up @@ -940,7 +964,8 @@ def get_queries(masks, aggfunc) -> list[str]:
for _, mask in masks.items():
# drop the name of the mask
if aggfunc is not None:
mask = mask.droplevel(0, axis=1)
if mask.columns.nlevels > 1:
mask = mask.droplevel(0, axis=1)
# identify level names for rows and columns
index_level_names = mask.index.names
column_level_names = mask.columns.names
Expand Down Expand Up @@ -1059,12 +1084,17 @@ def crosstab_with_totals( # pylint: disable=too-many-arguments,too-many-locals
index,
columns,
values,
rownames,
colnames,
margins,
margins_name,
dropna,
normalize,
crosstab,
rownames=None,
colnames=None,
normalize=False,
data=None,
fill_value=None,
observed=False,
sort=False,
) -> DataFrame:
"""Recalculate the crosstab table when margins are true and suppression is true.
Expand All @@ -1080,60 +1110,76 @@ def crosstab_with_totals( # pylint: disable=too-many-arguments,too-many-locals
Values to group by in the columns.
index : array-like, Series, or list of arrays/Series
Values to group by in the rows.
columns : array-like, Series, or list of arrays/Series
Values to group by in the columns.
values : array-like, optional
Array of values to aggregate according to the factors.
Requires `aggfunc` be specified.
rownames : sequence, default None
If passed, must match number of row arrays passed.
colnames : sequence, default None
If passed, must match number of column arrays passed.
aggfunc : str, optional
If specified, requires `values` be specified as well.
margins : bool, default False
Add row/column margins (subtotals).
margins_name : str, default 'All'
Name of the row/column that will contain the totals
when margins is True.
dropna : bool, default True
Do not include columns whose entries are all NaN.
normalize : bool, {'all', 'index', 'columns'}, or {0,1}, default False
Normalize by dividing all values by the sum of values.
- If passed 'all' or `True`, will normalize over all values.
- If passed 'index' will normalize over each row.
- If passed 'columns' will normalize over each column.
- If margins is `True`, will also normalize margin values.
columns : array-like, Series, or list of arrays/Series
Values to group by in the columns.
values : array-like, optional
Array of values to aggregate according to the factors.
Requires `aggfunc` be specified.
rownames : sequence, default None
If passed, must match number of row arrays passed.
colnames : sequence, default None
If passed, must match number of column arrays passed.
aggfunc : str, optional
If specified, requires `values` be specified as well.
margins : bool, default False
Add row/column margins (subtotals).
margins_name : str, default 'All'
Name of the row/column that will contain the totals
when margins is True.
dropna : bool, default True
Do not include columns whose entries are all NaN.
normalize : bool, {'all', 'index', 'columns'}, or {0,1}, default False
Normalize by dividing all values by the sum of values.
- If passed 'all' or `True`, will normalize over all values.
- If passed 'index' will normalize over each row.
- If passed 'columns' will normalize over each column.
- If margins is `True`, will also normalize margin values.
Returns
-------
DataFrame
Crosstabulation of data
"""
true_cell_queries = get_queries(masks, aggfunc)
data = create_dataframe(index, columns)

if crosstab:
data = create_dataframe(index, columns)
# apply the queries to the data
for query in true_cell_queries:
query = str(query).replace("['", "").replace("']", "")
data = data.query(f"not ({query})")

# get the index and columns from the data after the queries are applied
try:
index_new, columns_new = get_index_columns(index, columns, data)
# apply the crosstab with the new index and columns
table = pd.crosstab( # type: ignore
index_new,
columns_new,
values=values,
rownames=rownames,
colnames=colnames,
aggfunc=aggfunc,
margins=margins,
margins_name=margins_name,
dropna=dropna,
normalize=normalize,
)
if crosstab:
index_new, columns_new = get_index_columns(index, columns, data)
# apply the crosstab with the new index and columns
table = pd.crosstab( # type: ignore
index_new,
columns_new,
values=values,
rownames=rownames,
colnames=colnames,
aggfunc=aggfunc,
margins=margins,
margins_name=margins_name,
dropna=dropna,
normalize=normalize,
)
else:
table = pd.pivot_table( # type: ignore
data=data,
values=values,
index=index,
columns=columns,
aggfunc=aggfunc,
fill_value=fill_value,
margins=margins,
dropna=dropna,
margins_name=margins_name,
observed=observed,
sort=sort,
)

except ValueError:
logger.warning(
"All the cells in this data are disclosive."
Expand Down Expand Up @@ -1205,8 +1251,8 @@ def manual_crossstab_with_totals( # pylint: disable=too-many-arguments,too-many

elif aggfunc == "mean":
count_table = pd.crosstab( # type: ignore
index,
columns,
index=index,
columns=columns,
values=values,
rownames=rownames,
colnames=colnames,
Expand All @@ -1221,14 +1267,16 @@ def manual_crossstab_with_totals( # pylint: disable=too-many-arguments,too-many
# delete any columns from the count_table that are not in the table
columns_to_keep = table.columns
count_table = count_table[columns_to_keep]
count_table = count_table.sort_index(axis=1)
if count_table.index.is_numeric():
count_table = count_table.sort_index(axis=1)

Check warning on line 1271 in acro/acro_tables.py

View check run for this annotation

Codecov / codecov/patch

acro/acro_tables.py#L1271

Added line #L1271 was not covered by tests
# recalculate the margins considering the nan values
count_table = recalculate_margin(count_table, margins_name)
# multiply the table by the count table
table[margins_name] = 1
table.loc[margins_name, :] = 1
multip_table = count_table * table
multip_table = multip_table.sort_index(axis=1)
if multip_table.index.is_numeric():
multip_table = multip_table.sort_index(axis=1)

Check warning on line 1279 in acro/acro_tables.py

View check run for this annotation

Codecov / codecov/patch

acro/acro_tables.py#L1279

Added line #L1279 was not covered by tests
# calculate the margins columns
table[margins_name] = (
multip_table.drop(margins_name, axis=1).sum(axis=1)
Expand Down
Loading

0 comments on commit 805036a

Please sign in to comment.