From dd3c7f65517c56ea23de24a6ea83241f91139304 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Mon, 17 Jul 2017 17:24:38 +0100 Subject: [PATCH 1/2] [OpenCL] Registers DepthwiseConv2dNative --- tensorflow/core/kernels/depthwise_conv_op.cc | 474 +++++++++++++++++++ 1 file changed, 474 insertions(+) diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index ccd33c08612b27..3eebb7d071b582 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -53,6 +53,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template struct LaunchDepthwiseConvOp; @@ -435,6 +438,477 @@ class DepthwiseConv2dNativeOp : public BinaryOp { TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp); }; +#ifdef TENSORFLOW_USE_SYCL +template +class DepthwiseConv2dSYCLKernelNHWC { + using write_accessor = + cl::sycl::accessor; + using read_accessor = + cl::sycl::accessor; +public: + DepthwiseConv2dSYCLKernelNHWC(const DepthwiseArgs args, + read_accessor input_data_accessor, + read_accessor filter_data_accessor, + write_accessor output_data_accessor) + : args_(args), + input_data_accessor_(input_data_accessor), + filter_data_accessor_(filter_data_accessor), + output_data_accessor_(output_data_accessor){} + void operator()(cl::sycl::item<1> item){ + T* input_data = ConvertToActualTypeSycl(T, input_data_accessor_); + T* filter_data = ConvertToActualTypeSycl(T, filter_data_accessor_); + T* output_data = ConvertToActualTypeSycl(T, output_data_accessor_); + + const int thread_id = item.get_linear_id(); + + const int filter_rows = + kKnownFilterHeight < 0 ? args_.filter_rows : kKnownFilterHeight; + const int filter_cols = + kKnownFilterWidth < 0 ? args_.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args_.depth_multiplier : kKnownDepthMultiplier; + + // Compute the indexes of this thread in the output. + const int OD = thread_id % args_.out_depth; + const int OC = (thread_id / args_.out_depth) % args_.out_cols; + const int OR = (thread_id / args_.out_depth / args_.out_cols) % args_.out_rows; + const int OB = thread_id / args_.out_depth / args_.out_cols / args_.out_rows; + // Compute the input depth and the index of depth multiplier. + const int in_d = OD / depth_multiplier; + const int multiplier = OD % depth_multiplier; + + // Decide if all input is valid, if yes, we can skip the boundary checks + // for each input. + const int input_row_start = OR * args_.stride - args_.pad_rows; + const int input_col_start = OC * args_.stride - args_.pad_cols; + const int input_row_end = input_row_start + filter_rows; + const int input_col_end = input_col_start + filter_cols; + + T sum = T(0); + + const int input_offset_temp = args_.in_rows * OB; + if (input_row_start >= 0 && input_col_start >= 0 && + input_row_end < args_.in_rows && input_col_end < args_.in_cols) { + for (int f_r = 0; f_r < filter_rows; ++f_r) { + const int in_r = input_row_start + f_r; + const int filter_offset_temp = filter_cols * f_r; + for (int f_c = 0; f_c < filter_cols; ++f_c) { + const int in_c = input_col_start + f_c; + + const int input_offset = + in_d + args_.in_depth * (in_c + args_.in_cols * (in_r + input_offset_temp)); + const int filter_offset = + multiplier + + depth_multiplier * (in_d + args_.in_depth * (f_c + filter_offset_temp)); + sum += input_data[input_offset] * filter_data[filter_offset]; + } + } + } else { + for (int f_r = 0; f_r < filter_rows; ++f_r) { + const int in_r = input_row_start + f_r; + const int filter_offset_temp = filter_cols * f_r; + for (int f_c = 0; f_c < filter_cols; ++f_c) { + const int in_c = input_col_start + f_c; + if (in_r >= 0 && in_r < args_.in_rows && in_c >= 0 && in_c < args_.in_cols) { + const int in_c = input_col_start + f_c; + + const int input_offset = + in_d + args_.in_depth * (in_c + args_.in_cols * (in_r + input_offset_temp)); + const int filter_offset = + multiplier + depth_multiplier * + (in_d + args_.in_depth * (f_c + filter_offset_temp)); + sum += input_data[input_offset] * filter_data[filter_offset]; + } + } + } + } + output_data[thread_id] = sum; + } +private: + const DepthwiseArgs args_; + const read_accessor input_data_accessor_; + const read_accessor filter_data_accessor_; + write_accessor output_data_accessor_; +}; + +template +class DepthwiseConv2dSYCLKernelNCHW { + using write_accessor = + cl::sycl::accessor; + using read_accessor = + cl::sycl::accessor; +public: + DepthwiseConv2dSYCLKernelNCHW(const DepthwiseArgs args, + read_accessor input_data_accessor, + read_accessor filter_data_accessor, + write_accessor output_data_accessor) + : args_(args), + input_data_accessor_(input_data_accessor), + filter_data_accessor_(filter_data_accessor), + output_data_accessor_(output_data_accessor){} + void operator()(cl::sycl::item<1> item){ + T* input_data = ConvertToActualTypeSycl(T, input_data_accessor_); + T* filter_data = ConvertToActualTypeSycl(T, filter_data_accessor_); + T* output_data = ConvertToActualTypeSycl(T, output_data_accessor_); + + const int thread_id = item.get_linear_id(); + + const int filter_rows = + kKnownFilterHeight < 0 ? args_.filter_rows : kKnownFilterHeight; + const int filter_cols = + kKnownFilterWidth < 0 ? args_.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args_.depth_multiplier : kKnownDepthMultiplier; + // Compute the indexes of this thread in the output. + // + // We want coalesced reads so we make sure that each warp reads + // a contiguous chunk of memory. + // + // THIS IS PROBABLY WRONG, we are not doing coalesced reads + // into the input, because of the depth multiplier division... + const int OC = thread_id % args_.out_cols; + const int OR = (thread_id / args_.out_cols) % args_.out_rows; + const int OD = (thread_id / args_.out_cols / args_.out_rows) % args_.out_depth; + const int OB = thread_id / args_.out_cols / args_.out_rows / args_.out_depth; + + // Compute the input depth and the index of depth multiplier + // based off the output depth index that this thread is + // computing n. + const int in_d = OD / depth_multiplier; + const int multiplier = OD % depth_multiplier; + + // Data is stored in the following format (let's assume we + // flatten the height and width into one contiguous dimension + // called "P". + // + // B1C1P1 B1C1P2 ..... B1C2P1 B1C2P2 .... + // B2C1P1 B2C1P2 ..... B2C2P1 B2C2P2 .... + // + // Each row contains in_depth * in_rows * in_cols values + // for each sample in the batch. + // + // We can further flatten it into: + // + // B1C1P1 B1C1P2 ..... + // B1C2P1 B1C2P2 .... + // B2C1P1 B2C1P2 ..... + // B2C2P1 B2C2P2 .... + // + // where each row is a contiguous array of all of the spatial + // pixels for a given batch and input depth. The following + // loop unrolls across the filter dimensions for a given thread, + // indexing into the filter value and the corresponding input + // patch. + // + // We can compute the index into the patch once right here. + const int input_offset_temp = (OB * args_.in_depth + in_d) * (args_.in_rows * args_.in_cols); + + // Finally, we can iterate over the spatial dimensions and perform the + // convolution, writing into the output at the end. + // + // We perform an additional optimization, where we can determine + // whether the patch fits within the image indices statically, and + // avoid boundary checking within the loop. + const int input_row_start = OR * args_.stride - args_.pad_rows; + const int input_col_start = OC * args_.stride - args_.pad_cols; + const int input_row_end = input_row_start + filter_rows; + const int input_col_end = input_col_start + filter_cols; + + T sum = T(0); + if (input_row_start >= 0 && input_col_start >= 0 && + input_row_end < args_.in_rows && input_col_end < args_.in_cols) { + // Loop that doesn't need to check for boundary conditions. + for (int f_r = 0; f_r < filter_rows; ++f_r) { + const int in_r = input_row_start + f_r; + const int filter_offset_temp = filter_cols * f_r; + for (int f_c = 0; f_c < filter_cols; ++f_c) { + const int in_c = input_col_start + f_c; + + const int input_offset = + (input_offset_temp) + (in_r * args_.in_cols) + in_c; + const int filter_offset = + multiplier + + depth_multiplier * (in_d + args_.in_depth * (f_c + filter_offset_temp)); + sum += input_data[input_offset] * filter_data[filter_offset]; + } + } + } else { + // Loop that needs to check for boundary conditions. + for (int f_r = 0; f_r < filter_rows; ++f_r) { + const int in_r = input_row_start + f_r; + const int filter_offset_temp = filter_cols * f_r; + for (int f_c = 0; f_c < filter_cols; ++f_c) { + const int in_c = input_col_start + f_c; + // TODO(vrv): the in_r check can be done outside of this loop; + // benchmark both methods to determine the better decision. + if (in_r >= 0 && in_r < args_.in_rows && in_c >= 0 && in_c < args_.in_cols) { + const int in_c = input_col_start + f_c; + + // input_offset_temp indexes into the start of memory + // where the spatial data starts. + const int input_offset = + (input_offset_temp) + (in_r * args_.in_cols) + in_c; + + const int filter_offset = + multiplier + depth_multiplier * + (in_d + args_.in_depth * (f_c + filter_offset_temp)); + sum += input_data[input_offset] * filter_data[filter_offset]; + } + } + } + } + output_data[thread_id] = sum; + } +private: + const DepthwiseArgs args_; + const read_accessor input_data_accessor_; + const read_accessor filter_data_accessor_; + write_accessor output_data_accessor_; +}; + +template +void LaunchDepthwiseConv2dSYCL(const SYCLDevice& d, const DepthwiseArgs args, + const Tensor& input, const Tensor& filter, Tensor* output, + TensorFormat data_format) { + const int num_threads = output->NumElements(); + + auto input_data_buffer = d.get_sycl_buffer(input.template flat().data()); + auto filter_data_buffer = d.get_sycl_buffer(filter.template flat().data()); + auto output_data_buffer = d.get_sycl_buffer(output->template flat().data()); + + d.sycl_queue().submit([&](cl::sycl::handler& cgh) { + auto input_data_access = + input_data_buffer + .template get_access(cgh); + auto filter_data_access = + filter_data_buffer + .template get_access(cgh); + auto output_data_access = + output_data_buffer + .template get_access(cgh); + + if(data_format == FORMAT_NHWC){ + DepthwiseConv2dSYCLKernelNHWC functor( + args, input_data_access, filter_data_access, output_data_access); + cgh.parallel_for(cl::sycl::range<1>(num_threads), functor); + } else if (data_format == FORMAT_NCHW) { + DepthwiseConv2dSYCLKernelNCHW functor( + args, input_data_access, filter_data_access, output_data_access); + cgh.parallel_for(cl::sycl::range<1>(num_threads), functor); + } else { + assert(false && "Incorrect data format"); + return; + } + }); +} + +template +void LaunchDepthwiseConv2dSYCL(const SYCLDevice& d, const DepthwiseArgs args, + const Tensor& input, const Tensor& filter, + Tensor* output, TensorFormat data_format) { + if (args.depth_multiplier == 1) { + LaunchDepthwiseConv2dSYCL( + d, args, input, filter, output, data_format); + } else { + LaunchDepthwiseConv2dSYCL( + d, args, input, filter, output, data_format); + } +} + +template +struct LaunchDepthwiseConvOp { + static void launch(OpKernelContext* ctx, const DepthwiseArgs args, + const Tensor& input, const Tensor& filter, Tensor* output, + TensorFormat data_format) { + const SYCLDevice& d = ctx->eigen_device(); + if (args.filter_rows == 3 && args.filter_cols == 3) { + LaunchDepthwiseConv2dSYCL(d, args, input, filter, output, + data_format); + } else { + LaunchDepthwiseConv2dSYCL(d, args, input, filter, output, + data_format); + } + } +}; + +// Extern template instantiated in conv_ops.cc. +extern template class LaunchConv2DOp; + +template +class DepthwiseConv2dNativeOp : public BinaryOp { + public: + explicit DepthwiseConv2dNativeOp(OpKernelConstruction* context) + : BinaryOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + stride_ = GetTensorDim(strides_, data_format_, 'H'); + const int64 stride_w = GetTensorDim(strides_, data_format_, 'W'); + const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); + const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); + + OP_REQUIRES(context, stride_ == stride_w, + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES( + context, (stride_n == 1 && stride_c == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + + // For special case when in_depth == 1. + use_cudnn_ = CanUseCudnn(); + cudnn_use_autotune_ = CudnnUseAutotune(); + } + void Compute(OpKernelContext* context) override { + // Input tensor is of the following dimensions: + // [ batch, in_rows, in_cols, in_depth ] + const Tensor& input = context->input(0); + + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, in_depth, depth_multiplier] + const Tensor& filter = context->input(1); + + // For 2D convolution, there should be 4 dimensions. + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().DebugString())); + OP_REQUIRES(context, filter.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional: ", + filter.shape().DebugString())); + + // in_depth for input and filter must match. + const int64 in_depth = GetTensorDim(input, data_format_, 'C'); + OP_REQUIRES( + context, in_depth == filter.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + in_depth, " vs ", filter.dim_size(2))); + + // The last dimension for filter is depth multiplier. + const int32 depth_multiplier = filter.dim_size(3); + + // The output depth is input depth x depth multipler + const int32 out_depth = in_depth * depth_multiplier; + + const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H'); + OP_REQUIRES( + context, + FastBoundsCheck(input_rows_raw, std::numeric_limits::max()), + errors::InvalidArgument("Input rows too large")); + const int32 input_rows = static_cast(input_rows_raw); + const int32 filter_rows = filter.dim_size(0); + + const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W'); + OP_REQUIRES( + context, + FastBoundsCheck(input_cols_raw, std::numeric_limits::max()), + errors::InvalidArgument("Input cols too large")); + const int32 input_cols = static_cast(input_cols_raw); + const int32 filter_cols = filter.dim_size(1); + + // The first dimension for input is batch. + const int32 batch = input.dim_size(0); + + int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + OP_REQUIRES_OK(context, + GetWindowedOutputSize(input_rows, filter_rows, stride_, + padding_, &out_rows, &pad_rows)); + OP_REQUIRES_OK(context, + GetWindowedOutputSize(input_cols, filter_cols, stride_, + padding_, &out_cols, &pad_cols)); + TensorShape out_shape = + ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); + OP_REQUIRES( + context, out_shape.num_elements() <= 2147483647, + errors::InvalidArgument("total number of outputs should be within the " + "range of int which is used in the SYCL kernel", + in_depth, " vs ", filter.dim_size(2))); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + VLOG(2) << "DepthwiseConv2dNative: " + << " Input: [" << batch << ", " << input_rows << ", " << input_cols + << ", " << in_depth << "]; Filter: [" << filter_rows << ", " + << filter_cols << ", " << in_depth << ", " << depth_multiplier + << "]; stride = " << stride_ << ", pad_rows = " << pad_rows + << ", pad_cols = " << pad_cols << ", output: [" << batch << ", " + << out_rows << ", " << out_cols << ", " << out_depth << "]"; + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + + // If in_depth==1, this operation is just a standard convolution, so + // invoke that op. + if (std::is_same::value && in_depth == 1) { + launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter, + stride_, stride_, BrainPadding2EigenPadding(padding_), + output, data_format_); + return; + } + + DepthwiseArgs args; + args.batch = batch; + args.in_rows = input_rows; + args.in_cols = input_cols; + args.in_depth = in_depth; + args.filter_rows = filter_rows; + args.filter_cols = filter_cols; + args.depth_multiplier = depth_multiplier; + args.stride = stride_; + args.pad_rows = pad_rows; + args.pad_cols = pad_cols; + args.out_rows = out_rows; + args.out_cols = out_cols; + args.out_depth = out_depth; + + LaunchDepthwiseConvOp::launch( + context, args, input, filter, output, data_format_); + } + + private: + std::vector strides_; + Padding padding_; + TensorFormat data_format_; + + int64 stride_; // in height/width dimension. + + // For the case in_depth == 1. + LaunchConv2DOp launcher_; + bool use_cudnn_; + bool cudnn_use_autotune_; + + TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp); +}; + +REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") + .Device(DEVICE_SYCL).TypeConstraint("T"), + DepthwiseConv2dNativeOp); + +REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") + .Device(DEVICE_SYCL) + .TypeConstraint("T"), + DepthwiseConv2dNativeOp); +#endif // TENSORFLOW_USE_SYCL + #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint("T"), \ From fbe58629ad9c5356844e0f06f5c78bade8d790af Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Wed, 19 Jul 2017 14:56:45 +0100 Subject: [PATCH 2/2] Removing SYCL kernel for NCHW --- tensorflow/core/kernels/depthwise_conv_op.cc | 143 ------------------- 1 file changed, 143 deletions(-) diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index 3eebb7d071b582..8c901a94e18e42 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -534,144 +534,6 @@ class DepthwiseConv2dSYCLKernelNHWC { write_accessor output_data_accessor_; }; -template -class DepthwiseConv2dSYCLKernelNCHW { - using write_accessor = - cl::sycl::accessor; - using read_accessor = - cl::sycl::accessor; -public: - DepthwiseConv2dSYCLKernelNCHW(const DepthwiseArgs args, - read_accessor input_data_accessor, - read_accessor filter_data_accessor, - write_accessor output_data_accessor) - : args_(args), - input_data_accessor_(input_data_accessor), - filter_data_accessor_(filter_data_accessor), - output_data_accessor_(output_data_accessor){} - void operator()(cl::sycl::item<1> item){ - T* input_data = ConvertToActualTypeSycl(T, input_data_accessor_); - T* filter_data = ConvertToActualTypeSycl(T, filter_data_accessor_); - T* output_data = ConvertToActualTypeSycl(T, output_data_accessor_); - - const int thread_id = item.get_linear_id(); - - const int filter_rows = - kKnownFilterHeight < 0 ? args_.filter_rows : kKnownFilterHeight; - const int filter_cols = - kKnownFilterWidth < 0 ? args_.filter_cols : kKnownFilterWidth; - const int depth_multiplier = - kKnownDepthMultiplier < 0 ? args_.depth_multiplier : kKnownDepthMultiplier; - // Compute the indexes of this thread in the output. - // - // We want coalesced reads so we make sure that each warp reads - // a contiguous chunk of memory. - // - // THIS IS PROBABLY WRONG, we are not doing coalesced reads - // into the input, because of the depth multiplier division... - const int OC = thread_id % args_.out_cols; - const int OR = (thread_id / args_.out_cols) % args_.out_rows; - const int OD = (thread_id / args_.out_cols / args_.out_rows) % args_.out_depth; - const int OB = thread_id / args_.out_cols / args_.out_rows / args_.out_depth; - - // Compute the input depth and the index of depth multiplier - // based off the output depth index that this thread is - // computing n. - const int in_d = OD / depth_multiplier; - const int multiplier = OD % depth_multiplier; - - // Data is stored in the following format (let's assume we - // flatten the height and width into one contiguous dimension - // called "P". - // - // B1C1P1 B1C1P2 ..... B1C2P1 B1C2P2 .... - // B2C1P1 B2C1P2 ..... B2C2P1 B2C2P2 .... - // - // Each row contains in_depth * in_rows * in_cols values - // for each sample in the batch. - // - // We can further flatten it into: - // - // B1C1P1 B1C1P2 ..... - // B1C2P1 B1C2P2 .... - // B2C1P1 B2C1P2 ..... - // B2C2P1 B2C2P2 .... - // - // where each row is a contiguous array of all of the spatial - // pixels for a given batch and input depth. The following - // loop unrolls across the filter dimensions for a given thread, - // indexing into the filter value and the corresponding input - // patch. - // - // We can compute the index into the patch once right here. - const int input_offset_temp = (OB * args_.in_depth + in_d) * (args_.in_rows * args_.in_cols); - - // Finally, we can iterate over the spatial dimensions and perform the - // convolution, writing into the output at the end. - // - // We perform an additional optimization, where we can determine - // whether the patch fits within the image indices statically, and - // avoid boundary checking within the loop. - const int input_row_start = OR * args_.stride - args_.pad_rows; - const int input_col_start = OC * args_.stride - args_.pad_cols; - const int input_row_end = input_row_start + filter_rows; - const int input_col_end = input_col_start + filter_cols; - - T sum = T(0); - if (input_row_start >= 0 && input_col_start >= 0 && - input_row_end < args_.in_rows && input_col_end < args_.in_cols) { - // Loop that doesn't need to check for boundary conditions. - for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = input_row_start + f_r; - const int filter_offset_temp = filter_cols * f_r; - for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = input_col_start + f_c; - - const int input_offset = - (input_offset_temp) + (in_r * args_.in_cols) + in_c; - const int filter_offset = - multiplier + - depth_multiplier * (in_d + args_.in_depth * (f_c + filter_offset_temp)); - sum += input_data[input_offset] * filter_data[filter_offset]; - } - } - } else { - // Loop that needs to check for boundary conditions. - for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = input_row_start + f_r; - const int filter_offset_temp = filter_cols * f_r; - for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = input_col_start + f_c; - // TODO(vrv): the in_r check can be done outside of this loop; - // benchmark both methods to determine the better decision. - if (in_r >= 0 && in_r < args_.in_rows && in_c >= 0 && in_c < args_.in_cols) { - const int in_c = input_col_start + f_c; - - // input_offset_temp indexes into the start of memory - // where the spatial data starts. - const int input_offset = - (input_offset_temp) + (in_r * args_.in_cols) + in_c; - - const int filter_offset = - multiplier + depth_multiplier * - (in_d + args_.in_depth * (f_c + filter_offset_temp)); - sum += input_data[input_offset] * filter_data[filter_offset]; - } - } - } - } - output_data[thread_id] = sum; - } -private: - const DepthwiseArgs args_; - const read_accessor input_data_accessor_; - const read_accessor filter_data_accessor_; - write_accessor output_data_accessor_; -}; - template void LaunchDepthwiseConv2dSYCL(const SYCLDevice& d, const DepthwiseArgs args, @@ -699,11 +561,6 @@ void LaunchDepthwiseConv2dSYCL(const SYCLDevice& d, const DepthwiseArgs args, kKnownDepthMultiplier> functor( args, input_data_access, filter_data_access, output_data_access); cgh.parallel_for(cl::sycl::range<1>(num_threads), functor); - } else if (data_format == FORMAT_NCHW) { - DepthwiseConv2dSYCLKernelNCHW functor( - args, input_data_access, filter_data_access, output_data_access); - cgh.parallel_for(cl::sycl::range<1>(num_threads), functor); } else { assert(false && "Incorrect data format"); return;