Skip to content

Commit

Permalink
Fix issue #58: use weak_ptr to avoid circular shared_ptr ownership
Browse files Browse the repository at this point in the history
  • Loading branch information
facontidavide committed Apr 30, 2024
1 parent 784d4e9 commit 374edcf
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 24 deletions.
26 changes: 20 additions & 6 deletions behaviortree_ros2/include/behaviortree_ros2/bt_action_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,20 @@ class RosActionNode : public BT::ActionNodeBase

rclcpp::Logger logger()
{
return node_->get_logger();
if(auto node = node_.lock())
{
return node->get_logger();
}
return rclcpp::get_logger("RosActionNode");
}

rclcpp::Time now()
{
if(auto node = node_.lock())
{
return node->now();
}
return rclcpp::Clock(RCL_ROS_TIME).now();
}

using ClientsRegistry =
Expand All @@ -200,7 +213,7 @@ class RosActionNode : public BT::ActionNodeBase
return action_clients_registry;
}

std::shared_ptr<rclcpp::Node> node_;
std::weak_ptr<rclcpp::Node> node_;
std::shared_ptr<ActionClientInstance> client_instance_;
std::string action_name_;
bool action_name_may_change_ = false;
Expand Down Expand Up @@ -302,13 +315,14 @@ inline bool RosActionNode<T>::createClient(const std::string& action_name)
}

std::unique_lock lk(getMutex());
action_client_key_ = std::string(node_->get_fully_qualified_name()) + "/" + action_name;
auto node = node_.lock();
action_client_key_ = std::string(node->get_fully_qualified_name()) + "/" + action_name;

auto& registry = getRegistry();
auto it = registry.find(action_client_key_);
if(it == registry.end() || it->second.expired())
{
client_instance_ = std::make_shared<ActionClientInstance>(node_, action_name);
client_instance_ = std::make_shared<ActionClientInstance>(node, action_name);
registry.insert({ action_client_key_, client_instance_ });
}
else
Expand Down Expand Up @@ -421,7 +435,7 @@ inline NodeStatus RosActionNode<T>::tick()
}

future_goal_handle_ = action_client->async_send_goal(goal, goal_options);
time_goal_sent_ = node_->now();
time_goal_sent_ = now();

return NodeStatus::RUNNING;
}
Expand All @@ -442,7 +456,7 @@ inline NodeStatus RosActionNode<T>::tick()
future_goal_handle_, nodelay);
if(ret != rclcpp::FutureReturnCode::SUCCESS)
{
if((node_->now() - time_goal_sent_) > timeout)
if((now() - time_goal_sent_) > timeout)
{
return CheckStatus(onFailure(SEND_GOAL_TIMEOUT));
}
Expand Down
30 changes: 22 additions & 8 deletions behaviortree_ros2/include/behaviortree_ros2/bt_service_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,20 @@ class RosServiceNode : public BT::ActionNodeBase

rclcpp::Logger logger()
{
return node_->get_logger();
if(auto node = node_.lock())
{
return node->get_logger();
}
return rclcpp::get_logger("RosServiceNode");
}

rclcpp::Time now()
{
if(auto node = node_.lock())
{
return node->now();
}
return rclcpp::Clock(RCL_ROS_TIME).now();
}

using ClientsRegistry =
Expand All @@ -171,7 +184,7 @@ class RosServiceNode : public BT::ActionNodeBase
return clients_registry;
}

std::shared_ptr<rclcpp::Node> node_;
std::weak_ptr<rclcpp::Node> node_;
std::string service_name_;
bool service_name_may_change_ = false;
const std::chrono::milliseconds service_timeout_;
Expand Down Expand Up @@ -268,13 +281,14 @@ inline bool RosServiceNode<T>::createClient(const std::string& service_name)
}

std::unique_lock lk(getMutex());
auto client_key = std::string(node_->get_fully_qualified_name()) + "/" + service_name;
auto node = node_.lock();
auto client_key = std::string(node->get_fully_qualified_name()) + "/" + service_name;

auto& registry = getRegistry();
auto it = registry.find(client_key);
if(it == registry.end() || it->second.expired())
{
srv_instance_ = std::make_shared<ServiceClientInstance>(node_, service_name);
srv_instance_ = std::make_shared<ServiceClientInstance>(node, service_name);
registry.insert({ client_key, srv_instance_ });

RCLCPP_INFO(logger(), "Node [%s] created service client [%s]", name().c_str(),
Expand All @@ -289,8 +303,8 @@ inline bool RosServiceNode<T>::createClient(const std::string& service_name)
bool found = srv_instance_->service_client->wait_for_service(wait_for_service_timeout_);
if(!found)
{
RCLCPP_ERROR(node_->get_logger(), "%s: Service with name '%s' is not reachable.",
name().c_str(), service_name_.c_str());
RCLCPP_ERROR(logger(), "%s: Service with name '%s' is not reachable.", name().c_str(),
service_name_.c_str());
}
return found;
}
Expand Down Expand Up @@ -350,7 +364,7 @@ inline NodeStatus RosServiceNode<T>::tick()
}

future_response_ = srv_instance_->service_client->async_send_request(request).share();
time_request_sent_ = node_->now();
time_request_sent_ = now();

return NodeStatus::RUNNING;
}
Expand All @@ -371,7 +385,7 @@ inline NodeStatus RosServiceNode<T>::tick()

if(ret != rclcpp::FutureReturnCode::SUCCESS)
{
if((node_->now() - time_request_sent_) > timeout)
if((now() - time_request_sent_) > timeout)
{
return CheckStatus(onFailure(SERVICE_TIMEOUT));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class RosTopicPubNode : public BT::ConditionNode
/** You are not supposed to instantiate this class directly, the factory will do it.
* To register this class into the factory, use:
*
* RegisterRosAction<DerivedClasss>(factory, params)
* RegisterRosAction<DerivedClass>(factory, params)
*
* Note that if the external_action_client is not set, the constructor will build its own.
* */
Expand Down
15 changes: 11 additions & 4 deletions behaviortree_ros2/include/behaviortree_ros2/bt_topic_sub_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class RosTopicSubNode : public BT::ConditionNode
return subscribers_registry;
}

std::shared_ptr<rclcpp::Node> node_;
std::weak_ptr<rclcpp::Node> node_;
std::shared_ptr<SubscriberInstance> sub_instance_;
std::shared_ptr<TopicT> last_msg_;
std::string topic_name_;
Expand All @@ -85,7 +85,11 @@ class RosTopicSubNode : public BT::ConditionNode

rclcpp::Logger logger()
{
return node_->get_logger();
if(auto node = node_.lock())
{
return node->get_logger();
}
return rclcpp::get_logger("RosTopicSubNode");
}

public:
Expand Down Expand Up @@ -244,13 +248,16 @@ inline bool RosTopicSubNode<T>::createSubscriber(const std::string& topic_name)

// find SubscriberInstance in the registry
std::unique_lock lk(registryMutex());
subscriber_key_ = std::string(node_->get_fully_qualified_name()) + "/" + topic_name;

auto shared_node = node_.lock();
subscriber_key_ =
std::string(shared_node->get_fully_qualified_name()) + "/" + topic_name;

auto& registry = getRegistry();
auto it = registry.find(subscriber_key_);
if(it == registry.end() || it->second.expired())
{
sub_instance_ = std::make_shared<SubscriberInstance>(node_, topic_name);
sub_instance_ = std::make_shared<SubscriberInstance>(shared_node, topic_name);
registry.insert({ subscriber_key_, sub_instance_ });

RCLCPP_INFO(logger(), "Node [%s] created Subscriber to topic [%s]", name().c_str(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace BT

struct RosNodeParams
{
std::shared_ptr<rclcpp::Node> nh;
std::weak_ptr<rclcpp::Node> nh;

// This has different meaning based on the context:
//
Expand Down
7 changes: 3 additions & 4 deletions btcpp_ros2_samples/src/sleep_action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,21 @@ bool SleepAction::setGoal(RosActionNode::Goal& goal)

NodeStatus SleepAction::onResultReceived(const RosActionNode::WrappedResult& wr)
{
RCLCPP_INFO(node_->get_logger(), "%s: onResultReceived. Done = %s", name().c_str(),
RCLCPP_INFO(logger(), "%s: onResultReceived. Done = %s", name().c_str(),
wr.result->done ? "true" : "false");

return wr.result->done ? NodeStatus::SUCCESS : NodeStatus::FAILURE;
}

NodeStatus SleepAction::onFailure(ActionNodeErrorCode error)
{
RCLCPP_ERROR(node_->get_logger(), "%s: onFailure with error: %s", name().c_str(),
toStr(error));
RCLCPP_ERROR(logger(), "%s: onFailure with error: %s", name().c_str(), toStr(error));
return NodeStatus::FAILURE;
}

void SleepAction::onHalt()
{
RCLCPP_INFO(node_->get_logger(), "%s: onHalt", name().c_str());
RCLCPP_INFO(logger(), "%s: onHalt", name().c_str());
}

// Plugin registration.
Expand Down

0 comments on commit 374edcf

Please sign in to comment.