Skip to content

Commit 1d4e767

Browse files
authored
support both Float32 and Float64 (#69)
1 parent ff75abc commit 1d4e767

20 files changed

+320
-288
lines changed

src/Biogeochemistry/nutrient_fields.jl

+18-15
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
function nutrients_init(arch, g)
2-
fields = (Field(arch, g), Field(arch, g),
3-
Field(arch, g), Field(arch, g),
4-
Field(arch, g), Field(arch, g),
5-
Field(arch, g), Field(arch, g),
6-
Field(arch, g), Field(arch, g))
1+
function nutrients_init(arch, g, FT = Float32)
2+
fields = (Field(arch, g, FT), Field(arch, g, FT),
3+
Field(arch, g, FT), Field(arch, g, FT),
4+
Field(arch, g, FT), Field(arch, g, FT),
5+
Field(arch, g, FT), Field(arch, g, FT),
6+
Field(arch, g, FT), Field(arch, g, FT))
77

88
nut = NamedTuple{nut_names}(fields)
99
return nut
@@ -20,19 +20,22 @@ function default_nut_init()
2020
end
2121

2222
"""
23-
generate_nutrients(arch, grid, source)
23+
generate_nutrients(arch, grid, source, FT)
2424
Set up initial nutrient fields according to `grid`.
2525
26-
Keyword Arguments
26+
Arguments
2727
=================
2828
- `arch`: `CPU()` or `GPU()`. The computer architecture used to time-step `model`.
2929
- `grid`: The resolution and discrete geometry on which nutrient fields are solved.
30-
- `source`: A `NamedTuple` containing 10 numbers each of which is the uniform initial condition of one tracer,
31-
or a `Dict` containing the file paths pointing to the files of nutrient initial conditions.
30+
- `source`: A `NamedTuple` containing 10 numbers each of which is the uniform initial
31+
condition of one tracer, or a `Dict` containing the file paths pointing to
32+
the files of nutrient initial conditions.
33+
- `FT`: Floating point data type. Default: `Float32`.
3234
"""
33-
function generate_nutrients(arch, g, source::Union{Dict,NamedTuple})
35+
function generate_nutrients(arch::Architecture, g::AbstractGrid,
36+
source::Union{Dict,NamedTuple}, FT::DataType)
3437
total_size = (g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2)
35-
nut = nutrients_init(arch, g)
38+
nut = nutrients_init(arch, g, FT)
3639
pathkeys = collect(keys(source))
3740

3841
if typeof(source) <: NamedTuple
@@ -58,9 +61,9 @@ function generate_nutrients(arch, g, source::Union{Dict,NamedTuple})
5861
if source.initial_condition[name] < 0.0
5962
throw(ArgumentError("NUT_INIT: The initial condition should be none-negetive."))
6063
end
61-
lower = 1.0 - source.rand_noise[name]
62-
upper = 1.0 + source.rand_noise[name]
63-
nut[name].data .= fill(source.initial_condition[name],total_size) .* rand(lower:1e-4:upper, total_size) |> array_type(arch)
64+
lower = FT(1.0 - source.rand_noise[name])
65+
upper = FT(1.0 + source.rand_noise[name])
66+
nut[name].data .= fill(FT(source.initial_condition[name]),total_size) .* rand(lower:1e-4:upper, total_size) |> array_type(arch)
6467
end
6568
end
6669

src/Diagnostics/diagnostics_struct.jl

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mutable struct PlanktonDiagnostics
22
plankton::NamedTuple # for each species
33
tracer::NamedTuple # for tracers
4-
iteration_interval::Int64 # time interval that the diagnostics is time averaged
4+
iteration_interval::Int # time interval that the diagnostics is time averaged
55
end
66

77
"""
@@ -19,7 +19,7 @@ Keyword Arguments (Optional)
1919
"""
2020
function PlanktonDiagnostics(model; tracer=(),
2121
plankton=(:num, :graz, :mort, :dvid),
22-
iteration_interval::Int64 = 1)
22+
iteration_interval::Int = 1)
2323

2424
@assert isa(tracer, Tuple)
2525
@assert isa(plankton, Tuple)
@@ -30,15 +30,16 @@ function PlanktonDiagnostics(model; tracer=(),
3030
nproc = length(plankton)
3131
trs = []
3232
procs = []
33+
FT = model.FT
3334

3435
total_size = (model.grid.Nx+model.grid.Hx*2, model.grid.Ny+model.grid.Hy*2, model.grid.Nz+model.grid.Hz*2)
3536

3637
for i in 1:ntr
37-
tr = zeros(total_size) |> array_type(model.arch)
38+
tr = zeros(FT, total_size) |> array_type(model.arch)
3839
push!(trs, tr)
3940
end
40-
tr_d1 = zeros(total_size) |> array_type(model.arch)
41-
tr_d2 = zeros(total_size) |> array_type(model.arch)
41+
tr_d1 = zeros(FT, total_size) |> array_type(model.arch)
42+
tr_d2 = zeros(FT, total_size) |> array_type(model.arch)
4243
tr_default = (PAR = tr_d1, T = tr_d2)
4344

4445
diag_tr = NamedTuple{tracer}(trs)
@@ -50,14 +51,14 @@ function PlanktonDiagnostics(model; tracer=(),
5051
for j in 1:Nsp
5152
procs_sp = []
5253
for k in 1:nproc
53-
proc = zeros(total_size) |> array_type(model.arch)
54+
proc = zeros(FT, total_size) |> array_type(model.arch)
5455
push!(procs_sp, proc)
5556
end
5657
diag_proc = NamedTuple{plankton}(procs_sp)
5758

5859
procs_sp_d = []
5960
for l in 1:4
60-
proc = zeros(total_size) |> array_type(model.arch)
61+
proc = zeros(FT, total_size) |> array_type(model.arch)
6162
push!(procs_sp_d, proc)
6263
end
6364
diag_proc_default = NamedTuple{(:num, :graz, :mort, :dvid)}(procs_sp_d)

src/Fields/Fields.jl

+7-6
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,20 @@ include("halo_regions.jl")
1919
include("boundary_conditions.jl")
2020
include("apply_bcs.jl")
2121

22-
struct Field
23-
data::AbstractArray{Float64,3}
22+
struct Field{FT}
23+
data::AbstractArray{FT,3}
2424
bc::BoundaryConditions
2525
end
2626

2727
"""
28-
Field(arch::Architecture, grid::AbstractGrid; bcs = default_bcs())
28+
Field(arch::Architecture, grid::AbstractGrid, FT::DataType; bcs = default_bcs())
2929
Construct a `Field` on `grid` with data and boundary conditions on architecture `arch`
30+
with DataType `FT`.
3031
"""
31-
function Field(arch::Architecture, grid::AbstractGrid; bcs = default_bcs())
32+
function Field(arch::Architecture, grid::AbstractGrid, FT::DataType; bcs = default_bcs())
3233
total_size = (grid.Nx+grid.Hx*2, grid.Ny+grid.Hy*2, grid.Nz+grid.Hz*2)
33-
data = zeros(total_size) |> array_type(arch)
34-
return Field(data,bcs)
34+
data = zeros(FT, total_size) |> array_type(arch)
35+
return Field{FT}(data,bcs)
3536
end
3637

3738
@inline interior(c, grid) = c[grid.Hx+1:grid.Hx+grid.Nx, grid.Hy+1:grid.Hy+grid.Ny, grid.Hz+1:grid.Hz+grid.Nz]

src/Fields/boundary_conditions.jl

+22-9
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,48 @@ function default_bcs()
1919
end
2020

2121
"""
22-
set_bc!(model, tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray})
23-
Set the boundary condition of `tracer` on `pos` with `bc_value`.
22+
set_bc!(model; tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray})
23+
Set the boundary condition of `tracer` on `pos` with `bc_value` of DataType `FT`.
24+
25+
Keyword Arguments
26+
=================
27+
- `tracer`: the tracer of which the boundary condition will be set.
28+
- `pos`: the position of the bounday condition to be set, e.g., `:east`, `:top` etc.
29+
- `bc_value`: the value that will be used to set the boundary condition.
2430
"""
25-
function set_bc!(model, tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray})
31+
function set_bc!(model; tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray})
2632
@assert tracer in nut_names
2733

34+
FT = model.FT
2835
bc_value_d = bc_value
2936
if isa(bc_value, AbstractArray)
30-
bc_value_d = bc_value |> array_type(model.arch)
37+
bc_value_d = FT.(bc_value) |> array_type(model.arch)
3138
end
3239
setproperty!(model.nutrients[tracer].bc, pos, bc_value_d)
3340
return nothing
3441
end
3542

3643
# get boundary condition at each grid point
37-
@inline getbc(bc::Number, i, j, t) = bc
38-
@inline getbc(bc::AbstractArray{Float64,2}, i, j, t) = bc[i,j]
39-
@inline getbc(bc::AbstractArray{Float64,3}, i, j, t) = bc[i,j,t]
44+
@inline function getbc(bc::Union{Number, AbstractArray}, i, j, t)
45+
if typeof(bc) <: Number
46+
return bc
47+
elseif typeof(bc) <: AbstractArray{eltype(bc),2}
48+
return bc[i,j]
49+
elseif typeof(bc) <: AbstractArray{eltype(bc),3}
50+
return bc[i,j,t]
51+
end
52+
end
4053

4154
# validate boundary conditions, check if the grid information is compatible with nutrient field
4255
function validate_bc(bc, bc_size, nΔT)
43-
if typeof(bc) <: AbstractArray{Float64,2}
56+
if typeof(bc) <: AbstractArray{eltype(bc),2}
4457
if size(bc) == bc_size
4558
return nothing
4659
else
4760
throw(ArgumentError("BC west: grid mismatch, size(bc) must equal to $(bc_size) for a constant flux boundary condition."))
4861
end
4962
end
50-
if typeof(bc) <: AbstractArray{Float64,3}
63+
if typeof(bc) <: AbstractArray{eltype(bc),3}
5164
if size(bc) == (bc_size..., nΔT)
5265
return nothing
5366
else

src/Fields/halo_regions.jl

+29-29
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,37 @@
11
##### fill halo points based on topology
2-
@inline fill_halo_west!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[1:H, :, :] = c[N+1:N+H, :, :]
3-
@inline fill_halo_south!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, 1:H, :] = c[:, N+1:N+H, :]
4-
@inline fill_halo_top!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, :, 1:H] = c[:, :, N+1:N+H]
2+
@inline fill_halo_west!(c, H::Int, N::Int, ::Periodic) = @views @. c[1:H, :, :] = c[N+1:N+H, :, :]
3+
@inline fill_halo_south!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, 1:H, :] = c[:, N+1:N+H, :]
4+
@inline fill_halo_top!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, :, 1:H] = c[:, :, N+1:N+H]
55

6-
@inline fill_halo_east!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[N+H+1:N+2H, :, :] = c[1+H:2H, :, :]
7-
@inline fill_halo_north!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, N+H+1:N+2H, :] = c[:, 1+H:2H, :]
8-
@inline fill_halo_bottom!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, 1+H:2H]
6+
@inline fill_halo_east!(c, H::Int, N::Int, ::Periodic) = @views @. c[N+H+1:N+2H, :, :] = c[1+H:2H, :, :]
7+
@inline fill_halo_north!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, N+H+1:N+2H, :] = c[:, 1+H:2H, :]
8+
@inline fill_halo_bottom!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, 1+H:2H]
99

10-
@inline fill_halo_west!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[1:H, :, :] = c[H+1:H+1, :, :]
11-
@inline fill_halo_south!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, 1:H, :] = c[:, H+1:H+1, :]
12-
@inline fill_halo_top!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, 1:H] = c[:, :, H+1:H+1]
10+
@inline fill_halo_west!(c, H::Int, N::Int, ::Bounded) = @views @. c[1:H, :, :] = c[H+1:H+1, :, :]
11+
@inline fill_halo_south!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, 1:H, :] = c[:, H+1:H+1, :]
12+
@inline fill_halo_top!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, 1:H] = c[:, :, H+1:H+1]
1313

14-
@inline fill_halo_east!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = c[N+H:N+H, :, :]
15-
@inline fill_halo_north!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = c[:, N+H:N+H, :]
16-
@inline fill_halo_bottom!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, N+H:N+H]
14+
@inline fill_halo_east!(c, H::Int, N::Int, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = c[N+H:N+H, :, :]
15+
@inline fill_halo_north!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = c[:, N+H:N+H, :]
16+
@inline fill_halo_bottom!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, N+H:N+H]
1717

18-
@inline fill_halo_east_vel!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[N+H+2:N+2H, :, :] = c[N+H+1:N+H+1, :, :]
19-
@inline fill_halo_north_vel!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, N+H+2:N+2H, :] = c[:, N+H+1:N+H+1, :]
20-
@inline fill_halo_bottom_vel!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, N+H+2:N+2H] = c[:, :, N+H+1:N+H+1]
18+
@inline fill_halo_east_vel!(c, H::Int, N::Int, ::Bounded) = @views @. c[N+H+2:N+2H, :, :] = c[N+H+1:N+H+1, :, :]
19+
@inline fill_halo_north_vel!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, N+H+2:N+2H, :] = c[:, N+H+1:N+H+1, :]
20+
@inline fill_halo_bottom_vel!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, N+H+2:N+2H] = c[:, :, N+H+1:N+H+1]
2121

22-
@inline fill_halo_east_Gc!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = 0.0
23-
@inline fill_halo_north_Gc!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = 0.0
24-
@inline fill_halo_bottom_Gc!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = 0.0
22+
@inline fill_halo_east_Gc!(c, H::Int, N::Int, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = 0.0
23+
@inline fill_halo_north_Gc!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = 0.0
24+
@inline fill_halo_bottom_Gc!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = 0.0
2525

26-
fill_halo_east_vel!(c, H::Int64, N::Int64, TX::Periodic) = fill_halo_east!(c, H, N, TX)
27-
fill_halo_north_vel!(c, H::Int64, N::Int64, TY::Periodic) = fill_halo_north!(c, H, N, TY)
28-
fill_halo_bottom_vel!(c, H::Int64, N::Int64, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ)
26+
fill_halo_east_vel!(c, H::Int, N::Int, TX::Periodic) = fill_halo_east!(c, H, N, TX)
27+
fill_halo_north_vel!(c, H::Int, N::Int, TY::Periodic) = fill_halo_north!(c, H, N, TY)
28+
fill_halo_bottom_vel!(c, H::Int, N::Int, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ)
2929

30-
fill_halo_east_Gc!(c, H::Int64, N::Int64, TX::Periodic) = fill_halo_east!(c, H, N, TX)
31-
fill_halo_north_Gc!(c, H::Int64, N::Int64, TY::Periodic) = fill_halo_north!(c, H, N, TY)
32-
fill_halo_bottom_Gc!(c, H::Int64, N::Int64, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ)
30+
fill_halo_east_Gc!(c, H::Int, N::Int, TX::Periodic) = fill_halo_east!(c, H, N, TX)
31+
fill_halo_north_Gc!(c, H::Int, N::Int, TY::Periodic) = fill_halo_north!(c, H, N, TY)
32+
fill_halo_bottom_Gc!(c, H::Int, N::Int, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ)
3333

34-
@inline function fill_halo_nut!(nuts::NamedTuple, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ}
34+
@inline function fill_halo_nut!(nuts::NamedTuple, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ}
3535
for nut in nuts
3636
fill_halo_west!(nut.data, g.Hx, g.Nx, TX())
3737
fill_halo_east!(nut.data, g.Hx, g.Nx, TX())
@@ -43,7 +43,7 @@ fill_halo_bottom_Gc!(c, H::Int64, N::Int64, TZ::Periodic) = fill_halo_bottom!(c,
4343
return nothing
4444
end
4545

46-
@inline function fill_halo_Gcs!(nuts::NamedTuple, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ}
46+
@inline function fill_halo_Gcs!(nuts::NamedTuple, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ}
4747
for nut in nuts
4848
fill_halo_east_Gc!(nut.data, g.Hx, g.Nx, TX())
4949
fill_halo_north_Gc!(nut.data, g.Hy, g.Ny, TY())
@@ -52,7 +52,7 @@ end
5252
return nothing
5353
end
5454

55-
@inline function fill_halo_u!(u, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ}
55+
@inline function fill_halo_u!(u, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ}
5656
fill_halo_east_vel!(u, g.Hx, g.Nx, TX())
5757

5858
fill_halo_west!(u, g.Hx, g.Nx, TX())
@@ -62,7 +62,7 @@ end
6262
fill_halo_bottom!(u, g.Hz, g.Nz, TZ())
6363
end
6464

65-
@inline function fill_halo_v!(v, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ}
65+
@inline function fill_halo_v!(v, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ}
6666
fill_halo_north_vel!(v, g.Hy, g.Ny, TY())
6767

6868
fill_halo_west!(v, g.Hx, g.Nx, TX())
@@ -72,7 +72,7 @@ end
7272
fill_halo_bottom!(v, g.Hz, g.Nz, TZ())
7373
end
7474

75-
@inline function fill_halo_w!(w, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ}
75+
@inline function fill_halo_w!(w, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ}
7676
fill_halo_bottom_vel!(w, g.Hz, g.Nz, TZ())
7777

7878
fill_halo_west!(w, g.Hx, g.Nx, TX())

src/Grids/Grids.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ using Adapt
1414
using PlanktonIndividuals.Architectures
1515

1616
"""
17-
AbstractGrid{TX, TY, TZ}
18-
Abstract type for grids with elements of type `Float64` and topology `{TX, TY, TZ}`.
17+
AbstractGrid{FT, TX, TY, TZ}
18+
Abstract type for grids with elements of type `FT` and topology `{TX, TY, TZ}`.
1919
"""
20-
abstract type AbstractGrid{TX, TY, YZ} end
20+
abstract type AbstractGrid{FT, TX, TY, YZ} end
2121

2222
"""
2323
AbstractTopology

0 commit comments

Comments
 (0)