Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Oct 10, 2024
1 parent 0e9d24e commit 0e5dafc
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
8 changes: 5 additions & 3 deletions ai_diffusion/custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from copy import copy
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, NamedTuple, Literal
from typing import Any, Awaitable, Callable, NamedTuple, Literal, TYPE_CHECKING
from pathlib import Path
from PyQt5.QtCore import Qt, QObject, QUuid, QAbstractListModel, QSortFilterProxyModel, QModelIndex
from PyQt5.QtCore import pyqtSignal
Expand All @@ -13,14 +13,16 @@
from .comfy_workflow import ComfyWorkflow
from .connection import Connection
from .image import Bounds, Image
from .layer import LayerManager
from .jobs import Job, JobParams, JobQueue, JobKind
from .properties import Property, ObservableProperties
from .style import Styles
from .util import user_data_dir, client_logger as log
from .ui import theme
from . import eventloop

if TYPE_CHECKING:
from .layer import LayerManager


class WorkflowSource(Enum):
document = 0
Expand Down Expand Up @@ -390,7 +392,7 @@ def graph(self):
def metadata(self):
return self._metadata

def collect_parameters(self, layers: LayerManager, bounds: Bounds):
def collect_parameters(self, layers: "LayerManager", bounds: Bounds):
params = copy(self.params)
for md in self.metadata:
param = params.get(md.name)
Expand Down
6 changes: 4 additions & 2 deletions ai_diffusion/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
from dataclasses import dataclass, fields, field
from datetime import datetime
from enum import Enum, Flag
from typing import Any, Deque, NamedTuple
from typing import Any, Deque, NamedTuple, TYPE_CHECKING
from PyQt5.QtCore import QObject, pyqtSignal

from .image import Bounds, ImageCollection
from .settings import settings
from .style import Style
from .util import ensure
from . import control

if TYPE_CHECKING:
from . import control


class JobState(Flag):
Expand Down
8 changes: 7 additions & 1 deletion tests/test_custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace
from ai_diffusion.custom_workflow import CustomParam, ParamKind, workflow_parameters
from ai_diffusion.image import Image, Extent
from ai_diffusion.jobs import JobQueue
from ai_diffusion.style import Style
from ai_diffusion.resources import Arch
from ai_diffusion import workflow
Expand Down Expand Up @@ -141,13 +142,18 @@ def test_files(tmp_path: Path):
collection.import_file(bad_file)


async def dummy_generate(workflow_input):
return None


def test_workspace():
connection = Connection()
connection_workflows = {"connection1": make_dummy_graph(42)}
connection._workflows = connection_workflows
workflows = WorkflowCollection(connection)

workspace = CustomWorkspace(workflows)
jobs = JobQueue()
workspace = CustomWorkspace(workflows, dummy_generate, jobs)
assert workspace.workflow_id == "connection1"
assert workspace.workflow and workspace.workflow.id == "connection1"
assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter"
Expand Down

0 comments on commit 0e5dafc

Please sign in to comment.