@@ -753,8 +753,10 @@ class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder):
753
753
"""
754
754
755
755
def _make_msg (self , f , hse , key ):
756
- # Pass the whole of hse
757
- return MPIMsgBasic ('msg%d' % key , f , hse .halos , hse )
756
+ # Pass the fixed mapper e.g. {t: otime}
757
+ fixed = {d : Symbol (name = "o%s" % d .root ) for d in hse .loc_indices }
758
+
759
+ return MPIMsgBasic2 ('msg%d' % key , f , hse .halos , fixed )
758
760
759
761
def _make_sendrecv (self , f , hse , key , msg = None ):
760
762
cast = cast_mapper [(f .c0 .dtype , '*' )]
@@ -1312,7 +1314,6 @@ def _arg_defaults(self, allocator, alias, args=None):
1312
1314
self ._allocator = allocator
1313
1315
1314
1316
f = alias or self .target .c0
1315
-
1316
1317
for i , halo in enumerate (self .halos ):
1317
1318
entry = self .value [i ]
1318
1319
@@ -1374,7 +1375,6 @@ def _arg_defaults(self, allocator, alias, args=None):
1374
1375
self ._allocator = allocator
1375
1376
1376
1377
f = alias or self .target .c0
1377
-
1378
1378
for i , halo in enumerate (self .halos ):
1379
1379
entry = self .value [i ]
1380
1380
@@ -1401,32 +1401,28 @@ def _arg_defaults(self, allocator, alias, args=None):
1401
1401
return {self .name : self .value }
1402
1402
1403
1403
1404
- class MPIMsgBasic (MPIMsgBase ):
1404
+ class MPIMsgBasic2 (MPIMsgBase ):
1405
1405
1406
- def __init__ (self , name , target , halos , hse = None ):
1406
+ def __init__ (self , name , target , halos , fixed = None ):
1407
1407
self ._target = target
1408
1408
self ._halos = halos
1409
1409
1410
1410
super ().__init__ (name , 'msg' , self .fields )
1411
1411
1412
1412
# Required for buffer allocation/deallocation before/after jumping/returning
1413
1413
# to/from C-land
1414
- self ._hse = hse
1414
+ self ._fixed = fixed
1415
1415
self ._allocator = None
1416
1416
self ._memfree_args = []
1417
1417
1418
- @property
1419
- def hse (self ):
1420
- return self ._hse
1421
-
1422
1418
def _arg_defaults (self , allocator , alias , args = None ):
1423
1419
# Lazy initialization if `allocator` is necessary as the `allocator`
1424
1420
# type isn't really known until an Operator is constructed
1425
1421
self ._allocator = allocator
1426
1422
1427
1423
f = alias or self .target .c0
1428
1424
1429
- fixed = { d : Symbol ( name = "o%s" % d . root ) for d in self .hse . loc_indices }
1425
+ fixed = self ._fixed
1430
1426
1431
1427
# Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
1432
1428
# `ofs` are symbolic objects. This mapper tells what data values should be
@@ -1440,7 +1436,6 @@ def _arg_defaults(self, allocator, alias, args=None):
1440
1436
if d1 in fixed :
1441
1437
continue
1442
1438
else :
1443
- # meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side)
1444
1439
if d0 is d1 :
1445
1440
if region is OWNED :
1446
1441
sizes .append (getattr (f ._size_owned [d0 ], side .name ))
@@ -1455,7 +1450,7 @@ def _arg_defaults(self, allocator, alias, args=None):
1455
1450
if d in fixed :
1456
1451
continue
1457
1452
1458
- if (d , LEFT ) in self .hse . halos :
1453
+ if (d , LEFT ) in self .halos :
1459
1454
entry = self .value [i ]
1460
1455
i = i + 1
1461
1456
# Sending to left, receiving from right
@@ -1471,14 +1466,13 @@ def _arg_defaults(self, allocator, alias, args=None):
1471
1466
# returning from C-land
1472
1467
self ._memfree_args .extend ([bufg_memfree_args , bufs_memfree_args ])
1473
1468
1474
- if (d , RIGHT ) in self .hse . halos :
1469
+ if (d , RIGHT ) in self .halos :
1475
1470
entry = self .value [i ]
1476
1471
i = i + 1
1477
1472
# Sending to right, receiving from left
1478
1473
shape = mapper [(d , RIGHT , OWNED )]
1479
- entry .sizes = (c_int * len (shape ))(* shape )
1480
-
1481
1474
# Allocate the send/recv buffers
1475
+ entry .sizes = (c_int * len (shape ))(* shape )
1482
1476
size = reduce (mul , shape )* dtype_len (self .target .dtype )
1483
1477
ctype = dtype_to_ctype (f .dtype )
1484
1478
entry .bufg , bufg_memfree_args = allocator ._alloc_C_libcall (size , ctype )
0 commit comments