Skip to content

Commit

Permalink
Use FMA in separable resampling. (#5711)
Browse files Browse the repository at this point in the history
Use explicit FMA to avoid differences in the result across optimizers (and increase precision).

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient authored Nov 15, 2024
1 parent 2aaa1a3 commit a2d1bb1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 41 deletions.
10 changes: 5 additions & 5 deletions dali/kernels/imgproc/resample/bilinear_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ __device__ void LinearHorz_Channels(
for (int c = 0; c < channels; c++) {
float a = __ldg(&in0[c]);
float b = __ldg(&in1[c]);
float tmp = fmaf(b-a, q, a);
float tmp = __fmaf_rn(b-a, q, a);
out_row[cx + c] = ConvertSat<Dst>(tmp);
}
}
Expand Down Expand Up @@ -139,7 +139,7 @@ __device__ void LinearVert(
float src_y0, float scale,
Dst *__restrict__ out, ptrdiff_vec<1> out_strides,
const Src *__restrict__ in, ptrdiff_vec<1> in_strides, ivec2 in_shape, int channels) {
src_y0 += 0.5f * scale - 0.5f;
src_y0 += __fmaf_rn(scale, 0.5f, -0.5f);

ptrdiff_t out_stride = out_strides[0];
ptrdiff_t in_stride = in_strides[0];
Expand All @@ -150,7 +150,7 @@ __device__ void LinearVert(
const int j1 = hi.x * channels;

for (int y = lo.y + threadIdx.y; y < hi.y; y += blockDim.y) {
const float sy0f = y * scale + src_y0;
const float sy0f = __fmaf_rn(y, scale, src_y0);
const int sy0i = __float2int_rd(sy0f);
const float q = sy0f - sy0i;
const int sy0 = clamp(sy0i, 0, in_h-1);
Expand All @@ -163,7 +163,7 @@ __device__ void LinearVert(
for (int j = j0 + threadIdx.x; j < j1; j += blockDim.x) {
float a = __ldg(&in0[j]);
float b = __ldg(&in1[j]);
float tmp = fmaf(b-a, q, a);
float tmp = __fmaf_rn(b-a, q, a);
out_row[j] = ConvertSat<Dst>(tmp);
}
}
Expand Down Expand Up @@ -288,7 +288,7 @@ __device__ void LinearDepth(
for (int j = j0 + threadIdx.x; j < j1; j += blockDim.x) {
float a = __ldg(&in0[j]);
float b = __ldg(&in1[j]);
float tmp = fmaf(b-a, q, a);
float tmp = __fmaf_rn(b-a, q, a);
out_row[j] = ConvertSat<Dst>(tmp);
}
}
Expand Down
12 changes: 6 additions & 6 deletions dali/kernels/imgproc/resample/nearest_impl.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,14 +44,14 @@ __device__ void NNResample(
const Src *__restrict__ in, vec<1, ptrdiff_t> in_stride, ivec2 in_size, int channels) {
origin += 0.5f * scale;
for (int y = lo.y + threadIdx.y; y < hi.y; y += blockDim.y) {
int ysrc = floor_int(y * scale.y + origin.y);
int ysrc = floor_int(__fmaf_rn(y, scale.y, origin.y));
ysrc = clamp(ysrc, 0, in_size.y-1);

Dst *out_row = &out[y * out_stride.x];
const Src *in_row = &in[ysrc * in_stride.x];

for (int x = lo.x + threadIdx.x; x < hi.x; x += blockDim.x) {
int xsrc = floor_int(x * scale.x + origin.x);
int xsrc = floor_int(__fmaf_rn(x, scale.x, origin.x));
xsrc = clamp(xsrc, 0, in_size.x-1);
const Src *src_px = &in_row[xsrc * channels];
for (int c = 0; c < channels; c++)
Expand Down Expand Up @@ -82,21 +82,21 @@ __device__ void NNResample(
origin += 0.5f * scale;

for (int z = lo.z + threadIdx.z; z < hi.z; z += blockDim.z) {
int zsrc = floor_int(z * scale.z + origin.z);
int zsrc = floor_int(__fmaf_rn(z, scale.z, origin.z));
zsrc = clamp(zsrc, 0, in_size.z-1);

Dst *out_plane = &out[z * out_stride.y];
const Src *in_plane = &in[zsrc * in_stride.y];

for (int y = lo.y + threadIdx.y; y < hi.y; y += blockDim.y) {
int ysrc = floor_int(y * scale.y + origin.y);
int ysrc = floor_int(__fmaf_rn(y, scale.y, origin.y));
ysrc = clamp(ysrc, 0, in_size.y-1);

Dst *out_row = &out_plane[y * out_stride.x];
const Src *in_row = &in_plane[ysrc * in_stride.x];

for (int x = lo.x + threadIdx.x; x < hi.x; x += blockDim.x) {
int xsrc = floor_int(x * scale.x + origin.x);
int xsrc = floor_int(__fmaf_rn(x, scale.x, origin.x));
xsrc = clamp(xsrc, 0, in_size.x-1);
const Src *src_px = &in_row[xsrc * channels];
for (int c = 0; c < channels; c++)
Expand Down
60 changes: 30 additions & 30 deletions dali/kernels/imgproc/resample/resampling_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ __device__ void ResampleHorz_Channels(
ptrdiff_t in_stride = in_strides.x;
int in_w = in_shape.x;

src_x0 += 0.5f * scale - 0.5f - filter.anchor;
src_x0 += __fmaf_rn(scale, 0.5f, -0.5f) - filter.anchor;

const float filter_step = filter.scale;

Expand All @@ -80,18 +80,18 @@ __device__ void ResampleHorz_Channels(

for (int j = lo.x; j < hi.x; j += blockDim.x) {
int dx = j + threadIdx.x;
const float sx0f = dx * scale + src_x0;
const float sx0f = __fmaf_rn(dx, scale, src_x0);
const int sx0 = huge_kernel ? __float2int_rn(sx0f) : __float2int_ru(sx0f);
float f = (sx0 - sx0f) * filter_step;
__syncthreads();
if (huge_kernel) {
for (int k = threadIdx.x + blockDim.x*threadIdx.y; k < support; k += blockDim.x*blockDim.y) {
float flt = filter(f + k*filter_step);
float flt = filter(__fmaf_rn(k, filter_step, f));
coeffs[k] = flt;
}
} else {
for (int k = threadIdx.y; k < support; k += blockDim.y) {
float flt = filter(f + k*filter_step);
float flt = filter(__fmaf_rn(k, filter_step, f));
coeffs[coeff_base + coeff_stride*k] = flt;
}
}
Expand Down Expand Up @@ -119,7 +119,7 @@ __device__ void ResampleHorz_Channels(
int xsample = x < 0 ? 0 : x >= in_w-1 ? in_w-1 : x;
float flt = coeffs[coeff_idx];
Src px = __ldg(in_row + channels * xsample + c);
tmp = fmaf(px, flt, tmp);
tmp = __fmaf_rn(px, flt, tmp);
}

out_row[channels * dx + c] = ConvertSat<Dst>(tmp * norm);
Expand All @@ -136,7 +136,7 @@ __device__ void ResampleHorz_Channels(
float flt = coeffs[coeff_idx];
for (int c = 0; c < channels; c++) {
Src px = __ldg(in_row + channels * xsample + c);
tmp[c] = fmaf(px, flt, tmp[c]);
tmp[c] = __fmaf_rn(px, flt, tmp[c]);
}
}

Expand Down Expand Up @@ -186,7 +186,7 @@ __device__ void ResampleHorz_Channels(
ptrdiff_t in_stride_z = in_strides.y;
int in_w = in_shape.x;

src_x0 += 0.5f * scale - 0.5f - filter.anchor;
src_x0 += __fmaf_rn(scale, 0.5f, -0.5f) - filter.anchor;

const float filter_step = filter.scale;

Expand All @@ -198,18 +198,18 @@ __device__ void ResampleHorz_Channels(

for (int j = lo.x; j < hi.x; j += blockDim.x) {
int dx = j + threadIdx.x;
const float sx0f = dx * scale + src_x0;
const float sx0f = __fmaf_rn(dx, scale, src_x0);
const int sx0 = huge_kernel ? __float2int_rn(sx0f) : __float2int_ru(sx0f);
float f = (sx0 - sx0f) * filter_step;
__syncthreads();
if (huge_kernel) {
for (int k = threadIdx.x + blockDim.x*threadIdx.y; k < support; k += blockDim.x*blockDim.y) {
float flt = filter(f + k*filter_step);
float flt = filter(__fmaf_rn(k, filter_step, f));
coeffs[k] = flt;
}
} else {
for (int k = threadIdx.y; k < support; k += blockDim.y) {
float flt = filter(f + k*filter_step);
float flt = filter(__fmaf_rn(k, filter_step, f));
coeffs[coeff_base + coeff_stride*k] = flt;
}
}
Expand Down Expand Up @@ -241,7 +241,7 @@ __device__ void ResampleHorz_Channels(
int xsample = x < 0 ? 0 : x >= in_w-1 ? in_w-1 : x;
float flt = coeffs[coeff_idx];
Src px = __ldg(in_row + channels * xsample + c);
tmp = fmaf(px, flt, tmp);
tmp = __fmaf_rn(px, flt, tmp);
}

out_row[channels * dx + c] = ConvertSat<Dst>(tmp * norm);
Expand All @@ -258,7 +258,7 @@ __device__ void ResampleHorz_Channels(
float flt = coeffs[coeff_idx];
for (int c = 0; c < channels; c++) {
Src px = __ldg(in_row + channels * xsample + c);
tmp[c] = fmaf(px, flt, tmp[c]);
tmp[c] = __fmaf_rn(px, flt, tmp[c]);
}
}

Expand Down Expand Up @@ -297,7 +297,7 @@ __device__ void ResampleVert_Channels(
ptrdiff_t in_stride = in_strides.x;
int in_h = in_shape.y;

src_y0 += 0.5f * scale - 0.5f - filter.anchor;
src_y0 += __fmaf_rn(scale, 0.5f, -0.5f) - filter.anchor;

const float filter_step = filter.scale;

Expand All @@ -307,18 +307,18 @@ __device__ void ResampleVert_Channels(

for (int i = lo.y; i < hi.y; i+=blockDim.y) {
int dy = i + threadIdx.y;
const float sy0f = dy * scale + src_y0;
const float sy0f = __fmaf_rn(dy, scale, src_y0);
const int sy0 = huge_kernel ? __float2int_rn(sy0f) : __float2int_ru(sy0f);
float f = (sy0 - sy0f) * filter_step;
__syncthreads();
if (huge_kernel) {
for (int k = threadIdx.x + blockDim.x*threadIdx.y; k < support; k += blockDim.x*blockDim.y) {
float flt = filter(f + k*filter_step);
float flt = filter(__fmaf_rn(k, filter_step, f));
coeffs[k] = flt;
}
} else {
for (int k = threadIdx.x; k < support; k += blockDim.x) {
float flt = filter(f + k*filter_step);
float flt = filter(__fmaf_rn(k, filter_step, f));
coeffs[coeff_base + k] = flt;
}
}
Expand Down Expand Up @@ -348,7 +348,7 @@ __device__ void ResampleVert_Channels(
int ysample = y < 0 ? 0 : y >= in_h-1 ? in_h-1 : y;
float flt = coeffs[coeff_base + k];
Src px = __ldg(in_col + in_stride * ysample + c);
tmp = fmaf(px, flt, tmp);
tmp = __fmaf_rn(px, flt, tmp);
}

out_col[c] = ConvertSat<Dst>(tmp * norm);
Expand All @@ -364,7 +364,7 @@ __device__ void ResampleVert_Channels(
float flt = coeffs[coeff_base + k];
for (int c = 0; c < channels; c++) {
Src px = __ldg(in_col + in_stride * ysample + c);
tmp[c] = fmaf(px, flt, tmp[c]);
tmp[c] = __fmaf_rn(px, flt, tmp[c]);
}
}

Expand Down Expand Up @@ -403,7 +403,7 @@ __device__ void ResampleVert_Channels(
ptrdiff_t in_stride_z = in_strides.y;
int in_h = in_shape.y;

src_y0 += 0.5f * scale - 0.5f - filter.anchor;
src_y0 += __fmaf_rn(scale, 0.5f, -0.5f) - filter.anchor;

const float filter_step = filter.scale;

Expand All @@ -413,18 +413,18 @@ __device__ void ResampleVert_Channels(

for (int i = lo.y; i < hi.y; i+=blockDim.y) {
int dy = i + threadIdx.y;
const float sy0f = dy * scale + src_y0;
const float sy0f = __fmaf_rn(dy, scale, src_y0);
const int sy0 = huge_kernel ? __float2int_rn(sy0f) : __float2int_ru(sy0f);
float f = (sy0 - sy0f) * filter_step;
__syncthreads();
if (huge_kernel) {
for (int k = threadIdx.x + blockDim.x*threadIdx.y; k < support; k += blockDim.x*blockDim.y) {
float flt = filter(f + k*filter_step);
float flt = filter(__fmaf_rn(k, filter_step, f));
coeffs[k] = flt;
}
} else {
for (int k = threadIdx.x; k < support; k += blockDim.x) {
float flt = filter(f + k*filter_step);
float flt = filter(__fmaf_rn(k, filter_step, f));
coeffs[coeff_base + k] = flt;
}
}
Expand Down Expand Up @@ -457,7 +457,7 @@ __device__ void ResampleVert_Channels(
int ysample = y < 0 ? 0 : y >= in_h-1 ? in_h-1 : y;
float flt = coeffs[coeff_base + k];
Src px = __ldg(in_col + in_stride_y * ysample + c);
tmp = fmaf(px, flt, tmp);
tmp = __fmaf_rn(px, flt, tmp);
}

out_col[c] = ConvertSat<Dst>(tmp * norm);
Expand All @@ -473,7 +473,7 @@ __device__ void ResampleVert_Channels(
float flt = coeffs[coeff_base + k];
for (int c = 0; c < channels; c++) {
Src px = __ldg(in_col + in_stride_y * ysample + c);
tmp[c] = fmaf(px, flt, tmp[c]);
tmp[c] = __fmaf_rn(px, flt, tmp[c]);
}
}

Expand Down Expand Up @@ -520,7 +520,7 @@ __device__ void ResampleDepth_Channels(
ptrdiff_t in_stride_z = in_strides[1];
int in_d = in_shape.z;

src_z0 += 0.5f * scale - 0.5f - filter.anchor;
src_z0 += __fmaf_rn(scale, 0.5f, -0.5f) - filter.anchor;

const float filter_step = filter.scale;

Expand All @@ -531,18 +531,18 @@ __device__ void ResampleDepth_Channels(
// threadIdx.y is used to traverse Z axis
for (int i = lo.z; i < hi.z; i+=blockDim.y) {
int dz = i + threadIdx.y;
const float sz0f = dz * scale + src_z0;
const float sz0f = __fmaf_rn(dz, scale, src_z0);
const int sz0 = huge_kernel ? __float2int_rn(sz0f) : __float2int_ru(sz0f);
float f = (sz0 - sz0f) * filter_step;
__syncthreads();
if (huge_kernel) {
for (int k = threadIdx.x + blockDim.x*threadIdx.y; k < support; k += blockDim.x*blockDim.y) {
float flt = filter(f + k*filter_step);
float flt = filter(__fmaf_rn(k, filter_step, f));
coeffs[k] = flt;
}
} else {
for (int k = threadIdx.x; k < support; k += blockDim.x) {
float flt = filter(f + k*filter_step);
float flt = filter(__fmaf_rn(k, filter_step, f));
coeffs[coeff_base + k] = flt;
}
}
Expand Down Expand Up @@ -576,7 +576,7 @@ __device__ void ResampleDepth_Channels(
int zsample = z < 0 ? 0 : z >= in_d-1 ? in_d-1 : z;
float flt = coeffs[coeff_base + l];
Src px = __ldg(in_col + in_stride_z * zsample + c);
tmp = fmaf(px, flt, tmp);
tmp = __fmaf_rn(px, flt, tmp);
}

out_col[c] = ConvertSat<Dst>(tmp * norm);
Expand All @@ -592,7 +592,7 @@ __device__ void ResampleDepth_Channels(
float flt = coeffs[coeff_base + l];
for (int c = 0; c < channels; c++) {
Src px = __ldg(in_col + in_stride_z * zsample + c);
tmp[c] = fmaf(px, flt, tmp[c]);
tmp[c] = __fmaf_rn(px, flt, tmp[c]);
}
}

Expand Down

0 comments on commit a2d1bb1

Please sign in to comment.