-
Notifications
You must be signed in to change notification settings - Fork 2
/
kernels.cu
114 lines (86 loc) · 2.04 KB
/
kernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#ifndef __KERNELS_CU__
#define __KERNELS_CU__
#include "FFTGPU_Kernels.cuh"
__device__ bool lowPassK(int centerX, int centerY, size_t sqradius) {
return centerX * centerX + centerY * centerY >= sqradius;
}
__device__ bool highPassK(int centerX, int centerY, size_t sqradius) {
return centerX * centerX + centerY * centerY < sqradius;
}
__device__ void cutFrequencies(
fComplex *dData,
int fftH,
int fftW,
int imageW,
int imageH,
const size_t sqradius,
bool(*func)(int, int, size_t)
)
{
const int y = blockDim.y * blockIdx.y + threadIdx.y;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int centerX = fftW >> 1;
const int centerY = fftH >> 1;
const int coorY = y - centerY;
const int coorX = x - centerX;
if (func(coorX, coorY, sqradius)) {
const int ind = y * fftW + x;
dData[ind].x = 0.0f;
dData[ind].y = 0.0f;
}
}
__global__ void lowPassFilter(
fComplex *dData,
int fftH,
int fftW,
int imageW,
int imageH,
const size_t sqradius
)
{
cutFrequencies(dData, fftH, fftW, imageW, imageH, sqradius, &lowPassK);
}
__global__ void highPassFilter(
fComplex *dData,
int fftH,
int fftW,
int imageW,
int imageH,
const size_t sqradius
)
{
cutFrequencies(dData, fftH, fftW, imageW, imageH, sqradius, &highPassK);
}
__global__ void fftShift(
fComplex *dData,
int fftH,
int fftW
)
{
const int y = blockDim.y * blockIdx.y + threadIdx.y;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int ind = y * fftW + x;
int shiftX, shiftY;
const int centerX = fftW / 2;
const int centerY = fftH / 2;
if (x < fftW && y < centerY) {
// 1q to 4q
if (x < centerX) {
shiftX = x + centerX;
shiftY = y + centerY;
}
// 2q to 3q
else if (x >= centerX) {
shiftX = x - centerX;
shiftY = y + centerY;
}
//const int indShift = shiftY * imageW + shiftX;
const fComplex v = fComplex{ dData[ind].x, dData[ind].y };
const int indShift = fftW * shiftY + shiftX;
dData[ind].x = dData[indShift].x;
dData[ind].y = dData[indShift].y;
dData[indShift].x = v.x;
dData[indShift].y = v.y;
}
}
#endif