Skip to content

Commit 96d38df

Browse files
Merge branch 'main' of github.com:NCAR/SWM
2 parents 9bf746a + b744f3b commit 96d38df

File tree

3 files changed

+349
-81
lines changed

3 files changed

+349
-81
lines changed

swm_python/lap_cartesian_vs_next.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
# GT4Py - GridTools for Python
5+
#
6+
# Copyright (c) 2014-2023, ETH Zurich
7+
# All rights reserved.
8+
#
9+
# This file is part the GT4Py project and the GridTools framework.
10+
# GT4Py is free software: you can redistribute it and/or modify it under
11+
# the terms of the GNU General Public License as published by the
12+
# Free Software Foundation, either version 3 of the License, or any later
13+
# version. See the LICENSE.txt file at the top-level directory of this
14+
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
15+
#
16+
# SPDX-License-Identifier: GPL-3.0-or-later
17+
18+
# # Demonstrates gt4py.cartesian with gt4py.next compatibility
19+
20+
# Imports
21+
22+
# In[19]:
23+
24+
25+
import numpy as np
26+
27+
nx = 32
28+
ny = 32
29+
nz = 1
30+
dtype = np.float64
31+
32+
33+
# Storages
34+
# --
35+
#
36+
# We create fields using the gt4py.next constructors. These fields are compatible with gt4py.cartesian when we use "I", "J", "K" as the dimension names.
37+
38+
# In[8]:
39+
40+
41+
import gt4py.next as gtx
42+
43+
#allocator = gtx.itir_python # should match the executor
44+
allocator = gtx.gtfn_cpu
45+
# allocator = gtx.gtfn_gpu
46+
47+
# Note: for gt4py.next, names don't matter, for gt4py.cartesian they have to be "I", "J", "K"
48+
I = gtx.Dimension("I")
49+
J = gtx.Dimension("J")
50+
K = gtx.Dimension("K", kind=gtx.DimensionKind.VERTICAL)
51+
52+
domain = gtx.domain({I: nx, J: ny, K: nz})
53+
54+
inp = gtx.as_field(domain, np.fromfunction(lambda x, y, z: x**2+y**2, shape=(nx, ny, nz)), dtype, allocator=allocator)
55+
out_cartesian = gtx.zeros(domain, dtype, allocator=allocator)
56+
out_next = gtx.zeros(domain, dtype, allocator=allocator)
57+
58+
59+
# In[25]:
60+
61+
62+
#get_ipython().system('module load gcc')
63+
#get_ipython().system('export BOOST_HOME=/glade/derecho/scratch/haiyingx/boost_1_84_0/include/boost')
64+
65+
66+
# gt4py.cartesian
67+
# --
68+
69+
# In[26]:
70+
71+
72+
import gt4py.cartesian.gtscript as gtscript
73+
74+
#cartesian_backend = "numpy"
75+
#cartesian_backend = "gt:cpu_ifirst"
76+
cartesian_backend = "gt:gpu"
77+
78+
@gtscript.stencil(backend=cartesian_backend)
79+
def lap_cartesian(
80+
inp: gtscript.Field[dtype],
81+
out: gtscript.Field[dtype],
82+
):
83+
with computation(PARALLEL), interval(...):
84+
out = -4.0 * inp[0, 0, 0] + inp[-1, 0, 0] + inp[1, 0, 0] + inp[0, -1, 0] + inp[0, 1, 0]
85+
86+
#lap_cartesian(inp=inp, out=out_cartesian, origin=(1, 1, 0), domain=(nx-2, ny-2, nz))
87+
88+
89+
# In[ ]:
90+
91+
92+
from gt4py.next import Field
93+
94+
#next_backend = gtx.itir_python
95+
# next_backend = gtx.gtfn_cpu
96+
next_backend = gtx.gtfn_gpu
97+
98+
Ioff = gtx.FieldOffset("I", source=I, target=(I,))
99+
Joff = gtx.FieldOffset("J", source=J, target=(J,))
100+
101+
@gtx.field_operator
102+
def lap_next(inp: Field[[I, J, K], dtype]) -> Field[[I, J, K], dtype]:
103+
return -4.0 * inp + inp(Ioff[-1]) + inp(Ioff[1]) + inp(Joff[-1]) + inp(Joff[1])
104+
105+
@gtx.program(backend=next_backend)
106+
def lap_next_program(inp: Field[[I, J, K], dtype], out: Field[[I, J, K], dtype]):
107+
lap_next(inp, out=out[1:-1, 1:-1, :])
108+
109+
lap_next_program(inp, out_next, offset_provider={"Ioff": I, "Joff": J})
110+
111+
112+
# In[ ]:
113+
114+
115+
assert np.allclose(out_cartesian.asnumpy(), out_next.asnumpy())
116+
117+
118+
# In[ ]:
119+
120+
121+
122+

swm_python/swm.ipynb

+175-80
Large diffs are not rendered by default.

swm_python/swm_numpy.ipynb

+52-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"L_OUT = True # args.L_OUT\n",
3636
"VIS = True\n",
3737
"VIS_DT=100\n",
38+
"VAL = True\n",
3839
"ITMAX = 1000\n",
3940
"dt = 90.\n",
4041
"tdt = dt\n",
@@ -98,6 +99,28 @@
9899
" fig.suptitle(title)\n",
99100
" #plt.xlabel('x')\n",
100101
" #plt.ylabel('y')\n",
102+
" plt.show()\n",
103+
"\n",
104+
"def live_plot_val(fu, fv, fp, title=''):\n",
105+
" mxu = fu.max()\n",
106+
" mxv = fv.max()\n",
107+
" mxp = fp.max()\n",
108+
" clear_output(wait=True)\n",
109+
" fig, (ax1, ax2, ax3) = plt.subplots(figsize=(13, 3), ncols=3)\n",
110+
"\n",
111+
" pos1 = ax1.imshow(fp, cmap='Blues', vmin=-mxp, vmax=mxp,interpolation='none')\n",
112+
" ax1.set_title('p')\n",
113+
" plt.colorbar(pos1,ax=ax1)\n",
114+
" pos2 = ax2.imshow(fu, cmap='Reds', vmin=-mxu, vmax=mxu,interpolation='none')\n",
115+
" ax2.set_title('u')\n",
116+
" plt.colorbar(pos2,ax=ax2)\n",
117+
" pos3 = ax3.imshow(fv, cmap='Greens',vmin=-mxv, vmax=mxv,interpolation='none')\n",
118+
" ax3.set_title('v')\n",
119+
" plt.colorbar(pos3, ax=ax3)\n",
120+
"\n",
121+
" fig.suptitle(title)\n",
122+
" #plt.xlabel('x')\n",
123+
" #plt.ylabel('y')\n",
101124
" plt.show()\n"
102125
]
103126
},
@@ -293,7 +316,35 @@
293316
"id": "78ef5751-4589-4fbc-abfc-a316f672b3b9",
294317
"metadata": {},
295318
"outputs": [],
296-
"source": []
319+
"source": [
320+
"if VAL:\n",
321+
"\n",
322+
" u_val_f = 'ref/u.64.64.IT4000.txt'\n",
323+
" v_val_f = 'ref/v.64.64.IT4000.txt'\n",
324+
" p_val_f = 'ref/p.64.64.IT4000.txt'\n",
325+
" uval = np.zeros((M_LEN, N_LEN))\n",
326+
" vval = np.zeros((M_LEN, N_LEN))\n",
327+
" pval = np.zeros((M_LEN, N_LEN))\n",
328+
"\n",
329+
" uref, vref, pref = read_arrays(v_val_f, u_val_f, p_val_f)\n",
330+
" uval = uref-unew\n",
331+
" vval = vref-vnew\n",
332+
" pval = pref-pnew\n",
333+
" \n",
334+
" uLinfN= np.linalg.norm(uval, np.inf)\n",
335+
" vLinfN= np.linalg.norm(vval, np.inf)\n",
336+
" pLinfN= np.linalg.norm(pval, np.inf)\n",
337+
"\n",
338+
" \n",
339+
"\n",
340+
" live_plot_val(uval, vval, pval, \"Val\")\n",
341+
" print(\"uLinfN: \", uLinfN)\n",
342+
" print(\"vLinfN: \", vLinfN)\n",
343+
" print(\"pLinfN: \", pLinfN)\n",
344+
" print(\"udiff max: \",uval.max())\n",
345+
" print(\"vdiff max: \",vval.max())\n",
346+
" print(\"pdiff max: \",pval.max())"
347+
]
297348
}
298349
],
299350
"metadata": {

0 commit comments

Comments
 (0)