|
14 | 14 | import scipy.sparse
|
15 | 15 |
|
16 | 16 |
|
17 |
| -def unit_box(name='a', shape=(11, 11), grid=None): |
| 17 | +def unit_box(name='a', shape=(11, 11), grid=None, space_order=1): |
18 | 18 | """Create a field with value 0. to 1. in each dimension"""
|
19 | 19 | grid = grid or Grid(shape=shape)
|
20 |
| - a = Function(name=name, grid=grid) |
| 20 | + a = Function(name=name, grid=grid, space_order=space_order) |
21 | 21 | dims = tuple([np.linspace(0., 1., d) for d in shape])
|
22 | 22 | a.data[:] = np.meshgrid(*dims)[1]
|
23 | 23 | return a
|
24 | 24 |
|
25 | 25 |
|
26 |
| -def unit_box_time(name='a', shape=(11, 11)): |
| 26 | +def unit_box_time(name='a', shape=(11, 11), space_order=1): |
27 | 27 | """Create a field with value 0. to 1. in each dimension"""
|
28 | 28 | grid = Grid(shape=shape)
|
29 |
| - a = TimeFunction(name=name, grid=grid, time_order=1) |
| 29 | + a = TimeFunction(name=name, grid=grid, time_order=1, space_order=space_order) |
30 | 30 | dims = tuple([np.linspace(0., 1., d) for d in shape])
|
31 | 31 | a.data[0, :] = np.meshgrid(*dims)[1]
|
32 | 32 | a.data[1, :] = np.meshgrid(*dims)[1]
|
@@ -117,16 +117,15 @@ def test_precomputed_interpolation(r):
|
117 | 117 | origin = (0, 0)
|
118 | 118 |
|
119 | 119 | grid = Grid(shape=shape, origin=origin)
|
120 |
| - r = 2 # Constant for linear interpolation |
121 |
| - # because we interpolate across 2 neighbouring points in each dimension |
122 | 120 |
|
123 | 121 | def init(data):
|
| 122 | + # This is data with halo so need to shift to match the m.data expectations |
124 | 123 | for i in range(data.shape[0]):
|
125 | 124 | for j in range(data.shape[1]):
|
126 |
| - data[i, j] = sin(grid.spacing[0]*i) + sin(grid.spacing[1]*j) |
| 125 | + data[i, j] = sin(grid.spacing[0]*(i-r)) + sin(grid.spacing[1]*(j-r)) |
127 | 126 | return data
|
128 | 127 |
|
129 |
| - m = Function(name='m', grid=grid, initializer=init, space_order=0) |
| 128 | + m = Function(name='m', grid=grid, initializer=init, space_order=r) |
130 | 129 |
|
131 | 130 | gridpoints, interpolation_coeffs = precompute_linear_interpolation(points,
|
132 | 131 | grid, origin,
|
@@ -154,10 +153,8 @@ def test_precomputed_interpolation_time(r):
|
154 | 153 | origin = (0, 0)
|
155 | 154 |
|
156 | 155 | grid = Grid(shape=shape, origin=origin)
|
157 |
| - r = 2 # Constant for linear interpolation |
158 |
| - # because we interpolate across 2 neighbouring points in each dimension |
159 | 156 |
|
160 |
| - u = TimeFunction(name='u', grid=grid, space_order=0, save=5) |
| 157 | + u = TimeFunction(name='u', grid=grid, space_order=r, save=5) |
161 | 158 | for it in range(5):
|
162 | 159 | u.data[it, :] = it
|
163 | 160 |
|
@@ -190,11 +187,7 @@ def test_precomputed_injection(r):
|
190 | 187 | origin = (0, 0)
|
191 | 188 | result = 0.25
|
192 | 189 |
|
193 |
| - # Constant for linear interpolation |
194 |
| - # because we interpolate across 2 neighbouring points in each dimension |
195 |
| - r = 2 |
196 |
| - |
197 |
| - m = unit_box(shape=shape) |
| 190 | + m = unit_box(shape=shape, space_order=r) |
198 | 191 | m.data[:] = 0.
|
199 | 192 |
|
200 | 193 | gridpoints, interpolation_coeffs = precompute_linear_interpolation(coords,
|
@@ -228,11 +221,7 @@ def test_precomputed_injection_time(r):
|
228 | 221 | result = 0.25
|
229 | 222 | nt = 20
|
230 | 223 |
|
231 |
| - # Constant for linear interpolation |
232 |
| - # because we interpolate across 2 neighbouring points in each dimension |
233 |
| - r = 2 |
234 |
| - |
235 |
| - m = unit_box_time(shape=shape) |
| 224 | + m = unit_box_time(shape=shape, space_order=r) |
236 | 225 | m.data[:] = 0.
|
237 | 226 |
|
238 | 227 | gridpoints, interpolation_coeffs = precompute_linear_interpolation(coords,
|
@@ -761,3 +750,16 @@ def test_inject_function():
|
761 | 750 | for i in [0, 1, 3, 4]:
|
762 | 751 | for j in [0, 1, 3, 4]:
|
763 | 752 | assert u.data[1, i, j] == 0
|
| 753 | + |
| 754 | + |
| 755 | +def test_interpolation_radius(): |
| 756 | + nt = 11 |
| 757 | + |
| 758 | + grid = Grid(shape=(5, 5)) |
| 759 | + u = TimeFunction(name="u", grid=grid, space_order=0) |
| 760 | + src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1) |
| 761 | + try: |
| 762 | + src.interpolate(u) |
| 763 | + assert False |
| 764 | + except ValueError: |
| 765 | + assert True |
0 commit comments