Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3b0d7e5

Browse files
committedJun 21, 2024
compiler: fix internal language specific types and cast
wip
1 parent abf483c commit 3b0d7e5

File tree

7 files changed

+34
-78
lines changed

7 files changed

+34
-78
lines changed
 

‎devito/arch/compiler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def version(self):
248248
@property
249249
def _complex_ctype(self):
250250
"""
251-
Type definition for complex numbers. THese two cases cover 99% of the cases since
251+
Type definition for complex numbers. These two cases cover 99% of the cases since
252252
- Hip is now using std::complex
253253
https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings
254254
- Sycl supports std::complex
@@ -996,6 +996,7 @@ def __new_with__(self, **kwargs):
996996
'nvc++': NvidiaCompiler,
997997
'nvidia': NvidiaCompiler,
998998
'cuda': CudaCompiler,
999+
'nvcc': CudaCompiler,
9991000
'osx': ClangCompiler,
10001001
'intel': OneapiCompiler,
10011002
'icx': OneapiCompiler,

‎devito/operator/operator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from devito.parameters import configuration
2323
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
2424
generate_macros, minimize_symbols, unevaluate,
25-
error_mapper, complex_include)
25+
error_mapper, include_complex)
2626
from devito.symbolics import estimate_cost
2727
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
2828
filter_sorted, frozendict, is_integer, split, timed_pass,
@@ -468,7 +468,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
468468

469469
# Complex header if needed. Needs to be done before specialization
470470
# as some specific cases require complex to be loaded first
471-
complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler'])
471+
include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler'])
472472

473473
# Specialize
474474
graph = cls._specialize_iet(graph, **kwargs)

‎devito/passes/iet/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .instrument import * # noqa
99
from .languages import * # noqa
1010
from .errors import * # noqa
11+
from .complex import * # noqa

‎devito/passes/iet/misc.py

+1-70
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from devito.types import FIndexed
1717

1818
__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions',
19-
'generate_macros', 'minimize_symbols', 'complex_include']
19+
'generate_macros', 'minimize_symbols']
2020

2121

2222
@iet_pass
@@ -240,39 +240,6 @@ def minimize_symbols(iet):
240240
return iet, {}
241241

242242

243-
_complex_lib = {'cuda': 'thrust/complex.h'}
244-
245-
246-
@iet_pass
247-
def complex_include(iet, language, compiler):
248-
"""
249-
Add headers for complex arithmetic
250-
"""
251-
# Check if there is complex numbers that always take dtype precedence
252-
max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)])
253-
if not np.issubdtype(max_dtype, np.complexfloating):
254-
return iet, {}
255-
256-
lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),)
257-
258-
headers = {}
259-
260-
# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
261-
if compiler._cpp:
262-
c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type)
263-
# Constant I
264-
headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))}
265-
# Mix arithmetic definitions
266-
dest = compiler.get_jit_dir()
267-
hfile = dest.joinpath('stdcomplex_arith.h')
268-
if not hfile.is_file():
269-
with open(str(hfile), 'w') as ff:
270-
ff.write(str(_stdcomplex_defs))
271-
lib += (str(hfile),)
272-
273-
return iet, {'includes': lib, 'headers': headers}
274-
275-
276243
def remove_redundant_moddims(iet):
277244
key = lambda d: d.is_Modulo and d.origin is not None
278245
mds = [d for d in FindSymbols('dimensions').visit(iet) if key(d)]
@@ -351,39 +318,3 @@ def _rename_subdims(target, dimensions):
351318
return {d: d._rebuild(d.root.name) for d in dims
352319
if d.root not in dimensions
353320
and names.count(d.root.name) < 2}
354-
355-
356-
_stdcomplex_defs = """
357-
#include <complex>
358-
359-
template<typename _Tp, typename _Ti>
360-
std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){
361-
return std::complex<_Tp>(b.real() * a, b.imag() * a);
362-
}
363-
364-
template<typename _Tp, typename _Ti>
365-
std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){
366-
return std::complex<_Tp>(b.real() * a, b.imag() * a);
367-
}
368-
369-
template<typename _Tp, typename _Ti>
370-
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
371-
_Tp denom = b.real() * b.real () + b.imag() * b.imag()
372-
return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom);
373-
}
374-
375-
template<typename _Tp, typename _Ti>
376-
std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){
377-
return std::complex<_Tp>(b.real() / a, b.imag() / a);
378-
}
379-
380-
template<typename _Tp, typename _Ti>
381-
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
382-
return std::complex<_Tp>(b.real() + a, b.imag());
383-
}
384-
385-
template<typename _Tp, typename _Ti>
386-
std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){
387-
return std::complex<_Tp>(b.real() + a, b.imag());
388-
}
389-
"""

‎devito/symbolics/extended_sympy.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sympy import Expr, Function, Number, Tuple, sympify
88
from sympy.core.decorators import call_highest_priority
99

10+
from devito import configuration
1011
from devito.finite_differences.elementary import Min, Max
1112
from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa
1213
float3, float4, double2, double3, double4, int2, int3,
@@ -811,6 +812,20 @@ class VOID(Cast):
811812
_base_typ = 'void'
812813

813814

815+
class CFLOAT(Cast):
816+
817+
@property
818+
def _base_typ(self):
819+
return configuration['compiler']._complex_ctype('float')
820+
821+
822+
class CDOUBLE(Cast):
823+
824+
@property
825+
def _base_typ(self):
826+
return configuration['compiler']._complex_ctype('double')
827+
828+
814829
class CHARP(CastStar):
815830
base = CHAR
816831

@@ -827,6 +842,14 @@ class USHORTP(CastStar):
827842
base = USHORT
828843

829844

845+
class CFLOATP(CastStar):
846+
base = CFLOAT
847+
848+
849+
class CDOUBLEP(CastStar):
850+
base = CDOUBLE
851+
852+
830853
cast_mapper = {
831854
np.int8: CHAR,
832855
np.uint8: UCHAR,
@@ -839,6 +862,8 @@ class USHORTP(CastStar):
839862
np.float32: FLOAT, # noqa
840863
float: DOUBLE, # noqa
841864
np.float64: DOUBLE, # noqa
865+
np.complex64: CFLOAT, # noqa
866+
np.complex128: CDOUBLE, # noqa
842867

843868
(np.int8, '*'): CHARP,
844869
(np.uint8, '*'): UCHARP,
@@ -849,7 +874,9 @@ class USHORTP(CastStar):
849874
(np.int64, '*'): INTP, # noqa
850875
(np.float32, '*'): FLOATP, # noqa
851876
(float, '*'): DOUBLEP, # noqa
852-
(np.float64, '*'): DOUBLEP # noqa
877+
(np.float64, '*'): DOUBLEP, # noqa
878+
(np.complex64, '*'): CFLOATP, # noqa
879+
(np.complex128, '*'): CDOUBLEP, # noqa
853880
}
854881

855882
for base_name in ['int', 'float', 'double']:

‎tests/test_gpu_common.py

-2
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ def test_complex(self, dtype):
7474
u = Function(name="u", grid=grid, dtype=dtype)
7575

7676
eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
77-
# Currently wrong alias type
7877
op = Operator(eq)
79-
print(op)
8078
op()
8179

8280
# Check against numpy

‎tests/test_operator.py

-2
Original file line numberDiff line numberDiff line change
@@ -647,9 +647,7 @@ def test_complex(self, dtype):
647647
u = Function(name="u", grid=grid, dtype=dtype)
648648

649649
eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
650-
# Currently wrong alias type
651650
op = Operator(eq)
652-
# print(op)
653651
op()
654652

655653
# Check against numpy

0 commit comments

Comments
 (0)
Please sign in to comment.