Skip to content

Commit

Permalink
code cleanup & passing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzoic committed Aug 7, 2023
1 parent 8aed68d commit fe042d4
Show file tree
Hide file tree
Showing 13 changed files with 131 additions and 228 deletions.
11 changes: 5 additions & 6 deletions countess/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Iterable, Optional

from countess.core.logger import Logger
from countess.core.plugins import BasePlugin, FileInputMixin, get_plugin_classes
from countess.core.plugins import BasePlugin, FileInputPlugin, ProcessPlugin, get_plugin_classes

PRERUN_ROW_LIMIT = 100000

Expand Down Expand Up @@ -92,16 +92,15 @@ def execute(self, logger: Logger, row_limit: Optional[int] = None):
return
elif self.result and not self.is_dirty:
return
elif isinstance(self.plugin, FileInputMixin):
elif isinstance(self.plugin, FileInputPlugin):
num_files = self.plugin.num_files()
row_limit_each_file = row_limit // num_files if row_limit is not None else None
self.result = multi_iterator_map(self.plugin.load_file, range(0, num_files), args=(logger, row_limit_each_file))
else:
elif isinstance(self.plugin, ProcessPlugin):
self.plugin.prepare([p.name for p in self.parent_nodes], row_limit)
self.result = self.process_parent_iterables(logger)
self.result = self.plugin.collect(self.process_parent_iterables(logger))

self.result = self.plugin.collect(self.result)
if len(self.child_nodes) != 1:
if row_limit is not None or len(self.child_nodes) != 1:
self.result = list(self.result)

self.is_dirty = False
Expand Down
52 changes: 28 additions & 24 deletions countess/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def hash(self):
"""Returns a hex digest of the hash of all configuration parameters"""
return self.get_parameter_hash().hexdigest()


class ProcessPlugin(BasePlugin):
"""A plugin which accepts data from one or more sources"""

def prepare(self, sources: List[str], row_limit: Optional[int]):
pass

Expand All @@ -162,7 +166,7 @@ def collect(self, data: Iterable) -> Iterable:
return data


class FileInputMixin:
class FileInputPlugin(BasePlugin):
"""Mixin class to indicate that this plugin can read files from local
storage."""

Expand All @@ -186,8 +190,8 @@ def load_file(self, file_number: int, logger: Logger, row_limit: Optional[int] =
raise NotImplementedError("FileInputMixin.load_file")


class PandasBasePlugin(BasePlugin):
DATAFRAME_BUFFER_SIZE = 1000000
class PandasProcessPlugin(ProcessPlugin):
DATAFRAME_BUFFER_SIZE = 100000

def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable[pd.DataFrame]:
raise NotImplementedError(f"{self.__class__}.process")
Expand All @@ -211,7 +215,7 @@ def collect(self, data: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
yield buffer


class PandasSimplePlugin(PandasBasePlugin):
class PandasSimplePlugin(PandasProcessPlugin):
"""Base class for plugins which accept and return pandas DataFrames.
Subclassing this hides all the distracting aspects of the pipeline
from the plugin implementor, who only needs to override process_dataframe"""
Expand All @@ -232,12 +236,12 @@ def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable[p
assert isinstance(result, pd.DataFrame)
yield result

def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> pd.DataFrame:
def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> Optional[pd.DataFrame]:
"""Override this to process a single dataframe"""
raise NotImplementedError(f"{self.__class__}.process_dataframe()")


class PandasProductPlugin(PandasBasePlugin):
class PandasProductPlugin(PandasProcessPlugin):
"""Some plugins need to have all the data from two sources presented to them,
which is tricky in a pipelined environment. This superclass handles the two
input sources and calls .process_dataframes with pairs of dataframes.
Expand Down Expand Up @@ -391,10 +395,8 @@ def dataframe_to_series(self, dataframe: pd.DataFrame, logger: Logger) -> pd.Ser
self.process_raw,
axis=1,
raw=True,
kwargs={
"columns": list(dataframe.columns),
"logger": logger,
},
columns=list(dataframe.columns),
logger=logger,
)

def process_dict(self, data, logger: Logger):
Expand Down Expand Up @@ -505,7 +507,7 @@ def process_dict(self, data, logger: Logger):
raise NotImplementedError(f"{self.__class__}.process_dict()")


class PandasInputPlugin(FileInputMixin,BasePlugin):
class PandasInputPlugin(FileInputPlugin):
"""A specialization of the PandasBasePlugin to allow it to follow nothing,
eg: come first."""

Expand All @@ -522,23 +524,25 @@ def __init__(self, *a, **k):
def num_files(self):
return len(self.parameters["files"].params)

def load_files(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable[pd.DataFrame]:
def load_file(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable[pd.DataFrame]:
raise NotImplementedError(f"{self.__class__}.load_file()")

class PandasInputFilesPlugin(PandasInputPlugin):

def num_files(self):
return len(self.parameters["files"])

def load_file(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable[pd.DataFrame]:
assert isinstance(self.parameters["files"], ArrayParam)
fps = self.parameters["files"].params
num_files = len(fps)
per_file_row_limit = int(row_limit / len(fps) + 1) if row_limit else None
logger.progress("Loading", 0)
for num, fp in enumerate(fps):
assert isinstance(fp, MultiParam)
yield self.read_file_to_dataframe(fp, logger, per_file_row_limit)
logger.progress("Loading", 100 * (num + 1) // (num_files + 1))
logger.progress("Done", 100)
file_params = self.parameters["files"][file_number]
yield self.read_file_to_dataframe(file_params, logger, row_limit)

def read_file_to_dataframe(self, file_params, logger, row_limit=None) -> pd.DataFrame:
raise NotImplementedError(f"{self.__class__}.read_file_to_dataframe")

def read_file_to_dataframe(self, file_params: MultiParam, logger: Logger, row_limit: Optional[int] = None) -> pd.DataFrame:
raise NotImplementedError(f"Implement {self.__class__.__name__}.read_file_to_dataframe")


class PandasOutputPlugin(PandasBasePlugin):
class PandasOutputPlugin(PandasProcessPlugin):
def process_inputs(self, inputs: Mapping[str, Iterable[pd.DataFrame]], logger: Logger, row_limit: Optional[int]):
iterators = set(iter(input) for input in inputs.values())

Expand Down
36 changes: 27 additions & 9 deletions countess/plugins/correlation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Optional, Iterable

import pandas as pd

from countess import VERSION
from countess.core.logger import Logger
from countess.core.parameters import ChoiceParam, ColumnChoiceParam, ColumnOrNoneChoiceParam
from countess.core.plugins import PandasTransformPlugin
from countess.core.plugins import PandasSimplePlugin


class CorrelationPlugin(PandasTransformPlugin):
class CorrelationPlugin(PandasSimplePlugin):
"""Correlations"""

name = "Correlation Tool"
Expand All @@ -20,17 +22,33 @@ class CorrelationPlugin(PandasTransformPlugin):
"column1": ColumnChoiceParam("Column 1"),
"column2": ColumnChoiceParam("Column 2"),
}
columns : list[str] = []
dataframes : list[pd.DataFrame] = []

def run_df(self, df: pd.DataFrame, logger: Logger) -> pd.DataFrame:
def prepare(self, sources: list[str], row_limit: Optional[int]):
assert isinstance(self.parameters["group"], ColumnOrNoneChoiceParam)
column1 = self.parameters["column1"].value
column2 = self.parameters["column2"].value
self.columns = [ column1, column2 ]
if not self.parameters["group"].is_none():
self.columns.append(self.parameters["group"].value)
self.dataframes = []

method = self.parameters["method"].value
def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> None:
self.dataframes.append(dataframe[self.columns])

def finalize(self, logger: Logger) -> Iterable[pd.DataFrame]:
assert isinstance(self.parameters["group"], ColumnOrNoneChoiceParam)
column1 = self.parameters["column1"].value
column2 = self.parameters["column2"].value
groupby = None if self.parameters["group"].is_none() else self.parameters["group"].value

method = self.parameters["method"].value

if self.parameters["group"].is_none():
corr = df[column1].corr(df[column2], method=method)
return pd.DataFrame([[corr]], columns=["correlation"])
dataframe = pd.concat(self.dataframes)
if groupby:
ds = dataframe.groupby(groupby)[column1]
yield ds.corr(dataframe[column2], method=method).to_frame(name="correlation")
else:
ds = df.groupby(self.parameters["group"].value)[column1]
return ds.corr(df[column2], method=method).to_frame(name="correlation")
corr = dataframe[column1].corr(dataframe[column2], method=method)
yield pd.DataFrame([[corr]], columns=["correlation"])
30 changes: 11 additions & 19 deletions countess/plugins/csv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import csv
import gzip
from io import BufferedWriter, BytesIO
from typing import Iterable, Optional, Union
from typing import Optional, Union

import pandas as pd

Expand All @@ -16,7 +16,7 @@
MultiParam,
StringParam,
)
from countess.core.plugins import PandasInputPlugin, PandasOutputPlugin
from countess.core.plugins import PandasInputFilesPlugin, PandasProcessPlugin

# XXX it would be better to do the same this Regex Tool does and get the user to assign
# data types to each column
Expand All @@ -41,7 +41,7 @@ def clean_row(row):
return [maybe_number(x) for x in row]


class LoadCsvPlugin(PandasInputPlugin):
class LoadCsvPlugin(PandasInputFilesPlugin):
"""Load CSV files"""

name = "CSV Load"
Expand Down Expand Up @@ -143,16 +143,8 @@ def read_file_to_dataframe(self, file_params, logger, row_limit=None):

return df

def num_files(self):
return len(self.parameters["files"])

def load_file(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable:
assert isinstance(self.parameters["files"], ArrayParam)
file_params = self.parameters["files"][file_number]
yield self.read_file_to_dataframe(file_params, logger, row_limit)


class SaveCsvPlugin(PandasOutputPlugin):
class SaveCsvPlugin(PandasProcessPlugin):
name = "CSV Save"
description = "Save data as CSV or similar delimited text files"
link = "https://countess-project.github.io/CountESS/plugins/#csv-writer"
Expand Down Expand Up @@ -185,30 +177,30 @@ def prepare(self, sources: list[str], row_limit: Optional[int] = None):

self.csv_columns = None

def process(self, dataframe: pd.DataFrame, source: str, logger: Logger):
def process(self, data: pd.DataFrame, source: str, logger: Logger):
# reset indexes so we can treat all columns equally.
# if there's just a nameless index then we don't care about it, drop it.

drop_index = dataframe.index.name is None and dataframe.index.names[0] is None
dataframe = dataframe.reset_index(drop=drop_index)
drop_index = data.index.name is None and data.index.names[0] is None
dataframe = data.reset_index(drop=drop_index)

# if this is our first dataframe to write then decide whether to
# include the header or not.
if self.csv_columns is None:
self.csv_columns = list(dataframe.columns)
self.csv_columns = list(data.columns)
emit_header = bool(self.parameters["header"].value)
else:
# add in any columns we haven't seen yet in previous dataframes.
for c in dataframe.columns:
for c in data.columns:
if c not in self.csv_columns:
self.csv_columns.append(c)
logger.warning(f"Added CSV Column {repr(c)} with no header")
# fill in blanks for any columns which are in previous dataframes but not
# in this one.
dataframe = dataframe.assign(**{c: None for c in self.csv_columns if c not in dataframe.columns})
dataframe = data.assign(**{c: None for c in self.csv_columns if c not in dataframe.columns})
emit_header = False

dataframe.to_csv(
data.to_csv(
self.filehandle,
header=emit_header,
columns=self.csv_columns,
Expand Down
16 changes: 9 additions & 7 deletions countess/plugins/data_table.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Any, Iterable, Mapping, Optional
from typing import Iterable, Optional

import pandas as pd

from countess import VERSION
from countess.core.logger import Logger
from countess.core.parameters import ArrayParam, DataTypeChoiceParam, MultiParam, StringParam, TabularMultiParam
from countess.core.plugins import PandasBasePlugin
from countess.core.plugins import PandasInputPlugin


class DataTablePlugin(PandasBasePlugin):
class DataTablePlugin(PandasInputPlugin):
"""DataTable"""

name = "DataTable"
Expand Down Expand Up @@ -73,11 +73,13 @@ def set_parameter(self, key: str, *a, **k):
self.fix_columns()
super().set_parameter(key, *a, **k)

def process_inputs(self, inputs: Mapping[str, Iterable[Any]], logger: Logger, row_limit: Optional[int]) -> Iterable[Any]:
if len(inputs) > 0:
logger.warning(f"{self.name} doesn't take inputs")
raise ValueError(f"{self.name} doesn't take inputs")
def num_files(self):
return 1

def load_file(self, file_number: int, logger: Logger, row_limit: Optional[int] = None) -> Iterable[pd.DataFrame]:
assert file_number == 0
assert isinstance(self.parameters["rows"], ArrayParam)
assert isinstance(self.parameters["columns"], ArrayParam)
self.fix_columns()
values = []
for row in self.parameters["rows"]:
Expand Down
Loading

0 comments on commit fe042d4

Please sign in to comment.