diff --git a/countess/core/pipeline.py b/countess/core/pipeline.py index 0d236c8..1cdb10d 100644 --- a/countess/core/pipeline.py +++ b/countess/core/pipeline.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from itertools import chain from multiprocessing import Process, Queue from os import cpu_count from queue import Empty @@ -101,9 +100,7 @@ def execute(self, logger: Logger, row_limit: Optional[int] = None): ) elif isinstance(self.plugin, ProcessPlugin): self.plugin.prepare([p.name for p in self.parent_nodes], row_limit) - self.result = chain( - self.plugin.collect(self.process_parent_iterables(logger)), self.plugin.finalize(logger) - ) + self.result = self.process_parent_iterables(logger) if row_limit is not None or len(self.child_nodes) != 1: self.result = list(self.result) diff --git a/countess/core/plugins.py b/countess/core/plugins.py index e4dbc00..36cfeaf 100644 --- a/countess/core/plugins.py +++ b/countess/core/plugins.py @@ -162,9 +162,6 @@ def finalize(self, logger: Logger) -> Iterable[pd.DataFrame]: # override this if you need to do anything return [] - def collect(self, data: Iterable) -> Iterable: - return data - class FileInputPlugin(BasePlugin): """Mixin class to indicate that this plugin can read files from local @@ -196,24 +193,6 @@ class PandasProcessPlugin(ProcessPlugin): def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable[pd.DataFrame]: raise NotImplementedError(f"{self.__class__}.process") - def collect(self, data: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: - buffer = None - for dataframe in data: - if dataframe is None or len(dataframe) == 0: - continue - if len(dataframe) > self.DATAFRAME_BUFFER_SIZE: - yield dataframe - elif buffer is None: - buffer = dataframe - elif len(buffer) + len(dataframe) > self.DATAFRAME_BUFFER_SIZE: - yield buffer - buffer = dataframe - else: - # XXX catch errors? - buffer = pd.concat([buffer, dataframe]) - if buffer is not None and len(buffer) > 0: - yield buffer - class PandasSimplePlugin(PandasProcessPlugin): """Base class for plugins which accept and return pandas DataFrames. @@ -234,7 +213,8 @@ def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable[p result = self.process_dataframe(data, logger) if result is not None: assert isinstance(result, pd.DataFrame) - yield result + if len(result) > 0: + yield result def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> Optional[pd.DataFrame]: """Override this to process a single dataframe""" diff --git a/countess/utils/pandas.py b/countess/utils/pandas.py index 4a1b85d..0d52ea3 100644 --- a/countess/utils/pandas.py +++ b/countess/utils/pandas.py @@ -5,6 +5,25 @@ import pandas as pd +def collect_dataframes(data: Iterable[pd.DataFrame], preferred_size: int=100000) -> Iterable[pd.DataFrame]: + buffer = None + for dataframe in data: + if dataframe is None or len(dataframe) == 0: + continue + if len(dataframe) > preferred_size: + yield dataframe + elif buffer is None: + buffer = dataframe + elif len(buffer) + len(dataframe) > preferred_size: + yield buffer + buffer = dataframe + else: + # XXX catch errors? + buffer = pd.concat([buffer, dataframe]) + if buffer is not None and len(buffer) > 0: + yield buffer + + def get_all_indexes(dataframe: pd.DataFrame) -> Dict[str, Any]: if dataframe.index.name: return {str(dataframe.index.name): dataframe.index.dtype}