Skip to content

Commit

Permalink
unit test ok
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolas-f committed Nov 15, 2023
1 parent 43a5744 commit 36bbfd0
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,23 @@ import kotlin.math.log2
import kotlin.math.pow
import kotlin.math.sin

fun realFFT(length: Int, realArray: DoubleArray) {
val m = log2(length.toDouble()).toInt()
/**
* The bitwise AND operation of the number and its predecessor (number - 1) should result in 0.
* This is because powers of two in binary have only one bit set, and subtracting 1 from them flips
* that bit and sets all the lower bits to 1. So, the bitwise AND with the original number should
* be 0 if it is a power of two.
*/
fun isPowerOfTwo(number: Int): Boolean {
return number > 0 && (number and (number - 1)) == 0
}

fun realFFT(realArray: DoubleArray): DoubleArray {
require(isPowerOfTwo(realArray.size))
val outArray = DoubleArray(realArray.size + 2)
realArray.copyInto(outArray)
val m = log2(realArray.size.toDouble()).toInt()
val n = (2.0.pow(m) + 0.5).toInt()
require(n <= realArray.size)
fft(n/2, realArray)
fft(n/2, outArray)
val a = DoubleArray(n)
val b = DoubleArray(n)

Expand All @@ -24,17 +36,25 @@ fun realFFT(length: Int, realArray: DoubleArray) {
for (k in 1 until n / 4 + 1) {
val k2 = 2 * k
val xr =
realArray[k2] * a[k2] - realArray[k2 + 1] * a[k2 + 1] + realArray[n - k2] * b[k2] + realArray[n - k2 + 1] * b[k2 + 1]
outArray[k2] * a[k2] - outArray[k2 + 1] * a[k2 + 1] + outArray[n - k2] * b[k2] + outArray[n - k2 + 1] * b[k2 + 1]
val xi =
realArray[k2] * a[k2 + 1] + realArray[k2 + 1] * a[k2] + realArray[n - k2] * b[k2 + 1] - realArray[n - k2 + 1] * b[k2]
outArray[k2] * a[k2 + 1] + outArray[k2 + 1] * a[k2] + outArray[n - k2] * b[k2 + 1] - outArray[n - k2 + 1] * b[k2]
val xrN =
realArray[n - k2] * a[n - k2] - realArray[n - k2 + 1] * a[n - k2 + 1] + realArray[k2] * b[n - k2] + realArray[k2 + 1] * b[n - k2 + 1]
outArray[n - k2] * a[n - k2] - outArray[n - k2 + 1] * a[n - k2 + 1] + outArray[k2] * b[n - k2] + outArray[k2 + 1] * b[n - k2 + 1]
val xiN =
realArray[n - k2] * a[n - k2 + 1] + realArray[n - k2 + 1] * a[n - k2] + realArray[k2] * b[n - k2 + 1] - realArray[k2 + 1] * b[n - k2]
realArray[k2] = xr
realArray[k2 + 1] = xi
realArray[n - k2] = xrN
realArray[n - k2 + 1] = xiN
outArray[n - k2] * a[n - k2 + 1] + outArray[n - k2 + 1] * a[n - k2] + outArray[k2] * b[n - k2 + 1] - outArray[k2 + 1] * b[n - k2]
outArray[k2] = xr
outArray[k2 + 1] = xi
outArray[n - k2] = xrN
outArray[n - k2 + 1] = xiN
}

val tmp = outArray[0]
outArray[n] = outArray[0] - outArray[1]
outArray[0] = tmp + outArray[1]
outArray[1] = 0.0
outArray[n+1] = 0.0

return outArray
}

Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import kotlin.math.log2
import kotlin.math.pow
import kotlin.math.sin

fun realIFFT(length: Int, realArray: DoubleArray) {
val m = log2(length.toDouble()).toInt()
fun realIFFT(realArray: DoubleArray) : DoubleArray {
require(isPowerOfTwo(realArray.size - 2))
val outArray = realArray.copyOf()
val m = log2((realArray.size - 2).toDouble()).toInt()
val n = (2.0.pow(m) + 0.5).toInt()
require(n <= realArray.size)
val a = DoubleArray(n)
val b = DoubleArray(n)

Expand All @@ -23,21 +24,22 @@ fun realIFFT(length: Int, realArray: DoubleArray) {
for (k in 1 until n / 4 + 1) {
val k2 = 2 * k
val xr =
realArray[k2] * a[k2] + realArray[k2 + 1] * a[k2 + 1] + realArray[n - k2] * b[k2] - realArray[n - k2 + 1] * b[k2 + 1]
outArray[k2] * a[k2] + outArray[k2 + 1] * a[k2 + 1] + outArray[n - k2] * b[k2] - outArray[n - k2 + 1] * b[k2 + 1]
val xi =
-realArray[k2] * a[k2 + 1] + realArray[k2 + 1] * a[k2] - realArray[n - k2] * b[k2 + 1] - realArray[n - k2 + 1] * b[k2]
-outArray[k2] * a[k2 + 1] + outArray[k2 + 1] * a[k2] - outArray[n - k2] * b[k2 + 1] - outArray[n - k2 + 1] * b[k2]
val xrN =
realArray[n - k2] * a[n - k2] + realArray[n - k2 + 1] * a[n - k2 + 1] + realArray[k2] * b[n - k2] - realArray[k2 + 1] * b[n - k2 + 1]
outArray[n - k2] * a[n - k2] + outArray[n - k2 + 1] * a[n - k2 + 1] + outArray[k2] * b[n - k2] - outArray[k2 + 1] * b[n - k2 + 1]
val xiN =
-realArray[n - k2] * a[n - k2 + 1] + realArray[n - k2 + 1] * a[n - k2] - realArray[k2] * b[n - k2 + 1] - realArray[k2 + 1] * b[n - k2]
realArray[k2] = xr
realArray[k2 + 1] = xi
realArray[n - k2] = xrN
realArray[n - k2 + 1] = xiN
-outArray[n - k2] * a[n - k2 + 1] + outArray[n - k2 + 1] * a[n - k2] - outArray[k2] * b[n - k2 + 1] - outArray[k2 + 1] * b[n - k2]
outArray[k2] = xr
outArray[k2 + 1] = xi
outArray[n - k2] = xrN
outArray[n - k2 + 1] = xiN
}

val temp = realArray[0]
realArray[0] = 0.5 * realArray[0] + 0.5 * realArray[n-1]
realArray[1] = 0.5 * temp - 0.5 * realArray[n-1]
iFFT(n/2, realArray)
val temp = outArray[0]
outArray[0] = 0.5 * outArray[0] + 0.5 * outArray[n]
outArray[1] = 0.5 * temp - 0.5 * outArray[n]
iFFT(n/2, outArray)
return outArray.copyOf(outArray.size - 2)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,28 @@ class TestFFT {
val sampleRate = 64.0 // Hz
val duration = 2.0.pow(ceil(log2(sampleRate))) / sampleRate
val samples = generateSinusoidalSignal(frequency, sampleRate, duration)
val fftResult = samples.copyOf()
realFFT(samples.size, fftResult)
realIFFT(samples.size, fftResult)
fftResult.forEachIndexed {
index, value -> assertEquals(value, samples[index], 1e-8)
val spectrum = realFFT(samples)
val result = realIFFT(spectrum)
assertEquals(samples.size, result.size)
samples.forEachIndexed {
index, value -> assertEquals(value, result[index], 1e-8)
}
}

@Test
fun testRFFTIncremental() {
val samples = doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)
val expected =
doubleArrayOf(36.0, 0.0, -4.0, 9.65685425, -4.0, 4.0, -4.0, 1.65685425, -4.0, 0.0)
val result = realFFT(samples)
assertEquals(expected.size, result.size)
expected.forEachIndexed { index, value ->
assertEquals(value, result[index], 1e-8)
}
val origin = realIFFT(result)
assertEquals(samples.size, origin.size)
samples.forEachIndexed { index, value ->
assertEquals(value, origin[index], 1e-8)
}
}
}

0 comments on commit 36bbfd0

Please sign in to comment.