10
10
import ctypes
11
11
12
12
import cgen as c
13
+ import numpy as np
13
14
from sympy import IndexedBase
14
15
from sympy .core .function import Application
15
16
@@ -188,6 +189,18 @@ def __init__(self, *args, compiler=None, **kwargs):
188
189
}
189
190
_restrict_keyword = 'restrict'
190
191
192
+ def _complex_type (self , ctypestr , dtype ):
193
+ # Not complex
194
+ if not np .issubdtype (dtype , np .complexfloating ):
195
+ return ctypestr
196
+ # Complex only supported for float and double
197
+ if ctypestr not in ('float' , 'double' ):
198
+ return ctypestr
199
+ if self ._compiler ._cpp :
200
+ return 'std::complex<%s>' % ctypestr
201
+ else :
202
+ return '%s _Complex' % ctypestr
203
+
191
204
def _gen_struct_decl (self , obj , masked = ()):
192
205
"""
193
206
Convert ctypes.Struct -> cgen.Structure.
@@ -243,10 +256,10 @@ def _gen_value(self, obj, mode=1, masked=()):
243
256
if getattr (obj .function , k , False ) and v not in masked ]
244
257
245
258
if (obj ._mem_stack or obj ._mem_constant ) and mode == 1 :
246
- strtype = obj ._C_typedata
259
+ strtype = self . _complex_type ( obj ._C_typedata , obj . dtype )
247
260
strshape = '' .join ('[%s]' % ccode (i ) for i in obj .symbolic_shape )
248
261
else :
249
- strtype = ctypes_to_cstr (obj ._C_ctype )
262
+ strtype = self . _complex_type ( ctypes_to_cstr (obj ._C_ctype ), obj . dtype )
250
263
strshape = ''
251
264
if isinstance (obj , (AbstractFunction , IndexedData )) and mode >= 1 :
252
265
if not obj ._mem_stack :
@@ -376,10 +389,11 @@ def visit_tuple(self, o):
376
389
def visit_PointerCast (self , o ):
377
390
f = o .function
378
391
i = f .indexed
392
+ cstr = self ._complex_type (i ._C_typedata , i .dtype )
379
393
380
394
if f .is_PointerArray :
381
395
# lvalue
382
- lvalue = c .Value (i . _C_typedata , '**%s' % f .name )
396
+ lvalue = c .Value (cstr , '**%s' % f .name )
383
397
384
398
# rvalue
385
399
if isinstance (o .obj , ArrayObject ):
@@ -388,7 +402,7 @@ def visit_PointerCast(self, o):
388
402
v = f ._C_name
389
403
else :
390
404
assert False
391
- rvalue = '(%s**) %s' % (i . _C_typedata , v )
405
+ rvalue = '(%s**) %s' % (cstr , v )
392
406
393
407
else :
394
408
# lvalue
@@ -399,10 +413,10 @@ def visit_PointerCast(self, o):
399
413
if o .flat is None :
400
414
shape = '' .join ("[%s]" % ccode (i ) for i in o .castshape )
401
415
rshape = '(*)%s' % shape
402
- lvalue = c .Value (i . _C_typedata , '(*restrict %s)%s' % (v , shape ))
416
+ lvalue = c .Value (cstr , '(*restrict %s)%s' % (v , shape ))
403
417
else :
404
418
rshape = '*'
405
- lvalue = c .Value (i . _C_typedata , '*%s' % v )
419
+ lvalue = c .Value (cstr , '*%s' % v )
406
420
if o .alignment and f ._data_alignment :
407
421
lvalue = c .AlignedAttribute (f ._data_alignment , lvalue )
408
422
@@ -415,30 +429,30 @@ def visit_PointerCast(self, o):
415
429
else :
416
430
assert False
417
431
418
- rvalue = '(%s %s) %s->%s' % (i . _C_typedata , rshape , f ._C_name , v )
432
+ rvalue = '(%s %s) %s->%s' % (cstr , rshape , f ._C_name , v )
419
433
else :
420
434
if isinstance (o .obj , Pointer ):
421
435
v = o .obj .name
422
436
else :
423
437
v = f ._C_name
424
438
425
- rvalue = '(%s %s) %s' % (i . _C_typedata , rshape , v )
439
+ rvalue = '(%s %s) %s' % (cstr , rshape , v )
426
440
427
441
return c .Initializer (lvalue , rvalue )
428
442
429
443
def visit_Dereference (self , o ):
430
444
a0 , a1 = o .functions
431
445
if a1 .is_PointerArray or a1 .is_TempFunction :
432
446
i = a1 .indexed
447
+ cstr = self ._complex_type (i ._C_typedata , i .dtype )
433
448
if o .flat is None :
434
449
shape = '' .join ("[%s]" % ccode (i ) for i in a0 .symbolic_shape [1 :])
435
- rvalue = '(%s (*)%s) %s[%s]' % (i . _C_typedata , shape , a1 .name ,
450
+ rvalue = '(%s (*)%s) %s[%s]' % (cstr , shape , a1 .name ,
436
451
a1 .dim .name )
437
- lvalue = c .Value (i ._C_typedata ,
438
- '(*restrict %s)%s' % (a0 .name , shape ))
452
+ lvalue = c .Value (cstr , '(*restrict %s)%s' % (a0 .name , shape ))
439
453
else :
440
- rvalue = '(%s *) %s[%s]' % (i . _C_typedata , a1 .name , a1 .dim .name )
441
- lvalue = c .Value (i . _C_typedata , '*restrict %s' % a0 .name )
454
+ rvalue = '(%s *) %s[%s]' % (cstr , a1 .name , a1 .dim .name )
455
+ lvalue = c .Value (cstr , '*restrict %s' % a0 .name )
442
456
if a0 ._data_alignment :
443
457
lvalue = c .AlignedAttribute (a0 ._data_alignment , lvalue )
444
458
else :
0 commit comments