Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3f0b8e2

Browse files
committedJun 21, 2024
compiler: rework dtype lowering
1 parent 1ca3e46 commit 3f0b8e2

18 files changed

+326
-177
lines changed
 

‎devito/arch/compiler.py

-18
Original file line numberDiff line numberDiff line change
@@ -245,20 +245,6 @@ def version(self):
245245

246246
return version
247247

248-
@property
249-
def _complex_ctype(self):
250-
"""
251-
Type definition for complex numbers. These two cases cover 99% of the cases since
252-
- Hip is now using std::complex
253-
https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings
254-
- Sycl supports std::complex
255-
- C's _Complex is part of C99
256-
"""
257-
if self._cpp:
258-
return lambda dtype: 'std::complex<%s>' % str(dtype)
259-
else:
260-
return lambda dtype: '%s _Complex' % str(dtype)
261-
262248
def get_version(self):
263249
result, stdout, stderr = call_capture_output((self.cc, "--version"))
264250
if result != 0:
@@ -713,10 +699,6 @@ def __lookup_cmds__(self):
713699
self.MPICC = 'nvcc'
714700
self.MPICXX = 'nvcc'
715701

716-
@property
717-
def _complex_ctype(self):
718-
return lambda dtype: 'thrust::complex<%s>' % str(dtype)
719-
720702

721703
class HipCompiler(Compiler):
722704

‎devito/core/gpu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def _make_iet_passes_mapper(cls, **kwargs):
318318
'blocking', 'tasking', 'streaming', 'factorize', 'fission', 'fuse', 'lift',
319319
'cire-sops', 'cse', 'opt-pows', 'topofuse',
320320
# IET
321-
'orchestrate', 'pthreadify', 'parallel', 'mpi', 'linearize', 'prodders'
321+
'orchestrate', 'pthreadify', 'parallel', 'mpi', 'linearize', 'prodders', 'dtypes'
322322
)
323323
_known_passes_disabled = ('denormals', 'simd')
324324
assert not (set(_known_passes) & set(_known_passes_disabled))

‎devito/operator/operator.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from devito.parameters import configuration
2323
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
2424
generate_macros, minimize_symbols, unevaluate,
25-
error_mapper, include_complex)
25+
error_mapper)
2626
from devito.symbolics import estimate_cost
2727
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
2828
filter_sorted, frozendict, is_integer, split, timed_pass,
@@ -466,10 +466,6 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
466466
# Lower IET to a target-specific IET
467467
graph = Graph(iet, **kwargs)
468468

469-
# Complex header if needed. Needs to be done before specialization
470-
# as some specific cases require complex to be loaded first
471-
include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler'])
472-
473469
# Specialize
474470
graph = cls._specialize_iet(graph, **kwargs)
475471

‎devito/passes/iet/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
from .instrument import * # noqa
99
from .languages import * # noqa
1010
from .errors import * # noqa
11-
from .complex import * # noqa
11+
from .dtypes import * # noqa

‎devito/passes/iet/definitions.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
1313
FindSymbols, MapExprStmts, Transformer, make_callable)
1414
from devito.passes import is_gpu_create
15+
from devito.passes.iet.dtypes import lower_complex
1516
from devito.passes.iet.engine import iet_pass
1617
from devito.passes.iet.langbase import LangBB
1718
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
@@ -73,10 +74,12 @@ class DataManager:
7374
The language used to express data allocations, deletions, and host-device transfers.
7475
"""
7576

76-
def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs):
77+
def __init__(self, rcompile=None, sregistry=None, platform=None,
78+
compiler=None, **kwargs):
7779
self.rcompile = rcompile
7880
self.sregistry = sregistry
7981
self.platform = platform
82+
self.compiler = compiler
8083

8184
def _alloc_object_on_low_lat_mem(self, site, obj, storage):
8285
"""
@@ -409,12 +412,18 @@ def place_casts(self, iet, **kwargs):
409412

410413
return iet, {}
411414

415+
@iet_pass
416+
def make_langtypes(self, iet):
417+
iet, metadata = lower_complex(iet, self.lang, self.compiler)
418+
return iet, metadata
419+
412420
def process(self, graph):
413421
"""
414422
Apply the `place_definitions` and `place_casts` passes.
415423
"""
416424
self.place_definitions(graph, globs=set())
417425
self.place_casts(graph)
426+
self.make_langtypes(graph)
418427

419428

420429
class DeviceAwareDataManager(DataManager):
@@ -564,6 +573,7 @@ def process(self, graph):
564573
self.place_devptr(graph)
565574
self.place_bundling(graph, writes_input=graph.writes_input)
566575
self.place_casts(graph)
576+
self.make_langtypes(graph)
567577

568578

569579
def make_zero_init(obj):

‎devito/passes/iet/dtypes.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
3+
from devito.ir import FindSymbols, Uxreplace
4+
5+
__all__ = ['lower_complex']
6+
7+
8+
def lower_complex(iet, lang, compiler):
9+
"""
10+
Add headers for complex arithmetic
11+
"""
12+
# Check if there is complex numbers that always take dtype precedence
13+
max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)])
14+
if not np.issubdtype(max_dtype, np.complexfloating):
15+
return iet, {}
16+
17+
lib = (lang['header-complex'],)
18+
headers = lang.get('I-def')
19+
20+
# Some languges such as c++11 need some extra arithmetic definitions
21+
if lang.get('def-complex'):
22+
dest = compiler.get_jit_dir()
23+
hfile = dest.joinpath('complex_arith.h')
24+
with open(str(hfile), 'w') as ff:
25+
ff.write(str(lang['def-complex']))
26+
lib += (str(hfile),)
27+
28+
iet = _complex_dtypes(iet, lang)
29+
30+
return iet, {'includes': lib, 'headers': headers}
31+
32+
33+
def _complex_dtypes(iet, lang):
34+
"""
35+
Lower dtypes to language specific types
36+
"""
37+
mapper = {}
38+
39+
for s in FindSymbols('indexeds').visit(iet):
40+
if s.dtype in lang['types']:
41+
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])
42+
43+
for s in FindSymbols().visit(iet):
44+
if s.dtype in lang['types']:
45+
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])
46+
47+
body = Uxreplace(mapper).visit(iet.body)
48+
params = Uxreplace(mapper).visit(iet.parameters)
49+
iet = iet._rebuild(body=body, parameters=params)
50+
51+
return iet

‎devito/passes/iet/langbase.py

+11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def __getitem__(self, k):
3131
raise NotImplementedError("Missing required mapping for `%s`" % k)
3232
return self.mapper[k]
3333

34+
def get(self, k):
35+
return self.mapper.get(k)
36+
3437

3538
class LangBB(metaclass=LangMeta):
3639

@@ -200,6 +203,14 @@ def initialize(self, iet, options=None):
200203
"""
201204
return iet, {}
202205

206+
@iet_pass
207+
def make_langtypes(self, iet):
208+
"""
209+
An `iet_pass` which transforms an IET such that the target language
210+
types are introduced.
211+
"""
212+
return iet, {}
213+
203214
@property
204215
def Region(self):
205216
return self.lang.Region

‎devito/passes/iet/languages/C.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1+
import numpy as np
2+
13
from devito.ir import Call
24
from devito.passes.iet.definitions import DataManager
35
from devito.passes.iet.orchestration import Orchestrator
46
from devito.passes.iet.langbase import LangBB
7+
from devito.tools import CustomNpType
58

69
__all__ = ['CBB', 'CDataManager', 'COrchestrator']
710

811

12+
CCFloat = CustomNpType('_Complex float', np.complex64)
13+
CCDouble = CustomNpType('_Complex double', np.complex128)
14+
15+
916
class CBB(LangBB):
1017

1118
mapper = {
@@ -19,7 +26,11 @@ class CBB(LangBB):
1926
'host-free-pin': lambda i:
2027
Call('free', (i,)),
2128
'alloc-global-symbol': lambda i, j, k:
22-
Call('memcpy', (i, j, k))
29+
Call('memcpy', (i, j, k)),
30+
# Complex
31+
'header-complex': 'complex.h',
32+
'types': {np.complex128: CCDouble, np.complex64: CCFloat},
33+
'I-def': None
2334
}
2435

2536

‎devito/passes/iet/languages/CXX.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import numpy as np
2+
3+
from devito.ir import Call
4+
from devito.passes.iet.langbase import LangBB
5+
from devito.tools import CustomNpType
6+
7+
__all__ = ['CXXBB']
8+
9+
10+
std_arith = """
11+
#include <complex>
12+
13+
template<typename _Tp, typename _Ti>
14+
std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){
15+
return std::complex<_Tp>(b.real() * a, b.imag() * a);
16+
}
17+
18+
template<typename _Tp, typename _Ti>
19+
std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){
20+
return std::complex<_Tp>(b.real() * a, b.imag() * a);
21+
}
22+
23+
template<typename _Tp, typename _Ti>
24+
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
25+
_Tp denom = b.real() * b.real () + b.imag() * b.imag()
26+
return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom);
27+
}
28+
29+
template<typename _Tp, typename _Ti>
30+
std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){
31+
return std::complex<_Tp>(b.real() / a, b.imag() / a);
32+
}
33+
34+
template<typename _Tp, typename _Ti>
35+
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
36+
return std::complex<_Tp>(b.real() + a, b.imag());
37+
}
38+
39+
template<typename _Tp, typename _Ti>
40+
std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){
41+
return std::complex<_Tp>(b.real() + a, b.imag());
42+
}
43+
"""
44+
45+
CXXCFloat = CustomNpType('std::complex', np.complex64, template='float')
46+
CXXCDouble = CustomNpType('std::complex', np.complex128, template='double')
47+
48+
49+
class CXXBB(LangBB):
50+
51+
mapper = {
52+
'header-memcpy': 'string.h',
53+
'host-alloc': lambda i, j, k:
54+
Call('posix_memalign', (i, j, k)),
55+
'host-alloc-pin': lambda i, j, k:
56+
Call('posix_memalign', (i, j, k)),
57+
'host-free': lambda i:
58+
Call('free', (i,)),
59+
'host-free-pin': lambda i:
60+
Call('free', (i,)),
61+
'alloc-global-symbol': lambda i, j, k:
62+
Call('memcpy', (i, j, k)),
63+
# Complex
64+
'header-complex': '<complex>',
65+
'I-def': (('_Complex_I', ('std::complex<float>(0.0, 1.0)')),),
66+
'def-complex': std_arith,
67+
'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat},
68+
}

‎devito/passes/iet/languages/openacc.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from devito.passes.iet.orchestration import Orchestrator
1111
from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB,
1212
PragmaIteration, PragmaTransfer)
13-
from devito.passes.iet.languages.C import CBB
13+
from devito.passes.iet.languages.CXX import CXXBB
1414
from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration
1515
from devito.symbolics import FieldFromPointer, Macro, cast_mapper
1616
from devito.tools import filter_ordered, UnboundTuple
@@ -118,7 +118,8 @@ class AccBB(PragmaLangBB):
118118
'device-free': lambda i, *a:
119119
Call('acc_free', (i,))
120120
}
121-
mapper.update(CBB.mapper)
121+
122+
mapper.update(CXXBB.mapper)
122123

123124
Region = OmpRegion
124125
HostIteration = OmpIteration # Host parallelism still goes via OpenMP

‎devito/passes/iet/misc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
1313
from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper,
1414
ccode)
15-
from devito.tools import Bunch, as_mapper, filter_ordered, split, dtype_to_cstr
15+
from devito.tools import Bunch, as_mapper, filter_ordered, split
1616
from devito.types import FIndexed
1717

1818
__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions',

‎devito/symbolics/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from devito.symbolics.extended_sympy import * # noqa
2+
from devito.symbolics.extended_dtypes import * # noqa
23
from devito.symbolics.queries import * # noqa
34
from devito.symbolics.search import * # noqa
45
from devito.symbolics.printer import * # noqa

‎devito/symbolics/extended_dtypes.py

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import numpy as np
2+
3+
from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit
4+
from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa
5+
int2, int3, int4)
6+
7+
__all__ = ['cast_mapper', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa
8+
9+
10+
limits_mapper = {
11+
np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')),
12+
np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')),
13+
np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')),
14+
np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')),
15+
}
16+
17+
18+
class CustomType(ReservedWord):
19+
pass
20+
21+
22+
# Dynamically create INT, INT2, .... INTP, INT2P, ... FLOAT, ...
23+
for base_name in ['int', 'float', 'double']:
24+
for i in ['', '2', '3', '4']:
25+
v = '%s%s' % (base_name, i)
26+
cls = type(v.upper(), (Cast,), {'_base_typ': v})
27+
globals()[cls.__name__] = cls
28+
29+
clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls})
30+
globals()[clsp.__name__] = clsp
31+
32+
33+
class CHAR(Cast):
34+
_base_typ = 'char'
35+
36+
37+
class SHORT(Cast):
38+
_base_typ = 'short'
39+
40+
41+
class USHORT(Cast):
42+
_base_typ = 'unsigned short'
43+
44+
45+
class UCHAR(Cast):
46+
_base_typ = 'unsigned char'
47+
48+
49+
class LONG(Cast):
50+
_base_typ = 'long'
51+
52+
53+
class ULONG(Cast):
54+
_base_typ = 'unsigned long'
55+
56+
57+
class CFLOAT(Cast):
58+
_base_typ = 'float'
59+
60+
61+
class CDOUBLE(Cast):
62+
_base_typ = 'double'
63+
64+
65+
class VOID(Cast):
66+
_base_typ = 'void'
67+
68+
69+
class CHARP(CastStar):
70+
base = CHAR
71+
72+
73+
class UCHARP(CastStar):
74+
base = UCHAR
75+
76+
77+
class SHORTP(CastStar):
78+
base = SHORT
79+
80+
81+
class USHORTP(CastStar):
82+
base = USHORT
83+
84+
85+
class CFLOATP(CastStar):
86+
base = CFLOAT
87+
88+
89+
class CDOUBLEP(CastStar):
90+
base = CDOUBLE
91+
92+
93+
cast_mapper = {
94+
np.int8: CHAR,
95+
np.uint8: UCHAR,
96+
np.int16: SHORT, # noqa
97+
np.uint16: USHORT, # noqa
98+
int: INT, # noqa
99+
np.int32: INT, # noqa
100+
np.int64: LONG,
101+
np.uint64: ULONG,
102+
np.float32: FLOAT, # noqa
103+
float: DOUBLE, # noqa
104+
np.float64: DOUBLE, # noqa
105+
106+
(np.int8, '*'): CHARP,
107+
(np.uint8, '*'): UCHARP,
108+
(int, '*'): INTP, # noqa
109+
(np.uint16, '*'): USHORTP, # noqa
110+
(np.int16, '*'): SHORTP, # noqa
111+
(np.int32, '*'): INTP, # noqa
112+
(np.int64, '*'): INTP, # noqa
113+
(np.float32, '*'): FLOATP, # noqa
114+
(float, '*'): DOUBLEP, # noqa
115+
(np.float64, '*'): DOUBLEP, # noqa
116+
}
117+
118+
for base_name in ['int', 'float', 'double']:
119+
for i in [2, 3, 4]:
120+
v = '%s%d' % (base_name, i)
121+
cls = locals()[v]
122+
cast_mapper[cls] = locals()[v.upper()]
123+
cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()]

‎devito/symbolics/extended_sympy.py

+1-125
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from sympy import Expr, Function, Number, Tuple, sympify
88
from sympy.core.decorators import call_highest_priority
99

10-
from devito import configuration
1110
from devito.finite_differences.elementary import Min, Max
1211
from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa
1312
float3, float4, double2, double3, double4, int2, int3,
@@ -20,8 +19,7 @@
2019
'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction',
2120
'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String',
2221
'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', 'Namespace',
23-
'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null', 'SizeOf', 'rfunc',
24-
'cast_mapper', 'BasicWrapperMixin', 'ValueLimit', 'limits_mapper']
22+
'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit']
2523

2624

2725
class CondEq(sympy.Eq):
@@ -548,14 +546,6 @@ class ValueLimit(ReservedWord, sympy.Expr):
548546
pass
549547

550548

551-
limits_mapper = {
552-
np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')),
553-
np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')),
554-
np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')),
555-
np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')),
556-
}
557-
558-
559549
class DefFunction(Function, Pickable):
560550

561551
"""
@@ -773,120 +763,6 @@ def __new__(cls, base=''):
773763
return cls.base(base, '*')
774764

775765

776-
# Dynamically create INT, INT2, .... INTP, INT2P, ... FLOAT, ...
777-
for base_name in ['int', 'float', 'double']:
778-
for i in ['', '2', '3', '4']:
779-
v = '%s%s' % (base_name, i)
780-
cls = type(v.upper(), (Cast,), {'_base_typ': v})
781-
globals()[cls.__name__] = cls
782-
783-
clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls})
784-
globals()[clsp.__name__] = clsp
785-
786-
787-
class CHAR(Cast):
788-
_base_typ = 'char'
789-
790-
791-
class SHORT(Cast):
792-
_base_typ = 'short'
793-
794-
795-
class USHORT(Cast):
796-
_base_typ = 'unsigned short'
797-
798-
799-
class UCHAR(Cast):
800-
_base_typ = 'unsigned char'
801-
802-
803-
class LONG(Cast):
804-
_base_typ = 'long'
805-
806-
807-
class ULONG(Cast):
808-
_base_typ = 'unsigned long'
809-
810-
811-
class VOID(Cast):
812-
_base_typ = 'void'
813-
814-
815-
class CFLOAT(Cast):
816-
817-
@property
818-
def _base_typ(self):
819-
return configuration['compiler']._complex_ctype('float')
820-
821-
822-
class CDOUBLE(Cast):
823-
824-
@property
825-
def _base_typ(self):
826-
return configuration['compiler']._complex_ctype('double')
827-
828-
829-
class CHARP(CastStar):
830-
base = CHAR
831-
832-
833-
class UCHARP(CastStar):
834-
base = UCHAR
835-
836-
837-
class SHORTP(CastStar):
838-
base = SHORT
839-
840-
841-
class USHORTP(CastStar):
842-
base = USHORT
843-
844-
845-
class CFLOATP(CastStar):
846-
base = CFLOAT
847-
848-
849-
class CDOUBLEP(CastStar):
850-
base = CDOUBLE
851-
852-
853-
cast_mapper = {
854-
np.int8: CHAR,
855-
np.uint8: UCHAR,
856-
np.int16: SHORT, # noqa
857-
np.uint16: USHORT, # noqa
858-
int: INT, # noqa
859-
np.int32: INT, # noqa
860-
np.int64: LONG,
861-
np.uint64: ULONG,
862-
np.float32: FLOAT, # noqa
863-
float: DOUBLE, # noqa
864-
np.float64: DOUBLE, # noqa
865-
np.complex64: CFLOAT, # noqa
866-
np.complex128: CDOUBLE, # noqa
867-
868-
(np.int8, '*'): CHARP,
869-
(np.uint8, '*'): UCHARP,
870-
(int, '*'): INTP, # noqa
871-
(np.uint16, '*'): USHORTP, # noqa
872-
(np.int16, '*'): SHORTP, # noqa
873-
(np.int32, '*'): INTP, # noqa
874-
(np.int64, '*'): INTP, # noqa
875-
(np.float32, '*'): FLOATP, # noqa
876-
(float, '*'): DOUBLEP, # noqa
877-
(np.float64, '*'): DOUBLEP, # noqa
878-
(np.complex64, '*'): CFLOATP, # noqa
879-
(np.complex128, '*'): CDOUBLEP, # noqa
880-
}
881-
882-
for base_name in ['int', 'float', 'double']:
883-
for i in [2, 3, 4]:
884-
v = '%s%d' % (base_name, i)
885-
cls = locals()[v]
886-
cast_mapper[cls] = locals()[v.upper()]
887-
cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()]
888-
889-
890766
# Some other utility objects
891767
Null = Macro('NULL')
892768

‎devito/symbolics/inspection.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from devito.finite_differences import Derivative
99
from devito.finite_differences.differentiable import IndexDerivative
1010
from devito.logger import warning
11-
from devito.symbolics.extended_sympy import (INT, CallFromPointer, Cast,
11+
from devito.symbolics.extended_dtypes import INT
12+
from devito.symbolics.extended_sympy import (CallFromPointer, Cast,
1213
DefFunction, ReservedWord)
1314
from devito.symbolics.queries import q_routine
1415
from devito.tools import as_tuple, prod

‎devito/tools/dtypes_lowering.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype',
1414
'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len',
1515
'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper',
16-
'is_external_ctype', 'infer_dtype', 'CustomDtype']
16+
'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomNpType']
1717

1818

1919
# *** Custom np.dtypes
@@ -123,6 +123,18 @@ def __repr__(self):
123123
__str__ = __repr__
124124

125125

126+
class CustomNpType(CustomDtype):
127+
"""
128+
Custom dtype for underlying numpy type.
129+
"""
130+
131+
def __init__(self, name, nptype, template=None, modifier=None):
132+
self.nptype = nptype
133+
super().__init__(name, template, modifier)
134+
135+
def __call__(self, val):
136+
return self.nptype(val)
137+
126138
# *** np.dtypes lowering
127139

128140

@@ -136,16 +148,6 @@ def dtype_to_ctype(dtype):
136148
if isinstance(dtype, CustomDtype):
137149
return dtype
138150

139-
# Complex data
140-
if np.issubdtype(dtype, np.complexfloating):
141-
rtype = dtype(0).real.__class__
142-
from devito import configuration
143-
make = configuration['compiler']._complex_ctype
144-
ctname = make(dtype_to_cstr(rtype))
145-
ctype = dtype_to_ctype(rtype)
146-
r = type(ctname, (ctype,), {})
147-
return r
148-
149151
try:
150152
return ctypes_vector_mapper[dtype]
151153
except KeyError:

‎devito/types/basic.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from devito.data import default_allocator
1414
from devito.parameters import configuration
1515
from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype,
16-
frozendict, memoized_meth, sympy_mutex)
16+
frozendict, memoized_meth, sympy_mutex, CustomDtype)
1717
from devito.types.args import ArgProvider
1818
from devito.types.caching import Cached, Uncached
1919
from devito.types.lazy import Evaluable
@@ -82,6 +82,9 @@ def _C_typedata(self):
8282
The type of the object in the generated code as a `str`.
8383
"""
8484
_type = self._C_ctype
85+
if isinstance(_type, CustomDtype):
86+
return _type
87+
8588
while issubclass(_type, _Pointer):
8689
_type = _type._type_
8790

@@ -858,14 +861,16 @@ def __new__(cls, *args, **kwargs):
858861
name = kwargs.get('name')
859862
alias = kwargs.get('alias')
860863
function = kwargs.get('function')
864+
dtype = kwargs.get('dtype')
861865
if alias or (function and function.name != name):
862866
function = kwargs['function'] = None
863867

864868
# If same name/indices and `function` isn't None, then it's
865869
# definitely a reconstruction
866870
if function is not None and \
867871
function.name == name and \
868-
function.indices == indices:
872+
function.indices == indices and \
873+
function.dtype == dtype:
869874
# Special case: a syntactically identical alias of `function`, so
870875
# let's just return `function` itself
871876
return function
@@ -1170,7 +1175,8 @@ def bound_symbols(self):
11701175
@cached_property
11711176
def indexed(self):
11721177
"""The wrapped IndexedData object."""
1173-
return IndexedData(self.name, shape=self._shape, function=self.function)
1178+
return IndexedData(self.name, shape=self._shape, function=self.function,
1179+
dtype=self.dtype)
11741180

11751181
@cached_property
11761182
def dmap(self):
@@ -1355,7 +1361,7 @@ def _data_alignment(self):
13551361
def indexify(self, indices=None, subs=None):
13561362
"""Create a types.Indexed from the current object."""
13571363
if indices is not None:
1358-
return Indexed(self.indexed, *indices)
1364+
return Indexed(self.indexed, indices)
13591365

13601366
# Substitution for each index (spacing only used in own dimension)
13611367
subs = subs or {}
@@ -1414,20 +1420,21 @@ class IndexedBase(sympy.IndexedBase, Basic, Pickable):
14141420
__rargs__ = ('label', 'shape')
14151421
__rkwargs__ = ('function',)
14161422

1417-
def __new__(cls, label, shape, function=None):
1423+
def __new__(cls, label, shape, function=None, dtype=None):
14181424
# Make sure `label` is a devito.Symbol, not a sympy.Symbol
14191425
if isinstance(label, str):
14201426
label = Symbol(name=label, dtype=None)
14211427
with sympy_mutex:
14221428
obj = sympy.IndexedBase.__new__(cls, label, shape)
14231429
obj.function = function
1430+
obj._dtype = dtype
14241431
return obj
14251432

14261433
func = Pickable._rebuild
14271434

14281435
def __getitem__(self, indices, **kwargs):
14291436
"""Produce a types.Indexed, rather than a sympy.Indexed."""
1430-
return Indexed(self, *as_tuple(indices))
1437+
return Indexed(self, as_tuple(indices))
14311438

14321439
def _hashable_content(self):
14331440
return super()._hashable_content() + (self.function,)
@@ -1454,7 +1461,7 @@ def indices(self):
14541461

14551462
@property
14561463
def dtype(self):
1457-
return self.function.dtype
1464+
return self._dtype
14581465

14591466
@cached_property
14601467
def free_symbols(self):
@@ -1516,7 +1523,7 @@ def _C_ctype(self):
15161523
return self.function._C_ctype
15171524

15181525

1519-
class Indexed(sympy.Indexed):
1526+
class Indexed(sympy.Indexed, Pickable):
15201527

15211528
# The two type flags have changed in upstream sympy as of version 1.1,
15221529
# but the below interpretation is used throughout the compiler to
@@ -1528,6 +1535,15 @@ class Indexed(sympy.Indexed):
15281535

15291536
is_Dimension = False
15301537

1538+
__rargs__ = ('base', 'indices')
1539+
__rkwargs__ = ('dtype',)
1540+
1541+
def __new__(cls, indexed, indices, dtype=None, **kwargs):
1542+
newobj = sympy.Indexed.__new__(cls, indexed, *indices)
1543+
newobj._dtype = dtype or indexed.dtype
1544+
1545+
return newobj
1546+
15311547
@memoized_meth
15321548
def __str__(self):
15331549
return super().__str__()
@@ -1549,7 +1565,7 @@ def function(self):
15491565

15501566
@property
15511567
def dtype(self):
1552-
return self.function.dtype
1568+
return self._dtype
15531569

15541570
@property
15551571
def name(self):

‎devito/types/misc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class FIndexed(Indexed, Pickable):
7979
__rkwargs__ = ('strides_map', 'accessor')
8080

8181
def __new__(cls, base, *args, strides_map=None, accessor=None):
82-
obj = super().__new__(cls, base, *args)
82+
obj = super().__new__(cls, base, args)
8383
obj.strides_map = frozendict(strides_map or {})
8484
obj.accessor = accessor
8585

0 commit comments

Comments
 (0)
Please sign in to comment.