Skip to content

Commit bb8345e

Browse files
committed
add a gt4py cartesian example
1 parent 4620033 commit bb8345e

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

swm_python/js_gt4py_cartesian.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import gt4py.cartesian.gtscript as gtscript
2+
import gt4py.next as gtx
3+
from gt4py.next import Field
4+
import numpy as np
5+
import time
6+
7+
dtype = np.float64
8+
9+
M = 3000 # row dimension
10+
N = 4000 # column dimension
11+
nx = M + 1
12+
ny = N + 1
13+
nz = 1
14+
coeff1 = dtype(0.5)
15+
coeff2 = dtype(0.25)
16+
dx = 100000.
17+
dy = 100000.
18+
fsdx = dtype(4. / dx)
19+
fsdy = dtype(4. / dy)
20+
21+
allocator = gtx.gtfn_cpu # gtx.gtfn_cpu, gtx.gtfn_gpu
22+
backend = "gt:cpu_ifirst" # "gt:gpu" # "gt:cpu_ifirst"
23+
24+
###############################
25+
# Naive Python implementation #
26+
###############################
27+
28+
start_time = time.perf_counter()
29+
30+
p = np.ones((nx, ny))
31+
u = np.ones((nx, ny))
32+
v = np.ones((nx, ny))
33+
cu_ori = np.zeros((nx, ny))
34+
cv_ori = np.zeros((nx, ny))
35+
z_ori = np.ones((nx, ny))
36+
h_ori = np.zeros((nx, ny))
37+
38+
for i in range(M):
39+
for j in range(N):
40+
cu_ori[i + 1, j] = coeff1 * (p[i + 1, j] + p[i, j]) * u[i + 1, j]
41+
cv_ori[i, j + 1] = coeff1 * (p[i, j + 1] + p[i, j]) * v[i, j + 1]
42+
z_ori[i + 1, j + 1] = ( fsdx * (v[i + 1, j + 1] - v[i, j + 1]) -
43+
fsdy * (u[i + 1, j + 1] - u[i, j + 1])
44+
) / (p[i, j] + p[i + 1, j] + p[i + 1, j + 1] + p[i, j + 1])
45+
h_ori[i, j] = p[i, j] + coeff2 * (u[i + 1, j] * u[i + 1, j] +
46+
u[i, j] * u[i, j] +
47+
v[i, j + 1] * v[i, j + 1] +
48+
v[i, j] * v[i, j])
49+
50+
end_time = time.perf_counter()
51+
elapsed_time = end_time - start_time
52+
print(f"Naive python, elapsed time: {elapsed_time} seconds")
53+
54+
##################################
55+
# GT4PY Cartesian implementation #
56+
##################################
57+
58+
start_time = time.perf_counter()
59+
60+
I = gtx.Dimension("I")
61+
J = gtx.Dimension("J")
62+
K = gtx.Dimension("K", kind=gtx.DimensionKind.VERTICAL)
63+
64+
domain = gtx.domain({I: nx, J: ny, K: nz})
65+
66+
p_cart = gtx.ones(domain, dtype, allocator=allocator)
67+
u_cart = gtx.as_field(domain, np.ones((nx,ny,nz),dtype=dtype), dtype, allocator=allocator)
68+
v_cart = gtx.as_field(domain, np.ones((nx,ny,nz),dtype=dtype), dtype, allocator=allocator)
69+
cu_cart = gtx.zeros(domain, dtype, allocator=allocator)
70+
cv_cart = gtx.zeros(domain, dtype, allocator=allocator)
71+
z_cart = gtx.ones(domain, dtype, allocator=allocator)
72+
h_cart = gtx.zeros(domain, dtype, allocator=allocator)
73+
74+
@gtscript.stencil(backend=backend, rebuild=True)
75+
def cart_calc_region(
76+
p: gtscript.Field[dtype],
77+
u: gtscript.Field[dtype],
78+
v: gtscript.Field[dtype],
79+
cu: gtscript.Field[dtype],
80+
cv: gtscript.Field[dtype],
81+
z: gtscript.Field[dtype],
82+
h: gtscript.Field[dtype],
83+
coeff1: dtype,
84+
fsdx: dtype,
85+
fsdy: dtype,
86+
coeff2: dtype
87+
):
88+
with computation(PARALLEL), interval(...):
89+
with horizontal(region[1:,:-1]):
90+
cu = coeff1 * (p + p[I-1]) * u
91+
with horizontal(region[:-1,1:]):
92+
cv = coeff1 * (p + p[J-1]) * v
93+
with horizontal(region[1:,1:]):
94+
z = (fsdx * (v - v[I-1]) - fsdy * (u - u[I-1])) / (p[I-1,J-1] + p[J-1] + p + p[I-1])
95+
with horizontal(region[:-1,:-1]):
96+
h = p + coeff2 * (u[I+1] * u[I+1] + u * u + v[J+1] * v[J+1] + v * v)
97+
98+
# compute cu, cv, z and h; domain here refers to the number of points in each dimension to be computed
99+
cart_calc_region(p_cart, u_cart, v_cart, cu_cart, cv_cart, z_cart, h_cart,
100+
coeff1, fsdx, fsdy, coeff2, origin=(0,0,0), domain=(nx,ny,nz))
101+
102+
end_time = time.perf_counter()
103+
elapsed_time = end_time - start_time
104+
print(f"GT4PY Cartesian, elapsed time: {elapsed_time} seconds")
105+
106+
####################################
107+
# Check the GT4PY Cartesian result #
108+
####################################
109+
110+
assert np.array_equal(cu_cart[:,:,0].asnumpy(), cu_ori)
111+
assert np.array_equal(cv_cart[:,:,0].asnumpy(), cv_ori)
112+
assert np.array_equal(z_cart[:,:,0].asnumpy(), z_ori)
113+
assert np.array_equal(h_cart[:,:,0].asnumpy(), h_ori)

0 commit comments

Comments
 (0)