From 22339176d6dd1116b660730c0e1e2a92ecb24e50 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Thu, 26 Oct 2023 16:57:32 -0400 Subject: [PATCH] feat: implement basic get_row_reorder_df --- gt/utils_render_common.py | 25 +++++++++++++++++++++++++ tests/test_utils_render_common.py | 19 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 tests/test_utils_render_common.py diff --git a/gt/utils_render_common.py b/gt/utils_render_common.py index 8b1378917..8e888411a 100644 --- a/gt/utils_render_common.py +++ b/gt/utils_render_common.py @@ -1 +1,26 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Tuple + +if TYPE_CHECKING: + from ._gt_data import RowGroups, Stub + + +TupleStartFinal = Tuple[int, int] + + +def get_row_reorder_df(groups: RowGroups, stub_df: Stub) -> list[TupleStartFinal]: + if not len(groups): + indices = range(len(stub_df)) + + # TODO: is this used in indexing? If so, we may need to use + # ii + 1 for the final part? + return [(ii, ii) for ii in indices] + + # where in the group each element is + groups_pos = [groups.index(row.group_id) for row in stub_df] + # the index that when used on the rows will sort them by the order in groups + start_pos = list(range(len(groups_pos))) + sort_indx = sorted(start_pos, key=lambda ii: groups_pos[ii]) + + return list(zip(start_pos, sort_indx)) diff --git a/tests/test_utils_render_common.py b/tests/test_utils_render_common.py new file mode 100644 index 000000000..2e0ce4d3c --- /dev/null +++ b/tests/test_utils_render_common.py @@ -0,0 +1,19 @@ +from gt._gt_data import RowGroups, Stub, RowInfo +from gt.utils_render_common import get_row_reorder_df + + +def test_get_row_reorder_df_simple(): + groups = RowGroups(["b", "a"]) + stub = Stub([RowInfo(0, "a"), RowInfo(1, "b"), RowInfo(2, "a")]) + + start_end = get_row_reorder_df(groups, stub) + + assert start_end == [(0, 1), (1, 0), (2, 2)] + + +def test_get_row_reorder_df_no_groups(): + groups = RowGroups() + stub = Stub([RowInfo(0, "a"), RowInfo(1, "b")]) + + start_end = get_row_reorder_df(groups, stub) + assert start_end == [(0,0), (1,1)]