Skip to content

Commit

Permalink
[fiber] Implement std concurrency interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
salkinium committed May 6, 2024
1 parent 03f5696 commit 72ab66d
Show file tree
Hide file tree
Showing 20 changed files with 1,682 additions and 142 deletions.
4 changes: 4 additions & 0 deletions ext/gcc/assert.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,9 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
void
__throw_bad_any_cast()
{ __modm_stdcpp_failure("bad_any_cast"); }

void
__throw_system_error(int errc __attribute__((unused)))
{ __modm_stdcpp_failure("system_error"); }
_GLIBCXX_END_NAMESPACE_VERSION
} // namespace
85 changes: 85 additions & 0 deletions src/modm/processing/fiber/barrier.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2024, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "fiber.hpp"

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// Implements the `std::barrier` interface for fibers.
/// @warning This implementation is not interrupt-safe!
/// @see https://en.cppreference.com/w/cpp/thread/barrier
template< class CompletionFunction = decltype([]{}) >
class barrier
{
barrier(const barrier&) = delete;
barrier& operator=(const barrier&) = delete;
using count_t = uint16_t;

const CompletionFunction completion;
count_t expected;
count_t count;
volatile count_t sequence{};

public:
using arrival_token = count_t;

constexpr explicit
barrier(std::ptrdiff_t expected, CompletionFunction f = CompletionFunction())
: completion(std::move(f)), expected(expected), count(expected) {}

[[nodiscard]]
static constexpr std::ptrdiff_t
max() { return count_t(-1); }

[[nodiscard]]
arrival_token
arrive(count_t n=1)
{
count_t last_arrival{sequence};
if (n < count) count -= n;
else
{
count = expected;
sequence++;
completion();
}
return last_arrival;
}

void
wait(arrival_token arrival) const
{
while (arrival == sequence) modm::this_fiber::yield();
}

void
arrive_and_wait()
{
wait(arrive());
}

void
arrive_and_drop()
{
if (expected) expected--;
(void) arrive();
}
};

/// @}

}
178 changes: 178 additions & 0 deletions src/modm/processing/fiber/condition_variable.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Copyright (c) 2024, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "fiber.hpp"
#include "stop_token.hpp"
#include <atomic>


namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

enum class
cv_status
{
no_timeout,
timeout
};

/// Implements the `std::condition_variable_any` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/condition_variable
class condition_variable_any
{
condition_variable_any(const condition_variable_any&) = delete;
condition_variable_any& operator=(const condition_variable_any&) = delete;

std::atomic<uint16_t> sequence{};

const auto inline condition()
{
return [this, poll_sequence = sequence.load(std::memory_order_relaxed)]
{ return poll_sequence != sequence.load(std::memory_order_relaxed); };
}
public:
constexpr condition_variable_any() = default;

/// @note This function can be called from an interrupt.
void inline
notify_one()
{
sequence.fetch_add(1, std::memory_order_acquire);
}

/// @note This function can be called from an interrupt.
void inline
notify_any()
{
notify_one();
}


template< class Lock >
void
wait(Lock& lock)
{
lock.unlock();
this_fiber::poll(condition());
lock.lock();
}

template< class Lock, class Predicate >
requires requires { std::is_invocable_r_v<bool, Predicate, void>; }
void
wait(Lock& lock, Predicate&& pred)
{
while (not std::forward<Predicate>(pred)()) wait(lock);
}

template< class Lock, class Predicate >
requires requires { std::is_invocable_r_v<bool, Predicate, void>; }
bool
wait(Lock& lock, stop_token stoken, Predicate&& pred)
{
while (not stoken.stop_requested())
{
if (std::forward<Predicate>(pred)()) return true;
wait(lock);
}
return std::forward<Predicate>(pred)();
}


template< class Lock, class Rep, class Period >
cv_status
wait_for(Lock& lock, std::chrono::duration<Rep, Period> rel_time)
{
lock.unlock();
const bool result = this_fiber::poll_for(rel_time, condition());
lock.lock();
return result ? cv_status::no_timeout : cv_status::timeout;
}

template< class Lock, class Rep, class Period, class Predicate >
requires requires { std::is_invocable_r_v<bool, Predicate, void>; }
bool
wait_for(Lock& lock, std::chrono::duration<Rep, Period> rel_time, Predicate&& pred)
{
while (not std::forward<Predicate>(pred)())
{
if (wait_for(lock, rel_time) == cv_status::timeout)
return std::forward<Predicate>(pred)();
}
return true;
}

template< class Lock, class Rep, class Period, class Predicate >
requires requires { std::is_invocable_r_v<bool, Predicate, void>; }
bool
wait_for(Lock& lock, stop_token stoken,
std::chrono::duration<Rep, Period> rel_time, Predicate&& pred)
{
while (not stoken.stop_requested())
{
if (std::forward<Predicate>(pred)()) return true;
if (wait_for(lock, rel_time) == cv_status::timeout)
return std::forward<Predicate>(pred)();
}
return std::forward<Predicate>(pred)();
}


template< class Lock, class Clock, class Duration >
cv_status
wait_until(Lock& lock, std::chrono::time_point<Clock, Duration> abs_time)
{
lock.unlock();
const bool result = this_fiber::poll_until(abs_time, condition());
lock.lock();
return result ? cv_status::no_timeout : cv_status::timeout;
}

template< class Lock, class Clock, class Duration, class Predicate >
requires requires { std::is_invocable_r_v<bool, Predicate, void>; }
bool
wait_until(Lock& lock, std::chrono::time_point<Clock, Duration> abs_time, Predicate&& pred)
{
while (not std::forward<Predicate>(pred)())
{
if (wait_until(lock, abs_time) == cv_status::timeout)
return std::forward<Predicate>(pred)();
}
return true;
}

template< class Lock, class Clock, class Duration, class Predicate >
requires requires { std::is_invocable_r_v<bool, Predicate, void>; }
bool
wait_until(Lock& lock, stop_token stoken,
std::chrono::time_point<Clock, Duration> abs_time, Predicate&& pred)
{
while (not stoken.stop_requested())
{
if (std::forward<Predicate>(pred)()) return true;
if (wait_until(lock, abs_time) == cv_status::timeout)
return std::forward<Predicate>(pred)();
}
return std::forward<Predicate>(pred)();
}
};

/// There is no specialization for `std::unique_lock<fiber::mutex>`.
using condition_variable = condition_variable_any;

/// @}

}
77 changes: 77 additions & 0 deletions src/modm/processing/fiber/latch.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2024, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "fiber.hpp"
#include <atomic>

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// Implements the `std::latch` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/latch
class latch
{
latch(const latch&) = delete;
latch& operator=(const latch&) = delete;

using count_t = uint16_t;
std::atomic<count_t> count;

public:
constexpr explicit
latch(std::ptrdiff_t expected)
: count(expected) {}

[[nodiscard]]
static constexpr std::ptrdiff_t
max() { return count_t(-1); }

/// @note This function can be called from an interrupt.
void inline
count_down(count_t n=1)
{
// ensure we do not underflow the counter!
count_t value = count.load(std::memory_order_relaxed);
do if (value == 0) return;
while (not count.compare_exchange_weak(value, value >= n ? value - n : 0,
std::memory_order_acquire, std::memory_order_relaxed));
}

/// @note This function can be called from an interrupt.
[[nodiscard]]
bool inline
try_wait() const
{
return count.load(std::memory_order_relaxed) == 0;
}

void inline
wait() const
{
while(not try_wait()) modm::this_fiber::yield();
}

void inline
arrive_and_wait(std::ptrdiff_t n=1)
{
count_down(n);
wait();
}
};

/// @}

}
11 changes: 10 additions & 1 deletion src/modm/processing/fiber/module.lb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def is_enabled(env):
not env.has_module(":processing:protothread")

def prepare(module, options):
module.depends(":processing:timer")
module.depends(":processing:timer", ":architecture:atomic")

module.add_query(
EnvironmentQuery(name="__enabled", factory=is_enabled))
Expand All @@ -47,6 +47,7 @@ def build(env):
"with_fpu": with_fpu,
"target": env[":target"].identifier,
"multicore": env.has_module(":platform:multicore"),
"num_cores": 1,
}
if env.has_module(":platform:multicore"):
cores = int(env[":target"].identifier.cores)
Expand Down Expand Up @@ -78,3 +79,11 @@ def build(env):
env.copy("task.hpp")
env.copy("functions.hpp")
env.copy("fiber.hpp")

env.copy("mutex.hpp")
env.copy("shared_mutex.hpp")
env.copy("semaphore.hpp")
env.copy("latch.hpp")
env.copy("barrier.hpp")
env.copy("stop_token.hpp")
env.copy("condition_variable.hpp")
Loading

0 comments on commit 72ab66d

Please sign in to comment.