Skip to content

Commit 7ce0240

Browse files
committed
only valiate on M=64
1 parent 518927a commit 7ce0240

File tree

2 files changed

+85
-54
lines changed

2 files changed

+85
-54
lines changed

swm_python/swm.py

+80-51
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,32 @@
11
import numpy as np
22
import argparse
33

4+
def calculate_cu_cv_z_h_numpy(u, v, p, fsdx, fsdy):
5+
cu=np.zeros_like(u)
6+
cv=np.zeros_like(v)
7+
h=np.zeros_like(u)
8+
z=np.zeros_like(u)
9+
cu[1:,:-1] = 0.5 * (p[1:, :-1] + p[:-1, :-1]) * u[1:, :-1]
10+
cv[:-1,1:] = 0.5 * (p[:-1, 1:] + p[:-1, :-1]) * v[:-1, 1:]
11+
z[1:,1:] = fsdx * (v[1:, 1:] - v[:-1, 1:]) - fsdy * (u[1:, 1:] - u[1:, :-1]) / (
12+
p[:-1, :-1] + p[1:, :-1] + p[1:, 1:] + p[:-1, 1:]
13+
)
14+
h[:-1,:-1] = p[:-1, :-1] + 0.25 * (
15+
u[1:, :-1] * u[1:, :-1]
16+
+ u[:-1, :-1] * u[:-1, :-1]
17+
+ v[:-1, 1:] * v[:-1, 1:]
18+
+ v[:-1, :-1] * v[:-1, :-1]
19+
)
20+
return cu, cv, z, h
21+
def update_u_v_p_numpy(u,v,p,cu,cv,h,z,tdts8,tdtsdx,tdtsdy):
22+
unew=np.zeros_like(u)
23+
vnew=np.zeros_like(v)
24+
pnew=np.zeros_like(u)
25+
unew[1:,:-1] = (u[1:,:-1] + tdts8 * (z[1:,1:]+z[1:,:-1]) * (cv[1:,1:]+cv[1:,:-1] + cv[:-1,1:]+cv[:-1,:-1])-tdtsdx *(h[1:,:-1]-h[:-1,:-1]))
26+
vnew[:-1,1:] = (v[:-1,1:]-tdts8*(z[1:,1:]+z[:-1,1:])*(cu[1:,1:]+cu[1:,:-1]+cu[:-1,1:]+cu[:-1,:-1])-tdtsdy*(h[:-1,1:]-h[:-1,:-1]))
27+
pnew[:-1,:-1] = (p[:-1,:-1]-tdtsdx*(cu[1:,:-1]-cu[:-1,:-1])-tdtsdy*(cv[:-1,1:]-cv[:-1,:-1]))
28+
return unew,vnew,pnew
29+
430
def main():
531
parser = argparse.ArgumentParser(description="Shallow Water Model")
632
parser.add_argument('--M', type=int, default=64, help='Number of points in the x direction')
@@ -45,10 +71,10 @@ def main():
4571
uold = np.zeros((M_LEN, N_LEN))
4672
vold = np.zeros((M_LEN, N_LEN))
4773
pold = np.zeros((M_LEN, N_LEN))
48-
cu = np.zeros((M_LEN, N_LEN))
49-
cv = np.zeros((M_LEN, N_LEN))
50-
z = np.zeros((M_LEN, N_LEN))
51-
h = np.zeros((M_LEN, N_LEN))
74+
#cu = np.zeros((M_LEN, N_LEN))
75+
#cv = np.zeros((M_LEN, N_LEN))
76+
#z = np.zeros((M_LEN, N_LEN))
77+
#h = np.zeros((M_LEN, N_LEN))
5278
psi = np.zeros((M_LEN, N_LEN))
5379

5480
# Initial values of the stream function and p
@@ -108,6 +134,7 @@ def main():
108134
if ncycle % 100 == 0:
109135
print("cycle number ", ncycle)
110136
# Calculate cu, cv, z, and h
137+
'''
111138
for i in range(1, M):
112139
for j in range(N):
113140
cu[i, j] = .5 * (p[i, j] + p[i - 1, j]) * u[i, j]
@@ -126,7 +153,8 @@ def main():
126153
for j in range(N):
127154
h[i, j] = p[i, j] + 0.25 * (u[i + 1, j] * u[i + 1, j] + u[i, j] * u[i, j] +
128155
v[i, j + 1] * v[i, j + 1] + v[i, j] * v[i, j])
129-
156+
'''
157+
cu,cv,z,h=calculate_cu_cv_z_h_numpy(u,v,p,fsdx,fsdy)
130158
# for i in range(M):
131159
# for j in range(N):
132160
# cu[i + 1, j] = .5 * (p[i + 1, j] + p[i, j]) * u[i + 1, j]
@@ -136,38 +164,27 @@ def main():
136164
# ) / (p[i, j] + p[i + 1, j] + p[i + 1, j + 1] + p[i, j + 1])
137165
# h[i, j] = p[i, j] + 0.25 * (u[i + 1, j] * u[i + 1, j] + u[i, j] * u[i, j] +
138166
# v[i, j + 1] * v[i, j + 1] + v[i, j] * v[i, j])
139-
167+
168+
169+
140170
# # Periodic Boundary conditions
141-
for j in range(N):
142-
cu[0, j] = cu[M, j]
143-
h[M, j] = h[0, j]
144-
# for j in range(N):
145-
# cv[M, j + 1] = cv[0, j + 1]
146-
for j in range(1, N):
147-
cv[M, j] = cv[0, j]
148-
# for j in range(N):
149-
# z[0,j + 1] = z[M, j + 1]
150-
for j in range(1, N):
151-
z[0, j] = z[M, j]
152-
171+
cu[0, :] = cu[M, :]
172+
h[M, :] = h[0, :]
173+
cv[M, 1:] = cv[0, 1:]
174+
z[0, 1:] = z[M, 1:]
153175

154-
for i in range(M):
155-
cv[i, 0] = cv[i, N]
156-
h[i, N] = h[i, 0]
157-
# for i in range(M):
158-
# cu[i + 1, N] = cu[i + 1, 0]
159-
for i in range(1, M):
160-
cu[i, N] = cu[i, 0]
161-
# for i in range(M):
162-
# z[i + 1, N] = z[i + 1, 0]
163-
for i in range(1, M):
164-
z[i, N] = z[i, 0]
165-
176+
cv[:, 0] = cv[:, N]
177+
h[:, N] = h[:, 0]
178+
cu[1:, N] = cu[1:, 0]
179+
z[1:, N] = z[1:, 0]
166180

167-
cu[0, 0] = cu[0, N]
168-
cv[M, 0] = cv[0, 0]
181+
cu[0, N] = cu[M, 0]
182+
cv[M, 0] = cv[0, N]
169183
z[0, 0] = z[M, N]
170184
h[M, N] = h[0, 0]
185+
186+
187+
171188

172189
# Calclulate new values of u,v, and p
173190
tdts8 = tdt / 8.
@@ -180,30 +197,42 @@ def main():
180197
# (cv[i + 1, j + 1] + cv[i + 1, j] + cv[i, j + 1] + cv[i, j]) -
181198
# tdtsdx * (h[i + 1, j] - h[i, j])
182199
# )
183-
for i in range(1, M):
184-
for j in range(N):
185-
unew[i, j] = (uold[i, j] + tdts8 * (z[i, j + 1] + z[i, j]) *
186-
(cv[i, j + 1] + cv[i, j] + cv[i - 1, j + 1] + cv[i - 1, j]) -
187-
tdtsdx * (h[i, j] - h[i - 1, j])
188-
)
200+
##for i in range(1, M):
201+
## for j in range(N):
202+
## unew[i, j] = (uold[i, j] + tdts8 * (z[i, j + 1] + z[i, j]) *
203+
## (cv[i, j + 1] + cv[i, j] + cv[i - 1, j + 1] + cv[i - 1, j]) -
204+
## tdtsdx * (h[i, j] - h[i - 1, j])
205+
## )
189206
# for i in range(M):
190207
# for j in range(N):
191208
# vnew[i, j + 1] = (vold[i, j + 1] - tdts8 * (z[i + 1, j + 1] + z[i, j + 1]) *
192209
# (cu[i + 1, j + 1] + cu[i + 1, j] + cu[i, j + 1] + cu[i, j]) -
193210
# tdtsdy * (h[i, j + 1] - h[i, j])
194211
# )
195-
for i in range(M):
196-
for j in range(1, N):
197-
vnew[i, j] = (vold[i, j] - tdts8 * (z[i + 1, j] + z[i, j]) *
198-
(cu[i + 1, j] + cu[i + 1, j - 1] + cu[i, j] + cu[i, j - 1]) -
199-
tdtsdy * (h[i, j] - h[i, j - 1])
200-
)
201-
for i in range(M):
202-
for j in range(N):
203-
pnew[i, j] = (pold[i, j] - tdtsdx * (cu[i + 1, j] - cu[i, j]) -
204-
tdtsdy * (cv[i, j + 1] - cv[i, j])
205-
)
206-
212+
##for i in range(M):
213+
## for j in range(1, N):
214+
## vnew[i, j] = (vold[i, j] - tdts8 * (z[i + 1, j] + z[i, j]) *
215+
## (cu[i + 1, j] + cu[i + 1, j - 1] + cu[i, j] + cu[i, j - 1]) -
216+
## tdtsdy * (h[i, j] - h[i, j - 1])
217+
## )
218+
##for i in range(M):
219+
## for j in range(N):
220+
## pnew[i, j] = (pold[i, j] - tdtsdx * (cu[i + 1, j] - cu[i, j]) -
221+
## tdtsdy * (cv[i, j + 1] - cv[i, j])
222+
## )
223+
224+
unew,vnew,pnew=update_u_v_p_numpy(uold,vold,pold,cu,cv,h,z,tdts8,tdtsdx,tdtsdy)
225+
unew[0, :] = unew[M, :]
226+
pnew[M, :] = pnew[0, :]
227+
vnew[M, 1:] = vnew[0, 1:]
228+
unew[1:, N] = unew[1:, 0]
229+
vnew[:, 0] = vnew[:, N]
230+
pnew[:, N] = pnew[:, 0]
231+
232+
unew[0, N] = unew[M, 0]
233+
vnew[M, 0] = vnew[0, N]
234+
pnew[M, N] = pnew[0, 0]
235+
'''
207236
# Periodic Boundary conditions
208237
for j in range(N):
209238
unew[0, j] = unew[M, j]
@@ -220,10 +249,10 @@ def main():
220249
for i in range(M):
221250
vnew[i, 0] = vnew[i, N]
222251
pnew[i, N] = pnew[i, 0]
223-
224252
unew[0, 0] = unew[0, N]
225253
vnew[M, 0] = vnew[0, 0]
226254
pnew[0, 0] = pnew[0, 0]
255+
'''
227256

228257
# Print initial conditions
229258
if L_OUT:

swm_python/swm_cartesian.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#import gt4py
1010

1111
# Initialize model parameters
12-
M = 64 # args.M
13-
N = 64 # args.N
12+
M = 512 # args.M
13+
N = 512 # args.N
1414
M_LEN = M + 1
1515
N_LEN = N + 1
1616
L_OUT = True # args.L_OUT
@@ -169,6 +169,7 @@ def live_plot_val(fu, fv, fp, title=''):
169169
gt4py_type = "cartesian"
170170
#gt4py_type = "next"
171171
allocator = gtx.itir_python
172+
allocator = gtx.gtfn_gpu
172173

173174
I = gtx.Dimension("I")
174175
J = gtx.Dimension("J")
@@ -191,6 +192,7 @@ def live_plot_val(fu, fv, fp, title=''):
191192
pold_gt = gtx.as_field(domain,pold,allocator=allocator)
192193

193194
cartesian_backend = "numpy"
195+
cartesian_backend = "gt:gpu"
194196
next_backend = gtx.itir_python
195197

196198
if gt4py_type == "cartesian":
@@ -452,7 +454,7 @@ def copy_var(
452454
print("t200: ",dt2)
453455
print("t300: ",dt3)
454456

455-
if VAL:
457+
if VAL and M==64:
456458

457459
u_val_f = 'ref/u.64.64.IT4000.txt'
458460
v_val_f = 'ref/v.64.64.IT4000.txt'

0 commit comments

Comments
 (0)