Skip to content

Commit ed7606a

Browse files
Moving the gt4py stencil operation functions def to outside the loop
1 parent fe4b93a commit ed7606a

File tree

1 file changed

+54
-70
lines changed

1 file changed

+54
-70
lines changed

swm_python/swm_cartesian.py

+54-70
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,6 @@ def live_plot3(fu, fv, fp, title=''):
109109
vold = np.copy(v[...])
110110
pold = np.copy(p[...])
111111

112-
113-
# In[6]:
114-
115-
116112
# Print initial conditions
117113
if L_OUT:
118114
print(" Number of points in the x direction: ", M)
@@ -155,89 +151,77 @@ def live_plot3(fu, fv, fp, title=''):
155151
next_backend = gtx.itir_python
156152

157153
if gt4py_type == "cartesian":
154+
#for i in range(M):
155+
# for j in range(N):
156+
# h[i, j,0] = p[i, j,0] + 0.25 * (u[i + 1, j,0] * u[i + 1, j,0] + u[i, j,0] * u[i, j,0] +
157+
# v[i, j + 1,0] * v[i, j + 1,0] + v[i, j,0] * v[i, j,0])
158+
# i --> 0,M
159+
#j --> 0,N
160+
# at nx+1 its the boundary region, mask one region and
161+
@gtscript.stencil(backend=cartesian_backend)
162+
def calc_h(
163+
p: gtscript.Field[dtype],
164+
u: gtscript.Field[dtype],
165+
v: gtscript.Field[dtype],
166+
h: gtscript.Field[dtype]
167+
):
168+
with computation(PARALLEL), interval(...):
169+
h = p + 0.25 * u[1,0,0] * u[1,0,0] + u * u + v[0,1,0] * v[0,1,0] + v * v
170+
171+
#nx = M
172+
#ny = N
173+
#nz = 1
174+
# i --> 1,M+1 (1,1,0) (nx,ny,nz)
175+
#j --> 1,N+1
176+
@gtscript.stencil(backend=cartesian_backend)
177+
def calc_z(
178+
fsdx: float,
179+
fsdy: float,
180+
u: gtscript.Field[dtype],
181+
v: gtscript.Field[dtype],
182+
p: gtscript.Field[dtype],
183+
z: gtscript.Field[dtype]
184+
):
185+
with computation(PARALLEL), interval(...):
186+
z = (fsdx * (v - v[-1,0,0]) - fsdy * (u - u[0,-1,0])) / (p[-1,-1,0] + p[0,-1,0] + p + p[-1,0,0])
187+
188+
@gtscript.stencil(backend=cartesian_backend)
189+
def calc_cu(
190+
u: gtscript.Field[dtype],
191+
p: gtscript.Field[dtype],
192+
cu: gtscript.Field[dtype]
193+
):
194+
with computation(PARALLEL), interval(...):
195+
cu = .5 * (p + p) * u
196+
197+
@gtscript.stencil(backend=cartesian_backend)
198+
def calc_cv(
199+
v: gtscript.Field[dtype],
200+
p: gtscript.Field[dtype],
201+
cv: gtscript.Field[dtype]
202+
):
203+
with computation(PARALLEL), interval(...):
204+
cv = .5 * (p + p) * v
205+
158206
time = 0.0
159207
# Main time loop
160208
for ncycle in range(ITMAX):
161209
if((ncycle%100==0) & (VIS==False)):
162210
print("cycle number ", ncycle)
163211
# Calculate cu, cv, z, and h
164-
#for i in range(M):
165-
# for j in range(N):
166-
# h[i, j,0] = p[i, j,0] + 0.25 * (u[i + 1, j,0] * u[i + 1, j,0] + u[i, j,0] * u[i, j,0] +
167-
# v[i, j + 1,0] * v[i, j + 1,0] + v[i, j,0] * v[i, j,0])
168-
# i --> 0,M
169-
#j --> 0,N
170-
# at nx+1 its the boundary region, mask one region and
171-
@gtscript.stencil(backend=cartesian_backend)
172-
def calc_h(
173-
p: gtscript.Field[dtype],
174-
u: gtscript.Field[dtype],
175-
v: gtscript.Field[dtype],
176-
h: gtscript.Field[dtype]
177-
):
178-
with computation(PARALLEL), interval(...):
179-
h = p + 0.25 * u[1,0,0] * u[1,0,0] + u * u + v[0,1,0] * v[0,1,0] + v * v
180-
181-
calc_h(p=p_gt, u=u_gt, v=v_gt, h=h_gt, origin=(0,0,0), domain=(nx,ny,nz))
182-
212+
calc_h(p=p_gt, u=u_gt, v=v_gt, h=h_gt, origin=(0,0,0), domain=(nx,ny,nz))
183213
h = h_gt.asnumpy()
184214

185-
#nx = M
186-
#ny = N
187-
#nz = 1
188-
# i --> 1,M+1 (1,1,0) (nx,ny,nz)
189-
#j --> 1,N+1
190-
@gtscript.stencil(backend=cartesian_backend)
191-
def calc_z(
192-
fsdx: float,
193-
fsdy: float,
194-
u: gtscript.Field[dtype],
195-
v: gtscript.Field[dtype],
196-
p: gtscript.Field[dtype],
197-
z: gtscript.Field[dtype]
198-
):
199-
with computation(PARALLEL), interval(...):
200-
z = (fsdx * (v - v[-1,0,0]) - fsdy * (u - u[0,-1,0])) / (p[-1,-1,0] + p[0,-1,0] + p + p[-1,0,0])
201-
202215
calc_z(fsdx=fsdx, fsdy=fsdy, u=u_gt, v=v_gt, p=p_gt, z=z_gt, origin=(1,1,0), domain=(nx,ny,nz)) # domain(nx+1,ny+1,nz) gives error why?
203216
z = z_gt.asnumpy()
204217

205-
@gtscript.stencil(backend=cartesian_backend)
206-
def calc_cu(
207-
u: gtscript.Field[dtype],
208-
p: gtscript.Field[dtype],
209-
cu: gtscript.Field[dtype]
210-
):
211-
with computation(PARALLEL), interval(...):
212-
cu = .5 * (p + p) * u
213-
214-
#for i in range(1,M+1):
215-
# for j in range(N):
216-
# cu2[i, j,0] = .5 * (p[i, j,0] + p[i, j,0]) * u[i, j,0]
217218
calc_cu(u=u_gt, p=p_gt, cu=cu_gt, origin=(1,0,0), domain=(nx,ny+1,nz)) # domain(nx+1,ny+1,nz) gives error why? try removing ny+1
218219
cu = cu_gt.asnumpy()
219220

220-
@gtscript.stencil(backend=cartesian_backend)
221-
def calc_cv(
222-
v: gtscript.Field[dtype],
223-
p: gtscript.Field[dtype],
224-
cv: gtscript.Field[dtype]
225-
):
226-
with computation(PARALLEL), interval(...):
227-
cv = .5 * (p + p) * v
228-
229221
calc_cv(v=v_gt, p=p_gt, cv=cv_gt, origin=(0,1,0), domain=(nx+1,ny,nz)) # domain(nx+1,ny+1,nz) gives error why?
230222
cv = cv_gt.asnumpy()
231223

232-
#for i in range(M):
233-
# for j in range(N):
234-
# #cu[i + 1, j,0] = .5 * (p[i + 1, j,0] + p[i, j,0]) * u[i + 1, j,0]
235-
# cv[i, j + 1,0] = .5 * (p[i, j + 1,0] + p[i, j,0]) * v[i, j + 1,0]
236-
# #z[i + 1, j + 1,0] = (fsdx * (v[i + 1, j + 1,0] - v[i, j + 1,0]) -
237-
# # fsdy * (u[i + 1, j + 1,0] - u[i+1, j,0] )
238-
# # ) / (p[i, j,0] + p[i + 1, j,0] + p[i + 1, j + 1,0] + p[i, j + 1,0])
239-
240-
# # Periodic Boundary conditions
224+
# # Periodic Boundary conditions
241225
#try region
242226
cu[0, :,0] = cu[M, :,0]
243227
h[M, :,0] = h[0, :,0]

0 commit comments

Comments
 (0)