Skip to content

Commit 6ac5b99

Browse files
committed
api: move complex ctype to dtype lowering
1 parent ed90efb commit 6ac5b99

File tree

6 files changed

+26
-29
lines changed

6 files changed

+26
-29
lines changed

devito/passes/clusters/factorization.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from collections import defaultdict
22

33
from sympy import Add, Mul, S, collect
4-
from sympy.core import NumberKind
54

65
from devito.ir import cluster_pass
76
from devito.symbolics import BasicWrapperMixin, estimate_cost, retrieve_symbols
@@ -174,7 +173,7 @@ def _collect_nested(expr):
174173
Recursion helper for `collect_nested`.
175174
"""
176175
# Return semantic (rebuilt expression, factorization candidates)
177-
if expr.kind is NumberKind:
176+
if expr.is_Number:
178177
return expr, {'coeffs': expr}
179178
elif expr.is_Function:
180179
return expr, {'funcs': expr}

devito/symbolics/printer.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,11 @@ def _print_math_func(self, expr, nest=False, known=None):
101101
return super()._print_math_func(expr, nest=nest, known=known)
102102

103103
dtype = sympy_dtype(expr)
104-
if dtype is np.float32:
105-
cname += 'f'
106-
if np.issubdtype(self.dtype, np.complexfloating):
104+
if np.issubdtype(dtype, np.complexfloating):
107105
cname = 'c%s' % cname
106+
dtype = self.dtype(0).real.dtype
107+
if dtype is np.float32:
108+
cname = '%sf' % cname
108109
args = ', '.join((self._print(arg) for arg in expr.args))
109110

110111
return '%s(%s)' % (cname, args)
@@ -198,6 +199,9 @@ def _print_Float(self, expr):
198199

199200
return rv
200201

202+
def _print_ImaginaryUnit(self, expr):
203+
return '_Complex_I'
204+
201205
def _print_Differentiable(self, expr):
202206
return "(%s)" % self._print(expr._expr)
203207

@@ -249,10 +253,12 @@ def _print_ComponentAccess(self, expr):
249253

250254
def _print_TrigonometricFunction(self, expr):
251255
func_name = str(expr.func)
252-
if self.dtype == np.float32:
253-
func_name += 'f'
254-
if np.issubdtype(self.dtype, np.complexfloating):
256+
dtype = self.dtype
257+
if np.issubdtype(dtype, np.complexfloating):
255258
func_name = 'c%s' % func_name
259+
dtype = self.dtype(0).real.dtype
260+
if dtype == np.float32:
261+
func_name = '%sf' % func_name
256262
return '%s(%s)' % (func_name, self._print(*expr.args))
257263

258264
def _print_DefFunction(self, expr):

devito/tools/dtypes_lowering.py

+8
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ def dtype_to_ctype(dtype):
136136
if isinstance(dtype, CustomDtype):
137137
return dtype
138138

139+
# Complex data
140+
if np.issubdtype(dtype, np.complexfloating):
141+
rtype = dtype(0).real.__class__
142+
ctname = '%s _Complex' % dtype_to_cstr(rtype)
143+
ctype = dtype_to_ctype(rtype)
144+
r = type(ctname, (ctype,), {})
145+
return r
146+
139147
try:
140148
return ctypes_vector_mapper[dtype]
141149
except KeyError:

devito/types/basic.py

+3-20
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
from devito.data import default_allocator
1414
from devito.parameters import configuration
1515
from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype,
16-
frozendict, memoized_meth, sympy_mutex, dtype_to_cstr,
17-
CustomDtype)
16+
frozendict, memoized_meth, sympy_mutex)
1817
from devito.types.args import ArgProvider
1918
from devito.types.caching import Cached, Uncached
2019
from devito.types.lazy import Evaluable
@@ -432,16 +431,7 @@ def _C_name(self):
432431

433432
@property
434433
def _C_ctype(self):
435-
if isinstance(self.dtype, CustomDtype):
436-
return self.dtype
437-
elif np.issubdtype(self.dtype, np.complexfloating):
438-
rtype = self.dtype(0).real.__class__
439-
ctname = '%s complex' % dtype_to_cstr(rtype)
440-
ctype = dtype_to_ctype(rtype)
441-
r = type(ctname, (ctype,), {})
442-
return r
443-
else:
444-
return dtype_to_ctype(self.dtype)
434+
return dtype_to_ctype(self.dtype)
445435

446436
def _subs(self, old, new, **hints):
447437
"""
@@ -1438,14 +1428,7 @@ def _C_name(self):
14381428
@cached_property
14391429
def _C_ctype(self):
14401430
try:
1441-
if np.issubdtype(self.dtype, np.complexfloating):
1442-
rtype = self.dtype(0).real.__class__
1443-
ctname = '%s complex' % dtype_to_cstr(rtype)
1444-
ctype = dtype_to_ctype(rtype)
1445-
r = type(ctname, (ctype,), {})
1446-
return POINTER(r)
1447-
else:
1448-
return POINTER(dtype_to_ctype(self.dtype))
1431+
return POINTER(dtype_to_ctype(self.dtype))
14491432
except TypeError:
14501433
# `dtype` is a ctypes-derived type!
14511434
return self.dtype

tests/test_gpu_common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from conftest import assert_structure
88
from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension,
99
Dimension, MatrixSparseTimeFunction, SparseTimeFunction,
10-
SubDimension, SubDomain, SubDomainSet, TimeFunction,
10+
SubDimension, SubDomain, SubDomainSet, TimeFunction, exp,
1111
Operator, configuration, switchconfig, TensorTimeFunction)
1212
from devito.arch import get_gpu_info
1313
from devito.exceptions import InvalidArgument

tests/test_operator.py

+1
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,7 @@ def test_complex(self):
655655
dx = grid.spacing_map[x.spacing]
656656
xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5))
657657
npres = xx + 1j*yy + np.exp(1j + dx)
658+
print(op)
658659

659660
assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0)
660661

0 commit comments

Comments
 (0)