Skip to content

Commit 01ed5e2

Browse files
committed
compiler: Reduce some code after reviews
1 parent c6b6a42 commit 01ed5e2

File tree

1 file changed

+17
-23
lines changed

1 file changed

+17
-23
lines changed

devito/mpi/routines.py

+17-23
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,14 @@ def _call_sendrecv(self, name, *args, **kwargs):
429429
args = list(args[0].handles) + flatten(args[1:])
430430
return Call(name, args)
431431

432+
def _make_peers(self, d, distributor, nb):
433+
rname = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
434+
rpeer = FieldFromPointer(rname, nb)
435+
lname = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
436+
lpeer = FieldFromPointer(lname, nb)
437+
438+
return rpeer, lpeer
439+
432440
def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
433441
distributor = f.grid.distributor
434442
nb = distributor._obj_neighborhood
@@ -443,10 +451,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
443451
if d in fixed:
444452
continue
445453

446-
name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
447-
rpeer = FieldFromPointer(name, nb)
448-
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
449-
lpeer = FieldFromPointer(name, nb)
454+
rpeer, lpeer = self._make_peers(d, distributor, nb)
450455

451456
if (d, LEFT) in hse.halos:
452457
# Sending to left, receiving from right
@@ -616,10 +621,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
616621
if d in fixed:
617622
continue
618623

619-
name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
620-
rpeer = FieldFromPointer(name, nb)
621-
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
622-
lpeer = FieldFromPointer(name, nb)
624+
rpeer, lpeer = self._make_peers(d, distributor, nb)
623625

624626
if (d, LEFT) in hse.halos:
625627
# Sending to left, receiving from right
@@ -1297,6 +1299,7 @@ def _as_number(self, v, args):
12971299
return int(subs_op_args(v, args))
12981300

12991301
def _allocate_buffers(self, f, shape, entry):
1302+
# Allocate the send/recv buffers
13001303
entry.sizes = (c_int*len(shape))(*shape)
13011304
size = reduce(mul, shape)*dtype_len(self.target.dtype)
13021305
ctype = dtype_to_ctype(f.dtype)
@@ -1429,21 +1432,12 @@ def _arg_defaults(self, allocator, alias, args=None):
14291432
if d in fixed:
14301433
continue
14311434

1432-
if (d, LEFT) in self.halos:
1433-
entry = self.value[i]
1434-
i = i + 1
1435-
# Sending to left, receiving from right
1436-
shape = mapper[(d, LEFT, OWNED)]
1437-
# Allocate the send/recv buffers
1438-
self._allocate_buffers(f, shape, entry)
1439-
1440-
if (d, RIGHT) in self.halos:
1441-
entry = self.value[i]
1442-
i = i + 1
1443-
# Sending to right, receiving from left
1444-
shape = mapper[(d, RIGHT, OWNED)]
1445-
# Allocate the send/recv buffers
1446-
self._allocate_buffers(f, shape, entry)
1435+
for side in (LEFT, RIGHT):
1436+
if (d, side) in self.halos:
1437+
entry = self.value[i]
1438+
i += 1
1439+
shape = mapper[(d, side, OWNED)]
1440+
self._allocate_buffers(f, shape, entry)
14471441

14481442
return {self.name: self.value}
14491443

0 commit comments

Comments
 (0)