diff --git a/gt/_body.py b/gt/_body.py index 1a7de8a6b..13b08b4b2 100644 --- a/gt/_body.py +++ b/gt/_body.py @@ -1,5 +1,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING +from .utils_render_common import get_row_reorder_df -class BodyAPI: - pass + +if TYPE_CHECKING: + from ._gt_data import Body, RowGroups, Stub, Boxhead + + +def body_reassemble(body: Body, row_groups: RowGroups, stub_df: Stub, boxhead: Boxhead) -> Body: + cols = [col_info.var for col_info in boxhead] + start_final = get_row_reorder_df(row_groups, stub_df) diff --git a/gt/_boxhead.py b/gt/_boxhead.py index b0e50ef6c..754960c67 100644 --- a/gt/_boxhead.py +++ b/gt/_boxhead.py @@ -59,9 +59,9 @@ def cols_label(self, **kwargs: str): def _print_boxhead(self) -> pd.DataFrame: boxhead_list = list( zip( - [x.var for x in self._boxhead._boxhead], - [x.visible for x in self._boxhead._boxhead], - [x.column_label for x in self._boxhead._boxhead], + [x.var for x in self._boxhead], + [x.visible for x in self._boxhead], + [x.column_label for x in self._boxhead], ) ) return pd.DataFrame(boxhead_list, columns=["var", "visible", "column_label"]) diff --git a/gt/_gt_data.py b/gt/_gt_data.py index 017e42c10..651da4738 100644 --- a/gt/_gt_data.py +++ b/gt/_gt_data.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import overload, TypeVar +from typing_extensions import Self from dataclasses import dataclass # Note that we replace with with collections.abc after python 3.8 @@ -20,12 +21,19 @@ def __getitem__(self, ii: int) -> T: ... @overload - def __getitem__(self, ii: slice) -> _Sequence[T]: + def __getitem__(self, ii: slice) -> Self[T]: ... - def __getitem__(self, ii: int | slice) -> T | _Sequence[T]: + @overload + def __getitem__(self, ii: list[int]) -> Self[T]: + ... + + + def __getitem__(self, ii: int | slice | list[int]) -> T | Self[T]: if isinstance(ii, slice): return self.__class__(self._d[ii]) + elif isinstance(ii, list): + return self.__class__([self._d[el] for el in ii]) return self._d[ii] @@ -46,17 +54,16 @@ def __repr__(self): # from ._formats import FormatInfo +# TODO: it seems like this could just be a DataFrameLike object? +# Similar to TblData now being a DataFrame, rather than its own class +# I've left for now, and have just implemented concretes for it in +# _tbl_data.py class Body: body: TblData data: Any - def __init__(self, body: Union[pd.DataFrame, TblData], data: Any = None): - if isinstance(body, DataFrameLike): - self.body = pd.DataFrame( - pd.NA, index=body.index, columns=body.columns, dtype="string" - ) - else: - raise NotImplementedError() + def __init__(self, body: Union[pd.DataFrame, TblData]): + self.body = body def render_formats( self, data_tbl: TblData, formats: List[FormatInfo], context: Context @@ -74,6 +81,14 @@ def render_formats( return self + @classmethod + def from_empty(cls, body: DataFrameLike): + empty_df = pd.DataFrame( + pd.NA, index=body.index, columns=body.columns, dtype="string" + ) + + return cls(empty_df) + # Boxhead ---- __Boxhead = None @@ -124,26 +139,29 @@ def __init__( self.column_width = column_width -class Boxhead: - _boxhead: List[ColInfo] +class Boxhead(_Sequence[ColInfo]): + _d: List[ColInfo] - def __init__(self, data: TblData): - # Obtain the column names from the data and initialize the - # `_boxhead` from that - column_names = get_column_names(data) - self._boxhead = [ColInfo(col) for col in column_names] + def __init__(self, data: TblData | list[ColInfo]): + if isinstance(data, list): + self._d = data + else: + # Obtain the column names from the data and initialize the + # `_boxhead` from that + column_names = get_column_names(data) + self._d = [ColInfo(col) for col in column_names] # Get a list of columns def _get_columns(self) -> List[str]: - return [x.var for x in self._boxhead] + return [x.var for x in self._d] # Get a list of column labels def _get_column_labels(self) -> List[str]: - return [x.column_label for x in self._boxhead] + return [x.column_label for x in self._d] # Set column label def _set_column_label(self, column: str, label: str): - for x in self._boxhead: + for x in self._d: if x.var == column: x.column_label = label @@ -151,7 +169,7 @@ def _set_column_label(self, column: str, label: str): # Get a list of visible columns def _get_visible_columns(self) -> List[str]: - visible_columns = [x.var for x in self._boxhead if x.visible is True] + visible_columns = [x.var for x in self._d if x.visible is True] return visible_columns # Get the number of columns for the visible (not hidden) data; this @@ -763,7 +781,7 @@ def from_data(cls, data: TblData, locale: str | None = None): return cls( _tbl_data=data, - _body=Body(data, data), + _body=Body(data), _boxhead=Boxhead(data), # uses get_tbl_data() _stub=stub, # uses get_tbl_data _row_groups=RowGroups(row_groups), diff --git a/gt/_stub.py b/gt/_stub.py index 4f909e4b6..523f09815 100644 --- a/gt/_stub.py +++ b/gt/_stub.py @@ -1,5 +1,12 @@ from __future__ import annotations +from ._gt_data import Stub, RowGroups +from .utils_render_common import get_row_reorder_df -class StubAPI(): - pass + +def reorder_stub_df(stub_df: Stub, row_groups: RowGroups) -> Stub: + start_final = get_row_reorder_df(row_groups, stub_df) + + stub_df = stub_df[[final for _, final in start_final]] + + return stub_df diff --git a/gt/_tbl_data.py b/gt/_tbl_data.py index 0ca94f10c..719c43438 100644 --- a/gt/_tbl_data.py +++ b/gt/_tbl_data.py @@ -145,3 +145,24 @@ def _(data, row: int, column: str, value: Any): def _get_column_dtype(data: DataFrameLike, column: str) -> str: """Get the data type for a single column in the input data table""" return data[column].dtype + + +# reorder ---- + +@singledispatch +def reorder(data: DataFrameLike, rows: list[int], columns: list[str]) -> DataFrameLike: + """Return a re-ordered DataFrame.""" + _raise_not_implemented(data) + + +@reorder.register +def _(data: PdDataFrame, rows: list[int], columns: list[str]) -> PdDataFrame: + # note that because loc is label based, we need + # reset index to allow us to use integer indexing on the rows + # note that this means the index is not preserved when reordering pandas + return data.iloc[rows, :].loc[:, columns] + + +@reorder.register +def _(data: PlDataFrame, rows: list[int], columns: list[str]) -> PlDataFrame: + return data[rows, columns] diff --git a/gt/gt.py b/gt/gt.py index d21644822..1502ee333 100644 --- a/gt/gt.py +++ b/gt/gt.py @@ -19,7 +19,7 @@ # Rewrite main gt imports to use relative imports of APIs ---- from gt._tbl_data import TblDataAPI -from gt._body import BodyAPI +from gt._body import body_reassemble from gt._boxhead import BoxheadAPI from gt._footnotes import FootnotesAPI from gt._formats import FormatsAPI, fmt_number, fmt_integer @@ -29,7 +29,7 @@ from gt._row_groups import RowGroupsAPI from gt._source_notes import SourceNotesAPI from gt._spanners import SpannersAPI -from gt._stub import StubAPI +from gt._stub import reorder_stub_df from gt._stubhead import StubheadAPI from gt._styles import StylesAPI @@ -60,9 +60,7 @@ class GT( GTData, TblDataAPI, - BodyAPI, BoxheadAPI, - StubAPI, RowGroupsAPI, SpannersAPI, HeadingAPI, @@ -113,17 +111,17 @@ def _render_formats(self, context: str) -> GT: def _build_data(self, context: str): # Build the body of the table by generating a dictionary # of lists with cells initially set to nan values - self = copy.copy(self) - self._body = self._body.__class__(self._tbl_data) - self._render_formats(context) - # self._body = _migrate_unformatted_to_output(body) + built = copy.copy(self) + built._body = self._body.__class__(self._tbl_data) + built._render_formats(context) + # built._body = _migrate_unformatted_to_output(body) - # self._perform_col_merge() - # self._body_reassemble() + # built._perform_col_merge() + built._body = body_reassemble(built._body, self._row_groups, self._stub, self._boxhead) # Reordering of the metadata elements of the table - # self = self.reorder_stub_df() + built._stub = reorder_stub_df(self._stub, self._row_groups) # self = self.reorder_footnotes() # self = self.reorder_styles() @@ -225,7 +223,7 @@ def _set_has_built(gt: GT, value: bool) -> GT: def _get_column_labels(gt: GT, context: str) -> List[str]: gt_built = gt._build_data(context=context) - column_labels = [x.column_label for x in gt_built._boxhead._boxhead] + column_labels = [x.column_label for x in gt_built._boxhead] return column_labels diff --git a/tests/test_tbl_data.py b/tests/test_tbl_data.py index 74a91e3e3..53c22297e 100644 --- a/tests/test_tbl_data.py +++ b/tests/test_tbl_data.py @@ -3,7 +3,7 @@ import polars.testing import pytest -from gt._tbl_data import _get_cell, _get_column_dtype, _set_cell, get_column_names, DataFrameLike +from gt._tbl_data import _get_cell, _get_column_dtype, _set_cell, get_column_names, DataFrameLike, reorder params_frames = [ @@ -50,3 +50,13 @@ def test_set_cell(df: DataFrameLike): }) _set_cell(df, 1, 'col2', 'x') assert_frame_equal(df, expected) + + +def test_reorder(df: DataFrameLike): + res = reorder(df, [0, 2], ["col2"]) + dst = df.__class__({"col2": ["a", "c"]}) + + if isinstance(dst, pd.DataFrame): + dst.index = pd.Index([0, 2]) + + assert_frame_equal(res, dst)