Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integrate open-set detection tools with RaiNode #281

Draft
wants to merge 9 commits into
base: development
Choose a base branch
from
27 changes: 18 additions & 9 deletions examples/rosbot-xl-generic-node-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,23 @@

import rclpy
import rclpy.executors
from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool
from rai_open_set_vision import GetDetectionTool

from rai.node import RaiStateBasedLlmNode
from rai.tools.ros.manipulation import GetObjectPositionsTool
from rai.tools.ros.native import (
GetCameraImage,
GetMsgFromTopic,
Ros2PubMessageTool,
Ros2ShowMsgInterfaceTool,
)

# from rai.tools.ros.native_actions import Ros2RunActionSync
from rai.tools.ros.native_actions import (
GetTransformTool,
Ros2CancelAction,
Ros2GetActionResult,
Ros2GetLastActionFeedback,
Ros2IsActionComplete,
Ros2RunActionAsync,
)
from rai.tools.ros.tools import GetCurrentPositionTool
from rai.tools.time import WaitForSecondsTool


Expand All @@ -50,11 +48,13 @@ def main():
topics_whitelist = [
"/rosout",
"/camera/camera/color/image_raw",
"/camera/camera/color/camera_info",
"/camera/camera/depth/image_rect_raw",
"/camera/camera/depth/camera_info",
"/map",
"/scan",
"/diagnostics",
"/cmd_vel",
# "/cmd_vel",
"/led_strip",
]

Expand All @@ -81,6 +81,7 @@ def main():

<rule> use /cmd_vel topic very carefully. Obstacle detection works only with nav2 stack, so be careful when it is not used. </rule>>
<rule> be patient with running ros2 actions. usually the take some time to run. </rule>
<rule> Always check your transform before and after you perform ros2 actions, so that you can verify if it worked. </rule>

Navigation tips:
- it's good to start finding objects by rotating, then navigating to some diverse location with occasional rotations. Remember to frequency detect objects.
Expand Down Expand Up @@ -136,12 +137,20 @@ def main():
Ros2GetActionResult,
Ros2GetLastActionFeedback,
Ros2ShowMsgInterfaceTool,
GetCurrentPositionTool,
WaitForSecondsTool,
GetMsgFromTopic,
GetCameraImage,
GetDetectionTool,
GetDistanceToObjectsTool,
GetTransformTool,
(
GetObjectPositionsTool,
dict(
target_frame="odom",
source_frame="sensor_frame",
camera_topic="/camera/camera/color/image_raw",
depth_topic="/camera/camera/depth/image_rect_raw",
camera_info_topic="/camera/camera/color/camera_info",
),
),
],
)

Expand Down
61 changes: 44 additions & 17 deletions src/rai/rai/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from action_msgs.msg import GoalStatus
from langchain.tools import BaseTool
from langchain.tools.render import render_text_description_and_args
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.graph.graph import CompiledGraph
from rclpy.action.client import ActionClient
from rclpy.action.graph import get_action_names_and_types
Expand Down Expand Up @@ -226,7 +226,7 @@ def _is_task_complete(self):
)
return True
else:
self.get_logger().info("There is not result")
self.get_logger().info("There is no result")
# Timed out, still processing, not complete yet
return False

Expand All @@ -250,7 +250,7 @@ def __init__(
):
super().__init__(*args, **kwargs)

self.robot_state = dict()
self.robot_state: Dict[str, Any] = dict() # where Any is ROS 2 message type

self.DISCOVERY_FREQ = 2.0
self.DISCOVERY_DEPTH = 5
Expand All @@ -270,6 +270,8 @@ def __init__(

self.state_subscribers = dict()

# ------- TF subscriber -------

# ------- ROS2 actions handling -------
self._async_tool_node = RaiAsyncToolsNode()

Expand All @@ -288,9 +290,9 @@ def discovery(self):
)

def get_raw_message_from_topic(self, topic: str, timeout_sec: int = 1) -> Any:
self.get_logger().info(f"Getting msg from topic: {topic}")
self.get_logger().debug(f"Getting msg from topic: {topic}")
if topic in self.state_subscribers and topic in self.robot_state:
self.get_logger().info("Returning cached message")
self.get_logger().debug("Returning cached message")
return self.robot_state[topic]
else:
msg_type = self.get_msg_type(topic)
Expand All @@ -303,7 +305,7 @@ def get_raw_message_from_topic(self, topic: str, timeout_sec: int = 1) -> Any:
)

if success:
self.get_logger().info(
self.get_logger().debug(
f"Received message of type {msg_type.__class__.__name__} from topic {topic}"
)
return msg
Expand Down Expand Up @@ -342,7 +344,7 @@ def __init__(
self,
system_prompt: str,
observe_topics: Optional[List[str]] = None,
observe_postprocessors: Optional[Dict[str, Callable]] = None,
observe_postprocessors: Optional[Dict[str, Callable[[Any], Any]]] = None,
whitelist: Optional[List[str]] = None,
tools: Optional[List[Type[BaseTool]]] = None,
*args,
Expand Down Expand Up @@ -399,18 +401,27 @@ def __init__(
state_retriever=self.get_robot_state,
logger=self.get_logger(),
)
self.simple_llm = get_llm_model(model_type="simple_model")

def _initialize_tools(self, tools: List[Type[BaseTool]]):
initialized_tools = list()
initialized_tools: List[BaseTool] = list()
for tool_cls in tools:
args = None
if type(tool_cls) is tuple:
tool_cls, args = tool_cls
if issubclass(tool_cls, Ros2BaseTool):
if (
issubclass(tool_cls, Ros2BaseActionTool)
or "DetectionTool" in tool_cls.__name__
or "GetDistance" in tool_cls.__name__
or "GetTransformTool" in tool_cls.__name__
or "GetObjectPositionsTool" in tool_cls.__name__
or "GetDetectionTool" in tool_cls.__name__
): # TODO(boczekbartek): develop a way to handle all mutially
tool = tool_cls(node=self._async_tool_node)
if args:
tool = tool_cls(node=self._async_tool_node, **args)
else:
tool = tool_cls(node=self._async_tool_node)
else:
tool = tool_cls(node=self)
else:
Expand All @@ -426,7 +437,7 @@ def _initialize_system_prompt(self, prompt: str):
)
return system_prompt

def _initialize_robot_state_interfaces(self, topics):
def _initialize_robot_state_interfaces(self, topics: List[str]):
self.rosout_buffer = RosoutBuffer(get_llm_model(model_type="simple_model"))

for topic in topics:
Expand Down Expand Up @@ -467,7 +478,7 @@ async def agent_loop(self, goal_handle: ServerGoalHandle):
self.get_logger().info(f"Received task: {task}")

# ---- LLM Task Handling ----
self.get_logger().info(f'This is system prompt: "{self.system_prompt}"')
self.get_logger().debug(f'This is system prompt: "{self.system_prompt}"')

messages = [
SystemMessage(content=self.system_prompt),
Expand All @@ -484,19 +495,36 @@ async def agent_loop(self, goal_handle: ServerGoalHandle):
payload, {"recursion_limit": self.AGENT_RECURSION_LIMIT}
):

print(state.keys())
graph_node_name = list(state.keys())[0]
if graph_node_name == "reporter":
continue

msg = state[graph_node_name]["messages"][-1]

if isinstance(msg, HumanMultimodalMessage):
last_msg = msg.text
elif isinstance(msg, BaseMessage):
if isinstance(msg.content, list):
assert len(msg.content) == 1
last_msg = msg.content[0].get("text", "")
else:
last_msg = msg.content
else:
last_msg = msg.content

if len(str(last_msg)) > 0 and "Retrieved state: {}" not in last_msg:
raise ValueError(f"Unexpected type of message: {type(msg)}")

last_msg = self.simple_llm.invoke(
[
SystemMessage(
content=(
"You are an experienced reporter deployed on a autonomous robot. " # type: ignore
"Your task is to summarize the message in a way that is easy for other agents to understand. "
"Do not use markdown formatting. Keep it short and concise. If the message is empty, please return empty string."
)
),
HumanMessage(content=last_msg),
]
).content

if len(str(last_msg)) > 0 and graph_node_name != "state_retriever":
feedback_msg = TaskAction.Feedback()
feedback_msg.current_status = f"{graph_node_name}: {last_msg}"

Expand All @@ -505,7 +533,6 @@ async def agent_loop(self, goal_handle: ServerGoalHandle):
# ---- Share Action Result ----
if state is None:
raise ValueError("No output from LLM")
print(state)

graph_node_name = list(state.keys())[0]
if graph_node_name != "reporter":
Expand Down
12 changes: 6 additions & 6 deletions src/rai/rai/tools/ros/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from rclpy.node import Node
from tf2_geometry_msgs import do_transform_pose

from rai.tools.utils import TF2TransformFetcher
from rai.tools.ros.native_actions import Ros2BaseActionTool
from rai.tools.ros.utils import get_transform
from rai_interfaces.srv import ManipulatorMoveTo


Expand Down Expand Up @@ -141,7 +142,7 @@ class GetObjectPositionsToolInput(BaseModel):
)


class GetObjectPositionsTool(BaseTool):
class GetObjectPositionsTool(Ros2BaseActionTool):
name: str = "get_object_positions"
description: str = (
"Retrieve the positions of all objects of a specified type within the manipulator's frame of reference. "
Expand All @@ -154,7 +155,6 @@ class GetObjectPositionsTool(BaseTool):
camera_topic: str # rgb camera topic
depth_topic: str
camera_info_topic: str # rgb camera info topic
node: Node
get_grabbing_point_tool: GetGrabbingPointTool

def __init__(self, node: Node, **kwargs):
Expand All @@ -169,16 +169,16 @@ def format_pose(pose: Pose):
return f"Centroid(x={pose.position.x:.2f}, y={pose.position.y:2f}, z={pose.position.z:2f})"

def _run(self, object_name: str):
transform = TF2TransformFetcher(
target_frame=self.target_frame, source_frame=self.source_frame
).get_data()
transform = get_transform(self.node, self.source_frame, self.target_frame)
self.logger.info("Got transform: {transform}")

results = self.get_grabbing_point_tool._run(
camera_topic=self.camera_topic,
depth_topic=self.depth_topic,
camera_info_topic=self.camera_info_topic,
object_name=object_name,
)
self.logger.info("Got result: {results}")

poses = []
for result in results:
Expand Down
4 changes: 2 additions & 2 deletions src/rai/rai/tools/ros/native_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,14 @@ class Ros2GetLastActionFeedback(Ros2BaseActionTool):
args_schema: Type[Ros2BaseInput] = Ros2BaseInput

def _run(self) -> str:
return str(self.node.action_feedback)
return str(self.node.feedback)


class GetTransformTool(Ros2BaseActionTool):
name: str = "GetTransform"
description: str = "Get transform between two frames"

def _run(self, target_frame="map", source_frame="body_link") -> dict:
def _run(self, target_frame="odom", source_frame="body_link") -> dict:
return message_to_ordereddict(
get_transform(self.node, target_frame, source_frame)
)
14 changes: 11 additions & 3 deletions src/rai/rai/tools/ros/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#

import base64
import time
from typing import Type, Union, cast

import cv2
Expand Down Expand Up @@ -155,17 +156,24 @@ def wait_for_message(


def get_transform(
node: rclpy.node.Node, target_frame: str, source_frame: str
node: rclpy.node.Node, target_frame: str, source_frame: str, timeout=30
) -> TransformStamped:
node.get_logger().info(
"Waiting for transform from {} to {}".format(source_frame, target_frame)
)
tf_buffer = Buffer(node=node)
tf_listener = TransformListener(tf_buffer, node)

transform = None
while transform is None:
rclpy.spin_once(node)
for _ in range(timeout * 10):
rclpy.spin_once(node, timeout_sec=0.1)
if tf_buffer.can_transform(target_frame, source_frame, rclpy.time.Time()):
transform = tf_buffer.lookup_transform(
target_frame, source_frame, rclpy.time.Time()
)
break
else:
time.sleep(0.1)

tf_listener.unregister()
return transform
1 change: 0 additions & 1 deletion src/rai/rai/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from rai.messages import ToolMultimodalMessage


# Copied from https://github.com/ros2/rclpy/blob/jazzy/rclpy/rclpy/wait_for_message.py, to support humble
def wait_for_message(
msg_type,
node: "Node",
Expand Down
4 changes: 4 additions & 0 deletions src/rai/rai/utils/ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ def __init__(self, llm: BaseChatModel, bufsize: int = 100) -> None:
)
llm = llm
self.llm = self.template | llm
self.filter_out = ["rviz", "rai"]

def clear(self):
self._buffer.clear()

def append(self, line: str):
for w in self.filter_out:
if w in line:
return
self._buffer.append(line)
if len(self._buffer) > self.bufsize:
self._buffer.popleft()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _call_gdino_node(
future = cli.call_async(req)
return future

def get_img_from_topic(self, topic: str, timeout_sec: int = 2):
def get_img_from_topic(self, topic: str, timeout_sec: int = 4):
success, msg = wait_for_message(
sensor_msgs.msg.Image,
self.node,
Expand Down
Loading