|
1 | 1 | import numpy as np
|
2 | 2 | import argparse
|
3 | 3 | import matplotlib.pyplot as plt
|
| 4 | +import numpy as np |
| 5 | +import gt4py.next as gtx |
| 6 | +import gt4py.cartesian.gtscript as gtscript |
4 | 7 | #import cupy
|
5 | 8 | #import gt4py
|
6 | 9 |
|
@@ -120,10 +123,6 @@ def live_plot3(fu, fv, fp, title=''):
|
120 | 123 | print(" Initial p:\n", p[:,:,0].diagonal()[:-1])
|
121 | 124 | print(" Initial u:\n", u[:,:,0].diagonal()[:-1])
|
122 | 125 | print(" Initial v:\n", v[:,:,0].diagonal()[:-1])
|
123 |
| - |
124 |
| -import numpy as np |
125 |
| -import gt4py.next as gtx |
126 |
| -import gt4py.cartesian.gtscript as gtscript |
127 | 126 |
|
128 | 127 | nx = M
|
129 | 128 | ny = N
|
@@ -282,6 +281,14 @@ def calc_vold(
|
282 | 281 | ):
|
283 | 282 | with computation(PARALLEL), interval(...):
|
284 | 283 | vold = v + alpha * (vnew - 2 * v + vold)
|
| 284 | + |
| 285 | + @gtscript.stencil(backend=cartesian_backend) |
| 286 | + def copy_var( |
| 287 | + inp: gtscript.Field[dtype], |
| 288 | + out: gtscript.Field[dtype] |
| 289 | + ): |
| 290 | + with computation(PARALLEL), interval(...): |
| 291 | + out = inp |
285 | 292 |
|
286 | 293 | time = 0.0
|
287 | 294 | # Main time loop
|
@@ -363,10 +370,13 @@ def calc_vold(
|
363 | 370 | #uold[...] = u + alpha * (unew - 2 * u + uold)
|
364 | 371 | #vold[...] = v + alpha * (vnew - 2 * v + vold)
|
365 | 372 | #pold[...] = p + alpha * (pnew - 2 * p + pold)
|
366 |
| - |
367 |
| - u[...] = unew |
368 |
| - v[...] = vnew |
369 |
| - p[...] = pnew |
| 373 | + |
| 374 | + copy_var(unew_gt, u_gt, origin=(0,0,0), domain=(nx,ny,nz)) |
| 375 | + copy_var(vnew_gt, v_gt, origin=(0,0,0), domain=(nx,ny,nz)) |
| 376 | + copy_var(pnew_gt, p_gt, origin=(0,0,0), domain=(nx,ny,nz)) |
| 377 | + #u[...] = unew |
| 378 | + #v[...] = vnew |
| 379 | + #p[...] = pnew |
370 | 380 |
|
371 | 381 | else:
|
372 | 382 | tdt = tdt+tdt
|
|
0 commit comments