|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +from devito.ir import Call |
| 4 | +from devito.passes.iet.langbase import LangBB |
| 5 | +from devito.tools import CustomNpType |
| 6 | + |
| 7 | +__all__ = ['CXXBB'] |
| 8 | + |
| 9 | + |
| 10 | +std_arith = """ |
| 11 | +#include <complex> |
| 12 | +
|
| 13 | +template<typename _Tp, typename _Ti> |
| 14 | +std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ |
| 15 | + return std::complex<_Tp>(b.real() * a, b.imag() * a); |
| 16 | +} |
| 17 | +
|
| 18 | +template<typename _Tp, typename _Ti> |
| 19 | +std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ |
| 20 | + return std::complex<_Tp>(b.real() * a, b.imag() * a); |
| 21 | +} |
| 22 | +
|
| 23 | +template<typename _Tp, typename _Ti> |
| 24 | +std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ |
| 25 | + _Tp denom = b.real() * b.real () + b.imag() * b.imag() |
| 26 | + return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); |
| 27 | +} |
| 28 | +
|
| 29 | +template<typename _Tp, typename _Ti> |
| 30 | +std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ |
| 31 | + return std::complex<_Tp>(b.real() / a, b.imag() / a); |
| 32 | +} |
| 33 | +
|
| 34 | +template<typename _Tp, typename _Ti> |
| 35 | +std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ |
| 36 | + return std::complex<_Tp>(b.real() + a, b.imag()); |
| 37 | +} |
| 38 | +
|
| 39 | +template<typename _Tp, typename _Ti> |
| 40 | +std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ |
| 41 | + return std::complex<_Tp>(b.real() + a, b.imag()); |
| 42 | +} |
| 43 | +""" |
| 44 | + |
| 45 | +CXXCFloat = CustomNpType('std::complex', np.complex64, template='float') |
| 46 | +CXXCDouble = CustomNpType('std::complex', np.complex128, template='double') |
| 47 | + |
| 48 | + |
| 49 | +class CXXBB(LangBB): |
| 50 | + |
| 51 | + mapper = { |
| 52 | + 'header-memcpy': 'string.h', |
| 53 | + 'host-alloc': lambda i, j, k: |
| 54 | + Call('posix_memalign', (i, j, k)), |
| 55 | + 'host-alloc-pin': lambda i, j, k: |
| 56 | + Call('posix_memalign', (i, j, k)), |
| 57 | + 'host-free': lambda i: |
| 58 | + Call('free', (i,)), |
| 59 | + 'host-free-pin': lambda i: |
| 60 | + Call('free', (i,)), |
| 61 | + 'alloc-global-symbol': lambda i, j, k: |
| 62 | + Call('memcpy', (i, j, k)), |
| 63 | + # Complex |
| 64 | + 'header-complex': '<complex>', |
| 65 | + 'I-def': (('_Complex_I', ('std::complex<float>(0.0, 1.0)')),), |
| 66 | + 'def-complex': std_arith, |
| 67 | + 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, |
| 68 | + } |
0 commit comments