diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc index 00d7f5640829c3..4af37443cbb0a5 100644 --- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc @@ -45,6 +45,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL // Common code between the two backward pass kernels: verifies that the // dimensions all match and extract the padded rows and columns. @@ -612,6 +615,239 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropInputOp); }; +#ifdef TENSORFLOW_USE_SYCL +template +class DepthwiseConv2dBackpropInputSYCLKernelNHWC{ +public: + using write_accessor = + cl::sycl::accessor; + using read_accessor = + cl::sycl::accessor; +public: + DepthwiseConv2dBackpropInputSYCLKernelNHWC(const DepthwiseArgs args, + read_accessor out_backprop_accessor, + read_accessor filter_accessor, + write_accessor in_backprop_accessor) + : args_(args), + out_backprop_accessor_(out_backprop_accessor), + filter_accessor_(filter_accessor), + in_backprop_accessor_(in_backprop_accessor){} + void operator()(cl::sycl::item<1> item){ + T* out_backprop_data = ConvertToActualTypeSycl(T, out_backprop_accessor_); + T* filter_data = ConvertToActualTypeSycl(T, filter_accessor_); + T* in_backprop_data = ConvertToActualTypeSycl(T, in_backprop_accessor_); + + const int thread_id = item.get_linear_id(); + + const int filter_rows = + kKnownFilterHeight < 0 ? args_.filter_rows : kKnownFilterHeight; + const int filter_cols = + kKnownFilterWidth < 0 ? args_.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args_.depth_multiplier : kKnownDepthMultiplier; + // Compute the indexes of this thread in the output. + const int in_d = thread_id % args_.in_depth; + const int in_c = (thread_id / args_.in_depth) % args_.in_cols; + const int in_r = (thread_id / args_.in_depth / args_.in_cols) % args_.in_rows; + const int b = thread_id / args_.in_depth / args_.in_cols / args_.in_rows; + + T sum = T(0); + + const int out_r_start = + std::max(0, (in_r - filter_rows + args_.pad_rows + args_.stride) / args_.stride); + const int out_r_end = std::min(args_.out_rows - 1, (in_r + args_.pad_rows) / args_.stride); + const int out_c_start = + std::max(0, (in_c - filter_cols + args_.pad_cols + args_.stride) / args_.stride); + const int out_c_end = std::min(args_.out_cols - 1, (in_c + args_.pad_cols) / args_.stride); + + for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) { + const int f_r = in_r + args_.pad_rows - out_r * args_.stride; + const int temp_out_backprop_offset = + args_.out_depth * args_.out_cols * (out_r + args_.out_rows * b); + const int temp_filter_offset = filter_cols * f_r; + for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) { + const int f_c = in_c + args_.pad_cols - out_c * args_.stride; + int filter_offset = + depth_multiplier * (in_d + args_.in_depth * (f_c + temp_filter_offset)); + const int out_backprop_offset = + args_.out_depth * out_c + temp_out_backprop_offset; +#pragma unroll 6 + for (int i = 0; i < depth_multiplier; ++i) { + sum += + out_backprop_data[out_backprop_offset + in_d * depth_multiplier + i] + * filter_data[filter_offset + i]; + } + } + } + const int in_backprop_offset = + in_d + args_.in_depth * (in_c + args_.in_cols * (in_r + args_.in_rows * b)); + in_backprop_data[in_backprop_offset] = sum; + + } +private: + const DepthwiseArgs args_; + const read_accessor out_backprop_accessor_; + const read_accessor filter_accessor_; + write_accessor in_backprop_accessor_; +}; + +template +void LaunchDepthwiseConv2dBackpropInputSYCL(const SYCLDevice& d, + const DepthwiseArgs args, + const Tensor& out_backprop, + const Tensor& filter, Tensor* in_backprop, + TensorFormat data_format) { + const int num_threads = in_backprop->NumElements(); + + auto out_backprop_buffer = + d.get_sycl_buffer(out_backprop.template flat().data()); + auto filter_buffer = + d.get_sycl_buffer(filter.template flat().data()); + auto in_backprop_buffer = + d.get_sycl_buffer(in_backprop->template flat().data()); + + d.sycl_queue().submit([&](cl::sycl::handler& cgh) { + auto out_backprop_access = + out_backprop_buffer + .template get_access(cgh); + auto filter_access = + filter_buffer + .template get_access(cgh); + auto in_backprop_access = + in_backprop_buffer + .template get_access(cgh); + + if(data_format == FORMAT_NHWC){ + DepthwiseConv2dBackpropInputSYCLKernelNHWC functor( + args, out_backprop_access, filter_access, in_backprop_access); + cgh.parallel_for(cl::sycl::range<1>(num_threads), functor); + } else { + assert(false && "Incorrect data format"); + return; + } + }); +} + +template +void LaunchDepthwiseConv2dBackpropInputSYCL(const SYCLDevice& d, + const DepthwiseArgs args, + const Tensor& out_backprop, + const Tensor& filter, Tensor* in_backprop, + TensorFormat data_format) { + if (args.depth_multiplier == 1) { + LaunchDepthwiseConv2dBackpropInputSYCL( + d, args, out_backprop, filter, in_backprop, data_format); + } else { + LaunchDepthwiseConv2dBackpropInputSYCL( + d, args, out_backprop, filter, in_backprop, data_format); + } +} + +template +struct LaunchDepthwiseConvBackpropInputOp { + static void launch(OpKernelContext* ctx, const DepthwiseArgs args, + const Tensor& out_backprop, const Tensor& filter, + Tensor* in_backprop, TensorFormat data_format) { + const SYCLDevice& d = ctx->eigen_device(); + if (args.filter_rows == 3 && args.filter_cols == 3) { + LaunchDepthwiseConv2dBackpropInputSYCL( + d, args, out_backprop, filter, in_backprop, data_format); + } else { + LaunchDepthwiseConv2dBackpropInputSYCL( + d, args, out_backprop, filter, in_backprop, data_format); + } + } +}; + +template +class DepthwiseConv2dNativeBackpropInputOp : public OpKernel { + public: + explicit DepthwiseConv2dNativeBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + + stride_ = GetTensorDim(strides_, data_format_, 'H'); + const int64 stride_w = GetTensorDim(strides_, data_format_, 'W'); + const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); + const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); + + OP_REQUIRES(context, stride_ == stride_w, + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES( + context, (stride_n == 1 && stride_c == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_sizes = context->input(0); + const Tensor& filter = context->input(1); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(input_sizes.shape()), + errors::InvalidArgument( + "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", + input_sizes.dims())); + TensorShape input_shape; + const int32* in_sizes_data = input_sizes.template flat().data(); + for (int i = 0; i < input_sizes.NumElements(); ++i) { + OP_REQUIRES(context, in_sizes_data[i] >= 0, + errors::InvalidArgument("Dimension ", i, + " of input_sizes must be >= 0")); + input_shape.AddDim(in_sizes_data[i]); + } + const TensorShape& filter_shape = filter.shape(); + EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropInput"); + Tensor* in_backprop = nullptr; + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {0}, 0, input_shape, &in_backprop)); + if (input_shape.num_elements() == 0) { + return; + } + LaunchDepthwiseConvBackpropInputOp::launch( + context, args, out_backprop, filter, in_backprop, + data_format_); + } + + private: + std::vector strides_; + Padding padding_; + TensorFormat data_format_; + int64 stride_; + + TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropInputOp); +}; + +REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .HostMemory("input_sizes"), + DepthwiseConv2dNativeBackpropInputOp); + +REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .HostMemory("input_sizes"), + DepthwiseConv2dNativeBackpropInputOp); +#endif // TENSORFLOW_USE_SYCL + #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \ .Device(DEVICE_CPU) \