Skip to content

Commit

Permalink
Feat/cube/array assign ops (#1914)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Jun 25, 2024
1 parent 0f8dd57 commit 2fbc462
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 30 deletions.
8 changes: 8 additions & 0 deletions crates/burn-cube-macros/src/codegen_function/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,20 @@ pub(crate) fn codegen_block(
pub(crate) struct Codegen {
pub tokens: proc_macro2::TokenStream,
pub is_comptime: bool,
pub array_indexing: Option<ArrayIndexing>,
}

pub(crate) struct ArrayIndexing {
pub array: proc_macro2::TokenStream,
pub index: proc_macro2::TokenStream,
}

impl From<proc_macro2::TokenStream> for Codegen {
fn from(tokens: proc_macro2::TokenStream) -> Self {
Self {
tokens,
is_comptime: false,
array_indexing: None,
}
}
}
Expand All @@ -65,6 +72,7 @@ impl Codegen {
Self {
tokens: tokens.into(),
is_comptime,
array_indexing: None,
}
}

Expand Down
11 changes: 9 additions & 2 deletions crates/burn-cube-macros/src/codegen_function/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) fn codegen_expr(
syn::Expr::Call(call) => codegen_call(call, loop_level, variable_tracker),
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_tracker),
_ => {
let mut array_indexing = None;
let tokens = match expr {
syn::Expr::Path(path) => {
return codegen_path_var(path, loop_level, variable_tracker)
Expand All @@ -50,7 +51,11 @@ pub(crate) fn codegen_expr(
syn::Expr::MethodCall(call) => {
codegen_expr_method_call(call, loop_level, variable_tracker)
}
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_tracker),
syn::Expr::Index(index) => {
let codegen = codegen_index(index, loop_level, variable_tracker);
array_indexing = codegen.array_indexing;
codegen.tokens
}
syn::Expr::Array(array) => codegen_array_lit(array),
syn::Expr::Reference(reference) => {
codegen_ref(reference, loop_level, variable_tracker)
Expand All @@ -67,7 +72,9 @@ pub(crate) fn codegen_expr(
}
};

Codegen::new(tokens, false)
let mut codegen = Codegen::new(tokens, false);
codegen.array_indexing = array_indexing;
codegen
}
}
}
Expand Down
111 changes: 86 additions & 25 deletions crates/burn-cube-macros/src/codegen_function/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ pub(crate) fn codegen_binary(
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> Codegen {
let (lhs, is_comptime_lhs) = codegen_expr(&binary.left, loop_level, variable_tracker).split();
let lhs = codegen_expr(&binary.left, loop_level, variable_tracker);
let (lhs, is_comptime_lhs, lhs_array) = (lhs.tokens, lhs.is_comptime, lhs.array_indexing);
let (rhs, is_comptime_rhs) = codegen_expr(&binary.right, loop_level, variable_tracker).split();

if is_comptime_lhs && is_comptime_rhs {
Expand Down Expand Up @@ -99,34 +100,94 @@ pub(crate) fn codegen_binary(
burn_cube::frontend::eq::expand(context, _lhs, _rhs)
}
},
syn::BinOp::AddAssign(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs)
syn::BinOp::AddAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);

quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::add_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs)
}
}
}
},
syn::BinOp::SubAssign(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs)
}
syn::BinOp::SubAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);

quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::sub_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs)
}
}
}
},
syn::BinOp::MulAssign(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs)
}
syn::BinOp::MulAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);

quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::mul_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs)
}
}
}
},
syn::BinOp::DivAssign(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs)
}
syn::BinOp::DivAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);

quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::div_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs)
}
}
}
},
}
syn::BinOp::And(_) => quote::quote! {
{

Expand Down
14 changes: 11 additions & 3 deletions crates/burn-cube-macros/src/codegen_function/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,25 @@ pub(crate) fn codegen_index(
index: &syn::ExprIndex,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
) -> Codegen {
let array = codegen_expr(&index.expr, loop_level, variable_tracker);
let index = codegen_expr(&index.index, loop_level, variable_tracker);

quote::quote! {
let tokens = quote::quote! {
{
let _array = #array;
let _index = #index;
burn_cube::frontend::index::expand(context, _array, _index)
}
}
};

let mut codegen = Codegen::new(tokens, false);
codegen.array_indexing = Some(super::base::ArrayIndexing {
array: array.tokens,
index: index.tokens,
});

codegen
}

/// Codegen for assignation
Expand Down
84 changes: 84 additions & 0 deletions crates/burn-cube/src/frontend/operation/assignation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,90 @@ pub mod index {
impl_index!(SharedMemory);
}

pub mod add_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;

use self::ir::Operator;

use super::*;

pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array_assign_binary_op_expand(context, array, index, value, Operator::Add);
}
}

pub mod sub_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;

use self::ir::Operator;

use super::*;

pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array_assign_binary_op_expand(context, array, index, value, Operator::Sub);
}
}

pub mod mul_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;

use self::ir::Operator;

use super::*;

pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array_assign_binary_op_expand(context, array, index, value, Operator::Mul);
}
}

pub mod div_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;

use self::ir::Operator;

use super::*;

pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array_assign_binary_op_expand(context, array, index, value, Operator::Div);
}
}

pub mod add_assign_op {
use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64};

Expand Down
40 changes: 40 additions & 0 deletions crates/burn-cube/src/frontend/operation/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,43 @@ fn check_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization

output
}

pub fn array_assign_binary_op_expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
F: Fn(BinaryOperator) -> Operator,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
func: F,
) {
let array: ExpandElement = array.into();
let index: ExpandElement = index.into();
let value: ExpandElement = value.into();

let tmp = context.create_local(array.item());

let read = Operator::Index(BinaryOperator {
lhs: *array,
rhs: *index,
out: *tmp,
});
let calculate = func(BinaryOperator {
lhs: *tmp,
rhs: *value,
out: *tmp,
});

let write = Operator::IndexAssign(BinaryOperator {
lhs: *index,
rhs: *tmp,
out: *array,
});

context.register(read);
context.register(calculate);
context.register(write);
}
Loading

0 comments on commit 2fbc462

Please sign in to comment.