Skip to content

Commit 014ef2c

Browse files
committed
compiler: add std::complex arithmetic defs for unsupported types
1 parent 5a6b169 commit 014ef2c

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

devito/ir/iet/visitors.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sympy import IndexedBase
1515
from sympy.core.function import Application
1616

17+
from devito.parameters import configuration
1718
from devito.exceptions import VisitorException
1819
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
1920
Call, Lambda, BlankLine, Section, ListMajor)
@@ -177,7 +178,7 @@ class CGen(Visitor):
177178

178179
def __init__(self, *args, compiler=None, **kwargs):
179180
super().__init__(*args, **kwargs)
180-
self._compiler = compiler
181+
self._compiler = compiler or configuration['compiler']
181182

182183
# The following mappers may be customized by subclasses (that is,
183184
# backend-specific CGen-erators)

devito/passes/iet/misc.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -200,17 +200,26 @@ def complex_include(iet, language, compiler):
200200
"""
201201
Add headers for complex arithmetic
202202
"""
203-
lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h')
203+
lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),)
204204

205205
headers = {}
206+
206207
# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
207208
if compiler._cpp:
209+
# Constant I
208210
headers = {('_Complex_I', ('std::complex<float>(0.0f, 1.0f)'))}
211+
# Mix arithmetic definitions
212+
dest = compiler.get_jit_dir()
213+
hfile = dest.joinpath('stdcomplex_arith.h')
214+
if not hfile.is_file():
215+
with open(str(hfile), 'w') as ff:
216+
ff.write(str(_stdcomplex_defs))
217+
lib += (str(hfile),)
209218

210219
for f in FindSymbols().visit(iet):
211220
try:
212221
if np.issubdtype(f.dtype, np.complexfloating):
213-
return iet, {'includes': (lib,), 'headers': headers}
222+
return iet, {'includes': lib, 'headers': headers}
214223
except TypeError:
215224
pass
216225

@@ -295,3 +304,23 @@ def _rename_subdims(target, dimensions):
295304
return {d: d._rebuild(d.root.name) for d in dims
296305
if d.root not in dimensions
297306
and names.count(d.root.name) < 2}
307+
308+
309+
_stdcomplex_defs = """
310+
#include <complex>
311+
312+
template<typename _Tp, typename _Ti>
313+
std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){
314+
return std::complex<_Tp>(b.real() * a, b.imag() * a);
315+
}
316+
317+
template<typename _Tp, typename _Ti>
318+
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
319+
return std::complex<_Tp>(b.real() / a, b.imag() / a);
320+
}
321+
322+
template<typename _Tp, typename _Ti>
323+
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
324+
return std::complex<_Tp>(b.real() + a, b.imag());
325+
}
326+
"""

devito/symbolics/printer.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def dtype(self):
3939
def compiler(self):
4040
return self._settings['compiler']
4141

42+
@property
43+
def cpp(self):
44+
return self.compiler._cpp
45+
4246
def parenthesize(self, item, level, strict=False):
4347
if isinstance(item, BooleanFunction):
4448
return "(%s)" % self._print(item)
@@ -101,7 +105,7 @@ def _print_math_func(self, expr, nest=False, known=None):
101105
return super()._print_math_func(expr, nest=nest, known=known)
102106

103107
dtype = sympy_dtype(expr)
104-
if np.issubdtype(dtype, np.complexfloating):
108+
if np.issubdtype(dtype, np.complexfloating) and not self.cpp:
105109
cname = 'c%s' % cname
106110
dtype = self.dtype(0).real.dtype.type
107111

@@ -255,7 +259,7 @@ def _print_ComponentAccess(self, expr):
255259
def _print_TrigonometricFunction(self, expr):
256260
func_name = str(expr.func)
257261
dtype = self.dtype
258-
if np.issubdtype(dtype, np.complexfloating):
262+
if np.issubdtype(dtype, np.complexfloating) and not self.cpp:
259263
func_name = 'c%s' % func_name
260264
dtype = self.dtype(0).real.dtype.type
261265
if dtype == np.float32:

0 commit comments

Comments
 (0)