@@ -468,9 +468,11 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
468
468
return HaloUpdate ('haloupdate%s' % key , iet , parameters )
469
469
470
470
def _make_basic_mapper (self , f , fixed ):
471
- # Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
472
- # `ofs` are symbolic objects. This mapper tells what data values should be
473
- # sent (OWNED) or received (HALO) given dimension and side
471
+ """
472
+ Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
473
+ `ofs` are symbolic objects. This mapper tells what data values should be
474
+ sent (OWNED) or received (HALO) given dimension and side
475
+ """
474
476
mapper = {}
475
477
for d0 , side , region in product (f .dimensions , (LEFT , RIGHT ), (OWNED , HALO )):
476
478
if d0 in fixed :
@@ -636,13 +638,9 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
636
638
637
639
iet = List (body = body )
638
640
639
- parameters = list (f .handles ) + [comm , nb ] + list (fixed .values ())
640
-
641
- node = HaloUpdate ('haloupdate%s' % key , iet , parameters )
642
-
643
- node = node ._rebuild (parameters = node .parameters + (kwargs ['msg' ],))
641
+ parameters = list (f .handles ) + [comm , nb ] + list (fixed .values ()) + [kwargs ['msg' ]]
644
642
645
- return node
643
+ return HaloUpdate ( 'haloupdate%s' % key , iet , parameters )
646
644
647
645
def _call_haloupdate (self , name , f , hse , msg ):
648
646
call = super ()._call_haloupdate (name , f , hse )
@@ -1298,6 +1296,17 @@ def _as_number(self, v, args):
1298
1296
assert args is not None
1299
1297
return int (subs_op_args (v , args ))
1300
1298
1299
+ def _allocate_buffers (self , f , shape , entry ):
1300
+ entry .sizes = (c_int * len (shape ))(* shape )
1301
+ size = reduce (mul , shape )* dtype_len (self .target .dtype )
1302
+ ctype = dtype_to_ctype (f .dtype )
1303
+ entry .bufg , bufg_memfree_args = self ._allocator ._alloc_C_libcall (size , ctype )
1304
+ entry .bufs , bufs_memfree_args = self ._allocator ._alloc_C_libcall (size , ctype )
1305
+ # The `memfree_args` will be used to deallocate the buffer upon
1306
+ # returning from C-land
1307
+ self ._memfree_args .extend ([bufg_memfree_args , bufs_memfree_args ])
1308
+ return
1309
+
1301
1310
def _arg_defaults (self , allocator , alias , args = None ):
1302
1311
# Lazy initialization if `allocator` is necessary as the `allocator`
1303
1312
# type isn't really known until an Operator is constructed
@@ -1315,17 +1324,9 @@ def _arg_defaults(self, allocator, alias, args=None):
1315
1324
except AttributeError :
1316
1325
assert side is CENTER
1317
1326
shape .append (self ._as_number (f ._size_domain [dim ], args ))
1318
- entry .sizes = (c_int * len (shape ))(* shape )
1319
1327
1320
1328
# Allocate the send/recv buffers
1321
- size = reduce (mul , shape )* dtype_len (self .target .dtype )
1322
- ctype = dtype_to_ctype (f .dtype )
1323
- entry .bufg , bufg_memfree_args = allocator ._alloc_C_libcall (size , ctype )
1324
- entry .bufs , bufs_memfree_args = allocator ._alloc_C_libcall (size , ctype )
1325
-
1326
- # The `memfree_args` will be used to deallocate the buffer upon
1327
- # returning from C-land
1328
- self ._memfree_args .extend ([bufg_memfree_args , bufs_memfree_args ])
1329
+ self ._allocate_buffers (f , shape , entry )
1329
1330
1330
1331
return {self .name : self .value }
1331
1332
@@ -1376,17 +1377,9 @@ def _arg_defaults(self, allocator, alias, args=None):
1376
1377
except AttributeError :
1377
1378
assert side is CENTER
1378
1379
shape .append (self ._as_number (f ._size_domain [dim ], args ))
1379
- entry .sizes = (c_int * len (shape ))(* shape )
1380
1380
1381
1381
# Allocate the send/recv buffers
1382
- size = reduce (mul , shape )* dtype_len (self .target .dtype )
1383
- ctype = dtype_to_ctype (f .dtype )
1384
- entry .bufg , bufg_memfree_args = allocator ._alloc_C_libcall (size , ctype )
1385
- entry .bufs , bufs_memfree_args = allocator ._alloc_C_libcall (size , ctype )
1386
-
1387
- # The `memfree_args` will be used to deallocate the buffer upon
1388
- # returning from C-land
1389
- self ._memfree_args .extend ([bufg_memfree_args , bufs_memfree_args ])
1382
+ self ._allocate_buffers (f , shape , entry )
1390
1383
1391
1384
return {self .name : self .value }
1392
1385
@@ -1444,31 +1437,15 @@ def _arg_defaults(self, allocator, alias, args=None):
1444
1437
# Sending to left, receiving from right
1445
1438
shape = mapper [(d , LEFT , OWNED )]
1446
1439
# Allocate the send/recv buffers
1447
- entry .sizes = (c_int * len (shape ))(* shape )
1448
- size = reduce (mul , shape )* dtype_len (self .target .dtype )
1449
- ctype = dtype_to_ctype (f .dtype )
1450
- entry .bufg , bufg_memfree_args = allocator ._alloc_C_libcall (size , ctype )
1451
- entry .bufs , bufs_memfree_args = allocator ._alloc_C_libcall (size , ctype )
1452
-
1453
- # The `memfree_args` will be used to deallocate the buffer upon
1454
- # returning from C-land
1455
- self ._memfree_args .extend ([bufg_memfree_args , bufs_memfree_args ])
1440
+ self ._allocate_buffers (f , shape , entry )
1456
1441
1457
1442
if (d , RIGHT ) in self .halos :
1458
1443
entry = self .value [i ]
1459
1444
i = i + 1
1460
1445
# Sending to right, receiving from left
1461
1446
shape = mapper [(d , RIGHT , OWNED )]
1462
1447
# Allocate the send/recv buffers
1463
- entry .sizes = (c_int * len (shape ))(* shape )
1464
- size = reduce (mul , shape )* dtype_len (self .target .dtype )
1465
- ctype = dtype_to_ctype (f .dtype )
1466
- entry .bufg , bufg_memfree_args = allocator ._alloc_C_libcall (size , ctype )
1467
- entry .bufs , bufs_memfree_args = allocator ._alloc_C_libcall (size , ctype )
1468
-
1469
- # The `memfree_args` will be used to deallocate the buffer upon
1470
- # returning from C-land
1471
- self ._memfree_args .extend ([bufg_memfree_args , bufs_memfree_args ])
1448
+ self ._allocate_buffers (f , shape , entry )
1472
1449
1473
1450
return {self .name : self .value }
1474
1451
0 commit comments