Skip to content

Commit 68631a5

Browse files
committedFeb 25, 2025··
compiler: Reduce some code after reviews
1 parent 6950fa0 commit 68631a5

File tree

3 files changed

+21
-27
lines changed

3 files changed

+21
-27
lines changed
 

‎benchmarks/user/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ and run with `mpirun -n number_of_processes python benchmark.py ...`
9595

9696
Devito supports multiple MPI schemes for halo exchange.
9797

98-
* Devito's most prevalent MPI modes are three: `basic`, `diag2` and `full`.
99-
and are respectively activated e.g., via `DEVITO_MPI=basic`.
98+
* Devito's most prevalent MPI modes are three: `basic2`, `diag2` and `full`.
99+
and are respectively activated e.g., via `DEVITO_MPI=basic2`.
100100
These modes may perform better under different factors such as arithmetic intensity,
101101
or number of fields used in the computation.
102102

‎devito/mpi/routines.py

+18-24
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
443443
if d in fixed:
444444
continue
445445

446-
name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
447-
rpeer = FieldFromPointer(name, nb)
448-
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
449-
lpeer = FieldFromPointer(name, nb)
446+
rpeer, lpeer = self._make_peers(d, distributor, nb)
450447

451448
if (d, LEFT) in hse.halos:
452449
# Sending to left, receiving from right
@@ -491,6 +488,14 @@ def _make_basic_mapper(self, f, fixed):
491488

492489
return mapper
493490

491+
def _make_peers(self, d, distributor, nb):
492+
rname = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
493+
rpeer = FieldFromPointer(rname, nb)
494+
lname = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
495+
lpeer = FieldFromPointer(lname, nb)
496+
497+
return rpeer, lpeer
498+
494499
def _call_haloupdate(self, name, f, hse, *args):
495500
comm = f.grid.distributor._obj_comm
496501
nb = f.grid.distributor._obj_neighborhood
@@ -537,7 +542,7 @@ def _make_body(self, callcompute, remainder, haloupdates, halowaits):
537542
class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder):
538543

539544
"""
540-
A BasicHaloExchangeBuilder making use of pre-allocated buffers for
545+
A BasicHaloExchangeBuilder using pre-allocated buffers for
541546
message size.
542547
543548
Generates:
@@ -616,10 +621,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
616621
if d in fixed:
617622
continue
618623

619-
name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
620-
rpeer = FieldFromPointer(name, nb)
621-
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
622-
lpeer = FieldFromPointer(name, nb)
624+
rpeer, lpeer = self._make_peers(d, distributor, nb)
623625

624626
if (d, LEFT) in hse.halos:
625627
# Sending to left, receiving from right
@@ -1297,6 +1299,7 @@ def _as_number(self, v, args):
12971299
return int(subs_op_args(v, args))
12981300

12991301
def _allocate_buffers(self, f, shape, entry):
1302+
# Allocate the send/recv buffers
13001303
entry.sizes = (c_int*len(shape))(*shape)
13011304
size = reduce(mul, shape)*dtype_len(self.target.dtype)
13021305
ctype = dtype_to_ctype(f.dtype)
@@ -1429,21 +1432,12 @@ def _arg_defaults(self, allocator, alias, args=None):
14291432
if d in fixed:
14301433
continue
14311434

1432-
if (d, LEFT) in self.halos:
1433-
entry = self.value[i]
1434-
i = i + 1
1435-
# Sending to left, receiving from right
1436-
shape = mapper[(d, LEFT, OWNED)]
1437-
# Allocate the send/recv buffers
1438-
self._allocate_buffers(f, shape, entry)
1439-
1440-
if (d, RIGHT) in self.halos:
1441-
entry = self.value[i]
1442-
i = i + 1
1443-
# Sending to right, receiving from left
1444-
shape = mapper[(d, RIGHT, OWNED)]
1445-
# Allocate the send/recv buffers
1446-
self._allocate_buffers(f, shape, entry)
1435+
for side in (LEFT, RIGHT):
1436+
if (d, side) in self.halos:
1437+
entry = self.value[i]
1438+
i += 1
1439+
shape = mapper[(d, side, OWNED)]
1440+
self._allocate_buffers(f, shape, entry)
14471441

14481442
return {self.name: self.value}
14491443

‎tests/test_mpi.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2289,7 +2289,7 @@ def test_haloupdate_issue_1613(self, mode):
22892289
assert dims[0].is_Modulo
22902290
assert dims[0].origin is t
22912291

2292-
@pytest.mark.parallel(mode=[(4, 'basic'), (4, 'diag2'), (4, 'overlap2')])
2292+
@pytest.mark.parallel(mode=[(4, 'basic2'), (4, 'diag2'), (4, 'overlap2')])
22932293
def test_cire(self, mode):
22942294
"""
22952295
Check correctness when the DSE extracts aliases and places them

0 commit comments

Comments
 (0)
Please sign in to comment.