1
1
import numpy as np
2
2
3
3
from devito .arch .compiler import Compiler
4
- from devito .ir import Callable , FindSymbols , SymbolRegistry
4
+ from devito .ir import Callable , SymbolRegistry
5
+ from devito .ir .iet .utils import has_dtype
5
6
from devito .passes .iet .engine import iet_pass
6
7
from devito .passes .iet .langbase import LangBB
7
8
from devito .tools import as_tuple
10
11
11
12
12
13
@iet_pass
13
- def _complex_includes (iet : Callable , lang : type [LangBB ], compiler : Compiler ,
14
+ def _complex_includes (iet : Callable , langbb : type [LangBB ], compiler : Compiler ,
14
15
sregistry : SymbolRegistry ) -> tuple [Callable , dict ]:
15
16
"""
16
17
Includes complex arithmetic headers for the given language, if needed.
17
18
"""
18
19
# Check if there are complex numbers that always take dtype precedence
19
- for f in FindSymbols ().visit (iet ):
20
- try :
21
- if np .issubdtype (f .dtype , np .complexfloating ):
22
- break
23
- except TypeError :
24
- continue
25
- else :
20
+ if not has_dtype (iet , np .complexfloating ):
26
21
return iet , {}
27
22
28
23
metadata = {}
29
- lib = as_tuple (lang ['includes-complex' ])
24
+ lib = as_tuple (langbb ['includes-complex' ])
30
25
31
- if lang .get ('complex-namespace' ) is not None :
32
- metadata ['namespaces' ] = lang ['complex-namespace' ]
26
+ if langbb .get ('complex-namespace' ) is not None :
27
+ metadata ['namespaces' ] = langbb ['complex-namespace' ]
33
28
34
29
# Some languges such as c++11 need some extra arithmetic definitions
35
- if lang .get ('def-complex' ):
30
+ if langbb .get ('def-complex' ):
36
31
dest = compiler .get_jit_dir ()
37
32
hfile = dest .joinpath ('complex_arith.h' )
38
33
with open (str (hfile ), 'w' ) as ff :
39
- ff .write (str (lang ['def-complex' ]))
34
+ ff .write (str (langbb ['def-complex' ]))
40
35
lib += (str (hfile ),)
41
36
42
37
metadata ['includes' ] = lib
@@ -47,12 +42,14 @@ def _complex_includes(iet: Callable, lang: type[LangBB], compiler: Compiler,
47
42
dtype_passes = [_complex_includes ]
48
43
49
44
50
- def lower_dtypes (graph : Callable , lang : type [LangBB ] = None , compiler : Compiler = None ,
45
+ def lower_dtypes (graph : Callable ,
46
+ langbb : type [LangBB ] = None ,
47
+ compiler : Compiler = None ,
51
48
sregistry : SymbolRegistry = None , ** kwargs ) -> tuple [Callable , dict ]:
52
49
"""
53
50
Lowers float16 scalar types to pointers since we can't directly pass their
54
51
value. Also includes headers for complex arithmetic if needed.
55
52
"""
56
53
57
54
for dtype_pass in dtype_passes :
58
- dtype_pass (graph , lang = lang , compiler = compiler , sregistry = sregistry )
55
+ dtype_pass (graph , langbb = langbb , compiler = compiler , sregistry = sregistry )
0 commit comments