Skip to content

Commit

Permalink
[N/A] Add a mock launch argument (bdaiinstitute#199)
Browse files Browse the repository at this point in the history
This PR uses node / launch arguments to specify whether the driver should operate in "mock" mode, instead of using the sentinel name `Mock_spot`. With an explicit parameter, we can have multiple mocks, with different names, and also affect whether there is an arm.


For example, with this modification you can do

```
cd spot_driver
colcon build --symlink # (`symlink` [does not apply to launch files](colcon/colcon-core#407))
source install/setup.bash
ros2 launch spot_driver spot_driver.launch.py spot_name:=AnyNameHere mock_enable:=True mock_has_arm:=True
```

and then, for example (`spot_utilities`, specifically in `bdai` repo)

```
# source so that `spot_driver` can be found
cd spot_utilities # /workspaces/bdai/ws/src/spot_utilities
colcon build
source install/setup.bash
ros2 run spot_utilities spot_teleop.py AnyNameHere
```
  • Loading branch information
ggoretkin-bdai committed Dec 8, 2023
1 parent 3513dbf commit 8498270
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 23 deletions.
32 changes: 22 additions & 10 deletions spot_driver/launch/spot_driver.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,21 @@ def launch_setup(context: LaunchContext, ld: LaunchDescription) -> None:
tf_prefix = LaunchConfiguration("tf_prefix").perform(context)
depth_registered_mode_config = LaunchConfiguration("depth_registered_mode")
publish_point_clouds_config = LaunchConfiguration("publish_point_clouds")

# Get parameters from Spot.

username = os.getenv("BOSDYN_CLIENT_USERNAME", "username")
password = os.getenv("BOSDYN_CLIENT_PASSWORD", "password")
hostname = os.getenv("SPOT_IP", "hostname")

spot_wrapper = SpotWrapper(username, password, hostname, spot_name, logger)
has_arm = spot_wrapper.has_arm()
mock_enable = IfCondition(LaunchConfiguration("mock_enable", default="False")).evaluate(context)

if not mock_enable:
# Get parameters from Spot.
# TODO this deviates from the `get_from_env_and_fall_back_to_param` logic in `spot_ros2.py`,
# which would pull in values in `config_file`
username = os.getenv("BOSDYN_CLIENT_USERNAME", "username")
password = os.getenv("BOSDYN_CLIENT_PASSWORD", "password")
hostname = os.getenv("SPOT_IP", "hostname")

spot_wrapper = SpotWrapper(username, password, hostname, spot_name, logger)
has_arm = spot_wrapper.has_arm()
else:
mock_has_arm = IfCondition(LaunchConfiguration("mock_has_arm")).evaluate(context)
has_arm = mock_has_arm

pkg_share = FindPackageShare("spot_description").find("spot_description")

Expand Down Expand Up @@ -154,11 +160,17 @@ def launch_setup(context: LaunchContext, ld: LaunchDescription) -> None:
# in spot_driver.
spot_driver_params = {
"spot_name": spot_name,
"mock_enable": mock_enable,
"publish_depth_registered": False,
"publish_depth": False,
"publish_rgb": False,
}

if mock_enable:
mock_spot_driver_params = {"mock_has_arm": mock_has_arm}
# Merge the two dicts
spot_driver_params = {**spot_driver_params, **mock_spot_driver_params}

spot_driver_node = launch_ros.actions.Node(
package="spot_driver",
executable="spot_ros2",
Expand Down Expand Up @@ -196,7 +208,7 @@ def launch_setup(context: LaunchContext, ld: LaunchDescription) -> None:
PathJoinSubstitution([pkg_share, "urdf", "spot.urdf.xacro"]),
" ",
"arm:=",
TextSubstitution(text=str(spot_wrapper.has_arm()).lower()),
TextSubstitution(text=str(has_arm).lower()),
" ",
"tf_prefix:=",
tf_prefix,
Expand Down
45 changes: 34 additions & 11 deletions spot_driver/spot_driver/spot_ros2.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@
from spot_wrapper.wrapper import SpotWrapper

MAX_DURATION = 1e6
MOCK_HOSTNAME = "Mock_spot"
COLOR_END = "\33[0m"
COLOR_GREEN = "\33[32m"
COLOR_YELLOW = "\33[33m"
Expand Down Expand Up @@ -194,6 +193,14 @@ class SpotImageType(str, Enum):
RegDepth = "depth_registered"


def set_node_parameter_from_parameter_list(
node: Node, parameter_list: Optional[typing.List[Parameter]], parameter_name: str
) -> None:
"""Set parameters when the node starts not from a launch file."""
if parameter_list is not None:
node.set_parameters([parameter for parameter in parameter_list if parameter.name == parameter_name])


class SpotROS(Node):
"""Parent class for using the wrapper. Defines all callbacks and keeps the wrapper alive"""

Expand Down Expand Up @@ -253,6 +260,12 @@ def __init__(self, parameter_list: Optional[typing.List[Parameter]] = None, **kw
self.declare_parameter("initialize_spot_cam", False)

self.declare_parameter("spot_name", "")
self.declare_parameter("mock_enable", False)

# If `mock_enable:=True`, then there are additional parameters. We must set this one separately.
set_node_parameter_from_parameter_list(self, parameter_list, "mock_enable")
if self.get_parameter("mock_enable").value:
self.declare_parameter("mock_has_arm", rclpy.Parameter.Type.BOOL)

# used for setting when not using launch file
if parameter_list is not None:
Expand Down Expand Up @@ -294,6 +307,10 @@ def __init__(self, parameter_list: Optional[typing.List[Parameter]] = None, **kw
self.name: Optional[str] = self.get_parameter("spot_name").value
if not self.name:
self.name = None
self.mock: bool = self.get_parameter("mock_enable").value
self.mock_has_arm: Optional[bool] = None
if self.mock:
self.mock_has_arm = self.get_parameter("mock_has_arm").value

self.motion_deadzone: Parameter = self.get_parameter("deadzone")
self.estop_timeout: Parameter = self.get_parameter("estop_timeout")
Expand Down Expand Up @@ -363,10 +380,11 @@ def __init__(self, parameter_list: Optional[typing.List[Parameter]] = None, **kw
name_str = ""
if self.name is not None:
name_str = " for " + self.name
self.get_logger().info("Starting ROS driver for Spot" + name_str)
mocking_designator = " (mocked)" if self.mock else ""
self.get_logger().info("Starting ROS driver for Spot" + name_str + mocking_designator)
# testing with Robot

if self.name == MOCK_HOSTNAME:
if self.mock:
self.spot_wrapper: Optional[SpotWrapper] = None
self.cam_wrapper: Optional[SpotCamWrapper] = None
else:
Expand Down Expand Up @@ -397,8 +415,16 @@ def __init__(self, parameter_list: Optional[typing.List[Parameter]] = None, **kw
except SystemError:
self.spot_cam_wrapper = None

if self.frame_prefix != self.spot_wrapper.frame_prefix:
error_msg = (
f"ERROR: disagreement between `self.frame_prefix` ({self.frame_prefix}) and"
f" `self.spot_wrapper.frame_prefix` ({self.spot_wrapper.frame_prefix})"
)
self.get_logger().error(error_msg)
raise ValueError(error_msg)

all_cameras = ["frontleft", "frontright", "left", "right", "back"]
has_arm = False
has_arm = self.mock_has_arm
if self.spot_wrapper is not None:
has_arm = self.spot_wrapper.has_arm()
if has_arm:
Expand Down Expand Up @@ -2365,13 +2391,10 @@ def populate_camera_static_transforms(self, image_data: image_pb2.Image) -> None
# We exclude the odometry frames from static transforms since they are not static. We can ignore the body
# frame because it is a child of odom or vision depending on the preferred_odom_frame, and will be published
# by the non-static transform publishing that is done by the state callback
frame_prefix = MOCK_HOSTNAME + "/"
if self.spot_wrapper is not None:
frame_prefix = self.spot_wrapper.frame_prefix
excluded_frames = [
self.tf_name_vision_odom.value,
self.tf_name_kinematic_odom.value,
frame_prefix + "body",
self.frame_prefix + "body",
]
excluded_frames = [f[f.rfind("/") + 1 :] for f in excluded_frames]

Expand Down Expand Up @@ -2405,8 +2428,8 @@ def populate_camera_static_transforms(self, image_data: image_pb2.Image) -> None
(transform.header.frame_id, transform.child_frame_id) for transform in self.camera_static_transforms
]
if (
frame_prefix + parent_frame,
frame_prefix + frame_name,
self.frame_prefix + parent_frame,
self.frame_prefix + frame_name,
) in existing_transforms:
# We already extracted this transform
continue
Expand All @@ -2421,7 +2444,7 @@ def populate_camera_static_transforms(self, image_data: image_pb2.Image) -> None
parent_frame,
frame_name,
transform.parent_tform_child,
frame_prefix,
self.frame_prefix,
)
self.camera_static_transforms.append(static_tf)
self.camera_static_transform_broadcaster.sendTransform(self.camera_static_transforms)
Expand Down
10 changes: 8 additions & 2 deletions spot_driver/test/test_spot_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ def setUp(self) -> None:
self.fixture = contextlib.ExitStack()
self.ros = self.fixture.enter_context(ros_scope.top(namespace="fixture"))
# create and run spot ros2 servers
mock_param = rclpy.parameter.Parameter("spot_name", rclpy.Parameter.Type.STRING, "Mock_spot")
self.spot_ros2 = self.ros.load(spot_driver.spot_ros2.SpotROS, parameter_list=[mock_param])
self.spot_ros2 = self.ros.load(
spot_driver.spot_ros2.SpotROS,
parameter_list=[
rclpy.parameter.Parameter("spot_name", value="Mock_spot"),
rclpy.parameter.Parameter("mock_enable", value=True),
rclpy.parameter.Parameter("mock_has_arm", value=False),
],
)

# clients
self.claim_client: rclpy.node.Client = self.ros.node.create_client(Trigger, "claim")
Expand Down

0 comments on commit 8498270

Please sign in to comment.