Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: Update (Async)TransitionConfig(Dict) and adjusted tests #685

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Bug #683: Typing wrongly suggested that `Transition` instances can be passed to `Machine.__init__` and/or `Machine.add_transition(s)` (thanks @antonio-antuan)
- Typing should be more precise now
- Made `transitions.core.(Async)TransitionConfigDict` a `TypedDict` which can be used to spot parameter errors during static analysis
- `Machine.add_transitions` and `Machine.__init__` expect a `Sequence` of configurations for transitions now
- Added 'async' callbacks to types in `asyncio` extension

Expand Down
50 changes: 48 additions & 2 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

try:
import asyncio
from transitions.extensions.asyncio import AsyncMachine, HierarchicalAsyncMachine, AsyncEventData
from transitions.extensions.asyncio import AsyncMachine, HierarchicalAsyncMachine, AsyncEventData, \
AsyncTransition

except (ImportError, SyntaxError):
asyncio = None # type: ignore
Expand All @@ -17,7 +18,7 @@
from .test_pygraphviz import pgv

if TYPE_CHECKING:
from typing import Type, Sequence
from typing import Type, Sequence, List
from transitions.extensions.asyncio import AsyncTransitionConfig


Expand Down Expand Up @@ -584,6 +585,51 @@ async def run():

asyncio.run(run())

def test_custom_transition(self):

class MyTransition(self.machine_cls.transition_cls): # type: ignore

def __init__(self, source, dest, conditions=None, unless=None, before=None,
after=None, prepare=None, my_int=None, my_none=None, my_str=None, my_dict=None):
super(MyTransition, self).__init__(source, dest, conditions, unless, before, after, prepare)
self.my_int = my_int
self.my_none = my_none
self.my_str = my_str
self.my_dict = my_dict

class MyMachine(self.machine_cls): # type: ignore
transition_cls = MyTransition

a_transition = {
"trigger": "go", "source": "B", "dest": "A",
"my_int": 42, "my_str": "foo", "my_dict": {"bar": "baz"}
}
transitions = [
["go", "A", "B"],
a_transition
]

m = MyMachine(states=["A", "B"], transitions=transitions, initial="A")
m.add_transition("reset", "*", "A",
my_int=23, my_str="foo2", my_none=None, my_dict={"baz": "bar"})

async def run():
assert await m.go()
trans = m.get_transitions("go", "B") # type: List[MyTransition]
assert len(trans) == 1
assert trans[0].my_str == a_transition["my_str"]
assert trans[0].my_int == a_transition["my_int"]
assert trans[0].my_dict == a_transition["my_dict"]
assert trans[0].my_none is None
trans = m.get_transitions("reset", "A")
assert len(trans) == 1
assert trans[0].my_str == "foo2"
assert trans[0].my_int == 23
assert trans[0].my_dict == {"baz": "bar"}
assert trans[0].my_none is None

asyncio.run(run())


@skipIf(asyncio is None or (pgv is None and gv is None), "AsyncGraphMachine requires asyncio and (py)gaphviz")
class TestAsyncGraphMachine(TestAsync):
Expand Down
49 changes: 45 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
pass

import sys
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List
from functools import partial
from unittest import TestCase, skipIf
import weakref

from transitions import Machine, MachineError, State, EventData
from transitions.core import listify, _prep_ordered_arg
from transitions.core import listify, _prep_ordered_arg, Transition

from .utils import InheritedStuff
from .utils import Stuff, DummyModel
Expand All @@ -23,7 +23,7 @@

if TYPE_CHECKING:
from typing import Sequence
from transitions.core import TransitionConfig, StateConfig
from transitions.core import TransitionConfig, StateConfig, TransitionConfigDict


def on_exit_A(event):
Expand Down Expand Up @@ -570,7 +570,7 @@ def test_pickle(self):
{'trigger': 'walk', 'source': 'A', 'dest': 'B'},
{'trigger': 'run', 'source': 'B', 'dest': 'C'},
{'trigger': 'sprint', 'source': 'C', 'dest': 'D'}
]
] # type: Sequence[TransitionConfigDict]
m = Machine(states=states, transitions=transitions, initial='A')
m.walk()
dump = pickle.dumps(m)
Expand Down Expand Up @@ -1350,3 +1350,44 @@ def test_on_final(self):
self.assertEqual(1, final_mock.call_count)
machine.to_B()
self.assertEqual(2, final_mock.call_count)

def test_custom_transition(self):

class MyTransition(self.machine_cls.transition_cls): # type: ignore

def __init__(self, source, dest, conditions=None, unless=None, before=None,
after=None, prepare=None, my_int=None, my_none=None, my_str=None, my_dict=None):
super(MyTransition, self).__init__(source, dest, conditions, unless, before, after, prepare)
self.my_int = my_int
self.my_none = my_none
self.my_str = my_str
self.my_dict = my_dict

class MyMachine(self.machine_cls): # type: ignore
transition_cls = MyTransition

a_transition = {
"trigger": "go", "source": "B", "dest": "A",
"my_int": 42, "my_str": "foo", "my_dict": {"bar": "baz"}
}
transitions = [
["go", "A", "B"],
a_transition
]

m = MyMachine(states=["A", "B"], transitions=transitions, initial="A")
m.add_transition("reset", "*", "A",
my_int=23, my_str="foo2", my_none=None, my_dict={"baz": "bar"})
assert m.go()
trans = m.get_transitions("go", "B") # type: List[MyTransition]
assert len(trans) == 1
assert trans[0].my_str == a_transition["my_str"]
assert trans[0].my_int == a_transition["my_int"]
assert trans[0].my_dict == a_transition["my_dict"]
assert trans[0].my_none is None
trans = m.get_transitions("reset", "A")
assert len(trans) == 1
assert trans[0].my_str == "foo2"
assert trans[0].my_int == 23
assert trans[0].my_dict == {"baz": "bar"}
assert trans[0].my_none is None
7 changes: 2 additions & 5 deletions tests/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,8 @@ class Model:
def is_B(self) -> bool:
return False

transition_config = [["A", "B"], "C"] # type: TransitionConfig

@add_transitions(transition(source="A", dest="B"))
@add_transitions(transition_config)
@add_transitions([["A", "B"], "C"])
def go(self) -> bool:
raise RuntimeError("Should be overridden!")

Expand Down Expand Up @@ -196,8 +194,7 @@ class Model:
def is_B(self) -> bool:
return False

transition_config = [["A", "B"], "C"] # type: TransitionConfig
go = event(transition(source="A", dest="B"), transition_config)
go = event(transition(source="A", dest="B"), [["A", "B"], "C"], {"source": "*", "dest": None})

model = Model()
machine = self.trigger_machine(model, states=["A", "B", "C"], initial="A")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

if TYPE_CHECKING:
from typing import Type, List, Collection, Union, Literal, Sequence, Dict, Optional
from transitions.core import TransitionConfig
from transitions.core import TransitionConfig, TransitionConfigDict


class TestDiagramsImport(TestCase):
Expand Down Expand Up @@ -75,7 +75,7 @@ def setUp(self):
{'trigger': 'run', 'source': 'B', 'dest': 'C'},
{'trigger': 'sprint', 'source': 'C', 'dest': 'D', 'conditions': 'is_fast'},
{'trigger': 'sprint', 'source': 'C', 'dest': 'B'}
] # type: Sequence[Dict[str, str]]
] # type: Sequence[TransitionConfigDict]

def test_diagram(self):
m = self.machine_cls(states=self.states, transitions=self.transitions, initial='A', auto_transitions=False,
Expand Down Expand Up @@ -327,7 +327,7 @@ def setUp(self):
'conditions': 'is_fast'},
{'trigger': 'sprint', 'source': 'C', 'dest': 'B'}, # + 1 edge
{'trigger': 'reset', 'source': '*', 'dest': 'A'} # + 4 edges (from base state) = 8
] # type: Sequence[Dict[str, str]]
] # type: Sequence[TransitionConfigDict]

def test_diagram(self):
m = self.machine_cls(states=self.states, transitions=self.transitions, initial='A', auto_transitions=False,
Expand Down
11 changes: 5 additions & 6 deletions tests/test_reuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

if TYPE_CHECKING:
from typing import List, Union, Dict, Any, Sequence
from transitions.core import TransitionConfig

from transitions.core import TransitionConfig, TransitionConfigDict

test_states = ['A', 'B', {'name': 'C', 'children': ['1', '2', {'name': '3', 'children': ['a', 'b', 'c']}]},
'D', 'E', 'F']
Expand Down Expand Up @@ -91,7 +90,7 @@ def test_blueprint_reuse(self):
{'trigger': 'decrease', 'source': '3', 'dest': '2'},
{'trigger': 'decrease', 'source': '1', 'dest': '1'},
{'trigger': 'reset', 'source': '*', 'dest': '1'},
]
] # type: Sequence[TransitionConfigDict]

counter = self.machine_cls(states=states, transitions=transitions, before_state_change='check',
after_state_change='clear', initial='1')
Expand All @@ -103,7 +102,7 @@ def test_blueprint_reuse(self):
{'trigger': 'backward', 'source': 'C', 'dest': 'B'},
{'trigger': 'backward', 'source': 'B', 'dest': 'A'},
{'trigger': 'calc', 'source': '*', 'dest': 'C'},
]
] # type: Sequence[TransitionConfigDict]

walker = self.machine_cls(states=new_states, transitions=new_transitions, before_state_change='watch',
after_state_change='look_back', initial='A')
Expand Down Expand Up @@ -144,7 +143,7 @@ def test_blueprint_remap(self):
{'trigger': 'decrease', 'source': '1', 'dest': '1'},
{'trigger': 'reset', 'source': '*', 'dest': '1'},
{'trigger': 'done', 'source': '3', 'dest': 'finished'}
]
] # type: Sequence[TransitionConfigDict]

counter = self.machine_cls(states=states, transitions=transitions, initial='1')

Expand All @@ -158,7 +157,7 @@ def test_blueprint_remap(self):
{'trigger': 'backward', 'source': 'C', 'dest': 'B'},
{'trigger': 'backward', 'source': 'B', 'dest': 'A'},
{'trigger': 'calc', 'source': '*', 'dest': 'C%s1' % State.separator},
]
] # type: Sequence[TransitionConfigDict]

walker = self.machine_cls(states=new_states, transitions=new_transitions, before_state_change='watch',
after_state_change='look_back', initial='A')
Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ class Stuff(object):
is_false = False
is_True = True

def __init__(self, states=None, machine_cls=Machine, extra_kwargs={}):
def __init__(self, states=None, machine_cls=Machine, extra_kwargs=None):
extra_kwargs = extra_kwargs if extra_kwargs is not None else {}

self.state = None
self.message = None
Expand Down
33 changes: 22 additions & 11 deletions transitions/core.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from logging import Logger
from typing import (
Any, Optional, Callable, Sequence, Union, Iterable, List, Dict, DefaultDict,
Type, Deque, OrderedDict, Tuple, Literal, Collection, TypedDict, Mapping
Type, Deque, OrderedDict, Tuple, Literal, Collection, TypedDict, Required
)

# Enums are supported for Python 3.4+ and Python 2.7 with enum34 package installed
Expand Down Expand Up @@ -91,8 +91,19 @@ TransitionConfigList = Union[
List[str], List[Sequence[str]], List[Optional[str]],
List[Union[str, Enum]], List[Optional[Union[str, Enum]]]
]
TransitionConfigDict = Mapping[str, Union[None, StateConfig, Callback, Iterable[Callback]]]
TransitionConfig = Union[TransitionConfigList, TransitionConfigDict]

class TransitionConfigDict(TypedDict, total=False):
trigger: Required[str]
source: Required[Union[str, Enum, Sequence[Union[str, Enum]]]]
dest: Required[Optional[Union[str, Enum]]]
prepare: CallbacksArg
before: CallbacksArg
after: CallbacksArg
conditions: CallbacksArg
unless: CallbacksArg

# For backwards compatibility we also accept generic collections
TransitionConfig = Union[TransitionConfigList, TransitionConfigDict, Collection[str]]

class EventData:
state: State
Expand All @@ -115,7 +126,7 @@ class Event:
transitions: DefaultDict[str, List[Transition]]
def __init__(self, name: str, machine: Machine) -> None: ...
def add_transition(self, transition: Transition) -> None: ...
def trigger(self, model: object, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...
def trigger(self, model: object, *args: Any, **kwargs: Any) -> bool: ...
def _trigger(self, event_data: EventData) -> bool: ...
def _process(self, event_data: EventData) -> bool: ...
def _is_valid_source(self, state: State) -> bool: ...
Expand Down Expand Up @@ -157,7 +168,7 @@ class Machine:
name: str = ..., queued: bool = ...,
prepare_event: CallbacksArg = ..., finalize_event: CallbacksArg = ...,
model_attribute: str = ..., model_override: bool = ...,
on_exception: CallbacksArg = ..., on_final: CallbacksArg = ..., **kwargs: Dict[str, Any]) -> None: ...
on_exception: CallbacksArg = ..., on_final: CallbacksArg = ..., **kwargs: Any) -> None: ...
def add_model(self, model: ModelParameter,
initial: Optional[StateIdentifier] = ...) -> None: ...
def remove_model(self, model: ModelParameter) -> None: ...
Expand Down Expand Up @@ -201,21 +212,21 @@ class Machine:
def set_state(self, state: StateIdentifier, model: Optional[object] = ...) -> None: ...
def add_state(self, states: Union[Sequence[StateConfig], StateConfig],
on_enter: CallbacksArg = ..., on_exit: CallbacksArg = ...,
ignore_invalid_triggers: Optional[bool] = ..., **kwargs: Dict[str, Any]) -> None: ...
ignore_invalid_triggers: Optional[bool] = ..., **kwargs: Any) -> None: ...
def add_states(self, states: Union[Sequence[StateConfig], StateConfig],
on_enter: CallbacksArg = ..., on_exit: CallbacksArg = ...,
ignore_invalid_triggers: Optional[bool] = ..., **kwargs: Dict[str, Any]) -> None: ...
ignore_invalid_triggers: Optional[bool] = ..., **kwargs: Any) -> None: ...
def _add_model_to_state(self, state: State, model: object) -> None: ...
def _checked_assignment(self, model: object, name: str, func: CallbackFunc) -> None: ...
def _add_trigger_to_model(self, trigger: str, model: object) -> None: ...
def _get_trigger(self, model: object, trigger_name: str, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...
def _get_trigger(self, model: object, trigger_name: str, *args: Any, **kwargs: Any) -> bool: ...
def get_triggers(self, *args: Union[str, Enum, State]) -> List[str]: ...
def add_transition(self, trigger: str,
source: Union[StateIdentifier, List[StateIdentifier]],
dest: Optional[StateIdentifier] = ...,
conditions: CallbacksArg = ..., unless: CallbacksArg = ...,
before: CallbacksArg = ..., after: CallbacksArg = ..., prepare: CallbacksArg = ...,
**kwargs: Dict[str, Any]) -> None: ...
**kwargs: Any) -> None: ...
def add_transitions(self, transitions: Sequence[TransitionConfig]) -> None: ...
def add_ordered_transitions(self, states: Optional[Sequence[Union[str, State]]] = ...,
trigger: str = ..., loop: bool = ...,
Expand All @@ -225,11 +236,11 @@ class Machine:
before: Optional[Sequence[Union[Callback, None]]] = ...,
after: Optional[Sequence[Union[Callback, None]]] = ...,
prepare: CallbacksArg = ...,
**kwargs: Dict[str, Any]) -> None: ...
**kwargs: Any) -> None: ...
def get_transitions(self, trigger: str = ...,
source: StateIdentifier = ..., dest: StateIdentifier = ...) -> List[Transition]: ...
def remove_transition(self, trigger: str, source: str = ..., dest: str = ...) -> None: ...
def dispatch(self, trigger: str, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...
def dispatch(self, trigger: str, *args: Any, **kwargs: Any) -> bool: ...
def callbacks(self, funcs: Iterable[Callback], event_data: EventData) -> None: ...
def callback(self, func: Callback, event_data: EventData) -> None: ...
@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions transitions/experimental/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def generate_base_model(config):
f" def may_{trigger_name}(self) -> bool: {_placeholder_body}\n"
)

extra_params = "event_data: EventData" if m.send_event else "*args: List[Any], **kwargs: Dict[str, Any]"
extra_params = "event_data: EventData" if m.send_event else "*args: Any, **kwargs: Any"
for callback_name in callbacks:
if isinstance(callback_name, str):
callback_block += (f" @abstractmethod\n"
Expand Down Expand Up @@ -98,7 +98,7 @@ def add_model_override(self, model, initial=None):
self.model_override = True
for model in listify(model):
model = self if model == "self" else model
for name, specs in TriggerPlaceholder.definitions.get(model.__class__).items():
for name, specs in TriggerPlaceholder.definitions.get(model.__class__, {}).items():
for spec in specs:
if isinstance(spec, list):
self.add_transition(name, *spec)
Expand Down
Loading