Skip to content

Commit 5200adc

Browse files
committed
compiler: fix internal language specific types and cast
1 parent be9e55e commit 5200adc

File tree

5 files changed

+44
-17
lines changed

5 files changed

+44
-17
lines changed

devito/arch/compiler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def version(self):
248248
@property
249249
def _complex_ctype(self):
250250
"""
251-
Type definition for complex numbers. THese two cases cover 99% of the cases since
251+
Type definition for complex numbers. These two cases cover 99% of the cases since
252252
- Hip is now using std::complex
253253
https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings
254254
- Sycl supports std::complex
@@ -996,6 +996,7 @@ def __new_with__(self, **kwargs):
996996
'nvc++': NvidiaCompiler,
997997
'nvidia': NvidiaCompiler,
998998
'cuda': CudaCompiler,
999+
'nvcc': CudaCompiler,
9991000
'osx': ClangCompiler,
10001001
'intel': OneapiCompiler,
10011002
'icx': OneapiCompiler,

devito/passes/iet/misc.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,6 @@ def minimize_symbols(iet):
192192
return iet, {}
193193

194194

195-
_complex_lib = {'cuda': 'thrust/complex.h'}
196-
197-
198195
@iet_pass
199196
def complex_include(iet, language, compiler):
200197
"""
@@ -205,22 +202,28 @@ def complex_include(iet, language, compiler):
205202
if not np.issubdtype(max_dtype, np.complexfloating):
206203
return iet, {}
207204

208-
lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),)
205+
is_cuda = language == 'cuda'
206+
if is_cuda:
207+
lib = ('thrust/complex.h',)
208+
else:
209+
lib = ('complex' if compiler._cpp else 'complex.h',)
209210

210211
headers = {}
211212

212-
# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
213+
# For (cpp), need to define constant _Complex_I and missing mix-type
214+
# std::complex arithmetic
213215
if compiler._cpp:
216+
namespace = 'thrust' if is_cuda else 'std'
214217
c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type)
215218
# Constant I
216-
headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))}
217-
# Mix arithmetic definitions
218-
dest = compiler.get_jit_dir()
219-
hfile = dest.joinpath('stdcomplex_arith.h')
220-
if not hfile.is_file():
219+
headers = {('_Complex_I', ('%s::complex<%s>(0.0, 1.0)' % (namespace, c_str)))}
220+
# Mix arithmetic definitions, only for std, thrust has it defined
221+
if not is_cuda:
222+
dest = compiler.get_jit_dir()
223+
hfile = dest.joinpath('stdcomplex_arith.h')
221224
with open(str(hfile), 'w') as ff:
222225
ff.write(str(_stdcomplex_defs))
223-
lib += (str(hfile),)
226+
lib += (str(hfile),)
224227

225228
return iet, {'includes': lib, 'headers': headers}
226229

devito/symbolics/extended_sympy.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sympy import Expr, Function, Number, Tuple, sympify
88
from sympy.core.decorators import call_highest_priority
99

10+
from devito import configuration
1011
from devito.finite_differences.elementary import Min, Max
1112
from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa
1213
float3, float4, double2, double3, double4, int2, int3,
@@ -811,6 +812,20 @@ class VOID(Cast):
811812
_base_typ = 'void'
812813

813814

815+
class CFLOAT(Cast):
816+
817+
@property
818+
def _base_typ(self):
819+
return configuration['compiler']._complex_ctype('float')
820+
821+
822+
class CDOUBLE(Cast):
823+
824+
@property
825+
def _base_typ(self):
826+
return configuration['compiler']._complex_ctype('double')
827+
828+
814829
class CHARP(CastStar):
815830
base = CHAR
816831

@@ -827,6 +842,14 @@ class USHORTP(CastStar):
827842
base = USHORT
828843

829844

845+
class CFLOATP(CastStar):
846+
base = CFLOAT
847+
848+
849+
class CDOUBLEP(CastStar):
850+
base = CDOUBLE
851+
852+
830853
cast_mapper = {
831854
np.int8: CHAR,
832855
np.uint8: UCHAR,
@@ -839,6 +862,8 @@ class USHORTP(CastStar):
839862
np.float32: FLOAT, # noqa
840863
float: DOUBLE, # noqa
841864
np.float64: DOUBLE, # noqa
865+
np.complex64: CFLOAT, # noqa
866+
np.complex128: CDOUBLE, # noqa
842867

843868
(np.int8, '*'): CHARP,
844869
(np.uint8, '*'): UCHARP,
@@ -849,7 +874,9 @@ class USHORTP(CastStar):
849874
(np.int64, '*'): INTP, # noqa
850875
(np.float32, '*'): FLOATP, # noqa
851876
(float, '*'): DOUBLEP, # noqa
852-
(np.float64, '*'): DOUBLEP # noqa
877+
(np.float64, '*'): DOUBLEP, # noqa
878+
(np.complex64, '*'): CFLOATP, # noqa
879+
(np.complex128, '*'): CDOUBLEP, # noqa
853880
}
854881

855882
for base_name in ['int', 'float', 'double']:

tests/test_gpu_common.py

-2
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ def test_complex(self, dtype):
7474
u = Function(name="u", grid=grid, dtype=dtype)
7575

7676
eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
77-
# Currently wrong alias type
7877
op = Operator(eq)
79-
print(op)
8078
op()
8179

8280
# Check against numpy

tests/test_operator.py

-2
Original file line numberDiff line numberDiff line change
@@ -647,9 +647,7 @@ def test_complex(self, dtype):
647647
u = Function(name="u", grid=grid, dtype=dtype)
648648

649649
eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
650-
# Currently wrong alias type
651650
op = Operator(eq)
652-
# print(op)
653651
op()
654652

655653
# Check against numpy

0 commit comments

Comments
 (0)