Skip to content

Commit ed90efb

Browse files
committed
compiler: fix alias dtype with complex numbers
1 parent e6b2fe3 commit ed90efb

File tree

4 files changed

+27
-3
lines changed

4 files changed

+27
-3
lines changed

devito/symbolics/inspection.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -304,4 +304,10 @@ def sympy_dtype(expr, base=None):
304304
dtypes.add(i.dtype)
305305
except AttributeError:
306306
pass
307-
return infer_dtype(dtypes)
307+
dtype = infer_dtype(dtypes)
308+
309+
# Promote if complex
310+
if expr.has(ImaginaryUnit):
311+
dtype = np.promote_types(dtype, np.complex64).type
312+
313+
return dtype

devito/types/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def _C_ctype(self):
436436
return self.dtype
437437
elif np.issubdtype(self.dtype, np.complexfloating):
438438
rtype = self.dtype(0).real.__class__
439-
ctname = '%s _Complex' % dtype_to_cstr(rtype)
439+
ctname = '%s complex' % dtype_to_cstr(rtype)
440440
ctype = dtype_to_ctype(rtype)
441441
r = type(ctname, (ctype,), {})
442442
return r

tests/test_gpu_common.py

+18
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,24 @@ def test_maxpar_option(self):
6666
assert trees[0][0] is trees[1][0]
6767
assert trees[0][1] is not trees[1][1]
6868

69+
def test_complex(self):
70+
grid = Grid((5, 5))
71+
x, y = grid.dimensions
72+
# Float32 complex is called complex64 in numpy
73+
u = Function(name="u", grid=grid, dtype=np.complex64)
74+
75+
eq = Eq(u, x + 1j*y + exp(1j + x.spacing))
76+
# Currently wrong alias type
77+
op = Operator(eq)
78+
op()
79+
80+
# Check against numpy
81+
dx = grid.spacing_map[x.spacing]
82+
xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5))
83+
npres = xx + 1j*yy + np.exp(1j + dx)
84+
85+
assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0)
86+
6987

7088
class TestPassesOptional(object):
7189

tests/test_operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def test_complex(self):
648648

649649
eq = Eq(u, x + 1j*y + exp(1j + x.spacing))
650650
# Currently wrong alias type
651-
op = Operator(eq, opt='noop')
651+
op = Operator(eq)
652652
op()
653653

654654
# Check against numpy

0 commit comments

Comments
 (0)