Skip to content

Commit

Permalink
Removed num_traits::Num requirement from Zeros.
Browse files Browse the repository at this point in the history
Had to figure out a way to store zeros in place
  • Loading branch information
favilo committed Nov 28, 2023
1 parent e600b98 commit 3b25249
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
16 changes: 4 additions & 12 deletions dfdx-core/src/tensor/webgpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,20 @@ impl Webgpu {
}
}

impl<E: Unit + SafeZeros + From<bool>> ZerosTensor<E> for Webgpu {
impl<E: Unit + SafeZeros> ZerosTensor<E> for Webgpu {
fn try_zeros_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let strides = shape.strides();
let data = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
data.copy_to_device(
&self.dev,
&self.queue,
&vec![E::from(false); shape.num_elements()],
);
data.copy_to_device(&self.dev, &self.queue, &vec![0u8; data.size()]);

Ok(self.build_tensor(shape, strides, data))
}
}

impl<E: Unit + SafeZeros + From<bool>> ZeroFillStorage<E> for Webgpu {
impl<E: Unit + SafeZeros> ZeroFillStorage<E> for Webgpu {
fn try_fill_with_zeros(&self, storage: &mut Self::Vec) -> Result<(), Error> {
storage.copy_to_device(
&self.dev,
&self.queue,
&vec![E::from(false); storage.size() as usize / std::mem::size_of::<E>()],
);
storage.copy_to_device(&self.dev, &self.queue, &vec![0u8; storage.size()]);

Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions dfdx-core/src/tensor_ops/utilities/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ impl Device<f32> for crate::tensor::Cuda {}
impl Device<f64> for crate::tensor::Cuda {}

#[cfg(all(feature = "webgpu", feature = "f16"))]
impl Device<f16> for crate::tensor::Cuda {}
impl Device<f16> for crate::tensor::Webgpu {}
#[cfg(all(feature = "webgpu", feature = "f16"))]
impl Device<AMP<f16>> for crate::tensor::Cuda {}
impl Device<AMP<f16>> for crate::tensor::Webgpu {}
#[cfg(feature = "webgpu")]
impl Device<f32> for crate::tensor::Webgpu {}
#[cfg(feature = "webgpu")]
Expand Down

0 comments on commit 3b25249

Please sign in to comment.