diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 40eb76849..100b41230 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -4,7 +4,7 @@ from enum import Enum from copy import copy from dataclasses import dataclass -from typing import Any, Awaitable, Callable, NamedTuple, Literal +from typing import Any, Awaitable, Callable, NamedTuple, Literal, TYPE_CHECKING from pathlib import Path from PyQt5.QtCore import Qt, QObject, QUuid, QAbstractListModel, QSortFilterProxyModel, QModelIndex from PyQt5.QtCore import pyqtSignal @@ -13,7 +13,6 @@ from .comfy_workflow import ComfyWorkflow from .connection import Connection from .image import Bounds, Image -from .layer import LayerManager from .jobs import Job, JobParams, JobQueue, JobKind from .properties import Property, ObservableProperties from .style import Styles @@ -21,6 +20,9 @@ from .ui import theme from . import eventloop +if TYPE_CHECKING: + from .layer import LayerManager + class WorkflowSource(Enum): document = 0 @@ -390,7 +392,7 @@ def graph(self): def metadata(self): return self._metadata - def collect_parameters(self, layers: LayerManager, bounds: Bounds): + def collect_parameters(self, layers: "LayerManager", bounds: Bounds): params = copy(self.params) for md in self.metadata: param = params.get(md.name) diff --git a/ai_diffusion/jobs.py b/ai_diffusion/jobs.py index 962145be1..412d2d30d 100644 --- a/ai_diffusion/jobs.py +++ b/ai_diffusion/jobs.py @@ -3,14 +3,16 @@ from dataclasses import dataclass, fields, field from datetime import datetime from enum import Enum, Flag -from typing import Any, Deque, NamedTuple +from typing import Any, Deque, NamedTuple, TYPE_CHECKING from PyQt5.QtCore import QObject, pyqtSignal from .image import Bounds, ImageCollection from .settings import settings from .style import Style from .util import ensure -from . import control + +if TYPE_CHECKING: + from . import control class JobState(Flag): diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index e3fd9b177..38d52dbd0 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -11,6 +11,7 @@ from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace from ai_diffusion.custom_workflow import CustomParam, ParamKind, workflow_parameters from ai_diffusion.image import Image, Extent +from ai_diffusion.jobs import JobQueue from ai_diffusion.style import Style from ai_diffusion.resources import Arch from ai_diffusion import workflow @@ -141,13 +142,18 @@ def test_files(tmp_path: Path): collection.import_file(bad_file) +async def dummy_generate(workflow_input): + return None + + def test_workspace(): connection = Connection() connection_workflows = {"connection1": make_dummy_graph(42)} connection._workflows = connection_workflows workflows = WorkflowCollection(connection) - workspace = CustomWorkspace(workflows) + jobs = JobQueue() + workspace = CustomWorkspace(workflows, dummy_generate, jobs) assert workspace.workflow_id == "connection1" assert workspace.workflow and workspace.workflow.id == "connection1" assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter"