11
11
from devito .passes .iet .engine import iet_pass
12
12
from devito .ir .iet .efunc import DeviceFunction , EntryFunction
13
13
from devito .symbolics import ValueLimit , evalrel , has_integer_args , limits_mapper
14
- from devito .tools import as_mapper , filter_ordered , split
14
+ from devito .tools import as_mapper , filter_ordered , split , dtype_to_cstr
15
15
16
16
__all__ = ['avoid_denormals' , 'hoist_prodders' , 'relax_incr_dimensions' ,
17
17
'generate_macros' , 'minimize_symbols' , 'complex_include' ]
@@ -192,22 +192,28 @@ def minimize_symbols(iet):
192
192
return iet , {}
193
193
194
194
195
- _complex_lib = {'cuda' : 'cuComplex .h' , 'hip' : 'hip/hip_complex.h' }
195
+ _complex_lib = {'cuda' : 'thrust/complex .h' , 'hip' : 'hip/hip_complex.h' }
196
196
197
197
198
198
@iet_pass
199
199
def complex_include (iet , language , compiler ):
200
200
"""
201
201
Add headers for complex arithmetic
202
202
"""
203
+ # Check if there is complex numbers that always take dtype precedence
204
+ max_dtype = np .result_type (* [f .dtype for f in FindSymbols ().visit (iet )])
205
+ if not np .issubdtype (max_dtype , np .complexfloating ):
206
+ return iet , {}
207
+
203
208
lib = (_complex_lib .get (language , 'complex' if compiler ._cpp else 'complex.h' ),)
204
209
205
210
headers = {}
206
211
207
212
# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
208
213
if compiler ._cpp :
214
+ c_str = dtype_to_cstr (max_dtype .type (0 ).real .dtype .type )
209
215
# Constant I
210
- headers = {('_Complex_I' , ('std::complex<float >(0.0f , 1.0f)' ))}
216
+ headers = {('_Complex_I' , ('std::complex<%s >(0.0 , 1.0)' % c_str ))}
211
217
# Mix arithmetic definitions
212
218
dest = compiler .get_jit_dir ()
213
219
hfile = dest .joinpath ('stdcomplex_arith.h' )
@@ -216,14 +222,7 @@ def complex_include(iet, language, compiler):
216
222
ff .write (str (_stdcomplex_defs ))
217
223
lib += (str (hfile ),)
218
224
219
- for f in FindSymbols ().visit (iet ):
220
- try :
221
- if np .issubdtype (f .dtype , np .complexfloating ):
222
- return iet , {'includes' : lib , 'headers' : headers }
223
- except TypeError :
224
- pass
225
-
226
- return iet , {}
225
+ return iet , {'includes' : lib , 'headers' : headers }
227
226
228
227
229
228
def remove_redundant_moddims (iet ):
@@ -314,13 +313,29 @@ def _rename_subdims(target, dimensions):
314
313
return std::complex<_Tp>(b.real() * a, b.imag() * a);
315
314
}
316
315
316
+ template<typename _Tp, typename _Ti>
317
+ std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){
318
+ return std::complex<_Tp>(b.real() * a, b.imag() * a);
319
+ }
320
+
317
321
template<typename _Tp, typename _Ti>
318
322
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
323
+ _Tp denom = b.real() * b.real () + b.imag() * b.imag()
324
+ return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom);
325
+ }
326
+
327
+ template<typename _Tp, typename _Ti>
328
+ std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){
319
329
return std::complex<_Tp>(b.real() / a, b.imag() / a);
320
330
}
321
331
322
332
template<typename _Tp, typename _Ti>
323
333
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
324
334
return std::complex<_Tp>(b.real() + a, b.imag());
325
335
}
336
+
337
+ template<typename _Tp, typename _Ti>
338
+ std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){
339
+ return std::complex<_Tp>(b.real() + a, b.imag());
340
+ }
326
341
"""
0 commit comments