Skip to content

Commit

Permalink
naïve pipeline support and update plugins to new arch
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzoic committed Jul 25, 2023
1 parent dca810c commit cd6a89a
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 69 deletions.
53 changes: 35 additions & 18 deletions countess/core/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from dataclasses import dataclass, field
from typing import Any, Optional

from more_itertools import interleave_longest

from queue import Empty
import multiprocessing
import time
import psutil

from countess.core.logger import Logger
from countess.core.plugins import BasePlugin, get_plugin_classes
from countess.core.plugins import BasePlugin, get_plugin_classes, FileInputMixin

PRERUN_ROW_LIMIT = 100000

Expand Down Expand Up @@ -57,7 +59,6 @@ def __next__(self):
raise StopIteration



@dataclass
class PipelineNode:
name: str
Expand Down Expand Up @@ -85,22 +86,38 @@ def is_descendant_of(self, node):

def execute(self, logger: Logger, row_limit: Optional[int] = None):
assert row_limit is None or isinstance(row_limit, int)
inputs = {p.name: p.result for p in self.parent_nodes}
self.result = []
if self.plugin:
try:
self.result = self.plugin.process_inputs(inputs, logger, row_limit)
if isinstance(self.result, (bytes, str)):
pass
elif row_limit is not None:
self.result = list(self.result)
#if not isinstance(self.result, (bytes, str)) and (row_limit is not None or len(self.child_nodes) > 1):
# XXX freeze to handle fan-out and reloading.
# self.result = list(self.result)
elif hasattr(self.result, '__iter__'):
self.result = MultiprocessingProxy(self.result, self.name)
except Exception as exc: # pylint: disable=broad-exception-caught
logger.exception(exc)

if self.plugin is None:
self.result = []
elif self.result and not self.is_dirty:
return
elif isinstance(self.plugin, FileInputMixin):
num_files = self.plugin.num_files()
gg = [
self.plugin.load_file(n, logger, row_limit // num_files)
for n in range(0, num_files)
]
self.result = list(interleave_longest(*gg))
else:
self.plugin.prepare([p.name for p in self.parent_nodes])
input_source_data = interleave_longest(*(
[ (p.name, r) for r in p.result ]
for p in self.parent_nodes
))
self.result = list(self.plugin.collect(
d
for source, data in input_source_data
for d in self.plugin.process(data, source, logger)
))
for p in self.parent_nodes:
r = self.plugin.finished(p.name, logger)
if r:
self.result += list(r)
r = self.plugin.finalize(logger)
if r:
self.result += list(r)

self.is_dirty = False

def load_config(self, logger: Logger):
assert isinstance(self.plugin, BasePlugin)
Expand Down
7 changes: 7 additions & 0 deletions countess/plugins/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def read_file_to_dataframe(self, file_params, logger, row_limit=None):

return df

def num_files(self):
return len(self.parameters["files"])

def load_file(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable:
file_params = self.parameters["files"][file_number]
yield self.read_file_to_dataframe(file_params, logger, row_limit)


class SaveCsvPlugin(PandasOutputPlugin):
name = "CSV Save"
Expand Down
1 change: 1 addition & 0 deletions countess/plugins/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def process(df: pd.DataFrame, codes, logger: Logger):
for code in codes:
if not code:
continue

try:
result = df.eval(code)
except Exception as exc: # pylint: disable=W0718
Expand Down
11 changes: 10 additions & 1 deletion countess/plugins/fastq.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import gzip
from itertools import islice
from typing import Optional, Iterable

import pandas as pd
from fqfa.fastq.fastq import parse_fastq_reads # type: ignore

from countess import VERSION
from countess.core.parameters import BooleanParam, FloatParam
from countess.core.plugins import PandasInputPlugin

from countess.core.logger import Logger

def _file_reader(file_handle, min_avg_quality, row_limit=None):
for fastq_read in islice(parse_fastq_reads(file_handle), 0, row_limit):
Expand Down Expand Up @@ -54,3 +55,11 @@ def read_file_to_dataframe(self, file_params, logger, row_limit=None):
dataframe = dataframe.groupby("sequence").count()

return dataframe

def num_files(self):
return len(self.parameters["files"])

def load_file(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable:
file_params = self.parameters["files"][file_number]
yield self.read_file_to_dataframe(file_params, logger, row_limit)

81 changes: 41 additions & 40 deletions countess/plugins/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from typing import Iterable, Optional, Union

import pandas as pd
from moore_itertools import product

from countess import VERSION
from countess.core.logger import Logger
from countess.core.parameters import ArrayParam, BooleanParam, ColumnOrIndexChoiceParam, MultiParam
from countess.core.plugins import PandasBasePlugin
from countess.core.plugins import PandasBasePlugin, PandasProductPlugin
from countess.utils.pandas import get_all_columns


Expand All @@ -18,7 +17,7 @@ def _join_how(left_required: bool, right_required: bool) -> str:
return "right" if right_required else "outer"


class JoinPlugin(PandasBasePlugin):
class JoinPlugin(PandasProductPlugin):
"""Joins Pandas Dataframes"""

name = "Join"
Expand All @@ -41,22 +40,51 @@ class JoinPlugin(PandasBasePlugin):
max_size=2,
),
}
join_params = None
input_columns_1 = None
input_columns_2 = None

def prepare(self, sources: list[str]):
super().prepare(sources)

assert isinstance(self.parameters["inputs"], ArrayParam)
assert len(self.parameters["inputs"]) == 2
ip1, ip2 = self.parameters["inputs"]
ip1.label = f"Input 1: {sources[0]}"
ip2.label = f"Input 2: {sources[1]}"

self.join_params = {
"how": _join_how(ip1.required.value, ip2.required.value),
"left_index": ip1.join_on.is_index(),
"right_index": ip2.join_on.is_index(),
"left_on": None if ip1.join_on.is_index() else ip1.join_on.value,
"right_on": None if ip2.join_on.is_index() else ip2.join_on.value,
}
self.input_columns_1 = {}
self.input_columns_2 = {}

def process_dataframes(self, dataframe1: pd.DataFrame, dataframe2: pd.DataFrame, logger: Logger) -> pd.DataFrame:

# update columns on inputs, these won't propagate back in the case of multiprocess runs but
# they will work in preview mode where we only run this in a single thread.

self.input_columns_1.update(get_all_columns(dataframe1))
self.input_columns_2.update(get_all_columns(dataframe2))

def join_dataframes(self, dataframe1: pd.DataFrame, dataframe2: pd.DataFrame, join_params) -> pd.DataFrame:
# "left_on" and "right_on" don't seem to mind if the column
# is an index, but don't seem to work correctly if the column
# is part of a multiindex: the other multiindex columns go missing.
join1 = join_params.get("left_on")
join1 = self.join_params.get("left_on")
if join1 and dataframe1.index.name != join1:
drop_index = dataframe1.index.name is None and dataframe1.index.names[0] is None
dataframe1 = dataframe1.reset_index(drop=drop_index)

join2 = join_params.get("right_on")
join2 = self.join_params.get("right_on")
if join2 and dataframe2.index.name != join2:
drop_index = dataframe2.index.name is None and dataframe2.index.names[0] is None
dataframe2 = dataframe2.reset_index(drop=drop_index)

dataframe = dataframe1.merge(dataframe2, **join_params)
dataframe = dataframe1.merge(dataframe2, **self.join_params)

if self.parameters["inputs"][0]["drop"].value and join1 in dataframe.columns:
dataframe.drop(columns=join1, inplace=True)
Expand All @@ -65,36 +93,9 @@ def join_dataframes(self, dataframe1: pd.DataFrame, dataframe2: pd.DataFrame, jo

return dataframe

def process_inputs(
self, inputs: Mapping[str, Iterable[pd.DataFrame]], logger: Logger, row_limit: Optional[int]
) -> Iterable[pd.DataFrame]:
try:
input1, input2 = inputs.values()
except ValueError:
raise NotImplementedError("Only two-way joins implemented at this time") # pylint: disable=raise-missing-from

inputs_param = self.parameters["inputs"]
assert isinstance(inputs_param, ArrayParam)
assert len(inputs_param) == 2
ip1, ip2 = inputs_param

join_params = {
"how": _join_how(ip1.required.value, ip2.required.value),
"left_index": ip1.join_on.is_index(),
"right_index": ip2.join_on.is_index(),
"left_on": None if ip1.join_on.is_index() else ip1.join_on.value,
"right_on": None if ip2.join_on.is_index() else ip2.join_on.value,
}

input_columns_1 = {}
input_columns_2 = {}

for df_in1, df_in2 in product(input1, input2):
input_columns_1.update(get_all_columns(df_in1))
input_columns_2.update(get_all_columns(df_in2))
df_out = self.join_dataframes(df_in1, df_in2, join_params)
if len(df_out):
yield df_out

ip1.set_column_choices(input_columns_1.keys())
ip2.set_column_choices(input_columns_2.keys())
def finalize(self, logger: Logger) -> None:
assert isinstance(self.parameters["inputs"], ArrayParam)
assert len(self.parameters["inputs"]) == 2
ip1, ip2 = self.parameters["inputs"]
ip1.set_column_choices(self.input_columns_1.keys())
ip2.set_column_choices(self.input_columns_2.keys())
7 changes: 7 additions & 0 deletions countess/plugins/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

SIMPLE_TYPES = set((bool, int, float, str, tuple, list))

# XXX should probably actually be based on
# PandasTransformDictToDictPlugin
# which is a bit more efficient.

class PythonPlugin(PandasTransformRowToDictPlugin):
name = "Python Code"
Expand All @@ -37,6 +40,10 @@ def process_row(self, row: pd.Series, logger: Logger):
return dict((k, v) for k, v in row_dict.items() if type(v) in SIMPLE_TYPES)

def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> pd.DataFrame:
"""Override parent class because we a) want to reset
the indexes so we can use their values easily and
b) we don't need to merge afterwards"""

dataframe = dataframe.reset_index(drop=False)
series = self.dataframe_to_series(dataframe, logger)
dataframe = self.series_to_dataframe(series)
Expand Down
13 changes: 4 additions & 9 deletions countess/plugins/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class RegexToolPlugin(PandasTransformSingleToTuplePlugin):

compiled_re = None

def prepare(self, sources: list[str]):

self.compiled_re = re.compile(self.parameters["regex"].value)

def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> pd.DataFrame:
df = super().process_dataframe(dataframe, logger)

Expand All @@ -66,15 +70,6 @@ def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> pd.DataF

return df

def process_inputs(
self, inputs: Mapping[str, Iterable[pd.DataFrame]], logger: Logger, row_limit: Optional[int]
) -> Iterable[pd.DataFrame]:
self.compiled_re = re.compile(self.parameters["regex"].value)
while self.compiled_re.groups > len(self.parameters["output"].params):
self.parameters["output"].add_row()

return super().process_inputs(inputs, logger, row_limit)

def process_value(self, value: str, logger: Logger) -> Tuple[str]:
if value is not None:
try:
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ classifiers = [
dependencies = [
'fqfa~=1.2.3',
'more_itertools~=9.1.0',
'moore-itertools',
'numpy~=1.24.2',
'pandas~=2.0.0',
'rapidfuzz~=2.15.1',
Expand Down

0 comments on commit cd6a89a

Please sign in to comment.