Skip to content

Commit b6f7308

Browse files
Merge pull request #2234 from devitocodes/sradius
api: enforce interpolation radius to be smaller than any input space …
2 parents 074df11 + 5e60d91 commit b6f7308

File tree

7 files changed

+53
-32
lines changed

7 files changed

+53
-32
lines changed

devito/operations/interpolators.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from abc import ABC, abstractmethod
2+
from functools import wraps
23

34
import sympy
45
from cached_property import cached_property
56

67
from devito.finite_differences.differentiable import Mul
78
from devito.finite_differences.elementary import floor
8-
from devito.symbolics import retrieve_function_carriers, INT
9+
from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT
910
from devito.tools import as_tuple, flatten
1011
from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol,
1112
CustomDimension)
@@ -14,6 +15,18 @@
1415
__all__ = ['LinearInterpolator', 'PrecomputedInterpolator']
1516

1617

18+
def check_radius(func):
19+
@wraps(func)
20+
def wrapper(interp, *args, **kwargs):
21+
r = interp.sfunction.r
22+
funcs = set().union(*[retrieve_functions(a) for a in args])
23+
so = min({f.space_order for f in funcs if not f.is_SparseFunction} or {r})
24+
if so < r:
25+
raise ValueError("Space order %d smaller than interpolation r %d" % (so, r))
26+
return func(interp, *args, **kwargs)
27+
return wrapper
28+
29+
1730
class UnevaluatedSparseOperation(sympy.Expr, Evaluable):
1831

1932
"""
@@ -209,6 +222,7 @@ def _interp_idx(self, variables, implicit_dims=None):
209222

210223
return idx_subs, temps
211224

225+
@check_radius
212226
def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
213227
"""
214228
Generate equations interpolating an arbitrary expression into ``self``.
@@ -226,6 +240,7 @@ def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
226240
"""
227241
return Interpolation(expr, increment, implicit_dims, self_subs, self)
228242

243+
@check_radius
229244
def inject(self, field, expr, implicit_dims=None):
230245
"""
231246
Generate equations injecting an arbitrary expression into a field.

devito/types/sparse.py

+4
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,8 @@ class PrecomputedSparseFunction(AbstractSparseFunction):
10081008
uses `*args` to (re-)create the Dimension arguments of the symbolic object.
10091009
"""
10101010

1011+
is_SparseFunction = True
1012+
10111013
_sub_functions = ('gridpoints', 'coordinates', 'interpolation_coeffs')
10121014

10131015
__rkwargs__ = (AbstractSparseFunction.__rkwargs__ +
@@ -1173,6 +1175,8 @@ class PrecomputedSparseTimeFunction(AbstractSparseTimeFunction,
11731175
uses ``*args`` to (re-)create the Dimension arguments of the symbolic object.
11741176
"""
11751177

1178+
is_SparseTimeFunction = True
1179+
11761180
__rkwargs__ = tuple(filter_ordered(AbstractSparseTimeFunction.__rkwargs__ +
11771181
PrecomputedSparseFunction.__rkwargs__))
11781182

tests/test_dle.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def test_scheduling(self):
709709
"""
710710
grid = Grid(shape=(11, 11))
711711

712-
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
712+
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)
713713
sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5)
714714

715715
eqns = [Eq(u.forward, u + 1)]

tests/test_gpu_common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1424,7 +1424,7 @@ def test_empty_arrays(self):
14241424
"""
14251425
grid = Grid(shape=(4, 4), extent=(3.0, 3.0))
14261426

1427-
f = TimeFunction(name='f', grid=grid, space_order=0)
1427+
f = TimeFunction(name='f', grid=grid, space_order=1)
14281428
f.data[:] = 1.
14291429
sf1 = SparseTimeFunction(name='sf1', grid=grid, npoint=0, nt=10)
14301430
sf2 = SparseTimeFunction(name='sf2', grid=grid, npoint=0, nt=10,

tests/test_interpolation.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@
1414
import scipy.sparse
1515

1616

17-
def unit_box(name='a', shape=(11, 11), grid=None):
17+
def unit_box(name='a', shape=(11, 11), grid=None, space_order=1):
1818
"""Create a field with value 0. to 1. in each dimension"""
1919
grid = grid or Grid(shape=shape)
20-
a = Function(name=name, grid=grid)
20+
a = Function(name=name, grid=grid, space_order=space_order)
2121
dims = tuple([np.linspace(0., 1., d) for d in shape])
2222
a.data[:] = np.meshgrid(*dims)[1]
2323
return a
2424

2525

26-
def unit_box_time(name='a', shape=(11, 11)):
26+
def unit_box_time(name='a', shape=(11, 11), space_order=1):
2727
"""Create a field with value 0. to 1. in each dimension"""
2828
grid = Grid(shape=shape)
29-
a = TimeFunction(name=name, grid=grid, time_order=1)
29+
a = TimeFunction(name=name, grid=grid, time_order=1, space_order=space_order)
3030
dims = tuple([np.linspace(0., 1., d) for d in shape])
3131
a.data[0, :] = np.meshgrid(*dims)[1]
3232
a.data[1, :] = np.meshgrid(*dims)[1]
@@ -117,16 +117,15 @@ def test_precomputed_interpolation(r):
117117
origin = (0, 0)
118118

119119
grid = Grid(shape=shape, origin=origin)
120-
r = 2 # Constant for linear interpolation
121-
# because we interpolate across 2 neighbouring points in each dimension
122120

123121
def init(data):
122+
# This is data with halo so need to shift to match the m.data expectations
124123
for i in range(data.shape[0]):
125124
for j in range(data.shape[1]):
126-
data[i, j] = sin(grid.spacing[0]*i) + sin(grid.spacing[1]*j)
125+
data[i, j] = sin(grid.spacing[0]*(i-r)) + sin(grid.spacing[1]*(j-r))
127126
return data
128127

129-
m = Function(name='m', grid=grid, initializer=init, space_order=0)
128+
m = Function(name='m', grid=grid, initializer=init, space_order=r)
130129

131130
gridpoints, interpolation_coeffs = precompute_linear_interpolation(points,
132131
grid, origin,
@@ -154,10 +153,8 @@ def test_precomputed_interpolation_time(r):
154153
origin = (0, 0)
155154

156155
grid = Grid(shape=shape, origin=origin)
157-
r = 2 # Constant for linear interpolation
158-
# because we interpolate across 2 neighbouring points in each dimension
159156

160-
u = TimeFunction(name='u', grid=grid, space_order=0, save=5)
157+
u = TimeFunction(name='u', grid=grid, space_order=r, save=5)
161158
for it in range(5):
162159
u.data[it, :] = it
163160

@@ -190,11 +187,7 @@ def test_precomputed_injection(r):
190187
origin = (0, 0)
191188
result = 0.25
192189

193-
# Constant for linear interpolation
194-
# because we interpolate across 2 neighbouring points in each dimension
195-
r = 2
196-
197-
m = unit_box(shape=shape)
190+
m = unit_box(shape=shape, space_order=r)
198191
m.data[:] = 0.
199192

200193
gridpoints, interpolation_coeffs = precompute_linear_interpolation(coords,
@@ -228,11 +221,7 @@ def test_precomputed_injection_time(r):
228221
result = 0.25
229222
nt = 20
230223

231-
# Constant for linear interpolation
232-
# because we interpolate across 2 neighbouring points in each dimension
233-
r = 2
234-
235-
m = unit_box_time(shape=shape)
224+
m = unit_box_time(shape=shape, space_order=r)
236225
m.data[:] = 0.
237226

238227
gridpoints, interpolation_coeffs = precompute_linear_interpolation(coords,
@@ -761,3 +750,16 @@ def test_inject_function():
761750
for i in [0, 1, 3, 4]:
762751
for j in [0, 1, 3, 4]:
763752
assert u.data[1, i, j] == 0
753+
754+
755+
def test_interpolation_radius():
756+
nt = 11
757+
758+
grid = Grid(shape=(5, 5))
759+
u = TimeFunction(name="u", grid=grid, space_order=0)
760+
src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1)
761+
try:
762+
src.interpolate(u)
763+
assert False
764+
except ValueError:
765+
assert True

tests/test_mpi.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1501,7 +1501,7 @@ def test_injection_wodup(self):
15011501
"""
15021502
grid = Grid(shape=(4, 4), extent=(3.0, 3.0))
15031503

1504-
f = Function(name='f', grid=grid, space_order=0)
1504+
f = Function(name='f', grid=grid, space_order=1)
15051505
f.data[:] = 0.
15061506
coords = np.array([(0.5, 0.5), (0.5, 2.5), (2.5, 0.5), (2.5, 2.5)])
15071507
sf = SparseFunction(name='sf', grid=grid, npoint=len(coords), coordinates=coords)
@@ -1536,7 +1536,7 @@ def test_injection_wodup_wtime(self):
15361536
grid = Grid(shape=(4, 4), extent=(3.0, 3.0))
15371537

15381538
save = 3
1539-
f = TimeFunction(name='f', grid=grid, save=save, space_order=0)
1539+
f = TimeFunction(name='f', grid=grid, save=save, space_order=1)
15401540
f.data[:] = 0.
15411541
coords = np.array([(0.5, 0.5), (0.5, 2.5), (2.5, 0.5), (2.5, 2.5)])
15421542
sf = SparseTimeFunction(name='sf', grid=grid, nt=save,
@@ -1611,7 +1611,7 @@ def test_injection_dup(self):
16111611
def test_interpolation_wodup(self):
16121612
grid = Grid(shape=(4, 4), extent=(3.0, 3.0))
16131613

1614-
f = Function(name='f', grid=grid, space_order=0)
1614+
f = Function(name='f', grid=grid, space_order=1)
16151615
f.data[:] = 4.
16161616
coords = [(0.5, 0.5), (0.5, 2.5), (2.5, 0.5), (2.5, 2.5)]
16171617
sf = SparseFunction(name='sf', grid=grid, npoint=len(coords), coordinates=coords)

tests/test_operator.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def test_sparsefunction_inject(self):
521521
Test injection of a SparseFunction into a Function
522522
"""
523523
grid = Grid(shape=(11, 11))
524-
u = Function(name='u', grid=grid, space_order=0)
524+
u = Function(name='u', grid=grid, space_order=1)
525525

526526
sf1 = SparseFunction(name='s', grid=grid, npoint=1)
527527
op = Operator(sf1.inject(u, expr=sf1))
@@ -542,7 +542,7 @@ def test_sparsefunction_interp(self):
542542
Test interpolation of a SparseFunction from a Function
543543
"""
544544
grid = Grid(shape=(11, 11))
545-
u = Function(name='u', grid=grid, space_order=0)
545+
u = Function(name='u', grid=grid, space_order=1)
546546

547547
sf1 = SparseFunction(name='s', grid=grid, npoint=1)
548548
op = Operator(sf1.interpolate(u))
@@ -563,7 +563,7 @@ def test_sparsetimefunction_interp(self):
563563
Test injection of a SparseTimeFunction into a TimeFunction
564564
"""
565565
grid = Grid(shape=(11, 11))
566-
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
566+
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)
567567

568568
sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5)
569569
op = Operator(sf1.interpolate(u))
@@ -586,7 +586,7 @@ def test_sparsetimefunction_inject(self):
586586
Test injection of a SparseTimeFunction from a TimeFunction
587587
"""
588588
grid = Grid(shape=(11, 11))
589-
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
589+
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)
590590

591591
sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5)
592592
op = Operator(sf1.inject(u, expr=3*sf1))
@@ -611,7 +611,7 @@ def test_sparsetimefunction_inject_dt(self):
611611
Test injection of the time deivative of a SparseTimeFunction into a TimeFunction
612612
"""
613613
grid = Grid(shape=(11, 11))
614-
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
614+
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)
615615

616616
sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5, time_order=2)
617617

0 commit comments

Comments
 (0)