Skip to content

Commit e7a2791

Browse files
committed
compiler: fix alias dtype with complex numbers
1 parent 014ef2c commit e7a2791

File tree

9 files changed

+86
-46
lines changed

9 files changed

+86
-46
lines changed

devito/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def reinit_compiler(val):
5252
"""
5353
Re-initialize the Compiler.
5454
"""
55-
configuration['compiler'].__init__(suffix=configuration['compiler'].suffix,
55+
configuration['compiler'].__init__(name=configuration['compiler'].name,
56+
suffix=configuration['compiler'].suffix,
5657
mpi=configuration['mpi'])
5758
return val
5859

@@ -61,7 +62,7 @@ def reinit_compiler(val):
6162
configuration.add('platform', 'cpu64', list(platform_registry),
6263
callback=lambda i: platform_registry[i]())
6364
configuration.add('compiler', 'custom', list(compiler_registry),
64-
callback=lambda i: compiler_registry[i]())
65+
callback=lambda i: compiler_registry[i](name=i))
6566

6667
# Setup language for shared-memory parallelism
6768
preprocessor = lambda i: {0: 'C', 1: 'openmp'}.get(i, i) # Handles DEVITO_OPENMP deprec

devito/arch/compiler.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def __init__(self):
178178
_cpp = False
179179

180180
def __init__(self, **kwargs):
181+
self._name = kwargs.pop('name', self.__class__.__name__)
182+
181183
super().__init__(**kwargs)
182184

183185
self.__lookup_cmds__()
@@ -221,13 +223,13 @@ def __new_with__(self, **kwargs):
221223
Create a new Compiler from an existing one, inherenting from it
222224
the flags that are not specified via ``kwargs``.
223225
"""
224-
return self.__class__(suffix=kwargs.pop('suffix', self.suffix),
226+
return self.__class__(name=self.name, suffix=kwargs.pop('suffix', self.suffix),
225227
mpi=kwargs.pop('mpi', configuration['mpi']),
226228
**kwargs)
227229

228230
@property
229231
def name(self):
230-
return self.__class__.__name__
232+
return self._name
231233

232234
@property
233235
def version(self):
@@ -243,6 +245,20 @@ def version(self):
243245

244246
return version
245247

248+
@property
249+
def _complex_ctype(self):
250+
"""
251+
Type definition for complex numbers. THese two cases cover 99% of the cases since
252+
- Hip is now using std::complex
253+
https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings
254+
- Sycl supports std::complex
255+
- C's _Complex is part of C99
256+
"""
257+
if self._cpp:
258+
return lambda dtype: 'std::complex<%s>' % str(dtype)
259+
else:
260+
return lambda dtype: '%s _Complex' % str(dtype)
261+
246262
def get_version(self):
247263
result, stdout, stderr = call_capture_output((self.cc, "--version"))
248264
if result != 0:
@@ -697,6 +713,10 @@ def __lookup_cmds__(self):
697713
self.MPICC = 'nvcc'
698714
self.MPICXX = 'nvcc'
699715

716+
@property
717+
def _complex_ctype(self):
718+
return lambda dtype: 'thrust::complex<%s>' % str(dtype)
719+
700720

701721
class HipCompiler(Compiler):
702722

devito/ir/iet/visitors.py

+14-20
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
import ctypes
1111

1212
import cgen as c
13-
import numpy as np
1413
from sympy import IndexedBase
1514
from sympy.core.function import Application
1615

17-
from devito.parameters import configuration
16+
from devito.parameters import configuration, switchconfig
1817
from devito.exceptions import VisitorException
1918
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
2019
Call, Lambda, BlankLine, Section, ListMajor)
@@ -190,20 +189,15 @@ def __init__(self, *args, compiler=None, **kwargs):
190189
}
191190
_restrict_keyword = 'restrict'
192191

193-
def _complex_type(self, ctypestr, dtype):
194-
# Not complex
195-
try:
196-
if not np.issubdtype(dtype, np.complexfloating):
197-
return ctypestr
198-
except TypeError:
199-
return ctypestr
200-
# Complex only supported for float and double
201-
if ctypestr not in ('float', 'double'):
202-
return ctypestr
203-
if self._compiler._cpp:
204-
return 'std::complex<%s>' % ctypestr
205-
else:
206-
return '%s _Complex' % ctypestr
192+
@property
193+
def compiler(self):
194+
return self._compiler
195+
196+
def visit(self, o, *args, **kwargs):
197+
# Make sure the visitor always is within the generating compiler
198+
# in case the configuration is accessed
199+
with switchconfig(compiler=self.compiler.name):
200+
return super().visit(o, *args, **kwargs)
207201

208202
def _gen_struct_decl(self, obj, masked=()):
209203
"""
@@ -260,10 +254,10 @@ def _gen_value(self, obj, mode=1, masked=()):
260254
if getattr(obj.function, k, False) and v not in masked]
261255

262256
if (obj._mem_stack or obj._mem_constant) and mode == 1:
263-
strtype = self._complex_type(obj._C_typedata, obj.dtype)
257+
strtype = obj._C_typedata
264258
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
265259
else:
266-
strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype)
260+
strtype = ctypes_to_cstr(obj._C_ctype)
267261
strshape = ''
268262
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
269263
if not obj._mem_stack:
@@ -393,7 +387,7 @@ def visit_tuple(self, o):
393387
def visit_PointerCast(self, o):
394388
f = o.function
395389
i = f.indexed
396-
cstr = self._complex_type(i._C_typedata, i.dtype)
390+
cstr = i._C_typedata
397391

398392
if f.is_PointerArray:
399393
# lvalue
@@ -448,7 +442,7 @@ def visit_Dereference(self, o):
448442
a0, a1 = o.functions
449443
if a1.is_PointerArray or a1.is_TempFunction:
450444
i = a1.indexed
451-
cstr = self._complex_type(i._C_typedata, i.dtype)
445+
cstr = i._C_typedata
452446
if o.flat is None:
453447
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
454448
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,

devito/operator/operator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,8 @@ def parse_kwargs(**kwargs):
12201220
raise InvalidOperator("Illegal `compiler=%s`" % str(compiler))
12211221
kwargs['compiler'] = compiler_registry[compiler](platform=kwargs['platform'],
12221222
language=kwargs['language'],
1223-
mpi=configuration['mpi'])
1223+
mpi=configuration['mpi'],
1224+
name=compiler)
12241225
elif any([platform, language]):
12251226
kwargs['compiler'] =\
12261227
configuration['compiler'].__new_with__(platform=kwargs['platform'],

devito/passes/iet/misc.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from devito.passes.iet.engine import iet_pass
1212
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
1313
from devito.symbolics import ValueLimit, evalrel, has_integer_args, limits_mapper
14-
from devito.tools import as_mapper, filter_ordered, split
14+
from devito.tools import as_mapper, filter_ordered, split, dtype_to_cstr
1515

1616
__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions',
1717
'generate_macros', 'minimize_symbols', 'complex_include']
@@ -192,22 +192,28 @@ def minimize_symbols(iet):
192192
return iet, {}
193193

194194

195-
_complex_lib = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'}
195+
_complex_lib = {'cuda': 'thrust/complex.h', 'hip': 'hip/hip_complex.h'}
196196

197197

198198
@iet_pass
199199
def complex_include(iet, language, compiler):
200200
"""
201201
Add headers for complex arithmetic
202202
"""
203+
# Check if there is complex numbers that always take dtype precedence
204+
max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)])
205+
if not np.issubdtype(max_dtype, np.complexfloating):
206+
return iet, {}
207+
203208
lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),)
204209

205210
headers = {}
206211

207212
# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
208213
if compiler._cpp:
214+
c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type)
209215
# Constant I
210-
headers = {('_Complex_I', ('std::complex<float>(0.0f, 1.0f)'))}
216+
headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))}
211217
# Mix arithmetic definitions
212218
dest = compiler.get_jit_dir()
213219
hfile = dest.joinpath('stdcomplex_arith.h')
@@ -216,14 +222,7 @@ def complex_include(iet, language, compiler):
216222
ff.write(str(_stdcomplex_defs))
217223
lib += (str(hfile),)
218224

219-
for f in FindSymbols().visit(iet):
220-
try:
221-
if np.issubdtype(f.dtype, np.complexfloating):
222-
return iet, {'includes': lib, 'headers': headers}
223-
except TypeError:
224-
pass
225-
226-
return iet, {}
225+
return iet, {'includes': lib, 'headers': headers}
227226

228227

229228
def remove_redundant_moddims(iet):
@@ -314,13 +313,29 @@ def _rename_subdims(target, dimensions):
314313
return std::complex<_Tp>(b.real() * a, b.imag() * a);
315314
}
316315
316+
template<typename _Tp, typename _Ti>
317+
std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){
318+
return std::complex<_Tp>(b.real() * a, b.imag() * a);
319+
}
320+
317321
template<typename _Tp, typename _Ti>
318322
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
323+
_Tp denom = b.real() * b.real () + b.imag() * b.imag()
324+
return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom);
325+
}
326+
327+
template<typename _Tp, typename _Ti>
328+
std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){
319329
return std::complex<_Tp>(b.real() / a, b.imag() / a);
320330
}
321331
322332
template<typename _Tp, typename _Ti>
323333
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
324334
return std::complex<_Tp>(b.real() + a, b.imag());
325335
}
336+
337+
template<typename _Tp, typename _Ti>
338+
std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){
339+
return std::complex<_Tp>(b.real() + a, b.imag());
340+
}
326341
"""

devito/symbolics/inspection.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,12 @@ def sympy_dtype(expr, base=None):
304304
dtypes.add(i.dtype)
305305
except AttributeError:
306306
pass
307+
307308
dtype = infer_dtype(dtypes)
308309

309310
# Promote if complex
310-
if expr.has(ImaginaryUnit):
311+
is_im = np.issubdtype(dtype, np.complexfloating)
312+
if expr.has(ImaginaryUnit) and not is_im:
311313
dtype = np.promote_types(dtype, np.complex64).type
312314

313315
return dtype

devito/tools/dtypes_lowering.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,12 @@ def dtype_to_ctype(dtype):
139139
# Complex data
140140
if np.issubdtype(dtype, np.complexfloating):
141141
rtype = dtype(0).real.__class__
142-
return dtype_to_ctype(rtype)
142+
from devito import configuration
143+
make = configuration['compiler']._complex_ctype
144+
ctname = make(dtype_to_cstr(rtype))
145+
ctype = dtype_to_ctype(rtype)
146+
r = type(ctname, (ctype,), {})
147+
return r
143148

144149
try:
145150
return ctypes_vector_mapper[dtype]
@@ -308,7 +313,8 @@ def infer_dtype(dtypes):
308313
# Resolve the vector types, if any
309314
dtypes = {dtypes_vector_mapper.get_base_dtype(i, i) for i in dtypes}
310315

311-
fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating)}
316+
fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating) or
317+
np.issubdtype(i, np.complexfloating)}
312318
if len(fdtypes) > 1:
313319
return max(fdtypes, key=lambda i: np.dtype(i).itemsize)
314320
elif len(fdtypes) == 1:

tests/test_gpu_common.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,16 @@ def test_maxpar_option(self):
6767
assert trees[0][0] is trees[1][0]
6868
assert trees[0][1] is not trees[1][1]
6969

70-
def test_complex(self):
70+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
71+
def test_complex(self, dtype):
7172
grid = Grid((5, 5))
7273
x, y = grid.dimensions
73-
# Float32 complex is called complex64 in numpy
74-
u = Function(name="u", grid=grid, dtype=np.complex64)
74+
u = Function(name="u", grid=grid, dtype=dtype)
7575

7676
eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
7777
# Currently wrong alias type
7878
op = Operator(eq)
79+
print(op)
7980
op()
8081

8182
# Check against numpy

tests/test_operator.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -640,22 +640,22 @@ def test_tensor(self, func1):
640640
op2 = Operator([Eq(f, f.dx) for f in f1.values()])
641641
assert str(op1.ccode) == str(op2.ccode)
642642

643-
def test_complex(self):
643+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
644+
def test_complex(self, dtype):
644645
grid = Grid((5, 5))
645646
x, y = grid.dimensions
646-
# Float32 complex is called complex64 in numpy
647-
u = Function(name="u", grid=grid, dtype=np.complex64)
647+
u = Function(name="u", grid=grid, dtype=dtype)
648648

649649
eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
650650
# Currently wrong alias type
651651
op = Operator(eq)
652+
# print(op)
652653
op()
653654

654655
# Check against numpy
655656
dx = grid.spacing_map[x.spacing]
656657
xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5))
657658
npres = xx + 1j*yy + np.exp(1j + dx)
658-
print(op)
659659

660660
assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0)
661661

0 commit comments

Comments
 (0)