10
10
import ctypes
11
11
12
12
import cgen as c
13
+ import numpy as np
13
14
from sympy import IndexedBase
14
15
from sympy .core .function import Application
15
16
@@ -188,6 +189,21 @@ def __init__(self, *args, compiler=None, **kwargs):
188
189
}
189
190
_restrict_keyword = 'restrict'
190
191
192
+ def _complex_type (self , ctypestr , dtype ):
193
+ # Not complex
194
+ try :
195
+ if not np .issubdtype (dtype , np .complexfloating ):
196
+ return ctypestr
197
+ except TypeError :
198
+ return ctypestr
199
+ # Complex only supported for float and double
200
+ if ctypestr not in ('float' , 'double' ):
201
+ return ctypestr
202
+ if self ._compiler ._cpp :
203
+ return 'std::complex<%s>' % ctypestr
204
+ else :
205
+ return '%s _Complex' % ctypestr
206
+
191
207
def _gen_struct_decl (self , obj , masked = ()):
192
208
"""
193
209
Convert ctypes.Struct -> cgen.Structure.
@@ -243,10 +259,10 @@ def _gen_value(self, obj, mode=1, masked=()):
243
259
if getattr (obj .function , k , False ) and v not in masked ]
244
260
245
261
if (obj ._mem_stack or obj ._mem_constant ) and mode == 1 :
246
- strtype = obj ._C_typedata
262
+ strtype = self . _complex_type ( obj ._C_typedata , obj . dtype )
247
263
strshape = '' .join ('[%s]' % ccode (i ) for i in obj .symbolic_shape )
248
264
else :
249
- strtype = ctypes_to_cstr (obj ._C_ctype )
265
+ strtype = self . _complex_type ( ctypes_to_cstr (obj ._C_ctype ), obj . dtype )
250
266
strshape = ''
251
267
if isinstance (obj , (AbstractFunction , IndexedData )) and mode >= 1 :
252
268
if not obj ._mem_stack :
@@ -376,10 +392,11 @@ def visit_tuple(self, o):
376
392
def visit_PointerCast (self , o ):
377
393
f = o .function
378
394
i = f .indexed
395
+ cstr = self ._complex_type (i ._C_typedata , i .dtype )
379
396
380
397
if f .is_PointerArray :
381
398
# lvalue
382
- lvalue = c .Value (i . _C_typedata , '**%s' % f .name )
399
+ lvalue = c .Value (cstr , '**%s' % f .name )
383
400
384
401
# rvalue
385
402
if isinstance (o .obj , ArrayObject ):
@@ -388,7 +405,7 @@ def visit_PointerCast(self, o):
388
405
v = f ._C_name
389
406
else :
390
407
assert False
391
- rvalue = '(%s**) %s' % (i . _C_typedata , v )
408
+ rvalue = '(%s**) %s' % (cstr , v )
392
409
393
410
else :
394
411
# lvalue
@@ -399,10 +416,10 @@ def visit_PointerCast(self, o):
399
416
if o .flat is None :
400
417
shape = '' .join ("[%s]" % ccode (i ) for i in o .castshape )
401
418
rshape = '(*)%s' % shape
402
- lvalue = c .Value (i . _C_typedata , '(*restrict %s)%s' % (v , shape ))
419
+ lvalue = c .Value (cstr , '(*restrict %s)%s' % (v , shape ))
403
420
else :
404
421
rshape = '*'
405
- lvalue = c .Value (i . _C_typedata , '*%s' % v )
422
+ lvalue = c .Value (cstr , '*%s' % v )
406
423
if o .alignment and f ._data_alignment :
407
424
lvalue = c .AlignedAttribute (f ._data_alignment , lvalue )
408
425
@@ -415,30 +432,30 @@ def visit_PointerCast(self, o):
415
432
else :
416
433
assert False
417
434
418
- rvalue = '(%s %s) %s->%s' % (i . _C_typedata , rshape , f ._C_name , v )
435
+ rvalue = '(%s %s) %s->%s' % (cstr , rshape , f ._C_name , v )
419
436
else :
420
437
if isinstance (o .obj , Pointer ):
421
438
v = o .obj .name
422
439
else :
423
440
v = f ._C_name
424
441
425
- rvalue = '(%s %s) %s' % (i . _C_typedata , rshape , v )
442
+ rvalue = '(%s %s) %s' % (cstr , rshape , v )
426
443
427
444
return c .Initializer (lvalue , rvalue )
428
445
429
446
def visit_Dereference (self , o ):
430
447
a0 , a1 = o .functions
431
448
if a1 .is_PointerArray or a1 .is_TempFunction :
432
449
i = a1 .indexed
450
+ cstr = self ._complex_type (i ._C_typedata , i .dtype )
433
451
if o .flat is None :
434
452
shape = '' .join ("[%s]" % ccode (i ) for i in a0 .symbolic_shape [1 :])
435
- rvalue = '(%s (*)%s) %s[%s]' % (i . _C_typedata , shape , a1 .name ,
453
+ rvalue = '(%s (*)%s) %s[%s]' % (cstr , shape , a1 .name ,
436
454
a1 .dim .name )
437
- lvalue = c .Value (i ._C_typedata ,
438
- '(*restrict %s)%s' % (a0 .name , shape ))
455
+ lvalue = c .Value (cstr , '(*restrict %s)%s' % (a0 .name , shape ))
439
456
else :
440
- rvalue = '(%s *) %s[%s]' % (i . _C_typedata , a1 .name , a1 .dim .name )
441
- lvalue = c .Value (i . _C_typedata , '*restrict %s' % a0 .name )
457
+ rvalue = '(%s *) %s[%s]' % (cstr , a1 .name , a1 .dim .name )
458
+ lvalue = c .Value (cstr , '*restrict %s' % a0 .name )
442
459
if a0 ._data_alignment :
443
460
lvalue = c .AlignedAttribute (a0 ._data_alignment , lvalue )
444
461
else :
0 commit comments