Skip to content

Commit

Permalink
🩹 Make base.ntile() labels 1-based (#92)
Browse files Browse the repository at this point in the history
* 🩹 Make `base.ntile()` labels 1-based

* 🚑 Allow `base.paste/paste0()` to work with grouped data

* Update CHANGLOG
  • Loading branch information
pwwang authored Mar 16, 2022
1 parent 4db4f0a commit e67e8e0
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 46 deletions.
46 changes: 30 additions & 16 deletions datar/base/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import numpy as np
import pandas as pd
from pandas import Series
from pandas.core.base import PandasObject
from pandas.core.groupby import SeriesGroupBy
from pandas.api.types import is_string_dtype, is_scalar
from pipda import register_func


from ..core.tibble import TibbleRowwise
from ..core.tibble import TibbleGrouped, TibbleRowwise
from ..core.contexts import Context
from ..core.factory import func_factory, dispatching
from ..core.utils import (
Expand All @@ -21,7 +22,6 @@
from .casting import _as_type
from .testing import _register_type_testing
from .logical import as_logical
from .seq import lengths


def _recycle_value(value, size, name=None):
Expand Down Expand Up @@ -396,6 +396,9 @@ def _nchar_scalar(x, retn, allow_na, keep_na, na_len):


# paste and paste0 --------------------
_is_empty = lambda x: (
(is_scalar(x) and not x) or (not is_scalar(x) and len(x) == 0)
)


@register_func(None, context=Context.EVAL)
Expand All @@ -411,24 +414,35 @@ def paste(*args, sep=" ", collapse=None):
A single string if collapse is given, otherwise an array of strings.
"""
if len(args) == 1 and isinstance(args[0], TibbleRowwise):
return args[0].apply(
lambda row: paste(*row, sep=sep, collapse=collapse), axis=1
out = args[0].apply(
lambda row: row.astype(str).str.cat(sep=sep), axis=1
)
return collapse.join(out) if collapse else out

maxlen = max(regcall(lengths, args))
args = zip(
*(
_recycle_value(arg, maxlen, f"{i}th value")
for i, arg in enumerate(args)
)
)
from ..tibble import tibble

args = [as_character(arg, _na="NA") for arg in args]
out = [sep.join(arg) for arg in args]
if collapse is not None:
return collapse.join(out)
if all(_is_empty(arg) for arg in args):
df = tibble(*args, _name_repair="minimal")
else:
df = tibble(
*("" if _is_empty(arg) else arg for arg in args),
_name_repair="minimal",
)

return np.array(out, dtype=object)
if not isinstance(df, TibbleGrouped):
out = df.apply(lambda col: col.astype(str).str.cat(sep=sep), axis=1)
if collapse:
return collapse.join(out)
if any(isinstance(x, PandasObject) for x in args):
return out
return np.array(out, dtype=object)

out = df.apply(
lambda row: row.astype(str).str.cat(sep=sep), axis=1
).groupby(df._datar["grouped"].grouper)
if collapse:
out = out.agg(lambda x: x.str.cat(sep=collapse))
return out


@register_func(None, context=Context.EVAL)
Expand Down
4 changes: 2 additions & 2 deletions datar/dplyr/_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _(x, n):
return Categorical([np.nan] * x.size)

n = min(n, x.size)
return pd.cut(x, n, labels=range(n))
return pd.cut(x, n, labels=np.arange(n) + 1)


@_ntile.register(GroupBy)
Expand All @@ -132,7 +132,7 @@ def _(x, n):
lambda grup: pd.cut(
grup,
min(n, len(grup)),
labels=range(min(n, len(grup))),
labels=np.arange(min(n, len(grup))) + 1,
)
)

Expand Down
6 changes: 3 additions & 3 deletions datar/dplyr/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def ntile(
"""A rough rank, which breaks the input vector into `n` buckets.
Note:
The output tiles are 0-based.
The output tiles are 1-based.
The result is slightly different from dplyr's ntile.
>>> ntile(c(1,2,NA,1,0,NA), 2) # dplyr
>>> # 1 2 NA 2 1 NA
>>> ntile([1,2,NA,1,0,NA], n=2) # datar
>>> # [0, 1, NA, 0, 0, NA]
>>> # Categories (2, int64): [0 < 1]
>>> # [1, 2, NA, 1, 1, NA]
>>> # Categories (2, int64): [1 < 2]
"""
if isinstance(x, int) and n is None:
n = x
Expand Down
2 changes: 2 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
- ✨ Allow `forcats.fct_inorder()` to work with groupby data
- ✨ Allow `base.rep()`'s arguments `length` and `each` to work with grouped data
- ✨ Allow `base.c()` to work with grouped data
- ✨ Allow `base.paste()`/`base.paste0()` to work with grouped data
- 🐛 Force `&/|` operators to return boolean data
- 🚑 Fix `base.diff()` not keep empty groups
- 🐛 Fix recycling non-ordered grouped data
- 🩹 Fix `dplyr.count()/tally()`'s warning about the new name
- 🚑 Make `dplyr.n()` return groupoed data
- 🐛 Make `dplyr.slice()` work better with rows/indices from grouped data
- 🩹 Make `dplyr.ntile()` labels 1-based
- ✨ Add `datar.attrgetter()`, `datar.pd_str()`, `datar.pd_cat()` and `datar.pd_dt()`

## 0.6.2
Expand Down
10 changes: 9 additions & 1 deletion tests/base/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,15 @@ def test_paste():
assert_iterable_equal(out, [])

out = paste0([], ["a"])
assert_iterable_equal(out, ["NAa"])
assert_iterable_equal(out, ["a"])

df = tibble(x=[1, 2, 3], y=[4, 5, 5])
out = paste(df)
assert_iterable_equal(out, ["1 4", "2 5", "3 5"])

gf = df.group_by("y")
out = paste0(gf, collapse="|")
assert_iterable_equal(out, ['14', '25|35'])


def test_sprintf():
Expand Down
4 changes: 2 additions & 2 deletions tests/dplyr/test_mutate_windowed.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_rank_functions_deal_correctly_with_na():
assert res.min_rank[[0, 1, 3, 4]].tolist() == c(2, 4, 2, 1)
assert res.dense_rank[[0, 1, 3, 4]].tolist() == c(2, 3, 2, 1)
assert res.cume_dist[[0, 1, 3, 4]].tolist() == c(0.75, 1.0, 0.75, 0.25)
assert res.ntile[[0, 1, 3, 4]].tolist() == c(0, 1, 0, 0)
assert res.ntile[[0, 1, 3, 4]].tolist() == c(1, 2, 1, 1)
assert res.row_number[[0, 1, 3, 4]].tolist() == c(2, 4, 3, 1)

data = tibble(x=rep(c(1, 2, NA, 1, 0, NA), 2), g=rep([1, 2], each=6))
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_rank_functions_deal_correctly_with_na():
)
assert (
res.ntile.obj[[0, 1, 3, 4, 6, 7, 9, 10]].tolist()
== rep(c(0, 1, 0, 0), 2).tolist()
== rep(c(1, 2, 1, 1), 2).tolist()
)
assert (
res.row_number.obj[[0, 1, 3, 4, 6, 7, 9, 10]].tolist()
Expand Down
45 changes: 23 additions & 22 deletions tests/dplyr/test_rank.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# tests grabbed from:
# https://github.com/tidyverse/dplyr/blob/master/tests/testthat/test-rank.r
import numpy as np
import pytest

from datar import f
Expand Down Expand Up @@ -29,16 +30,16 @@ def ntile_h(x, n):
def test_ntile_ignores_number_of_nas():
x = c(1, 2, 3, NA, NA)
out = ntile(x, 3)
assert_iterable_equal(out, [0, 1, 2, NA, NA])
assert_iterable_equal(out, [1, 2, 3, NA, NA])

out = ntile_h(x, 3)
assert_iterable_equal(out, [0, 1, 2, NA, NA])
assert_iterable_equal(out, [1, 2, 3, NA, NA])

x1 = c(1, 1, 1, NA, NA, NA)
out = ntile(x1, n=1)
assert_iterable_equal(out, [0, 0, 0, NA, NA, NA])
assert_iterable_equal(out, [1, 1, 1, NA, NA, NA])
out = ntile_h(x1, 1)
assert_iterable_equal(out, [0, 0, 0, NA, NA, NA])
assert_iterable_equal(out, [1, 1, 1, NA, NA, NA])


def test_ntile_always_returns_an_integer():
Expand Down Expand Up @@ -107,20 +108,20 @@ def test_lead_lag_inside_mutates_handles_expressions_as_value_for_default():

def test_ntile_puts_large_groups_first():

assert_iterable_equal(ntile(range(1), n=5), [0])
assert_iterable_equal(ntile(range(2), n=5), list(range(2)))
assert_iterable_equal(ntile(range(3), n=5), list(range(3)))
assert_iterable_equal(ntile(range(4), n=5), list(range(4)))
assert_iterable_equal(ntile(range(5), n=5), list(range(5)))
assert_iterable_equal(ntile(range(6), n=5), c(0, range(5)))
assert_iterable_equal(ntile(range(1), n=7), [0])
assert_iterable_equal(ntile(range(2), n=7), list(range(2)))
assert_iterable_equal(ntile(range(3), n=7), list(range(3)))
assert_iterable_equal(ntile(range(4), n=7), list(range(4)))
assert_iterable_equal(ntile(range(5), n=7), list(range(5)))
assert_iterable_equal(ntile(range(6), n=7), list(range(6)))
assert_iterable_equal(ntile(range(7), n=7), list(range(7)))
assert_iterable_equal(ntile(range(8), n=7), c(0, range(7)))
assert_iterable_equal(ntile(range(1), n=5), [1])
assert_iterable_equal(ntile(range(2), n=5), np.arange(2) + 1)
assert_iterable_equal(ntile(range(3), n=5), np.arange(3) + 1)
assert_iterable_equal(ntile(range(4), n=5), np.arange(4) + 1)
assert_iterable_equal(ntile(range(5), n=5), np.arange(5) + 1)
assert_iterable_equal(ntile(range(6), n=5), c(1, np.arange(5) + 1))
assert_iterable_equal(ntile(range(1), n=7), [1])
assert_iterable_equal(ntile(range(2), n=7), np.arange(2) + 1)
assert_iterable_equal(ntile(range(3), n=7), np.arange(3) + 1)
assert_iterable_equal(ntile(range(4), n=7), np.arange(4) + 1)
assert_iterable_equal(ntile(range(5), n=7), np.arange(5) + 1)
assert_iterable_equal(ntile(range(6), n=7), np.arange(6) + 1)
assert_iterable_equal(ntile(range(7), n=7), np.arange(7) + 1)
assert_iterable_equal(ntile(range(8), n=7), c(1, np.arange(7) + 1))


def test_plain_arrays():
Expand All @@ -129,9 +130,9 @@ def test_plain_arrays():
out = row_number([1, 1, 2])
assert_iterable_equal(out, [1, 2, 3])
out = ntile(1, 1)
assert_iterable_equal(out, [0])
assert_iterable_equal(out, [1])
out = ntile((i for i in range(1)), 1)
assert_iterable_equal(out, [0])
assert_iterable_equal(out, [1])
out = cume_dist(1)
assert_iterable_equal(out, [1])
out = cume_dist([])
Expand All @@ -155,11 +156,11 @@ def test_row_number_with_groups():
def test_ntile_with_groups():
df = tibble(x=f[1:9], y=[1] * 4 + [2] * 4)
out = ntile(df.x, 2)
assert out.tolist() == [0, 0, 0, 0, 1, 1, 1, 1]
assert out.tolist() == [1, 1, 1, 1, 2, 2, 2, 2]

df = df.groupby("y")
out = ntile(df.x, 2)
assert out.tolist() == [0, 0, 1, 1, 0, 0, 1, 1]
assert out.tolist() == [1, 1, 2, 2, 1, 1, 2, 2]


def test_min_rank_with_groups():
Expand Down

0 comments on commit e67e8e0

Please sign in to comment.