Skip to content

Commit

Permalink
Merge branch 'main' into nick/columnorstringparam
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzoic committed Apr 5, 2024
2 parents b963325 + 5f617f8 commit 1c32a04
Show file tree
Hide file tree
Showing 66 changed files with 1,309 additions and 305 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# CountESS 0.0.51
# CountESS 0.0.57

This is CountESS, a modular, Python 3 reimplementation of Enrich2.

Expand All @@ -10,7 +10,8 @@ Source code is available at [https://github.com/CountESS-Project/CountESS](https

## Installing

The latest version of CountESS can be installed from pypi:
The latest version of
[CountESS can be installed from pypi](https://pypi.org/project/countess/):
```
pip install CountESS
```
Expand Down
2 changes: 1 addition & 1 deletion countess/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""CountESS Project"""

VERSION = "0.0.51"
VERSION = "0.0.57"
148 changes: 81 additions & 67 deletions countess/core/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import io
import os.path
import re
import sys
Expand All @@ -17,6 +18,45 @@ def default_output_callback(output):
sys.stderr.write(repr(output))


def read_config_dict(name: str, base_dir: str, config_dict: dict) -> PipelineNode:
if "_module" in config_dict:
module_name = config_dict["_module"]
class_name = config_dict["_class"]
# XXX version = config_dict.get("_version")
# XXX hash_digest = config_dict.get("_hash")
plugin = load_plugin(module_name, class_name)
else:
plugin = None

position_str = config_dict.get("_position")
notes = config_dict.get("_notes")

position = None
if position_str:
position_match = re.match(r"(\d+) (\d+)$", position_str)
if position_match:
position = (
int(position_match.group(1)) / 1000,
int(position_match.group(2)) / 1000,
)

sort = config_dict.get("_sort", "0 0").split()

# XXX check version and hash_digest and emit warnings.

config = [(key, ast.literal_eval(val), base_dir) for key, val in config_dict.items() if not key.startswith("_")]

return PipelineNode(
name=name,
plugin=plugin,
config=config,
position=position,
notes=notes,
sort_column=int(sort[0]),
sort_descending=bool(int(sort[1])),
)


def read_config(
filename: str,
logger: Logger = ConsoleLogger(),
Expand All @@ -32,46 +72,14 @@ def read_config(
nodes_by_name: dict[str, PipelineNode] = {}

for section_name in cp.sections():
config_dict = cp[section_name]

if "_module" in config_dict:
module_name = config_dict["_module"]
class_name = config_dict["_class"]
# XXX version = config_dict.get("_version")
# XXX hash_digest = config_dict.get("_hash")
plugin = load_plugin(module_name, class_name)
else:
plugin = None

position_str = config_dict.get("_position")
notes = config_dict.get("_notes")

position = None
if position_str:
position_match = re.match(r"(\d+) (\d+)$", position_str)
if position_match:
position = (
int(position_match.group(1)) / 1000,
int(position_match.group(2)) / 1000,
)

# XXX check version and hash_digest and emit warnings.

config = [(key, ast.literal_eval(val), base_dir) for key, val in config_dict.items() if not key.startswith("_")]

node = PipelineNode(
name=section_name,
plugin=plugin,
config=config,
position=position,
notes=notes,
)
pipeline_graph.nodes.append(node)
config_dict = dict(cp[section_name])
node = read_config_dict(section_name, base_dir, config_dict)

for key, val in config_dict.items():
if key.startswith("_parent."):
node.add_parent(nodes_by_name[val])

pipeline_graph.nodes.append(node)
nodes_by_name[section_name] = node

return pipeline_graph
Expand All @@ -80,47 +88,53 @@ def read_config(
def write_config(pipeline_graph: PipelineGraph, filename: str):
"""Write `pipeline_graph`'s configuration out to `filename`"""

pipeline_graph.reset_node_names()

cp = ConfigParser()
base_dir = os.path.dirname(filename)

node_names_seen = set()
for node in pipeline_graph.traverse_nodes():
while node.name in node_names_seen:
num = 0
if match := re.match(r"(.*?)\s+(\d+)$", node.name):
node.name = match.group(1)
num = int(match.group(2))
node.name += f" {num + 1}"
node_names_seen.add(node.name)

cp.add_section(node.name)
if node.plugin:
cp[node.name].update(
{
"_module": node.plugin.__module__,
"_class": node.plugin.__class__.__name__,
"_version": node.plugin.version,
"_hash": node.plugin.hash(),
}
)
if node.position:
xx, yy = node.position
cp[node.name]["_position"] = "%d %d" % (xx * 1000, yy * 1000)
if node.notes:
cp[node.name]["_notes"] = node.notes
for n, parent in enumerate(node.parent_nodes):
cp[node.name][f"_parent.{n}"] = parent.name
if node.config:
for k, v, _ in node.config:
cp[node.name][k] = repr(v)
elif node.plugin:
for k, v in node.plugin.get_parameters(base_dir):
cp[node.name][k] = repr(v)
write_config_node(node, cp, base_dir)

with open(filename, "w", encoding="utf-8") as fh:
cp.write(fh)


def write_config_node_string(node: PipelineNode, base_dir: str = ""):
cp = ConfigParser()
write_config_node(node, cp, base_dir)
buf = io.StringIO()
cp.write(buf)
return buf.getvalue()


def write_config_node(node: PipelineNode, cp: ConfigParser, base_dir: str):
cp.add_section(node.name)
if node.plugin:
cp[node.name].update(
{
"_module": node.plugin.__module__,
"_class": node.plugin.__class__.__name__,
"_version": node.plugin.version,
"_hash": node.plugin.hash(),
"_sort": "%d %d" % (node.sort_column, 1 if node.sort_descending else 0),
}
)
if node.position:
xx, yy = node.position
cp[node.name]["_position"] = "%d %d" % (xx * 1000, yy * 1000)
if node.notes:
cp[node.name]["_notes"] = node.notes
for n, parent in enumerate(node.parent_nodes):
cp[node.name][f"_parent.{n}"] = parent.name
if node.config:
for k, v, _ in node.config:
cp[node.name][k] = repr(v)
elif node.plugin:
for k, v in node.plugin.get_parameters(base_dir):
cp[node.name][k] = repr(v)


def export_config_graphviz(pipeline_graph: PipelineGraph, filename: str):
with open(filename, "w", encoding="utf-8") as fh:
fh.write("digraph {\n")
Expand Down
34 changes: 7 additions & 27 deletions countess/core/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,29 +158,6 @@ def copy(self) -> "StringCharacterSetParam":
return self.__class__(self.label, self.value, self.read_only, character_set=self.character_set)


def clean_file_types(file_types):
# MacOS in particular is crashy if file_types is not to it's liking.
# This leads to very confusing errors. Better to throw an assertion
# error here than bomb out later. See #27.

# XXX MacOS also doesn't seem to handle multiple file
# extensions, eg: .csv.gz, whereas Linux can.
# So maybe this function could accept those and censor
# them if running MacOS.

assert type(file_types) is list
for ft in file_types:
assert type(ft) in (tuple, list)
assert type(ft[0]) is str
assert type(ft[1]) in (tuple, list, str)
if type(ft[1]) in (tuple, list):
for ext in ft[1]:
assert type(ext) is str
assert ext == "*" or ext == "" or re.match(r"\.\w+$", ext), f"Invalid FileType Extension {ext}"

return file_types


class FileParam(StringParam):
"""A StringParam for holding a filename. Defaults to `read_only` because
it really should be populated from a file dialog or simiar."""
Expand All @@ -192,7 +169,7 @@ class FileParam(StringParam):
def __init__(self, label: str, value=None, read_only: bool = True, file_types=None):
super().__init__(label, value, read_only)
if file_types is not None:
self.file_types = clean_file_types(file_types)
self.file_types = file_types

def get_file_hash(self):
if not self.value:
Expand All @@ -215,11 +192,14 @@ def get_file_hash(self):

def get_parameters(self, key, base_dir="."):
if self.value:
relpath = os.path.relpath(self.value, base_dir)
if base_dir:
path = os.path.relpath(self.value, base_dir)
else:
path = os.path.abspath(self.value)
else:
relpath = None
path = None

return [(key, relpath)]
return [(key, path)]

def copy(self) -> "FileParam":
return self.__class__(self.label, self.value, self.read_only, file_types=self.file_types)
Expand Down
39 changes: 34 additions & 5 deletions countess/core/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import time
from queue import Empty, Queue
from threading import Thread
Expand Down Expand Up @@ -59,6 +60,8 @@ class PipelineNode:
name: str
plugin: Optional[BasePlugin] = None
position: Optional[tuple[float, float]] = None
sort_column: int = 0
sort_descending: bool = False
notes: Optional[str] = None
parent_nodes: set["PipelineNode"]
child_nodes: set["PipelineNode"]
Expand All @@ -74,11 +77,13 @@ class PipelineNode:
# at config load time, if it is present it is loaded the
# first time the plugin is prerun.

def __init__(self, name, plugin=None, config=None, position=None, notes=None):
def __init__(self, name, plugin=None, config=None, position=None, notes=None, sort_column=0, sort_descending=0):
self.name = name
self.plugin = plugin
self.config = config or []
self.position = position
self.sort_column = sort_column
self.sort_descending = sort_descending
self.notes = notes
self.parent_nodes = set()
self.child_nodes = set()
Expand Down Expand Up @@ -182,8 +187,10 @@ def load_config(self, logger: Logger):
self.config = None

def prerun(self, logger: Logger, row_limit=PRERUN_ROW_LIMIT):
if not self.plugin:
return
self.load_config(logger)
if self.is_dirty and self.plugin:
if self.is_dirty:
assert isinstance(self.plugin, (ProcessPlugin, FileInputPlugin))
self.result = []
self.plugin.prepare([node.name for node in self.parent_nodes], row_limit)
Expand All @@ -205,9 +212,10 @@ def mark_dirty(self):
child_node.mark_dirty()

def add_parent(self, parent):
self.parent_nodes.add(parent)
parent.child_nodes.add(self)
self.mark_dirty()
if (not self.plugin or self.plugin.num_inputs) and (not parent.plugin or parent.plugin.num_outputs):
self.parent_nodes.add(parent)
parent.child_nodes.add(self)
self.mark_dirty()

def del_parent(self, parent):
self.parent_nodes.discard(parent)
Expand Down Expand Up @@ -248,7 +256,17 @@ def __init__(self):
self.plugin_classes = get_plugin_classes()
self.nodes = []

def reset_node_name(self, node):
node_names_seen = set(n.name for n in self.nodes if n != node)
while node.name in node_names_seen:
num = 1
if match := re.match(r"(.*?)\s+(\d+)$", node.name):
node.name = match.group(1)
num = int(match.group(2))
node.name += f" {num + 1}"

def add_node(self, node):
self.reset_node_name(node)
self.nodes.append(node)

def del_node(self, node):
Expand Down Expand Up @@ -298,6 +316,17 @@ def reset(self):
node.result = None
node.is_dirty = True

def reset_node_names(self):
node_names_seen = set()
for node in self.traverse_nodes():
while node.name in node_names_seen:
num = 0
if match := re.match(r"(.*?)\s+(\d+)$", node.name):
node.name = match.group(1)
num = int(match.group(2))
node.name += f" {num + 1}"
node_names_seen.add(node.name)

def tidy(self):
"""Tidies the graph (sets all the node positions)"""

Expand Down
Loading

0 comments on commit 1c32a04

Please sign in to comment.