diff --git a/Cargo.lock b/Cargo.lock index 5218d9a65..945d754cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4576,6 +4576,16 @@ dependencies = [ "riot-rs-boards", ] +[[package]] +name = "threading-mutex" +version = "0.1.0" +dependencies = [ + "embassy-executor", + "portable-atomic", + "riot-rs", + "riot-rs-boards", +] + [[package]] name = "time" version = "0.3.36" diff --git a/Cargo.toml b/Cargo.toml index 2f02471f0..00c2c3bcf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ members = [ "tests/i2c-controller", "tests/threading-dynamic-prios", "tests/threading-lock", + "tests/threading-mutex", ] exclude = ["src/lib"] diff --git a/src/riot-rs-threads/src/lib.rs b/src/riot-rs-threads/src/lib.rs index 7de6111f0..83aa54716 100644 --- a/src/riot-rs-threads/src/lib.rs +++ b/src/riot-rs-threads/src/lib.rs @@ -24,6 +24,7 @@ #![cfg_attr(not(test), no_std)] #![feature(naked_functions)] #![feature(used_with_arg)] +#![feature(negative_impls)] #![cfg_attr(target_arch = "xtensa", feature(asm_experimental_arch))] // Disable indexing lints for now, possible panics are documented or rely on internally-enforced // invariants diff --git a/src/riot-rs-threads/src/sync/mod.rs b/src/riot-rs-threads/src/sync/mod.rs index eadb8fee0..7a9376573 100644 --- a/src/riot-rs-threads/src/sync/mod.rs +++ b/src/riot-rs-threads/src/sync/mod.rs @@ -1,6 +1,8 @@ //! Synchronization primitives. mod channel; mod lock; +mod mutex; pub use channel::Channel; pub use lock::Lock; +pub use mutex::{Mutex, MutexGuard}; diff --git a/src/riot-rs-threads/src/sync/mutex.rs b/src/riot-rs-threads/src/sync/mutex.rs new file mode 100644 index 000000000..31d9664e0 --- /dev/null +++ b/src/riot-rs-threads/src/sync/mutex.rs @@ -0,0 +1,205 @@ +use core::{ + cell::UnsafeCell, + ops::{Deref, DerefMut}, +}; + +use critical_section::CriticalSection; +use riot_rs_runqueue::{RunqueueId, ThreadId}; + +use crate::{thread::ThreadState, threadlist::ThreadList, THREADS}; + +/// A basic mutex with priority inheritance. +pub struct Mutex { + state: UnsafeCell, + inner: UnsafeCell, +} + +/// State of a [`Mutex`]. +enum LockState { + Unlocked, + Locked { + /// The current owner of the lock. + owner_id: ThreadId, + /// The original priority of the current owner (without priority inheritance). + owner_prio: RunqueueId, + //. Waiters for the mutex. + waiters: ThreadList, + }, +} + +impl LockState { + /// Returns a [`LockState::Locked`] with the current thread as the owner + /// and an empty waitlist. + /// + /// # Panics + /// + /// Panics if called outside of a thread context. + fn locked_with_current(cs: CriticalSection) -> Self { + let (owner_id, owner_prio) = THREADS.with_mut_cs(cs, |mut threads| { + let current = threads + .current() + .expect("Function should be called inside a thread context."); + (current.pid, current.prio) + }); + LockState::Locked { + waiters: ThreadList::new(), + owner_id, + owner_prio, + } + } +} + +impl Mutex { + /// Creates a new **unlocked** [`Mutex`]. + pub const fn new(value: T) -> Self { + Self { + state: UnsafeCell::new(LockState::Unlocked), + inner: UnsafeCell::new(value), + } + } +} + +impl Mutex { + /// Returns whether the mutex is locked. + pub fn is_locked(&self) -> bool { + critical_section::with(|_| { + let state = unsafe { &*self.state.get() }; + !matches!(state, LockState::Unlocked) + }) + } + + /// Acquires a mutex, blocking the current thread until it is able to do so. + /// + /// If the mutex was unlocked, it will be locked and a [`MutexGuard`] is returned. + /// If the mutex is locked, this function will block the current thread until the mutex gets + /// unlocked elsewhere. + /// + /// If the current owner of the mutex has a lower priority than the current thread, it will inherit + /// the waiting thread's priority. + /// The priority is reset once the mutex is released. This means that a **user can not change a thread's + /// priority while it holds the lock**, because it will be changed back after release! + /// + /// # Panics + /// + /// Panics if called outside of a thread context. + pub fn lock(&self) -> MutexGuard { + critical_section::with(|cs| { + // SAFETY: access to the state only happens in critical sections, so it's always unique. + let state = unsafe { &mut *self.state.get() }; + match state { + LockState::Unlocked => { + *state = LockState::locked_with_current(cs); + } + LockState::Locked { + waiters, + owner_id, + owner_prio, + } => { + // Insert thread in waitlist, which also triggers the scheduler. + match waiters.put_current(cs, ThreadState::LockBlocked) { + // `Some` when the inserted thread is the highest priority + // thread in the waitlist. + Some(waiter_prio) if waiter_prio > *owner_prio => { + // Current mutex owner inherits the priority. + THREADS.with_mut_cs(cs, |mut threads| { + threads.set_priority(*owner_id, waiter_prio) + }); + } + _ => {} + } + // Context switch happens here as soon as we leave the critical section. + } + } + }); + // Mutex was either directly acquired because it was unlocked, or the current thread was entered + // to the waitlist. In the latter case, it only continues running here after it was popped again + // from the waitlist and the thread acquired the mutex. + + MutexGuard { mutex: self } + } + + /// Attempts to acquire this lock, in a non-blocking fashion. + /// + /// If the mutex was unlocked, it will be locked and a [`MutexGuard`] is returned. + /// If the mutex was locked `None` is returned. + pub fn try_lock(&self) -> Option> { + critical_section::with(|cs| { + // SAFETY: access to the state only happens in critical sections, so it's always unique. + let state = unsafe { &mut *self.state.get() }; + if let LockState::Unlocked = *state { + *state = LockState::locked_with_current(cs); + Some(MutexGuard { mutex: self }) + } else { + None + } + }) + } + + /// Releases the mutex. + /// + /// If there are waiters, the first waiter will be woken up. + fn release(&self) { + critical_section::with(|cs| { + // SAFETY: access to the state only happens in critical sections, so it's always unique. + let state = unsafe { &mut *self.state.get() }; + if let LockState::Locked { + waiters, + owner_id, + owner_prio, + } = state + { + // Reset original priority of owner. + THREADS.with_mut_cs(cs, |mut threads| { + threads.set_priority(*owner_id, *owner_prio) + }); + // Pop next thread from waitlist so that it can acquire the mutex. + if let Some((pid, _)) = waiters.pop(cs) { + THREADS.with_mut_cs(cs, |threads| { + *owner_id = pid; + *owner_prio = threads.get_unchecked(pid).prio; + }) + } else { + // Unlock if waitlist was empty. + *state = LockState::Unlocked + } + } + }) + } +} + +unsafe impl Sync for Mutex {} + +/// Grants access to the [`Mutex`] inner data. +/// +/// Dropping the [`MutexGuard`] will unlock the [`Mutex`]; +pub struct MutexGuard<'a, T> { + mutex: &'a Mutex, +} + +impl<'a, T> Deref for MutexGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: MutexGuard always has unique access. + unsafe { &*self.mutex.inner.get() } + } +} + +impl<'a, T> DerefMut for MutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: MutexGuard always has unique access. + unsafe { &mut *self.mutex.inner.get() } + } +} + +impl<'a, T> Drop for MutexGuard<'a, T> { + fn drop(&mut self) { + // Unlock the mutex when the guard is dropped. + self.mutex.release() + } +} + +// The [`MutexGuard`] is tied to a thread, it must not be possible to `Send` it to another thread. +impl !Send for MutexGuard<'_, T> {} + +unsafe impl Sync for MutexGuard<'_, T> {} diff --git a/src/riot-rs-threads/src/threadlist.rs b/src/riot-rs-threads/src/threadlist.rs index 4dc9495a1..49c473579 100644 --- a/src/riot-rs-threads/src/threadlist.rs +++ b/src/riot-rs-threads/src/threadlist.rs @@ -1,12 +1,12 @@ use critical_section::CriticalSection; -use crate::{thread::Thread, ThreadId, ThreadState, THREADS}; +use crate::{thread::Thread, RunqueueId, ThreadId, ThreadState, THREADS}; /// Manages blocked [`super::Thread`]s for a resource, and triggering the scheduler when needed. #[derive(Debug, Default)] pub struct ThreadList { /// Next thread to run once the resource is available. - pub head: Option, + head: Option, } impl ThreadList { @@ -17,10 +17,12 @@ impl ThreadList { /// Puts the current (blocked) thread into this [`ThreadList`] and triggers the scheduler. /// + /// Returns a `RunqueueId` if the highest priority among the waiters in the list has changed. + /// /// # Panics /// /// Panics if this is called outside of a thread context. - pub fn put_current(&mut self, cs: CriticalSection, state: ThreadState) { + pub fn put_current(&mut self, cs: CriticalSection, state: ThreadState) -> Option { THREADS.with_mut_cs(cs, |mut threads| { let &mut Thread { pid, prio, .. } = threads .current() @@ -35,12 +37,19 @@ impl ThreadList { next = threads.thread_blocklist[usize::from(n)]; } threads.thread_blocklist[usize::from(pid)] = next; - match curr { - Some(curr) => threads.thread_blocklist[usize::from(curr)] = Some(pid), - _ => self.head = Some(pid), - } + let inherit_priority = match curr { + Some(curr) => { + threads.thread_blocklist[usize::from(curr)] = Some(pid); + None + } + None => { + self.head = Some(pid); + Some(prio) + } + }; threads.set_state(pid, state); - }); + inherit_priority + }) } /// Removes the head from this [`ThreadList`]. diff --git a/tests/laze.yml b/tests/laze.yml index c8792e3ff..755c01a7d 100644 --- a/tests/laze.yml +++ b/tests/laze.yml @@ -6,3 +6,4 @@ subdirs: - i2c-controller - threading-dynamic-prios - threading-lock + - threading-mutex diff --git a/tests/threading-mutex/Cargo.toml b/tests/threading-mutex/Cargo.toml new file mode 100644 index 000000000..1db1a1690 --- /dev/null +++ b/tests/threading-mutex/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "threading-mutex" +version = "0.1.0" +authors = ["Elena Frank "] +license.workspace = true +edition.workspace = true +publish = false + +[dependencies] +embassy-executor = { workspace = true } +riot-rs = { path = "../../src/riot-rs", features = ["threading"] } +riot-rs-boards = { path = "../../src/riot-rs-boards" } +portable-atomic = "1.6.0" diff --git a/tests/threading-mutex/laze.yml b/tests/threading-mutex/laze.yml new file mode 100644 index 000000000..93222643f --- /dev/null +++ b/tests/threading-mutex/laze.yml @@ -0,0 +1,5 @@ +apps: + - name: threading-mutex + selects: + - ?release + - sw/threading diff --git a/tests/threading-mutex/src/main.rs b/tests/threading-mutex/src/main.rs new file mode 100644 index 000000000..93ea07176 --- /dev/null +++ b/tests/threading-mutex/src/main.rs @@ -0,0 +1,106 @@ +#![no_main] +#![no_std] +#![feature(type_alias_impl_trait)] +#![feature(used_with_arg)] + +use portable_atomic::{AtomicUsize, Ordering}; +use riot_rs::thread::{sync::Mutex, thread_flags, RunqueueId, ThreadId}; + +static MUTEX: Mutex = Mutex::new(0); +static RUN_ORDER: AtomicUsize = AtomicUsize::new(0); + +#[riot_rs::thread(autostart, priority = 1)] +fn thread0() { + let pid = riot_rs::thread::current_pid().unwrap(); + assert_eq!(riot_rs::thread::get_priority(pid), Some(RunqueueId::new(1))); + + assert_eq!(RUN_ORDER.fetch_add(1, Ordering::AcqRel), 0); + + let mut counter = MUTEX.lock(); + + // Unblock other threads in the order of their IDs. + // + // Because all other threads have higher priorities, setting + // a flag will each time cause a context switch and give each + // thread the chance to run and try acquire the lock. + thread_flags::set(ThreadId::new(1), 0b1); + // Inherit prio of higher prio waiting thread. + assert_eq!( + riot_rs::thread::get_priority(pid), + riot_rs::thread::get_priority(ThreadId::new(1)), + ); + thread_flags::set(ThreadId::new(2), 0b1); + // Inherit prio of highest waiting thread. + assert_eq!( + riot_rs::thread::get_priority(pid), + riot_rs::thread::get_priority(ThreadId::new(2)), + ); + thread_flags::set(ThreadId::new(3), 0b1); + // Still has priority of highest waiting thread. + assert_eq!( + riot_rs::thread::get_priority(pid), + riot_rs::thread::get_priority(ThreadId::new(2)), + ); + + assert_eq!(*counter, 0); + *counter += 1; + + drop(counter); + + // Return to old prio. + assert_eq!(riot_rs::thread::get_priority(pid), Some(RunqueueId::new(1))); + + // Wait for other threads to complete. + thread_flags::wait_all(0b111); + + assert_eq!(*MUTEX.lock(), 4); + riot_rs::debug::log::info!("Test passed!"); +} + +#[riot_rs::thread(autostart, priority = 2)] +fn thread1() { + let pid = riot_rs::thread::current_pid().unwrap(); + assert_eq!(riot_rs::thread::get_priority(pid), Some(RunqueueId::new(2))); + + thread_flags::wait_one(0b1); + assert_eq!(RUN_ORDER.fetch_add(1, Ordering::AcqRel), 1); + + let mut counter = MUTEX.lock(); + assert_eq!(*counter, 2); + *counter += 1; + + thread_flags::set(ThreadId::new(0), 0b1); +} + +#[riot_rs::thread(autostart, priority = 3)] +fn thread2() { + let pid = riot_rs::thread::current_pid().unwrap(); + assert_eq!(riot_rs::thread::get_priority(pid), Some(RunqueueId::new(3))); + + thread_flags::wait_one(0b1); + assert_eq!(RUN_ORDER.fetch_add(1, Ordering::AcqRel), 2); + + let mut counter = MUTEX.lock(); + assert_eq!(*counter, 1); + // Priority didn't change because this thread has higher prio + // than all waiting threads. + assert_eq!(riot_rs::thread::get_priority(pid), Some(RunqueueId::new(3)),); + *counter += 1; + + thread_flags::set(ThreadId::new(0), 0b10); +} + +#[riot_rs::thread(autostart, priority = 2)] +fn thread3() { + let pid = riot_rs::thread::current_pid().unwrap(); + assert_eq!(riot_rs::thread::get_priority(pid), Some(RunqueueId::new(2))); + + thread_flags::wait_one(0b1); + assert_eq!(RUN_ORDER.fetch_add(1, Ordering::AcqRel), 3); + + let mut counter = MUTEX.lock(); + assert_eq!(*counter, 3); + *counter += 1; + + thread_flags::set(ThreadId::new(0), 0b100); +}