Skip to content

Commit

Permalink
Change set_value CPO to use member functions instead of tag_invoke
Browse files Browse the repository at this point in the history
  • Loading branch information
msimberg committed Nov 12, 2024
1 parent b252c0b commit 67a1976
Show file tree
Hide file tree
Showing 28 changed files with 132 additions and 198 deletions.
65 changes: 25 additions & 40 deletions libs/pika/async_cuda/include/pika/async_cuda/then_with_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,21 @@ namespace pika::cuda::experimental::then_with_stream_detail {
}

template <typename... Ts>
auto set_value(Ts&&... ts) noexcept
auto set_value(Ts&&... ts) && noexcept
-> decltype(PIKA_INVOKE(PIKA_MOVE(f), op_state.sched, stream.value(), ts...),
void())
{
auto r = PIKA_MOVE(*this);
pika::detail::try_catch_exception_ptr(
[&]() mutable {
using ts_element_type = std::tuple<std::decay_t<Ts>...>;
op_state.ts.template emplace<ts_element_type>(PIKA_FORWARD(Ts, ts)...);
[[maybe_unused]] auto& t = std::get<ts_element_type>(op_state.ts);
r.op_state.ts.template emplace<ts_element_type>(
PIKA_FORWARD(Ts, ts)...);
[[maybe_unused]] auto& t = std::get<ts_element_type>(r.op_state.ts);

if (!op_state.stream)
if (!r.op_state.stream)
{
op_state.stream.emplace(op_state.sched.get_next_stream());
r.op_state.stream.emplace(r.op_state.sched.get_next_stream());
}

// If the next receiver is also a
Expand All @@ -272,11 +274,11 @@ namespace pika::cuda::experimental::then_with_stream_detail {
if constexpr (is_then_with_cuda_stream_receiver<
std::decay_t<Receiver>>::value)
{
if (op_state.sched == op_state.receiver.op_state.sched)
if (r.op_state.sched == r.op_state.receiver.op_state.sched)
{
PIKA_ASSERT(op_state.stream);
PIKA_ASSERT(!op_state.receiver.op_state.stream);
op_state.receiver.op_state.stream = op_state.stream;
PIKA_ASSERT(r.op_state.stream);
PIKA_ASSERT(!r.op_state.receiver.op_state.stream);
r.op_state.receiver.op_state.stream = r.op_state.stream;

successor_uses_same_stream = true;
}
Expand All @@ -290,8 +292,8 @@ namespace pika::cuda::experimental::then_with_stream_detail {
{
std::apply(
[&](auto&... ts) mutable {
PIKA_INVOKE(PIKA_MOVE(op_state.f), op_state.sched,
op_state.stream.value(), ts...);
PIKA_INVOKE(PIKA_MOVE(r.op_state.f), r.op_state.sched,
r.op_state.stream.value(), ts...);
},
t);

Expand All @@ -307,14 +309,14 @@ namespace pika::cuda::experimental::then_with_stream_detail {
// stream when a
// non-then_with_cuda_stream receiver is
// connected.
set_value_immediate_void(op_state);
set_value_immediate_void(r.op_state);
}
else
{
// When the streams are different, we
// add a callback which will call
// set_value on the receiver.
set_value_event_callback_void(op_state);
set_value_event_callback_void(r.op_state);
}
}
else
Expand All @@ -323,16 +325,16 @@ namespace pika::cuda::experimental::then_with_stream_detail {
// then_with_cuda_stream_receiver, we add a
// callback which will call set_value on the
// receiver.
set_value_event_callback_void(op_state);
set_value_event_callback_void(r.op_state);
}
}
else
{
std::apply(
[&](auto&... ts) mutable {
op_state.result.template emplace<invoke_result_type>(
PIKA_INVOKE(PIKA_MOVE(op_state.f), op_state.sched,
op_state.stream.value(), ts...));
r.op_state.result.template emplace<invoke_result_type>(
PIKA_INVOKE(PIKA_MOVE(r.op_state.f), r.op_state.sched,
r.op_state.stream.value(), ts...));
},
t);

Expand All @@ -348,15 +350,16 @@ namespace pika::cuda::experimental::then_with_stream_detail {
// stream when a
// non-then_with_cuda_stream receiver is
// connected.
set_value_immediate_non_void<invoke_result_type>(op_state);
set_value_immediate_non_void<invoke_result_type>(
r.op_state);
}
else
{
// When the streams are different, we
// add a callback which will call
// set_value on the receiver.
set_value_event_callback_non_void<invoke_result_type>(
op_state);
r.op_state);
}
}
else
Expand All @@ -365,13 +368,14 @@ namespace pika::cuda::experimental::then_with_stream_detail {
// then_with_cuda_stream_receiver, we add a
// callback which will call set_value on the
// receiver.
set_value_event_callback_non_void<invoke_result_type>(op_state);
set_value_event_callback_non_void<invoke_result_type>(
r.op_state);
}
}
},
[&](std::exception_ptr ep) mutable {
pika::execution::experimental::set_error(
PIKA_MOVE(op_state.receiver), PIKA_MOVE(ep));
PIKA_MOVE(r.op_state.receiver), PIKA_MOVE(ep));
});
}

Expand All @@ -383,25 +387,6 @@ namespace pika::cuda::experimental::then_with_stream_detail {
}
};

// This should be a hidden friend in then_with_cuda_stream_receiver.
// However, nvcc does not know how to compile it with some argument
// types ("error: no instance of overloaded function std::forward
// matches the argument list").
template <typename... Ts>
friend auto tag_invoke(pika::execution::experimental::set_value_t,
then_with_cuda_stream_receiver&& r, Ts&&... ts) noexcept
-> decltype(r.set_value(PIKA_FORWARD(Ts, ts)...))
{
// nvcc fails to compile this with std::forward<Ts>(ts)... or
// static_cast<Ts&&>(ts)... so we explicitly use
// static_cast<decltype(ts)>(ts)... as a workaround.
#if defined(PIKA_HAVE_CUDA)
r.set_value(static_cast<decltype(ts)&&>(ts)...);
#else
r.set_value(PIKA_FORWARD(Ts, ts)...);
#endif
}

using operation_state_type =
pika::execution::experimental::connect_result_t<std::decay_t<Sender>,
then_with_cuda_stream_receiver>;
Expand Down
4 changes: 2 additions & 2 deletions libs/pika/async_mpi/include/pika/async_mpi/dispatch_mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ namespace pika::mpi::experimental::detail {
// otherwise return the request by passing it to set_value
template <typename... Ts,
typename = std::enable_if_t<is_mpi_request_invocable_v<F, Ts...>>>
friend constexpr void
tag_invoke(ex::set_value_t, dispatch_mpi_receiver r, Ts&&... ts) noexcept
constexpr void set_value(Ts&&... ts) && noexcept
{
auto r = PIKA_MOVE(*this);
pika::detail::try_catch_exception_ptr(
[&]() mutable {
using invoke_result_type = mpi_request_invoke_result_t<F, Ts...>;
Expand Down
5 changes: 3 additions & 2 deletions libs/pika/async_mpi/include/pika/async_mpi/trigger_mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ namespace pika::mpi::experimental::detail {

// receive the MPI Request and set a callback to be
// triggered when the mpi request completes
friend constexpr void tag_invoke(
ex::set_value_t, trigger_mpi_receiver r, MPI_Request request) noexcept
constexpr void set_value(MPI_Request request) && noexcept
{
auto r = PIKA_MOVE(*this);

// early exit check
if (request == MPI_REQUEST_NULL)
{
Expand Down
19 changes: 3 additions & 16 deletions libs/pika/execution/include/pika/execution/algorithms/bulk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ namespace pika::bulk_detail {
}

template <typename... Ts>
void set_value(Ts&&... ts)
void set_value(Ts&&... ts) && noexcept
{
auto r = PIKA_MOVE(*this);
pika::detail::try_catch_exception_ptr(
[&]() {
for (auto const& s : shape) { PIKA_INVOKE(f, s, ts...); }
for (auto const& s : r.shape) { PIKA_INVOKE(r.f, s, ts...); }
pika::execution::experimental::set_value(
PIKA_MOVE(receiver), PIKA_FORWARD(Ts, ts)...);
},
Expand All @@ -108,20 +109,6 @@ namespace pika::bulk_detail {
PIKA_MOVE(receiver), PIKA_MOVE(ep));
});
}

template <typename... Ts>
friend auto tag_invoke(
pika::execution::experimental::set_value_t, bulk_receiver&& r, Ts&&... ts) noexcept
-> decltype(pika::execution::experimental::set_value(
std::declval<std::decay_t<Receiver>&&>(), PIKA_FORWARD(Ts, ts)...),
void())
{
// set_value is in a member function only because of a
// compiler bug in GCC 7. When the body of set_value is
// inlined here compilation fails with an internal compiler
// error.
r.set_value(PIKA_FORWARD(Ts, ts)...);
}
};

template <typename Receiver>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ namespace pika::drop_op_state_detail {
};

template <typename... Ts>
friend void tag_invoke(pika::execution::experimental::set_value_t,
drop_op_state_receiver_type r, Ts&&... ts) noexcept
void set_value(Ts&&... ts) && noexcept
{
auto r = PIKA_MOVE(*this);

PIKA_ASSERT(r.op_state != nullptr);
PIKA_ASSERT(r.op_state->op_state.has_value());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ namespace pika::drop_value_detail {
}

template <typename... Ts>
friend void tag_invoke(pika::execution::experimental::set_value_t,
drop_value_receiver_type&& r, Ts&&...) noexcept
void set_value(Ts&&...) && noexcept
{
auto r = PIKA_MOVE(*this);
pika::execution::experimental::set_value(PIKA_MOVE(r.receiver));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,14 @@ namespace pika::ensure_started_detail {
#endif

template <typename... Ts>
friend auto tag_invoke(pika::execution::experimental::set_value_t,
ensure_started_receiver r, Ts&&... ts) noexcept
auto set_value(Ts&&... ts) && noexcept
-> decltype(std::declval<
pika::detail::variant<pika::detail::monostate, value_type>>()
.template emplace<value_type>(
std::make_tuple<>(PIKA_FORWARD(Ts, ts)...)),
void())
{
auto r = PIKA_MOVE(*this);
r.state->v.template emplace<value_type>(
std::make_tuple<>(PIKA_FORWARD(Ts, ts)...));
r.state->set_predecessor_done();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ namespace pika::let_error_detail {
template <typename... Ts,
typename = std::enable_if_t<std::is_invocable_v<
pika::execution::experimental::set_value_t, Receiver&&, Ts...>>>
friend void tag_invoke(pika::execution::experimental::set_value_t,
let_error_predecessor_receiver&& r, Ts&&... ts) noexcept
void set_value(Ts&&... ts) && noexcept
{
auto r = PIKA_MOVE(*this);
pika::execution::experimental::set_value(
PIKA_MOVE(r.receiver), PIKA_FORWARD(Ts, ts)...);
}
Expand Down
30 changes: 10 additions & 20 deletions libs/pika/execution/include/pika/execution/algorithms/let_value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,37 +231,27 @@ namespace pika::let_value_detail {
pika::detail::monostate>;

template <typename... Ts>
void set_value(Ts&&... ts)
auto set_value(Ts&&... ts) && noexcept
-> decltype(std::declval<predecessor_ts_type>()
.template emplace<std::tuple<std::decay_t<Ts>...>>(
PIKA_FORWARD(Ts, ts)...),
void())
{
auto r = PIKA_MOVE(*this);
pika::detail::try_catch_exception_ptr(
[&]() {
op_state.predecessor_ts
r.op_state.predecessor_ts
.template emplace<std::tuple<std::decay_t<Ts>...>>(
PIKA_FORWARD(Ts, ts)...);
pika::detail::visit(
set_value_visitor{PIKA_MOVE(receiver), PIKA_MOVE(f), op_state},
op_state.predecessor_ts);
set_value_visitor{PIKA_MOVE(r.receiver), PIKA_MOVE(f), r.op_state},
r.op_state.predecessor_ts);
},
[&](std::exception_ptr ep) {
pika::execution::experimental::set_error(
PIKA_MOVE(receiver), PIKA_MOVE(ep));
PIKA_MOVE(r.receiver), PIKA_MOVE(ep));
});
}

template <typename... Ts>
friend auto tag_invoke(pika::execution::experimental::set_value_t,
let_value_predecessor_receiver&& r, Ts&&... ts) noexcept
-> decltype(std::declval<predecessor_ts_type>()
.template emplace<std::tuple<std::decay_t<Ts>...>>(
PIKA_FORWARD(Ts, ts)...),
void())
{
// set_value is in a member function only because of a
// compiler bug in GCC 7. When the body of set_value is
// inlined here compilation fails with an internal
// compiler error.
r.set_value(PIKA_FORWARD(Ts, ts)...);
}
};

template <typename PredecessorSender_, typename Receiver_, typename F_>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ namespace pika {
};

template <typename... Ts>
friend void tag_invoke(pika::execution::experimental::set_value_t,
require_started_receiver_type r, Ts&&... ts) noexcept
void set_value(Ts&&... ts) && noexcept
{
auto r = PIKA_MOVE(*this);
PIKA_ASSERT(r.op_state != nullptr);
pika::execution::experimental::set_value(
PIKA_MOVE(r.op_state->receiver), PIKA_FORWARD(Ts, ts)...);
Expand Down Expand Up @@ -381,8 +381,7 @@ namespace pika {

s.connected = true;
return
{
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
{ // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
*std::exchange(s.sender, std::nullopt), PIKA_FORWARD(Receiver, receiver)
#if defined(PIKA_DETAIL_HAVE_REQUIRE_STARTED_MODE)
,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,13 @@ namespace pika::schedule_from_detail {
pika::detail::monostate>;

template <typename... Ts>
friend auto tag_invoke(pika::execution::experimental::set_value_t,
predecessor_sender_receiver&& r, Ts&&... ts) noexcept
auto set_value(Ts&&... ts) && noexcept
-> decltype(std::declval<value_type>()
.template emplace<std::tuple<std::decay_t<Ts>...>>(
PIKA_FORWARD(Ts, ts)...),
void())
{
auto r = PIKA_MOVE(*this);
// nvcc fails to compile this with std::forward<Ts>(ts)...
// or static_cast<Ts&&>(ts)... so we explicitly use
// static_cast<decltype(ts)>(ts)... as a workaround.
Expand Down Expand Up @@ -252,9 +252,9 @@ namespace pika::schedule_from_detail {
r.op_state.set_stopped_scheduler_sender();
}

friend void tag_invoke(pika::execution::experimental::set_value_t,
scheduler_sender_receiver&& r) noexcept
void set_value() && noexcept
{
auto r = PIKA_MOVE(*this);
r.op_state.set_value_scheduler_sender();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,14 @@ namespace pika::split_detail {
value_type_helper>;

template <typename... Ts>
friend auto tag_invoke(pika::execution::experimental::set_value_t, split_receiver r,
Ts&&... ts) noexcept
auto set_value(Ts&&... ts) && noexcept
-> decltype(std::declval<
pika::detail::variant<pika::detail::monostate, value_type>>()
.template emplace<value_type>(
std::make_tuple<>(PIKA_FORWARD(Ts, ts)...)),
void())
{
auto r = PIKA_MOVE(*this);
r.state->v.template emplace<value_type>(
std::make_tuple<>(PIKA_FORWARD(Ts, ts)...));

Expand Down
Loading

0 comments on commit 67a1976

Please sign in to comment.