From 2fbc4628f3ddbd25e223084985c94d2c46dccb58 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 25 Jun 2024 09:55:55 -0400 Subject: [PATCH] Feat/cube/array assign ops (#1914) --- .../src/codegen_function/base.rs | 8 ++ .../src/codegen_function/expr.rs | 11 +- .../src/codegen_function/operation.rs | 111 ++++++++++++++---- .../src/codegen_function/variable.rs | 14 ++- .../src/frontend/operation/assignation.rs | 84 +++++++++++++ .../burn-cube/src/frontend/operation/base.rs | 40 +++++++ crates/burn-cube/tests/frontend/array.rs | 84 +++++++++++++ crates/burn-cube/tests/frontend/mod.rs | 1 + 8 files changed, 323 insertions(+), 30 deletions(-) create mode 100644 crates/burn-cube/tests/frontend/array.rs diff --git a/crates/burn-cube-macros/src/codegen_function/base.rs b/crates/burn-cube-macros/src/codegen_function/base.rs index f6f82621fe..e0e25c8c1d 100644 --- a/crates/burn-cube-macros/src/codegen_function/base.rs +++ b/crates/burn-cube-macros/src/codegen_function/base.rs @@ -49,6 +49,12 @@ pub(crate) fn codegen_block( pub(crate) struct Codegen { pub tokens: proc_macro2::TokenStream, pub is_comptime: bool, + pub array_indexing: Option, +} + +pub(crate) struct ArrayIndexing { + pub array: proc_macro2::TokenStream, + pub index: proc_macro2::TokenStream, } impl From for Codegen { @@ -56,6 +62,7 @@ impl From for Codegen { Self { tokens, is_comptime: false, + array_indexing: None, } } } @@ -65,6 +72,7 @@ impl Codegen { Self { tokens: tokens.into(), is_comptime, + array_indexing: None, } } diff --git a/crates/burn-cube-macros/src/codegen_function/expr.rs b/crates/burn-cube-macros/src/codegen_function/expr.rs index a509642abf..2b46f1eb80 100644 --- a/crates/burn-cube-macros/src/codegen_function/expr.rs +++ b/crates/burn-cube-macros/src/codegen_function/expr.rs @@ -25,6 +25,7 @@ pub(crate) fn codegen_expr( syn::Expr::Call(call) => codegen_call(call, loop_level, variable_tracker), syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_tracker), _ => { + let mut array_indexing = None; let tokens = match expr { syn::Expr::Path(path) => { return codegen_path_var(path, loop_level, variable_tracker) @@ -50,7 +51,11 @@ pub(crate) fn codegen_expr( syn::Expr::MethodCall(call) => { codegen_expr_method_call(call, loop_level, variable_tracker) } - syn::Expr::Index(index) => codegen_index(index, loop_level, variable_tracker), + syn::Expr::Index(index) => { + let codegen = codegen_index(index, loop_level, variable_tracker); + array_indexing = codegen.array_indexing; + codegen.tokens + } syn::Expr::Array(array) => codegen_array_lit(array), syn::Expr::Reference(reference) => { codegen_ref(reference, loop_level, variable_tracker) @@ -67,7 +72,9 @@ pub(crate) fn codegen_expr( } }; - Codegen::new(tokens, false) + let mut codegen = Codegen::new(tokens, false); + codegen.array_indexing = array_indexing; + codegen } } } diff --git a/crates/burn-cube-macros/src/codegen_function/operation.rs b/crates/burn-cube-macros/src/codegen_function/operation.rs index 8a9da3e60d..fa6235c045 100644 --- a/crates/burn-cube-macros/src/codegen_function/operation.rs +++ b/crates/burn-cube-macros/src/codegen_function/operation.rs @@ -8,7 +8,8 @@ pub(crate) fn codegen_binary( loop_level: usize, variable_tracker: &mut VariableTracker, ) -> Codegen { - let (lhs, is_comptime_lhs) = codegen_expr(&binary.left, loop_level, variable_tracker).split(); + let lhs = codegen_expr(&binary.left, loop_level, variable_tracker); + let (lhs, is_comptime_lhs, lhs_array) = (lhs.tokens, lhs.is_comptime, lhs.array_indexing); let (rhs, is_comptime_rhs) = codegen_expr(&binary.right, loop_level, variable_tracker).split(); if is_comptime_lhs && is_comptime_rhs { @@ -99,34 +100,94 @@ pub(crate) fn codegen_binary( burn_cube::frontend::eq::expand(context, _lhs, _rhs) } }, - syn::BinOp::AddAssign(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs) + syn::BinOp::AddAssign(_) => { + if let Some(array) = lhs_array { + let (array, index) = (array.array, array.index); + + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #rhs; + burn_cube::frontend::add_assign_array_op::expand(context, _array, _index, _value) + } + } + } else { + quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs) + } + } } - }, - syn::BinOp::SubAssign(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs) + } + syn::BinOp::SubAssign(_) => { + if let Some(array) = lhs_array { + let (array, index) = (array.array, array.index); + + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #rhs; + burn_cube::frontend::sub_assign_array_op::expand(context, _array, _index, _value) + } + } + } else { + quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs) + } + } } - }, - syn::BinOp::MulAssign(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs) + } + syn::BinOp::MulAssign(_) => { + if let Some(array) = lhs_array { + let (array, index) = (array.array, array.index); + + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #rhs; + burn_cube::frontend::mul_assign_array_op::expand(context, _array, _index, _value) + } + } + } else { + quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs) + } + } } - }, - syn::BinOp::DivAssign(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs) + } + syn::BinOp::DivAssign(_) => { + if let Some(array) = lhs_array { + let (array, index) = (array.array, array.index); + + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #rhs; + burn_cube::frontend::div_assign_array_op::expand(context, _array, _index, _value) + } + } + } else { + quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs) + } + } } - }, + } syn::BinOp::And(_) => quote::quote! { { diff --git a/crates/burn-cube-macros/src/codegen_function/variable.rs b/crates/burn-cube-macros/src/codegen_function/variable.rs index fe8a15b402..2ec75d9c1b 100644 --- a/crates/burn-cube-macros/src/codegen_function/variable.rs +++ b/crates/burn-cube-macros/src/codegen_function/variable.rs @@ -99,17 +99,25 @@ pub(crate) fn codegen_index( index: &syn::ExprIndex, loop_level: usize, variable_tracker: &mut VariableTracker, -) -> TokenStream { +) -> Codegen { let array = codegen_expr(&index.expr, loop_level, variable_tracker); let index = codegen_expr(&index.index, loop_level, variable_tracker); - quote::quote! { + let tokens = quote::quote! { { let _array = #array; let _index = #index; burn_cube::frontend::index::expand(context, _array, _index) } - } + }; + + let mut codegen = Codegen::new(tokens, false); + codegen.array_indexing = Some(super::base::ArrayIndexing { + array: array.tokens, + index: index.tokens, + }); + + codegen } /// Codegen for assignation diff --git a/crates/burn-cube/src/frontend/operation/assignation.rs b/crates/burn-cube/src/frontend/operation/assignation.rs index 5b1922774d..86edb022ee 100644 --- a/crates/burn-cube/src/frontend/operation/assignation.rs +++ b/crates/burn-cube/src/frontend/operation/assignation.rs @@ -113,6 +113,90 @@ pub mod index { impl_index!(SharedMemory); } +pub mod add_assign_array_op { + use crate::prelude::array_assign_binary_op_expand; + + use self::ir::Operator; + + use super::*; + + pub fn expand< + Array: Into, + Index: Into, + Value: Into, + >( + context: &mut CubeContext, + array: Array, + index: Index, + value: Value, + ) { + array_assign_binary_op_expand(context, array, index, value, Operator::Add); + } +} + +pub mod sub_assign_array_op { + use crate::prelude::array_assign_binary_op_expand; + + use self::ir::Operator; + + use super::*; + + pub fn expand< + Array: Into, + Index: Into, + Value: Into, + >( + context: &mut CubeContext, + array: Array, + index: Index, + value: Value, + ) { + array_assign_binary_op_expand(context, array, index, value, Operator::Sub); + } +} + +pub mod mul_assign_array_op { + use crate::prelude::array_assign_binary_op_expand; + + use self::ir::Operator; + + use super::*; + + pub fn expand< + Array: Into, + Index: Into, + Value: Into, + >( + context: &mut CubeContext, + array: Array, + index: Index, + value: Value, + ) { + array_assign_binary_op_expand(context, array, index, value, Operator::Mul); + } +} + +pub mod div_assign_array_op { + use crate::prelude::array_assign_binary_op_expand; + + use self::ir::Operator; + + use super::*; + + pub fn expand< + Array: Into, + Index: Into, + Value: Into, + >( + context: &mut CubeContext, + array: Array, + index: Index, + value: Value, + ) { + array_assign_binary_op_expand(context, array, index, value, Operator::Div); + } +} + pub mod add_assign_op { use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; diff --git a/crates/burn-cube/src/frontend/operation/base.rs b/crates/burn-cube/src/frontend/operation/base.rs index b7263db485..4d0c705486 100644 --- a/crates/burn-cube/src/frontend/operation/base.rs +++ b/crates/burn-cube/src/frontend/operation/base.rs @@ -203,3 +203,43 @@ fn check_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization output } + +pub fn array_assign_binary_op_expand< + Array: Into, + Index: Into, + Value: Into, + F: Fn(BinaryOperator) -> Operator, +>( + context: &mut CubeContext, + array: Array, + index: Index, + value: Value, + func: F, +) { + let array: ExpandElement = array.into(); + let index: ExpandElement = index.into(); + let value: ExpandElement = value.into(); + + let tmp = context.create_local(array.item()); + + let read = Operator::Index(BinaryOperator { + lhs: *array, + rhs: *index, + out: *tmp, + }); + let calculate = func(BinaryOperator { + lhs: *tmp, + rhs: *value, + out: *tmp, + }); + + let write = Operator::IndexAssign(BinaryOperator { + lhs: *index, + rhs: *tmp, + out: *array, + }); + + context.register(read); + context.register(calculate); + context.register(write); +} diff --git a/crates/burn-cube/tests/frontend/array.rs b/crates/burn-cube/tests/frontend/array.rs new file mode 100644 index 0000000000..44a7fca4fa --- /dev/null +++ b/crates/burn-cube/tests/frontend/array.rs @@ -0,0 +1,84 @@ +use burn_cube::prelude::*; + +#[cube] +fn array_add_assign_simple(mut array: Array) { + array[UInt::new(1)] += UInt::new(1); +} + +#[cube] +fn array_add_assign_expr(mut array: Array) { + array[UInt::new(1) + UInt::new(5)] += UInt::new(1); +} + +mod tests { + use super::*; + use burn_cube::{ + cpa, + ir::{Elem, Item, Variable}, + }; + + #[test] + fn array_add_assign() { + let mut context = CubeContext::root(); + let array = context.input(0, Item::new(Elem::UInt)); + + array_add_assign_simple_expand(&mut context, array); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_array_add_assign_simple() + ); + } + + #[test] + fn array_add_assign_expr() { + let mut context = CubeContext::root(); + let array = context.input(0, Item::new(Elem::UInt)); + + array_add_assign_expr_expand(&mut context, array); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_array_add_assign_expr() + ); + } + + fn inline_macro_array_add_assign_simple() -> String { + let context = CubeContext::root(); + + let mut scope = context.into_scope(); + let local = scope.create_local(Item::new(Elem::UInt)); + + let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt)); + let index = Variable::ConstantScalar(1., Elem::UInt); + let value = Variable::ConstantScalar(1., Elem::UInt); + + cpa!(scope, local = array[index]); + cpa!(scope, local += value); + cpa!(scope, array[index] = local); + + format!("{:?}", scope.operations) + } + + fn inline_macro_array_add_assign_expr() -> String { + let context = CubeContext::root(); + + let mut scope = context.into_scope(); + let index = scope.create_local(Item::new(Elem::UInt)); + let local = scope.create_local(Item::new(Elem::UInt)); + + let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt)); + let const1 = Variable::ConstantScalar(1., Elem::UInt); + let const2 = Variable::ConstantScalar(5., Elem::UInt); + let value = Variable::ConstantScalar(1., Elem::UInt); + + cpa!(scope, index = const1 + const2); + cpa!(scope, local = array[index]); + cpa!(scope, local += value); + cpa!(scope, array[index] = local); + + format!("{:?}", scope.operations) + } +} diff --git a/crates/burn-cube/tests/frontend/mod.rs b/crates/burn-cube/tests/frontend/mod.rs index c13c1300b3..f9b433e0ab 100644 --- a/crates/burn-cube/tests/frontend/mod.rs +++ b/crates/burn-cube/tests/frontend/mod.rs @@ -1,3 +1,4 @@ +mod array; mod assign; mod cast_elem; mod cast_kind;