-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ✨ Allow `base.c()` to handle groupby data * 🚑 Allow `base.diff()` to work with groupby data * ✨ Allow `forcats.fct_inorder()` to work with groupby data * Add SeriesGroupBy as available type for forcats verbs * 🚑 Fix `base.diff()` not keep empty groups ✨Allow `base.rep()`'s arguments `length` and `each` to work with grouped data ✨Allow `base.c()` to work with grouped data 🐛 Fix recycling non-ordered grouped data 🐛 Force `&/|` operators to return boolean data 🚑 Make `dplyr.n()` return groupoed data 🩹 Fix `dplyr.count()/tally()`'s warning about the new name 🐛 Make `dplyr.slice()` work better with rows/indices from grouped data * ✨ Add `datar.attrgetter()`, `datar.pd_str()`, `datar.pd_cat()` and `datar.pd_dt()` * 🚑 Fix `base.c()` with grouped data * 📝 Update docs for `datar.datar` * 🔖 0.6.3 * Update readme.ipynb
- Loading branch information
Showing
32 changed files
with
1,339 additions
and
303 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
from functools import singledispatch | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from pandas import DataFrame, Series, Categorical | ||
from pandas.api.types import is_scalar, is_integer | ||
from pandas.core.groupby import SeriesGroupBy | ||
from pipda import register_func | ||
|
||
from ..core.contexts import Context | ||
from ..core.tibble import TibbleGrouped, reconstruct_tibble | ||
from ..core.utils import ensure_nparray, logger | ||
|
||
|
||
def _rep(x, times, length, each): | ||
"""Repeat sequence x""" | ||
x = ensure_nparray(x) | ||
times = ensure_nparray(times) | ||
length = ensure_nparray(length) | ||
each = ensure_nparray(each) | ||
if times.size == 1: | ||
times = times[0] | ||
if length.size >= 1: | ||
if length.size > 1: | ||
logger.warning( | ||
"In rep(...) : first element used of 'length' argument" | ||
) | ||
length = length[0] | ||
if each.size == 1: | ||
each = each[0] | ||
|
||
if not is_scalar(times): | ||
if times.size != x.size: | ||
raise ValueError( | ||
"Invalid times argument, expect length " | ||
f"{x.size}, got {times.size}" | ||
) | ||
|
||
if not is_integer(each) or each != 1: | ||
raise ValueError( | ||
"Unexpected each argument when times is an iterable." | ||
) | ||
|
||
if is_integer(times) and is_scalar(times): | ||
x = np.tile(np.repeat(x, each), times) | ||
else: | ||
x = np.repeat(x, times) | ||
|
||
if length is None: | ||
return x | ||
|
||
repeats = length // x.size + 1 | ||
x = np.tile(x, repeats) | ||
|
||
return x[:length] | ||
|
||
|
||
@singledispatch | ||
def _rep_dispatched(x, times, length, each): | ||
"""Repeat sequence x""" | ||
times_sgb = isinstance(times, SeriesGroupBy) | ||
length_sgb = isinstance(length, SeriesGroupBy) | ||
each_sgb = isinstance(each, SeriesGroupBy) | ||
values = {} | ||
if times_sgb: | ||
values["times"] = times | ||
if length_sgb: | ||
values["length"] = length | ||
if each_sgb: | ||
values["each"] = each | ||
|
||
if values: | ||
from ..tibble import tibble | ||
df = tibble(**values) | ||
out = df._datar["grouped"].apply( | ||
lambda subdf: _rep( | ||
x, | ||
times=subdf["times"] if times_sgb else times, | ||
length=subdf["length"] if length_sgb else length, | ||
each=subdf["each"] if each_sgb else each, | ||
) | ||
) | ||
non_na_out = out[out.transform(len) > 0] | ||
non_na_out = non_na_out.explode() | ||
grouping = Categorical(non_na_out.index, categories=out.index.unique()) | ||
return ( | ||
non_na_out.explode() | ||
.reset_index(drop=True) | ||
.groupby(grouping, observed=False) | ||
) | ||
|
||
return _rep(x, times, length, each) | ||
|
||
|
||
@_rep_dispatched.register(Series) | ||
def _(x, times, length, each): | ||
return _rep_dispatched.dispatch(object)(x.values, times, length, each) | ||
|
||
|
||
@_rep_dispatched.register(SeriesGroupBy) | ||
def _(x, times, length, each): | ||
from ..tibble import tibble | ||
df = tibble(x=x) | ||
times_sgb = isinstance(times, SeriesGroupBy) | ||
length_sgb = isinstance(length, SeriesGroupBy) | ||
each_sgb = isinstance(each, SeriesGroupBy) | ||
if times_sgb: | ||
df["times"] = times | ||
if length_sgb: | ||
df["length"] = length | ||
if each_sgb: | ||
df["each"] = each | ||
|
||
out = df._datar["grouped"].apply( | ||
lambda subdf: _rep( | ||
subdf["x"], | ||
times=subdf["times"] if times_sgb else times, | ||
length=subdf["length"] if length_sgb else length, | ||
each=subdf["each"] if each_sgb else each, | ||
) | ||
).explode().astype(x.obj.dtype) | ||
grouping = out.index | ||
return out.reset_index(drop=True).groupby(grouping) | ||
|
||
|
||
@_rep_dispatched.register(DataFrame) | ||
def _(x, times, length, each): | ||
if not is_integer(each) or each != 1: | ||
raise ValueError( | ||
"`each` has to be 1 to replicate a data frame." | ||
) | ||
|
||
out = pd.concat([x] * times, ignore_index=True) | ||
if length is not None: | ||
out = out.iloc[:length, :] | ||
|
||
return out | ||
|
||
|
||
@_rep_dispatched.register(TibbleGrouped) | ||
def _(x, times, length, each): | ||
out = _rep_dispatched.dispatch(DataFrame)(x, times, length, each) | ||
return reconstruct_tibble(x, out) | ||
|
||
|
||
@register_func(None, context=Context.EVAL) | ||
def rep( | ||
x, | ||
times=1, | ||
length=None, | ||
each=1, | ||
): | ||
"""replicates the values in x | ||
Args: | ||
x: a vector or scaler | ||
times: number of times to repeat each element if of length len(x), | ||
or to repeat the whole vector if of length 1 | ||
length: non-negative integer. The desired length of the output vector | ||
each: non-negative integer. Each element of x is repeated each times. | ||
Returns: | ||
An array of repeated elements in x. | ||
""" | ||
return _rep_dispatched(x, times, length, each) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.