Skip to content

Commit

Permalink
new parallelizable plugin architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzoic committed Jul 25, 2023
1 parent d7fe546 commit dca810c
Showing 1 changed file with 154 additions and 34 deletions.
188 changes: 154 additions & 34 deletions countess/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os.path
import sys
from collections.abc import Mapping, MutableMapping
from typing import Any, Dict, Iterable, Optional, Union
from typing import Dict, Iterable, Optional, Union, List

import numpy as np
import pandas as pd
Expand All @@ -39,6 +39,7 @@
)
from countess.utils.pandas import get_all_columns


PRERUN_ROW_LIMIT = 100000


Expand Down Expand Up @@ -137,11 +138,25 @@ def hash(self):
"""Returns a hex digest of the hash of all configuration parameters"""
return self.get_parameter_hash().hexdigest()

def prepare(self):
def prepare(self, sources: List[str]):
pass

def process_inputs(self, inputs: Mapping[str, Iterable[Any]], logger: Logger, row_limit: Optional[int]) -> Iterable[Any]:
raise NotImplementedError(f"{self.__class__}.process_inputs()")
def process(self, data, source: str, logger: Logger) -> Optional[Iterable[pd.DataFrame]]:
"""Called with each `data` input from `source`, calls
`callback` to send messages to the next plugin"""
raise NotImplementedError(f"{self.__class__}.process")

def finished(self, source: str, logger: Logger) -> Optional[Iterable[pd.DataFrame]]:
"""Called when a `source` is finished and not able to
send any more messages. Can be ignored by most things."""
# override this if you need to do anything

def finalize(self, logger: Logger) -> Optional[Iterable[pd.DataFrame]]:
"""Called when all sources are finished. Can be
ignored by most things. This should reset the
plugin to be ready for another use."""
# override this if you need to do anything
return None


class FileInputMixin:
Expand All @@ -155,59 +170,162 @@ class FileInputMixin:
file_types = [("Any", "*")]
file_params: MutableMapping[str, BaseParam] = {}

def load_files(self, logger: Logger, row_limit: Optional[int] = None) -> Iterable[Any]:
raise NotImplementedError("FileInputMixin.load_files")
def num_files(self) -> int:
"""return the number of 'files' which are to be loaded. The pipeline
will call code equivalent to
`[ p.load_file(n, logger, row_limit) for n in range(0, p.num_files() ]`
although potentially using threads, multiprocessing, etc."""

def process_inputs(self, inputs: Mapping[str, Iterable[Any]], logger: Logger, row_limit: Optional[int]) -> Iterable[Any]:
if len(inputs) > 0:
logger.warning(f"{self.name} doesn't take inputs")
raise ValueError(f"{self.name} doesn't take inputs")
return 0

return self.load_files(logger, row_limit)
def load_file(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable:
"""Called potentially from multiple processes, see FileInputMixin.num_files()"""
raise NotImplementedError("FileInputMixin.load_file")


class PandasBasePlugin(BasePlugin):

DATAFRAME_BUFFER_SIZE = 1000000

def process_inputs(
self, inputs: Mapping[str, Iterable[pd.DataFrame]], logger: Logger, row_limit: Optional[int]
) -> Iterable[pd.DataFrame]:
raise NotImplementedError(f"{self.__class__}.process_inputs()")

def process(self, data: pd.DataFrame, source: str, logger: Logger) -> pd.DataFrame:
raise NotImplementedError(f"{self.__class__}.process")

def collect(self, dataframes: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
buffer = None
for dataframe in dataframes:
if dataframe is None or len(dataframe) == 0:
continue
if len(dataframe) > self.DATAFRAME_BUFFER_SIZE:
yield dataframe
elif buffer is None:
buffer = dataframe
elif len(buffer) + len(dataframe) > self.DATAFRAME_BUFFER_SIZE:
yield buffer
buffer = dataframe
else:
# XXX catch errors?
buffer = pd.concat([buffer, dataframe])
if buffer is not None and len(buffer) > 0:
yield buffer


class PandasSimplePlugin(PandasBasePlugin):
"""Base class for plugins which accept and return pandas DataFrames"""
"""Base class for plugins which accept and return pandas DataFrames.
Subclassing this hides all the distracting aspects of the pipeline
from the plugin implementor, who only needs to override process_dataframe"""

input_columns: Dict[str, np.dtype] = {}

def process_inputs(
self, inputs: Mapping[str, Iterable[pd.DataFrame]], logger: Logger, row_limit: Optional[int]
) -> Iterable[pd.DataFrame]:
def prepare(self, sources: list[str]):
self.input_columns = {}
iterators = set(iter(input) for input in inputs.values())
while iterators:
for it in list(iterators):
try:
df_in = next(it)
assert isinstance(df_in, pd.DataFrame)
self.input_columns.update(get_all_columns(df_in))

df_out = self.process_dataframe(df_in, logger)
assert isinstance(df_out, pd.DataFrame)
yield df_out
except StopIteration:
iterators.remove(it)
def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable[pd.DataFrame]:
"""Just deal with each dataframe as it comes. PandasSimplePlugins don't care about `source`."""
assert isinstance(data, pd.DataFrame)

for p in self.parameters.values():
p.set_column_choices(self.input_columns.keys())
self.input_columns.update(get_all_columns(data))

result = self.process_dataframe(data, logger)
if result is not None:
assert isinstance(result, pd.DataFrame)
yield result

def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> pd.DataFrame:
"""Override this to process a single dataframe"""
raise NotImplementedError(f"{self.__class__}.process_dataframe()")


# XXX this might be excessively DRY but we'll see.
class PandasProductPlugin(PandasBasePlugin):
"""Some plugins need to have all the data from two sources presented to them,
which is tricky in a pipelined environment. This superclass handles the two
input sources and calls .process_dataframes with pairs of dataframes.
It is currently only used by JoinPlugin"""

source1 = None
source2 = None
mem1 = None
mem2 = None

def prepare(self, sources: list[str]):

if len(sources) != 2:
raise ValueError(f"{self.__class__} required exactly two inputs")
self.source1, self.source2 = sources

super().prepare(sources)

self.mem1 = []
self.mem2 = []

def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable[pd.DataFrame]:
if source == self.source1:
if self.mem1 is not None:
self.mem1.append(data)
for val2 in self.mem2:
df = self.process_dataframes(data, val2, logger)
if len(df):
yield df

elif source == self.source2:
if self.mem2 is not None:
self.mem2.append(data)
for val1 in self.mem1:
df = self.process_dataframes(val1, data, logger)
if len(df):
yield df

else:
raise ValueError(f"Unknown source {source}")

def finished(self, source: str, logger: Logger) -> None:
if source == self.source1:
# source1 is finished, mem1 is no longer needed
self.mem1 = None
elif source == self.source2:
# source2 is finished, mem2 is no longer needed
self.mem2 = None
else:
raise ValueError(f"Unknown source {source}")

def finalize(self, logger: Logger) -> None:
# free up any memory taken up by memoization
self.mem1 = None
self.mem2 = None

def process_dataframes(self, dataframe1: pd.DataFrame, dataframe2: pd.DataFrame, logger: Logger) -> pd.DataFrame:
raise NotImplementedError(f"{self.__class__}.process_dataframes")


# XXX this might be excessively DRY but we'll see.
#
# --> PandasTransformAToXMixin --
# / dataframe_to_series \
# PandasTransformBasePlugin -- --> PandasTransformAToBMixin
# \ series_to_dataframe /
# --> PandasTransformXToBMixin --
#
# A is one of Single (takes a value), Row (takes a pd.Series) or Dict (takes a dict)
# B is one of Single (returns a value), Tuple (returns a tuple) or Dict (returns a dict)
#
# Looked at from the point of view of the data it looks like:
#
# P_T_BasePlugin P_T_AToXMixin P_T_AToBPlugin P_T_XToBPlugin PTBP
#
# /-> process_row ---\
# process_dataframe -> dataframe_to_series ---> process_value ---> series_to_dataframe -> merge
# \-> process_dict --/
#
# Which probably seems overcomplicated but it also seems to work.
# Most plugins just need to pick a PandasTransform class and run with it. A lot of the time it's
# going to be PandasTransformSingleToSinglePlugin.

class PandasTransformBasePlugin(PandasSimplePlugin):
"""Base classes for the six (!) PandasTransformXToXPlugin superclasses."""
"""Base classes for the nine (!!) PandasTransformXToXPlugin superclasses."""

def series_to_dataframe(self, series: pd.Series) -> pd.DataFrame:
return NotImplementedError(f"{self.__class__}.series_to_dataframe()")
Expand Down Expand Up @@ -372,7 +490,7 @@ def process_dict(self, data, logger: Logger):

class PandasInputPlugin(FileInputMixin, PandasBasePlugin):
"""A specialization of the PandasBasePlugin to allow it to follow nothing,
eg: come first."""
eg: come first. """

def __init__(self, *a, **k):
# Add in filenames
Expand All @@ -384,10 +502,12 @@ def __init__(self, *a, **k):
[("files", FileArrayParam("Files", MultiParam("File", file_params)))] + list(self.parameters.items())
)

def load_files(self, logger: Logger, row_limit: Optional[int] = None) -> Iterable[pd.DataFrame]:
def num_files(self):
return len(self.parameters["files"].params)

def load_files(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable[pd.DataFrame]:
assert isinstance(self.parameters["files"], ArrayParam)
fps = self.parameters["files"].params

num_files = len(fps)
per_file_row_limit = int(row_limit / len(fps) + 1) if row_limit else None
logger.progress("Loading", 0)
Expand Down

0 comments on commit dca810c

Please sign in to comment.