Skip to content

Commit

Permalink
First draft CUDA runtime (#1685)
Browse files Browse the repository at this point in the history
Initial cuda runtime crate with a WIP compiler.
  • Loading branch information
nathanielsimard authored Apr 30, 2024
1 parent ab50143 commit 587b8f8
Show file tree
Hide file tree
Showing 49 changed files with 2,905 additions and 103 deletions.
198 changes: 140 additions & 58 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ version.workspace = true
# we depend on wgpu and autotune by default because we use the burn-wgpu crate to get system information
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle", "burn/cuda"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
candle-accelerate = ["burn/candle", "burn/accelerate"]
ndarray = ["burn/ndarray"]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-compute/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub struct Handle<Server: ComputeServer> {
}

/// Binding of a [tensor handle](Handle) to execute a kernel.
#[derive(new)]
#[derive(new, Debug)]
pub struct Binding<Server: ComputeServer> {
/// Memory binding.
pub memory: <Server::MemoryManagement as MemoryManagement<Server::Storage>>::Binding,
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ autodiff = ["burn-autodiff"]
fusion = ["burn-wgpu?/fusion"]

## Backend features
cuda = ["burn-candle?/cuda"]
metal = ["burn-candle?/metal"]
accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"]
openblas = ["burn-ndarray?/blas-openblas"]
Expand All @@ -84,6 +83,7 @@ template = ["burn-wgpu?/template"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
wgpu = ["burn-wgpu"]

# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
Expand Down
40 changes: 40 additions & 0 deletions crates/burn-cuda/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "CUDA backend for the Burn framework"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "gpu", "cuda"]
license.workspace = true
name = "burn-cuda"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/burn-cuda"
version.workspace = true

[features]
default = ["fusion", "burn-jit/default"]
fusion = ["burn-fusion", "burn-jit/fusion"]
autotune = ["burn-jit/autotune"]
doc = ["burn-jit/doc"]
std = ["burn-jit/std"]

[dependencies]
burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false }
burn-compute = { path = "../burn-compute", version = "0.14.0" }
burn-tensor = { path = "../burn-tensor", version = "0.14.0" }
burn-common = { path = "../burn-common", version = "0.14.0" }
burn-fusion = { path = "../burn-fusion", version = "0.14.0", optional = true }
half = { workspace = true }

bytemuck = { workspace = true }
cudarc = "0.10.0"

log = { workspace = true }
derive-new = { workspace = true }

[dev-dependencies]
burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false, features = [
"export_tests",
] }

[package.metadata.docs.rs]
features = ["doc"]
5 changes: 5 additions & 0 deletions crates/burn-cuda/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Burn-Cuda

This backend is still a work in progress and not ready to be used.

See #1525
Loading

0 comments on commit 587b8f8

Please sign in to comment.