diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 7f4041fa..0eeac1f4 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -40,6 +40,7 @@ rayon = { version = "1.7.0", optional = true } libm = { workspace = true } wgpu = { version = "0.18.0", optional = true } futures-lite = { version = "2.0.1", optional = true } +thingbuf = { version = "0.1.4", optional = true } [dev-dependencies] tempfile = "3.3.0" @@ -61,7 +62,7 @@ fast-alloc = ["std"] cuda = ["dep:cudarc", "dep:glob"] cudnn = ["cuda", "cudarc?/cudnn"] -webgpu = ["dep:wgpu", "dep:futures-lite", "wgpu/expose-ids"] +webgpu = ["dep:wgpu", "dep:futures-lite", "dep:thingbuf", "wgpu/expose-ids"] f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"] diff --git a/dfdx-core/src/tensor/webgpu/allocate.rs b/dfdx-core/src/tensor/webgpu/allocate.rs index 52b22724..49162381 100644 --- a/dfdx-core/src/tensor/webgpu/allocate.rs +++ b/dfdx-core/src/tensor/webgpu/allocate.rs @@ -74,7 +74,7 @@ impl ZeroFillStorage for Webgpu { impl OnesTensor for Webgpu { fn try_ones_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); - let buf = std::vec![E::ONE; shape.num_elements()]; + let buf = vec![E::ONE; shape.num_elements()]; self.tensor_from_host_buf(shape, buf) } } @@ -90,7 +90,7 @@ where diagonal: impl Into>, ) -> Result, Error> { let shape = *src.shape(); - let mut data = std::vec![val; shape.num_elements()]; + let mut data = vec![val; shape.num_elements()]; let offset = diagonal.into().unwrap_or(0); triangle_mask(&mut data, &shape, true, offset); self.tensor_from_host_buf(shape, data) @@ -103,7 +103,7 @@ where diagonal: impl Into>, ) -> Result, Error> { let shape = *src.shape(); - let mut data = std::vec![val; shape.num_elements()]; + let mut data = vec![val; shape.num_elements()]; let offset = diagonal.into().unwrap_or(0); triangle_mask(&mut data, &shape, false, offset); self.tensor_from_host_buf(shape, data) @@ -113,7 +113,7 @@ where impl OneFillStorage for Webgpu { fn try_fill_with_ones(&self, storage: &mut Self::Vec) -> Result<(), Error> { let len = storage.size() as usize / std::mem::size_of::(); - let buf = std::vec![E::ONE; len]; + let buf = vec![E::ONE; len]; storage .data .copy_to_device::(&self.dev, &self.queue, &buf); diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index a82e7019..3cba06c7 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -11,6 +11,12 @@ use crate::{ }, }; +#[cfg(feature = "no-std")] +use spin::Mutex; + +#[cfg(not(feature = "no-std"))] +use std::sync::Mutex; + use std::{marker::PhantomData, sync::Arc, vec::Vec}; use super::allocate::round_to_buffer_alignment; @@ -52,7 +58,7 @@ impl Buffer { } pub(crate) fn copy_to_host(&self, dev: &Device, queue: &Queue, buf: &mut [E]) { - let (sender, receiver) = std::sync::mpsc::channel(); + let (sender, receiver) = thingbuf::mpsc::channel(1); let buffer = dev.create_buffer(&BufferDescriptor { label: None, size: self.size() as u64, @@ -66,11 +72,11 @@ impl Buffer { } let slice = buffer.slice(..self.size() as u64); slice.map_async(wgpu::MapMode::Read, move |_| { - sender.send(()).unwrap(); + futures_lite::future::block_on(sender.send(())).unwrap(); }); dev.poll(Maintain::Wait); - let _ = receiver.recv().unwrap(); + let _ = futures_lite::future::block_on(receiver.recv()); let data = slice.get_mapped_range(); // TODO: How are we sure this is safe? let slice = unsafe { @@ -110,7 +116,7 @@ impl Default for Webgpu { } } -static CONSTRUCTOR_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); +static CONSTRUCTOR_MUTEX: Mutex<()> = Mutex::new(()); impl Webgpu { pub fn seed_from_u64(seed: u64) -> Self { @@ -118,7 +124,11 @@ impl Webgpu { } pub fn try_build(seed: u64) -> Result { - let _lock = CONSTRUCTOR_MUTEX.lock().unwrap(); + #[cfg(feature = "no-std")] + let _lock = { CONSTRUCTOR_MUTEX.lock() }; + #[cfg(not(feature = "no-std"))] + let _lock = { CONSTRUCTOR_MUTEX.lock().unwrap() }; + let cpu = Cpu::seed_from_u64(seed); let instance = Arc::new(Instance::new(InstanceDescriptor::default())); let adapter = futures_lite::future::block_on(instance.request_adapter(&Default::default())) @@ -332,7 +342,7 @@ impl Storage for Webgpu { device: self.cpu.clone(), tape: NoneTape, }; - let buf = std::sync::Arc::get_mut(&mut cpu_tensor.data).unwrap(); + let buf = Arc::get_mut(&mut cpu_tensor.data).unwrap(); tensor.data.copy_to_host::(&self.dev, &self.queue, buf); self.cpu.tensor_to_vec::(&cpu_tensor) } diff --git a/dfdx-core/src/tensor_ops/stack/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/stack/webgpu_kernel.rs index d2aebfc8..113aae1a 100644 --- a/dfdx-core/src/tensor_ops/stack/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/stack/webgpu_kernel.rs @@ -1,4 +1,5 @@ use crate::{shapes::*, tensor::Webgpu}; +use std::vec::Vec; impl super::StackKernel for Webgpu { fn forward(