Skip to content

Commit

Permalink
Update std driver
Browse files Browse the repository at this point in the history
  • Loading branch information
bugadani committed Dec 5, 2024
1 parent 2818f17 commit 21dff6a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 88 deletions.
2 changes: 1 addition & 1 deletion embassy-time/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ target = "x86_64-unknown-linux-gnu"
features = ["defmt", "std"]

[features]
std = ["tick-hz-1_000_000", "critical-section/std"]
std = ["tick-hz-1_000_000", "critical-section/std", "embassy-time-queue-driver/generic-queue"]
wasm = ["dep:wasm-bindgen", "dep:js-sys", "dep:wasm-timer", "tick-hz-1_000_000"]

## Display the time since startup next to defmt log messages.
Expand Down
167 changes: 80 additions & 87 deletions embassy-time/src/driver_std.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,35 @@
use core::sync::atomic::{AtomicU8, Ordering};
use std::cell::{RefCell, UnsafeCell};
use std::mem::MaybeUninit;
use std::sync::{Condvar, Mutex, Once};
use std::time::{Duration as StdDuration, Instant as StdInstant};
use std::{mem, ptr, thread};
use std::{ptr, thread};

use critical_section::Mutex as CsMutex;
use embassy_time_driver::{AlarmHandle, Driver};

const ALARM_COUNT: usize = 4;
use embassy_time_driver::Driver;

struct AlarmState {
timestamp: u64,

// This is really a Option<(fn(*mut ()), *mut ())>
// but fn pointers aren't allowed in const yet
callback: *const (),
ctx: *mut (),
}

unsafe impl Send for AlarmState {}

impl AlarmState {
const fn new() -> Self {
Self {
timestamp: u64::MAX,
callback: ptr::null(),
ctx: ptr::null_mut(),
}
Self { timestamp: u64::MAX }
}
}

struct TimeDriver {
alarm_count: AtomicU8,

once: Once,
// The STD Driver implementation requires the alarms' mutex to be reentrant, which the STD Mutex isn't
// Fortunately, mutexes based on the `critical-section` crate are reentrant, because the critical sections
// themselves are reentrant
alarms: UninitCell<CsMutex<RefCell<[AlarmState; ALARM_COUNT]>>>,
alarms: UninitCell<CsMutex<RefCell<AlarmState>>>,
zero_instant: UninitCell<StdInstant>,
signaler: UninitCell<Signaler>,
}

embassy_time_driver::time_driver_impl!(static DRIVER: TimeDriver = TimeDriver {
alarm_count: AtomicU8::new(0),

once: Once::new(),
alarms: UninitCell::uninit(),
zero_instant: UninitCell::uninit(),
Expand All @@ -56,7 +40,7 @@ impl TimeDriver {
fn init(&self) {
self.once.call_once(|| unsafe {
self.alarms
.write(CsMutex::new(RefCell::new([const { AlarmState::new() }; ALARM_COUNT])));
.write(CsMutex::new(RefCell::new(const { AlarmState::new() })));
self.zero_instant.write(StdInstant::now());
self.signaler.write(Signaler::new());

Expand All @@ -70,36 +54,13 @@ impl TimeDriver {
let now = DRIVER.now();

let next_alarm = critical_section::with(|cs| {
let alarms = unsafe { DRIVER.alarms.as_ref() }.borrow(cs);
loop {
let pending = alarms
.borrow_mut()
.iter_mut()
.find(|alarm| alarm.timestamp <= now)
.map(|alarm| {
alarm.timestamp = u64::MAX;

(alarm.callback, alarm.ctx)
});

if let Some((callback, ctx)) = pending {
// safety:
// - we can ignore the possiblity of `f` being unset (null) because of the safety contract of `allocate_alarm`.
// - other than that we only store valid function pointers into alarm.callback
let f: fn(*mut ()) = unsafe { mem::transmute(callback) };
f(ctx);
} else {
// No alarm due
break;
}
}
let mut alarm = unsafe { DRIVER.alarms.as_ref() }.borrow_ref_mut(cs);
if alarm.timestamp <= now {
alarm.timestamp = u64::MAX;

alarms
.borrow()
.iter()
.map(|alarm| alarm.timestamp)
.min()
.unwrap_or(u64::MAX)
TIMER_QUEUE_DRIVER.dispatch();
}
alarm.timestamp
});

// Ensure we don't overflow
Expand All @@ -110,46 +71,11 @@ impl TimeDriver {
unsafe { DRIVER.signaler.as_ref() }.wait_until(until);
}
}
}

impl Driver for TimeDriver {
fn now(&self) -> u64 {
self.init();

let zero = unsafe { self.zero_instant.read() };
StdInstant::now().duration_since(zero).as_micros() as u64
}

unsafe fn allocate_alarm(&self) -> Option<AlarmHandle> {
let id = self.alarm_count.fetch_update(Ordering::AcqRel, Ordering::Acquire, |x| {
if x < ALARM_COUNT as u8 {
Some(x + 1)
} else {
None
}
});

match id {
Ok(id) => Some(AlarmHandle::new(id)),
Err(_) => None,
}
}

fn set_alarm_callback(&self, alarm: AlarmHandle, callback: fn(*mut ()), ctx: *mut ()) {
fn set_alarm(&self, timestamp: u64) -> bool {
self.init();
critical_section::with(|cs| {
let mut alarms = unsafe { self.alarms.as_ref() }.borrow_ref_mut(cs);
let alarm = &mut alarms[alarm.id() as usize];
alarm.callback = callback as *const ();
alarm.ctx = ctx;
});
}

fn set_alarm(&self, alarm: AlarmHandle, timestamp: u64) -> bool {
self.init();
critical_section::with(|cs| {
let mut alarms = unsafe { self.alarms.as_ref() }.borrow_ref_mut(cs);
let alarm = &mut alarms[alarm.id() as usize];
let mut alarm = unsafe { self.alarms.as_ref() }.borrow_ref_mut(cs);
alarm.timestamp = timestamp;
unsafe { self.signaler.as_ref() }.signal();
});
Expand All @@ -158,6 +84,15 @@ impl Driver for TimeDriver {
}
}

impl Driver for TimeDriver {
fn now(&self) -> u64 {
self.init();

let zero = unsafe { self.zero_instant.read() };
StdInstant::now().duration_since(zero).as_micros() as u64
}
}

struct Signaler {
mutex: Mutex<bool>,
condvar: Condvar,
Expand Down Expand Up @@ -228,3 +163,61 @@ impl<T: Copy> UninitCell<T> {
ptr::read(self.as_mut_ptr())
}
}

struct RawQueue {
inner: core::cell::RefCell<embassy_time_queue_driver::queue_generic::Queue>,
}

impl RawQueue {
const fn new() -> Self {
Self {
inner: core::cell::RefCell::new(embassy_time_queue_driver::queue_generic::Queue::new()),
}
}

fn schedule_wake(&self, waker: &core::task::Waker, at: u64) -> bool {
self.inner.borrow_mut().schedule_wake(at, waker)
}

fn next_expiration(&self, now: u64) -> u64 {
self.inner.borrow_mut().next_expiration(now)
}
}

struct TimerQueueDriver {
inner: Mutex<RawQueue>,
}

embassy_time_queue_driver::timer_queue_impl!(static TIMER_QUEUE_DRIVER: TimerQueueDriver = TimerQueueDriver::new());

impl embassy_time_queue_driver::TimerQueue for TimerQueueDriver {
fn schedule_wake(&'static self, at: u64, waker: &core::task::Waker) {
let q = self.inner.lock().unwrap();
if q.schedule_wake(waker, at) {
self.arm_alarm(at);
}
}
}

impl TimerQueueDriver {
const fn new() -> Self {
Self {
inner: Mutex::new(RawQueue::new()),
}
}

pub fn dispatch(&self) {
let now = DRIVER.now();
let q = self.inner.lock().unwrap();
let next_expiration = q.next_expiration(now);
self.arm_alarm(next_expiration);
}

fn arm_alarm(&self, mut next_expiration: u64) {
while !DRIVER.set_alarm(next_expiration) {
// next_expiration is in the past, dequeue and find a new expiration
let q = self.inner.lock().unwrap();
next_expiration = q.next_expiration(next_expiration);
}
}
}

0 comments on commit 21dff6a

Please sign in to comment.