diff --git a/crates/burn-jit/src/kernel/index/mod.rs b/crates/burn-jit/src/kernel/index/mod.rs index be1b814b70..8ac9ac8298 100644 --- a/crates/burn-jit/src/kernel/index/mod.rs +++ b/crates/burn-jit/src/kernel/index/mod.rs @@ -9,7 +9,7 @@ mod slice_assign; pub use flip::*; pub use repeat_dim::*; -pub use select::*; +pub(crate) use select::*; pub(crate) use select_assign::*; pub use slice::*; pub use slice_assign::*; diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index 979adafe31..2746b1c1f5 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -1,118 +1,33 @@ use crate::{ element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime, }; -use cubecl::{ - cpa, - frontend::TensorHandleRef, - ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, - CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, - OutputInfo, -}; -use std::marker::PhantomData; - -#[derive(new)] -struct SelectEagerKernel { - dim: usize, - _runtime: PhantomData, - _elem: PhantomData, -} - -pub struct SelectComputeShader { - input: Variable, - indices: Variable, - output: Variable, - dim: usize, -} - -impl SelectComputeShader { - pub fn expand(self, scope: &mut Scope) { - let input = self.input; - let indices = self.indices; - let output = self.output; - let id = Variable::AbsolutePos; - let offset_input = scope.zero(Elem::UInt); - - cpa!( - scope, - range(0u32, Variable::Rank).for_each(|i, scope| { - let stride_input = scope.create_local(Elem::UInt); - let stride_output = scope.create_local(Elem::UInt); - let shape_output = scope.create_local(Elem::UInt); - - cpa!(scope, stride_input = stride(input, i)); - cpa!(scope, stride_output = stride(output, i)); - cpa!(scope, shape_output = shape(output, i)); - - let offset_local = scope.create_local(Elem::UInt); - cpa!(scope, offset_local = id / stride_output); - cpa!(scope, offset_local = offset_local % shape_output); - - let dim_index = scope.create_local(Elem::Bool); - cpa!(scope, dim_index = i == self.dim); - - cpa!(scope, if(dim_index).then(|scope| { - cpa!(scope, offset_local = indices[offset_local]); - cpa!(scope, offset_local = offset_local * stride_input); - }).else(|scope| { - cpa!(scope, offset_local = offset_local * stride_input); - })); - - cpa!(scope, offset_input += offset_local); - }) - ); - - let value = scope.create_local(input.item()); - cpa!(scope, value = input[offset_input]); - cpa!(scope, output[id] = value); +use cubecl::prelude::*; +use cubecl::{calculate_cube_count_elemwise, CubeDim}; + +#[cube(launch_unchecked)] +fn select_kernel( + input: &Tensor, + indices: &Tensor, + output: &mut Tensor, + dim: UInt, +) { + if ABSOLUTE_POS >= output.len() { + return; } -} -impl Kernel for SelectEagerKernel { - fn define(&self) -> KernelDefinition { - let mut scope = Scope::root(); - let item = E::cube_elem().into(); - let item_indices: Item = Elem::Int(IntKind::I32).into(); + let mut offset_input = UInt::new(0); - let input = Variable::GlobalInputArray { id: 0, item }; - let indices = Variable::GlobalInputArray { - id: 1, - item: item_indices, - }; - let output = Variable::GlobalOutputArray { id: 0, item }; + for i in range(0u32, output.rank(), Comptime::new(false)) { + let mut offset_local = ABSOLUTE_POS / output.stride(i) % output.shape(i); - scope.write_global_custom(output); - - SelectComputeShader { - input, - indices, - output, - dim: self.dim, + if i == dim { + offset_local = UInt::cast_from(indices[offset_local]); } - .expand(&mut scope); - - let input = InputInfo::Array { - item, - visibility: Visibility::Read, - }; - let indices = InputInfo::Array { - item: item_indices, - visibility: Visibility::Read, - }; - let output = OutputInfo::Array { item }; - - let info = KernelExpansion { - inputs: vec![input, indices], - outputs: vec![output], - scope, - }; - let settings = KernelSettings::default(); - KernelIntegrator::new(info).integrate(settings) + offset_input += offset_local * input.stride(i); } - fn id(&self) -> cubecl::KernelId { - cubecl::KernelId::new::().info(self.dim) - } + output[ABSOLUTE_POS] = input[offset_input]; } pub(crate) fn select( @@ -122,26 +37,25 @@ pub(crate) fn select JitTensor { let mut shape_output = tensor.shape.clone(); shape_output.dims[dim] = indices.shape.dims[0]; + let total_elem = shape_output.num_elements(); let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - let kernel = SelectEagerKernel::::new(dim); - - let num_elems = indices.shape.dims[0]; - let mut shapes = [1; D]; - let mut strides = [num_elems; D]; - shapes[D - 1] = num_elems; - strides[D - 1] = 1; - Execution::start(kernel, tensor.client.clone()) - .inputs(&[ - tensor.as_handle_ref(), - // This is a current hacks because the info buffer that contains the strides and shapes is - // hardcoded to only contains information about tensors of the same rank. However, since - // we don't rely on the shape and stride of the indices tensors, it doesn't matter - // which value we put, it just needs to be of the same rank. - unsafe { TensorHandleRef::from_raw_parts(&indices.handle, &strides, &shapes) }, - ]) - .outputs(&[output.as_handle_ref()]) - .execute(CubeCountSettings::Output { pos: 0 }); + let dummy_array = [1; D]; + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); + + unsafe { + select_kernel::launch_unchecked::( + &tensor.client, + cube_count, + cube_dim, + tensor.as_tensor_arg(1), + // Ignore shape and stride + TensorArg::from_raw_parts(&indices.handle, &dummy_array, &dummy_array, 1), + output.as_tensor_arg(1), + ScalarArg::new(dim as u32), + ) + }; output }