Skip to content

Commit e6b2fe3

Browse files
committed
api: fix printer for complex dtype
1 parent 34c8314 commit e6b2fe3

File tree

6 files changed

+29
-11
lines changed

6 files changed

+29
-11
lines changed

Diff for: devito/finite_differences/differentiable.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from devito.logger import warning
1515
from devito.tools import (as_tuple, filter_ordered, flatten, frozendict,
1616
infer_dtype, is_integer, split)
17-
from devito.types import (Array, DimensionTuple, Evaluable, Indexed,
18-
StencilDimension)
17+
from devito.types import Array, DimensionTuple, Evaluable, StencilDimension
1918

2019
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
2120
'Weights']

Diff for: devito/passes/iet/misc.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import cgen
44
import numpy as np
55
import sympy
6-
import numpy as np
76

87
from devito import configuration
98
from devito.finite_differences import Max, Min

Diff for: devito/symbolics/inspection.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
from sympy import (Function, Indexed, Integer, Mul, Number,
55
Pow, S, Symbol, Tuple)
6-
from sympy.core.operations import AssocOp
76
from sympy.core.numbers import ImaginaryUnit
87

98
from devito.finite_differences import Derivative

Diff for: devito/symbolics/printer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def _print_math_func(self, expr, nest=False, known=None):
103103
dtype = sympy_dtype(expr)
104104
if dtype is np.float32:
105105
cname += 'f'
106-
106+
if np.issubdtype(self.dtype, np.complexfloating):
107+
cname = 'c%s' % cname
107108
args = ', '.join((self._print(arg) for arg in expr.args))
108109

109110
return '%s(%s)' % (cname, args)
@@ -250,6 +251,8 @@ def _print_TrigonometricFunction(self, expr):
250251
func_name = str(expr.func)
251252
if self.dtype == np.float32:
252253
func_name += 'f'
254+
if np.issubdtype(self.dtype, np.complexfloating):
255+
func_name = 'c%s' % func_name
253256
return '%s(%s)' % (func_name, self._print(*expr.args))
254257

255258
def _print_DefFunction(self, expr):

Diff for: devito/types/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1440,7 +1440,7 @@ def _C_ctype(self):
14401440
try:
14411441
if np.issubdtype(self.dtype, np.complexfloating):
14421442
rtype = self.dtype(0).real.__class__
1443-
ctname = '%s _Complex' % dtype_to_cstr(rtype)
1443+
ctname = '%s complex' % dtype_to_cstr(rtype)
14441444
ctype = dtype_to_ctype(rtype)
14451445
r = type(ctname, (ctype,), {})
14461446
return POINTER(r)

Diff for: tests/test_operator.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
SparseFunction, SparseTimeFunction, Dimension, error, SpaceDimension,
1010
NODE, CELL, dimensions, configuration, TensorFunction,
1111
TensorTimeFunction, VectorFunction, VectorTimeFunction,
12-
div, grad, switchconfig)
12+
div, grad, switchconfig, exp)
1313
from devito import Inc, Le, Lt, Ge, Gt # noqa
1414
from devito.exceptions import InvalidOperator
1515
from devito.finite_differences.differentiable import diff2sympy
@@ -640,6 +640,24 @@ def test_tensor(self, func1):
640640
op2 = Operator([Eq(f, f.dx) for f in f1.values()])
641641
assert str(op1.ccode) == str(op2.ccode)
642642

643+
def test_complex(self):
644+
grid = Grid((5, 5))
645+
x, y = grid.dimensions
646+
# Float32 complex is called complex64 in numpy
647+
u = Function(name="u", grid=grid, dtype=np.complex64)
648+
649+
eq = Eq(u, x + 1j*y + exp(1j + x.spacing))
650+
# Currently wrong alias type
651+
op = Operator(eq, opt='noop')
652+
op()
653+
654+
# Check against numpy
655+
dx = grid.spacing_map[x.spacing]
656+
xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5))
657+
npres = xx + 1j*yy + np.exp(1j + dx)
658+
659+
assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0)
660+
643661

644662
class TestAllocation(object):
645663

@@ -724,10 +742,10 @@ def verify_parameters(self, parameters, expected):
724742
"""
725743
boilerplate = ['timers']
726744
parameters = [p.name for p in parameters]
727-
for exp in expected:
728-
if exp not in parameters + boilerplate:
729-
error("Missing parameter: %s" % exp)
730-
assert exp in parameters + boilerplate
745+
for expi in expected:
746+
if expi not in parameters + boilerplate:
747+
error("Missing parameter: %s" % expi)
748+
assert expi in parameters + boilerplate
731749
extra = [p for p in parameters if p not in expected and p not in boilerplate]
732750
if len(extra) > 0:
733751
error("Redundant parameters: %s" % str(extra))

0 commit comments

Comments
 (0)