Skip to content

Commit

Permalink
fix collate plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzoic committed Oct 23, 2023
1 parent cd6b7e0 commit 0b22bb6
Showing 1 changed file with 15 additions and 27 deletions.
42 changes: 15 additions & 27 deletions countess/plugins/collate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from typing import Iterable, List

import numpy as np
import pandas as pd

from countess import VERSION
from countess.core.logger import Logger
from countess.core.parameters import (
ChoiceParam,
IntegerParam,
PerColumnArrayParam,
)
from countess.core.parameters import ChoiceParam, IntegerParam, PerColumnArrayParam
from countess.core.plugins import PandasProcessPlugin
from countess.utils.pandas import get_all_columns

Expand All @@ -21,8 +16,6 @@ class CollatePlugin(PandasProcessPlugin):
description = "Collate and sort records by column(s), taking the first N"
version = VERSION

input_columns: dict[str, np.dtype]

parameters = {
"columns": PerColumnArrayParam(
"Columns", ChoiceParam("Role", choices=["—", "Group", "Sort (Asc)", "Sort (Desc)"])
Expand All @@ -42,35 +35,30 @@ def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable:
self.dataframes.append(data)
return []

def sort_and_limit(self, df):
column_parameters = list(zip(self.input_columns, self.parameters["columns"]))
sort_cols = [col for col, param in column_parameters if param.value.startswith("Sort")]
sort_dirs = [
param.value.endswith("(Asc)") for param in self.parameters["columns"] if param.value.startswith("Sort")
]

df = df.sort_values(by=sort_cols, ascending=sort_dirs)
if self.parameters["limit"].value > 0:
df = df.head(self.parameters["limit"].value)
return df

def finalize(self, logger: Logger) -> Iterable[pd.DataFrame]:
assert isinstance(self.parameters["columns"], PerColumnArrayParam)
assert self.dataframes

df = pd.concat(self.dataframes)
self.parameters["columns"].set_column_choices(get_all_columns(df).keys())

column_parameters = list(zip(self.input_columns, self.parameters["columns"]))
input_columns = get_all_columns(df).keys()
self.parameters["columns"].set_column_choices(input_columns)
column_parameters = list(zip(input_columns, self.parameters["columns"]))
group_cols = [col for col, param in column_parameters if param.value == "Group"]
sort_cols = {
col: param.value.endswith("(Asc)") for col, param in column_parameters if param.value.startswith("Sort")
}

def sort_and_limit(df: pd.DataFrame) -> pd.DataFrame:
df = df.sort_values(by=list(sort_cols.keys()), ascending=list(sort_cols.values()))
if self.parameters["limit"].value > 0:
df = df.head(self.parameters["limit"].value)
return df

try:
if group_cols:
df = df.groupby(group_cols, group_keys=False).apply(self.sort_and_limit)
df = df.groupby(group_cols, group_keys=False).apply(sort_and_limit)
else:
df = self.sort_and_limit(df)

print(f">>>>> {type(df)}\n{df}")
df = sort_and_limit(df)

yield df
except ValueError as exc:
Expand Down

0 comments on commit 0b22bb6

Please sign in to comment.