diff --git a/src/concurrency/sync.rs b/src/concurrency/sync.rs index ef4034cc0c..b9b8932523 100644 --- a/src/concurrency/sync.rs +++ b/src/concurrency/sync.rs @@ -422,7 +422,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { mutex_ref: MutexRef, retval_dest: Option<(Scalar, MPlaceTy<'tcx>)>, } - @unblock = |this| { + @unblock = |this, unblock: UnblockKind| { + assert_eq!(unblock, UnblockKind::Ready); + assert!(!this.mutex_is_locked(&mutex_ref)); this.mutex_lock(&mutex_ref); @@ -538,7 +540,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { retval: Scalar, dest: MPlaceTy<'tcx>, } - @unblock = |this| { + @unblock = |this, unblock: UnblockKind| { + assert_eq!(unblock, UnblockKind::Ready); this.rwlock_reader_lock(id); this.write_scalar(retval, &dest)?; interp_ok(()) @@ -623,7 +626,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { retval: Scalar, dest: MPlaceTy<'tcx>, } - @unblock = |this| { + @unblock = |this, unblock: UnblockKind| { + assert_eq!(unblock, UnblockKind::Ready); this.rwlock_writer_lock(id); this.write_scalar(retval, &dest)?; interp_ok(()) @@ -677,25 +681,29 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { retval_timeout: Scalar, dest: MPlaceTy<'tcx>, } - @unblock = |this| { - // The condvar was signaled. Make sure we get the clock for that. - if let Some(data_race) = &this.machine.data_race { - data_race.acquire_clock( - &this.machine.sync.condvars[condvar].clock, - &this.machine.threads, - ); + @unblock = |this, unblock: UnblockKind| { + match unblock { + UnblockKind::Ready => { + // The condvar was signaled. Make sure we get the clock for that. + if let Some(data_race) = &this.machine.data_race { + data_race.acquire_clock( + &this.machine.sync.condvars[condvar].clock, + &this.machine.threads, + ); + } + // Try to acquire the mutex. + // The timeout only applies to the first wait (until the signal), not for mutex acquisition. + this.condvar_reacquire_mutex(&mutex_ref, retval_succ, dest) + } + UnblockKind::TimedOut => { + // We have to remove the waiter from the queue again. + let thread = this.active_thread(); + let waiters = &mut this.machine.sync.condvars[condvar].waiters; + waiters.retain(|waiter| *waiter != thread); + // Now get back the lock. + this.condvar_reacquire_mutex(&mutex_ref, retval_timeout, dest) + } } - // Try to acquire the mutex. - // The timeout only applies to the first wait (until the signal), not for mutex acquisition. - this.condvar_reacquire_mutex(&mutex_ref, retval_succ, dest) - } - @timeout = |this| { - // We have to remove the waiter from the queue again. - let thread = this.active_thread(); - let waiters = &mut this.machine.sync.condvars[condvar].waiters; - waiters.retain(|waiter| *waiter != thread); - // Now get back the lock. - this.condvar_reacquire_mutex(&mutex_ref, retval_timeout, dest) } ), ); @@ -752,25 +760,29 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { dest: MPlaceTy<'tcx>, errno_timeout: IoError, } - @unblock = |this| { - let futex = futex_ref.0.borrow(); - // Acquire the clock of the futex. - if let Some(data_race) = &this.machine.data_race { - data_race.acquire_clock(&futex.clock, &this.machine.threads); + @unblock = |this, unblock: UnblockKind| { + match unblock { + UnblockKind::Ready => { + let futex = futex_ref.0.borrow(); + // Acquire the clock of the futex. + if let Some(data_race) = &this.machine.data_race { + data_race.acquire_clock(&futex.clock, &this.machine.threads); + } + // Write the return value. + this.write_scalar(retval_succ, &dest)?; + interp_ok(()) + }, + UnblockKind::TimedOut => { + // Remove the waiter from the futex. + let thread = this.active_thread(); + let mut futex = futex_ref.0.borrow_mut(); + futex.waiters.retain(|waiter| waiter.thread != thread); + // Set errno and write return value. + this.set_last_error(errno_timeout)?; + this.write_scalar(retval_timeout, &dest)?; + interp_ok(()) + }, } - // Write the return value. - this.write_scalar(retval_succ, &dest)?; - interp_ok(()) - } - @timeout = |this| { - // Remove the waiter from the futex. - let thread = this.active_thread(); - let mut futex = futex_ref.0.borrow_mut(); - futex.waiters.retain(|waiter| waiter.thread != thread); - // Set errno and write return value. - this.set_last_error(errno_timeout)?; - this.write_scalar(retval_timeout, &dest)?; - interp_ok(()) } ), ); diff --git a/src/concurrency/thread.rs b/src/concurrency/thread.rs index 730c27d016..decffebca7 100644 --- a/src/concurrency/thread.rs +++ b/src/concurrency/thread.rs @@ -38,71 +38,20 @@ pub enum TlsAllocAction { Leak, } -/// Trait for callbacks that are executed when a thread gets unblocked. -pub trait UnblockCallback<'tcx>: VisitProvenance { - /// Will be invoked when the thread was unblocked the "regular" way, - /// i.e. whatever event it was blocking on has happened. - fn unblock(self: Box, ecx: &mut InterpCx<'tcx, MiriMachine<'tcx>>) -> InterpResult<'tcx>; - - /// Will be invoked when the timeout ellapsed without the event the - /// thread was blocking on having occurred. - fn timeout(self: Box, _ecx: &mut InterpCx<'tcx, MiriMachine<'tcx>>) - -> InterpResult<'tcx>; +/// The argument type for the "unblock" callback, indicating why the thread got unblocked. +#[derive(Debug, PartialEq)] +pub enum UnblockKind { + /// Operation completed successfully, thread continues normal execution. + Ready, + /// The operation did not complete within its specified duration. + TimedOut, } -pub type DynUnblockCallback<'tcx> = Box + 'tcx>; - -#[macro_export] -macro_rules! callback { - ( - @capture<$tcx:lifetime $(,)? $($lft:lifetime),*> { $($name:ident: $type:ty),* $(,)? } - @unblock = |$this:ident| $unblock:block - ) => { - callback!( - @capture<$tcx, $($lft),*> { $($name: $type),* } - @unblock = |$this| $unblock - @timeout = |_this| { - unreachable!( - "timeout on a thread that was blocked without a timeout (or someone forgot to overwrite this method)" - ) - } - ) - }; - ( - @capture<$tcx:lifetime $(,)? $($lft:lifetime),*> { $($name:ident: $type:ty),* $(,)? } - @unblock = |$this:ident| $unblock:block - @timeout = |$this_timeout:ident| $timeout:block - ) => {{ - struct Callback<$tcx, $($lft),*> { - $($name: $type,)* - _phantom: std::marker::PhantomData<&$tcx ()>, - } - - impl<$tcx, $($lft),*> VisitProvenance for Callback<$tcx, $($lft),*> { - #[allow(unused_variables)] - fn visit_provenance(&self, visit: &mut VisitWith<'_>) { - $( - self.$name.visit_provenance(visit); - )* - } - } - - impl<$tcx, $($lft),*> UnblockCallback<$tcx> for Callback<$tcx, $($lft),*> { - fn unblock(self: Box, $this: &mut MiriInterpCx<$tcx>) -> InterpResult<$tcx> { - #[allow(unused_variables)] - let Callback { $($name,)* _phantom } = *self; - $unblock - } - fn timeout(self: Box, $this_timeout: &mut MiriInterpCx<$tcx>) -> InterpResult<$tcx> { - #[allow(unused_variables)] - let Callback { $($name,)* _phantom } = *self; - $timeout - } - } +/// Type alias for boxed machine callbacks with generic argument type. +pub type DyMachineCallback<'tcx, T> = Box + 'tcx>; - Box::new(Callback { $($name,)* _phantom: std::marker::PhantomData }) - }} -} +/// Type alias for unblock callbacks using UnblockKind argument. +pub type DynUnblockCallback<'tcx> = DyMachineCallback<'tcx, UnblockKind>; /// A thread identifier. #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] @@ -656,7 +605,8 @@ impl<'tcx> ThreadManager<'tcx> { @capture<'tcx> { joined_thread_id: ThreadId, } - @unblock = |this| { + @unblock = |this, unblock: UnblockKind| { + assert_eq!(unblock, UnblockKind::Ready); if let Some(data_race) = &mut this.machine.data_race { data_race.thread_joined(&this.machine.threads, joined_thread_id); } @@ -842,7 +792,7 @@ trait EvalContextPrivExt<'tcx>: MiriInterpCxExt<'tcx> { // 2. Make the scheduler the only place that can change the active // thread. let old_thread = this.machine.threads.set_active_thread_id(thread); - callback.timeout(this)?; + callback.call(this, UnblockKind::TimedOut)?; this.machine.threads.set_active_thread_id(old_thread); } // found_callback can remain None if the computer's clock @@ -1084,7 +1034,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { }; // The callback must be executed in the previously blocked thread. let old_thread = this.machine.threads.set_active_thread_id(thread); - callback.unblock(this)?; + callback.call(this, UnblockKind::Ready)?; this.machine.threads.set_active_thread_id(old_thread); interp_ok(()) } diff --git a/src/lib.rs b/src/lib.rs index e02d51afce..ebfc1cd801 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,8 +126,8 @@ pub use crate::concurrency::sync::{ CondvarId, EvalContextExt as _, MutexRef, RwLockId, SynchronizationObjects, }; pub use crate::concurrency::thread::{ - BlockReason, EvalContextExt as _, StackEmptyCallback, ThreadId, ThreadManager, TimeoutAnchor, - TimeoutClock, UnblockCallback, + BlockReason, DynUnblockCallback, EvalContextExt as _, StackEmptyCallback, ThreadId, + ThreadManager, TimeoutAnchor, TimeoutClock, UnblockKind, }; pub use crate::diagnostics::{ EvalContextExt as _, NonHaltingDiagnostic, TerminationInfo, report_error, @@ -139,8 +139,8 @@ pub use crate::eval::{ pub use crate::helpers::{AccessKind, EvalContextExt as _}; pub use crate::intrinsics::EvalContextExt as _; pub use crate::machine::{ - AllocExtra, FrameExtra, MemoryKind, MiriInterpCx, MiriInterpCxExt, MiriMachine, MiriMemoryKind, - PrimitiveLayouts, Provenance, ProvenanceExtra, + AllocExtra, FrameExtra, MachineCallback, MemoryKind, MiriInterpCx, MiriInterpCxExt, + MiriMachine, MiriMemoryKind, PrimitiveLayouts, Provenance, ProvenanceExtra, }; pub use crate::mono_hash_map::MonoHashMap; pub use crate::operator::EvalContextExt as _; diff --git a/src/machine.rs b/src/machine.rs index 5e8f616a37..4fe7ae5366 100644 --- a/src/machine.rs +++ b/src/machine.rs @@ -1723,3 +1723,77 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> { Cow::Borrowed(ecx.machine.union_data_ranges.entry(ty).or_insert_with(compute_range)) } } + +/// Trait for callbacks handling asynchronous machine operations. +/// +/// Callbacks receive a completion state and can perform follow-up actions while +/// maintaining interpreter invariants. They are executed with mutable access to +/// the interpreter context. +/// +/// # Type Parameters +/// - `'tcx`: Typing context lifetime for the interpreter. +/// - `T`: Type of argument passed to the callback on completion. +pub trait MachineCallback<'tcx, T>: VisitProvenance { + /// Executes the callback when an operation completes. + /// + /// # Arguments + /// - `self`: Owned callback, boxed for dynamic dispatch + /// - `ecx`: Mutable interpreter context + /// - `arg`: Operation-specific completion argument + /// + /// # Returns + /// Success or error of the callback execution. + fn call( + self: Box, + ecx: &mut InterpCx<'tcx, MiriMachine<'tcx>>, + arg: T, + ) -> InterpResult<'tcx>; +} + +/// Creates machine callbacks with captured variables and generic argument types. +/// +/// This macro generates the boilerplate needed to create type-safe callbacks: +/// - Creates a struct to hold captured variables +/// - Implements required traits (VisitProvenance, MachineCallback) +/// - Handles proper lifetime and type parameters +/// +/// The callback body receives the interpreter context and completion argument. +#[macro_export] +macro_rules! callback { + (@capture<$tcx:lifetime $(,)? $($lft:lifetime),*> + { $($name:ident: $type:ty),* $(,)? } + @unblock = |$this:ident, $arg:ident: $arg_ty:ty| $body:expr $(,)?) => {{ + // Create callback struct with the captured variables and generic type T + struct Callback<$tcx, $($lft),*> { + $($name: $type,)* + _phantom: std::marker::PhantomData<&$tcx ()>, + } + + // Implement VisitProvenance trait for the callback + impl<$tcx, $($lft),*> VisitProvenance for Callback<$tcx, $($lft),*> { + fn visit_provenance(&self, _visit: &mut VisitWith<'_>) { + $( + self.$name.visit_provenance(_visit); + )* + } + } + + // Implement MachineCallback trait with the specified argument type + impl<$tcx, $($lft),*> MachineCallback<$tcx, $arg_ty> for Callback<$tcx, $($lft),*> { + fn call( + self: Box, + $this: &mut MiriInterpCx<$tcx>, + $arg: $arg_ty + ) -> InterpResult<$tcx> { + #[allow(unused_variables)] + let Callback { $($name,)* _phantom } = *self; + $body + } + } + + Box::new(Callback { + $($name,)* + _phantom: std::marker::PhantomData + }) + }}; +} diff --git a/src/shims/time.rs b/src/shims/time.rs index 72d98bc1c4..4a7728f2ac 100644 --- a/src/shims/time.rs +++ b/src/shims/time.rs @@ -331,8 +331,14 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { Some((TimeoutClock::Monotonic, TimeoutAnchor::Relative, duration)), callback!( @capture<'tcx> {} - @unblock = |_this| { panic!("sleeping thread unblocked before time is up") } - @timeout = |_this| { interp_ok(()) } + @unblock = |_this, unblock: UnblockKind| { + match unblock { + UnblockKind::Ready => { + panic!("sleeping thread unblocked before time is up") + }, + UnblockKind::TimedOut => { interp_ok(()) }, + } + } ), ); interp_ok(Scalar::from_i32(0)) @@ -353,8 +359,14 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { Some((TimeoutClock::Monotonic, TimeoutAnchor::Relative, duration)), callback!( @capture<'tcx> {} - @unblock = |_this| { panic!("sleeping thread unblocked before time is up") } - @timeout = |_this| { interp_ok(()) } + @unblock = |_this, unblock: UnblockKind| { + match unblock { + UnblockKind::Ready => { + panic!("sleeping thread unblocked before time is up") + }, + UnblockKind::TimedOut => { interp_ok(()) }, + } + } ), ); interp_ok(()) diff --git a/src/shims/unix/linux_like/epoll.rs b/src/shims/unix/linux_like/epoll.rs index 5b240351c2..0465e281fe 100644 --- a/src/shims/unix/linux_like/epoll.rs +++ b/src/shims/unix/linux_like/epoll.rs @@ -496,22 +496,26 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { dest: MPlaceTy<'tcx>, event: MPlaceTy<'tcx>, } - @unblock = |this| { - return_ready_list(epfd_value, weak_epfd, &dest, &event, this)?; - interp_ok(()) - } - @timeout = |this| { - // No notification after blocking timeout. - let Some(epfd) = weak_epfd.upgrade() else { - throw_unsup_format!("epoll FD {epfd_value} got closed while blocking.") - }; - // Remove the current active thread_id from the blocked thread_id list. - epfd.downcast::() - .ok_or_else(|| err_unsup_format!("non-epoll FD passed to `epoll_wait`"))? - .thread_id.borrow_mut() - .retain(|&id| id != this.active_thread()); - this.write_int(0, &dest)?; - interp_ok(()) + @unblock = |this, unblock: UnblockKind| { + match unblock { + UnblockKind::Ready => { + return_ready_list(epfd_value, weak_epfd, &dest, &event, this)?; + interp_ok(()) + }, + UnblockKind::TimedOut => { + // No notification after blocking timeout. + let Some(epfd) = weak_epfd.upgrade() else { + throw_unsup_format!("epoll FD {epfd_value} got closed while blocking.") + }; + // Remove the current active thread_id from the blocked thread_id list. + epfd.downcast::() + .ok_or_else(|| err_unsup_format!("non-epoll FD passed to `epoll_wait`"))? + .thread_id.borrow_mut() + .retain(|&id| id != this.active_thread()); + this.write_int(0, &dest)?; + interp_ok(()) + }, + } } ), ); diff --git a/src/shims/unix/linux_like/eventfd.rs b/src/shims/unix/linux_like/eventfd.rs index 4bbe417ea8..b20079b6a8 100644 --- a/src/shims/unix/linux_like/eventfd.rs +++ b/src/shims/unix/linux_like/eventfd.rs @@ -254,7 +254,8 @@ fn eventfd_write<'tcx>( dest: MPlaceTy<'tcx>, weak_eventfd: WeakFileDescriptionRef, } - @unblock = |this| { + @unblock = |this, unblock: UnblockKind| { + assert_eq!(unblock, UnblockKind::Ready); // When we get unblocked, try again. eventfd_write(num, buf_place, &dest, weak_eventfd, this) } @@ -302,7 +303,8 @@ fn eventfd_read<'tcx>( dest: MPlaceTy<'tcx>, weak_eventfd: WeakFileDescriptionRef, } - @unblock = |this| { + @unblock = |this, unblock: UnblockKind| { + assert_eq!(unblock, UnblockKind::Ready); // When we get unblocked, try again. eventfd_read(buf_place, &dest, weak_eventfd, this) } diff --git a/src/shims/unix/macos/sync.rs b/src/shims/unix/macos/sync.rs index f66a57ae70..8ae19c760f 100644 --- a/src/shims/unix/macos/sync.rs +++ b/src/shims/unix/macos/sync.rs @@ -64,7 +64,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { None, callback!( @capture<'tcx> {} - @unblock = |_this| { + @unblock = |_this, _unblock: UnblockKind| { panic!("we shouldn't wake up ever") } ), diff --git a/src/shims/unix/unnamed_socket.rs b/src/shims/unix/unnamed_socket.rs index 86ebe95762..250230dce2 100644 --- a/src/shims/unix/unnamed_socket.rs +++ b/src/shims/unix/unnamed_socket.rs @@ -199,7 +199,8 @@ fn anonsocket_write<'tcx>( len: usize, dest: MPlaceTy<'tcx>, } - @unblock = |this| { + @unblock = |this, unblock: UnblockKind| { + assert_eq!(unblock, UnblockKind::Ready); anonsocket_write(weak_self_ref, ptr, len, dest, this) } ), @@ -273,7 +274,8 @@ fn anonsocket_read<'tcx>( ptr: Pointer, dest: MPlaceTy<'tcx>, } - @unblock = |this| { + @unblock = |this, unblock: UnblockKind| { + assert_eq!(unblock, UnblockKind::Ready); anonsocket_read(weak_self_ref, len, ptr, dest, this) } ), diff --git a/src/shims/windows/sync.rs b/src/shims/windows/sync.rs index a394e0430b..dbceb9756e 100644 --- a/src/shims/windows/sync.rs +++ b/src/shims/windows/sync.rs @@ -111,7 +111,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { pending_place: MPlaceTy<'tcx>, dest: MPlaceTy<'tcx>, } - @unblock = |this| { + @unblock = |this, unblock: UnblockKind| { + assert_eq!(unblock, UnblockKind::Ready); let ret = this.init_once_try_begin(id, &pending_place, &dest)?; assert!(ret, "we were woken up but init_once_try_begin still failed"); interp_ok(())