From 23e4347b4a2e26c807a28453c413753317b30e55 Mon Sep 17 00:00:00 2001 From: Josh Holmer Date: Wed, 4 May 2022 03:38:25 -0400 Subject: [PATCH] Autovectorization pass --- src/denoise.rs | 107 ++++++++++++++++++++++++++++--------------------- 1 file changed, 62 insertions(+), 45 deletions(-) diff --git a/src/denoise.rs b/src/denoise.rs index bd2ecf3d35..93cfb499ae 100644 --- a/src/denoise.rs +++ b/src/denoise.rs @@ -235,38 +235,45 @@ where .map(|f| f[p].data_origin()) .collect::>(); - for y in (0..effective_height).step_by(INC) { - for x in (0..=(pad_width - SB_SIZE)).step_by(INC) { - for z in 0..TB_SIZE { - self.proc0( - &src_planes[z][x..], - &self.hw[(BLOCK_AREA * z)..], - &mut dftr[(BLOCK_AREA * z)..], - src_stride, + // SAFETY: We know the size of the planes we're working on, + // so we can safely ensure we are not out of bounds. + // There are a fair number of unsafe function calls here + // which are unsafe for optimization purposes. + // All are safe as long as we do not pass out-of-bounds parameters. + unsafe { + for y in (0..effective_height).step_by(INC) { + for x in (0..=(pad_width - SB_SIZE)).step_by(INC) { + for z in 0..TB_SIZE { + self.proc0( + &src_planes[z][x..], + &self.hw[(BLOCK_AREA * z)..], + &mut dftr[(BLOCK_AREA * z)..], + src_stride, + SB_SIZE, + self.src_scale, + ); + } + + self.real_to_complex_3d(&dftr, &mut dftc); + self.remove_mean(&mut dftc, &self.dftgc, &mut means); + + self.filter_coeffs(&mut dftc); + + self.add_mean(&mut dftc, &means); + self.complex_to_real_3d(&dftc, &mut dftr); + + self.proc1( + &dftr[(TB_MIDPOINT * BLOCK_AREA)..], + &self.hw[(TB_MIDPOINT * BLOCK_AREA)..], + &mut ebuff[(y * ebuff_stride + x)..], SB_SIZE, - self.src_scale, + ebuff_stride, ); } - self.real_to_complex_3d(&dftr, &mut dftc); - self.remove_mean(&mut dftc, &self.dftgc, &mut means); - - self.filter_coeffs(&mut dftc); - - self.add_mean(&mut dftc, &means); - self.complex_to_real_3d(&dftc, &mut dftr); - - self.proc1( - &dftr[(TB_MIDPOINT * BLOCK_AREA)..], - &self.hw[(TB_MIDPOINT * BLOCK_AREA)..], - &mut ebuff[(y * ebuff_stride + x)..], - SB_SIZE, - ebuff_stride, - ); - } - - for q in 0..TB_SIZE { - src_planes[q] = &src_planes[q][(INC * src_stride)..]; + for q in 0..TB_SIZE { + src_planes[q] = &src_planes[q][(INC * src_stride)..]; + } } } @@ -313,6 +320,7 @@ where hw } + #[inline(always)] // Hanning windowing fn spatial_window(n: f64) -> f64 { 0.5 - 0.5 * (2.0 * PI * n / SB_SIZE as f64).cos() @@ -345,35 +353,44 @@ where } } - fn proc0( + #[inline] + unsafe fn proc0( &self, s0: &[T], s1: &[f32], dest: &mut [f32], p0: usize, p1: usize, src_scale: f32, ) { - let s0 = s0.chunks(p0); - let s1 = s1.chunks(p1); - let dest = dest.chunks_mut(p1); + let s0 = s0.as_ptr(); + let s1 = s1.as_ptr(); + let dest = dest.as_mut_ptr(); - for (s0, (s1, dest)) in s0.zip(s1.zip(dest)).take(p1) { + for u in 0..p1 { for v in 0..p1 { - dest[v] = u16::cast_from(s0[v]) as f32 * src_scale * s1[v]; + let s0 = s0.add(u * p0 + v); + let s1 = s1.add(u * p1 + v); + let dest = dest.add(u * p1 + v); + dest.write(u16::cast_from(s0.read()) as f32 * src_scale * s1.read()) } } } - fn proc1( + #[inline] + unsafe fn proc1( &self, s0: &[f32], s1: &[f32], dest: &mut [f32], p0: usize, p1: usize, ) { - let s0 = s0.chunks(p0); - let s1 = s1.chunks(p0); - let dest = dest.chunks_mut(p1); + let s0 = s0.as_ptr(); + let s1 = s1.as_ptr(); + let dest = dest.as_mut_ptr(); - for (s0, (s1, dest)) in s0.zip(s1.zip(dest)).take(p0) { + for u in 0..p0 { for v in 0..p0 { - dest[v] += s0[v] * s1[v]; + let s0 = s0.add(u * p0 + v); + let s1 = s1.add(u * p0 + v); + let dest = dest.add(u * p1 + v); + dest.write(s0.read().mul_add(s1.read(), dest.read())); } } } + #[inline] fn remove_mean( &self, dftc: &mut [Complex; COMPLEX_COUNT], dftgc: &[Complex; COMPLEX_COUNT], @@ -389,6 +406,7 @@ where } } + #[inline] fn add_mean( &self, dftc: &mut [Complex; COMPLEX_COUNT], means: &[Complex; COMPLEX_COUNT], @@ -399,6 +417,7 @@ where } } + #[inline] // Applies a generalized wiener filter fn filter_coeffs(&self, dftc: &mut [Complex; COMPLEX_COUNT]) { for h in 0..COMPLEX_COUNT { @@ -495,11 +514,8 @@ where for (ebuff, dest) in ebuff.zip(dest).take(dest_height) { for x in 0..dest_width { let fval = ebuff[x].mul_add(self.dest_scale, 0.5); - dest[x] = clamp( - T::cast_from(fval.round() as u16), - T::cast_from(0u16), - self.peak, - ); + dest[x] = + clamp(T::cast_from(fval as u16), T::cast_from(0u16), self.peak); } } } @@ -544,6 +560,7 @@ where } } +#[inline(always)] fn extra(a: usize, b: usize) -> usize { if a % b > 0 { b - (a % b)