Skip to content

Commit

Permalink
Managed to get built spirv working as long as we go through the
Browse files Browse the repository at this point in the history
non-passthrough route.

Can't get sum_to working until wgpu supports atomic operations. Which is
super unfortunate.

Maybe I'll work on that soon...
  • Loading branch information
favilo committed Dec 24, 2023
1 parent 0118dee commit c1b440b
Show file tree
Hide file tree
Showing 34 changed files with 487 additions and 209 deletions.
12 changes: 10 additions & 2 deletions dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_dis
gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] }
rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", optional = true }
wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true }
naga = { version = "0.14.1", optional = true }
futures-lite = { version = "2.0.1", optional = true }
thingbuf = { version = "0.1.4", optional = true }

Expand All @@ -62,7 +63,14 @@ fast-alloc = ["std"]

cuda = ["dep:cudarc", "dep:glob"]
cudnn = ["cuda", "cudarc?/cudnn"]
webgpu = ["dep:wgpu", "dep:futures-lite", "dep:thingbuf", "wgpu/expose-ids"]
webgpu = [
"dep:wgpu",
"dep:futures-lite",
"dep:thingbuf",
"dep:naga",
"dep:glob",
"wgpu/expose-ids",
]

f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"]

Expand Down
52 changes: 52 additions & 0 deletions dfdx-core/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ fn main() {

#[cfg(feature = "cuda")]
cuda::build_ptx();

#[cfg(feature = "webgpu")]
webgpu::build_spv();
}

fn maybe_enable_nightly() {
Expand Down Expand Up @@ -210,3 +213,52 @@ mod cuda {
}
}
}

#[cfg(feature = "webgpu")]
mod webgpu {
pub fn build_spv() {
let out_dir = std::env::var("OUT_DIR").unwrap();
let kernel_paths: Vec<std::path::PathBuf> = glob::glob("src/**/*.glsl")
.unwrap()
.map(|p| p.unwrap())
.collect();
for path in &kernel_paths {
println!("cargo:rerun-if-changed={}", path.display());
}

kernel_paths
.iter()
.for_each(|p| println!("cargo:rerun-if-changed={}", p.display()));

let children = kernel_paths
.iter()
.map(|p| {
// TODO: we need to build this for both float and double
let out_path: std::path::PathBuf = out_dir.clone().into();
let base = p.file_stem().unwrap();
let new_name = format!("{}.float.spv", base.to_str().unwrap());
let out_file = &out_path.join(new_name);
eprintln!("out_file: {:?}", out_file);
std::process::Command::new("glslc")
.args(["-std=460core"])
.args(["-fshader-stage=compute"])
.args(["-DTYPENAME=float"])
.args(["-o", &out_file.as_os_str().to_str().unwrap()])
.arg(p)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.expect("glslc failed to start. Ensure that you have shaderc installed and that `glslc` is in your PATH.")
})
.collect::<Vec<_>>();
for (kernel_path, child) in kernel_paths.iter().zip(children.into_iter()) {
let output = child.wait_with_output().expect("glslc failed to run. Ensure that you have shaderc installed and that `glslc` is in your PATH.");
assert!(
output.status.success(),
"glslc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
}
}
3 changes: 3 additions & 0 deletions dfdx-core/src/tensor/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub enum Error {

#[cfg(feature = "webgpu")]
WebgpuRequestDeviceError(wgpu::RequestDeviceError),

#[cfg(feature = "webgpu")]
WebgpuSourceLoadError,
}

impl std::fmt::Display for Error {
Expand Down
4 changes: 2 additions & 2 deletions dfdx-core/src/tensor/webgpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Webgpu {
shape: S,
buf: Vec<E>,
) -> Result<Tensor<S, E, Self>, Error> {
let buffer = unsafe { self.alloc_empty::<E>(buf.len()) }?;
let buffer = self.alloc_empty::<E>(buf.len())?;
buffer.copy_to_device::<E>(&self.dev, &self.queue, &buf);

Ok(self.build_tensor(shape, shape.strides(), buffer))
Expand Down Expand Up @@ -56,7 +56,7 @@ impl<E: Unit + SafeZeros> ZerosTensor<E> for Webgpu {
fn try_zeros_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let strides = shape.strides();
let data = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
let data = self.alloc_empty::<E>(shape.num_elements())?;
data.copy_to_device(&self.dev, &self.queue, &vec![0u8; data.size()]);

Ok(self.build_tensor(shape, strides, data))
Expand Down
31 changes: 25 additions & 6 deletions dfdx-core/src/tensor/webgpu/device.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
Adapter, BufferDescriptor, BufferUsages, Device, Instance, InstanceDescriptor, Maintain, Queue,
RequestDeviceError, ShaderModule, ShaderModuleDescriptor,
util::{make_spirv, make_spirv_raw, BufferInitDescriptor, DeviceExt},
Adapter, BufferDescriptor, BufferUsages, Device, DeviceDescriptor, Features, Instance,
InstanceDescriptor, Maintain, Queue, RequestDeviceError, ShaderModule, ShaderModuleDescriptor,
ShaderModuleDescriptorSpirV,
};

use crate::{
prelude::webgpu_kernels::HasGlslType,
shapes::{Shape, Unit},
tensor::{
cache::TensorCache, cpu::Cpu, Cache, Error, NoneTape, RandomU64, Storage, Synchronize,
Expand Down Expand Up @@ -141,8 +143,13 @@ impl Webgpu {
let adapter = futures_lite::future::block_on(instance.request_adapter(&Default::default()))
.ok_or(Error::WebgpuAdapterNotFound)?;
let adapter = Arc::new(adapter);
let descriptor = DeviceDescriptor {
label: None,
features: Features::default() | Features::SPIRV_SHADER_PASSTHROUGH,
limits: Default::default(),
};
let (dev, queue) =
futures_lite::future::block_on(adapter.request_device(&Default::default(), None))?;
futures_lite::future::block_on(adapter.request_device(&descriptor, None))?;
let dev = Arc::new(dev);
let queue = Arc::new(queue);

Expand Down Expand Up @@ -214,10 +221,22 @@ impl Webgpu {
self.cs_cache.read().contains_key(&name)
}

pub(crate) fn load_shader_module(&self, name: TypeId, source: &str) {
pub(crate) fn load_shader_module<E>(&self, name: TypeId, source: &[u8])
where
E: HasGlslType,
{
// TODO: Get raw SpirV working. I am guessing that is how we are going
// to have to implement atomic stuff with `wgpu`.
//
// let module = Arc::new(unsafe {
// self.dev.create_shader_module_spirv(&ShaderModuleDescriptorSpirV {
// label: None,
// source: make_spirv_raw(source),
// })
// });
let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(source.into()),
source: make_spirv(source),
}));
#[cfg(not(feature = "no-std"))]
self.cs_cache.write().unwrap().insert(name, module);
Expand Down
28 changes: 28 additions & 0 deletions dfdx-core/src/tensor_ops/abs/abs.bwd.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#version 460 core

#extension GL_ARB_compute_shader: enable
#extension GL_ARB_shader_storage_buffer_object: enable

layout(local_size_x = 128) in;

layout(std430, binding = 1) buffer inpBlock {
TYPENAME inp[];
};

layout(std430, binding = 2) buffer outpBlock {
TYPENAME outp[];
};

layout(std430, binding = 3) buffer input_gradBlock {
TYPENAME input_grad[];
};

layout(std430, binding = 4) buffer output_gradBlock {
TYPENAME output_grad[];
};

void main() {
TYPENAME dx = sign(inp[gl_GlobalInvocationID.x]);

input_grad[gl_GlobalInvocationID.x] = dx * output_grad[gl_GlobalInvocationID.x];
}
22 changes: 22 additions & 0 deletions dfdx-core/src/tensor_ops/abs/abs.fwd.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#version 460 core

#extension GL_ARB_compute_shader: enable
#extension GL_ARB_shader_storage_buffer_object: enable

layout(local_size_x = 128) in;

layout(std430, binding = 1) buffer inpBlock {
TYPENAME inp[];
};

layout(std430, binding = 2) buffer outpBlock{
TYPENAME outp[];
};

void main() {
if (inp.length() == 0) {
outp[gl_GlobalInvocationID.x] = abs(outp[gl_GlobalInvocationID.x]);
} else {
outp[gl_GlobalInvocationID.x] = abs(inp[gl_GlobalInvocationID.x]);
}
}
48 changes: 0 additions & 48 deletions dfdx-core/src/tensor_ops/abs/abs.wgsl

This file was deleted.

12 changes: 0 additions & 12 deletions dfdx-core/src/tensor_ops/abs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,4 @@ mod tests {
let g = r.mean().backward();
assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]);
}

#[cfg(feature = "webgpu")]
#[test]
fn test_webgpu_abs() {
let dev: Webgpu = Default::default();
let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
let r = x.leaky_trace().abs();
assert_close_to_literal!(r, [2.0, 1.0, 0.0, 1.0, 2.0]);
// TODO: Add mean back in
// let g = r.mean().backward();
// assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]);
}
}
23 changes: 21 additions & 2 deletions dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
use super::AbsKernelOp;
use crate::tensor_ops::webgpu_kernels::webgpu_unary;

const WGSL: &str = include_str!("abs.wgsl");
const GLSL_FWD: &str = include_str!("abs.fwd.glsl");
const GLSL_BWD: &str = include_str!("abs.bwd.glsl");
const SPV_FWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.fwd.float.spv"));
const SPV_BWD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/abs.bwd.float.spv"));

webgpu_unary!(AbsKernelOp, f32, WGSL, "abs_fwd_f32", "abs_bwd_f32");
webgpu_unary!(AbsKernelOp, f32, SPV_FWD, SPV_BWD);

#[cfg(test)]
mod tests {
use crate::{tensor::*, tensor_ops::*, tests::*};

#[test]
fn test_webgpu_abs() {
let dev: Webgpu = Default::default();
let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
let r = x.leaky_trace().abs();
assert_close_to_literal!(r, [2.0, 1.0, 0.0, 1.0, 2.0]);
// TODO: Add mean back in
// let g = r.mean().backward();
// assert_close_to_literal!(g.get(&x), [-0.2, -0.2, 0.0, 0.2, 0.2]);
}
}
10 changes: 2 additions & 8 deletions dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
use crate::prelude::webgpu_kernels::webgpu_unary;

const WGSL: &str = "TODO";
const WGSL: &[u8] = b"TODO";

webgpu_unary!(
super::AccurateGeLUKernelOp,
f32,
WGSL,
"gelu_fwd_f32",
"gelu_bwd_f32"
);
webgpu_unary!(super::AccurateGeLUKernelOp, f32, WGSL, WGSL);
4 changes: 2 additions & 2 deletions dfdx-core/src/tensor_ops/add/webgpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use crate::prelude::{
Dtype, Webgpu,
};

const WGSL: &str = "TODO";
const WGSL: &[u8] = b"TODO";

webgpu_unary!(Scalar<f32>, f32, WGSL, "scalar_fwd_f32", "scalar_bwd_f32");
webgpu_unary!(Scalar<f32>, f32, WGSL, WGSL);

impl<E: Dtype> BinaryKernel<super::BinaryAddKernelOp, E> for Webgpu {
const BACKWARD_WITHOUT_DATA: bool = true;
Expand Down
10 changes: 2 additions & 8 deletions dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
use crate::prelude::webgpu_kernels::webgpu_unary;

const WGSL: &str = "TODO";
const WGSL: &[u8] = b"TODO";

webgpu_unary!(
super::ClampKernelOp<f32>,
f32,
WGSL,
"clamp_fwd_f32",
"clamp_bwd_f32"
);
webgpu_unary!(super::ClampKernelOp<f32>, f32, WGSL, WGSL);
4 changes: 2 additions & 2 deletions dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::prelude::webgpu_kernels::webgpu_unary;

const WGSL: &str = "TODO";
const WGSL: &[u8] = b"TODO";

webgpu_unary!(super::CosKernelOp, f32, WGSL, "cos_fwd_f32", "cos_bwd_f32");
webgpu_unary!(super::CosKernelOp, f32, WGSL, WGSL);
4 changes: 2 additions & 2 deletions dfdx-core/src/tensor_ops/div/webgpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::borrow::Cow;

use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu};

const WGSL: &str = "TODO";
const WGSL: &[u8] = b"TODO";

webgpu_unary!(const_df() Scalar<f32>, f32, WGSL, "scalar_sub_fwd", "scalar_sub_bwd");
webgpu_unary!(const_df() Scalar<f32>, f32, WGSL, WGSL);

impl<E: Dtype> BinaryKernel<super::BinaryDivKernelOp, E> for Webgpu {
const BACKWARD_WITHOUT_DATA: bool = true;
Expand Down
4 changes: 2 additions & 2 deletions dfdx-core/src/tensor_ops/exp/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::prelude::webgpu_kernels::webgpu_unary;

const WGSL: &str = "TODO";
const WGSL: &[u8] = b"TODO";

webgpu_unary!(super::ExpKernelOp, f32, WGSL, "exp_fwd_f32", "exp_bwd_f32");
webgpu_unary!(super::ExpKernelOp, f32, WGSL, WGSL);
Loading

0 comments on commit c1b440b

Please sign in to comment.