@@ -429,6 +429,14 @@ def _call_sendrecv(self, name, *args, **kwargs):
429
429
args = list (args [0 ].handles ) + flatten (args [1 :])
430
430
return Call (name , args )
431
431
432
+ def _make_peers (self , d , distributor , nb ):
433
+ rname = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
434
+ rpeer = FieldFromPointer (rname , nb )
435
+ lname = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
436
+ lpeer = FieldFromPointer (lname , nb )
437
+
438
+ return rpeer , lpeer
439
+
432
440
def _make_haloupdate (self , f , hse , key , sendrecv , ** kwargs ):
433
441
distributor = f .grid .distributor
434
442
nb = distributor ._obj_neighborhood
@@ -443,10 +451,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
443
451
if d in fixed :
444
452
continue
445
453
446
- name = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
447
- rpeer = FieldFromPointer (name , nb )
448
- name = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
449
- lpeer = FieldFromPointer (name , nb )
454
+ rpeer , lpeer = self ._make_peers (d , distributor , nb )
450
455
451
456
if (d , LEFT ) in hse .halos :
452
457
# Sending to left, receiving from right
@@ -616,10 +621,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
616
621
if d in fixed :
617
622
continue
618
623
619
- name = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
620
- rpeer = FieldFromPointer (name , nb )
621
- name = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
622
- lpeer = FieldFromPointer (name , nb )
624
+ rpeer , lpeer = self ._make_peers (d , distributor , nb )
623
625
624
626
if (d , LEFT ) in hse .halos :
625
627
# Sending to left, receiving from right
@@ -1297,6 +1299,7 @@ def _as_number(self, v, args):
1297
1299
return int (subs_op_args (v , args ))
1298
1300
1299
1301
def _allocate_buffers (self , f , shape , entry ):
1302
+ # Allocate the send/recv buffers
1300
1303
entry .sizes = (c_int * len (shape ))(* shape )
1301
1304
size = reduce (mul , shape )* dtype_len (self .target .dtype )
1302
1305
ctype = dtype_to_ctype (f .dtype )
@@ -1429,21 +1432,12 @@ def _arg_defaults(self, allocator, alias, args=None):
1429
1432
if d in fixed :
1430
1433
continue
1431
1434
1432
- if (d , LEFT ) in self .halos :
1433
- entry = self .value [i ]
1434
- i = i + 1
1435
- # Sending to left, receiving from right
1436
- shape = mapper [(d , LEFT , OWNED )]
1437
- # Allocate the send/recv buffers
1438
- self ._allocate_buffers (f , shape , entry )
1439
-
1440
- if (d , RIGHT ) in self .halos :
1441
- entry = self .value [i ]
1442
- i = i + 1
1443
- # Sending to right, receiving from left
1444
- shape = mapper [(d , RIGHT , OWNED )]
1445
- # Allocate the send/recv buffers
1446
- self ._allocate_buffers (f , shape , entry )
1435
+ for side in (LEFT , RIGHT ):
1436
+ if (d , side ) in self .halos :
1437
+ entry = self .value [i ]
1438
+ i += 1
1439
+ shape = mapper [(d , side , OWNED )]
1440
+ self ._allocate_buffers (f , shape , entry )
1447
1441
1448
1442
return {self .name : self .value }
1449
1443
0 commit comments