diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 0eeac1f4..5309ef7c 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -38,7 +38,8 @@ half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_dis gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } libm = { workspace = true } -wgpu = { version = "0.18.0", optional = true } +wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true } +naga = { version = "0.14.1", optional = true } futures-lite = { version = "2.0.1", optional = true } thingbuf = { version = "0.1.4", optional = true } @@ -62,7 +63,14 @@ fast-alloc = ["std"] cuda = ["dep:cudarc", "dep:glob"] cudnn = ["cuda", "cudarc?/cudnn"] -webgpu = ["dep:wgpu", "dep:futures-lite", "dep:thingbuf", "wgpu/expose-ids"] +webgpu = [ + "dep:wgpu", + "dep:futures-lite", + "dep:thingbuf", + "dep:naga", + "dep:glob", + "wgpu/expose-ids", +] f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"] diff --git a/dfdx-core/build.rs b/dfdx-core/build.rs index 1048382a..76d33682 100644 --- a/dfdx-core/build.rs +++ b/dfdx-core/build.rs @@ -9,6 +9,9 @@ fn main() { #[cfg(feature = "cuda")] cuda::build_ptx(); + + #[cfg(feature = "webgpu")] + webgpu::build_spv(); } fn maybe_enable_nightly() { @@ -210,3 +213,52 @@ mod cuda { } } } + +#[cfg(feature = "webgpu")] +mod webgpu { + pub fn build_spv() { + let out_dir = std::env::var("OUT_DIR").unwrap(); + let kernel_paths: Vec = glob::glob("src/**/*.glsl") + .unwrap() + .map(|p| p.unwrap()) + .collect(); + for path in &kernel_paths { + println!("cargo:rerun-if-changed={}", path.display()); + } + + kernel_paths + .iter() + .for_each(|p| println!("cargo:rerun-if-changed={}", p.display())); + + let children = kernel_paths + .iter() + .map(|p| { + // TODO: we need to build this for both float and double + let out_path: std::path::PathBuf = out_dir.clone().into(); + let base = p.file_stem().unwrap(); + let new_name = format!("{}.float.spv", base.to_str().unwrap()); + let out_file = &out_path.join(new_name); + eprintln!("out_file: {:?}", out_file); + std::process::Command::new("glslc") + .args(["-std=460core"]) + .args(["-fshader-stage=compute"]) + .args(["-DTYPENAME=float"]) + .args(["-o", &out_file.as_os_str().to_str().unwrap()]) + .arg(p) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .expect("glslc failed to start. Ensure that you have shaderc installed and that `glslc` is in your PATH.") + }) + .collect::>(); + for (kernel_path, child) in kernel_paths.iter().zip(children.into_iter()) { + let output = child.wait_with_output().expect("glslc failed to run. Ensure that you have shaderc installed and that `glslc` is in your PATH."); + assert!( + output.status.success(), + "glslc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + } +} diff --git a/dfdx-core/src/tensor/error.rs b/dfdx-core/src/tensor/error.rs index 906c474c..6ab35ac6 100644 --- a/dfdx-core/src/tensor/error.rs +++ b/dfdx-core/src/tensor/error.rs @@ -22,6 +22,9 @@ pub enum Error { #[cfg(feature = "webgpu")] WebgpuRequestDeviceError(wgpu::RequestDeviceError), + + #[cfg(feature = "webgpu")] + WebgpuSourceLoadError, } impl std::fmt::Display for Error { diff --git a/dfdx-core/src/tensor/webgpu/allocate.rs b/dfdx-core/src/tensor/webgpu/allocate.rs index 49162381..4e4a2692 100644 --- a/dfdx-core/src/tensor/webgpu/allocate.rs +++ b/dfdx-core/src/tensor/webgpu/allocate.rs @@ -22,7 +22,7 @@ impl Webgpu { shape: S, buf: Vec, ) -> Result, Error> { - let buffer = unsafe { self.alloc_empty::(buf.len()) }?; + let buffer = self.alloc_empty::(buf.len())?; buffer.copy_to_device::(&self.dev, &self.queue, &buf); Ok(self.build_tensor(shape, shape.strides(), buffer)) @@ -56,7 +56,7 @@ impl ZerosTensor for Webgpu { fn try_zeros_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); let strides = shape.strides(); - let data = unsafe { self.alloc_empty::(shape.num_elements()) }?; + let data = self.alloc_empty::(shape.num_elements())?; data.copy_to_device(&self.dev, &self.queue, &vec![0u8; data.size()]); Ok(self.build_tensor(shape, strides, data)) diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 086eb4e4..23d73060 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -1,10 +1,12 @@ use wgpu::{ - util::{BufferInitDescriptor, DeviceExt}, - Adapter, BufferDescriptor, BufferUsages, Device, Instance, InstanceDescriptor, Maintain, Queue, - RequestDeviceError, ShaderModule, ShaderModuleDescriptor, + util::{make_spirv, make_spirv_raw, BufferInitDescriptor, DeviceExt}, + Adapter, BufferDescriptor, BufferUsages, Device, DeviceDescriptor, Features, Instance, + InstanceDescriptor, Maintain, Queue, RequestDeviceError, ShaderModule, ShaderModuleDescriptor, + ShaderModuleDescriptorSpirV, }; use crate::{ + prelude::webgpu_kernels::HasGlslType, shapes::{Shape, Unit}, tensor::{ cache::TensorCache, cpu::Cpu, Cache, Error, NoneTape, RandomU64, Storage, Synchronize, @@ -141,8 +143,13 @@ impl Webgpu { let adapter = futures_lite::future::block_on(instance.request_adapter(&Default::default())) .ok_or(Error::WebgpuAdapterNotFound)?; let adapter = Arc::new(adapter); + let descriptor = DeviceDescriptor { + label: None, + features: Features::default() | Features::SPIRV_SHADER_PASSTHROUGH, + limits: Default::default(), + }; let (dev, queue) = - futures_lite::future::block_on(adapter.request_device(&Default::default(), None))?; + futures_lite::future::block_on(adapter.request_device(&descriptor, None))?; let dev = Arc::new(dev); let queue = Arc::new(queue); @@ -214,10 +221,22 @@ impl Webgpu { self.cs_cache.read().contains_key(&name) } - pub(crate) fn load_shader_module(&self, name: TypeId, source: &str) { + pub(crate) fn load_shader_module(&self, name: TypeId, source: &[u8]) + where + E: HasGlslType, + { + // TODO: Get raw SpirV working. I am guessing that is how we are going + // to have to implement atomic stuff with `wgpu`. + // + // let module = Arc::new(unsafe { + // self.dev.create_shader_module_spirv(&ShaderModuleDescriptorSpirV { + // label: None, + // source: make_spirv_raw(source), + // }) + // }); let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor { label: None, - source: wgpu::ShaderSource::Wgsl(source.into()), + source: make_spirv(source), })); #[cfg(not(feature = "no-std"))] self.cs_cache.write().unwrap().insert(name, module); diff --git a/dfdx-core/src/tensor_ops/abs/abs.bwd.glsl b/dfdx-core/src/tensor_ops/abs/abs.bwd.glsl new file mode 100644 index 00000000..cd765ec8 --- /dev/null +++ b/dfdx-core/src/tensor_ops/abs/abs.bwd.glsl @@ -0,0 +1,28 @@ +#version 460 core + +#extension GL_ARB_compute_shader: enable +#extension GL_ARB_shader_storage_buffer_object: enable + +layout(local_size_x = 128) in; + +layout(std430, binding = 1) buffer inpBlock { + TYPENAME inp[]; +}; + +layout(std430, binding = 2) buffer outpBlock { + TYPENAME outp[]; +}; + +layout(std430, binding = 3) buffer input_gradBlock { + TYPENAME input_grad[]; +}; + +layout(std430, binding = 4) buffer output_gradBlock { + TYPENAME output_grad[]; +}; + +void main() { + TYPENAME dx = sign(inp[gl_GlobalInvocationID.x]); + + input_grad[gl_GlobalInvocationID.x] = dx * output_grad[gl_GlobalInvocationID.x]; +} diff --git a/dfdx-core/src/tensor_ops/abs/abs.fwd.glsl b/dfdx-core/src/tensor_ops/abs/abs.fwd.glsl new file mode 100644 index 00000000..00f4c5a8 --- /dev/null +++ b/dfdx-core/src/tensor_ops/abs/abs.fwd.glsl @@ -0,0 +1,22 @@ +#version 460 core + +#extension GL_ARB_compute_shader: enable +#extension GL_ARB_shader_storage_buffer_object: enable + +layout(local_size_x = 128) in; + +layout(std430, binding = 1) buffer inpBlock { + TYPENAME inp[]; +}; + +layout(std430, binding = 2) buffer outpBlock{ + TYPENAME outp[]; +}; + +void main() { + if (inp.length() == 0) { + outp[gl_GlobalInvocationID.x] = abs(outp[gl_GlobalInvocationID.x]); + } else { + outp[gl_GlobalInvocationID.x] = abs(inp[gl_GlobalInvocationID.x]); + } +} diff --git a/dfdx-core/src/tensor_ops/abs/abs.wgsl b/dfdx-core/src/tensor_ops/abs/abs.wgsl deleted file mode 100644 index cd812370..00000000 --- a/dfdx-core/src/tensor_ops/abs/abs.wgsl +++ /dev/null @@ -1,48 +0,0 @@ -// TODO: We need to figure out how to represent empty structs in wgsl -// struct AbsKernelOp { -// empty: u32, -// } - -@group(0) -@binding(0) -var op: array; - -@group(0) -@binding(1) -var inp: array; - -@group(0) -@binding(2) -var out: array; - -@group(0) -@binding(3) -var inp_grad: array; - -@group(0) -@binding(4) -var out_grad: array; - -@compute -@workgroup_size(1) -fn abs_fwd_f32(@builtin(global_invocation_id) global_id: vec3) { - // let length: u32 = arrayLength(&inp); - // if (length > 1) { - // out[global_id.x] = abs(inp[global_id.x]); - // } else { - // out[global_id.x] = abs(out[global_id.x]); - // } - out[global_id.x] = abs(inp[global_id.x]); -} - -@compute -@workgroup_size(1) -fn abs_bwd_f32(@builtin(global_invocation_id) global_id: vec3) { - // Not needed for Abs, but if we can figure out a template system, we can leave it all in. - // let x = if arrayLength(inp) > 0 { inp[global_id] } else { 0.0 }; - // let y = if arrayLength(out) > 0 { out[global_id] } else { 0.0 }; - var dx: f32; - dx = sign(inp[global_id.x]); - - inp_grad[global_id.x] += dx * out_grad[global_id.x]; -} diff --git a/dfdx-core/src/tensor_ops/abs/mod.rs b/dfdx-core/src/tensor_ops/abs/mod.rs index e2ce6982..45c7794d 100644 --- a/dfdx-core/src/tensor_ops/abs/mod.rs +++ b/dfdx-core/src/tensor_ops/abs/mod.rs @@ -57,16 +57,4 @@ mod tests { let g = r.mean().backward(); assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]); } - - #[cfg(feature = "webgpu")] - #[test] - fn test_webgpu_abs() { - let dev: Webgpu = Default::default(); - let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]); - let r = x.leaky_trace().abs(); - assert_close_to_literal!(r, [2.0, 1.0, 0.0, 1.0, 2.0]); - // TODO: Add mean back in - // let g = r.mean().backward(); - // assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]); - } } diff --git a/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs index a5d9059e..130c9e3f 100644 --- a/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs @@ -1,6 +1,25 @@ use super::AbsKernelOp; use crate::tensor_ops::webgpu_kernels::webgpu_unary; -const WGSL: &str = include_str!("abs.wgsl"); +const GLSL_FWD: &str = include_str!("abs.fwd.glsl"); +const GLSL_BWD: &str = include_str!("abs.bwd.glsl"); +const SPV_FWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.fwd.float.spv")); +const SPV_BWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.bwd.float.spv")); -webgpu_unary!(AbsKernelOp, f32, WGSL, "abs_fwd_f32", "abs_bwd_f32"); +webgpu_unary!(AbsKernelOp, f32, SPV_FWD, SPV_BWD); + +#[cfg(test)] +mod tests { + use crate::{tensor::*, tensor_ops::*, tests::*}; + + #[test] + fn test_webgpu_abs() { + let dev: Webgpu = Default::default(); + let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]); + let r = x.leaky_trace().abs(); + assert_close_to_literal!(r, [2.0, 1.0, 0.0, 1.0, 2.0]); + // TODO: Add mean back in + // let g = r.mean().backward(); + // assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]); + } +} diff --git a/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs index de36720c..50a14d55 100644 --- a/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::AccurateGeLUKernelOp, - f32, - WGSL, - "gelu_fwd_f32", - "gelu_bwd_f32" -); +webgpu_unary!(super::AccurateGeLUKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs index 204a4357..4caf717b 100644 --- a/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs @@ -7,9 +7,9 @@ use crate::prelude::{ Dtype, Webgpu, }; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(Scalar, f32, WGSL, "scalar_fwd_f32", "scalar_bwd_f32"); +webgpu_unary!(Scalar, f32, WGSL, WGSL); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs index 82485d5a..ced36298 100644 --- a/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::ClampKernelOp, - f32, - WGSL, - "clamp_fwd_f32", - "clamp_bwd_f32" -); +webgpu_unary!(super::ClampKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs index a352702f..88737c1b 100644 --- a/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(super::CosKernelOp, f32, WGSL, "cos_fwd_f32", "cos_bwd_f32"); +webgpu_unary!(super::CosKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs index 158a05fd..2ba21757 100644 --- a/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs @@ -3,9 +3,9 @@ use std::borrow::Cow; use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_sub_fwd", "scalar_sub_bwd"); +webgpu_unary!(const_df() Scalar, f32, WGSL, WGSL); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs index 8670fc23..5438fd96 100644 --- a/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(super::ExpKernelOp, f32, WGSL, "exp_fwd_f32", "exp_bwd_f32"); +webgpu_unary!(super::ExpKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs index 4abc2738..438e3066 100644 --- a/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/fast_gelu/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::FastGeLUKernelOp, - f32, - WGSL, - "sigmoid_fwd_f32", - "sigmoid_bwd_f32" -); +webgpu_unary!(super::FastGeLUKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs index bf08f71f..e0d574a6 100644 --- a/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/ln/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(super::LnKernelOp, f32, WGSL, "ln_fwd_f32", "ln_bwd_f32"); +webgpu_unary!(super::LnKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs index 6b778db4..f3176043 100644 --- a/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs @@ -3,9 +3,9 @@ use std::borrow::Cow; use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_mul_fwd", "scalar_mul_bwd"); +webgpu_unary!(const_df() Scalar, f32, WGSL, WGSL); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs index 5721f4a0..a86c5ab3 100644 --- a/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/nans_to/webgpu_kernel.rs @@ -1,12 +1,6 @@ use super::NansToKernelOp; use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - NansToKernelOp, - f32, - WGSL, - "nans_to_fwd_f32", - "nans_to_bwd_f32" -); +webgpu_unary!(NansToKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs index be7690f6..0cce0688 100644 --- a/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/negate/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::NegateKernelOp, - f32, - WGSL, - "negate_fwd_f32", - "negate_bwd_f32" -); +webgpu_unary!(super::NegateKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs index d21a15b6..83af122e 100644 --- a/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/pow/webgpu_kernel.rs @@ -2,15 +2,9 @@ use std::borrow::Cow; use crate::prelude::{ops::UnaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::PowfKernelOp, - f32, - WGSL, - "powf_fwd_f32", - "powf_bwd_f32" -); +webgpu_unary!(super::PowfKernelOp, f32, WGSL, WGSL); // TODO: Conflicting implementations of trait `UnaryKernel` for type `Webgpu`: impl UnaryKernel for Webgpu diff --git a/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs index d3a14fe0..7f31f22d 100644 --- a/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/recip/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(df(f(x)) super::RecipKernelOp, f32, WGSL, "recip_fwd_f32", "recip_bwd_f32"); +webgpu_unary!(df(f(x)) super::RecipKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs index c986c73d..d1917a44 100644 --- a/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/relu/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::ReLUKernelOp, - f32, - WGSL, - "relu_fwd_f32", - "relu_bwd_f32" -); +webgpu_unary!(super::ReLUKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs index 377c4453..bbe904eb 100644 --- a/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sigmoid/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(df(f(x)) super::SigmoidKernelOp, f32, WGSL, "sigmoid_fwd_f32", "sigmoid_bwd_f32"); +webgpu_unary!(df(f(x)) super::SigmoidKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs index befc2da0..ee975e34 100644 --- a/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sin/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(super::SinKernelOp, f32, WGSL, "sin_fwd_f32", "sin_bwd_f32"); +webgpu_unary!(super::SinKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs index e2b0b032..86fb7809 100644 --- a/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sqrt/webgpu_kernel.rs @@ -1,5 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(df(f(x)) super::SqrtKernelOp, f32, WGSL, "sqrt_fwd_f32", "sqrt_bwd_f32"); +webgpu_unary!(df(f(x)) super::SqrtKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs index f16523a2..c5a1805a 100644 --- a/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/square/webgpu_kernel.rs @@ -1,11 +1,5 @@ use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!( - super::SquareKernelOp, - f32, - WGSL, - "square_fwd_f32", - "square_bwd_f32" -); +webgpu_unary!(super::SquareKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs index 0d476c94..d2a86789 100644 --- a/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs @@ -4,9 +4,9 @@ use super::{BinarySubKernelOp as Binary, ScalarSubKernelOp as Scalar}; use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu}; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(const_df() Scalar, f32, WGSL, "scalar_sub_fwd", "scalar_sub_bwd"); +webgpu_unary!(const_df() Scalar, f32, WGSL, WGSL); impl BinaryKernel for Webgpu { const BACKWARD_WITHOUT_DATA: bool = true; diff --git a/dfdx-core/src/tensor_ops/sum_to/sum_to.fwd.glsl b/dfdx-core/src/tensor_ops/sum_to/sum_to.fwd.glsl new file mode 100644 index 00000000..a8647f4b --- /dev/null +++ b/dfdx-core/src/tensor_ops/sum_to/sum_to.fwd.glsl @@ -0,0 +1,108 @@ +#version 460 core + +#extension GL_EXT_shader_atomic_float: enable +#extension SPV_EXT_shader_atomic_float_add: enable +#extension GL_ARB_compute_shader: enable +#extension GL_ARB_shader_storage_buffer_object: enable +#extension ARB_shader_atomic_counter_ops: enable +#extension VK_EXT_shader_atomic_float: enable + +layout(local_size_x = 128) in; + +layout(std430, binding = 1) buffer inpBlock { + TYPENAME inp[]; +}; + +layout(std430, binding = 2) buffer outpBlock { + TYPENAME outp[]; +}; + +layout(std430, binding = 3) buffer params { + uint chunk_len; + TYPENAME elems_per_thread; +}; + +layout(std430, binding = 4) buffer dimsBlock { + uint dims[]; +}; + +layout(std430, binding = 5) buffer stridesBlock { + uint strides[]; +}; + +uint next_power_of_two(uint v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v++; + return v; +} + +uint get_strided_index(uint idx) { + uint strided_i = 0; + for (uint d = 0; d < dims.length(); d++) { + uint dim_idx = dims.length() - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +void chunk_sum( + uint chunk_len, + TYPENAME data +) { + TYPENAME buf[1024]; + + // assumes that threads where i >= numel have already exited + uint i = gl_GlobalInvocationID.x; + uint block_i = gl_WorkGroupID.x; + + // Fall back to atomicAdd if chunk_len is small to reduce overhead + if (chunk_len <= 2) { + atomicAdd(outp[i / chunk_len], data); + return; + } + buf[block_i] = data; + + uint chunk_i = i % chunk_len; + uint chunk_start = max(int(block_i - chunk_i), 0); + uint chunk_end = min(uint(block_i + chunk_len - chunk_i), gl_WorkGroupSize.x); + + chunk_i = block_i - chunk_start; + + uint max_chunk_len = min(chunk_end - chunk_start, gl_WorkGroupSize.x); + uint incr = next_power_of_two(max_chunk_len) >> 1; + + barrier(); + + // Uses sequential addressing as discussed in + // https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf + for (; incr > 0; incr >>= 1) { + uint block_i_2 = block_i + incr; + + if (block_i_2 < chunk_end && chunk_i < incr) { + // This is sound because __syncthreads and the conditions above + // ensure that no data races occur + buf[block_i] += buf[block_i_2]; + } + + barrier(); + } + + if (block_i == chunk_start) { + atomicAdd(outp[i / chunk_len], buf[block_i]); + } +} + +void main() { + if (gl_GlobalInvocationID.x >= inp.length()) { + return; + } + + uint inp_idx = get_strided_index(gl_GlobalInvocationID.x); + + chunk_sum(chunk_len, inp[inp_idx] * elems_per_thread); +} diff --git a/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs index 29247ea7..131e2b75 100644 --- a/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sum_to/webgpu_kernel.rs @@ -1,6 +1,31 @@ -use crate::prelude::{Dtype, Webgpu}; +use core::any::TypeId; -impl super::SumKernel for Webgpu { +use wgpu::ComputePipelineDescriptor; + +use crate::{ + prelude::{ + webgpu_kernels::{Forward, HasGlslType}, + Dtype, Webgpu, + }, + tensor_ops::reduction_utils::*, +}; + +struct WebgpuSumKernel; + +trait HasWebgpuKernel { + const MOD: &'static str; + const FNS: &'static [&'static str]; +} + +impl HasWebgpuKernel for Webgpu { + const MOD: &'static str = "sum_f32"; + const FNS: &'static [&'static str] = &["sum_to_fwd_f32", "sum_to_bwd_f32"]; +} + +impl super::SumKernel for Webgpu +where + Self: HasWebgpuKernel, +{ fn forward( &self, dst: Dst, @@ -9,7 +34,72 @@ impl super::SumKernel for Webgpu { where Src: crate::prelude::ReduceShapeTo, { - todo!() + if !self.shader_module_loaded(TypeId::of::>()) { + self.load_shader_module::( + TypeId::of::>(), + include_bytes!(concat!(env!("OUT_DIR"), "/sum_to.fwd.float.spv")), + ); + } + + let cs_module = self + .get_shader_module(TypeId::of::>()) + .expect("shader module not loaded"); + let pipeline = self + .dev + .create_compute_pipeline(&ComputePipelineDescriptor { + label: None, + layout: None, + module: &cs_module, + entry_point: "main", + }); + + let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides); + let num_dims = dims.len(); + + let mut info = Vec::with_capacity(num_dims * 2); + info.extend(dims); + info.extend(strides); + let info_buffer = self.alloc_empty::(num_dims * 2)?; + info_buffer.copy_to_device(&self.dev, &self.queue, &info); + + let elems_per_thread = E::from_usize(reduction_elems_per_thread::<_, Src>( + inp.shape.concrete(), + inp.strides, + Ax::as_array(), + )) + .unwrap(); + + let physical_numel = inp.data.len::(); + let physical_num_blocks = (physical_numel + 128 - 1) / 128; + let (dst_physical_numel, dst_strides) = + reduction_output_strides::(inp.strides, dst); + let chunk_len = physical_numel / dst_physical_numel; + + let bind_group_layout = pipeline.get_bind_group_layout(0); + let storage = self.alloc_empty::(dst_physical_numel)?; + let mut entries = Vec::new(); + + todo!("add buffers to entries, but we need to get atomic operations working"); + + let binding_group = self.dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &entries, + }); + let mut encoder = self + .dev + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &binding_group, &[]); + cpass.dispatch_workgroups(physical_num_blocks as u32, 1, 1); + } + self.queue.submit(Some(encoder.finish())); + Ok(self.build_tensor(dst, dst_strides, storage)) } fn backward( @@ -25,3 +115,24 @@ impl super::SumKernel for Webgpu { todo!() } } + +#[cfg(test)] +mod tests { + use crate::prelude::*; + use crate::tensor_ops::*; + use crate::tests::*; + + #[ignore] + #[test] + fn test_sum_1d() { + let dev: Webgpu = Webgpu::default(); + let t = dev.tensor([1.0, 2.0, 3.0]); + let r = t.leaky_trace().sum::(); + let e = 6.0f64; + assert_close_to_literal!(r, e); + // TODO: Add exp back in + // NOTE: .exp() to make sure its using result grad properly + // let g = r.exp().backward(); + // assert_close_to_literal!(g.get(&t), [e.exp(); 3]); + } +} diff --git a/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs index aa8742d8..d29f2a07 100644 --- a/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/tanh/webgpu_kernel.rs @@ -1,6 +1,6 @@ use super::TanhKernelOp; use crate::prelude::webgpu_kernels::webgpu_unary; -const WGSL: &str = "TODO"; +const WGSL: &[u8] = b"TODO"; -webgpu_unary!(TanhKernelOp, f32, WGSL, "tanh_fwd_f32", "tanh_bwd_f32"); +webgpu_unary!(TanhKernelOp, f32, WGSL, WGSL); diff --git a/dfdx-core/src/tensor_ops/utilities/reduction_utils.rs b/dfdx-core/src/tensor_ops/utilities/reduction_utils.rs index fefb9811..378a6655 100644 --- a/dfdx-core/src/tensor_ops/utilities/reduction_utils.rs +++ b/dfdx-core/src/tensor_ops/utilities/reduction_utils.rs @@ -46,7 +46,7 @@ pub(crate) fn index_for_reductions( /// Moves all axes in Ax to the end of dims and strides and removes broadcasted dimensions /// so that a cuda kernel called for each physical element of the input tensor will place elements /// to be reduced with each other next to each other in memory. -#[cfg(feature = "cuda")] +#[cfg(any(feature = "cuda", feature = "webgpu"))] pub(crate) fn permute_for_reductions(dims: I, strides: I) -> (Vec, Vec) where I: IntoIterator, @@ -74,7 +74,7 @@ where /// Returns the physical number of elements and strides of dst so that broadcasted dimensions in /// src are also broadcasted in dst -#[cfg(feature = "cuda")] +#[cfg(any(feature = "cuda", feature = "webgpu"))] #[inline(always)] pub(crate) fn reduction_output_strides( src_strides: Src::Concrete, @@ -101,7 +101,7 @@ pub(crate) fn reduction_output_strides( } /// Gives the product of all dimensions that are being reduced and are broadcasted. -#[cfg(feature = "cuda")] +#[cfg(any(feature = "cuda", feature = "webgpu"))] #[inline(always)] pub(crate) fn reduction_elems_per_thread, S: Shape>( dims: S::Concrete, diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs index 619afb1c..1579134c 100644 --- a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -4,64 +4,80 @@ use crate::{ tensor_ops::ops::{BinaryKernel, UnaryKernel}, }; use core::any::TypeId; -use std::{borrow::Cow, sync::Arc, vec::Vec}; +use std::{borrow::Cow, marker::PhantomData, sync::Arc, vec::Vec}; pub(crate) trait UnaryOpWebgpuKernel { const DF_USES_FX: bool; const HAS_CONST_DF: bool; - /// Compiled by build.rs - const WGSL_SRC: &'static str; + // /// Unique name for the kernel + // const MODULE_NAME: &'static str; - /// Unique name for the kernel - const MODULE_NAME: &'static str; + /// Glsl source code for the forward pass + const GLSL_FWD_SPV: &'static [u8]; - /// Name of function in the .wgsl file - const FWD_FN_NAME: &'static str; - - /// Name of function in the .wgsl file - const BWD_FN_NAME: &'static str; - - const ALL_FN_NAMES: [&'static str; 2] = [Self::FWD_FN_NAME, Self::BWD_FN_NAME]; + /// Glsl source code for the backward pass + const GLSL_BWD_SPV: &'static [u8]; } macro_rules! webgpu_unary { - ($Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + ($Op:path, $TypeName:ty, $Fwd:tt, $Bwd:tt) => { impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { const DF_USES_FX: bool = false; const HAS_CONST_DF: bool = false; - const WGSL_SRC: &'static str = $Wgsl; - const MODULE_NAME: &'static str = stringify!($Op); - const FWD_FN_NAME: &'static str = $Fwd; - const BWD_FN_NAME: &'static str = $Bwd; + // const MODULE_NAME: &'static str = stringify!($Op); + const GLSL_FWD_SPV: &'static [u8] = $Fwd; + const GLSL_BWD_SPV: &'static [u8] = $Bwd; } }; - (df(f(x)) $Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + (df(f(x)) $Op:path, $TypeName:ty, $Fwd:tt, $Bwd:tt) => { impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { const DF_USES_FX: bool = true; const HAS_CONST_DF: bool = false; - const WGSL_SRC: &'static str = $Wgsl; - const MODULE_NAME: &'static str = $Fwd; - const FWD_FN_NAME: &'static str = $Fwd; - const BWD_FN_NAME: &'static str = $Bwd; + // const MODULE_NAME: &'static str = $Fwd; + const GLSL_FWD_SPV: &'static [u8] = $Fwd; + const GLSL_BWD_SPV: &'static [u8] = $Bwd; } }; - (const_df() $Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + (const_df() $Op:path, $TypeName:ty, $Fwd:tt, $Bwd:tt) => { impl crate::tensor_ops::webgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { const DF_USES_FX: bool = false; const HAS_CONST_DF: bool = true; - const WGSL_SRC: &'static str = $Wgsl; - const MODULE_NAME: &'static str = $Fwd; - const FWD_FN_NAME: &'static str = $Fwd; - const BWD_FN_NAME: &'static str = $Bwd; + // const MODULE_NAME: &'static str = $Fwd; + const GLSL_FWD_SPV: &'static [u8] = $Fwd; + const GLSL_BWD_SPV: &'static [u8] = $Bwd; } }; } +/// Zero-sized marker type for forward pass TypeId +#[derive(Debug, Default)] +pub(crate) struct Forward { + _phantom: PhantomData<(E, K)>, +} + +/// Zero-sized marker type for backward pass TypeId +#[derive(Debug, Default)] +pub(crate) struct Backward { + _phantom: PhantomData<(E, K)>, +} + +pub(crate) trait HasGlslType { + const TYPE: &'static str; +} + +impl HasGlslType for f32 { + const TYPE: &'static str = "float"; +} + +impl HasGlslType for f64 { + const TYPE: &'static str = "double"; +} + pub(crate) use webgpu_unary; use wgpu::ComputePipelineDescriptor; -impl + 'static> UnaryKernel for Webgpu { +impl + 'static> UnaryKernel for Webgpu { const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX; const BACKWARD_WITHOUT_DATA: bool = K::HAS_CONST_DF; @@ -70,27 +86,28 @@ impl + 'static> UnaryKernel for Webgpu op: K, inp: Cow>, ) -> Result, Error> { - if !self.shader_module_loaded(TypeId::of::()) { - self.load_shader_module(TypeId::of::(), K::WGSL_SRC); + if !self.shader_module_loaded(TypeId::of::>()) { + self.load_shader_module::(TypeId::of::>(), K::GLSL_FWD_SPV); } let cs_module = self - .get_shader_module(TypeId::of::()) - .expect("shader module not loaded"); + .get_shader_module(TypeId::of::>()) + .ok_or(Error::WebgpuSourceLoadError)?; let pipeline = self .dev .create_compute_pipeline(&ComputePipelineDescriptor { label: None, layout: None, module: &cs_module, - entry_point: K::FWD_FN_NAME, + entry_point: "main", }); let bind_group_layout = pipeline.get_bind_group_layout(0); let op_storage = self.alloc_init::(&[op])?; let numel = inp.data.len::(); + let num_blocks = (numel + 128 - 1) / 128; let storage = self.alloc_empty::(numel)?; let empty = self.alloc_empty::(0)?; - let mut entries = vec![]; + let mut entries = Vec::new(); // WGSL doesn't support empty structs, so don't bind the empty buffer if std::mem::size_of::() > 0 { entries.push(wgpu::BindGroupEntry { @@ -124,7 +141,7 @@ impl + 'static> UnaryKernel for Webgpu }); cpass.set_pipeline(&pipeline); cpass.set_bind_group(0, &binding_group, &[]); - cpass.dispatch_workgroups(numel as u32, 1, 1); + cpass.dispatch_workgroups(num_blocks as u32, 1, 1); } self.queue.submit(Some(encoder.finish())); Ok(self.build_tensor(inp.shape, inp.strides, storage)) @@ -155,7 +172,7 @@ impl + 'static> UnaryKernel for Webgpu }); cpass.set_pipeline(&pipeline); cpass.set_bind_group(0, &binding_group, &[]); - cpass.dispatch_workgroups(numel as u32, 1, 1); + cpass.dispatch_workgroups(num_blocks as u32, 1, 1); } self.queue.submit(Some(encoder.finish())); Ok(inp) @@ -171,28 +188,27 @@ impl + 'static> UnaryKernel for Webgpu out: &impl Tensorlike, grad_out: &Self::Vec, ) -> Result<(), Error> { - if !self.shader_module_loaded(TypeId::of::()) { - self.load_shader_module(TypeId::of::(), K::WGSL_SRC); + if !self.shader_module_loaded(TypeId::of::>()) { + self.load_shader_module::(TypeId::of::>(), K::GLSL_BWD_SPV); } let cs_module = self - .get_shader_module(TypeId::of::()) - .expect("shader module not loaded"); + .get_shader_module(TypeId::of::>()) + .ok_or(Error::WebgpuSourceLoadError)?; let pipeline = self .dev .create_compute_pipeline(&ComputePipelineDescriptor { label: None, layout: None, module: &cs_module, - entry_point: K::BWD_FN_NAME, + entry_point: "main", }); let bind_group_layout = pipeline.get_bind_group_layout(0); let op_storage = self.alloc_init::(&[op])?; let numel = inp.len(); - let storage = self.alloc_empty::(numel)?; let empty_inp = self.alloc_empty::(0)?; let empty_out = self.alloc_empty::(0)?; - let mut entries = vec![]; + let mut entries = Vec::new(); // WGSL doesn't support empty structs, so don't bind the empty buffer if std::mem::size_of::() > 0 { entries.push(wgpu::BindGroupEntry {