Skip to content

Commit

Permalink
Changed observer for initial std
Browse files Browse the repository at this point in the history
  • Loading branch information
SamueleSandrini committed Aug 26, 2024
1 parent 6bb78db commit b252d63
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class ActionObservedCostClient : public plansys2::ActionExecutorClient
state_observer_params_loader_;
state_observer::StateObserverParam::SharedPtr state_observer_params_;
std::shared_ptr<pluginlib::ClassLoader<state_observer::StateObserver>> state_observer_loader_;
std::shared_ptr<state_observer::StateObserver> state_observer_;
std::string state_observer_plugin_name_;

// parameters
Expand All @@ -86,6 +85,8 @@ class ActionObservedCostClient : public plansys2::ActionExecutorClient
void save_updated_fluent(const std::string & updated_fluent);
void save_updated_problem(const std::string & updated_problem);

std::shared_ptr<state_observer::StateObserver> load_state_observer();


// plansys2::msg::ActionExecutionDataCollectionPtr data_collection_;
};
Expand Down
70 changes: 44 additions & 26 deletions src/action_observed_cost_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,32 +171,16 @@ ActionObservedCostClient::finish(
auto arguments_hash = get_arguments_hash();

if (observed_action_cost_.find(arguments_hash) == observed_action_cost_.end()) {
std::shared_ptr<state_observer::StateObserver> state_observer;
try {
state_observer =
state_observer_loader_->createSharedInstance(
state_observer_plugin_name_);
} catch (pluginlib::PluginlibException & ex) {
RCLCPP_ERROR(
get_logger(), "The plugin: %s, failed to load for some reason. Error: %s", state_observer_plugin_name_,
ex.what());
throw std::runtime_error("The plugin failed to load for some reason.");
}

try {
state_observer->set_parameters(
state_observer_params_);
} catch (const std::exception & e) {
RCLCPP_ERROR(get_logger(), "Exception caught in state observer set_parameters: %s", e.what());
throw std::runtime_error("The plugin failed to set parameters.");
}

observed_action_cost_[arguments_hash] = state_observer;
observed_action_cost_[arguments_hash]->initialize(residual);
RCLCPP_INFO(get_logger(), "Initialized observer");
RCLCPP_WARN(get_logger(), "Observer not found for arguments hash %s", arguments_hash.c_str());
observed_action_cost_[arguments_hash] = load_state_observer();
} else {
observed_action_cost_[arguments_hash]->update(residual);
RCLCPP_INFO(get_logger(), "Updated observer");
if (observed_action_cost_[arguments_hash]->is_initialized()) {
observed_action_cost_[arguments_hash]->update(residual);
RCLCPP_INFO(get_logger(), "Updated observer");
} else {
observed_action_cost_[arguments_hash]->initialize(residual);
RCLCPP_INFO(get_logger(), "Observer initialized");
}
RCLCPP_INFO(get_logger(), "State %f", observed_action_cost_[arguments_hash]->get_state()[0]);
RCLCPP_INFO(get_logger(), "Output %f", observed_action_cost_[arguments_hash]->get_output()[0]);
}
Expand All @@ -206,7 +190,6 @@ ActionObservedCostClient::finish(
// check if arguments_hash is in observed_action_cost_
if (observed_action_cost_.find(arguments_hash) == observed_action_cost_.end()) {
RCLCPP_INFO(get_logger(), "Observer not found for arguments hash %s", arguments_hash.c_str());

}
//check dim of get_state()
if (action_cost_) {
Expand Down Expand Up @@ -275,7 +258,16 @@ ActionObservedCostClient::send_response(
msg_resp.action_cost.std_dev_cost =
observed_action_cost_[arguments_hash]->get_state_variance()[0];
} else {

observed_action_cost_[arguments_hash] = load_state_observer();

msg_resp.action_cost = *action_cost_;
try {
msg_resp.action_cost.std_dev_cost =
observed_action_cost_[arguments_hash]->get_state_variance()[0];
} catch (const std::exception & e) {
RCLCPP_INFO(get_logger(), "This state observer does not provide state variance");
}
}
action_hub_pub_->publish(msg_resp);

Expand Down Expand Up @@ -437,4 +429,30 @@ ActionObservedCostClient::should_execute(
return true;
}

std::shared_ptr<state_observer::StateObserver>
ActionObservedCostClient::load_state_observer()
{
std::shared_ptr<state_observer::StateObserver> state_observer;
try {
state_observer =
state_observer_loader_->createSharedInstance(
state_observer_plugin_name_);
} catch (pluginlib::PluginlibException & ex) {
RCLCPP_ERROR(
get_logger(), "The plugin: %s, failed to load for some reason. Error: %s", state_observer_plugin_name_,
ex.what());
throw std::runtime_error("The plugin failed to load for some reason.");
}

try {
state_observer->set_parameters(
state_observer_params_);
} catch (const std::exception & e) {
RCLCPP_ERROR(get_logger(), "Exception caught in state observer set_parameters: %s", e.what());
throw std::runtime_error("The plugin failed to set parameters.");
}

return state_observer;
}

} // namespace plansys2_actions_clients

0 comments on commit b252d63

Please sign in to comment.