Skip to content

Commit fdf1b36

Browse files
committed
compiler: rework dtype lowering
1 parent 3b0d7e5 commit fdf1b36

18 files changed

+326
-177
lines changed

devito/arch/compiler.py

-18
Original file line numberDiff line numberDiff line change
@@ -245,20 +245,6 @@ def version(self):
245245

246246
return version
247247

248-
@property
249-
def _complex_ctype(self):
250-
"""
251-
Type definition for complex numbers. These two cases cover 99% of the cases since
252-
- Hip is now using std::complex
253-
https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings
254-
- Sycl supports std::complex
255-
- C's _Complex is part of C99
256-
"""
257-
if self._cpp:
258-
return lambda dtype: 'std::complex<%s>' % str(dtype)
259-
else:
260-
return lambda dtype: '%s _Complex' % str(dtype)
261-
262248
def get_version(self):
263249
result, stdout, stderr = call_capture_output((self.cc, "--version"))
264250
if result != 0:
@@ -713,10 +699,6 @@ def __lookup_cmds__(self):
713699
self.MPICC = 'nvcc'
714700
self.MPICXX = 'nvcc'
715701

716-
@property
717-
def _complex_ctype(self):
718-
return lambda dtype: 'thrust::complex<%s>' % str(dtype)
719-
720702

721703
class HipCompiler(Compiler):
722704

devito/core/gpu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def _make_iet_passes_mapper(cls, **kwargs):
318318
'blocking', 'tasking', 'streaming', 'factorize', 'fission', 'fuse', 'lift',
319319
'cire-sops', 'cse', 'opt-pows', 'topofuse',
320320
# IET
321-
'orchestrate', 'pthreadify', 'parallel', 'mpi', 'linearize', 'prodders'
321+
'orchestrate', 'pthreadify', 'parallel', 'mpi', 'linearize', 'prodders', 'dtypes'
322322
)
323323
_known_passes_disabled = ('denormals', 'simd')
324324
assert not (set(_known_passes) & set(_known_passes_disabled))

devito/operator/operator.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from devito.parameters import configuration
2323
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
2424
generate_macros, minimize_symbols, unevaluate,
25-
error_mapper, include_complex)
25+
error_mapper)
2626
from devito.symbolics import estimate_cost
2727
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
2828
filter_sorted, frozendict, is_integer, split, timed_pass,
@@ -466,10 +466,6 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
466466
# Lower IET to a target-specific IET
467467
graph = Graph(iet, **kwargs)
468468

469-
# Complex header if needed. Needs to be done before specialization
470-
# as some specific cases require complex to be loaded first
471-
include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler'])
472-
473469
# Specialize
474470
graph = cls._specialize_iet(graph, **kwargs)
475471

devito/passes/iet/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
from .instrument import * # noqa
99
from .languages import * # noqa
1010
from .errors import * # noqa
11-
from .complex import * # noqa
11+
from .dtypes import * # noqa

devito/passes/iet/definitions.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
1313
FindSymbols, MapExprStmts, Transformer, make_callable)
1414
from devito.passes import is_gpu_create
15+
from devito.passes.iet.dtypes import lower_complex
1516
from devito.passes.iet.engine import iet_pass
1617
from devito.passes.iet.langbase import LangBB
1718
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
@@ -73,10 +74,12 @@ class DataManager:
7374
The language used to express data allocations, deletions, and host-device transfers.
7475
"""
7576

76-
def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs):
77+
def __init__(self, rcompile=None, sregistry=None, platform=None,
78+
compiler=None, **kwargs):
7779
self.rcompile = rcompile
7880
self.sregistry = sregistry
7981
self.platform = platform
82+
self.compiler = compiler
8083

8184
def _alloc_object_on_low_lat_mem(self, site, obj, storage):
8285
"""
@@ -409,12 +412,18 @@ def place_casts(self, iet, **kwargs):
409412

410413
return iet, {}
411414

415+
@iet_pass
416+
def make_langtypes(self, iet):
417+
iet, metadata = lower_complex(iet, self.lang, self.compiler)
418+
return iet, metadata
419+
412420
def process(self, graph):
413421
"""
414422
Apply the `place_definitions` and `place_casts` passes.
415423
"""
416424
self.place_definitions(graph, globs=set())
417425
self.place_casts(graph)
426+
self.make_langtypes(graph)
418427

419428

420429
class DeviceAwareDataManager(DataManager):
@@ -564,6 +573,7 @@ def process(self, graph):
564573
self.place_devptr(graph)
565574
self.place_bundling(graph, writes_input=graph.writes_input)
566575
self.place_casts(graph)
576+
self.make_langtypes(graph)
567577

568578

569579
def make_zero_init(obj):

devito/passes/iet/dtypes.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
3+
from devito.ir import FindSymbols, Uxreplace
4+
5+
__all__ = ['lower_complex']
6+
7+
8+
def lower_complex(iet, lang, compiler):
9+
"""
10+
Add headers for complex arithmetic
11+
"""
12+
# Check if there is complex numbers that always take dtype precedence
13+
max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)])
14+
if not np.issubdtype(max_dtype, np.complexfloating):
15+
return iet, {}
16+
17+
lib = (lang['header-complex'],)
18+
headers = lang.get('I-def')
19+
20+
# Some languges such as c++11 need some extra arithmetic definitions
21+
if lang.get('def-complex'):
22+
dest = compiler.get_jit_dir()
23+
hfile = dest.joinpath('complex_arith.h')
24+
with open(str(hfile), 'w') as ff:
25+
ff.write(str(lang['def-complex']))
26+
lib += (str(hfile),)
27+
28+
iet = _complex_dtypes(iet, lang)
29+
30+
return iet, {'includes': lib, 'headers': headers}
31+
32+
33+
def _complex_dtypes(iet, lang):
34+
"""
35+
Lower dtypes to language specific types
36+
"""
37+
mapper = {}
38+
39+
for s in FindSymbols('indexeds').visit(iet):
40+
if s.dtype in lang['types']:
41+
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])
42+
43+
for s in FindSymbols().visit(iet):
44+
if s.dtype in lang['types']:
45+
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])
46+
47+
body = Uxreplace(mapper).visit(iet.body)
48+
params = Uxreplace(mapper).visit(iet.parameters)
49+
iet = iet._rebuild(body=body, parameters=params)
50+
51+
return iet

devito/passes/iet/langbase.py

+11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def __getitem__(self, k):
3131
raise NotImplementedError("Missing required mapping for `%s`" % k)
3232
return self.mapper[k]
3333

34+
def get(self, k):
35+
return self.mapper.get(k)
36+
3437

3538
class LangBB(metaclass=LangMeta):
3639

@@ -200,6 +203,14 @@ def initialize(self, iet, options=None):
200203
"""
201204
return iet, {}
202205

206+
@iet_pass
207+
def make_langtypes(self, iet):
208+
"""
209+
An `iet_pass` which transforms an IET such that the target language
210+
types are introduced.
211+
"""
212+
return iet, {}
213+
203214
@property
204215
def Region(self):
205216
return self.lang.Region

devito/passes/iet/languages/C.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1+
import numpy as np
2+
13
from devito.ir import Call
24
from devito.passes.iet.definitions import DataManager
35
from devito.passes.iet.orchestration import Orchestrator
46
from devito.passes.iet.langbase import LangBB
7+
from devito.tools import CustomNpType
58

69
__all__ = ['CBB', 'CDataManager', 'COrchestrator']
710

811

12+
CCFloat = CustomNpType('_Complex float', np.complex64)
13+
CCDouble = CustomNpType('_Complex double', np.complex128)
14+
15+
916
class CBB(LangBB):
1017

1118
mapper = {
@@ -19,7 +26,11 @@ class CBB(LangBB):
1926
'host-free-pin': lambda i:
2027
Call('free', (i,)),
2128
'alloc-global-symbol': lambda i, j, k:
22-
Call('memcpy', (i, j, k))
29+
Call('memcpy', (i, j, k)),
30+
# Complex
31+
'header-complex': 'complex.h',
32+
'types': {np.complex128: CCDouble, np.complex64: CCFloat},
33+
'I-def': None
2334
}
2435

2536

devito/passes/iet/languages/CXX.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import numpy as np
2+
3+
from devito.ir import Call
4+
from devito.passes.iet.langbase import LangBB
5+
from devito.tools import CustomNpType
6+
7+
__all__ = ['CXXBB']
8+
9+
10+
std_arith = """
11+
#include <complex>
12+
13+
template<typename _Tp, typename _Ti>
14+
std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){
15+
return std::complex<_Tp>(b.real() * a, b.imag() * a);
16+
}
17+
18+
template<typename _Tp, typename _Ti>
19+
std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){
20+
return std::complex<_Tp>(b.real() * a, b.imag() * a);
21+
}
22+
23+
template<typename _Tp, typename _Ti>
24+
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
25+
_Tp denom = b.real() * b.real () + b.imag() * b.imag()
26+
return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom);
27+
}
28+
29+
template<typename _Tp, typename _Ti>
30+
std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){
31+
return std::complex<_Tp>(b.real() / a, b.imag() / a);
32+
}
33+
34+
template<typename _Tp, typename _Ti>
35+
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
36+
return std::complex<_Tp>(b.real() + a, b.imag());
37+
}
38+
39+
template<typename _Tp, typename _Ti>
40+
std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){
41+
return std::complex<_Tp>(b.real() + a, b.imag());
42+
}
43+
"""
44+
45+
CXXCFloat = CustomNpType('std::complex', np.complex64, template='float')
46+
CXXCDouble = CustomNpType('std::complex', np.complex128, template='double')
47+
48+
49+
class CXXBB(LangBB):
50+
51+
mapper = {
52+
'header-memcpy': 'string.h',
53+
'host-alloc': lambda i, j, k:
54+
Call('posix_memalign', (i, j, k)),
55+
'host-alloc-pin': lambda i, j, k:
56+
Call('posix_memalign', (i, j, k)),
57+
'host-free': lambda i:
58+
Call('free', (i,)),
59+
'host-free-pin': lambda i:
60+
Call('free', (i,)),
61+
'alloc-global-symbol': lambda i, j, k:
62+
Call('memcpy', (i, j, k)),
63+
# Complex
64+
'header-complex': '<complex>',
65+
'I-def': (('_Complex_I', ('std::complex<float>(0.0, 1.0)')),),
66+
'def-complex': std_arith,
67+
'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat},
68+
}

devito/passes/iet/languages/openacc.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from devito.passes.iet.orchestration import Orchestrator
1111
from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB,
1212
PragmaIteration, PragmaTransfer)
13-
from devito.passes.iet.languages.C import CBB
13+
from devito.passes.iet.languages.CXX import CXXBB
1414
from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration
1515
from devito.symbolics import FieldFromPointer, Macro, cast_mapper
1616
from devito.tools import filter_ordered, UnboundTuple
@@ -118,7 +118,8 @@ class AccBB(PragmaLangBB):
118118
'device-free': lambda i, *a:
119119
Call('acc_free', (i,))
120120
}
121-
mapper.update(CBB.mapper)
121+
122+
mapper.update(CXXBB.mapper)
122123

123124
Region = OmpRegion
124125
HostIteration = OmpIteration # Host parallelism still goes via OpenMP

devito/passes/iet/misc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
1313
from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper,
1414
ccode)
15-
from devito.tools import Bunch, as_mapper, filter_ordered, split, dtype_to_cstr
15+
from devito.tools import Bunch, as_mapper, filter_ordered, split
1616
from devito.types import FIndexed
1717

1818
__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions',

devito/symbolics/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from devito.symbolics.extended_sympy import * # noqa
2+
from devito.symbolics.extended_dtypes import * # noqa
23
from devito.symbolics.queries import * # noqa
34
from devito.symbolics.search import * # noqa
45
from devito.symbolics.printer import * # noqa

0 commit comments

Comments
 (0)