Skip to content

Commit 464327e

Browse files
committed
compiler: rename lang option to langbb for clarity
1 parent 2362257 commit 464327e

File tree

9 files changed

+1170
-184
lines changed

9 files changed

+1170
-184
lines changed

devito/arch/compiler.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,11 @@ def __init__(self):
185185
_cstd = 'c99'
186186

187187
def __init__(self, **kwargs):
188-
name = kwargs.pop('name', self.__class__.__name__)
189-
if isinstance(name, Compiler):
190-
name = name.name
191-
self._name = name
188+
maybe_name = kwargs.pop('name', self.__class__.__name__)
189+
if isinstance(maybe_name, Compiler):
190+
self._name = maybe_name.name
191+
else:
192+
self._name = maybe_name
192193

193194
super().__init__(**kwargs)
194195

@@ -658,7 +659,7 @@ def __init_finalize__(self, **kwargs):
658659

659660
self.cflags.remove('-Wall')
660661
self.cflags.remove('-fPIC')
661-
self.cflags.append('-Xcompiler')
662+
self.cflags.extend(['-Xcompiler', '-fPIC'])
662663

663664
if configuration['mpi']:
664665
# We rather use `nvcc` to compile MPI, but for this we have to

devito/ir/iet/utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
from devito.ir.iet import FindSections, FindSymbols
24
from devito.symbolics import Keyword, Macro
35
from devito.tools import filter_ordered
@@ -166,3 +168,20 @@ def maybe_alias(obj, candidate):
166168
# the __rkwargs__ except for e.g. the name
167169

168170
return False
171+
172+
173+
def has_dtype(iet, dtype):
174+
"""
175+
Check if the given IET has at least one symbol with the given dtype or
176+
dtype kind.
177+
"""
178+
for f in FindSymbols().visit(iet):
179+
try:
180+
# Check if the dtype matches exactly (dtype input)
181+
# or matches the generic kind (dtype generic input)
182+
if np.issubdtype(f.dtype, dtype) or f.dtype == dtype:
183+
return True
184+
except TypeError:
185+
continue
186+
else:
187+
return False

devito/operator/operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _lower(cls, expressions, **kwargs):
261261
# Create a symbol registry
262262
kwargs.setdefault('sregistry', SymbolRegistry())
263263
# Add lang-base kwargs
264-
kwargs.setdefault('lang', cls._Target.lang())
264+
kwargs.setdefault('langbb', cls._Target.lang())
265265
kwargs.setdefault('printer', cls._Target.Printer)
266266

267267
expressions = as_tuple(expressions)

devito/passes/clusters/cse.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
7070
"""
7171
min_cost = options['cse-min-cost']
7272
mode = options['cse-algo']
73-
dtype = np.promote_types(options['scalar-min-type'], cluster.dtype).type
73+
try:
74+
dtype = np.promote_types(options['scalar-min-type'], cluster.dtype).type
75+
except TypeError:
76+
dtype = cluster.dtype
7477

7578
if cluster.is_fence:
7679
return cluster

devito/passes/iet/dtypes.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22

33
from devito.arch.compiler import Compiler
4-
from devito.ir import Callable, FindSymbols, SymbolRegistry
4+
from devito.ir import Callable, SymbolRegistry
5+
from devito.ir.iet.utils import has_dtype
56
from devito.passes.iet.engine import iet_pass
67
from devito.passes.iet.langbase import LangBB
78
from devito.tools import as_tuple
@@ -10,33 +11,27 @@
1011

1112

1213
@iet_pass
13-
def _complex_includes(iet: Callable, lang: type[LangBB], compiler: Compiler,
14+
def _complex_includes(iet: Callable, langbb: type[LangBB], compiler: Compiler,
1415
sregistry: SymbolRegistry) -> tuple[Callable, dict]:
1516
"""
1617
Includes complex arithmetic headers for the given language, if needed.
1718
"""
1819
# Check if there are complex numbers that always take dtype precedence
19-
for f in FindSymbols().visit(iet):
20-
try:
21-
if np.issubdtype(f.dtype, np.complexfloating):
22-
break
23-
except TypeError:
24-
continue
25-
else:
20+
if not has_dtype(iet, np.complexfloating):
2621
return iet, {}
2722

2823
metadata = {}
29-
lib = as_tuple(lang['includes-complex'])
24+
lib = as_tuple(langbb['includes-complex'])
3025

31-
if lang.get('complex-namespace') is not None:
32-
metadata['namespaces'] = lang['complex-namespace']
26+
if langbb.get('complex-namespace') is not None:
27+
metadata['namespaces'] = langbb['complex-namespace']
3328

3429
# Some languges such as c++11 need some extra arithmetic definitions
35-
if lang.get('def-complex'):
30+
if langbb.get('def-complex'):
3631
dest = compiler.get_jit_dir()
3732
hfile = dest.joinpath('complex_arith.h')
3833
with open(str(hfile), 'w') as ff:
39-
ff.write(str(lang['def-complex']))
34+
ff.write(str(langbb['def-complex']))
4035
lib += (str(hfile),)
4136

4237
metadata['includes'] = lib
@@ -47,12 +42,14 @@ def _complex_includes(iet: Callable, lang: type[LangBB], compiler: Compiler,
4742
dtype_passes = [_complex_includes]
4843

4944

50-
def lower_dtypes(graph: Callable, lang: type[LangBB] = None, compiler: Compiler = None,
45+
def lower_dtypes(graph: Callable,
46+
langbb: type[LangBB] = None,
47+
compiler: Compiler = None,
5148
sregistry: SymbolRegistry = None, **kwargs) -> tuple[Callable, dict]:
5249
"""
5350
Lowers float16 scalar types to pointers since we can't directly pass their
5451
value. Also includes headers for complex arithmetic if needed.
5552
"""
5653

5754
for dtype_pass in dtype_passes:
58-
dtype_pass(graph, lang=lang, compiler=compiler, sregistry=sregistry)
55+
dtype_pass(graph, langbb=langbb, compiler=compiler, sregistry=sregistry)

devito/passes/iet/instrument.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ def instrument_sections(iet, **kwargs):
121121

122122

123123
@iet_pass
124-
def sync_sections(iet, lang=None, profiler=None, **kwargs):
124+
def sync_sections(iet, langbb=None, profiler=None, **kwargs):
125125
"""
126126
Wrap sections within global barriers if deemed necessary by the profiler.
127127
"""
128128
try:
129-
sync = lang['map-wait']
129+
sync = langbb['map-wait']
130130
except (KeyError, NotImplementedError):
131131
return iet, {}
132132

@@ -137,7 +137,7 @@ def sync_sections(iet, lang=None, profiler=None, **kwargs):
137137
for tl in FindNodes(TimedList).visit(iet):
138138
symbols = FindSymbols().visit(tl)
139139

140-
queues = [i for i in symbols if isinstance(i, lang.AsyncQueue)]
140+
queues = [i for i in symbols if isinstance(i, langbb.AsyncQueue)]
141141
unnecessary = any(FindNodes(BusyWait).visit(tl))
142142
if queues and not unnecessary:
143143
waits = tuple(sync(i) for i in queues)

devito/passes/iet/misc.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def generate_macros(graph, **kwargs):
144144

145145

146146
@iet_pass
147-
def _generate_macros(iet, tracker=None, lang=None, **kwargs):
147+
def _generate_macros(iet, tracker=None, langbb=None, **kwargs):
148148
# Derive the Macros necessary for the FIndexeds
149149
iet = _generate_macros_findexeds(iet, tracker=tracker, **kwargs)
150150

@@ -155,7 +155,7 @@ def _generate_macros(iet, tracker=None, lang=None, **kwargs):
155155
for define, expr in headers)
156156

157157
# Generate Macros from higher-level SymPy objects
158-
mheaders, includes = _generate_macros_math(iet, lang=lang)
158+
mheaders, includes = _generate_macros_math(iet, langbb=langbb)
159159
includes = sorted(includes, key=str)
160160
headers.extend(sorted(mheaders, key=str))
161161

@@ -199,42 +199,42 @@ def _generate_macros_findexeds(iet, sregistry=None, tracker=None, **kwargs):
199199
return iet
200200

201201

202-
def _generate_macros_math(iet, lang=None):
202+
def _generate_macros_math(iet, langbb=None):
203203
headers = []
204204
includes = []
205205
for i in FindApplications().visit(iet):
206-
header, include = _lower_macro_math(i, lang)
206+
header, include = _lower_macro_math(i, langbb)
207207
headers.extend(header)
208208
includes.extend(include)
209209

210210
return headers, set(includes) - {None}
211211

212212

213213
@singledispatch
214-
def _lower_macro_math(expr, lang):
214+
def _lower_macro_math(expr, langbb):
215215
return (), {}
216216

217217

218218
@_lower_macro_math.register(Min)
219219
@_lower_macro_math.register(sympy.Min)
220-
def _(expr, lang):
220+
def _(expr, langbb):
221221
if has_integer_args(*expr.args):
222222
return (('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),), {}
223223
else:
224-
return (), as_tuple(lang.get('header-algorithm'))
224+
return (), as_tuple(langbb.get('header-algorithm'))
225225

226226

227227
@_lower_macro_math.register(Max)
228228
@_lower_macro_math.register(sympy.Max)
229-
def _(expr, lang):
229+
def _(expr, langbb):
230230
if has_integer_args(*expr.args):
231231
return (('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))')),), {}
232232
else:
233-
return (), as_tuple(lang.get('header-algorithm'))
233+
return (), as_tuple(langbb.get('header-algorithm'))
234234

235235

236236
@_lower_macro_math.register(SafeInv)
237-
def _(expr, lang):
237+
def _(expr, langbb):
238238
try:
239239
eps = np.finfo(expr.base.dtype).resolution**2
240240
except ValueError:

devito/tools/dtypes_lowering.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,16 @@ def is_external_ctype(ctype, includes):
325325
return False
326326

327327

328+
def is_numpy_dtype(dtype):
329+
"""
330+
True if `dtype` is a numpy dtype, False otherwise.
331+
"""
332+
try:
333+
return issubclass(dtype, np.generic)
334+
except TypeError:
335+
return False
336+
337+
328338
def infer_dtype(dtypes):
329339
"""
330340
Given a set of np.dtypes, return the "winning" dtype:
@@ -335,7 +345,9 @@ def infer_dtype(dtypes):
335345
"""
336346
# Resolve the vector types, if any
337347
dtypes = {dtypes_vector_mapper.get_base_dtype(i, i) for i in dtypes}
338-
348+
# Only keep number types
349+
dtypes = {i for i in dtypes if is_numpy_dtype(i)}
350+
# Separate floating point types from the rest
339351
fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating) or
340352
np.issubdtype(i, np.complexfloating)}
341353
if len(fdtypes) > 1:

0 commit comments

Comments
 (0)