Skip to content

Commit

Permalink
[behaviours] add ProbabilisticBehaviour(Behaviour)
Browse files Browse the repository at this point in the history
  • Loading branch information
gitpushoriginmaster authored Feb 21, 2023
1 parent ff5d051 commit 6e97964
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
45 changes: 45 additions & 0 deletions py_trees/behaviours.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
import functools
import operator
import random
import typing

from . import behaviour, blackboard, common, meta
Expand Down Expand Up @@ -736,3 +737,47 @@ def update(self) -> common.Status:
"|".join(["T" if result else "F" for result in results])
)
return common.Status.FAILURE


class ProbabilisticBehaviour(behaviour.Behaviour):
"""
Return a status based on a probability distribution. If unspecified - a uniform distribution will be used.
Args:
name: name of the behaviour
weights: 3 probabilities that correspond to returning :data:`~py_trees.common.Status.SUCCESS`,
:data:`~py_trees.common.Status.FAILURE` and :data:`~py_trees.common.Status.RUNNING` respectively.
.. note:: Probability distribution does not need to be normalised, it will be normalised internally.
Raises:
ValueError if only some probabilities are specified
"""

def __init__(self, name: str, weights: typing.Optional[typing.List[float]] = None):
if weights is not None and (type(weights) is not list or len(weights) != 3):
raise ValueError(
"Either all or none of the probabilities must be specified"
)

super(ProbabilisticBehaviour, self).__init__(name=name)

self._population = [
common.Status.SUCCESS,
common.Status.FAILURE,
common.Status.RUNNING,
]
self._weights = weights if weights is not None else [1.0, 1.0, 1.0]

def update(self) -> common.Status:
"""
Return a status based on a probability distribution.
Returns:
:data:`~py_trees.common.Status.SUCCESS` with probability weights[0],
:data:`~py_trees.common.Status.FAILURE` with probability weights[1] and
:data:`~py_trees.common.Status.RUNNING` with probability weights[2].
"""
self.logger.debug("%s.update()" % self.__class__.__name__)
return random.choices(self._population, self._weights, k=1)[0]
57 changes: 57 additions & 0 deletions tests/test_probabilistic_behaviour.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python
#
# License: BSD
# https://raw.githubusercontent.com/splintered-reality/py_trees/devel/LICENSE
#

##############################################################################
# Imports
##############################################################################

import py_trees
import py_trees.console as console
import py_trees.tests
import pytest

##############################################################################
# Logging Level
##############################################################################

py_trees.logging.level = py_trees.logging.Level.DEBUG
logger = py_trees.logging.Logger("Tests")

##############################################################################
# Tests
##############################################################################


def test_probabilistic_behaviour_workflow() -> None:
console.banner("Probabilistic Behaviour")

with pytest.raises(ValueError) as context: # if raised, context survives
# intentional error -> silence mypy
unused_root = py_trees.behaviours.ProbabilisticBehaviour( # noqa: F841 [unused]
name="ProbabilisticBehaviour", weights="invalid_type" # type: ignore[arg-type]
)
py_trees.tests.print_assert_details("ValueError raised", "raised", "not raised")
py_trees.tests.print_assert_details("ValueError raised", "yes", "yes")
assert "ValueError" == context.typename

root = py_trees.behaviours.ProbabilisticBehaviour(
name="ProbabilisticBehaviour", weights=[0.0, 0.0, 1.0]
)

py_trees.tests.print_assert_details(
text="task not yet ticked",
expected=py_trees.common.Status.INVALID,
result=root.status,
)
assert root.status == py_trees.common.Status.INVALID

root.tick_once()
py_trees.tests.print_assert_details(
text="task ticked once",
expected=py_trees.common.Status.RUNNING,
result=root.status,
)
assert root.status == py_trees.common.Status.RUNNING

0 comments on commit 6e97964

Please sign in to comment.