Skip to content

Commit 2c80bf8

Browse files
committed
compiler: generate std:complex for cpp compilers
1 parent d2dc9ee commit 2c80bf8

File tree

3 files changed

+34
-20
lines changed

3 files changed

+34
-20
lines changed

devito/ir/iet/visitors.py

+30-13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import ctypes
1111

1212
import cgen as c
13+
import numpy as np
1314
from sympy import IndexedBase
1415
from sympy.core.function import Application
1516

@@ -188,6 +189,21 @@ def __init__(self, *args, compiler=None, **kwargs):
188189
}
189190
_restrict_keyword = 'restrict'
190191

192+
def _complex_type(self, ctypestr, dtype):
193+
# Not complex
194+
try:
195+
if not np.issubdtype(dtype, np.complexfloating):
196+
return ctypestr
197+
except TypeError:
198+
return ctypestr
199+
# Complex only supported for float and double
200+
if ctypestr not in ('float', 'double'):
201+
return ctypestr
202+
if self._compiler._cpp:
203+
return 'std::complex<%s>' % ctypestr
204+
else:
205+
return '%s _Complex' % ctypestr
206+
191207
def _gen_struct_decl(self, obj, masked=()):
192208
"""
193209
Convert ctypes.Struct -> cgen.Structure.
@@ -243,10 +259,10 @@ def _gen_value(self, obj, mode=1, masked=()):
243259
if getattr(obj.function, k, False) and v not in masked]
244260

245261
if (obj._mem_stack or obj._mem_constant) and mode == 1:
246-
strtype = obj._C_typedata
262+
strtype = self._complex_type(obj._C_typedata, obj.dtype)
247263
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
248264
else:
249-
strtype = ctypes_to_cstr(obj._C_ctype)
265+
strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype)
250266
strshape = ''
251267
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
252268
if not obj._mem_stack:
@@ -376,10 +392,11 @@ def visit_tuple(self, o):
376392
def visit_PointerCast(self, o):
377393
f = o.function
378394
i = f.indexed
395+
cstr = self._complex_type(i._C_typedata, i.dtype)
379396

380397
if f.is_PointerArray:
381398
# lvalue
382-
lvalue = c.Value(i._C_typedata, '**%s' % f.name)
399+
lvalue = c.Value(cstr, '**%s' % f.name)
383400

384401
# rvalue
385402
if isinstance(o.obj, ArrayObject):
@@ -388,7 +405,7 @@ def visit_PointerCast(self, o):
388405
v = f._C_name
389406
else:
390407
assert False
391-
rvalue = '(%s**) %s' % (i._C_typedata, v)
408+
rvalue = '(%s**) %s' % (cstr, v)
392409

393410
else:
394411
# lvalue
@@ -399,10 +416,10 @@ def visit_PointerCast(self, o):
399416
if o.flat is None:
400417
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
401418
rshape = '(*)%s' % shape
402-
lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape))
419+
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
403420
else:
404421
rshape = '*'
405-
lvalue = c.Value(i._C_typedata, '*%s' % v)
422+
lvalue = c.Value(cstr, '*%s' % v)
406423
if o.alignment and f._data_alignment:
407424
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)
408425

@@ -415,30 +432,30 @@ def visit_PointerCast(self, o):
415432
else:
416433
assert False
417434

418-
rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v)
435+
rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v)
419436
else:
420437
if isinstance(o.obj, Pointer):
421438
v = o.obj.name
422439
else:
423440
v = f._C_name
424441

425-
rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v)
442+
rvalue = '(%s %s) %s' % (cstr, rshape, v)
426443

427444
return c.Initializer(lvalue, rvalue)
428445

429446
def visit_Dereference(self, o):
430447
a0, a1 = o.functions
431448
if a1.is_PointerArray or a1.is_TempFunction:
432449
i = a1.indexed
450+
cstr = self._complex_type(i._C_typedata, i.dtype)
433451
if o.flat is None:
434452
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
435-
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
453+
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
436454
a1.dim.name)
437-
lvalue = c.Value(i._C_typedata,
438-
'(*restrict %s)%s' % (a0.name, shape))
455+
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
439456
else:
440-
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
441-
lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name)
457+
rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name)
458+
lvalue = c.Value(cstr, '*restrict %s' % a0.name)
442459
if a0._data_alignment:
443460
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
444461
else:

devito/passes/iet/misc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,12 @@ def complex_include(iet, language, compiler):
200200
"""
201201
Add headers for complex arithmetic
202202
"""
203-
lib = _complex_lib.get(language, 'complex.h')
203+
lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h')
204204

205205
headers = {}
206206
# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
207207
if compiler._cpp:
208-
headers = {('_Complex_I', ('1.0fi'))}
208+
headers = {('_Complex_I', ('std::complex<float>(0.0f, 1.0f)'))}
209209

210210
for f in FindSymbols().visit(iet):
211211
try:

devito/tools/dtypes_lowering.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,7 @@ def dtype_to_ctype(dtype):
139139
# Complex data
140140
if np.issubdtype(dtype, np.complexfloating):
141141
rtype = dtype(0).real.__class__
142-
ctname = '%s _Complex' % dtype_to_cstr(rtype)
143-
ctype = dtype_to_ctype(rtype)
144-
r = type(ctname, (ctype,), {})
145-
return r
142+
return dtype_to_ctype(rtype)
146143

147144
try:
148145
return ctypes_vector_mapper[dtype]
@@ -217,7 +214,7 @@ class c_restrict_void_p(ctypes.c_void_p):
217214
# *** ctypes lowering
218215

219216

220-
def ctypes_to_cstr(ctype, toarray=None):
217+
def ctypes_to_cstr(ctype, toarray=None, cpp=False):
221218
"""Translate ctypes types into C strings."""
222219
if ctype in ctypes_vector_mapper.values():
223220
retval = ctype.__name__

0 commit comments

Comments
 (0)