From 5c532ec5dc51cd17cd4bb9ae940ecf2c9baf89f6 Mon Sep 17 00:00:00 2001 From: rainiwu Date: Fri, 26 Jan 2024 00:29:35 -0800 Subject: [PATCH 01/20] remove deprecated ftz intrinsics --- dfdx-core/src/lib.rs | 38 -------------------------------------- dfdx/examples/12-mnist.rs | 3 --- 2 files changed, 41 deletions(-) diff --git a/dfdx-core/src/lib.rs b/dfdx-core/src/lib.rs index 31e61643..c126db2c 100644 --- a/dfdx-core/src/lib.rs +++ b/dfdx-core/src/lib.rs @@ -128,44 +128,6 @@ pub mod prelude { pub use crate::tensor_ops::*; } -/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn flush_denormals_to_zero() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } -} - -/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn keep_denormals() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } -} - #[cfg(test)] pub(crate) mod tests { pub use num_traits::{Float, NumCast, Zero}; diff --git a/dfdx/examples/12-mnist.rs b/dfdx/examples/12-mnist.rs index 705d14c8..00d43452 100644 --- a/dfdx/examples/12-mnist.rs +++ b/dfdx/examples/12-mnist.rs @@ -62,9 +62,6 @@ type Mlp = ( const BATCH_SIZE: usize = 32; fn main() { - // ftz substantially improves performance - dfdx::flush_denormals_to_zero(); - let mnist_path = std::env::args() .nth(1) .unwrap_or_else(|| "./datasets/MNIST/raw".to_string()); From fb91f13314fb24a67c2d8e14ad40345d2d334805 Mon Sep 17 00:00:00 2001 From: rainiwu Date: Fri, 26 Jan 2024 00:55:48 -0800 Subject: [PATCH 02/20] suppress spurious cargo clippy warning --- dfdx-core/src/data/collate.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/dfdx-core/src/data/collate.rs b/dfdx-core/src/data/collate.rs index d38a2a67..5f52d636 100644 --- a/dfdx-core/src/data/collate.rs +++ b/dfdx-core/src/data/collate.rs @@ -55,6 +55,7 @@ impl Collate for Vec<(A, B)> { impl<'a, A, B> Collate for Vec<&'a (A, B)> { type Collated = (Vec<&'a A>, Vec<&'a B>); fn collated(self) -> Self::Collated { + #[allow(clippy::map_identity)] self.into_iter().map(|(a, b)| (a, b)).unzip() } } From a832f51acf89ceafa161c87c6de8bb7e6cb56296 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 31 Jan 2024 01:01:29 -0500 Subject: [PATCH 03/20] Update safetensors module and naming - Makes the safetensors module private. - Doesn't get exported on the preamble, avoiding a naming clash with the safetensors external crate. - Change how and when the period is inserted. - This should make it closer to how the fields are accessed in the code. --- dfdx-core/src/nn_traits/tuples.rs | 4 ++-- dfdx-core/src/nn_traits/vecs.rs | 4 ++-- dfdx-core/src/tensor/mod.rs | 2 +- dfdx-derives/src/lib.rs | 20 ++++++++++++++++---- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/dfdx-core/src/nn_traits/tuples.rs b/dfdx-core/src/nn_traits/tuples.rs index 97e8c7de..205c0419 100644 --- a/dfdx-core/src/nn_traits/tuples.rs +++ b/dfdx-core/src/nn_traits/tuples.rs @@ -25,7 +25,7 @@ macro_rules! tuple_impls { location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, ) { - $(self.$idx.write_safetensors(&format!("{location}{}.", $idx), tensors);)+ + $(self.$idx.write_safetensors(&format!("{location}.{}", $idx), tensors);)+ } } @@ -36,7 +36,7 @@ macro_rules! tuple_impls { location: &str, tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError> { - $(self.$idx.read_safetensors(&format!("{location}{}.", $idx), tensors)?;)+ + $(self.$idx.read_safetensors(&format!("{location}.{}", $idx), tensors)?;)+ Ok(()) } } diff --git a/dfdx-core/src/nn_traits/vecs.rs b/dfdx-core/src/nn_traits/vecs.rs index 803a07d8..593b1a55 100644 --- a/dfdx-core/src/nn_traits/vecs.rs +++ b/dfdx-core/src/nn_traits/vecs.rs @@ -66,7 +66,7 @@ impl crate::nn_traits::SaveSafeTensors for tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, ) { for (i, t) in self.iter().enumerate() { - t.write_safetensors(&format!("{location}{i}."), tensors); + t.write_safetensors(&format!("{location}.{i}"), tensors); } } } @@ -79,7 +79,7 @@ impl crate::nn_traits::LoadSafeTensors for tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError> { for (i, t) in self.iter_mut().enumerate() { - t.read_safetensors(&format!("{location}{i}."), tensors)?; + t.read_safetensors(&format!("{location}.{i}"), tensors)?; } Ok(()) } diff --git a/dfdx-core/src/tensor/mod.rs b/dfdx-core/src/tensor/mod.rs index acc4074a..0163480a 100644 --- a/dfdx-core/src/tensor/mod.rs +++ b/dfdx-core/src/tensor/mod.rs @@ -151,7 +151,7 @@ pub(crate) mod webgpu; pub use numpy::NumpyDtype; mod error; #[cfg(feature = "safetensors")] -pub mod safetensors; +mod safetensors; mod tensorlike; mod unique_id; diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 4eca0d82..7af885f9 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -850,7 +850,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#name.write_safetensors(&format!("{location}{}", #name_str), tensors);) + quote_spanned!(f.span()=>self.#name.write_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), + tensors + );) } else { Default::default() } @@ -866,7 +869,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#index.write_safetensors(&format!("{location}{}", #index), tensors);) + quote_spanned!(f.span()=>self.#index.write_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), + tensors + );) } else { Default::default() } @@ -913,7 +919,10 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#name.read_safetensors(&format!("{location}{}", #name_str), tensors)?;) + quote_spanned!(f.span()=>self.#name.read_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), + tensors + )?;) } else { Default::default() } @@ -928,7 +937,10 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#index.read_safetensors(&format!("{location}{}", #index), tensors)?;) + quote_spanned!(f.span()=>self.#index.read_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), + tensors + )?;) } else { Default::default() } From 901cfe43d634bbc9688f8040b57c2acc3e42c8f7 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 31 Jan 2024 01:05:32 -0500 Subject: [PATCH 04/20] impl core::ops::Sub for Dim types --- dfdx-core/src/shapes/shape.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/dfdx-core/src/shapes/shape.rs b/dfdx-core/src/shapes/shape.rs index 184337cd..c3e27121 100644 --- a/dfdx-core/src/shapes/shape.rs +++ b/dfdx-core/src/shapes/shape.rs @@ -69,6 +69,30 @@ where } } +impl core::ops::Sub> for usize { + type Output = usize; + fn sub(self, _: Const) -> Self::Output { + self.size() - N + } +} +impl core::ops::Sub for Const { + type Output = usize; + fn sub(self, rhs: usize) -> Self::Output { + N - rhs.size() + } +} + +#[cfg(feature = "nightly")] +impl core::ops::Sub> for Const +where + Const<{ M - N }>: Sized, +{ + type Output = Const<{ M - N }>; + fn sub(self, _: Const) -> Self::Output { + Const + } +} + impl core::ops::Mul> for usize { type Output = usize; fn mul(self, _: Const) -> Self::Output { From a14b40b6e3a214c716c83a5fbc244c208a767def Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 31 Jan 2024 01:10:46 -0500 Subject: [PATCH 05/20] add SiLU activation function --- dfdx-core/src/tensor_ops/mod.rs | 2 + dfdx-core/src/tensor_ops/silu/cpu_kernel.rs | 20 ++++++ dfdx-core/src/tensor_ops/silu/cuda_kernel.rs | 15 +++++ dfdx-core/src/tensor_ops/silu/mod.rs | 62 +++++++++++++++++++ dfdx-core/src/tensor_ops/silu/silu.cu | 32 ++++++++++ .../src/tensor_ops/silu/webgpu_kernel.rs | 28 +++++++++ dfdx-core/src/tensor_ops/utilities/device.rs | 1 + 7 files changed, 160 insertions(+) create mode 100644 dfdx-core/src/tensor_ops/silu/cpu_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/silu/cuda_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/silu/mod.rs create mode 100644 dfdx-core/src/tensor_ops/silu/silu.cu create mode 100644 dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..1cb0f38c 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -197,6 +197,7 @@ mod roll; mod select_and_gather; mod sgd; mod sigmoid; +mod silu; mod sin; mod slice; mod softmax; @@ -264,6 +265,7 @@ pub use roll::Roll; pub use select_and_gather::{GatherTo, SelectTo}; pub use sgd::SgdConfig; pub use sigmoid::sigmoid; +pub use silu::silu; pub use sin::sin; pub use slice::slice; pub use softmax::softmax; diff --git a/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs b/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs new file mode 100644 index 00000000..f6f05752 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs @@ -0,0 +1,20 @@ +use crate::tensor_ops::cpu_kernels::UnaryDerivative; + +impl UnaryDerivative for super::SiLUKernelOp { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = false; + + // x / (1 + e^-x) + #[inline(always)] + fn f(&self, x: &F) -> F { + *x / (F::one() + x.neg().exp()) + } + + // (1 + e^-x + x * e^-x) / (1 + e^-x)^2 + // alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2 + #[inline(always)] + fn df(&self, x: &F) -> F { + let exp_nx = x.neg().exp(); + F::one() + exp_nx + *x * exp_nx / (F::one() + exp_nx).powi(2) + } +} diff --git a/dfdx-core/src/tensor_ops/silu/cuda_kernel.rs b/dfdx-core/src/tensor_ops/silu/cuda_kernel.rs new file mode 100644 index 00000000..45bf1385 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/cuda_kernel.rs @@ -0,0 +1,15 @@ +use super::SiLUKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; +use crate::tensor_ops::cuda_kernels::cuda_unary; + +unsafe impl cudarc::driver::DeviceRepr for SiLUKernelOp {} + +const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/silu.ptx")); + +#[cfg(feature = "f16")] +cuda_unary!(SiLUKernelOp, f16, PTX, "silu_fwd_f16", "silu_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(SiLUKernelOp, AMP, PTX, "silu_fwd_f16", "silu_bwd_f16"); +cuda_unary!(SiLUKernelOp, f32, PTX, "silu_fwd_f32", "silu_bwd_f32"); +cuda_unary!(SiLUKernelOp, f64, PTX, "silu_fwd_f64", "silu_bwd_f64"); diff --git a/dfdx-core/src/tensor_ops/silu/mod.rs b/dfdx-core/src/tensor_ops/silu/mod.rs new file mode 100644 index 00000000..97bcce10 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/mod.rs @@ -0,0 +1,62 @@ +mod cpu_kernel; + +#[cfg(feature = "cuda")] +mod cuda_kernel; + +#[cfg(feature = "webgpu")] +mod webgpu_kernel; + +use super::ops::{try_unary_op, UnaryKernel}; +use crate::{shapes::*, tensor::*}; + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct SiLUKernelOp; + +/// [Sigmoid-Weighted Linear Unit (SiLU)](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)). `x * x.sigmoid()` +/// +/// The derivative is `x * sigmoid'(x) + sigmoid(x)`. +/// +/// Examples: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]); +/// let r = t.silu(); +/// ``` +pub fn silu, T: Tape>( + t: Tensor, +) -> Tensor { + t.silu() +} + +impl, T: Tape> Tensor { + /// See [silu] + pub fn silu(self) -> Self { + self.try_silu().unwrap() + } + /// See [silu] + pub fn try_silu(self) -> Result { + try_unary_op(SiLUKernelOp, self) + } +} + +#[cfg(test)] +mod tests { + use crate::{tensor::*, tensor_ops::*, tests::*}; + + #[test] + fn test_silu() { + let dev: TestDevice = Default::default(); + let x = dev + .tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + .to_dtype::(); + let r = x.leaky_trace().silu(); + assert_close_to_literal!(r, [-0.23840584, -0.26894143, 0.0, 0.7310586, 1.761594]); + let g = r.mean().backward(); + assert_close_to_literal!( + g.get(&x), + [1.635814, 0.70433396, 0.4, 0.31289828, 0.26906452] + ); + } +} diff --git a/dfdx-core/src/tensor_ops/silu/silu.cu b/dfdx-core/src/tensor_ops/silu/silu.cu new file mode 100644 index 00000000..d3b01a7e --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/silu.cu @@ -0,0 +1,32 @@ +#include "unary_op_macros.cuh" + +struct SiLUKernelOp {}; + +// x / (1 + e^-x) +template +__device__ __forceinline__ T silu_fwd(T x) { + T one = 1.0; + return x / (one + expg(-x)); +} + +// (1 + e^-x + x * e^-x) / (1 + e^-x)^2 +// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2 +template +__device__ __forceinline__ T silu_bwd(T x) { + T one = 1.0; + T exp_nx = expg(-x); + T denom_sqrt = (one + exp_nx); + return (one + exp_nx + x * exp_nx) / (denom_sqrt * denom_sqrt); +} + +UNARY_OP(__half, silu_fwd_f16, silu_bwd_f16, SiLUKernelOp, + silu_fwd(x), + silu_bwd(x)) + +UNARY_OP(float, silu_fwd_f32, silu_bwd_f32, SiLUKernelOp, + silu_fwd(x), + silu_bwd(x)) + +UNARY_OP(double, silu_fwd_f64, silu_bwd_f64, SiLUKernelOp, + silu_fwd(x), + silu_bwd(x)) diff --git a/dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs new file mode 100644 index 00000000..438850e1 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs @@ -0,0 +1,28 @@ +use std::borrow::Cow; + +use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; + +impl UnaryKernel for Webgpu { + const BACKWARD_WITHOUT_INP: bool = false; + + const BACKWARD_WITHOUT_DATA: bool = false; + + fn forward( + &self, + op: super::SiLUKernelOp, + inp: Cow>, + ) -> Result, crate::prelude::Error> { + todo!() + } + + fn backward( + &self, + op: super::SiLUKernelOp, + inp: &impl crate::prelude::Tensorlike, + grad_inp: &mut Self::Vec, + out: &impl crate::prelude::Tensorlike, + grad_out: &Self::Vec, + ) -> Result<(), crate::prelude::Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..6e9b6ec4 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -92,6 +92,7 @@ pub trait Device: + UnaryKernel + UnaryKernel + UnaryKernel + + UnaryKernel + UnaryKernel + UnaryKernel + UnaryKernel From b52932ccd6942965a9de18cf581b96b243812a92 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 31 Jan 2024 01:54:19 -0500 Subject: [PATCH 06/20] add RMS normalization - Add the try_normalize_rms related functions. - Add the `LayerRMSNorm1D` module. --- dfdx-core/src/tensor_ops/mod.rs | 2 + dfdx-core/src/tensor_ops/normalize_rms.rs | 136 +++++++++++++++++ dfdx/src/nn/layers/layer_rms_norm1d.rs | 169 ++++++++++++++++++++++ dfdx/src/nn/layers/mod.rs | 2 + 4 files changed, 309 insertions(+) create mode 100644 dfdx-core/src/tensor_ops/normalize_rms.rs create mode 100644 dfdx/src/nn/layers/layer_rms_norm1d.rs diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..a649196c 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -184,6 +184,7 @@ mod mul; mod nans_to; mod negate; mod normalize; +mod normalize_rms; pub(super) mod optim; mod permute_to; mod pow; @@ -251,6 +252,7 @@ pub use mul::{mul, TryMul}; pub use nans_to::nans_to; pub use negate::negate; pub use normalize::normalize; +pub use normalize_rms::normalize_rms; pub use optim::*; pub use permute_to::PermuteTo; pub use pow::{powf, powi}; diff --git a/dfdx-core/src/tensor_ops/normalize_rms.rs b/dfdx-core/src/tensor_ops/normalize_rms.rs new file mode 100644 index 00000000..eb70302a --- /dev/null +++ b/dfdx-core/src/tensor_ops/normalize_rms.rs @@ -0,0 +1,136 @@ +use crate::{ + shapes::{Axes, Dtype, ReduceShape, Shape}, + tensor::{Error, Tape, Tensor}, +}; + +use super::{BroadcastTo, Device, MeanTo, TryAdd, TryMul}; + +/// Normalizes `t` to have stddev `1.0` along `Ax`. `epsilon` is used during stddev. +/// Computes `t / (t.square().mean() + epsilon).sqrt()`. +/// +/// Normalizing a single axis: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t: Tensor, f32, _> = dev.zeros(); +/// let _ = t.normalize_rms::>(1e-5); +/// ``` +pub fn normalize_rms< + Ax: Axes, + S: Shape + ReduceShape, + E: Dtype, + D: Device, + T: Tape, +>( + t: Tensor, + epsilon: impl Into, +) -> Tensor { + t.normalize_rms::(epsilon) +} + +impl, T: Tape> Tensor { + /// See [normalize_rms] + pub fn normalize_rms(self, epsilon: impl Into) -> Self + where + S: ReduceShape, + { + self.try_normalize_rms::(epsilon).unwrap() + } + + /// See [normalize_rms] + pub fn try_normalize_rms(self, epsilon: impl Into) -> Result + where + S: ReduceShape, + { + let shape = self.shape; + let sq = self.retaped::().try_square()?; + let sq_mean = sq.try_mean::<_, Ax>()?; + let rsqrt = sq_mean + .try_add(epsilon)? + .try_sqrt()? + .try_recip()? + .try_broadcast_like(&shape)?; + self.try_mul(rsqrt) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::*; + use crate::{shapes::*, tensor::*, tensor_ops::*}; + + #[test] + fn test_1d_normalize_rms_axis_last() { + let dev: TestDevice = Default::default(); + let a = dev.tensor([-2.0, 0.0, 5.0]).to_dtype::(); + let r = a.leaky_trace().normalize_rms(1e-5); + assert_close_to_literal!(&r, [-0.64326715, 0.0, 1.6081679]); + // NOTE: .exp() so we can make sure normalize is using result grad properly + let g = r.exp().mean().backward(); + assert_close_to_literal!(&g.get(&a), [0.23318729, 0.107211195, 0.09327549]); + } + + #[test] + fn test_2d_normalize_rms_axis_last() { + let dev: TestDevice = Default::default(); + let a = dev + .tensor([[-2.0, 0.0, 5.0], [1.0, 2.0, 3.0]]) + .to_dtype::(); + let r = a.leaky_trace().normalize_rms::>(1e-5); + assert_close_to_literal!( + r, + [ + [-0.64326715, 0.0, 1.6081679], + [0.46290955, 0.9258191, 1.3887286] + ] + ); + let g = r.exp().mean().backward(); + assert_close_to_literal!( + g.get(&a), + [ + [0.116593644, 0.053605597, 0.046637744], + [0.019706108, -0.011002079, 0.0007670224] + ] + ); + } + + #[test] + fn test_2d_normalize_rms_axis_first() { + let dev: TestDevice = Default::default(); + let a = dev + .tensor([[-2.0, 0.0], [1.0, 2.0], [4.0, 5.0]]) + .to_dtype::(); + let r = a.leaky_trace().normalize_rms::>(1e-5); + assert_close_to_literal!( + r, + [ + [-0.7559284, 0.0], + [0.3779642, 0.64326715], + [1.5118568, 1.6081679] + ] + ); + let g = r.exp().mean().backward(); + assert_close_to_literal!( + g.get(&a), + [ + [0.14153406, 0.053605597], + [0.03595103, -0.0043795705], + [0.061779693, 0.0017521679] + ] + ); + } + + #[test] + fn test_3d_normalize_rms_axis_last() { + let dev: TestDevice = Default::default(); + let a: Tensor, TestDtype, _> = dev.ones(); + let r = a.leaky_trace().normalize_rms::>(1e-5); + assert_close_to_literal!(r, [[[1.0; 3]; 2]; 4], 1e-5); + let g = r.exp().mean().backward(); + assert_close_to_literal!(g.get(&a), [[[0.0; 3]; 2]; 4], 1e-5); + } +} + +// Implementation references: +// - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328 +// - https://github.com/kroggen/mamba.c/blob/7387f49e352f86a0c22041c0f66fd2a40b58a207/mamba.c#L222 diff --git a/dfdx/src/nn/layers/layer_rms_norm1d.rs b/dfdx/src/nn/layers/layer_rms_norm1d.rs new file mode 100644 index 00000000..17143aef --- /dev/null +++ b/dfdx/src/nn/layers/layer_rms_norm1d.rs @@ -0,0 +1,169 @@ +use crate::prelude::*; + +/// Implements RMS layer normalization as described in [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467). +/// +/// This calls [normalize_rms()] on the last axis of the input to normalize to unit std dev, and then does an element-wise +/// affine transform using learnable parameters. +/// +/// Epsilon is passed to [normalize_rms()] and added to the variance to ensure big enough numbers. It defaults to `1e-5`. +/// +/// Generics: +/// - `M` The size of the affine transform tensors. +/// +/// # Examples +/// ```rust +/// # use dfdx::prelude::*; +/// # use dfdx::*; +/// # let dev: Cpu = Default::default(); +/// type Model = LayerRMSNorm1DConstConfig<5>; +/// let model = dev.build_module::(Model::default()); +/// let _: Tensor, f32, _> = model.forward(dev.zeros::>()); +/// ``` +#[derive(Default, Clone, Copy, Debug)] +#[repr(transparent)] +pub struct LayerRMSNorm1DConfig(pub M); + +/// Compile time sugar alias around [LayerRMSNorm1DConfig] +pub type LayerRMSNorm1DConstConfig = LayerRMSNorm1DConfig>; + +impl> BuildOnDevice for LayerRMSNorm1DConfig { + type Built = LayerRMSNorm1D; + fn try_build_on_device(&self, device: &D) -> Result { + Ok(LayerRMSNorm1D { + gamma: device.try_ones_like(&(self.0,))?, + beta: device.try_zeros_like(&(self.0,))?, + epsilon: 1e-5, + }) + } +} + +/// See [LayerRMSNorm1DConfig] +#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] +pub struct LayerRMSNorm1D> { + #[param] + #[cfg_attr(feature = "safetensors", serialize)] + pub gamma: Tensor<(M,), Elem, Dev>, + #[param] + #[cfg_attr(feature = "safetensors", serialize)] + pub beta: Tensor<(M,), Elem, Dev>, + #[cfg_attr(feature = "safetensors", serialize)] + pub epsilon: f64, +} + +impl> ResetParams for LayerRMSNorm1D { + fn try_reset_params(&mut self) -> Result<(), crate::tensor::Error> { + self.gamma.try_fill_with_ones()?; + self.beta.try_fill_with_zeros()?; + Ok(()) + } +} + +impl, T: Tape> Module> + for LayerRMSNorm1D +{ + type Output = Tensor<(M,), E, D, T>; + fn try_forward(&self, x: Tensor<(M,), E, D, T>) -> Result { + let x = x.try_normalize_rms::>(self.epsilon)?; + let x = self.gamma.retaped::().try_mul(x)?; + self.beta.retaped::().try_add(x) + } +} + +impl, T: Tape> Module> + for LayerRMSNorm1D +{ + type Output = Tensor<(Batch, M), E, D, T>; + fn try_forward(&self, x: Tensor<(Batch, M), E, D, T>) -> Result { + let x = x.try_normalize_rms::>(self.epsilon)?; + let x = self.gamma.retaped::().broadcast_like(&x).try_mul(x)?; + self.beta.retaped::().broadcast_like(&x).try_add(x) + } +} + +impl, T: Tape> + Module> for LayerRMSNorm1D +{ + type Output = Tensor<(Batch, Seq, M), E, D, T>; + fn try_forward(&self, x: Tensor<(Batch, Seq, M), E, D, T>) -> Result { + let x = x.try_normalize_rms::>(self.epsilon)?; + let x = self.gamma.retaped::().broadcast_like(&x).try_mul(x)?; + self.beta.retaped::().broadcast_like(&x).try_add(x) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::*; + + #[test] + fn test_layer_rms_norm_reset() { + let dev: TestDevice = Default::default(); + + let mut m = dev.build_module::(>::default()); + assert_close_to_literal!(m.gamma, [1.0; 5]); + assert_close_to_literal!(m.beta, [0.0; 5]); + + m.gamma = dev.sample_normal(); + m.beta = dev.sample_normal(); + + assert_ne!(m.gamma.array(), [TestDtype::ONE; 5]); + assert_ne!(m.beta.array(), [TestDtype::default(); 5]); + + m.reset_params(); + + assert_close_to_literal!(m.gamma, [1.0; 5]); + assert_close_to_literal!(m.beta, [0.0; 5]); + } + + #[test] + fn test_layer_rms_norm_1d_forward() { + let dev: TestDevice = Default::default(); + let mut m = dev.build_module::(>::default()); + let x = dev.sample_normal::>(); + let r = m.forward_mut(x.leaky_trace()); + assert_close_to_literal!( + r, + [0.53631353, 0.6458002, -1.8330059, 0.12289862, -0.9593052] + ); + let g = r.mean().backward(); + assert_close_to_literal!( + g.get(&m.gamma), + [0.10726271, 0.12916003, -0.3666012, 0.024579724, -0.19186105] + ); + assert_close_to_literal!(g.get(&m.beta), [0.2; 5]); + } + + #[test] + fn test_layer_rms_norm_2d_forward() { + let dev: TestDevice = Default::default(); + let m = dev.build_module::(>::default()); + let x = dev.sample_normal::>(); + let r = m.forward(x.leaky_trace()); + assert_close_to_literal!( + r, + [ + [0.53631353, 0.6458002, -1.8330059, 0.12289862, -0.9593052], + [1.0418473, -1.199064, 0.49583954, 0.5000605, 1.4074267], + [0.90727454, -1.6644237, -0.5176145, 1.0127299, -0.33612955] + ] + ); + let g = r.mean().backward(); + assert_close_to_literal!( + g.get(&m.gamma), + [ + 0.16569571, + -0.14784585, + -0.123652056, + 0.10904594, + 0.0074661337 + ] + ); + assert_close_to_literal!(g.get(&m.beta), [0.2; 5]); + } +} + +// Implementation references: +// - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328 +// - https://github.com/kroggen/mamba.c/blob/7387f49e352f86a0c22041c0f66fd2a40b58a207/mamba.c#L222 diff --git a/dfdx/src/nn/layers/mod.rs b/dfdx/src/nn/layers/mod.rs index 828b1e97..062b9f08 100644 --- a/dfdx/src/nn/layers/mod.rs +++ b/dfdx/src/nn/layers/mod.rs @@ -20,6 +20,7 @@ mod gelu; mod generalized_add; mod generalized_mul; mod layer_norm1d; +mod layer_rms_norm1d; mod leaky_relu; mod linear; mod ln; @@ -73,6 +74,7 @@ pub use gelu::{AccurateGeLU, FastGeLU}; pub use generalized_add::GeneralizedAdd; pub use generalized_mul::GeneralizedMul; pub use layer_norm1d::{LayerNorm1D, LayerNorm1DConfig, LayerNorm1DConstConfig}; +pub use layer_rms_norm1d::{LayerRMSNorm1D, LayerRMSNorm1DConfig, LayerRMSNorm1DConstConfig}; pub use leaky_relu::LeakyReLU; pub use linear::{Linear, LinearConfig, LinearConstConfig}; pub use ln::Ln; From 693b699e37d22d1f02c70c90c0721a6b8dd5a69d Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Thu, 1 Feb 2024 08:04:45 -0500 Subject: [PATCH 07/20] Add split_tensor_along method - Add `TrySplitShapeAlong` and `TrySplitTensorAlong`. - Minor linting and docs fix. TODO - Check if the tape should be returned. If not, it can be removed from the interface. - Add cuda kernel. - Consider a different interface, where it could get split in more than two tensors - possibly stated on a vec. In this way it could get closer to the pytorch interface (chunks). --- .../concat_tensor_along/cpu_kernel.rs | 8 +- .../src/tensor_ops/concat_tensor_along/mod.rs | 8 +- dfdx-core/src/tensor_ops/mod.rs | 4 + .../src/tensor_ops/split_shape_along/mod.rs | 158 ++++++++++ .../split_tensor_along/cpu_kernel.rs | 99 +++++++ .../split_tensor_along/cuda_kernel.rs | 31 ++ .../src/tensor_ops/split_tensor_along/mod.rs | 275 ++++++++++++++++++ .../split_tensor_along/webgpu_kernel.rs | 26 ++ dfdx-core/src/tensor_ops/utilities/device.rs | 3 + 9 files changed, 604 insertions(+), 8 deletions(-) create mode 100644 dfdx-core/src/tensor_ops/split_shape_along/mod.rs create mode 100644 dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/split_tensor_along/mod.rs create mode 100644 dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs diff --git a/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs index e6ab2eb2..25efc27e 100644 --- a/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs @@ -26,11 +26,11 @@ impl super::ConcatAlongKernel for Cpu { let buf = std::sync::Arc::get_mut(&mut c.data).unwrap(); while i < n { for _ in 0..a_n { - buf[i] = a.data[a_idx.next().unwrap()]; + (*buf)[i] = a.data[a_idx.next().unwrap()]; i += 1; } for _ in 0..b_n { - buf[i] = b.data[b_idx.next().unwrap()]; + (*buf)[i] = b.data[b_idx.next().unwrap()]; i += 1; } } @@ -59,11 +59,11 @@ impl super::ConcatAlongKernel for Cpu { let n = grad_out.len(); while i < n { for _ in 0..a_n { - grad_a[a_idx.next().unwrap()] += grad_out[i]; + (*grad_a)[a_idx.next().unwrap()] += grad_out[i]; i += 1; } for _ in 0..b_n { - grad_b[b_idx.next().unwrap()] += grad_out[i]; + (*grad_b)[b_idx.next().unwrap()] += grad_out[i]; i += 1; } } diff --git a/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs index 7462fd2b..9165efba 100644 --- a/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs +++ b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs @@ -19,7 +19,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor, f32, _> = dev.zeros(); /// let b: Tensor, f32, _> = dev.zeros(); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<0>); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<0>); /// ``` /// /// Along Axis 1: @@ -28,7 +28,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor, f32, _> = dev.zeros(); /// let b: Tensor, f32, _> = dev.zeros(); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<1>); /// ``` /// /// # [usize] dims @@ -38,7 +38,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const)); /// let b: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(4, Const)); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<0>).realize(); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<0>).realize(); /// ``` /// /// Along Axis 1: @@ -47,7 +47,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2)); /// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4)); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>).realize(); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<1>).realize(); /// ``` pub trait TryConcatTensorAlong: Sized { type Output; diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..38a03d14 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -200,6 +200,8 @@ mod sigmoid; mod sin; mod slice; mod softmax; +mod split_shape_along; +mod split_tensor_along; mod sqrt; mod square; mod stack; @@ -267,6 +269,8 @@ pub use sigmoid::sigmoid; pub use sin::sin; pub use slice::slice; pub use softmax::softmax; +pub use split_shape_along::TrySplitShapeAlong; +pub use split_tensor_along::TrySplitTensorAlong; pub use sqrt::sqrt; pub use square::square; pub use stack::{AddDim, TryStack}; diff --git a/dfdx-core/src/tensor_ops/split_shape_along/mod.rs b/dfdx-core/src/tensor_ops/split_shape_along/mod.rs new file mode 100644 index 00000000..1421e12f --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_shape_along/mod.rs @@ -0,0 +1,158 @@ +use crate::{shapes::*, tensor::*}; + +/// Split a shape in two along a given axis. +/// +/// # [Const] dims **requires nightly** +/// +/// Along Axis 0: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b): (Rank2<3, 3>, Rank2<4, 3>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<0>, Const::<3>, Const::<4>); +/// ``` +/// +/// Along Axis 1: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b): (Rank2<7, 2>, Rank2<7, 1>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<1>, Const::<2>, Const::<1>); +/// ``` +/// +/// # [usize] dims +/// Along Axis 0: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b) = (7, Const::<3>).split_shape_along(Axis::<0>, 3, 4); +/// assert_eq!(a, (3, Const::<3>)); +/// assert_eq!(b, (4, Const::<3>)); +/// ``` +/// +/// Along Axis 1: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b) = (Const::<7>, 3).split_shape_along(Axis::<1>, 2, 1); +/// assert_eq!(a, (Const::<7>, 2)); +/// assert_eq!(b, (Const::<7>, 1)); +/// ``` +pub trait TrySplitShapeAlong: Shape { + type Output; + + /// Splits self along the given axis. + fn split_shape_along(self, ax: Ax, a: A, b: B) -> Self::Output { + self.try_split_shape_along(ax, a, b).unwrap() + } + /// Fallibly splits self along the given axis. + fn try_split_shape_along(self, ax: Ax, a: A, b: B) -> Result; +} + +macro_rules! impl_split { + ($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => { + impl TrySplitShapeAlong, A, B> + for + ( + $($Head, )* + AB, + $($Tail, )* + ) + where + ($($Head, )* A, $($Tail, )*): Shape::Concrete>, + ($($Head, )* B, $($Tail, )*): Shape::Concrete>, + { + type Output = + ( + ($($Head, )* A, $($Tail, )*), + ($($Head, )* B, $($Tail, )*), + ); + + fn try_split_shape_along(self, _: Axis<$Ax>, a: A, b: B) -> Result { + let dims = self.concrete(); + let mut lhs_dims = dims; + let mut rhs_dims = dims; + lhs_dims[$Ax] = a.size(); + rhs_dims[$Ax] = b.size(); + assert_eq!(dims[$Ax], lhs_dims[$Ax] + rhs_dims[$Ax]); + + Ok(( + <($($Head, )* A, $($Tail, )*)>::from_concrete(&lhs_dims).unwrap(), + <($($Head, )* B, $($Tail, )*)>::from_concrete(&rhs_dims).unwrap(), + )) + } + } + }; +} + +impl_split!(0, 1, [], []); +impl_split!(0, 2, [], [D1]); +impl_split!(0, 3, [], [D1, D2]); +impl_split!(0, 4, [], [D1, D2, D3]); +impl_split!(0, 5, [], [D1, D2, D3, D4]); +impl_split!(0, 6, [], [D1, D2, D3, D4, D5]); + +impl_split!(1, 2, [D0], []); +impl_split!(1, 3, [D0], [D2]); +impl_split!(1, 4, [D0], [D2, D3]); +impl_split!(1, 5, [D0], [D2, D3, D4]); +impl_split!(1, 6, [D0], [D2, D3, D4, D5]); + +impl_split!(2, 3, [D0, D1], []); +impl_split!(2, 4, [D0, D1], [D3]); +impl_split!(2, 5, [D0, D1], [D3, D4]); +impl_split!(2, 6, [D0, D1], [D3, D4, D5]); + +impl_split!(3, 4, [D0, D1, D2], []); +impl_split!(3, 5, [D0, D1, D2], [D4]); +impl_split!(3, 6, [D0, D1, D2], [D4, D5]); + +impl_split!(4, 5, [D0, D1, D2, D3], []); +impl_split!(4, 6, [D0, D1, D2, D3], [D5]); + +impl_split!(5, 6, [D0, D1, D2, D3, D4], []); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_shape() { + let a: (usize, Const<5>) = (5, Const); + let b: (usize, Const<5>) = (3, Const); + assert_eq!( + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + + let a: (Const<5>, Const<5>) = (Const, Const); + let b: (usize, Const<5>) = (3, Const); + assert_eq!( + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + + let a: (usize, Const<5>) = (5, Const); + let b: (Const<3>, Const<5>) = (Const, Const); + assert_eq!( + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + + #[cfg(feature = "nightly")] + { + let a: (Const<5>, Const<5>) = (Const, Const); + let b: (Const<3>, Const<5>) = (Const, Const); + assert_eq!( + (Const::<8>, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + } + } + + #[test] + #[should_panic = "left: 8\n right: 7"] + fn test_split_shape_fails() { + let a: (usize, Const<5>) = (4, Const); + let b: (usize, Const<5>) = (3, Const); + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0); + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs b/dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs new file mode 100644 index 00000000..3e2fa5e1 --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs @@ -0,0 +1,99 @@ +use super::AorB; +use crate::{ + shapes::*, + tensor::{cpu::NdIndex, *}, +}; + +impl super::SplitAlongKernel for Cpu { + fn forward( + &self, + ax: usize, + ab: &Tensor, + a: &mut Tensor, + b: &mut Tensor, + ) -> Result<(), Error> { + let mut a_n = 1; + let mut b_n = 1; + { + let a_idx = NdIndex::new(a.shape, a.strides); + let b_idx = NdIndex::new(b.shape, b.strides); + for i in ax..A::NUM_DIMS { + a_n *= a_idx.shape[i]; + b_n *= b_idx.shape[i]; + } + } + + let n_ab = ab.data.len(); + + let buf_a = std::sync::Arc::get_mut(&mut a.data).unwrap(); + let buf_b = std::sync::Arc::get_mut(&mut b.data).unwrap(); + + let mut i = 0; + let mut k = 0; + let mut ab_idx = NdIndex::new(ab.shape, ab.strides); + while i < n_ab { + for j in 0..a_n { + (*buf_a)[j + k * a_n] = ab.data[ab_idx.next().unwrap()]; + i += 1; + } + for j in 0..b_n { + (*buf_b)[j + k * b_n] = ab.data[ab_idx.next().unwrap()]; + i += 1; + } + k += 1; + } + Ok(()) + } + + fn backward( + &self, + ax: usize, + ab: &GhostTensor, + grad_ab: &mut Self::Vec, + a: &GhostTensor, + b: &GhostTensor, + a_or_b: AorB, + grad_out: &Self::Vec, + ) -> Result<(), Error> { + let a_idx = NdIndex::new(a.shape, a.strides); + let b_idx = NdIndex::new(b.shape, b.strides); + + let mut a_n = 1; + let mut b_n = 1; + for i in ax..A::NUM_DIMS { + a_n *= a_idx.shape[i]; + b_n *= b_idx.shape[i]; + } + + let mut i = 0; + let mut j = 0; + let n = grad_ab.len(); + let mut ab_idx = NdIndex::new(ab.shape, ab.strides); + while i + j < n { + match a_or_b { + AorB::A => { + for _ in 0..a_n { + (*grad_ab)[ab_idx.next().unwrap()] = grad_out[i]; + i += 1; + } + for _ in 0..b_n { + ab_idx.next().unwrap(); + j += 1; + } + } + AorB::B => { + for _ in 0..a_n { + ab_idx.next().unwrap(); + j += 1; + } + for _ in 0..b_n { + (*grad_ab)[ab_idx.next().unwrap()] = grad_out[i]; + i += 1; + } + } + }; + } + + Ok(()) + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs b/dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs new file mode 100644 index 00000000..515f0365 --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs @@ -0,0 +1,31 @@ +use super::AorB; +use crate::{ + shapes::*, + tensor::{Cuda, Error, GhostTensor, Tensor}, +}; +use cudarc::types::CudaTypeName; + +impl super::SplitAlongKernel for Cuda { + fn forward( + &self, + _ax: usize, + _ab: &Tensor, + _a: &mut Tensor, + _b: &mut Tensor, + ) -> Result<(), Error> { + todo!() + } + + fn backward( + &self, + _ax: usize, + _ab: &GhostTensor, + _grad_ab: &mut Self::Vec, + _a: &GhostTensor, + _b: &GhostTensor, + _a_or_b: AorB, + _grad_out: &Self::Vec, + ) -> Result<(), Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/mod.rs b/dfdx-core/src/tensor_ops/split_tensor_along/mod.rs new file mode 100644 index 00000000..ac619301 --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/mod.rs @@ -0,0 +1,275 @@ +use super::split_shape_along::TrySplitShapeAlong; +use crate::{shapes::*, tensor::*}; + +pub(crate) mod cpu_kernel; +#[cfg(feature = "cuda")] +pub(crate) mod cuda_kernel; +#[cfg(feature = "webgpu")] +mod webgpu_kernel; + +/// Split a tensor in two along a given axis. +/// +/// This is the reverse of [TryConcatTensorAlong::concat_tensor_along]. +/// +/// # [Const] dims **requires nightly** +/// +/// Along Axis 0: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor, f32, _> = dev.zeros(); +/// let (a, b, _tape): ( +/// Tensor, f32, _>, +/// Tensor, f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<0>, Const::<2>, Const::<3>); +/// ``` +/// +/// Along Axis 1: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor, f32, _> = dev.zeros(); +/// let (a, b, _tape): ( +/// Tensor, f32, _>, +/// Tensor, f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<1>, Const::<2>, Const::<3>); +/// ``` +/// +/// # [usize] dims +/// Along Axis 0: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor<(usize, Const::<4>), f32, _> = dev.zeros_like(&(5, Const)); +/// let (a, b, _tape): ( +/// Tensor<(usize, Const::<4>), f32, _>, +/// Tensor<(usize, Const::<4>), f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<0>, 2, 3); +/// let a: Tensor, f32, _> = a.realize(); +/// let b: Tensor, f32, _> = b.realize(); +/// ``` +/// +/// Along Axis 1: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor<(Const::<4>, usize), f32, _> = dev.zeros_like(&(Const, 5)); +/// let (a, b, _tape): ( +/// Tensor<(Const::<4>, usize), f32, _>, +/// Tensor<(Const::<4>, usize), f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<1>, 2, 3); +/// let a: Tensor, f32, _> = a.realize(); +/// let b: Tensor, f32, _> = b.realize(); +/// ``` +pub trait TrySplitTensorAlong: Sized { + type Output; + + /// Splits self along the given axis. + fn split_tensor_along(self, ax: Ax, a: A, b: B) -> Self::Output { + self.try_split_tensor_along(ax, a, b).unwrap() + } + /// Fallibly splits self along the given axis. + fn try_split_tensor_along(self, ax: Ax, a: A, b: B) -> Result; +} + +#[derive(Debug, Clone)] +pub enum AorB { + A, + B, +} + +pub trait SplitAlongKernel: Storage { + fn forward( + &self, + ax: usize, + ab: &Tensor, + a: &mut Tensor, + b: &mut Tensor, + ) -> Result<(), Error>; + + #[allow(clippy::too_many_arguments)] + fn backward( + &self, + ax: usize, + ab: &GhostTensor, + grad_ab: &mut Self::Vec, + a: &GhostTensor, + b: &GhostTensor, + a_or_b: AorB, + grad_out: &Self::Vec, + ) -> Result<(), Error>; +} + +impl> TrySplitTensorAlong + for Tensor +where + Ax: Axes, + A: Dim, + B: Dim, + AS: Shape, + BS: Shape, + AB: Shape + TrySplitShapeAlong, + D: SplitAlongKernel + ZerosTensor, +{ + type Output = (Tensor, Tensor, T); + + fn try_split_tensor_along(self, ax: Ax, a: A, b: B) -> Result { + let device = self.device.clone(); + let (a_shape, b_shape) = (*self.shape()).try_split_shape_along(ax, a, b)?; + let ax = Ax::as_array()[0] as usize; + + let (ab, tape) = self.split_tape(); + + let mut at: Tensor = device.try_zeros_like(&a_shape)?; + let mut bt: Tensor = device.try_zeros_like(&b_shape)?; + + ab.device.forward(ax, &ab, &mut at, &mut bt)?; + + let mut ta = T::default(); + let mut tb = T::default(); + + let device_b = device.clone(); + + let ab_ghost = ab.ghost(); + let a_ghost = at.ghost(); + let b_ghost = bt.ghost(); + ta.add_backward_op(move |grads| { + grads.try_alloc_for(&ab_ghost)?; + grads.try_alloc_for(&a_ghost)?; + let (ab_grad, a_grad) = grads.mut_and_ref(&ab_ghost, &a_ghost); + device.backward(ax, &ab_ghost, ab_grad, &a_ghost, &b_ghost, AorB::A, a_grad) + }); + + let ab_ghost = ab.ghost(); + let a_ghost = at.ghost(); + let b_ghost = bt.ghost(); + tb.add_backward_op(move |grads| { + grads.try_alloc_for(&ab_ghost)?; + grads.try_alloc_for(&b_ghost)?; + let (ab_grad, b_grad) = grads.mut_and_ref(&ab_ghost, &b_ghost); + device_b.backward(ax, &ab_ghost, ab_grad, &a_ghost, &b_ghost, AorB::B, b_grad) + }); + + let at = at.put_tape(ta); + let bt = bt.put_tape(tb); + Ok((at, bt, tape)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{tensor_ops::*, tests::*}; + + #[test] + fn test_split_ax_0() { + let dev: TestDevice = Default::default(); + let ab: Tensor, TestDtype, _> = dev.sample_normal(); + let ab_dyn = ab + .leaky_trace() + .try_realize::<(usize, Const<3>, Const<4>)>() + .unwrap(); + let (a, b, _tape) = ab_dyn.split_tensor_along(Axis::<0>, 2, 3); + let a = a.try_realize::<(Const<2>, Const<3>, Const<4>)>().unwrap(); + let b = b.try_realize::<(Const<3>, Const<3>, Const<4>)>().unwrap(); + let ab_arr = ab.array(); + let a_arr = a.array(); + let b_arr = b.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{ab_arr:?}"); + + assert_eq!(ab_arr[0], a_arr[0]); + assert_eq!(ab_arr[1], a_arr[1]); + assert_eq!(ab_arr[2], b_arr[0]); + assert_eq!(ab_arr[3], b_arr[1]); + assert_eq!(ab_arr[4], b_arr[2]); + + let ab_concat = (a, b).concat_tensor_along(Axis::<0>); + assert_eq!(ab.array(), ab_concat.array()); + let concat_grads = ab_concat.exp().sum().backward(); + let ab_grads = ab.leaky_trace().exp().sum().backward(); + + assert_close_to_tensor!(concat_grads.get(&ab), ab_grads.get(&ab)); + } + + #[test] + fn test_split_ax_1() { + let dev: TestDevice = Default::default(); + let ab: Tensor, TestDtype, _> = dev.sample_normal(); + let ab_dyn = ab + .leaky_trace() + .try_realize::<(Const<2>, usize, Const<4>)>() + .unwrap(); + let (a, b, _tape) = ab_dyn.split_tensor_along(Axis::<1>, 2, 3); + let a = a.try_realize::<(Const<2>, Const<2>, Const<4>)>().unwrap(); + let b = b.try_realize::<(Const<2>, Const<3>, Const<4>)>().unwrap(); + let ab_arr = ab.array(); + let a_arr = a.array(); + let b_arr = b.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{ab_arr:?}"); + + for i in 0..2 { + assert_eq!(ab_arr[i][0], a_arr[i][0]); + assert_eq!(ab_arr[i][1], a_arr[i][1]); + assert_eq!(ab_arr[i][2], b_arr[i][0]); + assert_eq!(ab_arr[i][3], b_arr[i][1]); + assert_eq!(ab_arr[i][4], b_arr[i][2]); + } + + let ab_concat = (a, b).concat_tensor_along(Axis::<1>); + assert_eq!(ab.array(), ab_concat.array()); + let concat_grads = ab_concat.exp().sum().backward(); + let ab_grads = ab.leaky_trace().exp().sum().backward(); + + println!("{:?}", concat_grads.get(&ab).array()); + println!("{:?}", ab_grads.get(&ab).array()); + + assert_close_to_tensor!(concat_grads.get(&ab), ab_grads.get(&ab)); + } + + #[test] + fn test_split_ax_2() { + let dev: TestDevice = Default::default(); + let ab: Tensor, TestDtype, _> = dev.sample_normal(); + let ab_dyn = ab + .leaky_trace() + .try_realize::<(Const<2>, Const<3>, usize)>() + .unwrap(); + let (a, b, _tape) = ab_dyn.split_tensor_along(Axis::<2>, 2, 3); + let a = a.try_realize::<(Const<2>, Const<3>, Const<2>)>().unwrap(); + let b = b.try_realize::<(Const<2>, Const<3>, Const<3>)>().unwrap(); + let ab_arr = ab.array(); + let a_arr = a.array(); + let b_arr = b.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{ab_arr:?}"); + + for i in 0..2 { + for j in 0..3 { + assert_eq!(ab_arr[i][j][0], a_arr[i][j][0]); + assert_eq!(ab_arr[i][j][1], a_arr[i][j][1]); + assert_eq!(ab_arr[i][j][2], b_arr[i][j][0]); + assert_eq!(ab_arr[i][j][3], b_arr[i][j][1]); + assert_eq!(ab_arr[i][j][4], b_arr[i][j][2]); + } + } + + let ab_concat = (a, b).concat_tensor_along(Axis::<2>); + assert_eq!(ab.array(), ab_concat.array()); + let concat_grads = ab_concat.exp().sum().backward(); + let ab_grads = ab.leaky_trace().exp().sum().backward(); + + println!("{:?}", concat_grads.get(&ab).array()); + println!("{:?}", ab_grads.get(&ab).array()); + + assert_close_to_tensor!(concat_grads.get(&ab), ab_grads.get(&ab)); + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs new file mode 100644 index 00000000..be1923dd --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs @@ -0,0 +1,26 @@ +use crate::{shapes::*, tensor::*}; + +impl super::ConcatAlongKernel for Webgpu { + fn forward( + &self, + _ax: usize, + _ab: &Tensor, + _a: &mut Tensor, + _b: &mut Tensor, + ) -> Result<(), Error> { + todo!() + } + + fn backward( + &self, + _ax: usize, + _ab: &GhostTensor, + _grad_ab: &mut Self::Vec, + _a: &GhostTensor, + _b: &GhostTensor, + _a_or_b: AorB, + _grad_out: &Self::Vec, + ) -> Result<(), Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..0869b6c1 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -21,6 +21,9 @@ pub trait Device: + super::super::concat_tensor_along::ConcatAlongKernel + super::super::concat_tensor_along::ConcatAlongKernel + // splits + + super::super::split_tensor_along::SplitAlongKernel + // optimizers + super::super::adam::AdamKernel + super::super::sgd::SgdKernel From de5556737ec560990c4278418f1f31f16d532d21 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Thu, 1 Feb 2024 11:25:07 -0500 Subject: [PATCH 08/20] rm unrelated derive --- dfdx/src/nn/layers/layer_rms_norm1d.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx/src/nn/layers/layer_rms_norm1d.rs b/dfdx/src/nn/layers/layer_rms_norm1d.rs index 17143aef..a62fffb9 100644 --- a/dfdx/src/nn/layers/layer_rms_norm1d.rs +++ b/dfdx/src/nn/layers/layer_rms_norm1d.rs @@ -38,7 +38,7 @@ impl> BuildOnDevice for LayerRMSNorm1DConfi } /// See [LayerRMSNorm1DConfig] -#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct LayerRMSNorm1D> { #[param] From ea424c3dc8c19fcb72651a781a6b1251453956cb Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 6 Feb 2024 17:10:30 -0500 Subject: [PATCH 09/20] Added `TryUnstack` for tensors. - Also added `from_fn` for Arrays. Note: the interface currently requires two passes for construction, one for creating a list of tensors with NoneTape and another for putting tapes into those tensors. --- dfdx-core/src/shapes/shape.rs | 15 + dfdx-core/src/tensor_ops/mod.rs | 2 + .../src/tensor_ops/unstack/cpu_kernel.rs | 63 +++++ .../src/tensor_ops/unstack/cuda_kernel.rs | 27 ++ dfdx-core/src/tensor_ops/unstack/mod.rs | 258 ++++++++++++++++++ .../src/tensor_ops/unstack/webgpu_kernel.rs | 20 ++ dfdx-core/src/tensor_ops/utilities/device.rs | 4 + 7 files changed, 389 insertions(+) create mode 100644 dfdx-core/src/tensor_ops/unstack/cpu_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/unstack/cuda_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/unstack/mod.rs create mode 100644 dfdx-core/src/tensor_ops/unstack/webgpu_kernel.rs diff --git a/dfdx-core/src/shapes/shape.rs b/dfdx-core/src/shapes/shape.rs index 184337cd..0a22d547 100644 --- a/dfdx-core/src/shapes/shape.rs +++ b/dfdx-core/src/shapes/shape.rs @@ -121,18 +121,33 @@ where pub trait Array: IntoIterator { type Dim: Dim; fn dim(&self) -> Self::Dim; + fn from_fn(cb: F, len: Self::Dim) -> Self + where + F: FnMut(usize) -> T; } impl Array for [T; N] { type Dim = Const; fn dim(&self) -> Self::Dim { Const } + fn from_fn(cb: F, _len: Self::Dim) -> Self + where + F: FnMut(usize) -> T, + { + std::array::from_fn(cb) + } } impl Array for std::vec::Vec { type Dim = usize; fn dim(&self) -> Self::Dim { self.len() } + fn from_fn(cb: F, len: Self::Dim) -> Self + where + F: FnMut(usize) -> T, + { + (0..len).map(cb).collect() + } } /// A collection of dimensions ([Dim]) that change how a multi-dimensional diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..eaf86eaa 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -209,6 +209,7 @@ mod sum_to; mod tanh; mod to_dtype; mod tri; +mod unstack; mod upscale2d; mod var_to; @@ -276,6 +277,7 @@ pub use sum_to::SumTo; pub use tanh::tanh; pub use to_dtype::{to_dtype, ToDtypeKernel}; pub use tri::{lower_tri, upper_tri}; +pub use unstack::{SubDim, TryUnstack}; pub use upscale2d::{ Bilinear, GenericUpscale2D, NearestNeighbor, TryUpscale2D, Upscale2DKernel, UpscaleMethod, }; diff --git a/dfdx-core/src/tensor_ops/unstack/cpu_kernel.rs b/dfdx-core/src/tensor_ops/unstack/cpu_kernel.rs new file mode 100644 index 00000000..75d8ed0e --- /dev/null +++ b/dfdx-core/src/tensor_ops/unstack/cpu_kernel.rs @@ -0,0 +1,63 @@ +use crate::{ + prelude::NoneTape, + shapes::*, + tensor::{unique_id, Cpu, Error, Tensor}, +}; + +// note: in order to return NoneTape items and not require a tape type information T, +// each element must be optional. +impl super::UnstackKernel for Cpu { + fn forward( + &self, + stack: Tensor, + ) -> Result + where + S: super::SubDim, + OptionalItems: Array>, Dim = S::Head>, + { + let (head, tail) = stack.shape().sub_dim(); + let stack_data = stack.data.as_slice(); + let unstack_num_elements = tail.num_elements(); + Ok(OptionalItems::from_fn( + |i| { + let mut data = self + .try_alloc_elem(unstack_num_elements, E::default()) + // TODO: remove unwrap (needs try_from_fn) + // https://github.com/rust-lang/rust/issues/89379 + .unwrap(); + + data.copy_from_slice( + &stack_data[i * unstack_num_elements..(i + 1) * unstack_num_elements], + ); + + Some(Tensor { + id: unique_id(), + data: std::sync::Arc::new(data), + shape: *tail.shape(), + strides: tail.strides(), + device: self.clone(), + tape: NoneTape, + }) + }, + head, + )) + } + fn backward( + &self, + grad_stack: &mut Self::Vec, + grad_unstack: &Self::Vec, + unstack_idx: usize, + ) -> Result<(), Error> { + let unstack_num_elements = grad_unstack.len(); + for (i, stacked) in grad_stack + .iter_mut() + .skip(unstack_idx * unstack_num_elements) + .take(unstack_num_elements) + .enumerate() + { + *stacked += grad_unstack[i]; + } + + Ok(()) + } +} diff --git a/dfdx-core/src/tensor_ops/unstack/cuda_kernel.rs b/dfdx-core/src/tensor_ops/unstack/cuda_kernel.rs new file mode 100644 index 00000000..bd9e7039 --- /dev/null +++ b/dfdx-core/src/tensor_ops/unstack/cuda_kernel.rs @@ -0,0 +1,27 @@ +use crate::{ + prelude::NoneTape, + shapes::*, + tensor::{Cuda, Error, Tensor}, +}; +use cudarc::types::CudaTypeName; + +impl super::UnstackKernel for Cuda { + fn forward( + &self, + _stack: Tensor, + ) -> Result + where + S: super::SubDim, + OptionalItems: Array>, Dim = S::Head>, + { + todo!() + } + fn backward( + &self, + _grad_stack: &mut Self::Vec, + _grad_unstack: &Self::Vec, + _unstack_idx: usize, + ) -> Result<(), Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/unstack/mod.rs b/dfdx-core/src/tensor_ops/unstack/mod.rs new file mode 100644 index 00000000..0155e555 --- /dev/null +++ b/dfdx-core/src/tensor_ops/unstack/mod.rs @@ -0,0 +1,258 @@ +use crate::{shapes::*, tensor::*}; + +mod cpu_kernel; +#[cfg(feature = "cuda")] +mod cuda_kernel; + +#[cfg(feature = "webgpu")] +mod webgpu_kernel; + +/// Unstack a tensor along the first dimension into an array or vec of tensors. +/// +/// This is the opposite of [crate::prelude::TryStack]. +/// +/// A [Const] dim will be turned into an array of tensors, and +/// a [usize] dim will be turned into a `Vec` of tensors. +/// +/// **Pytorch equivalent** `torch.unbind` with `dim=0`. +/// +/// Unstacking to an array: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let stack: Tensor, f32, _> = dev.zeros(); +/// let [a, b]: [Tensor, f32, _>; 2] = stack.unstack(); +/// ``` +/// +/// Unstacking to a vec: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let stack: Tensor<(usize, Const::<3>, Const::<4>>, f32, _> = dev.zeros_like(&(2, Const, Const)); +/// let unstack: Vec, f32, _>> = stack.unstack(); +/// ``` +pub trait TryUnstack: Sized { + type Unstacked; + + /// Unstack a tensor along the first dimension into an array or vec of tensors. + fn unstack(self) -> Self::Unstacked { + self.try_unstack().unwrap() + } + /// Fallible version of [TryUnstack::unstack] + fn try_unstack(self) -> Result; +} + +impl, T, const N: usize> TryUnstack> + for Tensor +where + S: SubDim>, + T: Tape, +{ + type Unstacked = ([Tensor; N], T); + fn try_unstack(self) -> Result { + try_unstack::<[Option>; N], _, S, E, D, T>(self) + } +} + +impl, T> TryUnstack for Tensor +where + S: SubDim, + T: Tape, +{ + type Unstacked = (Vec>, T); + fn try_unstack(self) -> Result { + try_unstack::>>, _, S, E, D, T>(self) + } +} + +pub trait SubDim: Shape { + type Head: Dim; + type Tail: Shape; + fn sub_dim(&self) -> (Self::Head, Self::Tail); +} + +impl SubDim for (D1,) { + type Head = D1; + type Tail = (); + fn sub_dim(&self) -> (Self::Head, Self::Tail) { + (self.0, ()) + } +} +impl SubDim for (D1, D2) { + type Head = D1; + type Tail = (D2,); + fn sub_dim(&self) -> (Self::Head, Self::Tail) { + (self.0, (self.1,)) + } +} +impl SubDim for (D1, D2, D3) { + type Head = D1; + type Tail = (D2, D3); + fn sub_dim(&self) -> (Self::Head, Self::Tail) { + (self.0, (self.1, self.2)) + } +} +impl SubDim for (D1, D2, D3, D4) { + type Head = D1; + type Tail = (D2, D3, D4); + fn sub_dim(&self) -> (Self::Head, Self::Tail) { + (self.0, (self.1, self.2, self.3)) + } +} +impl SubDim for (D1, D2, D3, D4, D5) { + type Head = D1; + type Tail = (D2, D3, D4, D5); + fn sub_dim(&self) -> (Self::Head, Self::Tail) { + (self.0, (self.1, self.2, self.3, self.4)) + } +} +impl SubDim for (D1, D2, D3, D4, D5, D6) { + type Head = D1; + type Tail = (D2, D3, D4, D5, D6); + fn sub_dim(&self) -> (Self::Head, Self::Tail) { + (self.0, (self.1, self.2, self.3, self.4, self.5)) + } +} + +pub trait UnstackKernel: Storage { + fn forward( + &self, + stack: Tensor, + ) -> Result + where + S: SubDim, + OptionalItems: Array>, Dim = S::Head>; + fn backward( + &self, + grad_stack: &mut Self::Vec, + grad_unstack: &Self::Vec, + unstack_idx: usize, + ) -> Result<(), Error>; +} + +fn try_unstack, T>( + stack: Tensor, +) -> Result<(Items, T), crate::tensor::Error> +where + S: SubDim, + T: Tape, + OptionalItems: Array>, Dim = S::Head> + + std::ops::IndexMut>>, + Items: Array, Dim = S::Head>, +{ + let device = stack.device.clone(); + let (head, _tail) = stack.shape().sub_dim(); + let (stack, stack_tape) = stack.split_tape(); + + let stack_ghost = stack.ghost(); + + // list of optional tensors (all are Some) + let mut unstacks = device.forward::<_, OptionalItems>(stack)?; + + // tensors from unstacks must get tapes inserted into them. + // to do this, from_fn is re-utilized, but this time without optionals + let unstacks = Items::from_fn( + |i| { + let unstack = std::mem::take(&mut unstacks[i]).unwrap(); + let device = device.clone(); + let stack_ghost = stack_ghost.clone(); + let unstack_ghost = unstack.ghost(); + let mut unstack_tape = T::default(); + + unstack_tape.add_backward_op(move |grads| { + grads.try_alloc_for(&stack_ghost)?; + grads.try_alloc_for(&unstack_ghost)?; + let (grad_stack, grad_unstack) = grads.mut_and_ref(&stack_ghost, &unstack_ghost); + device.backward(grad_stack, grad_unstack, i) + }); + unstack.put_tape(unstack_tape) + }, + head, + ); + + Ok((unstacks, stack_tape)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{tensor_ops::*, tests::*}; + + // note: based on a stack test + #[test] + fn test_valid_unstacks() { + let dev: TestDevice = Default::default(); + + { + let stack: Tensor, TestDtype, _> = dev.sample_normal(); + let ([_x, _y, _z], _tape): ([Tensor<(), TestDtype, _>; 3], _) = stack.unstack(); + } + + { + let stack: Tensor, TestDtype, _> = dev.sample_normal(); + let ([_x, _y, _z], _tape): ([Tensor, TestDtype, _>; 3], _) = stack.unstack(); + } + + { + let stack: Tensor<(usize, Const<2>), TestDtype, _> = + dev.sample_normal_like(&(3, Const)); + let (unstacks, _tape): (Vec, _, _, _>>, _) = stack.unstack(); + assert_eq!(unstacks.len(), 3); + } + } + + // note: based on a stack test + #[test] + fn test_unstack_backwards() { + let dev: TestDevice = Default::default(); + let stack: Tensor, TestDtype, _> = dev.sample_normal(); + + let ([x, y, z], _tape): ([Tensor, TestDtype, _, _>; 3], _) = + stack.leaky_trace().unstack(); // r + assert_eq!(stack.array(), [x.array(), y.array(), z.array()]); + + let x1 = x.retaped::(); // r1 + let y1 = y.retaped::(); // r1 + let z1 = z.retaped::(); // r1 + let x1g = x1.leaky_trace().exp().mean().backward(); // g1 + let y1g = y1.leaky_trace().exp().mean().backward(); // g1 + let z1g = z1.leaky_trace().exp().mean().backward(); // g1 + + let xg = x.exp().mean().backward(); // g + let yg = y.exp().mean().backward(); // g + let zg = z.exp().mean().backward(); // g + + let x1_grad = x1g.get(&x1).array(); // r_grad + let y1_grad = y1g.get(&y1).array(); // r_grad + let z1_grad = z1g.get(&z1).array(); // r_grad + + assert_eq!( + [x1_grad, [TestDtype::zero(); 2], [TestDtype::zero(); 2]], + xg.get(&stack).array() + ); + assert_eq!( + [[TestDtype::zero(); 2], y1_grad, [TestDtype::zero(); 2]], + yg.get(&stack).array() + ); + assert_eq!( + [[TestDtype::zero(); 2], [TestDtype::zero(); 2], z1_grad], + zg.get(&stack).array() + ); + + // extra check + let stack_g = stack + .leaky_trace() + .exp() + .mean::<_, Axis<1>>() + .sum() + .backward(); + assert_eq!( + stack_g.get(&stack).array(), + [ + xg.get(&stack).array()[0], + yg.get(&stack).array()[1], + zg.get(&stack).array()[2] + ] + ); + } +} diff --git a/dfdx-core/src/tensor_ops/unstack/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/unstack/webgpu_kernel.rs new file mode 100644 index 00000000..d61512eb --- /dev/null +++ b/dfdx-core/src/tensor_ops/unstack/webgpu_kernel.rs @@ -0,0 +1,20 @@ +use crate::{prelude::NoneTape, shapes::*, tensor::Webgpu}; +use std::vec::Vec; + +impl super::UnstackKernel for Webgpu { + fn forward(&self, stack: Tensor) -> Result + where + S: super::SubDim, + Items: Array>, Dim = S::Head>, + { + todo!() + } + fn backward( + &self, + grad_stack: &mut Self::Vec, + grad_unstack: &Self::Vec, + unstack_idx: usize, + ) -> Result<(), Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..95f4c985 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -21,6 +21,10 @@ pub trait Device: + super::super::concat_tensor_along::ConcatAlongKernel + super::super::concat_tensor_along::ConcatAlongKernel + // splits + + super::super::unstack::UnstackKernel + + super::super::unstack::UnstackKernel + // optimizers + super::super::adam::AdamKernel + super::super::sgd::SgdKernel From 5994ac5ea46e1080bb40391516d7b34f11418fc6 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 6 Feb 2024 18:35:44 -0500 Subject: [PATCH 10/20] fix wgpu signature --- dfdx-core/src/tensor_ops/unstack/webgpu_kernel.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dfdx-core/src/tensor_ops/unstack/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/unstack/webgpu_kernel.rs index d61512eb..26e15937 100644 --- a/dfdx-core/src/tensor_ops/unstack/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/unstack/webgpu_kernel.rs @@ -1,11 +1,14 @@ use crate::{prelude::NoneTape, shapes::*, tensor::Webgpu}; use std::vec::Vec; -impl super::UnstackKernel for Webgpu { - fn forward(&self, stack: Tensor) -> Result +impl super::UnstackKernel for Webgpu { + fn forward>( + &self, + stack: Tensor, + ) -> Result where S: super::SubDim, - Items: Array>, Dim = S::Head>, + Items: Array, Dim = S::Head>, { todo!() } From 5ffff2da7c7dd4697c29da6dda22e59aed3aaca7 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 7 Feb 2024 21:52:22 -0500 Subject: [PATCH 11/20] add continuity requirement for unstack --- dfdx-core/src/tensor_ops/unstack/mod.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/dfdx-core/src/tensor_ops/unstack/mod.rs b/dfdx-core/src/tensor_ops/unstack/mod.rs index 0155e555..21204f99 100644 --- a/dfdx-core/src/tensor_ops/unstack/mod.rs +++ b/dfdx-core/src/tensor_ops/unstack/mod.rs @@ -46,6 +46,7 @@ impl, T, const N: usize> TryUnstack where S: SubDim>, + D: super::reshape_to::ReshapeKernel, T: Tape, { type Unstacked = ([Tensor; N], T); @@ -57,6 +58,7 @@ where impl, T> TryUnstack for Tensor where S: SubDim, + D: super::reshape_to::ReshapeKernel, T: Tape, { type Unstacked = (Vec>, T); @@ -136,6 +138,7 @@ fn try_unstack, T> where S: SubDim, T: Tape, + D: super::reshape_to::ReshapeKernel, OptionalItems: Array>, Dim = S::Head> + std::ops::IndexMut>>, Items: Array, Dim = S::Head>, @@ -144,10 +147,17 @@ where let (head, _tail) = stack.shape().sub_dim(); let (stack, stack_tape) = stack.split_tape(); + // TODO: remove this overhead, and panic on a non-contiguous condition + let stack = { + use super::reshape_to::ReshapeTo; + stack.try_contiguous()? + }; + let stack_ghost = stack.ghost(); // list of optional tensors (all are Some) - let mut unstacks = device.forward::<_, OptionalItems>(stack)?; + + let mut unstacks = UnstackKernel::forward::<_, OptionalItems>(&device, stack)?; // tensors from unstacks must get tapes inserted into them. // to do this, from_fn is re-utilized, but this time without optionals @@ -163,7 +173,7 @@ where grads.try_alloc_for(&stack_ghost)?; grads.try_alloc_for(&unstack_ghost)?; let (grad_stack, grad_unstack) = grads.mut_and_ref(&stack_ghost, &unstack_ghost); - device.backward(grad_stack, grad_unstack, i) + UnstackKernel::backward(&device, grad_stack, grad_unstack, i) }); unstack.put_tape(unstack_tape) }, From e883b28f43f46d00d06b2dee69bfbc1a973a373f Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 11:29:05 -0500 Subject: [PATCH 12/20] Added {load/read/save/write}_safetensor_with methods This alternative method: - Requires load/read to decide whether it should skip missing tensors; - Requires load/read/save/write to decide how should keys be mapped. --- dfdx-core/src/nn_traits/mod.rs | 75 +++++++++++++++++++++++------ dfdx-core/src/nn_traits/tuples.rs | 17 +++++-- dfdx-core/src/nn_traits/vecs.rs | 13 +++-- dfdx-core/src/tensor/safetensors.rs | 13 ++++- dfdx-derives/src/lib.rs | 36 +++++++++----- 5 files changed, 118 insertions(+), 36 deletions(-) diff --git a/dfdx-core/src/nn_traits/mod.rs b/dfdx-core/src/nn_traits/mod.rs index 20c55da2..52203373 100644 --- a/dfdx-core/src/nn_traits/mod.rs +++ b/dfdx-core/src/nn_traits/mod.rs @@ -116,12 +116,13 @@ pub trait ZeroGrads> { #[cfg(feature = "safetensors")] /// Something that can be saved to a .safetensors file. pub trait SaveSafeTensors { - fn save_safetensors>( + fn save_safetensors_with, F: FnMut(String) -> String>( &self, path: P, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { let mut tensors = Vec::new(); - self.write_safetensors("", &mut tensors); + self.write_safetensors_with("", &mut tensors, key_map); let data = tensors.iter().map(|(k, dtype, shape, data)| { ( k.clone(), @@ -131,53 +132,88 @@ pub trait SaveSafeTensors { safetensors::serialize_to_file(data, &None, path.as_ref()) } - fn write_safetensors( + fn save_safetensors>( + &self, + path: P, + ) -> Result<(), safetensors::SafeTensorError> { + self.save_safetensors_with(path, &mut core::convert::identity) + } + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ); + fn write_safetensors( + &self, + location: &str, + tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + ) { + self.write_safetensors_with(location, tensors, &mut core::convert::identity) + } } #[cfg(feature = "safetensors")] /// Something that can be loaded from a .safetensors file. pub trait LoadSafeTensors { - fn load_safetensors>( + fn load_safetensors_with, F: FnMut(String) -> String>( &mut self, path: P, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { let f = std::fs::File::open(path)?; let buffer = unsafe { memmap2::MmapOptions::new().map(&f)? }; let tensors = safetensors::SafeTensors::deserialize(&buffer)?; - self.read_safetensors("", &tensors) + self.read_safetensors_with("", &tensors, skip_missing, key_map) + } + fn load_safetensors>( + &mut self, + path: P, + ) -> Result<(), safetensors::SafeTensorError> { + self.load_safetensors_with(path, false, &mut core::convert::identity) } - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError>; + fn read_safetensors( + &mut self, + location: &str, + tensors: &safetensors::SafeTensors, + ) -> Result<(), safetensors::SafeTensorError> { + self.read_safetensors_with(location, tensors, false, &mut core::convert::identity) + } } #[cfg(feature = "safetensors")] impl, T> LoadSafeTensors for Tensor { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { - self.load_safetensor(tensors, location) + self.load_safetensor(tensors, location, skip_missing, key_map) } } #[cfg(feature = "safetensors")] impl, T> SaveSafeTensors for Tensor { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { + let location = key_map(location.to_string()); tensors.push(( - location.to_string(), + location, ::DTYPE, self.shape.concrete().into(), self.as_vec().iter().flat_map(|e| e.to_le_bytes()).collect(), @@ -189,15 +225,17 @@ macro_rules! unit_safetensors { ($Ty:ty) => { #[cfg(feature = "safetensors")] impl SaveSafeTensors for $Ty { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { + let location = key_map(location.to_string()); #[allow(unused_imports)] use crate::dtypes::ToLeBytes; tensors.push(( - location.to_string(), + location, <$Ty as crate::dtypes::SafeTensorsDtype>::DTYPE, Vec::new(), self.to_le_bytes().to_vec(), @@ -207,14 +245,23 @@ macro_rules! unit_safetensors { #[cfg(feature = "safetensors")] impl LoadSafeTensors for $Ty { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { + let location = key_map(location.to_string()); #[allow(unused_imports)] use crate::dtypes::FromLeBytes; - let view = tensors.tensor(location)?; + let view = match tensors.tensor(&location) { + Ok(ok) => ok, + Err(safetensors::SafeTensorError::TensorNotFound(_name)) if skip_missing => { + return Ok(()); + } + Err(e) => return Err(e), + }; *self = Self::from_le_bytes(view.data().try_into().unwrap()); Ok(()) } diff --git a/dfdx-core/src/nn_traits/tuples.rs b/dfdx-core/src/nn_traits/tuples.rs index 205c0419..7f267482 100644 --- a/dfdx-core/src/nn_traits/tuples.rs +++ b/dfdx-core/src/nn_traits/tuples.rs @@ -20,23 +20,32 @@ macro_rules! tuple_impls { #[cfg(feature = "safetensors")] impl<$($name: crate::nn_traits::SaveSafeTensors, )+> crate::nn_traits::SaveSafeTensors for ($($name,)+) { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { - $(self.$idx.write_safetensors(&format!("{location}.{}", $idx), tensors);)+ + $( + let name = &format!("{location}.{}", $idx); + self.$idx.write_safetensors_with(name, tensors, key_map); + )+ } } #[cfg(feature = "safetensors")] impl<$($name: crate::nn_traits::LoadSafeTensors, )+> crate::nn_traits::LoadSafeTensors for ($($name,)+) { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { - $(self.$idx.read_safetensors(&format!("{location}.{}", $idx), tensors)?;)+ + $( + let name = &format!("{location}.{}", $idx); + self.$idx.read_safetensors_with(name, tensors, skip_missing, key_map)?; + )+ Ok(()) } } diff --git a/dfdx-core/src/nn_traits/vecs.rs b/dfdx-core/src/nn_traits/vecs.rs index 593b1a55..201dd932 100644 --- a/dfdx-core/src/nn_traits/vecs.rs +++ b/dfdx-core/src/nn_traits/vecs.rs @@ -60,26 +60,31 @@ impl, T: crate::nn_traits::ZeroGrads> crate::nn_tra #[cfg(feature = "safetensors")] impl crate::nn_traits::SaveSafeTensors for Vec { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { for (i, t) in self.iter().enumerate() { - t.write_safetensors(&format!("{location}.{i}"), tensors); + let name = &format!("{location}.{i}"); + t.write_safetensors_with(name, tensors, key_map); } } } #[cfg(feature = "safetensors")] impl crate::nn_traits::LoadSafeTensors for Vec { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { for (i, t) in self.iter_mut().enumerate() { - t.read_safetensors(&format!("{location}.{i}"), tensors)?; + let name = &format!("{location}.{i}"); + t.read_safetensors_with(name, tensors, skip_missing, key_map)?; } Ok(()) } diff --git a/dfdx-core/src/tensor/safetensors.rs b/dfdx-core/src/tensor/safetensors.rs index c0566c40..626eaeaa 100644 --- a/dfdx-core/src/tensor/safetensors.rs +++ b/dfdx-core/src/tensor/safetensors.rs @@ -5,12 +5,21 @@ use std::vec::Vec; impl, T> Tensor { /// Loads data from the [SafeTensors] `Storage` with the given `key` - pub fn load_safetensor( + pub fn load_safetensor String>( &mut self, tensors: &SafeTensors, key: &str, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), SafeTensorError> { - let tensor_view = tensors.tensor(key)?; + let key = key_map(key.to_string()); + let tensor_view = match tensors.tensor(&key) { + Ok(ok) => ok, + Err(safetensors::SafeTensorError::TensorNotFound(_name)) if skip_missing => { + return Ok(()); + } + Err(e) => return Err(e), + }; let v = tensor_view.data(); let num_bytes = std::mem::size_of::(); assert_eq!( diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 7af885f9..3c68fcb3 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -196,18 +196,21 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream let safetensors_impls = if cfg!(feature = "safetensors") { quote! { impl #built_impl ::dfdx::nn_traits::SaveSafeTensors for #builder_name #built_ty #built_where { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, ::dfdx::safetensors::Dtype, Vec, Vec)>, + key_map: &mut KeyMap, ) {} } impl #built_impl ::dfdx::nn_traits::LoadSafeTensors for #builder_name #built_ty #built_where { - fn read_safetensors<'a>( + fn read_safetensors_with<'a, KeyMap: FnMut(String) -> String>( &mut self, location: &str, tensors: &::dfdx::safetensors::SafeTensors<'a>, + skip_missing: bool, + key_map: &mut KeyMap, ) -> Result<(), ::dfdx::safetensors::SafeTensorError> { Ok(()) } @@ -850,9 +853,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#name.write_safetensors( + quote_spanned!(f.span()=>self.#name.write_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), - tensors + tensors, + key_map );) } else { Default::default() @@ -869,9 +873,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#index.write_safetensors( + quote_spanned!(f.span()=>self.#index.write_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), - tensors + tensors, + key_map );) } else { Default::default() @@ -890,10 +895,11 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre proc_macro::TokenStream::from(quote! { // note: SaveSafeTensors definition is already gated by the safetensors feature impl #impl_generics ::dfdx::nn_traits::SaveSafeTensors for #name #ty_generics #where_clause { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, ::dfdx::safetensors::Dtype, Vec, Vec)>, + key_map: &mut KeyMap, ) { #save_fields } @@ -919,9 +925,11 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#name.read_safetensors( + quote_spanned!(f.span()=>self.#name.read_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), - tensors + tensors, + skip_missing, + key_map )?;) } else { Default::default() @@ -937,9 +945,11 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#index.read_safetensors( + quote_spanned!(f.span()=>self.#index.read_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), - tensors + tensors, + skip_missing, + key_map )?;) } else { Default::default() @@ -958,10 +968,12 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre proc_macro::TokenStream::from(quote! { // note: LoadSafeTensors definition is already gated by the safetensors feature impl #impl_generics ::dfdx::nn_traits::LoadSafeTensors for #name #ty_generics #where_clause { - fn read_safetensors<'a>( + fn read_safetensors_with<'a, KeyMap: FnMut(String) -> String>( &mut self, location: &str, tensors: &::dfdx::safetensors::SafeTensors<'a>, + skip_missing: bool, + key_map: &mut KeyMap, ) -> Result<(), ::dfdx::safetensors::SafeTensorError> { #load_fields Ok(()) From c695a15eb3472157ea4467884e33b69dce0756a6 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 12:37:46 -0500 Subject: [PATCH 13/20] unstack fixes --- dfdx-core/src/tensor_ops/unstack/mod.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dfdx-core/src/tensor_ops/unstack/mod.rs b/dfdx-core/src/tensor_ops/unstack/mod.rs index 21204f99..528ee1b7 100644 --- a/dfdx-core/src/tensor_ops/unstack/mod.rs +++ b/dfdx-core/src/tensor_ops/unstack/mod.rs @@ -1,4 +1,5 @@ use crate::{shapes::*, tensor::*}; +use std::vec::Vec; mod cpu_kernel; #[cfg(feature = "cuda")] @@ -21,15 +22,15 @@ mod webgpu_kernel; /// # use dfdx_core::prelude::*; /// # let dev: Cpu = Default::default(); /// let stack: Tensor, f32, _> = dev.zeros(); -/// let [a, b]: [Tensor, f32, _>; 2] = stack.unstack(); +/// let ([a, b], _tape): ([Tensor, f32, _>; 2], _) = stack.unstack(); /// ``` /// /// Unstacking to a vec: /// ```rust /// # use dfdx_core::prelude::*; /// # let dev: Cpu = Default::default(); -/// let stack: Tensor<(usize, Const::<3>, Const::<4>>, f32, _> = dev.zeros_like(&(2, Const, Const)); -/// let unstack: Vec, f32, _>> = stack.unstack(); +/// let stack: Tensor<(usize, Const::<3>, Const::<4>), f32, _> = dev.zeros_like(&(2, Const, Const)); +/// let (unstack, _tape): (Vec, f32, _>>, _) = stack.unstack(); /// ``` pub trait TryUnstack: Sized { type Unstacked; From 93202ad2b7dd16a915783334eb959d7bbb51aeef Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 19 Feb 2024 21:47:22 -0500 Subject: [PATCH 14/20] silu: fix cpu df --- dfdx-core/src/tensor_ops/silu/cpu_kernel.rs | 2 +- dfdx-core/src/tensor_ops/silu/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs b/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs index f6f05752..2fcba2fb 100644 --- a/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs @@ -15,6 +15,6 @@ impl UnaryDerivative for super::SiLUKernelOp { #[inline(always)] fn df(&self, x: &F) -> F { let exp_nx = x.neg().exp(); - F::one() + exp_nx + *x * exp_nx / (F::one() + exp_nx).powi(2) + (F::one() + exp_nx + *x * exp_nx) / (F::one() + exp_nx).powi(2) } } diff --git a/dfdx-core/src/tensor_ops/silu/mod.rs b/dfdx-core/src/tensor_ops/silu/mod.rs index 97bcce10..53881079 100644 --- a/dfdx-core/src/tensor_ops/silu/mod.rs +++ b/dfdx-core/src/tensor_ops/silu/mod.rs @@ -56,7 +56,7 @@ mod tests { let g = r.mean().backward(); assert_close_to_literal!( g.get(&x), - [1.635814, 0.70433396, 0.4, 0.31289828, 0.26906452] + [-0.018156849, 0.014465898, 0.1, 0.1855341, 0.21815684] ); } } From eb70a8805bbbd5c15f516b50e907cefc6de171bb Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 19 Feb 2024 21:42:32 -0500 Subject: [PATCH 15/20] allow to load safetensors from a byte array --- dfdx-core/src/nn_traits/mod.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dfdx-core/src/nn_traits/mod.rs b/dfdx-core/src/nn_traits/mod.rs index 52203373..869e1047 100644 --- a/dfdx-core/src/nn_traits/mod.rs +++ b/dfdx-core/src/nn_traits/mod.rs @@ -173,6 +173,21 @@ pub trait LoadSafeTensors { ) -> Result<(), safetensors::SafeTensorError> { self.load_safetensors_with(path, false, &mut core::convert::identity) } + fn load_safetensors_from_bytes_with String>( + &mut self, + bytes: &[u8], + skip_missing: bool, + key_map: &mut F, + ) -> Result<(), safetensors::SafeTensorError> { + let tensors = safetensors::SafeTensors::deserialize(&bytes)?; + self.read_safetensors_with("", &tensors, skip_missing, key_map) + } + fn load_safetensors_from_bytes( + &mut self, + bytes: &[u8], + ) -> Result<(), safetensors::SafeTensorError> { + self.load_safetensors_from_bytes_with(bytes, false, &mut core::convert::identity) + } fn read_safetensors_with String>( &mut self, From fde7a40f5ebce6ea076b2697a931a79791e2bcad Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 6 Feb 2024 18:27:46 -0500 Subject: [PATCH 16/20] avoid conv1d bound for cudnn --- dfdx-core/src/tensor_ops/utilities/device.rs | 50 +++++++++++++++----- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 1a9997fd..c9bfa355 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -120,25 +120,49 @@ pub trait Device: + crate::tensor_ops::axpy::AxpyKernel // conv1d - + super::super::conv1d::Conv1DKernel + + NonCudnnCuda +{ +} + +#[cfg(feature = "cudnn")] +pub trait NonCudnnCuda {} + +#[cfg(not(feature = "cudnn"))] +pub trait NonCudnnCuda: + // conv1d + super::super::conv1d::Conv1DKernel { } #[cfg(feature = "f16")] -impl Device for crate::tensor::Cpu {} -#[cfg(feature = "f16")] -impl Device> for crate::tensor::Cpu {} +mod f16_ { + use super::*; + impl Device for crate::tensor::Cpu {} + impl NonCudnnCuda for crate::tensor::Cpu {} + impl Device> for crate::tensor::Cpu {} + impl NonCudnnCuda> for crate::tensor::Cpu {} +} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} #[cfg(all(feature = "cuda", feature = "f16"))] -impl Device for crate::tensor::Cuda {} -#[cfg(all(feature = "cuda", feature = "f16"))] -impl Device> for crate::tensor::Cuda {} -#[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda_f16 { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device> for crate::tensor::Cuda {} + impl NonCudnnCuda> for crate::tensor::Cuda {} +} #[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} +} // TODO: How can we implement this for f16 when WGSL doesn't support f16 yet? // #[cfg(all(feature = "webgpu", feature = "f16"))] @@ -146,7 +170,11 @@ impl Device for crate::tensor::Cuda {} // #[cfg(all(feature = "webgpu", feature = "f16"))] // impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} +mod webgpu { + use super::*; + impl Device for crate::tensor::Webgpu {} + impl NonCudnnCuda for crate::tensor::Webgpu {} +} // TODO: How can we implement this for f64 when WGSL doesn't support f64 yet? // #[cfg(feature = "webgpu")] From 75d63cd4599e5144ea4502c3183c8e50972a1a3c Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 11:53:40 -0500 Subject: [PATCH 17/20] bump gemm --- dfdx-core/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 5309ef7c..0f6cd5c6 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -35,7 +35,7 @@ num-traits = { workspace = true } safetensors = { workspace = true, optional = true } memmap2 = { workspace = true, optional = true } half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } -gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } +gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } libm = { workspace = true } wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true } From f0bcb9a2920c0e375851015131e51a6a64bdd26c Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 12:52:05 -0500 Subject: [PATCH 18/20] clippy fix --- dfdx-core/src/tensor/gradients.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-core/src/tensor/gradients.rs b/dfdx-core/src/tensor/gradients.rs index 86974ec6..d24e2e32 100644 --- a/dfdx-core/src/tensor/gradients.rs +++ b/dfdx-core/src/tensor/gradients.rs @@ -153,7 +153,7 @@ impl> Gradients { #[inline] pub(crate) fn many_and_ref( &mut self, - ls: &Vec>, + ls: &[impl Tensorlike], r: &impl Tensorlike, ) -> (Vec<&mut D::Vec>, &D::Vec) { for i in 0..ls.len() { From cac2f335b994fd3e7184e8c707cdeb9db5eb4285 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Thu, 1 Feb 2024 19:12:20 -0500 Subject: [PATCH 19/20] Add mamba-minimal - Add stateless forward impl. - Efficient for training (but training is not yet implemented). - Input requires the entire sequence, and requires no state cache. - Generates one output for each input sequence. - Add stateful forward impl. - Efficient for inference. - Input requires the last single sequence point, and requires the last state cache. - Generates a single output referring to the last input. --- dfdx/src/nn/layers/mamba_minimal.rs | 1103 +++++++++++++++++++++++++++ dfdx/src/nn/layers/mod.rs | 5 + 2 files changed, 1108 insertions(+) create mode 100644 dfdx/src/nn/layers/mamba_minimal.rs diff --git a/dfdx/src/nn/layers/mamba_minimal.rs b/dfdx/src/nn/layers/mamba_minimal.rs new file mode 100644 index 00000000..465aa97f --- /dev/null +++ b/dfdx/src/nn/layers/mamba_minimal.rs @@ -0,0 +1,1103 @@ +// references: +// - https://github.com/huggingface/candle/blob/fd7c8565646039e35925b8730d27ddad195d7e73/candle-examples/examples/mamba-minimal/ +// - https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/ + +#![allow(clippy::type_complexity)] + +use dfdx::nn::{ + Bias1D, Bias1DConfig, Conv1D, Conv1DConfig, Linear, LinearConfig, MatMul, MatMulConfig, +}; +use dfdx::prelude::{ + Axes3, Axes4, Axis, BuildOnDevice, Const, Device, Dim, Dtype, Error, HasShape, Module, + NoneTape, PutTape, Tape, Tensor, +}; +use dfdx::tensor_ops::{ + BroadcastTo, PermuteTo, RealizeTo, ReshapeTo, SumTo, TryAdd, TryConcatTensorAlong, TryMatMul, + TryMul, TrySplitTensorAlong, TryStack, TryUnstack, +}; +#[cfg(feature = "safetensors")] +use dfdx::{LoadSafeTensors, SaveSafeTensors}; +use dfdx::{ResetParams, UpdateParams, ZeroGrads}; +use std::ops::{Add, Div, Mul, Sub}; + +pub type C1 = Const<1>; +pub type C2 = Const<2>; +pub type C4 = Const<4>; +pub type C15 = Const<15>; +pub type C16 = Const<16>; + +// +/// A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper. +#[derive(Clone, Debug, Default)] +pub struct MambaBlockConfig< + // Hidden dimension. + DModel: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + // + // Default: 16 + DState: Dim = C16, + // Rank of Δ (See Section 3.6 "Parameterization of ∆" from the Mamba paper). + // Δ or delta: input-dependent step size. + // + // Default: (DModel + 15) / 16 + DtRank: Dim = <>::Output as Div>::Output, + // Default: 4 + DConv: Dim = C4, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + // + // Default: DModel * 2 + DInner: Dim = >::Output, +> where + // DModel + 15 + DModel: Add, + >::Output: Dim, + // (DModel + 15) / 16 + >::Output: Div, + <>::Output as Div>::Output: Dim, + // DModel * 2 + DModel: Mul, + >::Output: Dim, + // DInner * 2 + DInner: Mul, + >::Output: Dim + Default, + // DConv - 1 + DConv: Sub, + >::Output: Dim + Default, + // DState * 2 + DState: Mul, + >::Output: Dim, + // DtRank + DState * 2 + DtRank: Add<>::Output>, + >::Output>>::Output: Dim + Default, +{ + /// Input: DModel. + /// Output: DInner * 2. + pub in_proj: MatMulConfig>::Output>, + + /// Input channel: DInner. + /// Output channel: DInner. + pub conv1d: Conv1DConfig>::Output, C1, DInner>, + + /// Input channel: DInner. + /// Output channel: DInner. + pub conv1d_bias: Bias1DConfig, + + /// Takes in the state and outputs the input-specific Δ, B, C. + /// + /// Input: DInner. + /// Output: DtRank + DState * 2. + pub x_proj: MatMulConfig>::Output>>::Output>, + + /// Projects Δ from DT_RANK to D_INNER + /// + /// Input: DtRank. + /// Output: DInner. + pub dt_proj: LinearConfig, + + pub a_log: (DInner, DState), + + pub d: (DInner,), + + // TODO: this could have a bias (becoming a Linear layer) + // ref: https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L203 + // + /// Input: DInner. + /// Output: DModel. + pub out_proj: MatMulConfig, +} + +// +/// A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper. +pub type MambaBlockConstConfig< + // Hidden dimension. + const D_MODEL: usize, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + const D_STATE: usize = 16, + const DT_RANK: usize = { (D_MODEL + 15) / 16 }, + const D_CONV: usize = 4, + const D_INNER: usize = { D_MODEL * 2 }, +> = MambaBlockConfig< + // + Const, + Const, + Const, + Const, + Const, +>; + +impl< + // Hidden dimension. + DModel: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + // + // Default: 16 + DState: Dim, + // Rank of Δ (See Section 3.6 "Parameterization of ∆" from the Mamba paper). + // Δ or delta: input-dependent step size. + // + // Default: (DModel + 15) / 16 + DtRank: Dim, + // Default: 4 + DConv: Dim, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + // + // Default: DModel * 2 + DInner: Dim, + > MambaBlockConfig +where + // DModel + 15 + DModel: Add, + >::Output: Dim, + // (DModel + 15) / 16 + >::Output: Div, + <>::Output as Div>::Output: Dim, + // DModel * 2 + DModel: Mul, + >::Output: Dim, + // DInner * 2 + DInner: Mul, + >::Output: Dim + Default, + // DConv - 1 + DConv: Sub, + >::Output: Dim + Default, + // DState * 2 + DState: Mul, + >::Output: Dim, + // DtRank + DState * 2 + DtRank: Add<>::Output>, + >::Output>>::Output: Dim + Default, +{ + pub fn new( + d_model: DModel, + d_state: DState, + dt_rank: DtRank, + d_conv: DConv, + d_inner: DInner, + ) -> Self { + MambaBlockConfig { + in_proj: MatMulConfig { + inp: d_model, + out: d_inner * Const::<2>, + }, + conv1d: Conv1DConfig { + in_chan: d_inner, + out_chan: d_inner, + kernel_size: d_conv, + stride: Const::<1>, + padding: d_conv - Const::<1>, + dilation: Const::<1>, + groups: d_inner, + }, + conv1d_bias: Bias1DConfig(d_inner), + x_proj: MatMulConfig { + inp: d_inner, + out: dt_rank + d_state * Const::<2>, + }, + dt_proj: LinearConfig { + inp: dt_rank, + out: d_inner, + }, + a_log: (d_inner, d_state), + d: (d_inner,), + out_proj: MatMulConfig { + inp: d_inner, + out: d_model, + }, + } + } +} + +// +/// A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper. +#[derive(Clone, Debug, ResetParams, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] +pub struct MambaBlock< + // Hidden dimension. + DModel: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + DState: Dim, + // Rank of Δ (See Section 3.6 "Parameterization of ∆" from the Mamba paper). + // Δ or delta: input-dependent step size. + DtRank: Dim, + DConv: Dim, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + // By default, expand is implicitly `2`. + DInner: Dim, + Elem: Dtype, + Dev: Device, +> where + // DInner can be divided by itself + DInner: Div, + >::Output: Dim, + // DInner * 2 + DInner: Mul, + >::Output: Dim, + // DConv - 1 + DConv: Sub, + >::Output: Dim + Default, + // DState * 2 + DState: Mul, + >::Output: Dim, + // DtRank + DState * 2 + DtRank: Add<>::Output>, + >::Output>>::Output: Dim + Default, +{ + /// Input: DModel. + /// Output: DInner * 2. + #[module] + #[cfg_attr(feature = "safetensors", serialize)] + pub in_proj: MatMul>::Output, Elem, Dev>, + + // TODO: is the padding correct? (DConv - 1) + // is it different in here? + // https://github.com/kroggen/mamba-cpu/blob/d12b23b059d249b7077ad080679ae918c9a45caf/mamba_ssm/modules/mamba_simple.py#L103 + // + /// Input channel: DInner. + /// Output channel: DInner. + #[module] + #[cfg_attr(feature = "safetensors", serialize)] + pub conv1d: + Conv1D>::Output, C1, DInner, Elem, Dev>, + + #[module] + #[cfg_attr(feature = "safetensors", serialize)] + pub conv1d_bias: Bias1D, + + /// Takes in the state and outputs the input-specific Δ, B, C. + /// + /// Input: DInner. + /// Output: DtRank + DState * 2. + #[module] + #[cfg_attr(feature = "safetensors", serialize)] + pub x_proj: MatMul>::Output>>::Output, Elem, Dev>, + + /// Projects Δ. + /// + /// Input: DtRank. + /// Output: DInner. + #[module] + #[cfg_attr(feature = "safetensors", serialize)] + pub dt_proj: Linear, + + #[param] + #[cfg_attr(feature = "safetensors", serialize)] + pub a_log: Tensor<(DInner, DState), Elem, Dev>, + + #[param] + #[cfg_attr(feature = "safetensors", serialize)] + pub d: Tensor<(DInner,), Elem, Dev>, + + // TODO: this could have a bias (becoming a Linear layer) + // ref: https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L203 + /// Input: DInner. + /// Output: DModel. + #[module] + #[cfg_attr(feature = "safetensors", serialize)] + pub out_proj: MatMul, +} + +impl> + BuildOnDevice for MambaBlockConfig +where + // DModel + 15 + DModel: Add, + >::Output: Dim, + // (DModel + 15) / 16 + >::Output: Div, + <>::Output as Div>::Output: Dim, + // DModel * 2 + DModel: Mul, + >::Output: Dim, + // DInner can be divided by itself + DInner: Div, + >::Output: Dim, + // DInner * 2 + DInner: Mul, + >::Output: Dim + Default, + // DConv - 1 + DConv: Sub, + >::Output: Dim + Default, + // DState * 2 + DState: Mul, + >::Output: Dim, + // DtRank + DState * 2 + DtRank: Add<>::Output>, + >::Output>>::Output: Dim + Default, +{ + type Built = MambaBlock; + fn try_build_on_device(&self, device: &D) -> Result { + Ok(MambaBlock { + in_proj: self.in_proj.try_build_on_device(device)?, + conv1d: self.conv1d.try_build_on_device(device)?, + conv1d_bias: self.conv1d_bias.try_build_on_device(device)?, + x_proj: self.x_proj.try_build_on_device(device)?, + dt_proj: self.dt_proj.try_build_on_device(device)?, + a_log: device.try_zeros_like(&self.a_log)?, + d: device.try_zeros_like(&self.d)?, + out_proj: self.out_proj.try_build_on_device(device)?, + }) + } +} + +pub mod stateless { + use super::*; + + #[allow(clippy::let_unit_value)] + impl< + // Batch size (`B` in Algorithm 2 from the Mamba paper). + Batch: Dim, + // Sequence length (`L` in Algorithm 2 from the Mamba paper). + Sequence: Dim, + // Hidden dimension. + DModel: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + DState: Dim, + // Rank of Δ (See Section 3.6 "Parameterization of ∆" from the Mamba paper). + // Δ or delta: input-dependent step size. + DtRank: Dim, + DConv: Dim, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + DInner: Dim, + E: Dtype, + D: Device, + T: Tape, + > Module> + for MambaBlock + where + // DInner can be divided by itself + DInner: Div, + >::Output: Dim, + // DInner * 2 + DInner: Mul, + >::Output: Dim, + // DInner * 2 / 2 = DInner + >::Output: Div, + // DConv - 1 + DConv: Sub, + >::Output: Dim + Default, + // DState * 2 + DState: Mul, + >::Output: Dim, + // DtRank + DState * 2 + DtRank: Add<>::Output>, + >::Output>>::Output: Dim + Default, + // layer 2 (conv1d) + // used to truncate back to Sequence: Sequence + DConv + Sequence: Add, + >::Output: Dim, + // used to truncate back to Sequence: Sequencec + DConv - 1 + >::Output: Sub, + <>::Output as Sub>::Output: Dim, + Conv1D< + // in channel + DInner, + // out chanel + DInner, + // kernel + DConv, + // stride + C1, + // padding = DConv - 1 + >::Output, + // dillation + C1, + // groups + DInner, + E, + D, + >: Module< + Tensor<(Batch, DInner, Sequence), E, D, T>, + Output = Tensor< + // (Batch, DInner, Sequence + DConv - 1) + // but this is later truncated back to (Batch, DInner, Sequence) + ( + Batch, + DInner, + <>::Output as Sub>::Output, + ), + E, + D, + T, + >, + >, + // conv1d bias + Bias1D: Module< + Tensor<(Batch, DInner, Sequence), E, D, T>, + Output = Tensor<(Batch, DInner, Sequence), E, D, T>, + >, + // dt_proj bias + // (this needs to be defined otherwise Rust thinks this should behave the same as conv1d bias) + Bias1D: Module< + Tensor<(Batch, Sequence, DtRank), E, D, T>, + Output = Tensor<(Batch, Sequence, DtRank), E, D, T>, + >, + { + type Output = Tensor<(Batch, Sequence, DModel), E, D, T>; + + // + /// Mamba block forward. + /// This looks the same as Figure 3 in Section 3.4 in the Mamba paper. + fn try_forward( + &self, + x: Tensor<(Batch, Sequence, DModel), E, D, T>, + ) -> Result { + let (batch, sequence, _d_model) = *x.shape(); + let (d_inner,) = *self.d.shape(); + + // layer 1 (in_proj) + let (xs, res): ( + Tensor<(Batch, Sequence, DInner), _, _, _>, + Tensor<(Batch, Sequence, DInner), _, _, _>, + ) = { + // projects the input DModel into 2*DInner + let xs_and_res: Tensor<(Batch, Sequence, >::Output), _, _, _> = + self.in_proj.try_forward(x)?; + + // splits xs_and_res into (xs, res) + let (xs, res, _tape) = + xs_and_res.try_split_tensor_along(Axis::<2>, d_inner, d_inner)?; + + (xs, res) + }; + + // layer 2 (conv1d) + let xs: Tensor<(Batch, Sequence, DInner), _, _, _> = { + let xs: Tensor<(Batch, DInner, Sequence), _, _, _> = + xs.try_permute::<_, Axes3<0, 2, 1>>()?; + let xs: Tensor<(Batch, DInner, _), _, _, _> = + self.conv1d.try_forward(xs.try_contiguous()?)?; + + // assert shape + { + let (_, _, d_conv) = self.conv1d.weight.shape(); + let xs_shape = xs.shape(); + debug_assert_eq!( + ( + batch.size(), + d_inner.size(), + sequence.size() + d_conv.size() - 1 + ), + (xs_shape.0.size(), xs_shape.1.size(), xs_shape.2.size()) + ); + } + + // make the last axis be limited to the size of 0..sequence + let (_d_inner, _, d_conv) = *self.conv1d.weight.shape(); + let (xs, _tail, _tape): (Tensor<(Batch, DInner, Sequence), _, _, _>, _, _) = + xs.try_split_tensor_along(Axis::<2>, sequence, d_conv - Const::<1>)?; + + // conv1d bias, and restore original positioning as per before the layer 2 + let xs: Tensor<(Batch, Sequence, DInner), _, _, _> = + xs.try_permute::<_, Axes3<0, 2, 1>>()?; + let xs = self.conv1d_bias.try_forward(xs)?; + + // activation + xs.try_silu()? + }; + + let ss = ss( + self.a_log.retaped::(), + self.d.retaped::(), + xs, + &self.x_proj, + &self.dt_proj, + )?; + + let ys = ss.try_mul(res.try_silu()?)?; + + let y: Tensor<(Batch, Sequence, DModel), _, _, _> = self.out_proj.try_forward(ys)?; + Ok(y) + } + } + + /// Runs the SSM. See: + /// - Algorithm 2 in Section 3.2 from the Mamba paper; + /// - run_SSM(A, B, C, u) from The Annotated S4. + /// + pub fn ss< + // Batch size (`B` in Algorithm 2 from the Mamba paper). + Batch: Dim, + // Sequence length (`L` in Algorithm 2 from the Mamba paper). + Sequence: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + DState: Dim, + // Rank of Δ (See Section 3.6 "Parameterization of ∆" from the Mamba paper). + // Δ or delta: input-dependent step size. + DtRank: Dim, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + DInner: Dim, + E: Dtype, + D: Device, + T: Tape, + >( + a: Tensor<(DInner, DState), E, D, T>, + d: Tensor<(DInner,), E, D, T>, + u: Tensor<(Batch, Sequence, DInner), E, D, T>, + x_proj: &MatMul>::Output>>::Output, E, D>, + dt_proj: &Linear, + ) -> Result, dfdx::tensor::Error> + where + // used to truncate back to DtRank: DState * 2 + DState: Mul, + >::Output: Dim, + // used to truncate back to DtRank: DtRank + DState * 2 + DtRank: Add<>::Output>, + >::Output>>::Output: Dim + Default, + { + let device = u.device().clone(); + + let (_d_inner, d_state) = *a.shape(); + let (_d_inner, dt_rank) = *dt_proj.weight.shape(); + + // Compute ∆ A B C D, the state space parameters. + + // A + // this is input independent (see Section 3.5.2 "Interpretation of A" form the Mamba paper for why A isn't selective) + let a: Tensor<(DInner, DState), _, _, _> = a.try_exp()?.try_negate()?; + + // (Batch, Sequence, DtRank + DState * 2) + let x_dbl: Tensor<(Batch, Sequence, _), _, _, _> = x_proj.try_forward(u.retaped::())?; + + // ∆ (part 1/2) + // ∆ is input-dependent + let (delta, x_dbl_tail, _tape): (Tensor<(Batch, Sequence, DtRank), _, _, _>, _, _) = + x_dbl.try_split_tensor_along(Axis::<2>, dt_rank, d_state * Const::<2>)?; + + // B and C + // B and C are input-dependent + let (b, c, _tape): ( + Tensor<(Batch, Sequence, DState), _, _, _>, + Tensor<(Batch, Sequence, DState), _, _, _>, + _, + ) = x_dbl_tail.try_split_tensor_along(Axis::<2>, d_state, d_state)?; + + // ∆ (part 2/2) + // ∆ is input-dependent + let delta: Tensor<(Batch, Sequence, DInner), _, _, _> = { + let delta = dt_proj.try_forward(delta)?; + // softplus without threshold + // TODO: consider the threshold + let one = device.ones_like(&delta); + (delta.try_exp()?.try_add(one)?).try_ln()? + }; + + selective_scan( + delta.try_permute::<_, Axes3<0, 2, 1>>()?, + a, + b, + c.try_permute::<_, Axes3<1, 0, 2>>()?, + d, + u, + ) + } + + /// Selective Scan. + /// + /// Does selective scan algorithm. See: + /// - Section 2 State Space Models from the Mamba paper; + /// - Algorithm 2 in Section 3.2 from the Mamba paper; + /// - run_SSM(A, B, C, u) from The Annotated S4. + pub fn selective_scan< + // Batch size (`B` in Algorithm 2 from the Mamba paper). + Batch: Dim, + // Sequence length (`L` in Algorithm 2 from the Mamba paper). + Sequence: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + DState: Dim, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + DInner: Dim, + E: Dtype, + D: Device, + T: Tape, + >( + delta: Tensor<(Batch, DInner, Sequence), E, D, T>, + a: Tensor<(DInner, DState), E, D, T>, + b: Tensor<(Batch, Sequence, DState), E, D, T>, + c: Tensor<(Sequence, Batch, DState), E, D, T>, + d: Tensor<(DInner,), E, D, T>, + u: Tensor<(Batch, Sequence, DInner), E, D, T>, + ) -> Result, dfdx::tensor::Error> { + let device = delta.device().clone(); + + let (batch, d_inner, sequence) = *delta.shape(); + let (_d_inner, d_state) = *a.shape(); + + // Discretize continuous parameters (A, B) + // - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper) + // - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: + // "A is the more important term and the performance doesn't change much with the simplification on B" + let (delta_a, delta_bu): ( + Tensor<(Batch, DInner, Sequence, DState), _, _, _>, + Tensor<(Batch, DInner, Sequence, DState), _, _, _>, + ) = { + let target_shape = (batch, d_inner, sequence, d_state); + + let delta_shape = delta.try_broadcast_like(&target_shape)?; + + let a = a.try_broadcast_like(&target_shape)?; + let delta_a: Tensor<(Batch, DInner, Sequence, DState), _, _, _> = + delta_shape.retaped::().try_mul(a)?.try_exp()?; + + let b = b.try_broadcast_like(&target_shape)?; + let delta_bu = delta_shape.try_mul(b)?; + + let u_bu = u + .retaped::() + .try_permute::<_, Axes3<0, 2, 1>>()? + .try_broadcast_like(&target_shape)?; + let delta_bu = delta_bu.try_mul(u_bu)?; + + (delta_a, delta_bu) + }; + + // Perform selective scan (see scan_SSM() from The Annotated S4) + // Note that the below is sequential, while the official implementation does a much faster parallel scan that + // is additionally hardware-aware (like FlashAttention). + + let mut xs: Tensor<(Batch, DInner, DState), E, _, _> = device + .zeros_like(&(batch, d_inner, d_state)) + .put_tape(T::default()); + let mut ys: Vec> = Vec::with_capacity(sequence.size()); + + // permute so that the Sequence refers to the first axis + let delta_a: Tensor<(Sequence, Batch, DInner, DState), _, _, _> = + delta_a.try_permute::<_, Axes4<2, 0, 1, 3>>()?; + let delta_bu: Tensor<(Sequence, Batch, DInner, DState), _, _, _> = + delta_bu.try_permute::<_, Axes4<2, 0, 1, 3>>()?; + + // unstack the Sequence axis + // + // delta A + let delta_a: Tensor<(usize, Batch, DInner, DState), _, _, _> = match delta_a.try_realize() { + Ok(delta_a) => delta_a, + Err(_delta_a) => unreachable!(), + }; + let (delta_a, _delta_a_tape): (Vec>, _) = + delta_a.try_contiguous()?.try_unstack()?; + // + // delta B + let delta_bu: Tensor<(usize, Batch, DInner, DState), _, _, _> = match delta_bu.try_realize() + { + Ok(delta_bu) => delta_bu, + Err(_delta_bu) => unreachable!(), + }; + let (delta_bu, _delta_bu_tape): (Vec>, _) = + delta_bu.try_contiguous()?.try_unstack()?; + // + // C + let c: Tensor<(usize, Batch, DState, C1), _, _, _> = match c + .try_broadcast_like(&(sequence, batch, d_state, Const::<1>))? + .try_realize() + { + Ok(c) => c, + Err(_c) => unreachable!(), + }; + let (c, _c_tape): (Vec>, _) = c.try_unstack()?; + + // loop over the sequence + for ((delta_a, delta_bu), c) in delta_a + .into_iter() + .zip(delta_bu.into_iter()) + .zip(c.into_iter()) + { + xs = xs.retaped::().try_mul(delta_a)?.try_add(delta_bu)?; + let y: Tensor<(Batch, DInner), _, _, _> = xs + .retaped::() + .try_matmul(c)? + .try_reshape_like(&(batch, d_inner))?; + ys.push(y); + } + + let ys: Tensor<(Batch, Sequence, DInner), _, _, _> = if let Ok(ys) = ys + .try_stack()? + .try_permute::<_, Axes3<1, 0, 2>>()? + .try_realize::<(Batch, Sequence, DInner)>() + { + ys + } else { + // TODO + // try_realize whould never fail in this case? + todo!(); + }; + + // D + let d: Tensor<(Batch, Sequence, DInner), _, _, T> = + d.try_broadcast_like(&(batch, sequence, d_inner))?; + let u = u; + let du = d.try_mul(u)?; + + let ys: Tensor<(Batch, Sequence, DInner), _, _, _> = ys.try_add(du)?; + Ok(ys) + } +} + +pub mod stateful { + // additional references: + // - https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py + // - https://github.com/kroggen/mamba.c/blob/learning/mamba.c + // - https://github.com/kroggen/mamba-cpu/blob/recurrent-only/mamba_ssm/mamba_simple.py + + use super::*; + + #[derive(Clone, Debug, ResetParams)] + pub struct MambaStateCacheConfig< + Batch: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + DState: Dim, + DConv: Dim, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + DInner: Dim, + > { + pub conv_state: (Batch, DInner, DConv), + pub ssm_state: (Batch, DInner, DState), + } + + #[derive(Debug, Clone, ResetParams)] + pub struct MambaStateCache< + Batch: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + DState: Dim, + DConv: Dim, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + DInner: Dim, + E: Dtype, + D: Device, + T: Tape, + > { + pub conv_state: Tensor<(Batch, DInner, DConv), E, D, T>, + pub ssm_state: Tensor<(Batch, DInner, DState), E, D, T>, + } + + impl + MambaStateCacheConfig + { + pub fn new(batch: Batch, d_state: DState, d_conv: DConv, d_inner: DInner) -> Self { + Self { + conv_state: (batch, d_inner, d_conv), + ssm_state: (batch, d_inner, d_state), + } + } + } + + impl> + BuildOnDevice for MambaStateCacheConfig + { + type Built = MambaStateCache; + fn try_build_on_device(&self, device: &D) -> Result { + Ok(MambaStateCache { + conv_state: device.try_zeros_like(&self.conv_state)?, + ssm_state: device.try_zeros_like(&self.ssm_state)?, + }) + } + } + + #[allow(clippy::let_unit_value)] + impl< + // Batch size (`B` in Algorithm 2 from the Mamba paper). + Batch: Dim, + // Hidden dimension. + DModel: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + DState: Dim, + // Rank of Δ (See Section 3.6 "Parameterization of ∆" from the Mamba paper). + // Δ or delta: input-dependent step size. + DtRank: Dim, + DConv: Dim, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + DInner: Dim, + E: Dtype, + D: Device, + T: Tape, + > + Module<( + Tensor<(Batch, DModel), E, D, T>, + MambaStateCache, + )> for MambaBlock + where + // DInner can be divided by itself + DInner: Div, + >::Output: Dim, + // DInner * 2 + DInner: Mul, + >::Output: Dim, + // DInner * 2 / 2 = DInner + >::Output: Div, + // DConv - 1 + DConv: Sub, + >::Output: Dim + Default, + // DState * 2 + DState: Mul, + >::Output: Dim, + // DtRank + DState * 2 + DtRank: Add<>::Output>, + >::Output>>::Output: Dim + Default, + // layer 2 (conv1d) + ( + ( + Batch, + DInner, + >>::Output, + ), + (Batch, DInner, dfdx_core::shapes::Const<1>), + ): dfdx_core::tensor_ops::TryConcatShapeAlong, Output = (Batch, DInner, DConv)>, + { + type Output = ( + Tensor<(Batch, DModel), E, D, T>, + MambaStateCache, + ); + + /// Mamba block forward. + fn try_forward( + &self, + x: ( + Tensor<(Batch, DModel), E, D, T>, + MambaStateCache, + ), + ) -> Result { + let (x, mut cache) = x; + + let (batch, d_inner, d_conv) = *cache.conv_state.shape(); + + // layer 1 (in_proj) + let (xs, res): ( + Tensor<(Batch, DInner), _, _, _>, + Tensor<(Batch, DInner), _, _, _>, + ) = { + // projects the input DModel into 2*DInner + let xs_and_res: Tensor<(Batch, >::Output), _, _, _> = + self.in_proj.try_forward(x)?; + + // splits xs_and_res into (xs, res) + let (xs, res, _tape) = + xs_and_res.try_split_tensor_along(Axis::<1>, d_inner, d_inner)?; + + (xs, res) + }; + + // layer 2 (conv1d) + // + // needs to replace the first column of cache.conv_state with + // the new input and roll it so it's the last column + cache.conv_state = { + // not sure if there is a way to directly replace just a single column, + // so the workaround is first to split away the first column (by the left side) + let (_head, conv_state, _tape): ( + _, + Tensor<(Batch, DInner, >::Output), _, _, _>, + _, + ) = cache.conv_state.try_split_tensor_along( + Axis::<2>, + Const::<1>, + d_conv - Const::<1>, + )?; + // then concat with the xs as the last column (by the right side) + let xs: Tensor<(Batch, DInner, C1), _, _, _> = + xs.try_reshape_like(&(batch, d_inner, Const::<1>))?; + (conv_state, xs).try_concat_tensor_along(Axis::<2>)? + }; + + let xs: Tensor<(Batch, DInner), E, _, _> = { + let conv1d = self + .conv1d + .weight + .clone() + .try_reshape_like(&(d_inner, d_conv))? + .try_broadcast_like(&(batch, d_inner, d_conv))?; + let xs: Tensor<(Batch, DInner, DConv), _, _, _> = + cache.conv_state.retaped::().try_mul(conv1d)?; + let xs: Tensor<(Batch, DInner), _, _, _> = xs.try_sum::<_, Axis<2>>()?; + + // conv1d bias + let xs = self.conv1d_bias.try_forward(xs)?; + + // activation + xs.try_silu()? + }; + + let (ss, cache_ssm_state) = ss_step::( + // + self.a_log.retaped::(), + self.d.retaped::(), + xs, + &self.x_proj, + &self.dt_proj, + cache.ssm_state, + )?; + + let ys = ss.try_mul(res.try_silu()?)?; + let y: Tensor<(Batch, DModel), _, _, _> = self.out_proj.try_forward(ys)?; + + cache.ssm_state = cache_ssm_state; + + Ok((y, cache)) + } + } + + /// Runs the SSM. See: + /// - Algorithm 2 in Section 3.2 from the Mamba paper; + /// - run_SSM(A, B, C, u) from The Annotated S4. + pub fn ss_step< + // Batch size (`B` in Algorithm 2 from the Mamba paper). + Batch: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + DState: Dim, + // Rank of Δ (See Section 3.6 "Parameterization of ∆" from the Mamba paper). + // Δ or delta: input-dependent step size. + DtRank: Dim, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + DInner: Dim, + E: Dtype, + D: Device, + T: Tape, + >( + // + a: Tensor<(DInner, DState), E, D, T>, + d: Tensor<(DInner,), E, D, T>, + u: Tensor<(Batch, DInner), E, D, T>, + x_proj: &MatMul>::Output>>::Output, E, D>, + dt_proj: &Linear, + ssm_state_cache: Tensor<(Batch, DInner, DState), E, D, T>, + ) -> Result< + ( + Tensor<(Batch, DInner), E, D, T>, + Tensor<(Batch, DInner, DState), E, D, T>, + ), + dfdx::tensor::Error, + > + where + // used to truncate back to DtRank: DState * 2 + DState: Mul, + >::Output: Dim, + // used to truncate back to DtRank: DtRank + DState * 2 + DtRank: Add<>::Output>, + >::Output>>::Output: Dim + Default, + { + let device = u.device().clone(); + + let (_d_inner, dt_rank) = *dt_proj.weight.shape(); + let (batch, d_inner, d_state) = *ssm_state_cache.shape(); + + // Compute ∆ A B C D, the state space parameters. + + // A + // this is input independent (see Section 3.5.2 "Interpretation of A" form the Mamba paper for why A isn't selective) + let a: Tensor<(DInner, DState), _, _, _> = a.try_exp()?.try_negate()?; + + // (Batch, DtRank + DState * 2) + let x_dbl: Tensor<(Batch, _), _, _, _> = x_proj.try_forward(u.retaped::())?; + + // ∆ (part 1/2) + // ∆ is input-dependent + let (delta, x_dbl_tail, _tape): (Tensor<(Batch, DtRank), _, _, _>, _, _) = + x_dbl.try_split_tensor_along(Axis::<1>, dt_rank, d_state * Const::<2>)?; + + // B and C + // B and C are input-dependent + let (b, c, _tape): ( + Tensor<(Batch, DState), _, _, _>, + Tensor<(Batch, DState), _, _, _>, + _, + ) = x_dbl_tail.try_split_tensor_along(Axis::<1>, d_state, d_state)?; + + // ∆ (part 2/2) + // ∆ is input-dependent + let delta: Tensor<(Batch, DInner), _, _, _> = { + // note: don't add dt_proj bias + let delta = delta.try_matmul( + dt_proj + .weight + .retaped::() + .try_permute::<_, dfdx::prelude::Axes2<1, 0>>()?, + )?; + // softplus without threshold + // TODO: consider the threshold + let one = device.ones_like(&delta); + (delta + .try_add( + dt_proj + .bias + .retaped::() + .try_broadcast_like(&(batch, d_inner))?, + )? + .try_exp()? + .try_add(one)?) + .try_ln()? + }; + + selective_scan_step::(delta, a, b, c, d, u, ssm_state_cache) + } + + // Selective Scan. + /// + /// Does selective scan algorithm. See: + /// - Section 2 State Space Models from the Mamba paper; + /// - Algorithm 2 in Section 3.2 from the Mamba paper; + /// - run_SSM(A, B, C, u) from The Annotated S4. + /// + pub fn selective_scan_step< + // Batch size (`B` in Algorithm 2 from the Mamba paper). + Batch: Dim, + // latent state dimension (`N` in Algorithm 2 from the Mamba paper). + DState: Dim, + // DModel * expand (`D` in Algorithm 2 from the Mamba paper). + DInner: Dim, + E: Dtype, + D: Device, + T: Tape, + >( + delta: Tensor<(Batch, DInner), E, D, T>, + a: Tensor<(DInner, DState), E, D, T>, + b: Tensor<(Batch, DState), E, D, T>, + c: Tensor<(Batch, DState), E, D, T>, + d: Tensor<(DInner,), E, D, T>, + u: Tensor<(Batch, DInner), E, D, T>, + mut ssm_state_cache: Tensor<(Batch, DInner, DState), E, D, T>, + ) -> Result< + ( + Tensor<(Batch, DInner), E, D, T>, + Tensor<(Batch, DInner, DState), E, D, T>, + ), + dfdx::tensor::Error, + > { + let (batch, d_inner, d_state) = *ssm_state_cache.shape(); + + // Discretize continuous parameters (A, B) + // - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper) + // - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: + // "A is the more important term and the performance doesn't change much with the simplification on B" + let (delta_a, delta_bu): ( + Tensor<(Batch, DInner, DState), _, _, _>, + Tensor<(Batch, DInner, DState), _, _, _>, + ) = { + let target_shape = (batch, d_inner, d_state); + + let delta_broadcasted = delta.try_broadcast_like(&target_shape)?; + + let a = a.try_broadcast_like(&target_shape)?; + let delta_a: Tensor<(Batch, DInner, DState), _, _, _> = + delta_broadcasted.retaped::().try_mul(a)?.try_exp()?; + + let b = b.try_broadcast_like(&target_shape)?; + let delta_bu = delta_broadcasted.try_mul(b)?; + + (delta_a, delta_bu) + }; + + ssm_state_cache = ssm_state_cache + .try_mul(delta_a.try_reshape_like(&(batch, d_inner, d_state))?)? + .try_add( + u.retaped::() + .try_reshape_like(&(batch, d_inner))? + .try_broadcast_like(&(batch, d_inner, d_state))? + .try_mul(delta_bu.try_reshape_like(&(batch, d_inner, d_state))?)?, + )?; + + let y = ssm_state_cache + .retaped::() + .try_matmul(c.try_reshape_like(&(batch, d_state, Const::<1>))?)?; + let du = d.try_broadcast_like(&(batch, d_inner))?.try_mul(u)?; + let y = y.try_reshape_like(&(batch, d_inner))?.try_add(du)?; + + Ok((y, ssm_state_cache)) + } +} diff --git a/dfdx/src/nn/layers/mod.rs b/dfdx/src/nn/layers/mod.rs index 062b9f08..21e3f7fb 100644 --- a/dfdx/src/nn/layers/mod.rs +++ b/dfdx/src/nn/layers/mod.rs @@ -25,6 +25,7 @@ mod leaky_relu; mod linear; mod ln; mod log_softmax; +mod mamba_minimal; mod matmul; mod multi_head_attention; #[cfg(feature = "nightly")] @@ -79,6 +80,10 @@ pub use leaky_relu::LeakyReLU; pub use linear::{Linear, LinearConfig, LinearConstConfig}; pub use ln::Ln; pub use log_softmax::LogSoftmax; +pub use mamba_minimal::{ + stateful::{MambaStateCache, MambaStateCacheConfig}, + MambaBlock, MambaBlockConfig, MambaBlockConstConfig, +}; pub use matmul::{MatMul, MatMulConfig, MatMulConstConfig}; pub use multi_head_attention::{MultiHeadAttention, MultiHeadAttentionConfig}; #[cfg(feature = "nightly")] From bff1b658aa3b91b0af57f037c7cc704d216d1f03 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 12:55:11 -0500 Subject: [PATCH 20/20] add nightly requirement for mamba-minimal --- dfdx/src/nn/layers/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dfdx/src/nn/layers/mod.rs b/dfdx/src/nn/layers/mod.rs index 21e3f7fb..fb5c1382 100644 --- a/dfdx/src/nn/layers/mod.rs +++ b/dfdx/src/nn/layers/mod.rs @@ -25,6 +25,7 @@ mod leaky_relu; mod linear; mod ln; mod log_softmax; +#[cfg(feature = "nightly")] mod mamba_minimal; mod matmul; mod multi_head_attention; @@ -80,6 +81,7 @@ pub use leaky_relu::LeakyReLU; pub use linear::{Linear, LinearConfig, LinearConstConfig}; pub use ln::Ln; pub use log_softmax::LogSoftmax; +#[cfg(feature = "nightly")] pub use mamba_minimal::{ stateful::{MambaStateCache, MambaStateCacheConfig}, MambaBlock, MambaBlockConfig, MambaBlockConstConfig,