From f8cab888dc3803cf43788d246e3c82c77e332139 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Wed, 19 Jul 2017 17:53:50 +0100 Subject: [PATCH] [OpenCL] Registeres Dilation2D --- tensorflow/core/kernels/dilation_ops.cc | 207 ++++++++++++++++++++++++ 1 file changed, 207 insertions(+) diff --git a/tensorflow/core/kernels/dilation_ops.cc b/tensorflow/core/kernels/dilation_ops.cc index 6f5c0e91569eb5..d0d59704fb9870 100644 --- a/tensorflow/core/kernels/dilation_ops.cc +++ b/tensorflow/core/kernels/dilation_ops.cc @@ -39,6 +39,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL void ParseAttributes(OpKernelConstruction* context, std::vector* strides, std::vector* rates, Padding* padding) { @@ -208,6 +211,197 @@ struct Dilation { }; } // namespace functor +#ifdef TENSORFLOW_USE_SYCL +template +class Dilation2DSYCL { + using write_accessor = + cl::sycl::accessor; + using read_accessor = + cl::sycl::accessor; + + public: + Dilation2DSYCL(int batch, int input_rows, int input_cols, + int depth, int filter_rows, int filter_cols, + int output_rows, int output_cols, int stride_rows, + int stride_cols, int rate_rows, int rate_cols, + int pad_top, int pad_left, + const read_accessor input_accessor, + const read_accessor filter_accessor, + write_accessor output_accessor) + : batch_(batch), + input_rows_(input_rows), + input_cols_(input_cols), + depth_(depth), + filter_rows_(filter_rows), + filter_cols_(filter_cols), + output_rows_(output_rows), + output_cols_(output_cols), + stride_rows_(stride_rows), + stride_cols_(stride_cols), + rate_rows_(rate_rows), + rate_cols_(rate_cols), + pad_top_(pad_top), + pad_left_(pad_left), + input_accessor_(input_accessor), + filter_accessor_(filter_accessor), + output_accessor_(output_accessor) {} + void operator()(cl::sycl::item<1> item) { + T* input_data = ConvertToActualTypeSycl(T, input_accessor_); + T* filter_data = ConvertToActualTypeSycl(T, filter_accessor_); + T* output_data = ConvertToActualTypeSycl(T, output_accessor_); + + int out_idx = item.get_linear_id(); + const int d = out_idx % depth_; + const int out_idx2 = out_idx / depth_; + const int w_out = out_idx2 % output_cols_; + const int out_idx3 = out_idx2 / output_cols_; + const int h_out = out_idx3 % output_rows_; + const int b = out_idx3 / output_rows_; + int h_beg = h_out * stride_rows_ - pad_top_; + int w_beg = w_out * stride_cols_ - pad_left_; + T cur_val = Eigen::NumTraits::lowest(); + for (int h = 0; h < filter_rows_; ++h) { + const int h_in = h_beg + h * rate_rows_; + if (h_in >= 0 && h_in < input_rows_) { + for (int w = 0; w < filter_cols_; ++w) { + const int w_in = w_beg + w * rate_cols_; + if (w_in >= 0 && w_in < input_cols_) { + const T val = + input_data[d + + depth_ * + (w_in + input_cols_ * (h_in + input_rows_ * b))] + + filter_data[d + depth_ * (w + filter_cols_ * h)]; + if (val > cur_val) { + cur_val = val; + } + } + } + } + } + output_data[out_idx] = cur_val; + } + + private: + + int batch_; + int input_rows_; + int input_cols_; + int depth_; + int filter_rows_; + int filter_cols_; + int output_rows_; + int output_cols_; + int stride_rows_; + int stride_cols_; + int rate_rows_; + int rate_cols_; + int pad_top_; + int pad_left_; + const read_accessor input_accessor_; + const read_accessor filter_accessor_; + write_accessor output_accessor_; +}; + +namespace functor { + +template +struct Dilation { + void operator()(const SYCLDevice& device, const Tensor& input, + const Tensor& filter, int stride_rows, + int stride_cols, int rate_rows, int rate_cols, int pad_top, + int pad_left, Tensor* output) { + auto input_tensor = input.tensor(); + auto filter_tensor = filter.tensor(); + auto output_tensor = output->tensor(); + + const int batch = input_tensor.dimension(0); + const int input_rows = input_tensor.dimension(1); + const int input_cols = input_tensor.dimension(2); + const int depth = input_tensor.dimension(3); + + const int filter_rows = filter_tensor.dimension(0); + const int filter_cols = filter_tensor.dimension(1); + + const int output_rows = output_tensor.dimension(1); + const int output_cols = output_tensor.dimension(2); + + const int num_threads = output->NumElements(); + + auto input_buffer = + device.get_sycl_buffer(input.template flat().data()); + auto filter_buffer = + device.get_sycl_buffer(filter.template flat().data()); + auto output_buffer = + device.get_sycl_buffer(output->template flat().data()); + + device.sycl_queue().submit([&](cl::sycl::handler& cgh) { + auto input_access = + input_buffer.template get_access(cgh); + auto filter_access = + filter_buffer.template get_access(cgh); + auto output_access = + output_buffer.template get_access(cgh); + Dilation2DSYCL dilation(batch, input_rows, input_cols, depth, + filter_rows, filter_cols, output_rows, + output_cols, stride_rows, stride_cols, + rate_rows, rate_cols, pad_top, pad_left, + input_access, filter_access, output_access); + + cgh.parallel_for(cl::sycl::range<1>(num_threads), dilation); + }); + } +}; +} // namespace functor + +template +class DilationOp : public OpKernel { + public: + explicit DilationOp(OpKernelConstruction* context) : OpKernel(context) { + ParseAttributes(context, &strides_, &rates_, &padding_); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& filter = context->input(1); + + // Determine relevant sizes from input and filters. + int stride_rows = 0, stride_cols = 0; + int rate_rows = 0, rate_cols = 0; + int64 pad_top = 0, pad_left = 0; + int64 out_rows = 0, out_cols = 0; + ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols, + &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, + &out_cols); + + // Output tensor is of the following dimensions: + // [ batch, out_rows, out_cols, depth ] + const int batch = input.dim_size(0); + const int depth = input.dim_size(3); + const std::vector out_sizes = {batch, out_rows, out_cols, depth}; + TensorShape out_shape(out_sizes); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + + functor::Dilation()( + context->eigen_device(), input, + filter, stride_rows, stride_cols, rate_rows, rate_cols, + pad_top, pad_left, output); + } + + std::vector strides_; + std::vector rates_; + Padding padding_; +}; +#endif // TENSORFLOW_USE_SYCL + template class DilationBackpropInputOp : public OpKernel { public: @@ -488,4 +682,17 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Dilation2D").Device(DEVICE_SYCL).TypeConstraint("T"), \ + DilationOp); \ + +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER); + +#undef REGISTER + +#endif + } // namespace tensorflow