diff --git a/countess/core/pipeline.py b/countess/core/pipeline.py index c658d1f..bb50f65 100644 --- a/countess/core/pipeline.py +++ b/countess/core/pipeline.py @@ -1,13 +1,15 @@ from dataclasses import dataclass, field from typing import Any, Optional +from more_itertools import interleave_longest + from queue import Empty import multiprocessing import time import psutil from countess.core.logger import Logger -from countess.core.plugins import BasePlugin, get_plugin_classes +from countess.core.plugins import BasePlugin, get_plugin_classes, FileInputMixin PRERUN_ROW_LIMIT = 100000 @@ -57,7 +59,6 @@ def __next__(self): raise StopIteration - @dataclass class PipelineNode: name: str @@ -85,22 +86,38 @@ def is_descendant_of(self, node): def execute(self, logger: Logger, row_limit: Optional[int] = None): assert row_limit is None or isinstance(row_limit, int) - inputs = {p.name: p.result for p in self.parent_nodes} - self.result = [] - if self.plugin: - try: - self.result = self.plugin.process_inputs(inputs, logger, row_limit) - if isinstance(self.result, (bytes, str)): - pass - elif row_limit is not None: - self.result = list(self.result) - #if not isinstance(self.result, (bytes, str)) and (row_limit is not None or len(self.child_nodes) > 1): - # XXX freeze to handle fan-out and reloading. - # self.result = list(self.result) - elif hasattr(self.result, '__iter__'): - self.result = MultiprocessingProxy(self.result, self.name) - except Exception as exc: # pylint: disable=broad-exception-caught - logger.exception(exc) + + if self.plugin is None: + self.result = [] + elif self.result and not self.is_dirty: + return + elif isinstance(self.plugin, FileInputMixin): + num_files = self.plugin.num_files() + gg = [ + self.plugin.load_file(n, logger, row_limit // num_files) + for n in range(0, num_files) + ] + self.result = list(interleave_longest(*gg)) + else: + self.plugin.prepare([p.name for p in self.parent_nodes]) + input_source_data = interleave_longest(*( + [ (p.name, r) for r in p.result ] + for p in self.parent_nodes + )) + self.result = list(self.plugin.collect( + d + for source, data in input_source_data + for d in self.plugin.process(data, source, logger) + )) + for p in self.parent_nodes: + r = self.plugin.finished(p.name, logger) + if r: + self.result += list(r) + r = self.plugin.finalize(logger) + if r: + self.result += list(r) + + self.is_dirty = False def load_config(self, logger: Logger): assert isinstance(self.plugin, BasePlugin) diff --git a/countess/plugins/csv.py b/countess/plugins/csv.py index 69a81a2..12d9f72 100644 --- a/countess/plugins/csv.py +++ b/countess/plugins/csv.py @@ -143,6 +143,13 @@ def read_file_to_dataframe(self, file_params, logger, row_limit=None): return df + def num_files(self): + return len(self.parameters["files"]) + + def load_file(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable: + file_params = self.parameters["files"][file_number] + yield self.read_file_to_dataframe(file_params, logger, row_limit) + class SaveCsvPlugin(PandasOutputPlugin): name = "CSV Save" diff --git a/countess/plugins/expression.py b/countess/plugins/expression.py index c7f1540..32aff3f 100644 --- a/countess/plugins/expression.py +++ b/countess/plugins/expression.py @@ -10,6 +10,7 @@ def process(df: pd.DataFrame, codes, logger: Logger): for code in codes: if not code: continue + try: result = df.eval(code) except Exception as exc: # pylint: disable=W0718 diff --git a/countess/plugins/fastq.py b/countess/plugins/fastq.py index 18f1414..4c50d28 100644 --- a/countess/plugins/fastq.py +++ b/countess/plugins/fastq.py @@ -1,5 +1,6 @@ import gzip from itertools import islice +from typing import Optional, Iterable import pandas as pd from fqfa.fastq.fastq import parse_fastq_reads # type: ignore @@ -7,7 +8,7 @@ from countess import VERSION from countess.core.parameters import BooleanParam, FloatParam from countess.core.plugins import PandasInputPlugin - +from countess.core.logger import Logger def _file_reader(file_handle, min_avg_quality, row_limit=None): for fastq_read in islice(parse_fastq_reads(file_handle), 0, row_limit): @@ -54,3 +55,11 @@ def read_file_to_dataframe(self, file_params, logger, row_limit=None): dataframe = dataframe.groupby("sequence").count() return dataframe + + def num_files(self): + return len(self.parameters["files"]) + + def load_file(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable: + file_params = self.parameters["files"][file_number] + yield self.read_file_to_dataframe(file_params, logger, row_limit) + diff --git a/countess/plugins/join.py b/countess/plugins/join.py index cc6cf7f..b5fd06b 100644 --- a/countess/plugins/join.py +++ b/countess/plugins/join.py @@ -2,12 +2,11 @@ from typing import Iterable, Optional, Union import pandas as pd -from moore_itertools import product from countess import VERSION from countess.core.logger import Logger from countess.core.parameters import ArrayParam, BooleanParam, ColumnOrIndexChoiceParam, MultiParam -from countess.core.plugins import PandasBasePlugin +from countess.core.plugins import PandasBasePlugin, PandasProductPlugin from countess.utils.pandas import get_all_columns @@ -18,7 +17,7 @@ def _join_how(left_required: bool, right_required: bool) -> str: return "right" if right_required else "outer" -class JoinPlugin(PandasBasePlugin): +class JoinPlugin(PandasProductPlugin): """Joins Pandas Dataframes""" name = "Join" @@ -41,22 +40,51 @@ class JoinPlugin(PandasBasePlugin): max_size=2, ), } + join_params = None + input_columns_1 = None + input_columns_2 = None + + def prepare(self, sources: list[str]): + super().prepare(sources) + + assert isinstance(self.parameters["inputs"], ArrayParam) + assert len(self.parameters["inputs"]) == 2 + ip1, ip2 = self.parameters["inputs"] + ip1.label = f"Input 1: {sources[0]}" + ip2.label = f"Input 2: {sources[1]}" + + self.join_params = { + "how": _join_how(ip1.required.value, ip2.required.value), + "left_index": ip1.join_on.is_index(), + "right_index": ip2.join_on.is_index(), + "left_on": None if ip1.join_on.is_index() else ip1.join_on.value, + "right_on": None if ip2.join_on.is_index() else ip2.join_on.value, + } + self.input_columns_1 = {} + self.input_columns_2 = {} + + def process_dataframes(self, dataframe1: pd.DataFrame, dataframe2: pd.DataFrame, logger: Logger) -> pd.DataFrame: + + # update columns on inputs, these won't propagate back in the case of multiprocess runs but + # they will work in preview mode where we only run this in a single thread. + + self.input_columns_1.update(get_all_columns(dataframe1)) + self.input_columns_2.update(get_all_columns(dataframe2)) - def join_dataframes(self, dataframe1: pd.DataFrame, dataframe2: pd.DataFrame, join_params) -> pd.DataFrame: # "left_on" and "right_on" don't seem to mind if the column # is an index, but don't seem to work correctly if the column # is part of a multiindex: the other multiindex columns go missing. - join1 = join_params.get("left_on") + join1 = self.join_params.get("left_on") if join1 and dataframe1.index.name != join1: drop_index = dataframe1.index.name is None and dataframe1.index.names[0] is None dataframe1 = dataframe1.reset_index(drop=drop_index) - join2 = join_params.get("right_on") + join2 = self.join_params.get("right_on") if join2 and dataframe2.index.name != join2: drop_index = dataframe2.index.name is None and dataframe2.index.names[0] is None dataframe2 = dataframe2.reset_index(drop=drop_index) - dataframe = dataframe1.merge(dataframe2, **join_params) + dataframe = dataframe1.merge(dataframe2, **self.join_params) if self.parameters["inputs"][0]["drop"].value and join1 in dataframe.columns: dataframe.drop(columns=join1, inplace=True) @@ -65,36 +93,9 @@ def join_dataframes(self, dataframe1: pd.DataFrame, dataframe2: pd.DataFrame, jo return dataframe - def process_inputs( - self, inputs: Mapping[str, Iterable[pd.DataFrame]], logger: Logger, row_limit: Optional[int] - ) -> Iterable[pd.DataFrame]: - try: - input1, input2 = inputs.values() - except ValueError: - raise NotImplementedError("Only two-way joins implemented at this time") # pylint: disable=raise-missing-from - - inputs_param = self.parameters["inputs"] - assert isinstance(inputs_param, ArrayParam) - assert len(inputs_param) == 2 - ip1, ip2 = inputs_param - - join_params = { - "how": _join_how(ip1.required.value, ip2.required.value), - "left_index": ip1.join_on.is_index(), - "right_index": ip2.join_on.is_index(), - "left_on": None if ip1.join_on.is_index() else ip1.join_on.value, - "right_on": None if ip2.join_on.is_index() else ip2.join_on.value, - } - - input_columns_1 = {} - input_columns_2 = {} - - for df_in1, df_in2 in product(input1, input2): - input_columns_1.update(get_all_columns(df_in1)) - input_columns_2.update(get_all_columns(df_in2)) - df_out = self.join_dataframes(df_in1, df_in2, join_params) - if len(df_out): - yield df_out - - ip1.set_column_choices(input_columns_1.keys()) - ip2.set_column_choices(input_columns_2.keys()) + def finalize(self, logger: Logger) -> None: + assert isinstance(self.parameters["inputs"], ArrayParam) + assert len(self.parameters["inputs"]) == 2 + ip1, ip2 = self.parameters["inputs"] + ip1.set_column_choices(self.input_columns_1.keys()) + ip2.set_column_choices(self.input_columns_2.keys()) diff --git a/countess/plugins/python.py b/countess/plugins/python.py index 8431012..bc7b2e2 100644 --- a/countess/plugins/python.py +++ b/countess/plugins/python.py @@ -14,6 +14,9 @@ SIMPLE_TYPES = set((bool, int, float, str, tuple, list)) +# XXX should probably actually be based on +# PandasTransformDictToDictPlugin +# which is a bit more efficient. class PythonPlugin(PandasTransformRowToDictPlugin): name = "Python Code" @@ -37,6 +40,10 @@ def process_row(self, row: pd.Series, logger: Logger): return dict((k, v) for k, v in row_dict.items() if type(v) in SIMPLE_TYPES) def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> pd.DataFrame: + """Override parent class because we a) want to reset + the indexes so we can use their values easily and + b) we don't need to merge afterwards""" + dataframe = dataframe.reset_index(drop=False) series = self.dataframe_to_series(dataframe, logger) dataframe = self.series_to_dataframe(series) diff --git a/countess/plugins/regex.py b/countess/plugins/regex.py index d2786f6..888e5f4 100644 --- a/countess/plugins/regex.py +++ b/countess/plugins/regex.py @@ -46,6 +46,10 @@ class RegexToolPlugin(PandasTransformSingleToTuplePlugin): compiled_re = None + def prepare(self, sources: list[str]): + + self.compiled_re = re.compile(self.parameters["regex"].value) + def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> pd.DataFrame: df = super().process_dataframe(dataframe, logger) @@ -66,15 +70,6 @@ def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> pd.DataF return df - def process_inputs( - self, inputs: Mapping[str, Iterable[pd.DataFrame]], logger: Logger, row_limit: Optional[int] - ) -> Iterable[pd.DataFrame]: - self.compiled_re = re.compile(self.parameters["regex"].value) - while self.compiled_re.groups > len(self.parameters["output"].params): - self.parameters["output"].add_row() - - return super().process_inputs(inputs, logger, row_limit) - def process_value(self, value: str, logger: Logger) -> Tuple[str]: if value is not None: try: diff --git a/pyproject.toml b/pyproject.toml index 0ea99ea..e348fc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ dependencies = [ 'fqfa~=1.2.3', 'more_itertools~=9.1.0', - 'moore-itertools', 'numpy~=1.24.2', 'pandas~=2.0.0', 'rapidfuzz~=2.15.1',