Skip to content

Commit 4620033

Browse files
gt4py port of pnew,vnew,unew and also move the swap calculation to numpy format
1 parent ed7606a commit 4620033

File tree

1 file changed

+74
-24
lines changed

1 file changed

+74
-24
lines changed

swm_python/swm_cartesian.py

+74-24
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ def live_plot3(fu, fv, fp, title=''):
146146
z_gt = gtx.as_field(domain,z,allocator=allocator)
147147
cu_gt = gtx.as_field(domain,cu,allocator=allocator)
148148
cv_gt = gtx.as_field(domain,cv,allocator=allocator)
149+
pnew_gt = gtx.as_field(domain,pnew,allocator=allocator)
150+
unew_gt = gtx.as_field(domain,unew,allocator=allocator)
151+
vnew_gt = gtx.as_field(domain,vnew,allocator=allocator)
152+
uold_gt = gtx.as_field(domain,uold,allocator=allocator)
153+
vold_gt = gtx.as_field(domain,vold,allocator=allocator)
154+
pold_gt = gtx.as_field(domain,pold,allocator=allocator)
149155

150156
cartesian_backend = "numpy"
151157
next_backend = gtx.itir_python
@@ -201,7 +207,51 @@ def calc_cv(
201207
cv: gtscript.Field[dtype]
202208
):
203209
with computation(PARALLEL), interval(...):
204-
cv = .5 * (p + p) * v
210+
cv = .5 * (p + p) * v
211+
212+
#recheck this section
213+
#pnew[i,j,0] = pold[i,j,0] - tdtsdx * (cu[i+1,j,0] - cu[i,j,0]) - tdtsdy * (cv[i,j+1,0] - cv[i,j,0])
214+
@gtscript.stencil(backend=cartesian_backend)
215+
def calc_pnew(
216+
tdtsdx: float,
217+
tdtsdy: float,
218+
pold: gtscript.Field[dtype],
219+
cu: gtscript.Field[dtype],
220+
cv: gtscript.Field[dtype],
221+
pnew: gtscript.Field[dtype]
222+
):
223+
with computation(PARALLEL), interval(...):
224+
pnew = pold - tdtsdx * (cu[1,0,0] - cu) - tdtsdy * (cv[0,1,0] - cv)
225+
226+
#unew[i+1,j,0] = uold[i+1,j,0] + tdts8 * (z[i+1,j+1,0] + z[i+1,j,0]) * (cv[i+1,j+1,0] + cv[i+1,j,0] + cv[i,j+1,0] + cv[i,j,0]) - tdtsdx * (h[i+1,j,0] - h[i,j,0])
227+
@gtscript.stencil(backend=cartesian_backend)
228+
def calc_unew(
229+
tdts8: float,
230+
tdtsdx: float,
231+
uold: gtscript.Field[dtype],
232+
cu: gtscript.Field[dtype],
233+
cv: gtscript.Field[dtype],
234+
z: gtscript.Field[dtype],
235+
h: gtscript.Field[dtype],
236+
unew: gtscript.Field[dtype]
237+
):
238+
with computation(PARALLEL), interval(...):
239+
unew = uold + tdts8 * (z + z[0,1,0]) * (cv[0,1,0] + cv + cv[-1,1,0] + cv[-1,0,0]) - tdtsdx * (h - h[-1,0,0])
240+
241+
#vnew[i,j+1,0] = vold[i,j+1,0] - tdts8 * (z[i+1,j+1,0] + z[i,j+1,0]) * (cu[i+1,j+1,0] + cu[i+1,j,0] + cu[i,j+1,0] + cu[i,j,0]) - tdtsdy * (h[i,j+1,0] - h[i,j,0])
242+
@gtscript.stencil(backend=cartesian_backend)
243+
def calc_vnew(
244+
tdts8: float,
245+
tdtsdy: float,
246+
vold: gtscript.Field[dtype],
247+
cu: gtscript.Field[dtype],
248+
cv: gtscript.Field[dtype],
249+
z: gtscript.Field[dtype],
250+
h: gtscript.Field[dtype],
251+
vnew: gtscript.Field[dtype]
252+
):
253+
with computation(PARALLEL), interval(...):
254+
vnew = vold - tdts8 * (z[1,0,0] + z) * (cu[1,0,0] + cu[1,-1,0] + cu + cu[0,-1,0]) - tdtsdy * (h - h[0,-1,0])
205255

206256
time = 0.0
207257
# Main time loop
@@ -215,10 +265,10 @@ def calc_cv(
215265
calc_z(fsdx=fsdx, fsdy=fsdy, u=u_gt, v=v_gt, p=p_gt, z=z_gt, origin=(1,1,0), domain=(nx,ny,nz)) # domain(nx+1,ny+1,nz) gives error why?
216266
z = z_gt.asnumpy()
217267

218-
calc_cu(u=u_gt, p=p_gt, cu=cu_gt, origin=(1,0,0), domain=(nx,ny+1,nz)) # domain(nx+1,ny+1,nz) gives error why? try removing ny+1
268+
calc_cu(u=u_gt, p=p_gt, cu=cu_gt, origin=(1,0,0), domain=(nx,ny,nz)) # (nx,ny+1,nz)-->works domain(nx+1,ny+1,nz) gives error why? try removing ny+1
219269
cu = cu_gt.asnumpy()
220270

221-
calc_cv(v=v_gt, p=p_gt, cv=cv_gt, origin=(0,1,0), domain=(nx+1,ny,nz)) # domain(nx+1,ny+1,nz) gives error why?
271+
calc_cv(v=v_gt, p=p_gt, cv=cv_gt, origin=(0,1,0), domain=(nx,ny,nz)) #(nx+1,ny,nz)--> works domain(nx+1,ny+1,nz) gives error why?
222272
cv = cv_gt.asnumpy()
223273

224274
# # Periodic Boundary conditions
@@ -244,14 +294,21 @@ def calc_cv(
244294
tdtsdy = tdt / dy
245295
#print(tdts8, tdtsdx, tdtsdy)
246296

247-
248-
for i in range(M):
249-
for j in range(N):
250-
unew[i+1,j,0] = uold[i+1,j,0] + tdts8 * (z[i+1,j+1,0] + z[i+1,j,0]) * (cv[i+1,j+1,0] + cv[i+1,j,0] + cv[i,j+1,0] + cv[i,j,0]) - tdtsdx * (h[i+1,j,0] - h[i,j,0])
251-
vnew[i,j+1,0] = vold[i,j+1,0] - tdts8 * (z[i+1,j+1,0] + z[i,j+1,0]) * (cu[i+1,j+1,0] + cu[i+1,j,0] + cu[i,j+1,0] + cu[i,j,0]) - tdtsdy * (h[i,j+1,0] - h[i,j,0])
252-
pnew[i,j,0] = pold[i,j,0] - tdtsdx * (cu[i+1,j,0] - cu[i,j,0]) - tdtsdy * (cv[i,j+1,0] - cv[i,j,0])
253-
297+
calc_unew(tdts8=tdts8, tdtsdx=tdtsdx, uold=uold_gt, cu=cu_gt, cv=cv_gt, z=z_gt, h=h_gt, unew=unew_gt, origin=(1,0,0), domain=(nx,ny,nz))
298+
unew = unew_gt.asnumpy()
299+
300+
calc_vnew(tdts8=tdts8, tdtsdy=tdtsdy, vold=vold_gt, cu=cu_gt, cv=cv_gt, z=z_gt, h=h_gt, vnew=vnew_gt, origin=(0,1,0), domain=(nx,ny,nz))
301+
vnew = vnew_gt.asnumpy()
254302

303+
calc_pnew(tdtsdx=tdtsdx, tdtsdy=tdtsdy, pold=pold_gt, cu=cu_gt, cv=cv_gt, pnew=pnew_gt, origin=(0,0,0), domain=(nx,ny,nz))
304+
pnew = pnew_gt.asnumpy()
305+
306+
# for i in range(M):
307+
# for j in range(N):
308+
# unew[i+1,j,0] = uold[i+1,j,0] + tdts8 * (z[i+1,j+1,0] + z[i+1,j,0]) * (cv[i+1,j+1,0] + cv[i+1,j,0] + cv[i,j+1,0] + cv[i,j,0]) - tdtsdx * (h[i+1,j,0] - h[i,j,0])
309+
# vnew[i,j+1,0] = vold[i,j+1,0] - tdts8 * (z[i+1,j+1,0] + z[i,j+1,0]) * (cu[i+1,j+1,0] + cu[i+1,j,0] + cu[i,j+1,0] + cu[i,j,0]) - tdtsdy * (h[i,j+1,0] - h[i,j,0])
310+
# pnew[i,j,0] = pold[i,j,0] - tdtsdx * (cu[i+1,j,0] - cu[i,j,0]) - tdtsdy * (cv[i,j+1,0] - cv[i,j,0])
311+
255312
# Periodic Boundary conditions
256313
unew[0, :,0] = unew[M, :,0]
257314
pnew[M, :,0] = pnew[0, :,0]
@@ -267,20 +324,13 @@ def calc_cv(
267324
time = time + dt
268325

269326
if(ncycle > 0):
270-
for i in range(M_LEN):
271-
for j in range(N_LEN):
272-
uoldtemp=uold[i,j,0]
273-
voldtemp=vold[i,j,0]
274-
poldtemp=pold[i,j,0]
275-
uold[i,j,0] = u[i,j,0] + alpha * (unew[i,j,0] - 2. * u[i,j,0] + uoldtemp)
276-
vold[i,j,0] = v[i,j,0] + alpha * (vnew[i,j,0] - 2. * v[i,j,0] + voldtemp)
277-
pold[i,j,0] = p[i,j,0] + alpha * (pnew[i,j,0] - 2. * p[i,j,0] + poldtemp)
278-
279-
for i in range(M_LEN):
280-
for j in range(N_LEN):
281-
u[i,j,0] = unew[i,j,0]
282-
v[i,j,0] = vnew[i,j,0]
283-
p[i,j,0] = pnew[i,j,0]
327+
uold[...] = u + alpha * (unew - 2 * u + uold)
328+
vold[...] = v + alpha * (vnew - 2 * v + vold)
329+
pold[...] = p + alpha * (pnew - 2 * p + pold)
330+
331+
u[...] = unew
332+
v[...] = vnew
333+
p[...] = pnew
284334

285335
else:
286336
tdt = tdt+tdt

0 commit comments

Comments
 (0)