Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0268781

Browse files
committedMay 28, 2024··
compiler: generate std:complex for cpp compilers
1 parent d2dc9ee commit 0268781

File tree

3 files changed

+31
-20
lines changed

3 files changed

+31
-20
lines changed
 

‎devito/ir/iet/visitors.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import ctypes
1111

1212
import cgen as c
13+
import numpy as np
1314
from sympy import IndexedBase
1415
from sympy.core.function import Application
1516

@@ -188,6 +189,18 @@ def __init__(self, *args, compiler=None, **kwargs):
188189
}
189190
_restrict_keyword = 'restrict'
190191

192+
def _complex_type(self, ctypestr, dtype):
193+
# Not complex
194+
if not np.issubdtype(dtype, np.complexfloating):
195+
return ctypestr
196+
# Complex only supported for float and double
197+
if ctypestr not in ('float', 'double'):
198+
return ctypestr
199+
if self._compiler._cpp:
200+
return 'std::complex<%s>' % ctypestr
201+
else:
202+
return '%s _Complex' % ctypestr
203+
191204
def _gen_struct_decl(self, obj, masked=()):
192205
"""
193206
Convert ctypes.Struct -> cgen.Structure.
@@ -243,10 +256,10 @@ def _gen_value(self, obj, mode=1, masked=()):
243256
if getattr(obj.function, k, False) and v not in masked]
244257

245258
if (obj._mem_stack or obj._mem_constant) and mode == 1:
246-
strtype = obj._C_typedata
259+
strtype = self._complex_type(obj._C_typedata, obj.dtype)
247260
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
248261
else:
249-
strtype = ctypes_to_cstr(obj._C_ctype)
262+
strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype)
250263
strshape = ''
251264
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
252265
if not obj._mem_stack:
@@ -376,10 +389,11 @@ def visit_tuple(self, o):
376389
def visit_PointerCast(self, o):
377390
f = o.function
378391
i = f.indexed
392+
cstr = self._complex_type(i._C_typedata, i.dtype)
379393

380394
if f.is_PointerArray:
381395
# lvalue
382-
lvalue = c.Value(i._C_typedata, '**%s' % f.name)
396+
lvalue = c.Value(cstr, '**%s' % f.name)
383397

384398
# rvalue
385399
if isinstance(o.obj, ArrayObject):
@@ -388,7 +402,7 @@ def visit_PointerCast(self, o):
388402
v = f._C_name
389403
else:
390404
assert False
391-
rvalue = '(%s**) %s' % (i._C_typedata, v)
405+
rvalue = '(%s**) %s' % (cstr, v)
392406

393407
else:
394408
# lvalue
@@ -399,10 +413,10 @@ def visit_PointerCast(self, o):
399413
if o.flat is None:
400414
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
401415
rshape = '(*)%s' % shape
402-
lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape))
416+
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
403417
else:
404418
rshape = '*'
405-
lvalue = c.Value(i._C_typedata, '*%s' % v)
419+
lvalue = c.Value(cstr, '*%s' % v)
406420
if o.alignment and f._data_alignment:
407421
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)
408422

@@ -415,30 +429,30 @@ def visit_PointerCast(self, o):
415429
else:
416430
assert False
417431

418-
rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v)
432+
rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v)
419433
else:
420434
if isinstance(o.obj, Pointer):
421435
v = o.obj.name
422436
else:
423437
v = f._C_name
424438

425-
rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v)
439+
rvalue = '(%s %s) %s' % (cstr, rshape, v)
426440

427441
return c.Initializer(lvalue, rvalue)
428442

429443
def visit_Dereference(self, o):
430444
a0, a1 = o.functions
431445
if a1.is_PointerArray or a1.is_TempFunction:
432446
i = a1.indexed
447+
cstr = self._complex_type(i._C_typedata, i.dtype)
433448
if o.flat is None:
434449
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
435-
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
450+
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
436451
a1.dim.name)
437-
lvalue = c.Value(i._C_typedata,
438-
'(*restrict %s)%s' % (a0.name, shape))
452+
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
439453
else:
440-
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
441-
lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name)
454+
rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name)
455+
lvalue = c.Value(cstr, '*restrict %s' % a0.name)
442456
if a0._data_alignment:
443457
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
444458
else:

‎devito/passes/iet/misc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,12 @@ def complex_include(iet, language, compiler):
200200
"""
201201
Add headers for complex arithmetic
202202
"""
203-
lib = _complex_lib.get(language, 'complex.h')
203+
lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h')
204204

205205
headers = {}
206206
# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
207207
if compiler._cpp:
208-
headers = {('_Complex_I', ('1.0fi'))}
208+
headers = {('_Complex_I', ('std::complex<float>(0.0f, 1.0f)'))}
209209

210210
for f in FindSymbols().visit(iet):
211211
try:

‎devito/tools/dtypes_lowering.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,7 @@ def dtype_to_ctype(dtype):
139139
# Complex data
140140
if np.issubdtype(dtype, np.complexfloating):
141141
rtype = dtype(0).real.__class__
142-
ctname = '%s _Complex' % dtype_to_cstr(rtype)
143-
ctype = dtype_to_ctype(rtype)
144-
r = type(ctname, (ctype,), {})
145-
return r
142+
return dtype_to_ctype(rtype)
146143

147144
try:
148145
return ctypes_vector_mapper[dtype]
@@ -217,7 +214,7 @@ class c_restrict_void_p(ctypes.c_void_p):
217214
# *** ctypes lowering
218215

219216

220-
def ctypes_to_cstr(ctype, toarray=None):
217+
def ctypes_to_cstr(ctype, toarray=None, cpp=False):
221218
"""Translate ctypes types into C strings."""
222219
if ctype in ctypes_vector_mapper.values():
223220
retval = ctype.__name__

0 commit comments

Comments
 (0)
Please sign in to comment.