Skip to content

Commit

Permalink
Limit scope of changes
Browse files Browse the repository at this point in the history
  • Loading branch information
amol- committed Oct 11, 2024
1 parent e4a2964 commit aa154d8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
6 changes: 5 additions & 1 deletion great_tables/_gt_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def render_formats(self, data_tbl: TblData, formats: list[FormatInfo], context:
# TODO: I think that this is very inefficient with polars, so
# we could either accumulate results and set them per column, or
# could always use a pandas DataFrame inside Body?
self.body = _set_cell(self.body, row, col, result)
new_body = _set_cell(self.body, row, col, result)
if new_body is not None:
# Some backends do not support inplace operations, but return a new dataframe
# TODO: Consolidate the behaviour of _set_cell
self.body = new_body

return self

Expand Down
21 changes: 7 additions & 14 deletions great_tables/_tbl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,23 +246,16 @@ def _set_cell(data: DataFrameLike, row: int, column: str, value: Any):


@_set_cell.register(PdDataFrame)
def _(data, row: int, column: str, value: Any) -> PdDataFrame:
def _(data, row: int, column: str, value: Any) -> None:
# TODO: This assumes column names are unique
# if this is violated, get_loc will return a mask
data_new = data.copy(deep=False) # make a shallow copy and only update the specific column.
data_new[column] = data_new[column].copy()
data_new.at[row, column] = value
return data_new
col_indx = data.columns.get_loc(column)
data.iloc[row, col_indx] = value


@_set_cell.register(PlDataFrame)
def _(data, row: int, column: str, value: Any) -> PlDataFrame:
# While using scatter is considered an antipattern,
# it is easier to read than a when.then.otherwise expression,
# and it is generally better performing.
col_series_modified = data[column].scatter(row, value)
data_new = data.with_columns(col_series_modified)
return data_new
def _(data, row: int, column: str, value: Any) -> None:
data[row, column] = value


@_set_cell.register(PyArrowTable)
Expand All @@ -273,8 +266,8 @@ def _(data: PyArrowTable, row: int, column: str, value: Any) -> PyArrowTable:
col = data.column(column)
pylist = col.to_pylist()
pylist[row] = value
data_new = data.set_column(colindex, column, pa.array(pylist))
return data_new
data = data.set_column(colindex, column, pa.array(pylist))
return data


# _get_column_dtype ----
Expand Down

0 comments on commit aa154d8

Please sign in to comment.