diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 3565854d53..127f92bf95 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -57,3 +57,4 @@ We implement common ops used in detection, segmentation, etc. | TINShift | | √ | √ | | | UpFirDn2d | | √ | | | | Voxelization | √ | √ | | | +| PrRoIPool | | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 94a77218d0..82c9eb4fca 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -57,3 +57,4 @@ MMCV 提供了检测、分割等任务中常用的算子 | TINShift | | √ | √ | | | UpFirDn2d | | √ | | | | Voxelization | √ | √ | | | +| PrRoIPool | | √ | | | diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 87ad6f3d73..a65f14fff5 100755 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -46,6 +46,7 @@ points_in_boxes_part) from .points_in_polygons import points_in_polygons from .points_sampler import PointsSampler +from .prroi_pool import PrRoIPool, prroi_pool from .psa_mask import PSAMask from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated from .roi_align import RoIAlign, roi_align @@ -100,5 +101,6 @@ 'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons', 'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou', - 'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance' + 'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance', + 'PrRoIPool', 'prroi_pool' ] diff --git a/mmcv/ops/csrc/common/cuda/prroi_pool_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/prroi_pool_cuda_kernel.cuh new file mode 100644 index 0000000000..ea8c37e22a --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/prroi_pool_cuda_kernel.cuh @@ -0,0 +1,381 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Modified from +// https://github.com/vacancy/PreciseRoIPooling/blob/master/src/prroi_pooling_gpu_impl.cu +// Distributed under terms of the MIT license. +#ifndef PRROI_POOL_CUDA_KERNEL_CUH +#define PRROI_POOL_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__device__ static __forceinline__ T PrRoIPoolingGetData(const T *data, + const int h, + const int w, + const int height, + const int width) { + bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); + T retVal = overflow ? 0.0f : data[h * width + w]; + return retVal; +} + +template +__device__ static __forceinline__ T PrRoIPoolingGetCoeff(T dh, T dw) { + return (1.0f - abs(dh)) * (1.0f - abs(dw)); +} + +template +__device__ static __forceinline__ T PrRoIPoolingSingleCoorIntegral(T s, T t, + T c1, T c2) { + return 0.5 * (t * t - s * s) * (c2 - c1) + (t - s) * c1; +} + +template +__device__ static T PrRoIPoolingInterpolation(const T *data, const T h, + const T w, const int height, + const int width) { + T retVal = 0.0f; + int h1 = floorf(h); + int w1 = floorf(w); + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - T(h1), w - T(w1)); + h1 = floorf(h) + 1; + w1 = floorf(w); + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - T(h1), w - T(w1)); + h1 = floorf(h); + w1 = floorf(w) + 1; + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - T(h1), w - T(w1)); + h1 = floorf(h) + 1; + w1 = floorf(w) + 1; + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - T(h1), w - T(w1)); + return retVal; +} + +template +__device__ static T PrRoIPoolingMatCalculation(const T *this_data, + const int s_h, const int s_w, + const int e_h, const int e_w, + const T y0, const T x0, + const T y1, const T x1, + const int h0, const int w0) { + T alpha, beta, lim_alpha, lim_beta, tmp; + T sum_out = 0; + + alpha = x0 - T(s_w); + beta = y0 - T(s_h); + lim_alpha = x1 - T(s_w); + lim_beta = y1 - T(s_h); + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp; + + alpha = T(e_w) - x1; + lim_alpha = T(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp; + + alpha = x0 - T(s_w); + beta = T(e_h) - y1; + lim_alpha = x1 - T(s_w); + lim_beta = T(e_h) - y0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp; + + alpha = T(e_w) - x1; + lim_alpha = T(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp; + + return sum_out; +} + +template +__device__ static void PrRoIPoolingDistributeDiff(T *diff, const T top_diff, + const int h, const int w, + const int height, + const int width, + const T coeff) { + bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); + if (!overflow) atomicAdd(diff + h * width + w, top_diff * coeff); +} + +template +__device__ static void PrRoIPoolingMatDistributeDiff( + T *diff, const T top_diff, const int s_h, const int s_w, const int e_h, + const int e_w, const T y0, const T x0, const T y1, const T x1, const int h0, + const int w0) { + T alpha, beta, lim_alpha, lim_beta, tmp; + + alpha = x0 - T(s_w); + beta = y0 - T(s_h); + lim_alpha = x1 - T(s_w); + lim_beta = y1 - T(s_h); + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp); + + alpha = T(e_w) - x1; + lim_alpha = T(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp); + + alpha = x0 - T(s_w); + beta = T(e_h) - y1; + lim_alpha = x1 - T(s_w); + lim_beta = T(e_h) - y0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp); + + alpha = T(e_w) - x1; + lim_alpha = T(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp); +} + +template +__global__ void prroi_pool_forward_cuda_kernel( + const int nthreads, const T *input, const T *rois, T *output, + const int pooled_height, const int pooled_width, const T spatial_scale, + const int channels, const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T *offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + T roi_x1 = offset_rois[1] * spatial_scale; + T roi_y1 = offset_rois[2] * spatial_scale; + T roi_x2 = offset_rois[3] * spatial_scale; + T roi_y2 = offset_rois[4] * spatial_scale; + + T roi_width = max(roi_x2 - roi_x1, ((T)0.0)); + T roi_height = max(roi_y2 - roi_y1, ((T)0.0)); + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + const T *this_data = + input + (roi_batch_ind * channels + c) * height * width; + T *this_out = output + index; + + T bin_x1 = roi_x1 + bin_size_w * pw; + T bin_y1 = roi_y1 + bin_size_h * ph; + T bin_x2 = bin_x1 + bin_size_w; + T bin_y2 = bin_y1 + bin_size_h; + + T bin_size = max(T(0.0), bin_size_w * bin_size_h); + if (bin_size == 0) { + *this_out = 0; + continue; + } + + T sum_out = 0; + + int start_x, start_y, end_x, end_y; + + start_x = floorf(bin_x1); + end_x = ceilf(bin_x2); + start_y = floorf(bin_y1); + end_y = ceilf(bin_y2); + + for (int bin_x = start_x; bin_x < end_x; ++bin_x) + for (int bin_y = start_y; bin_y < end_y; ++bin_y) + sum_out += PrRoIPoolingMatCalculation( + this_data, bin_y, bin_x, bin_y + 1, bin_x + 1, + max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)), + min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height, + width); + *this_out = sum_out / bin_size; + } +} + +template +__global__ void prroi_pool_backward_cuda_kernel( + const int nthreads, const T *grad_output, const T *rois, T *grad_input, + const int pooled_height, const int pooled_width, const T spatial_scale, + const int channels, const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + rois += n * 5; + + int roi_batch_ind = rois[0]; + T roi_x1 = rois[1] * spatial_scale; + T roi_y1 = rois[2] * spatial_scale; + T roi_x2 = rois[3] * spatial_scale; + T roi_y2 = rois[4] * spatial_scale; + + T roi_width = max(roi_x2 - roi_x1, (T)0); + T roi_height = max(roi_y2 - roi_y1, (T)0); + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + const T *this_out_grad = grad_output + index; + T *this_data_grad = + grad_input + (roi_batch_ind * channels + c) * height * width; + + T bin_x1 = roi_x1 + bin_size_w * pw; + T bin_y1 = roi_y1 + bin_size_h * ph; + T bin_x2 = bin_x1 + bin_size_w; + T bin_y2 = bin_y1 + bin_size_h; + + T bin_size = max(T(0.0), bin_size_w * bin_size_h); + + T sum_out = bin_size == T(0) ? T(0) : *this_out_grad / bin_size; + + int start_x, start_y, end_x, end_y; + + start_x = floorf(bin_x1); + end_x = ceilf(bin_x2); + start_y = floorf(bin_y1); + end_y = ceilf(bin_y2); + + for (int bin_x = start_x; bin_x < end_x; ++bin_x) + for (int bin_y = start_y; bin_y < end_y; ++bin_y) + PrRoIPoolingMatDistributeDiff( + this_data_grad, sum_out, bin_y, bin_x, bin_y + 1, bin_x + 1, + max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)), + min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height, + width); + } +} + +template +__global__ void prroi_pool_coor_backward_cuda_kernel( + const int nthreads, const T *output, const T *grad_output, const T *input, + const T *rois, T *grad_rois, const int pooled_height, + const int pooled_width, const T spatial_scale, const int channels, + const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + rois += n * 5; + + int roi_batch_ind = rois[0]; + T roi_x1 = rois[1] * spatial_scale; + T roi_y1 = rois[2] * spatial_scale; + T roi_x2 = rois[3] * spatial_scale; + T roi_y2 = rois[4] * spatial_scale; + + T roi_width = max(roi_x2 - roi_x1, (T)0); + T roi_height = max(roi_y2 - roi_y1, (T)0); + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + const T output_grad_val = grad_output[index]; + const T *this_input_data = + input + (roi_batch_ind * channels + c) * height * width; + const T output_val = output[index]; + T *this_rois_grad = grad_rois + n * 5; + + T bin_x1 = roi_x1 + bin_size_w * pw; + T bin_y1 = roi_y1 + bin_size_h * ph; + T bin_x2 = bin_x1 + bin_size_w; + T bin_y2 = bin_y1 + bin_size_h; + + T bin_size = max(T(0.0), bin_size_w * bin_size_h); + + T sum_out = bin_size == T(0) ? T(0) : output_grad_val / bin_size; + + // WARNING: to be discussed + if (sum_out == 0) return; + + int start_x, start_y, end_x, end_y; + + start_x = floorf(bin_x1); + end_x = ceilf(bin_x2); + start_y = floorf(bin_y1); + end_y = ceilf(bin_y2); + + T grad_x1_y = 0, grad_x2_y = 0, grad_x_y1 = 0, grad_x_y2 = 0; + for (int bin_y = start_y; bin_y < end_y; ++bin_y) { + grad_x1_y += PrRoIPoolingSingleCoorIntegral( + max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y, + PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x1, + height, width), + PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x1, + height, width)); + + grad_x2_y += PrRoIPoolingSingleCoorIntegral( + max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y, + PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x2, + height, width), + PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x2, + height, width)); + } + + for (int bin_x = start_x; bin_x < end_x; ++bin_x) { + grad_x_y1 += PrRoIPoolingSingleCoorIntegral( + max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x, + PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x), + height, width), + PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x + 1), + height, width)); + + grad_x_y2 += PrRoIPoolingSingleCoorIntegral( + max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x, + PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x), + height, width), + PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x + 1), + height, width)); + } + + T partial_x1 = -grad_x1_y + (bin_y2 - bin_y1) * output_val; + T partial_y1 = -grad_x_y1 + (bin_x2 - bin_x1) * output_val; + T partial_x2 = grad_x2_y - (bin_y2 - bin_y1) * output_val; + T partial_y2 = grad_x_y2 - (bin_x2 - bin_x1) * output_val; + + partial_x1 = partial_x1 / bin_size * spatial_scale; + partial_x2 = partial_x2 / bin_size * spatial_scale; + partial_y1 = partial_y1 / bin_size * spatial_scale; + partial_y2 = partial_y2 / bin_size * spatial_scale; + + // (index, x1, y1, x2, y2) + this_rois_grad[0] = 0; + atomicAdd(this_rois_grad + 1, + (partial_x1 * (1.0f - T(pw) / pooled_width) + + partial_x2 * (1.0f - T(pw + 1) / pooled_width)) * + output_grad_val); + atomicAdd(this_rois_grad + 2, + (partial_y1 * (1.0f - T(ph) / pooled_height) + + partial_y2 * (1.0f - T(ph + 1) / pooled_height)) * + output_grad_val); + atomicAdd(this_rois_grad + 3, (partial_x2 * T(pw + 1) / pooled_width + + partial_x1 * T(pw) / pooled_width) * + output_grad_val); + atomicAdd(this_rois_grad + 4, (partial_y2 * T(ph + 1) / pooled_height + + partial_y1 * T(ph) / pooled_height) * + output_grad_val); + } +} + +#endif // ROI_POOL_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index 093e383ef2..12cf7afdc2 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -1737,3 +1737,54 @@ REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA, chamfer_distance_forward_cuda); REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA, chamfer_distance_backward_cuda); + +void PrROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale); + +void PrROIPoolBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale); + +void PrROIPoolCoorBackwardCUDAKernelLauncher( + Tensor output, Tensor grad_output, Tensor input, Tensor rois, + Tensor grad_rois, int pooled_height, int pooled_width, float spatial_scale); + +void prroi_pool_forward_cuda(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale) { + PrROIPoolForwardCUDAKernelLauncher(input, rois, output, pooled_height, + pooled_width, spatial_scale); +} + +void prroi_pool_backward_cuda(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale) { + PrROIPoolBackwardCUDAKernelLauncher(grad_output, rois, grad_input, + pooled_height, pooled_width, + spatial_scale); +} + +void prroi_pool_coor_backward_cuda(Tensor output, Tensor grad_output, + Tensor input, Tensor rois, Tensor grad_rois, + int pooled_height, int pooled_width, + float spatial_scale) { + PrROIPoolCoorBackwardCUDAKernelLauncher(output, grad_output, input, rois, + grad_rois, pooled_height, + pooled_width, spatial_scale); +} + +void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale); +void prroi_pool_backward_impl(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale); +void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output, + Tensor input, Tensor rois, Tensor grad_rois, + int pooled_height, int pooled_width, + float spatial_scale); +REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, CUDA, prroi_pool_forward_cuda); +REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, CUDA, prroi_pool_backward_cuda); +REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, CUDA, + prroi_pool_coor_backward_cuda); diff --git a/mmcv/ops/csrc/pytorch/cuda/prroi_pool_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/prroi_pool_cuda.cu new file mode 100644 index 0000000000..e0636098b1 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/prroi_pool_cuda.cu @@ -0,0 +1,65 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "prroi_pool_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void PrROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale) { + int output_size = output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + prroi_pool_forward_cuda_kernel + <<>>( + output_size, input.data_ptr(), rois.data_ptr(), + output.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), channels, height, width); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void PrROIPoolBackwardCUDAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, + float spatial_scale) { + int output_size = grad_output.numel(); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + prroi_pool_backward_cuda_kernel + <<>>( + output_size, grad_output.data_ptr(), rois.data_ptr(), + grad_input.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), channels, height, width); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void PrROIPoolCoorBackwardCUDAKernelLauncher(Tensor output, Tensor grad_output, + Tensor input, Tensor rois, + Tensor grad_rois, + int pooled_height, + int pooled_width, + float spatial_scale) { + int output_size = grad_output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + prroi_pool_coor_backward_cuda_kernel + <<>>( + output_size, output.data_ptr(), grad_output.data_ptr(), + input.data_ptr(), rois.data_ptr(), + grad_rois.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), channels, height, width); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/prroi_pool.cpp b/mmcv/ops/csrc/pytorch/prroi_pool.cpp new file mode 100644 index 0000000000..00db84a154 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/prroi_pool.cpp @@ -0,0 +1,47 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale) { + DISPATCH_DEVICE_IMPL(prroi_pool_forward_impl, input, rois, output, + pooled_height, pooled_width, spatial_scale); +} + +void prroi_pool_backward_impl(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale) { + DISPATCH_DEVICE_IMPL(prroi_pool_backward_impl, grad_output, rois, grad_input, + pooled_height, pooled_width, spatial_scale); +} + +void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output, + Tensor input, Tensor rois, Tensor grad_rois, + int pooled_height, int pooled_width, + float spatial_scale) { + DISPATCH_DEVICE_IMPL(prroi_pool_coor_backward_impl, output, grad_output, + input, rois, grad_rois, pooled_height, pooled_width, + spatial_scale); +} + +void prroi_pool_forward(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale) { + prroi_pool_forward_impl(input, rois, output, pooled_height, pooled_width, + spatial_scale); +} + +void prroi_pool_backward(Tensor grad_output, Tensor rois, Tensor grad_input, + int pooled_height, int pooled_width, + float spatial_scale) { + prroi_pool_backward_impl(grad_output, rois, grad_input, pooled_height, + pooled_width, spatial_scale); +} + +void prroi_pool_coor_backward(Tensor output, Tensor grad_output, Tensor input, + Tensor rois, Tensor grad_rois, int pooled_height, + int pooled_width, float spatial_scale) { + prroi_pool_coor_backward_impl(output, grad_output, input, rois, grad_rois, + pooled_height, pooled_width, spatial_scale); +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index afac6f426a..c134090871 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -240,6 +240,18 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, Tensor idx_tensor, int b, int n, int m, float min_radius, float max_radius, int nsample); +void prroi_pool_forward(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale); + +void prroi_pool_backward(Tensor grad_output, Tensor rois, Tensor grad_input, + int pooled_height, int pooled_width, + float spatial_scale); + +void prroi_pool_coor_backward(Tensor output, Tensor grad_output, Tensor input, + Tensor rois, Tensor grad_rois, int pooled_height, + int pooled_width, float spatial_scale); + template std::vector get_indice_pairs_forward( torch::Tensor indices, int64_t batchSize, @@ -828,4 +840,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "chamfer_distance_backward", py::arg("xyz1"), py::arg("xyz2"), py::arg("gradxyz1"), py::arg("gradxyz2"), py::arg("graddist1"), py::arg("graddist2"), py::arg("idx1"), py::arg("idx2")); + m.def("prroi_pool_forward", &prroi_pool_forward, "prroi_pool forward", + py::arg("input"), py::arg("rois"), py::arg("output"), + py::arg("pooled_height"), py::arg("pooled_width"), + py::arg("spatial_scale")); + m.def("prroi_pool_backward", &prroi_pool_backward, "prroi_pool_backward", + py::arg("grad_output"), py::arg("rois"), py::arg("grad_input"), + py::arg("pooled_height"), py::arg("pooled_width"), + py::arg("spatial_scale")); + m.def("prroi_pool_coor_backward", &prroi_pool_coor_backward, + "prroi_pool_coor_backward", py::arg("output"), py::arg("grad_output"), + py::arg("input"), py::arg("rois"), py::arg("grad_rois"), + py::arg("pooled_height"), py::arg("pooled_width"), + py::arg("spatial_scale")); } diff --git a/mmcv/ops/prroi_pool.py b/mmcv/ops/prroi_pool.py new file mode 100644 index 0000000000..47c223aa58 --- /dev/null +++ b/mmcv/ops/prroi_pool.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', + ['prroi_pool_forward', 'prroi_pool_backward', 'prroi_pool_coor_backward']) + + +class PrRoIPoolFunction(Function): + + @staticmethod + def symbolic(g, features, rois, output_size, spatial_scale): + return g.op( + 'mmcv::PrRoIPool', + features, + rois, + pooled_height_i=int(output_size[0]), + pooled_width_i=int(output_size[1]), + spatial_scale_f=float(spatial_scale)) + + @staticmethod + def forward(ctx, + features: torch.Tensor, + rois: torch.Tensor, + output_size: Tuple, + spatial_scale: float = 1.0) -> torch.Tensor: + if 'FloatTensor' not in features.type( + ) or 'FloatTensor' not in rois.type(): + raise ValueError( + 'Precise RoI Pooling only takes float input, got ' + f'{features.type()} for features and {rois.type()} for rois.') + + pooled_height = int(output_size[0]) + pooled_width = int(output_size[1]) + spatial_scale = float(spatial_scale) + + features = features.contiguous() + rois = rois.contiguous() + output_shape = (rois.size(0), features.size(1), pooled_height, + pooled_width) + output = features.new_zeros(output_shape) + params = (pooled_height, pooled_width, spatial_scale) + + ext_module.prroi_pool_forward(features, rois, output, *params) + ctx.params = params + # everything here is contiguous. + ctx.save_for_backward(features, rois, output) + + return output + + @staticmethod + @once_differentiable + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]: + features, rois, output = ctx.saved_tensors + grad_input = grad_output.new_zeros(*features.shape) + grad_coor = grad_output.new_zeros(*rois.shape) + + if features.requires_grad: + grad_output = grad_output.contiguous() + ext_module.prroi_pool_backward(grad_output, rois, grad_input, + *ctx.params) + if rois.requires_grad: + grad_output = grad_output.contiguous() + ext_module.prroi_pool_coor_backward(output, grad_output, features, + rois, grad_coor, *ctx.params) + + return grad_input, grad_coor, None, None, None + + +prroi_pool = PrRoIPoolFunction.apply + + +class PrRoIPool(nn.Module): + """The operation of precision RoI pooling. The implementation of PrRoIPool + is modified from https://github.com/vacancy/PreciseRoIPooling/ + + Precise RoI Pooling (PrRoIPool) is an integration-based (bilinear + interpolation) average pooling method for RoI Pooling. It avoids any + quantization and has a continuous gradient on bounding box coordinates. + It is: + + 1. different from the original RoI Pooling proposed in Fast R-CNN. PrRoI + Pooling uses average pooling instead of max pooling for each bin and has a + continuous gradient on bounding box coordinates. That is, one can take the + derivatives of some loss function w.r.t the coordinates of each RoI and + optimize the RoI coordinates. + 2. different from the RoI Align proposed in Mask R-CNN. PrRoI Pooling uses + a full integration-based average pooling instead of sampling a constant + number of points. This makes the gradient w.r.t. the coordinates + continuous. + + Args: + output_size (Union[int, tuple]): h, w. + spatial_scale (float, optional): scale the input boxes by this number. + Defaults to 1.0. + """ + + def __init__(self, + output_size: Union[int, tuple], + spatial_scale: float = 1.0): + super().__init__() + + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + + def forward(self, features: torch.Tensor, + rois: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + features (torch.Tensor): The feature map. + rois (torch.Tensor): The RoI bboxes in [tl_x, tl_y, br_x, br_y] + format. + + Returns: + torch.Tensor: The pooled results. + """ + return prroi_pool(features, rois, self.output_size, self.spatial_scale) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(output_size={self.output_size}, ' + s += f'spatial_scale={self.spatial_scale})' + return s diff --git a/tests/test_ops/test_prroi_pool.py b/tests/test_ops/test_prroi_pool.py new file mode 100644 index 0000000000..6ee471e828 --- /dev/null +++ b/tests/test_ops/test_prroi_pool.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest +import torch + +from mmcv.utils import IS_CUDA_AVAILABLE + +_USING_PARROTS = True +try: + from parrots.autograd import gradcheck +except ImportError: + from torch.autograd import gradcheck + + _USING_PARROTS = False + +inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]), + ([[[[1., 2.], [3., 4.]], [[4., 3.], [2., + 1.]]]], [[0., 0., 0., 1., 1.]]), + ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.], + [11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])] +outputs = [ + ([[[[1.75, 2.25], [2.75, 3.25]]]], [[[[1., 1.], + [1., 1.]]]], [[0., 2., 4., 2., 4.]]), + ([[[[1.75, 2.25], [2.75, 3.25]], + [[3.25, 2.75], [2.25, 1.75]]]], [[[[1., 1.], [1., 1.]], + [[1., 1.], + [1., 1.]]]], [[0., 0., 0., 0., 0.]]), + ([[[[3.75, 6.91666651], + [10.08333302, + 13.25]]]], [[[[0.11111111, 0.22222224, 0.22222222, 0.11111111], + [0.22222224, 0.444444448, 0.44444448, 0.22222224], + [0.22222224, 0.44444448, 0.44444448, 0.22222224], + [0.11111111, 0.22222224, 0.22222224, 0.11111111]]]], + [[0.0, 3.33333302, 6.66666603, 3.33333349, 6.66666698]]) +] + + +class TestPrRoiPool: + + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + ]) + def test_roipool_gradcheck(self, device): + from mmcv.ops import PrRoIPool + pool_h = 2 + pool_w = 2 + spatial_scale = 1.0 + + for case in inputs: + np_input = np.array(case[0], dtype=np.float32) + np_rois = np.array(case[1], dtype=np.float32) + + x = torch.tensor(np_input, device=device, requires_grad=True) + rois = torch.tensor(np_rois, device=device) + + froipool = PrRoIPool((pool_h, pool_w), spatial_scale) + + if _USING_PARROTS: + pass + # gradcheck(froipool, (x, rois), no_grads=[rois]) + else: + gradcheck(froipool, (x, rois), eps=1e-2, atol=1e-2) + + def _test_roipool_allclose(self, device, dtype=torch.float): + from mmcv.ops import prroi_pool + pool_h = 2 + pool_w = 2 + spatial_scale = 1.0 + + for case, output in zip(inputs, outputs): + np_input = np.array(case[0], dtype=np.float32) + np_rois = np.array(case[1], dtype=np.float32) + np_output = np.array(output[0], dtype=np.float32) + np_input_grad = np.array(output[1], dtype=np.float32) + np_rois_grad = np.array(output[2], dtype=np.float32) + + x = torch.tensor( + np_input, dtype=dtype, device=device, requires_grad=True) + rois = torch.tensor( + np_rois, dtype=dtype, device=device, requires_grad=True) + + output = prroi_pool(x, rois, (pool_h, pool_w), spatial_scale) + output.backward(torch.ones_like(output)) + assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3) + assert np.allclose(x.grad.data.cpu().numpy(), np_input_grad, 1e-3) + assert np.allclose(rois.grad.data.cpu().numpy(), np_rois_grad, + 1e-3) + + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + ]) + def test_roipool_allclose_float(self, device): + self._test_roipool_allclose(device, dtype=torch.float)