Skip to content

Commit 582d864

Browse files
Moving more loops to gt4py
1 parent 4620033 commit 582d864

File tree

1 file changed

+39
-3
lines changed

1 file changed

+39
-3
lines changed

swm_python/swm_cartesian.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,36 @@ def calc_vnew(
252252
):
253253
with computation(PARALLEL), interval(...):
254254
vnew = vold - tdts8 * (z[1,0,0] + z) * (cu[1,0,0] + cu[1,-1,0] + cu + cu[0,-1,0]) - tdtsdy * (h - h[0,-1,0])
255+
256+
@gtscript.stencil(backend=cartesian_backend)
257+
def calc_pold(
258+
p: gtscript.Field[dtype],
259+
alpha: float,
260+
pnew: gtscript.Field[dtype],
261+
pold: gtscript.Field[dtype]
262+
):
263+
with computation(PARALLEL), interval(...):
264+
pold = p + alpha * (pnew - 2 * p + pold)
265+
266+
@gtscript.stencil(backend=cartesian_backend)
267+
def calc_uold(
268+
u: gtscript.Field[dtype],
269+
alpha: float,
270+
unew: gtscript.Field[dtype],
271+
uold: gtscript.Field[dtype]
272+
):
273+
with computation(PARALLEL), interval(...):
274+
uold[...] = u + alpha * (unew - 2 * u + uold)
275+
276+
@gtscript.stencil(backend=cartesian_backend)
277+
def calc_vold(
278+
v: gtscript.Field[dtype],
279+
alpha: float,
280+
vnew: gtscript.Field[dtype],
281+
vold: gtscript.Field[dtype]
282+
):
283+
with computation(PARALLEL), interval(...):
284+
vold[...] = v + alpha * (vnew - 2 * v + vold)
255285

256286
time = 0.0
257287
# Main time loop
@@ -324,9 +354,15 @@ def calc_vnew(
324354
time = time + dt
325355

326356
if(ncycle > 0):
327-
uold[...] = u + alpha * (unew - 2 * u + uold)
328-
vold[...] = v + alpha * (vnew - 2 * v + vold)
329-
pold[...] = p + alpha * (pnew - 2 * p + pold)
357+
calc_pold(p=p_gt, alpha=alpha, pnew=pnew_gt, pold=pold_gt, origin=(0,0,0), domain=(nx,ny,nz))
358+
pold = pold_gt.asnumpy()
359+
calc_uold(u=u_gt, alpha=alpha, unew=unew_gt, uold=uold_gt, origin=(0,0,0), domain=(nx,ny,nz))
360+
uold = uold_gt.asnumpy()
361+
calc_vold(v=v_gt, alpha=alpha, vnew=vnew_gt, vold=vold_gt, origin=(0,0,0), domain=(nx,ny,nz))
362+
vold = vold_gt.asnumpy()
363+
#uold[...] = u + alpha * (unew - 2 * u + uold)
364+
#vold[...] = v + alpha * (vnew - 2 * v + vold)
365+
#pold[...] = p + alpha * (pnew - 2 * p + pold)
330366

331367
u[...] = unew
332368
v[...] = vnew

0 commit comments

Comments
 (0)