1
1
import numpy as np
2
- import numpy as np
3
- import gt4py .next as gtx
4
2
import gt4py .cartesian .gtscript as gtscript
3
+ import gt4py .storage as gts
4
+ from gt4py .storage .cartesian import utils as gts_utils
5
5
from time import perf_counter
6
- # import cupy
7
- # import gt4py
8
6
import initial_conditions
9
7
import utils
10
8
import config
11
9
12
- I = gtx .Dimension ("I" )
13
- J = gtx .Dimension ("J" )
14
- K = gtx .Dimension ("K" , kind = gtx .DimensionKind .VERTICAL )
15
-
16
10
dtype = np .float64
17
11
18
12
cartesian_backend = config .backend
19
- allocator = gtx .gtfn_cpu
20
- if cartesian_backend in ("gt:gpu" , "cuda" , "dace:gpu" ):
21
- allocator = gtx .gtfn_gpu
22
13
23
- print (f"Using { cartesian_backend } backend with { allocator . __name__ } allocator ." )
14
+ print (f"Using { cartesian_backend } backend." )
24
15
25
16
@gtscript .stencil (backend = cartesian_backend )
26
- def calc_cucvzh (u : gtscript .Field [dtype ], v : gtscript .Field [dtype ], p : gtscript .Field [dtype ], cu : gtscript .Field [dtype ], cv : gtscript .Field [dtype ], z : gtscript .Field [dtype ], h : gtscript .Field [dtype ], fsdx :float , fsdy : float ):
17
+ def calc_cucvzh (u : gtscript .Field [dtype ],
18
+ v : gtscript .Field [dtype ],
19
+ p : gtscript .Field [dtype ],
20
+ cu : gtscript .Field [dtype ],
21
+ cv : gtscript .Field [dtype ],
22
+ z : gtscript .Field [dtype ],
23
+ h : gtscript .Field [dtype ],
24
+ fsdx :float ,
25
+ fsdy : float ):
26
+
27
27
with computation (PARALLEL ), interval (...):
28
28
cu = .5 * (p [1 ,0 ,0 ] + p ) * u
29
29
cv = .5 * (p [0 ,1 ,0 ] + p ) * v
@@ -44,8 +44,8 @@ def calc_uvp(
44
44
h : gtscript .Field [dtype ],
45
45
unew : gtscript .Field [dtype ],
46
46
vnew : gtscript .Field [dtype ],
47
- pnew : gtscript .Field [dtype ]
48
- ):
47
+ pnew : gtscript .Field [dtype ]):
48
+
49
49
with computation (PARALLEL ), interval (...):
50
50
unew = uold + tdts8 * (z + z [0 ,- 1 ,0 ]) * (cv [1 ,0 ,0 ] + cv + cv [0 ,- 1 ,0 ] + cv [1 ,- 1 ,0 ]) - tdtsdx * (h [1 ,0 ,0 ] - h )
51
51
vnew = vold - tdts8 * (z + z [- 1 ,0 ,0 ]) * (cu [- 1 ,0 ,0 ] + cu [- 1 ,1 ,0 ] + cu + cu [0 ,1 ,0 ]) - tdtsdy * (h [0 ,1 ,0 ] - h )
@@ -62,16 +62,22 @@ def calc_uvp_old(
62
62
uold : gtscript .Field [dtype ],
63
63
p : gtscript .Field [dtype ],
64
64
pnew : gtscript .Field [dtype ],
65
- pold : gtscript .Field [dtype ],
66
- ):
65
+ pold : gtscript .Field [dtype ]):
66
+
67
67
with computation (PARALLEL ), interval (...):
68
68
uold = u + alpha * (unew - 2 * u + uold )
69
69
vold = v + alpha * (vnew - 2 * v + vold )
70
70
pold = p + alpha * (pnew - 2 * p + pold )
71
71
72
72
73
73
@gtscript .stencil (backend = cartesian_backend )
74
- def copy_3var (inp0 : gtscript .Field [dtype ], inp1 : gtscript .Field [dtype ], inp2 : gtscript .Field [dtype ], out0 : gtscript .Field [dtype ], out1 : gtscript .Field [dtype ], out2 : gtscript .Field [dtype ]):
74
+ def copy_3var (inp0 : gtscript .Field [dtype ],
75
+ inp1 : gtscript .Field [dtype ],
76
+ inp2 : gtscript .Field [dtype ],
77
+ out0 : gtscript .Field [dtype ],
78
+ out1 : gtscript .Field [dtype ],
79
+ out2 : gtscript .Field [dtype ]):
80
+
75
81
with computation (PARALLEL ), interval (...):
76
82
out0 = inp0
77
83
out1 = inp1
@@ -86,35 +92,28 @@ def main():
86
92
dt25 = 0.
87
93
dt3 = 0.
88
94
89
-
90
-
91
- M_LEN = config .M_LEN
92
- N_LEN = config .N_LEN
93
95
M = config .M
94
96
N = config .N
95
-
96
97
97
98
_u , _v , _p = initial_conditions .initialize (M , N , config .dx , config .dy , config .a )
98
99
_u = _u [:,:,np .newaxis ]
99
100
_v = _v [:,:,np .newaxis ]
100
101
_p = _p [:,:,np .newaxis ]
101
102
102
- domain = gtx .domain ({I :M + 1 , J :N + 1 , K :1 })
103
-
104
- h_gt = gtx .empty (domain ,dtype = dtype ,allocator = allocator )
105
- z_gt = gtx .empty (domain ,dtype = dtype ,allocator = allocator )
106
- cu_gt = gtx .empty (domain ,dtype = dtype ,allocator = allocator )
107
- cv_gt = gtx .empty (domain ,dtype = dtype ,allocator = allocator )
108
- pnew_gt = gtx .empty (domain ,dtype = dtype ,allocator = allocator )
109
- unew_gt = gtx .empty (domain ,dtype = dtype ,allocator = allocator )
110
- vnew_gt = gtx .empty (domain ,dtype = dtype ,allocator = allocator )
111
- uold_gt = gtx .empty (domain ,dtype = dtype ,allocator = allocator )
112
- vold_gt = gtx .empty (domain ,dtype = dtype ,allocator = allocator )
113
- pold_gt = gtx .empty (domain ,dtype = dtype ,allocator = allocator )
114
-
115
- u_gt = gtx .as_field (domain ,_u ,allocator = allocator )
116
- v_gt = gtx .as_field (domain ,_v ,allocator = allocator )
117
- p_gt = gtx .as_field (domain ,_p ,allocator = allocator )
103
+ shape = (M + 1 ,N + 1 ,1 )
104
+ h_gt = gts .empty (dtype = dtype ,backend = cartesian_backend ,shape = shape )
105
+ z_gt = gts .empty (dtype = dtype ,backend = cartesian_backend ,shape = shape )
106
+ cu_gt = gts .empty (dtype = dtype ,backend = cartesian_backend ,shape = shape )
107
+ cv_gt = gts .empty (dtype = dtype ,backend = cartesian_backend ,shape = shape )
108
+ pnew_gt = gts .empty (dtype = dtype ,backend = cartesian_backend ,shape = shape )
109
+ unew_gt = gts .empty (dtype = dtype ,backend = cartesian_backend ,shape = shape )
110
+ vnew_gt = gts .empty (dtype = dtype ,backend = cartesian_backend ,shape = shape )
111
+ uold_gt = gts .empty (dtype = dtype ,backend = cartesian_backend ,shape = shape )
112
+ vold_gt = gts .empty (dtype = dtype ,backend = cartesian_backend ,shape = shape )
113
+ pold_gt = gts .empty (dtype = dtype ,backend = cartesian_backend ,shape = shape )
114
+ u_gt = gts .from_array (_u ,dtype = dtype ,backend = cartesian_backend )
115
+ v_gt = gts .from_array (_v ,dtype = dtype ,backend = cartesian_backend )
116
+ p_gt = gts .from_array (_p ,dtype = dtype ,backend = cartesian_backend )
118
117
119
118
# Save initial conditions
120
119
uold_gt [...] = u_gt [...]
@@ -129,15 +128,10 @@ def main():
129
128
print (" grid spacing in the y direction: " , config .dy )
130
129
print (" time step: " , config .dt )
131
130
print (" time filter coefficient: " , config .alpha )
132
- print (" Initial p:\n " , p_gt .asnumpy ()[:,:,0 ].diagonal ()[:- 1 ])
133
- print (" Initial u:\n " , u_gt .asnumpy ()[:,:,0 ].diagonal ()[:- 1 ])
134
- print (" Initial v:\n " , v_gt .asnumpy ()[:,:,0 ].diagonal ()[:- 1 ])
135
-
136
-
137
- u_gt = gtx .as_field (domain ,u ,allocator = allocator )
138
- p_gt = gtx .as_field (domain ,p ,allocator = allocator )
139
- v_gt = gtx .as_field (domain ,v ,allocator = allocator )
140
-
131
+
132
+ print (" Initial p:\n " , p_gt [:,:,0 ].diagonal ()[:- 1 ])
133
+ print (" Initial u:\n " , u_gt [:,:,0 ].diagonal ()[:- 1 ])
134
+ print (" Initial v:\n " , v_gt [:,:,0 ].diagonal ()[:- 1 ])
141
135
142
136
t0_start = perf_counter ()
143
137
time = 0.0
@@ -148,13 +142,13 @@ def main():
148
142
p_origin = (0 ,0 ,0 )
149
143
z_origin = (1 ,1 ,0 )
150
144
# Main time loop
151
- for ncycle in range (ITMAX ):
145
+ for ncycle in range (config . ITMAX ):
152
146
153
147
if ((ncycle % 100 == 0 ) & (config .VIS == False )):
154
148
print (f"cycle number{ ncycle } and gt4py type cartesian" )
155
149
156
150
if config .VAL_DEEP and ncycle <= 3 :
157
- utils .validate_uvp (u_gt . asnumpy ( ), v_gt . asnumpy ( ), p_gt . asnumpy ( ), M , N , ncycle , 'init' )
151
+ utils .validate_uvp (gts_utils . cpu_copy ( u_gt ), gts_utils . cpu_copy ( v_gt ), gts_utils . cpu_copy ( p_gt ), M , N , ncycle , 'init' )
158
152
159
153
t1_start = perf_counter ()
160
154
@@ -200,7 +194,7 @@ def main():
200
194
dt15 = dt15 + (t15_stop - t15_start )
201
195
202
196
if config .VAL_DEEP and ncycle <= 1 :
203
- utils .validate_cucvzh (cu_gt . asnumpy ( ), cv_gt . asnumpy ( ), z_gt . asnumpy ( ), h_gt . asnumpy ( ), M , N , ncycle , 't100' )
197
+ utils .validate_cucvzh (gts_utils . cpu_copy ( cu_gt ), gts_utils . cpu_copy ( cv_gt ), gts_utils . cpu_copy ( z_gt ), gts_utils . cpu_copy ( h_gt ), M , N , ncycle , 't100' )
204
198
205
199
# Calclulate new values of u,v, and p
206
200
tdts8 = tdt / 8.
@@ -260,7 +254,7 @@ def main():
260
254
261
255
262
256
if config .VAL_DEEP and ncycle <= 1 :
263
- utils .validate_uvp (unew_gt . asnumpy ( ), vnew_gt . asnumpy ( ), pnew_gt . asnumpy ( ), M , N , ncycle , 't200' )
257
+ utils .validate_uvp (gts_utils . cpu_copy ( unew_gt ), gts_utils . cpu_copy ( vnew_gt ), gts_utils . cpu_copy ( pnew_gt ), M , N , ncycle , 't200' )
264
258
265
259
time = time + config .dt
266
260
@@ -272,14 +266,6 @@ def main():
272
266
t3_stop = perf_counter ()
273
267
dt3 = dt3 + (t3_stop - t3_start )
274
268
275
- t35_start = perf_counter ()
276
- # u = u_gt.asnumpy()
277
- # v = v_gt.asnumpy()
278
- # p = p_gt.asnumpy()
279
- t35_stop = perf_counter ()
280
- dt35 = dt35 + (t35_stop - t35_start )
281
-
282
-
283
269
else :
284
270
tdt = tdt + tdt
285
271
@@ -290,28 +276,26 @@ def main():
290
276
v_gt [...] = vnew_gt [...]
291
277
p_gt [...] = pnew_gt [...]
292
278
293
- if ((config .VIS == True ) & (ncycle % config .VIS_DT == 0 )):
294
- utils .live_plot3 (u_gt . asnumpy ( ), v_gt . asnumpy ( ), p_gt . asnumpy ( ), "ncycle: " + str (ncycle ))
279
+ if ((config .VIS ) & (ncycle % config .VIS_DT == 0 )):
280
+ utils .live_plot3 (gts_utils . cpu_copy ( u_gt ), gts_utils . cpu_copy ( v_gt ), gts_utils . cpu_copy ( p_gt ), "ncycle: " + str (ncycle ))
295
281
296
282
t0_stop = perf_counter ()
297
283
dt0 = dt0 + (t0_stop - t0_start )
298
284
# Print initial conditions
299
285
if config .L_OUT :
300
- print ("cycle number " , ITMAX )
301
- print (" diagonal elements of p:\n " , pnew_gt . asnumpy () [:,:,0 ].diagonal ()[:- 1 ])
302
- print (" diagonal elements of u:\n " , unew_gt . asnumpy () [:,:,0 ].diagonal ()[:- 1 ])
303
- print (" diagonal elements of v:\n " , vnew_gt . asnumpy () [:,:,0 ].diagonal ()[:- 1 ])
286
+ print ("cycle number " , config . ITMAX )
287
+ print (" diagonal elements of p:\n " , p_gt [:,:,0 ].diagonal ()[:- 1 ])
288
+ print (" diagonal elements of u:\n " , u_gt [:,:,0 ].diagonal ()[:- 1 ])
289
+ print (" diagonal elements of v:\n " , v_gt [:,:,0 ].diagonal ()[:- 1 ])
304
290
print ("total: " ,dt0 )
305
- print ("t050: " ,dt05 )
306
291
print ("t100: " ,dt1 )
307
292
print ("t150: " ,dt15 )
308
293
print ("t200: " ,dt2 )
309
294
print ("t250: " ,dt25 )
310
295
print ("t300: " ,dt3 )
311
- print ("t350: " ,dt35 )
312
296
313
297
if config .VAL :
314
- utils .final_validation (u_gt . asnumpy ( ), v_gt . asnumpy ( ), p_gt . asnumpy ( ), ITMAX = ITMAX , M = M , N = N )
298
+ utils .final_validation (gts_utils . cpu_copy ( u_gt ), gts_utils . cpu_copy ( v_gt ), gts_utils . cpu_copy ( p_gt ), ITMAX = config . ITMAX , M = M , N = N )
315
299
316
300
if __name__ == "__main__" :
317
301
main ()
0 commit comments