Skip to content

Commit

Permalink
improve NMS implementation (#1173)
Browse files Browse the repository at this point in the history
  • Loading branch information
borongyuan committed Nov 28, 2023
1 parent 3a7f88d commit 0876603
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions corelib/src/util2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,7 @@ void NMS(
cv::Mat inds = cv::Mat(cv::Size(img_width, img_height), CV_16UC1);

cv::Mat confidence = cv::Mat(cv::Size(img_width, img_height), CV_32FC1);
cv::Mat dilated_conf = cv::Mat(cv::Size(img_width, img_height), CV_32FC1);

grid.setTo(0);
inds.setTo(0);
Expand All @@ -2134,11 +2135,8 @@ void NMS(
confidence.at<float>(vv, uu) = ptsIn[i].response;
}

// debug
//cv::Mat confidenceVis = confidence.clone() * 255;
//confidenceVis.convertTo(confidenceVis, CV_8UC1);
//cv::imwrite("confidence.bmp", confidenceVis);
//cv::imwrite("grid_in.bmp", grid);
cv::dilate(confidence, dilated_conf, cv::Mat());
cv::Mat peaks = confidence == dilated_conf;

cv::copyMakeBorder(grid, grid, dist_thresh, dist_thresh, dist_thresh, dist_thresh, cv::BORDER_CONSTANT, 0);

Expand All @@ -2151,20 +2149,25 @@ void NMS(

if (grid.at<unsigned char>(vv, uu) == 100) // If not yet suppressed.
{
for(int k = -dist_thresh; k < (dist_thresh+1); k++)
if (peaks.at<unsigned char>(vv-dist_thresh, uu-dist_thresh) == 255)
{
for(int j = -dist_thresh; j < (dist_thresh+1); j++)
for(int k = -dist_thresh; k < (dist_thresh+1); k++)
{
if((j==0 && k==0) || grid.at<unsigned char>(vv + k, uu + j) == 0)
continue;

if ( confidence.at<float>(vv + k - dist_thresh, uu + j - dist_thresh) <= c )
for(int j = -dist_thresh; j < (dist_thresh+1); j++)
{
grid.at<unsigned char>(vv + k, uu + j) = 0;
if ((j==0 && k==0) || grid.at<unsigned char>(vv + k, uu + j) == 0)
continue;

if (confidence.at<float>(vv + k - dist_thresh, uu + j - dist_thresh) <= c)
grid.at<unsigned char>(vv + k, uu + j) = 0;
}
}
grid.at<unsigned char>(vv, uu) = 255;
}
else
{
grid.at<unsigned char>(vv, uu) = 0;
}
grid.at<unsigned char>(vv, uu) = 255;
}
}

Expand All @@ -2173,9 +2176,6 @@ void NMS(

grid = cv::Mat(grid, cv::Rect(dist_thresh, dist_thresh, img_width, img_height));

//debug
//cv::imwrite("grid_nms.bmp", grid);

for (int v = 0; v < img_height; v++)
{
for (int u = 0; u < img_width; u++)
Expand Down

0 comments on commit 0876603

Please sign in to comment.