Skip to content

Commit

Permalink
Add DelayedVariable class
Browse files Browse the repository at this point in the history
  • Loading branch information
mhthies committed Dec 27, 2023
1 parent 1eaa7e9 commit ca2badd
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 7 deletions.
6 changes: 6 additions & 0 deletions docs/variables_expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ Equally, they can be retrieved via the :meth:`VariableField.field` method.
Use :class:`shc.misc.UpdateExchange` to split up NamedTuple-based value updates in a stateless way:
It provides an equal way for subscribing to fields of the NamedTuple via the :meth:`shc.misc.UpdateExchange.field` method but does not store the latest value and does not suppress value updates with unchanged values.

DelayedVariable
^^^^^^^^^^^^^^^

.. autoclass:: shc.variables.DelayedVariable


.. _expressions:

Expressions
Expand Down
2 changes: 2 additions & 0 deletions shc/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,8 @@ class RateLimitedSubscription(Subscribable[T], Generic[T]):
A transparent wrapper for `Subscribable` objects, that delays and drops values to make sure that a given maximum
rate of new values is not exceeded.
See also :class:`shc.variables.DelayedVariable` for a similar (but slightly different) behaviour.
:param wrapped: The Subscribable object to be wrapped
:param min_interval: The minimal allowed interval between published values in seconds
"""
Expand Down
67 changes: 60 additions & 7 deletions shc/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
# specific language governing permissions and limitations under the License.

import asyncio
import datetime
import logging
import warnings
from typing import Generic, Type, Optional, List, Any, Union, Dict

from . import timer
from .base import Writable, T, Readable, Subscribable, UninitializedError, Reading
from .expressions import ExpressionWrapper

Expand Down Expand Up @@ -46,7 +48,7 @@ def __init__(self, type_: Type[T], name: Optional[str] = None, initial_value: Op
self._value: Optional[T] = initial_value
self._variable_fields: Dict[str, "VariableField"] = {}

# Create VariableFields for each typeannotated field of the type if it is typing.NamedTuple-based.
# Create VariableFields for each type-annotated field of the type if it is typing.NamedTuple-based.
if issubclass(type_, tuple) and type_.__annotations__:
for name, field_type in type_.__annotations__.items():
variable_field = VariableField(self, name, field_type)
Expand All @@ -73,11 +75,16 @@ async def _write(self, value: T, origin: List[Any]) -> None:
self._value = value
if old_value != value: # if a single field is different, the full value will also be different
logger.info("New value %s for Variable %s from %s", value, self, origin[:1])
self._publish(value, origin)
for name, field in self._variable_fields.items():
field._recursive_publish(getattr(value, name),
None if old_value is None else getattr(old_value, name),
origin)
self._do_all_publish(old_value, origin)

def _do_all_publish(self, old_value: Optional[T], origin: List[Any]) -> None:
logger.debug("Publishing value %s for Variable %s", self._value, self)
assert self._value is not None
self._publish(self._value, origin)
for name, field in self._variable_fields.items():
field._recursive_publish(getattr(self._value, name),
None if old_value is None else getattr(old_value, name),
origin)

async def read(self) -> T:
if self._value is None:
Expand All @@ -96,7 +103,7 @@ def EX(self) -> ExpressionWrapper:

def __repr__(self) -> str:
if self.name:
return "<Variable \"{}\">".format(self.name)
return "<{} \"{}\">".format(self.__class__.__name__, self.name)
else:
return super().__repr__()

Expand Down Expand Up @@ -158,3 +165,49 @@ async def read(self) -> T:
@property
def EX(self) -> ExpressionWrapper:
return ExpressionWrapper(self)


class DelayedVariable(Variable[T], Generic[T]):
"""
A Variable object, which delays the updates to avoid publishing half-updated values
This is achieved by delaying the publishing of a newly received value by a configurable amount of time
(`publish_delay`). If more value updates are received while a previous update publishing is still pending, the
latest value will be published at the originally scheduled publishing time. There will be no publishing of the
intermediate values. The next value update received after the publishing will be delayed by the configured delay
time again, resulting in a maximum update interval of the specified delay time.
This is similar (but slightly different) to the behaviour of :class:`shc.misc.RateLimitedSubscription`.
:param type_: The Variable's value type (used for its ``.type`` attribute, i.e. for the *Connectable* type
checking mechanism)
:param name: An optional name of the variable. Used for logging and future displaying purposes.
:param initial_value: An optional initial value for the Variable. If not provided and no default provider is
set via :meth:`set_provider`, the Variable is initialized with a None value and any :meth:`read` request
will raise an :exc:`shc.base.UninitializedError` until the first value update is received.
:param publish_delay: Amount of time to delay the publishing of a new value.
"""
def __init__(self, type_: Type[T], name: Optional[str] = None, initial_value: Optional[T] = None,
publish_delay: datetime.timedelta = datetime.timedelta(seconds=0.25)):
super().__init__(type_, name, initial_value)
self._publish_delay = publish_delay
self._pending_publish_task: Optional[asyncio.Task] = None
self._latest_origin: List[Any] = []

async def _write(self, value: T, origin: List[Any]) -> None:
old_value = self._value
self._value = value
self._latest_origin = origin
if old_value != value: # if a single field is different, the full value will also be different
logger.info("New value %s for Variable %s from %s", value, self, origin[:1])
if not self._pending_publish_task:
self._pending_publish_task = asyncio.create_task(self._wait_and_publish(old_value))
timer.timer_supervisor.add_temporary_task(self._pending_publish_task)

async def _wait_and_publish(self, old_value: Optional[T]) -> None:
try:
await asyncio.sleep(self._publish_delay.total_seconds())
except asyncio.CancelledError:

Check warning on line 210 in shc/variables.py

View check run for this annotation

Codecov / codecov/patch

shc/variables.py#L210

Added line #L210 was not covered by tests
pass
self._do_all_publish(old_value, self._latest_origin)
self._pending_publish_task = None
36 changes: 36 additions & 0 deletions test/test_variables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import datetime
import unittest
import unittest.mock
import warnings
Expand Down Expand Up @@ -292,6 +293,41 @@ async def test_concurrent_field_update_publishing(self) -> None:
self.assertEqual(writable1._write.call_args[0][0], writable3._write.call_args[0][0])


class DelayedVariableTest(unittest.TestCase):
@async_test
async def test_simple(self):
var = variables.DelayedVariable(int, name="A test variable", publish_delay=datetime.timedelta(seconds=0.02))
subscriber = ExampleWritable(int)
var.subscribe(subscriber)

await var.write(5, [])
self.assertEqual(5, await var.read())
await asyncio.sleep(0)
await var.write(42, [self])
self.assertEqual(42, await var.read())
await asyncio.sleep(0.025)
subscriber._write.assert_called_once_with(42, [self, var])

@async_test
async def test_field_update(self):
var = variables.DelayedVariable(ExampleTupleType,
name="A test variable",
initial_value=ExampleTupleType(0, 0.0),
publish_delay=datetime.timedelta(seconds=0.02))
field_subscriber = ExampleWritable(int)
subscriber = ExampleWritable(ExampleTupleType)
var.subscribe(subscriber)
var.field('a').subscribe(field_subscriber)

await var.field('a').write(21, [self])
await asyncio.sleep(0)
await var.field('b').write(3.1416, [self])
self.assertEqual(ExampleTupleType(21, 3.1416), await var.read())
await asyncio.sleep(0.025)
subscriber._write.assert_called_once_with(ExampleTupleType(21, 3.1416), [self, var.field('b'), var])
field_subscriber._write.assert_called_once_with(21, [self, var.field('b'), var.field('a')])


class MyPyPluginTest(unittest.TestCase):
def test_mypy_plugin_variable(self) -> None:
asset_dir = Path(__file__).parent / 'assets' / 'mypy_plugin_test'
Expand Down

0 comments on commit ca2badd

Please sign in to comment.