Skip to content

Commit

Permalink
Merge branch 'nick/30-multiple-io', at least the tree tidying feature…
Browse files Browse the repository at this point in the history
…s are

useful already
  • Loading branch information
nickzoic committed Mar 27, 2024
2 parents ac506c9 + 8efba82 commit 8e88373
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 32 deletions.
11 changes: 7 additions & 4 deletions countess/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,10 @@ def load_config(self, logger: Logger):
self.config = None

def prerun(self, logger: Logger, row_limit=PRERUN_ROW_LIMIT):
if not self.plugin:
return
self.load_config(logger)
if self.is_dirty and self.plugin:
if self.is_dirty:
assert isinstance(self.plugin, (ProcessPlugin, FileInputPlugin))
self.result = []
self.plugin.prepare([node.name for node in self.parent_nodes], row_limit)
Expand All @@ -209,9 +211,10 @@ def mark_dirty(self):
child_node.mark_dirty()

def add_parent(self, parent):
self.parent_nodes.add(parent)
parent.child_nodes.add(self)
self.mark_dirty()
if (not self.plugin or self.plugin.num_inputs) and (not parent.plugin or parent.plugin.num_outputs):
self.parent_nodes.add(parent)
parent.child_nodes.add(self)
self.mark_dirty()

def del_parent(self, parent):
self.parent_nodes.discard(parent)
Expand Down
19 changes: 5 additions & 14 deletions countess/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class BasePlugin:
description: str = ""
additional: str = ""
link: Optional[str] = None
num_inputs: int = 1
num_outputs: int = 1

parameters: MutableMapping[str, BaseParam] = {}
show_preview: bool = True
Expand Down Expand Up @@ -182,6 +184,7 @@ class FileInputPlugin(BasePlugin):
file_number = 0
name = ""
row_limit = None
num_inputs = 0

# used by the GUI file dialog
file_types: List[tuple[str, Union[str, list[str]]]] = [("Any", "*")]
Expand Down Expand Up @@ -288,6 +291,7 @@ class PandasProductPlugin(PandasProcessPlugin):
source2 = None
mem1: Optional[List] = None
mem2: Optional[List] = None
num_inputs = 2

def prepare(self, sources: list[str], row_limit: Optional[int] = None):
if len(sources) != 2:
Expand Down Expand Up @@ -655,17 +659,4 @@ def read_file_to_dataframe(self, file_params, logger, row_limit=None) -> pd.Data


class PandasOutputPlugin(PandasProcessPlugin):
def process_inputs(self, inputs: Mapping[str, Iterable[pd.DataFrame]], logger: Logger, row_limit: Optional[int]):
iterators = set(iter(input) for input in inputs.values())

while iterators:
for it in list(iterators):
try:
df_in = next(it)
assert isinstance(df_in, pd.DataFrame)
self.output_dataframe(df_in, logger)
except StopIteration:
iterators.remove(it)

def output_dataframe(self, dataframe: pd.DataFrame, logger: Logger):
raise NotImplementedError(f"{self.__class__}.output_dataframe")
num_outputs = 0
Binary file added countess/gui/icons/hbar.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added countess/gui/icons/redbar.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added countess/gui/icons/vbar.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 13 additions & 2 deletions countess/gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


class PluginChooserFrame(tk.Frame):
def __init__(self, master, title, callback, *a, **k):
def __init__(self, master, title, callback, has_parents, has_children, *a, **k):
super().__init__(master, *a, **k)

self.columnconfigure(0, weight=1)
Expand All @@ -37,6 +37,13 @@ def __init__(self, master, title, callback, *a, **k):
label_frame.grid(row=1, column=0, sticky=tk.EW, padx=10, pady=10)

for n, plugin_class in enumerate(plugin_classes):
if (
(has_parents and plugin_class.num_inputs == 0)
or (not has_parents and plugin_class.num_inputs > 0)
or (has_children and plugin_class.num_outputs == 0)
):
continue

label_text = plugin_class.description
tk.Button(
label_frame,
Expand Down Expand Up @@ -121,7 +128,11 @@ def show_config_subframe(self):
self.frame.rowconfigure(4, weight=1)

else:
self.config_subframe = PluginChooserFrame(self.config_canvas, "Choose Plugin", self.choose_plugin)
has_parents = len(self.node.parent_nodes) > 0
has_children = len(self.node.child_nodes) > 0
self.config_subframe = PluginChooserFrame(
self.config_canvas, "Choose Plugin", self.choose_plugin, has_parents, has_children
)
self.config_subframe.grid(sticky=tk.NSEW)
self.frame.rowconfigure(3, weight=1)
self.frame.rowconfigure(4, weight=0)
Expand Down
40 changes: 30 additions & 10 deletions countess/gui/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from functools import partial

from countess.core.pipeline import PipelineNode

from countess.gui.widgets import get_icon

def _limit(value, min_value, max_value):
return max(min_value, min(max_value, value))
Expand Down Expand Up @@ -254,8 +254,21 @@ class DraggableLabel(DraggableMixin, FixedUnbindMixin, tk.Label):
pass


class DraggableMessage(DraggableMixin, FixedUnbindMixin, tk.Message):
pass
class NodeWrapper(DraggableLabel):

def update_node(self, node, vertical=False):
input_bar = node.plugin and node.plugin.num_inputs == 0
output_bar = node.plugin and node.plugin.num_outputs == 0
if not input_bar and not output_bar:
image = None
compound = tk.NONE
elif vertical:
image=get_icon(self, "hbar")
compound = tk.TOP if input_bar else tk.BOTTOM
else:
image=get_icon(self, "vbar")
compound = tk.LEFT if input_bar else tk.RIGHT
self.configure( text=node.name, image=image, compound=compound)


class GraphWrapper:
Expand Down Expand Up @@ -284,7 +297,7 @@ def __init__(self, canvas, graph, node_select_callback):
self.canvas.bind("<Key-BackSpace>", self.on_canvas_delete)

def label_for_node(self, node):
label = DraggableLabel(self.canvas, text=node.name, wraplength=125, cursor="hand1", takefocus=True)
label = NodeWrapper(self.canvas, wraplength=125, cursor="hand1", takefocus=True)
if not node.position:
node.position = (random.random() * 0.8 + 0.1, random.random() * 0.8 + 0.1)
# XXX should be more elegant way of answering the question "are we flipped?"
Expand Down Expand Up @@ -330,14 +343,16 @@ def on_mousedown(self, node, event):

def on_configure(self, node, label, event):
"""Stores the updated position of the label in node.position"""
xx = float(label.place_info()["relx"]) * self.canvas.winfo_width()
yy = float(label.place_info()["rely"]) * self.canvas.winfo_height()
height = self.canvas.winfo_height()
width = self.canvas.winfo_width()

xx = float(label.place_info()["relx"]) * width
yy = float(label.place_info()["rely"]) * height
node.position = self.new_node_position(xx, yy)
label.update_node(node, width < height)

# Adapt label sizes to suit the window size, as best we can ...
# XXX very arbitrary and definitely open to tweaking
height = self.canvas.winfo_height()
width = self.canvas.winfo_width()
if height > width:
label_max_width = max(width // 9, 25)
label_font_size = int(math.sqrt(width) / 3)
Expand Down Expand Up @@ -499,7 +514,11 @@ def on_ghost_release(self, start_node, event):
self.node_select_callback(other_node)

def add_parent(self, parent_node, child_node):
if parent_node not in child_node.parent_nodes:
if (
(not parent_node.plugin or parent_node.plugin.num_outputs)
and (not child_node.plugin or child_node.plugin.num_inputs)
and parent_node not in child_node.parent_nodes
):
child_node.add_parent(parent_node)
connecting_line = ConnectingLine(self.canvas, self.labels[parent_node], self.labels[child_node])
self.lines[child_node][parent_node] = connecting_line
Expand All @@ -518,7 +537,8 @@ def del_parent(self, parent_node, child_node):
def node_changed(self, node):
"""Called when something external updates the node's name, status
or configuration."""
self.labels[node]["text"] = node.name
flipped = self.canvas.winfo_width() >= self.canvas.winfo_height()
self.labels[node].update_node(node, not flipped)

def destroy(self):
for node_lines in self.lines.values():
Expand Down
4 changes: 2 additions & 2 deletions countess/plugins/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
MultiParam,
StringParam,
)
from countess.core.plugins import PandasInputFilesPlugin, PandasProcessPlugin
from countess.core.plugins import PandasInputFilesPlugin, PandasOutputPlugin
from countess.utils.pandas import flatten_columns


Expand Down Expand Up @@ -119,7 +119,7 @@ def read_file_to_dataframe(self, file_params, logger, row_limit=None):
return df


class SaveCsvPlugin(PandasProcessPlugin):
class SaveCsvPlugin(PandasOutputPlugin):
name = "CSV Save"
description = "Save data as CSV or similar delimited text files"
link = "https://countess-project.github.io/CountESS/included-plugins/#csv-writer"
Expand Down

0 comments on commit 8e88373

Please sign in to comment.