@@ -200,17 +200,26 @@ def complex_include(iet, language, compiler):
200
200
"""
201
201
Add headers for complex arithmetic
202
202
"""
203
- lib = _complex_lib .get (language , 'complex' if compiler ._cpp else 'complex.h' )
203
+ lib = ( _complex_lib .get (language , 'complex' if compiler ._cpp else 'complex.h' ), )
204
204
205
205
headers = {}
206
+
206
207
# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
207
208
if compiler ._cpp :
209
+ # Constant I
208
210
headers = {('_Complex_I' , ('std::complex<float>(0.0f, 1.0f)' ))}
211
+ # Mix arithmetic definitions
212
+ dest = compiler .get_jit_dir ()
213
+ hfile = dest .joinpath ('stdcomplex_arith.h' )
214
+ if not hfile .is_file ():
215
+ with open (str (hfile ), 'w' ) as ff :
216
+ ff .write (str (_stdcomplex_defs ))
217
+ lib += (str (hfile ),)
209
218
210
219
for f in FindSymbols ().visit (iet ):
211
220
try :
212
221
if np .issubdtype (f .dtype , np .complexfloating ):
213
- return iet , {'includes' : ( lib ,) , 'headers' : headers }
222
+ return iet , {'includes' : lib , 'headers' : headers }
214
223
except TypeError :
215
224
pass
216
225
@@ -295,3 +304,23 @@ def _rename_subdims(target, dimensions):
295
304
return {d : d ._rebuild (d .root .name ) for d in dims
296
305
if d .root not in dimensions
297
306
and names .count (d .root .name ) < 2 }
307
+
308
+
309
+ _stdcomplex_defs = """
310
+ #include <complex>
311
+
312
+ template<typename _Tp, typename _Ti>
313
+ std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){
314
+ return std::complex<_Tp>(b.real() * a, b.imag() * a);
315
+ }
316
+
317
+ template<typename _Tp, typename _Ti>
318
+ std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
319
+ return std::complex<_Tp>(b.real() / a, b.imag() / a);
320
+ }
321
+
322
+ template<typename _Tp, typename _Ti>
323
+ std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
324
+ return std::complex<_Tp>(b.real() + a, b.imag());
325
+ }
326
+ """
0 commit comments