Skip to content

Commit

Permalink
Add Mutex back, since evidently it was causing issues.
Browse files Browse the repository at this point in the history
Hopefully I can figure out a way to remove it again.
  • Loading branch information
favilo committed Nov 27, 2023
1 parent 32a60b4 commit e600b98
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions dfdx-core/src/tensor/webgpu/device.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use wgpu::{
util::DownloadBuffer, Adapter, Backends, BufferDescriptor, BufferUsages, Device, Instance,
InstanceDescriptor, Maintain, Queue, RequestDeviceError, COPY_BUFFER_ALIGNMENT,
Adapter, BufferDescriptor, BufferUsages, Device, Instance, InstanceDescriptor, Maintain, Queue,
RequestDeviceError,
};

use crate::{
Expand All @@ -11,12 +11,7 @@ use crate::{
},
};

use core::sync::atomic::AtomicPtr;
use std::{
marker::PhantomData,
sync::{atomic::Ordering, Arc},
vec::Vec,
};
use std::{marker::PhantomData, sync::Arc, vec::Vec};

use super::allocate::round_to_buffer_alignment;

Expand Down Expand Up @@ -45,7 +40,12 @@ impl Buffer {
}

pub(crate) fn copy_to_device<E: Unit>(&self, dev: &Device, queue: &Queue, slice: &[E]) {
let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, self.size()) };
let slice = unsafe {
std::slice::from_raw_parts(
slice.as_ptr() as *const u8,
slice.len() * std::mem::size_of::<E>(),
)
};
queue.write_buffer(&self.data, 0, slice);
queue.submit(std::iter::empty());
dev.poll(Maintain::Wait);
Expand Down Expand Up @@ -80,13 +80,17 @@ impl Buffer {
)
};
buf.copy_from_slice(slice);
drop(data);
buffer.unmap();
}
}

#[derive(Clone, Debug)]
pub struct Webgpu {
pub(crate) cpu: Cpu,
#[allow(unused)]
pub(crate) instance: Arc<Instance>,
#[allow(unused)]
pub(crate) adapter: Arc<Adapter>,
pub(crate) dev: Arc<Device>,
pub(crate) queue: Arc<Queue>,
Expand All @@ -106,12 +110,15 @@ impl Default for Webgpu {
}
}

static CONSTRUCTOR_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());

impl Webgpu {
pub fn seed_from_u64(seed: u64) -> Self {
Self::try_build(seed).unwrap()
}

pub fn try_build(seed: u64) -> Result<Self, Error> {
let _lock = CONSTRUCTOR_MUTEX.lock().unwrap();
let cpu = Cpu::seed_from_u64(seed);
let instance = Arc::new(Instance::new(InstanceDescriptor::default()));
let adapter = futures_lite::future::block_on(instance.request_adapter(&Default::default()))
Expand Down

0 comments on commit e600b98

Please sign in to comment.