@@ -443,10 +443,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
443
443
if d in fixed :
444
444
continue
445
445
446
- name = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
447
- rpeer = FieldFromPointer (name , nb )
448
- name = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
449
- lpeer = FieldFromPointer (name , nb )
446
+ rpeer , lpeer = self ._make_peers (d , distributor , nb )
450
447
451
448
if (d , LEFT ) in hse .halos :
452
449
# Sending to left, receiving from right
@@ -491,6 +488,14 @@ def _make_basic_mapper(self, f, fixed):
491
488
492
489
return mapper
493
490
491
+ def _make_peers (self , d , distributor , nb ):
492
+ rname = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
493
+ rpeer = FieldFromPointer (rname , nb )
494
+ lname = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
495
+ lpeer = FieldFromPointer (lname , nb )
496
+
497
+ return rpeer , lpeer
498
+
494
499
def _call_haloupdate (self , name , f , hse , * args ):
495
500
comm = f .grid .distributor ._obj_comm
496
501
nb = f .grid .distributor ._obj_neighborhood
@@ -616,10 +621,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
616
621
if d in fixed :
617
622
continue
618
623
619
- name = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
620
- rpeer = FieldFromPointer (name , nb )
621
- name = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
622
- lpeer = FieldFromPointer (name , nb )
624
+ rpeer , lpeer = self ._make_peers (d , distributor , nb )
623
625
624
626
if (d , LEFT ) in hse .halos :
625
627
# Sending to left, receiving from right
@@ -1297,6 +1299,7 @@ def _as_number(self, v, args):
1297
1299
return int (subs_op_args (v , args ))
1298
1300
1299
1301
def _allocate_buffers (self , f , shape , entry ):
1302
+ # Allocate the send/recv buffers
1300
1303
entry .sizes = (c_int * len (shape ))(* shape )
1301
1304
size = reduce (mul , shape )* dtype_len (self .target .dtype )
1302
1305
ctype = dtype_to_ctype (f .dtype )
@@ -1429,21 +1432,12 @@ def _arg_defaults(self, allocator, alias, args=None):
1429
1432
if d in fixed :
1430
1433
continue
1431
1434
1432
- if (d , LEFT ) in self .halos :
1433
- entry = self .value [i ]
1434
- i = i + 1
1435
- # Sending to left, receiving from right
1436
- shape = mapper [(d , LEFT , OWNED )]
1437
- # Allocate the send/recv buffers
1438
- self ._allocate_buffers (f , shape , entry )
1439
-
1440
- if (d , RIGHT ) in self .halos :
1441
- entry = self .value [i ]
1442
- i = i + 1
1443
- # Sending to right, receiving from left
1444
- shape = mapper [(d , RIGHT , OWNED )]
1445
- # Allocate the send/recv buffers
1446
- self ._allocate_buffers (f , shape , entry )
1435
+ for side in (LEFT , RIGHT ):
1436
+ if (d , side ) in self .halos :
1437
+ entry = self .value [i ]
1438
+ i += 1
1439
+ shape = mapper [(d , side , OWNED )]
1440
+ self ._allocate_buffers (f , shape , entry )
1447
1441
1448
1442
return {self .name : self .value }
1449
1443
0 commit comments