From e0fea25b0631812a9a6501aaf24e2cd568d573e6 Mon Sep 17 00:00:00 2001 From: daxpedda Date: Tue, 26 Dec 2023 01:22:10 +0100 Subject: [PATCH] Make `canvas` in `WindowBuilder` safe (#3320) --- src/platform/web.rs | 3 +- src/platform_impl/web/async/dispatcher.rs | 4 +- src/platform_impl/web/async/waker.rs | 6 +- src/platform_impl/web/async/wrapper.rs | 53 ++----- src/platform_impl/web/cursor.rs | 165 +++++++-------------- src/platform_impl/web/event_loop/runner.rs | 10 +- src/platform_impl/web/main_thread.rs | 96 ++++++++++++ src/platform_impl/web/mod.rs | 1 + src/platform_impl/web/web_sys/canvas.rs | 10 +- src/platform_impl/web/window.rs | 35 ++++- 10 files changed, 214 insertions(+), 169 deletions(-) create mode 100644 src/platform_impl/web/main_thread.rs diff --git a/src/platform/web.rs b/src/platform/web.rs index 7fcfe240dd..baab78cfab 100644 --- a/src/platform/web.rs +++ b/src/platform/web.rs @@ -34,7 +34,6 @@ use crate::event_loop::EventLoopWindowTarget; use crate::platform_impl::PlatformCustomCursorBuilder; use crate::window::CustomCursor; use crate::window::{Window, WindowBuilder}; -use crate::SendSyncWrapper; use web_sys::HtmlCanvasElement; @@ -105,7 +104,7 @@ pub trait WindowBuilderExtWebSys { impl WindowBuilderExtWebSys for WindowBuilder { fn with_canvas(mut self, canvas: Option) -> Self { - self.platform_specific.canvas = SendSyncWrapper(canvas); + self.platform_specific.set_canvas(canvas); self } diff --git a/src/platform_impl/web/async/dispatcher.rs b/src/platform_impl/web/async/dispatcher.rs index 5fb63d5456..9ec9b6a365 100644 --- a/src/platform_impl/web/async/dispatcher.rs +++ b/src/platform_impl/web/async/dispatcher.rs @@ -1,3 +1,4 @@ +use super::super::main_thread::MainThreadMarker; use super::{channel, AsyncReceiver, AsyncSender, Wrapper}; use std::{ cell::Ref, @@ -10,10 +11,11 @@ struct Closure(Box); impl Dispatcher { #[track_caller] - pub fn new(value: T) -> Option<(Self, DispatchRunner)> { + pub fn new(main_thread: MainThreadMarker, value: T) -> Option<(Self, DispatchRunner)> { let (sender, receiver) = channel::>(); Wrapper::new( + main_thread, value, |value, Closure(closure)| { // SAFETY: The given `Closure` here isn't really `'static`, so we shouldn't do anything diff --git a/src/platform_impl/web/async/waker.rs b/src/platform_impl/web/async/waker.rs index 2c27c474c9..79062d8456 100644 --- a/src/platform_impl/web/async/waker.rs +++ b/src/platform_impl/web/async/waker.rs @@ -1,3 +1,4 @@ +use super::super::main_thread::MainThreadMarker; use super::Wrapper; use atomic_waker::AtomicWaker; use std::future; @@ -19,7 +20,7 @@ struct Sender(Arc); impl WakerSpawner { #[track_caller] - pub fn new(value: T, handler: fn(&T, usize)) -> Option { + pub fn new(main_thread: MainThreadMarker, value: T, handler: fn(&T, usize)) -> Option { let inner = Arc::new(Inner { counter: AtomicUsize::new(0), waker: AtomicWaker::new(), @@ -31,6 +32,7 @@ impl WakerSpawner { let sender = Sender(Arc::clone(&inner)); let wrapper = Wrapper::new( + main_thread, handler, |handler, count| { let handler = handler.borrow(); @@ -86,7 +88,7 @@ impl WakerSpawner { pub fn fetch(&self) -> usize { debug_assert!( - self.0.is_main_thread(), + MainThreadMarker::new().is_some(), "this should only be called from the main thread" ); diff --git a/src/platform_impl/web/async/wrapper.rs b/src/platform_impl/web/async/wrapper.rs index 22088ef1fc..d8b05c713c 100644 --- a/src/platform_impl/web/async/wrapper.rs +++ b/src/platform_impl/web/async/wrapper.rs @@ -1,9 +1,8 @@ +use super::super::main_thread::MainThreadMarker; use std::cell::{Ref, RefCell}; use std::future::Future; use std::marker::PhantomData; use std::sync::Arc; -use wasm_bindgen::prelude::wasm_bindgen; -use wasm_bindgen::{JsCast, JsValue}; // Unsafe wrapper type that allows us to use `T` when it's not `Send` from other threads. // `value` **must** only be accessed on the main thread. @@ -34,36 +33,15 @@ unsafe impl Send for Value {} unsafe impl Sync for Value {} impl Wrapper { - thread_local! { - static MAIN_THREAD: bool = { - #[wasm_bindgen] - extern "C" { - #[derive(Clone)] - type Global; - - #[wasm_bindgen(method, getter, js_name = Window)] - fn window(this: &Global) -> JsValue; - } - - let global: Global = js_sys::global().unchecked_into(); - !global.window().is_undefined() - }; - } - #[track_caller] pub fn new>( + _: MainThreadMarker, value: V, handler: fn(&RefCell>, E), receiver: impl 'static + FnOnce(Arc>>) -> R, sender_data: S, sender_handler: fn(&S, E), ) -> Option { - Self::MAIN_THREAD.with(|safe| { - if !safe { - panic!("only callable from inside the `Window`") - } - }); - let value = Arc::new(RefCell::new(Some(value))); wasm_bindgen_futures::spawn_local({ @@ -86,29 +64,16 @@ impl Wrapper { } pub fn send(&self, event: E) { - Self::MAIN_THREAD.with(|is_main_thread| { - if *is_main_thread { - (self.handler)(&self.value.value, event) - } else { - (self.sender_handler)(&self.sender_data, event) - } - }) - } - - pub fn is_main_thread(&self) -> bool { - Self::MAIN_THREAD.with(|is_main_thread| *is_main_thread) + if MainThreadMarker::new().is_some() { + (self.handler)(&self.value.value, event) + } else { + (self.sender_handler)(&self.sender_data, event) + } } pub fn value(&self) -> Option> { - Self::MAIN_THREAD.with(|is_main_thread| { - if *is_main_thread { - Some(Ref::map(self.value.value.borrow(), |value| { - value.as_ref().unwrap() - })) - } else { - None - } - }) + MainThreadMarker::new() + .map(|_| Ref::map(self.value.value.borrow(), |value| value.as_ref().unwrap())) } pub fn with_sender_data(&self, f: impl FnOnce(&S) -> T) -> T { diff --git a/src/platform_impl/web/cursor.rs b/src/platform_impl/web/cursor.rs index 4d57c20079..7287c25fb6 100644 --- a/src/platform_impl/web/cursor.rs +++ b/src/platform_impl/web/cursor.rs @@ -9,12 +9,8 @@ use std::{ task::{Poll, Waker}, }; -use crate::{ - cursor::{BadImage, Cursor, CursorImage}, - platform_impl::platform::r#async, -}; +use crate::cursor::{BadImage, Cursor, CursorImage}; use cursor_icon::CursorIcon; -use once_cell::sync::Lazy; use wasm_bindgen::{closure::Closure, JsCast}; use wasm_bindgen_futures::JsFuture; use web_sys::{ @@ -22,9 +18,9 @@ use web_sys::{ ImageBitmapRenderingContext, ImageData, PremultiplyAlpha, Url, Window, }; -use self::thread_safe::ThreadSafe; - -use super::{backend::Style, r#async::AsyncSender, EventLoopWindowTarget}; +use super::backend::Style; +use super::main_thread::{MainThreadMarker, MainThreadSafe}; +use super::EventLoopWindowTarget; #[derive(Debug)] pub(crate) enum CustomCursorBuilder { @@ -51,7 +47,7 @@ impl CustomCursorBuilder { } #[derive(Clone, Debug)] -pub struct CustomCursor(Arc); +pub struct CustomCursor(Arc>>); impl Hash for CustomCursor { fn hash(&self, state: &mut H) { @@ -68,14 +64,22 @@ impl PartialEq for CustomCursor { impl Eq for CustomCursor {} impl CustomCursor { + fn new(main_thread: MainThreadMarker) -> Self { + Self(Arc::new(MainThreadSafe::new( + main_thread, + RefCell::new(ImageState::Loading(None)), + ))) + } + pub(crate) fn build( builder: CustomCursorBuilder, window_target: &EventLoopWindowTarget, ) -> Self { - Lazy::force(&DROP_HANDLER); + let main_thread = window_target.runner.main_thread(); - Self(match builder { + match builder { CustomCursorBuilder::Image(image) => ImageState::from_rgba( + main_thread, window_target.runner.window(), window_target.runner.document().clone(), &image, @@ -84,47 +88,7 @@ impl CustomCursor { url, hotspot_x, hotspot_y, - } => ImageState::from_url(url, hotspot_x, hotspot_y), - }) - } -} - -#[derive(Debug)] -struct Inner(Option>>); - -static DROP_HANDLER: Lazy>>> = Lazy::new(|| { - let (sender, receiver) = r#async::channel(); - wasm_bindgen_futures::spawn_local(async move { while receiver.next().await.is_ok() {} }); - - sender -}); - -impl Inner { - fn new() -> Arc { - Arc::new(Inner(Some(ThreadSafe::new(RefCell::new( - ImageState::Loading(None), - ))))) - } - - fn get(&self) -> &RefCell { - self.0 - .as_ref() - .expect("value has accidently already been dropped") - .get() - } -} - -impl Drop for Inner { - fn drop(&mut self) { - let value = self - .0 - .take() - .expect("value has accidently already been dropped"); - - if !value.in_origin_thread() { - DROP_HANDLER - .send(value) - .expect("sender dropped in main thread") + } => ImageState::from_url(main_thread, url, hotspot_x, hotspot_y), } } } @@ -133,8 +97,9 @@ impl Drop for Inner { pub struct CursorState(Rc>); impl CursorState { - pub fn new(style: Style) -> Self { + pub fn new(main_thread: MainThreadMarker, style: Style) -> Self { Self(Rc::new(RefCell::new(State { + main_thread, style, visible: true, cursor: SelectedCursor::default(), @@ -147,7 +112,9 @@ impl CursorState { match cursor { Cursor::Icon(icon) => { if let SelectedCursor::ImageLoading { state, .. } = &this.cursor { - if let ImageState::Loading(state) = state.get().borrow_mut().deref_mut() { + if let ImageState::Loading(state) = + state.0.get(this.main_thread).borrow_mut().deref_mut() + { state.take(); } } @@ -155,10 +122,16 @@ impl CursorState { this.cursor = SelectedCursor::Named(icon); this.set_style(); } - Cursor::Custom(cursor) => match cursor.inner.0.get().borrow_mut().deref_mut() { + Cursor::Custom(cursor) => match cursor + .inner + .0 + .get(this.main_thread) + .borrow_mut() + .deref_mut() + { ImageState::Loading(state) => { this.cursor = SelectedCursor::ImageLoading { - state: cursor.inner.0.clone(), + state: cursor.inner.clone(), previous: mem::take(&mut this.cursor).into(), }; *state = Some(Rc::downgrade(&self.0)); @@ -187,6 +160,7 @@ impl CursorState { #[derive(Debug)] struct State { + main_thread: MainThreadMarker, style: Style, visible: bool, cursor: SelectedCursor, @@ -210,7 +184,7 @@ impl State { enum SelectedCursor { Named(CursorIcon), ImageLoading { - state: Arc, + state: CustomCursor, previous: Previous, }, ImageReady(Rc), @@ -264,7 +238,12 @@ enum ImageState { } impl ImageState { - fn from_rgba(window: &Window, document: Document, image: &CursorImage) -> Arc { + fn from_rgba( + main_thread: MainThreadMarker, + window: &Window, + document: Document, + image: &CursorImage, + ) -> CustomCursor { // 1. Create an `ImageData` from the RGBA data. // 2. Create an `ImageBitmap` from the `ImageData`. // 3. Draw `ImageBitmap` on an `HTMLCanvasElement`. @@ -316,10 +295,11 @@ impl ImageState { .expect("unexpected exception in `createImageBitmap()`"), ); - let this = Inner::new(); + #[allow(clippy::arc_with_non_send_sync)] + let this = CustomCursor::new(main_thread); wasm_bindgen_futures::spawn_local({ - let weak = Arc::downgrade(&this); + let weak = Arc::downgrade(&this.0); let CursorImage { width, height, @@ -394,7 +374,7 @@ impl ImageState { let Some(this) = weak.upgrade() else { return; }; - let mut this = this.get().borrow_mut(); + let mut this = this.get(main_thread).borrow_mut(); let Some(blob) = blob else { log::error!("creating custom cursor failed"); @@ -422,17 +402,24 @@ impl ImageState { .expect("unexpected exception in `URL.createObjectURL()`") }; - Self::decode(weak, url, true, hotspot_x, hotspot_y).await; + Self::decode(main_thread, weak, url, true, hotspot_x, hotspot_y).await; } }); this } - fn from_url(url: String, hotspot_x: u16, hotspot_y: u16) -> Arc { - let this = Inner::new(); + fn from_url( + main_thread: MainThreadMarker, + url: String, + hotspot_x: u16, + hotspot_y: u16, + ) -> CustomCursor { + #[allow(clippy::arc_with_non_send_sync)] + let this = CustomCursor::new(main_thread); wasm_bindgen_futures::spawn_local(Self::decode( - Arc::downgrade(&this), + main_thread, + Arc::downgrade(&this.0), url, false, hotspot_x, @@ -443,7 +430,8 @@ impl ImageState { } async fn decode( - weak: sync::Weak, + main_thread: MainThreadMarker, + weak: sync::Weak>>, url: String, object: bool, hotspot_x: u16, @@ -462,7 +450,7 @@ impl ImageState { let Some(this) = weak.upgrade() else { return; }; - let mut this = this.get().borrow_mut(); + let mut this = this.get(main_thread).borrow_mut(); let ImageState::Loading(state) = this.deref_mut() else { unreachable!("found invalid state"); @@ -533,46 +521,3 @@ impl Image { }) } } - -mod thread_safe { - use std::mem; - use std::thread::{self, ThreadId}; - - #[derive(Debug)] - pub struct ThreadSafe { - origin_thread: ThreadId, - value: T, - } - - impl ThreadSafe { - pub fn new(value: T) -> Self { - Self { - origin_thread: thread::current().id(), - value, - } - } - - pub fn get(&self) -> &T { - if self.origin_thread == thread::current().id() { - &self.value - } else { - panic!("value not accessible outside its origin thread") - } - } - - pub fn in_origin_thread(&self) -> bool { - self.origin_thread == thread::current().id() - } - } - - impl Drop for ThreadSafe { - fn drop(&mut self) { - if mem::needs_drop::() && self.origin_thread != thread::current().id() { - panic!("value can't be dropped outside its origin thread") - } - } - } - - unsafe impl Send for ThreadSafe {} - unsafe impl Sync for ThreadSafe {} -} diff --git a/src/platform_impl/web/event_loop/runner.rs b/src/platform_impl/web/event_loop/runner.rs index b029ab516f..f7932ff76e 100644 --- a/src/platform_impl/web/event_loop/runner.rs +++ b/src/platform_impl/web/event_loop/runner.rs @@ -1,3 +1,4 @@ +use super::super::main_thread::MainThreadMarker; use super::super::DeviceId; use super::{backend, state::State}; use crate::dpi::PhysicalSize; @@ -37,6 +38,7 @@ impl Clone for Shared { type OnEventHandle = RefCell>>; pub struct Execution { + main_thread: MainThreadMarker, proxy_spawner: WakerSpawner>, control_flow: Cell, poll_strategy: Cell, @@ -143,13 +145,14 @@ impl Runner { impl Shared { pub fn new() -> Self { + let main_thread = MainThreadMarker::new().expect("only callable from inside the `Window`"); #[allow(clippy::disallowed_methods)] let window = web_sys::window().expect("only callable from inside the `Window`"); #[allow(clippy::disallowed_methods)] let document = window.document().expect("Failed to obtain document"); Shared(Rc::::new_cyclic(|weak| { - let proxy_spawner = WakerSpawner::new(weak.clone(), |runner, count| { + let proxy_spawner = WakerSpawner::new(main_thread, weak.clone(), |runner, count| { if let Some(runner) = runner.upgrade() { Shared(runner).send_events(iter::repeat(Event::UserEvent(())).take(count)) } @@ -157,6 +160,7 @@ impl Shared { .expect("`EventLoop` has to be created in the main thread"); Execution { + main_thread, proxy_spawner, control_flow: Cell::new(ControlFlow::default()), poll_strategy: Cell::new(PollStrategy::default()), @@ -184,6 +188,10 @@ impl Shared { })) } + pub fn main_thread(&self) -> MainThreadMarker { + self.0.main_thread + } + pub fn window(&self) -> &web_sys::Window { &self.0.window } diff --git a/src/platform_impl/web/main_thread.rs b/src/platform_impl/web/main_thread.rs new file mode 100644 index 0000000000..8c8598f476 --- /dev/null +++ b/src/platform_impl/web/main_thread.rs @@ -0,0 +1,96 @@ +use std::fmt::{self, Debug, Formatter}; +use std::marker::PhantomData; +use std::mem; +use std::sync::OnceLock; + +use wasm_bindgen::prelude::wasm_bindgen; +use wasm_bindgen::{JsCast, JsValue}; + +use super::r#async::{self, AsyncSender}; + +thread_local! { + static MAIN_THREAD: bool = { + #[wasm_bindgen] + extern "C" { + #[derive(Clone)] + type Global; + + #[wasm_bindgen(method, getter, js_name = Window)] + fn window(this: &Global) -> JsValue; + } + + let global: Global = js_sys::global().unchecked_into(); + !global.window().is_undefined() + }; +} + +#[derive(Clone, Copy, Debug)] +pub struct MainThreadMarker(PhantomData<*const ()>); + +impl MainThreadMarker { + pub fn new() -> Option { + MAIN_THREAD.with(|is| is.then_some(Self(PhantomData))) + } +} + +pub struct MainThreadSafe(Option); + +impl MainThreadSafe { + pub fn new(_: MainThreadMarker, value: T) -> Self { + DROP_HANDLER.get_or_init(|| { + let (sender, receiver) = r#async::channel(); + wasm_bindgen_futures::spawn_local( + async move { while receiver.next().await.is_ok() {} }, + ); + + sender + }); + + Self(Some(value)) + } + + pub fn into_inner(mut self, _: MainThreadMarker) -> T { + self.0.take().expect("already taken or dropped") + } + + pub fn get(&self, _: MainThreadMarker) -> &T { + self.0.as_ref().expect("already taken or dropped") + } +} + +impl Debug for MainThreadSafe { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if MainThreadMarker::new().is_some() { + f.debug_tuple("MainThreadSafe").field(&self.0).finish() + } else { + f.debug_struct("MainThreadSafe").finish_non_exhaustive() + } + } +} + +impl Drop for MainThreadSafe { + fn drop(&mut self) { + if let Some(value) = self.0.take() { + if mem::needs_drop::() && MainThreadMarker::new().is_none() { + DROP_HANDLER + .get() + .expect("drop handler not initialized when setting canvas") + .send(DropBox(Box::new(value))) + .expect("sender dropped in main thread") + } + } + } +} + +unsafe impl Send for MainThreadSafe {} +unsafe impl Sync for MainThreadSafe {} + +static DROP_HANDLER: OnceLock> = OnceLock::new(); + +struct DropBox(Box); + +unsafe impl Send for DropBox {} +unsafe impl Sync for DropBox {} + +trait Any {} +impl Any for T {} diff --git a/src/platform_impl/web/mod.rs b/src/platform_impl/web/mod.rs index 56ca0eaf47..f3c063e218 100644 --- a/src/platform_impl/web/mod.rs +++ b/src/platform_impl/web/mod.rs @@ -23,6 +23,7 @@ mod device; mod error; mod event_loop; mod keyboard; +mod main_thread; mod monitor; mod window; diff --git a/src/platform_impl/web/web_sys/canvas.rs b/src/platform_impl/web/web_sys/canvas.rs index 360c9895d6..390a54569c 100644 --- a/src/platform_impl/web/web_sys/canvas.rs +++ b/src/platform_impl/web/web_sys/canvas.rs @@ -17,6 +17,7 @@ use crate::keyboard::{Key, KeyLocation, ModifiersState, PhysicalKey}; use crate::platform_impl::{OsError, PlatformSpecificWindowBuilderAttributes}; use crate::window::{WindowAttributes, WindowId as RootWindowId}; +use super::super::main_thread::MainThreadMarker; use super::super::WindowId; use super::animation_frame::AnimationFrameHandler; use super::event_handle::EventListenerHandle; @@ -66,13 +67,18 @@ pub struct Style { impl Canvas { pub fn create( + main_thread: MainThreadMarker, id: WindowId, window: web_sys::Window, document: Document, attr: &WindowAttributes, - platform_attr: PlatformSpecificWindowBuilderAttributes, + mut platform_attr: PlatformSpecificWindowBuilderAttributes, ) -> Result { - let canvas = match platform_attr.canvas.0 { + let canvas = match platform_attr.canvas.take().map(|canvas| { + Arc::try_unwrap(canvas) + .map(|canvas| canvas.into_inner(main_thread)) + .unwrap_or_else(|canvas| canvas.get(main_thread).clone()) + }) { Some(canvas) => canvas, None => document .create_element("canvas") diff --git a/src/platform_impl/web/window.rs b/src/platform_impl/web/window.rs index c47fc3c6be..645f2e8548 100644 --- a/src/platform_impl/web/window.rs +++ b/src/platform_impl/web/window.rs @@ -5,9 +5,9 @@ use crate::window::{ Cursor, CursorGrabMode, ImePurpose, ResizeDirection, Theme, UserAttentionType, WindowAttributes, WindowButtons, WindowId as RootWI, WindowLevel, }; -use crate::SendSyncWrapper; use super::cursor::CursorState; +use super::main_thread::{MainThreadMarker, MainThreadSafe}; use super::r#async::Dispatcher; use super::{backend, monitor::MonitorHandle, EventLoopWindowTarget, Fullscreen}; use web_sys::HtmlCanvasElement; @@ -15,6 +15,7 @@ use web_sys::HtmlCanvasElement; use std::cell::RefCell; use std::collections::VecDeque; use std::rc::Rc; +use std::sync::Arc; pub struct Window { inner: Dispatcher, @@ -38,10 +39,16 @@ impl Window { let window = target.runner.window(); let document = target.runner.document(); - let canvas = - backend::Canvas::create(id, window.clone(), document.clone(), &attr, platform_attr)?; + let canvas = backend::Canvas::create( + target.runner.main_thread(), + id, + window.clone(), + document.clone(), + &attr, + platform_attr, + )?; let canvas = Rc::new(RefCell::new(canvas)); - let cursor = CursorState::new(canvas.borrow().style().clone()); + let cursor = CursorState::new(target.runner.main_thread(), canvas.borrow().style().clone()); target.register(&canvas, id); @@ -62,7 +69,7 @@ impl Window { inner.set_window_icon(attr.window_icon); let canvas = Rc::downgrade(&inner.canvas); - let (dispatcher, runner) = Dispatcher::new(inner).unwrap(); + let (dispatcher, runner) = Dispatcher::new(target.runner.main_thread(), inner).unwrap(); target.runner.add_canvas(RootWI(id), canvas, runner); Ok(Window { inner: dispatcher }) @@ -465,16 +472,30 @@ impl From for WindowId { #[derive(Clone)] pub struct PlatformSpecificWindowBuilderAttributes { - pub(crate) canvas: SendSyncWrapper>, + pub(crate) canvas: Option>>, pub(crate) prevent_default: bool, pub(crate) focusable: bool, pub(crate) append: bool, } +impl PlatformSpecificWindowBuilderAttributes { + pub(crate) fn set_canvas(&mut self, canvas: Option) { + let Some(canvas) = canvas else { + self.canvas = None; + return; + }; + + let main_thread = MainThreadMarker::new() + .expect("received a `HtmlCanvasElement` outside the window context"); + + self.canvas = Some(Arc::new(MainThreadSafe::new(main_thread, canvas))); + } +} + impl Default for PlatformSpecificWindowBuilderAttributes { fn default() -> Self { Self { - canvas: SendSyncWrapper(None), + canvas: None, prevent_default: true, focusable: true, append: false,