Skip to content

Commit b4429a4

Browse files
committed
mpi: Drop redundant code, come cleanup
1 parent a62ccd9 commit b4429a4

File tree

2 files changed

+11
-23
lines changed

2 files changed

+11
-23
lines changed

devito/mpi/routines.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -753,8 +753,10 @@ class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder):
753753
"""
754754

755755
def _make_msg(self, f, hse, key):
756-
# Pass the whole of hse
757-
return MPIMsgBasic('msg%d' % key, f, hse.halos, hse)
756+
# Pass the fixed mapper e.g. {t: otime}
757+
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}
758+
759+
return MPIMsgBasic2('msg%d' % key, f, hse.halos, fixed)
758760

759761
def _make_sendrecv(self, f, hse, key, msg=None):
760762
cast = cast_mapper[(f.c0.dtype, '*')]
@@ -1312,7 +1314,6 @@ def _arg_defaults(self, allocator, alias, args=None):
13121314
self._allocator = allocator
13131315

13141316
f = alias or self.target.c0
1315-
13161317
for i, halo in enumerate(self.halos):
13171318
entry = self.value[i]
13181319

@@ -1374,7 +1375,6 @@ def _arg_defaults(self, allocator, alias, args=None):
13741375
self._allocator = allocator
13751376

13761377
f = alias or self.target.c0
1377-
13781378
for i, halo in enumerate(self.halos):
13791379
entry = self.value[i]
13801380

@@ -1401,32 +1401,28 @@ def _arg_defaults(self, allocator, alias, args=None):
14011401
return {self.name: self.value}
14021402

14031403

1404-
class MPIMsgBasic(MPIMsgBase):
1404+
class MPIMsgBasic2(MPIMsgBase):
14051405

1406-
def __init__(self, name, target, halos, hse=None):
1406+
def __init__(self, name, target, halos, fixed=None):
14071407
self._target = target
14081408
self._halos = halos
14091409

14101410
super().__init__(name, 'msg', self.fields)
14111411

14121412
# Required for buffer allocation/deallocation before/after jumping/returning
14131413
# to/from C-land
1414-
self._hse = hse
1414+
self._fixed = fixed
14151415
self._allocator = None
14161416
self._memfree_args = []
14171417

1418-
@property
1419-
def hse(self):
1420-
return self._hse
1421-
14221418
def _arg_defaults(self, allocator, alias, args=None):
14231419
# Lazy initialization if `allocator` is necessary as the `allocator`
14241420
# type isn't really known until an Operator is constructed
14251421
self._allocator = allocator
14261422

14271423
f = alias or self.target.c0
14281424

1429-
fixed = {d: Symbol(name="o%s" % d.root) for d in self.hse.loc_indices}
1425+
fixed = self._fixed
14301426

14311427
# Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
14321428
# `ofs` are symbolic objects. This mapper tells what data values should be
@@ -1440,7 +1436,6 @@ def _arg_defaults(self, allocator, alias, args=None):
14401436
if d1 in fixed:
14411437
continue
14421438
else:
1443-
# meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side)
14441439
if d0 is d1:
14451440
if region is OWNED:
14461441
sizes.append(getattr(f._size_owned[d0], side.name))
@@ -1455,7 +1450,7 @@ def _arg_defaults(self, allocator, alias, args=None):
14551450
if d in fixed:
14561451
continue
14571452

1458-
if (d, LEFT) in self.hse.halos:
1453+
if (d, LEFT) in self.halos:
14591454
entry = self.value[i]
14601455
i = i + 1
14611456
# Sending to left, receiving from right
@@ -1471,14 +1466,13 @@ def _arg_defaults(self, allocator, alias, args=None):
14711466
# returning from C-land
14721467
self._memfree_args.extend([bufg_memfree_args, bufs_memfree_args])
14731468

1474-
if (d, RIGHT) in self.hse.halos:
1469+
if (d, RIGHT) in self.halos:
14751470
entry = self.value[i]
14761471
i = i + 1
14771472
# Sending to right, receiving from left
14781473
shape = mapper[(d, RIGHT, OWNED)]
1479-
entry.sizes = (c_int*len(shape))(*shape)
1480-
14811474
# Allocate the send/recv buffers
1475+
entry.sizes = (c_int*len(shape))(*shape)
14821476
size = reduce(mul, shape)*dtype_len(self.target.dtype)
14831477
ctype = dtype_to_ctype(f.dtype)
14841478
entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype)

tests/test_mpi.py

-6
Original file line numberDiff line numberDiff line change
@@ -2647,17 +2647,11 @@ def run_adjoint_F(self, nd):
26472647
assert np.isclose(norm(u) / Eu, 1.0)
26482648
assert np.isclose(norm(rec) / Erec, 1.0)
26492649

2650-
print(norm(rec))
2651-
print("Erec is:", Erec)
2652-
2653-
print("----------------------------==============----------------------")
26542650
# Run adjoint operator
26552651
srca, v, _ = solver.adjoint(rec=rec)
26562652
assert np.isclose(norm(v) / Ev, 1.0)
26572653
assert np.isclose(norm(srca) / Esrca, 1.0)
26582654

2659-
print("----------------------------==============----------------------")
2660-
26612655
# Adjoint test: Verify <Ax,y> matches <x, A^Ty> closely
26622656
term1 = inner(srca, solver.geometry.src)
26632657
term2 = norm(rec)**2

0 commit comments

Comments
 (0)