Skip to content

Commit

Permalink
removing collect phase from plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzoic committed Sep 1, 2023
1 parent 7fb11dd commit 74d912c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 26 deletions.
5 changes: 1 addition & 4 deletions countess/core/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
from itertools import chain
from multiprocessing import Process, Queue
from os import cpu_count
from queue import Empty
Expand Down Expand Up @@ -101,9 +100,7 @@ def execute(self, logger: Logger, row_limit: Optional[int] = None):
)
elif isinstance(self.plugin, ProcessPlugin):
self.plugin.prepare([p.name for p in self.parent_nodes], row_limit)
self.result = chain(
self.plugin.collect(self.process_parent_iterables(logger)), self.plugin.finalize(logger)
)
self.result = self.process_parent_iterables(logger)

if row_limit is not None or len(self.child_nodes) != 1:
self.result = list(self.result)
Expand Down
24 changes: 2 additions & 22 deletions countess/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ def finalize(self, logger: Logger) -> Iterable[pd.DataFrame]:
# override this if you need to do anything
return []

def collect(self, data: Iterable) -> Iterable:
return data


class FileInputPlugin(BasePlugin):
"""Mixin class to indicate that this plugin can read files from local
Expand Down Expand Up @@ -196,24 +193,6 @@ class PandasProcessPlugin(ProcessPlugin):
def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable[pd.DataFrame]:
raise NotImplementedError(f"{self.__class__}.process")

def collect(self, data: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
buffer = None
for dataframe in data:
if dataframe is None or len(dataframe) == 0:
continue
if len(dataframe) > self.DATAFRAME_BUFFER_SIZE:
yield dataframe
elif buffer is None:
buffer = dataframe
elif len(buffer) + len(dataframe) > self.DATAFRAME_BUFFER_SIZE:
yield buffer
buffer = dataframe
else:
# XXX catch errors?
buffer = pd.concat([buffer, dataframe])
if buffer is not None and len(buffer) > 0:
yield buffer


class PandasSimplePlugin(PandasProcessPlugin):
"""Base class for plugins which accept and return pandas DataFrames.
Expand All @@ -234,7 +213,8 @@ def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable[p
result = self.process_dataframe(data, logger)
if result is not None:
assert isinstance(result, pd.DataFrame)
yield result
if len(result) > 0:
yield result

def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> Optional[pd.DataFrame]:
"""Override this to process a single dataframe"""
Expand Down
19 changes: 19 additions & 0 deletions countess/utils/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,25 @@
import pandas as pd


def collect_dataframes(data: Iterable[pd.DataFrame], preferred_size: int=100000) -> Iterable[pd.DataFrame]:
buffer = None
for dataframe in data:
if dataframe is None or len(dataframe) == 0:
continue
if len(dataframe) > preferred_size:
yield dataframe
elif buffer is None:
buffer = dataframe
elif len(buffer) + len(dataframe) > preferred_size:
yield buffer
buffer = dataframe
else:
# XXX catch errors?
buffer = pd.concat([buffer, dataframe])
if buffer is not None and len(buffer) > 0:
yield buffer


def get_all_indexes(dataframe: pd.DataFrame) -> Dict[str, Any]:
if dataframe.index.name:
return {str(dataframe.index.name): dataframe.index.dtype}
Expand Down

0 comments on commit 74d912c

Please sign in to comment.