-
Notifications
You must be signed in to change notification settings - Fork 45
/
stencil2d-gt4py-v0.py
158 lines (131 loc) · 4.29 KB
/
stencil2d-gt4py-v0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# ******************************************************
# Program: stencil2d-gt4py
# Author: Stefano Ubbiali
# Email: subbiali@phys.ethz.ch
# Date: 04.06.2020
# Description: GT4Py implementation of 4th-order diffusion
# ******************************************************
import click
import gt4py as gt
from gt4py import gtscript
import matplotlib.pyplot as plt
import numpy as np
import time
@gtscript.function
def laplacian(field):
# TODO
pass
def diffusion_defs(
in_field: gtscript.Field[float], out_field: gtscript.Field[float], *, alpha: float
):
# TODO
pass
def update_halo(field, num_halo):
# TODO
pass
def apply_diffusion(
diffusion_stencil, in_field, out_field, alpha, num_halo, num_iter=1
):
# origin and extent of the computational domain
origin = () # TODO
domain = () # TODO
for n in range(num_iter):
update_halo(in_field, num_halo)
# TODO: run the stencil
if n < num_iter - 1:
in_field, out_field = out_field, in_field
else:
update_halo(out_field, num_halo)
@click.command()
@click.option(
"--nx", type=int, required=True, help="Number of gridpoints in x-direction"
)
@click.option(
"--ny", type=int, required=True, help="Number of gridpoints in y-direction"
)
@click.option(
"--nz", type=int, required=True, help="Number of gridpoints in z-direction"
)
@click.option("--num_iter", type=int, required=True, help="Number of iterations")
@click.option(
"--num_halo",
type=int,
default=2,
help="Number of halo-points in x- and y-direction",
)
@click.option(
"--backend", type=str, required=False, default="numpy", help="GT4Py backend."
)
@click.option(
"--plot_result", type=bool, default=False, help="Make a plot of the result?"
)
def main(nx, ny, nz, num_iter, num_halo=2, backend="numpy", plot_result=False):
"""Driver for apply_diffusion that sets up fields and does timings."""
assert 0 < nx <= 1024 * 1024, "You have to specify a reasonable value for nx"
assert 0 < ny <= 1024 * 1024, "You have to specify a reasonable value for ny"
assert 0 < nz <= 1024, "You have to specify a reasonable value for nz"
assert (
0 < num_iter <= 1024 * 1024
), "You have to specify a reasonable value for num_iter"
assert (
2 <= num_halo <= 256
), "You have to specify a reasonable number of halo points"
assert backend in (
"numpy",
"gt:cpu_ifirst",
"gt:cpu_kfirst",
"gt:gpu",
"cuda",
), "You have to specify a reasonable value for backend"
alpha = 1.0 / 32.0
# default origin
dorigin = (num_halo, num_halo, 0)
# allocate input and output fields
in_field = None # TODO
out_field = None # TODO
# prepare input field
in_field[
num_halo + nx // 4 : num_halo + 3 * nx // 4,
num_halo + ny // 4 : num_halo + 3 * ny // 4,
nz // 4 : 3 * nz // 4,
] = 1.0
# write input field to file
# swap first and last axes for compatibility with day1/stencil2d.py
np.save("in_field", np.swapaxes(in_field, 0, 2))
if plot_result:
# plot initial field
plt.ioff()
plt.imshow(in_field[:, :, 0], origin="lower")
plt.colorbar()
plt.savefig("in_field.png")
plt.close()
# compile diffusion stencil
kwargs = {"verbose": True} if backend in ("gtx86", "gtmc", "gtcuda") else {}
diffusion_stencil = gtscript.stencil(
definition=diffusion_defs,
backend=backend,
externals={"laplacian": laplacian},
rebuild=False,
**kwargs,
)
# warmup caches
apply_diffusion(diffusion_stencil, in_field, out_field, alpha, num_halo)
# time the actual work
tic = time.time()
apply_diffusion(
diffusion_stencil, in_field, out_field, alpha, num_halo, num_iter=num_iter
)
toc = time.time()
print(f"Elapsed time for work = {toc - tic} s")
# save output field
# swap first and last axes for compatibility with day1/stencil2d.py
np.save("out_field", np.swapaxes(out_field, 0, 2))
if plot_result:
# plot the output field
plt.ioff()
plt.imshow(out_field[:, :, 0], origin="lower")
plt.colorbar()
plt.savefig("out_field.png")
plt.close()
if __name__ == "__main__":
main()