Skip to content

Commit 9bf746a

Browse files
moving array copies to gt4py
1 parent d8538cb commit 9bf746a

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

swm_python/swm_cartesian.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22
import argparse
33
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import gt4py.next as gtx
6+
import gt4py.cartesian.gtscript as gtscript
47
#import cupy
58
#import gt4py
69

@@ -120,10 +123,6 @@ def live_plot3(fu, fv, fp, title=''):
120123
print(" Initial p:\n", p[:,:,0].diagonal()[:-1])
121124
print(" Initial u:\n", u[:,:,0].diagonal()[:-1])
122125
print(" Initial v:\n", v[:,:,0].diagonal()[:-1])
123-
124-
import numpy as np
125-
import gt4py.next as gtx
126-
import gt4py.cartesian.gtscript as gtscript
127126

128127
nx = M
129128
ny = N
@@ -282,6 +281,14 @@ def calc_vold(
282281
):
283282
with computation(PARALLEL), interval(...):
284283
vold = v + alpha * (vnew - 2 * v + vold)
284+
285+
@gtscript.stencil(backend=cartesian_backend)
286+
def copy_var(
287+
inp: gtscript.Field[dtype],
288+
out: gtscript.Field[dtype]
289+
):
290+
with computation(PARALLEL), interval(...):
291+
out = inp
285292

286293
time = 0.0
287294
# Main time loop
@@ -363,10 +370,13 @@ def calc_vold(
363370
#uold[...] = u + alpha * (unew - 2 * u + uold)
364371
#vold[...] = v + alpha * (vnew - 2 * v + vold)
365372
#pold[...] = p + alpha * (pnew - 2 * p + pold)
366-
367-
u[...] = unew
368-
v[...] = vnew
369-
p[...] = pnew
373+
374+
copy_var(unew_gt, u_gt, origin=(0,0,0), domain=(nx,ny,nz))
375+
copy_var(vnew_gt, v_gt, origin=(0,0,0), domain=(nx,ny,nz))
376+
copy_var(pnew_gt, p_gt, origin=(0,0,0), domain=(nx,ny,nz))
377+
#u[...] = unew
378+
#v[...] = vnew
379+
#p[...] = pnew
370380

371381
else:
372382
tdt = tdt+tdt

0 commit comments

Comments
 (0)