Skip to content

Commit 9434d33

Browse files
committed
Moved UpdateIntermediateVariables into kernel.
1 parent 92a7312 commit 9434d33

6 files changed

+39
-17
lines changed

swm_AMReX/Make.package

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
CEXE_sources += main.cpp swm_mini_app_utils.cpp
2-
CEXE_headers += swm_mini_app_utils.h
2+
CEXE_headers += swm_mini_app_utils.h swm_mini_app_kernels.h
33

swm_AMReX/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
Based on the NCAR SWM mini-app found [here](https://github.com/NCAR/SWM).
44

55
## Prerequisites
6-
- g++ (a version with support for c++20)
7-
- Other compilers should also work fine but I have been using gcc/g++ for testing so far.
6+
- A compiler for C, C++, and Fortran
7+
- Just about any compilers should work but I have been using gcc/g++/gfortran for testing so far.
88
- make
99
- [AMReX](https://github.com/AMReX-Codes/amrex)
1010
- [yt](https://yt-project.org/) (Only needed if you want to run the postprocessing script [plot_with_yt.py](plot_with_yt.py))

swm_AMReX/main.cpp

+5-7
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,13 @@ int main (int argc, char* argv[])
133133
Copy(p, p_old);
134134

135135
// Constants used in time stepping loop
136-
const double fsdx = 4.0/dx;
137-
const double fsdy = 4.0/dy;
138136
double tdt = dt;
139137
const double alpha = 0.001;
140138

141139
for (int time_step = 1; time_step <= n_time_steps; ++time_step)
142140
{
143141
// Sets: cu, cv, h, z
144-
UpdateIntermediateVariables(fsdx, fsdy, geom,
142+
UpdateIntermediateVariables(dx, dy, geom,
145143
p, u, v,
146144
cu, cv, h, z);
147145

@@ -161,10 +159,6 @@ int main (int argc, char* argv[])
161159
// Sets: p, u, v
162160
UpdateVariables(geom, u_new, v_new, p_new, u, v, p);
163161

164-
if (time_step == 0) {
165-
tdt = tdt + tdt;
166-
}
167-
168162
time = time + dt;
169163

170164
// Write a plotfile of the current data (plot_interval was defined in the inputs file)
@@ -173,6 +167,10 @@ int main (int argc, char* argv[])
173167
WriteOutput(psi, p, u, v, geom, time, time_step, output_values);
174168
}
175169

170+
if (time_step == 0) {
171+
tdt = tdt + tdt;
172+
}
173+
176174
}
177175

178176
//amrex::Print() << "Final: " << std::endl;

swm_AMReX/swm_mini_app_kernels.h

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef SWM_MINI_APP_KERNELS_H_
2+
#define SWM_MINI_APP_KERNELS_H_
3+
4+
#include <AMReX.H>
5+
6+
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
7+
void UpdateIntermediateVariablesKernel( const int i, const int j, const int k, const double fsdx, const double fsdy,
8+
const amrex::Array4<amrex::Real const>& p,
9+
const amrex::Array4<amrex::Real const>& u,
10+
const amrex::Array4<amrex::Real const>& v,
11+
const amrex::Array4<amrex::Real>& cu,
12+
const amrex::Array4<amrex::Real>& cv,
13+
const amrex::Array4<amrex::Real>& h,
14+
const amrex::Array4<amrex::Real>& z)
15+
{
16+
cu(i,j,k) = 0.5*(p(i,j,k) + p(i+1,j,k))*u(i,j,k);
17+
cv(i,j,k) = 0.5*(p(i,j,k) + p(i,j+1,k))*v(i,j,k);
18+
z(i,j,k) = (fsdx*(v(i+1,j,k)-v(i,j,k)) + fsdy*(u(i,j+1,k)-u(i,j,k)))/(p(i,j,k)+p(i+1,j,k)+p(i,j+1,k)+p(i+1,j+1,k));
19+
h(i,j,k) = p(i,j,k) + 0.25*(u(i-1,j,k)*u(i-1,j,k) + u(i,j,k)*u(i,j,k) + v(i,j-1,k)*v(i,j-1,k) + v(i,j,k)*v(i,j,k));
20+
}
21+
22+
#endif // SWM_MINI_APP_KERNELS_H_

swm_AMReX/swm_mini_app_utils.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <AMReX_Array.H>
99

1010
#include "swm_mini_app_utils.h"
11-
11+
#include "swm_mini_app_kernels.h"
1212

1313
void ParseInput(int & nx, int & ny,
1414
amrex::Real & dx, amrex::Real & dy,
@@ -320,10 +320,13 @@ void Copy(const amrex::MultiFab & src, amrex::MultiFab & dest)
320320
return;
321321
}
322322

323-
void UpdateIntermediateVariables(amrex::Real fsdx, amrex::Real fsdy, const amrex::Geometry& geom,
323+
void UpdateIntermediateVariables(amrex::Real dx, amrex::Real dy, const amrex::Geometry& geom,
324324
const amrex::MultiFab& p, const amrex::MultiFab& u, const amrex::MultiFab& v,
325325
amrex::MultiFab& cu, amrex::MultiFab& cv, amrex::MultiFab& h, amrex::MultiFab& z)
326326
{
327+
const double fsdx = 4.0/dx;
328+
const double fsdy = 4.0/dy;
329+
327330
for (amrex::MFIter mfi(p); mfi.isValid(); ++mfi)
328331
{
329332
const amrex::Box& bx = mfi.validbox();
@@ -341,10 +344,9 @@ void UpdateIntermediateVariables(amrex::Real fsdx, amrex::Real fsdy, const amrex
341344

342345
amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k)
343346
{
344-
cu_array(i,j,k) = 0.5*(p_array(i,j,k) + p_array(i+1,j,k))*u_array(i,j,k);
345-
cv_array(i,j,k) = 0.5*(p_array(i,j,k) + p_array(i,j+1,k))*v_array(i,j,k);
346-
z_array(i,j,k) = (fsdx*(v_array(i+1,j,k)-v_array(i,j,k)) + fsdy*(u_array(i,j+1,k)-u_array(i,j,k)))/(p_array(i,j,k)+p_array(i+1,j,k)+p_array(i,j+1,k)+p_array(i+1,j+1,k));
347-
h_array(i,j,k) = p_array(i,j,k) + 0.25*(u_array(i-1,j,k)*u_array(i-1,j,k) + u_array(i,j,k)*u_array(i,j,k) + v_array(i,j-1,k)*v_array(i,j-1,k) + v_array(i,j,k)*v_array(i,j,k));
347+
UpdateIntermediateVariablesKernel(i, j, k, fsdx, fsdy,
348+
p_array, u_array, v_array,
349+
cu_array, cv_array, h_array, z_array);
348350
});
349351
}
350352

swm_AMReX/swm_mini_app_utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ amrex::MultiFab CreateMultiFab(const amrex::MultiFab & mf);
5050

5151
void Copy(const amrex::MultiFab & src, amrex::MultiFab & dest);
5252

53-
void UpdateIntermediateVariables(amrex::Real fsdx, amrex::Real fsdy, const amrex::Geometry& geom,
53+
void UpdateIntermediateVariables(amrex::Real dx, amrex::Real dy, const amrex::Geometry& geom,
5454
const amrex::MultiFab& p, const amrex::MultiFab& u, const amrex::MultiFab& v,
5555
amrex::MultiFab& cu, amrex::MultiFab& cv, amrex::MultiFab& h, amrex::MultiFab& z);
5656

0 commit comments

Comments
 (0)