Skip to content

Commit ff919e2

Browse files
committed
mpi: Add _allocate_buffers func, drop redundant determinism
1 parent c85bceb commit ff919e2

File tree

2 files changed

+22
-47
lines changed

2 files changed

+22
-47
lines changed

devito/mpi/routines.py

+22-45
Original file line numberDiff line numberDiff line change
@@ -468,9 +468,11 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
468468
return HaloUpdate('haloupdate%s' % key, iet, parameters)
469469

470470
def _make_basic_mapper(self, f, fixed):
471-
# Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
472-
# `ofs` are symbolic objects. This mapper tells what data values should be
473-
# sent (OWNED) or received (HALO) given dimension and side
471+
"""
472+
Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
473+
`ofs` are symbolic objects. This mapper tells what data values should be
474+
sent (OWNED) or received (HALO) given dimension and side
475+
"""
474476
mapper = {}
475477
for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)):
476478
if d0 in fixed:
@@ -636,13 +638,9 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
636638

637639
iet = List(body=body)
638640

639-
parameters = list(f.handles) + [comm, nb] + list(fixed.values())
640-
641-
node = HaloUpdate('haloupdate%s' % key, iet, parameters)
642-
643-
node = node._rebuild(parameters=node.parameters + (kwargs['msg'],))
641+
parameters = list(f.handles) + [comm, nb] + list(fixed.values()) + [kwargs['msg']]
644642

645-
return node
643+
return HaloUpdate('haloupdate%s' % key, iet, parameters)
646644

647645
def _call_haloupdate(self, name, f, hse, msg):
648646
call = super()._call_haloupdate(name, f, hse)
@@ -1298,6 +1296,17 @@ def _as_number(self, v, args):
12981296
assert args is not None
12991297
return int(subs_op_args(v, args))
13001298

1299+
def _allocate_buffers(self, f, shape, entry):
1300+
entry.sizes = (c_int*len(shape))(*shape)
1301+
size = reduce(mul, shape)*dtype_len(self.target.dtype)
1302+
ctype = dtype_to_ctype(f.dtype)
1303+
entry.bufg, bufg_memfree_args = self._allocator._alloc_C_libcall(size, ctype)
1304+
entry.bufs, bufs_memfree_args = self._allocator._alloc_C_libcall(size, ctype)
1305+
# The `memfree_args` will be used to deallocate the buffer upon
1306+
# returning from C-land
1307+
self._memfree_args.extend([bufg_memfree_args, bufs_memfree_args])
1308+
return
1309+
13011310
def _arg_defaults(self, allocator, alias, args=None):
13021311
# Lazy initialization if `allocator` is necessary as the `allocator`
13031312
# type isn't really known until an Operator is constructed
@@ -1315,17 +1324,9 @@ def _arg_defaults(self, allocator, alias, args=None):
13151324
except AttributeError:
13161325
assert side is CENTER
13171326
shape.append(self._as_number(f._size_domain[dim], args))
1318-
entry.sizes = (c_int*len(shape))(*shape)
13191327

13201328
# Allocate the send/recv buffers
1321-
size = reduce(mul, shape)*dtype_len(self.target.dtype)
1322-
ctype = dtype_to_ctype(f.dtype)
1323-
entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype)
1324-
entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype)
1325-
1326-
# The `memfree_args` will be used to deallocate the buffer upon
1327-
# returning from C-land
1328-
self._memfree_args.extend([bufg_memfree_args, bufs_memfree_args])
1329+
self._allocate_buffers(f, shape, entry)
13291330

13301331
return {self.name: self.value}
13311332

@@ -1376,17 +1377,9 @@ def _arg_defaults(self, allocator, alias, args=None):
13761377
except AttributeError:
13771378
assert side is CENTER
13781379
shape.append(self._as_number(f._size_domain[dim], args))
1379-
entry.sizes = (c_int*len(shape))(*shape)
13801380

13811381
# Allocate the send/recv buffers
1382-
size = reduce(mul, shape)*dtype_len(self.target.dtype)
1383-
ctype = dtype_to_ctype(f.dtype)
1384-
entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype)
1385-
entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype)
1386-
1387-
# The `memfree_args` will be used to deallocate the buffer upon
1388-
# returning from C-land
1389-
self._memfree_args.extend([bufg_memfree_args, bufs_memfree_args])
1382+
self._allocate_buffers(f, shape, entry)
13901383

13911384
return {self.name: self.value}
13921385

@@ -1444,31 +1437,15 @@ def _arg_defaults(self, allocator, alias, args=None):
14441437
# Sending to left, receiving from right
14451438
shape = mapper[(d, LEFT, OWNED)]
14461439
# Allocate the send/recv buffers
1447-
entry.sizes = (c_int*len(shape))(*shape)
1448-
size = reduce(mul, shape)*dtype_len(self.target.dtype)
1449-
ctype = dtype_to_ctype(f.dtype)
1450-
entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype)
1451-
entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype)
1452-
1453-
# The `memfree_args` will be used to deallocate the buffer upon
1454-
# returning from C-land
1455-
self._memfree_args.extend([bufg_memfree_args, bufs_memfree_args])
1440+
self._allocate_buffers(f, shape, entry)
14561441

14571442
if (d, RIGHT) in self.halos:
14581443
entry = self.value[i]
14591444
i = i + 1
14601445
# Sending to right, receiving from left
14611446
shape = mapper[(d, RIGHT, OWNED)]
14621447
# Allocate the send/recv buffers
1463-
entry.sizes = (c_int*len(shape))(*shape)
1464-
size = reduce(mul, shape)*dtype_len(self.target.dtype)
1465-
ctype = dtype_to_ctype(f.dtype)
1466-
entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype)
1467-
entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype)
1468-
1469-
# The `memfree_args` will be used to deallocate the buffer upon
1470-
# returning from C-land
1471-
self._memfree_args.extend([bufg_memfree_args, bufs_memfree_args])
1448+
self._allocate_buffers(f, shape, entry)
14721449

14731450
return {self.name: self.value}
14741451

devito/passes/iet/misc.py

-2
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,6 @@ def generate_macros(iet):
140140
# Generate Macros from higher-level SymPy objects
141141
applications = FindApplications().visit(iet)
142142
headers = set().union(*[_generate_macros(i) for i in applications])
143-
# Sort for deterministic code generation
144-
headers = sorted(headers)
145143

146144
# Some special Symbols may represent Macros defined in standard libraries,
147145
# so we need to include the respective includes

0 commit comments

Comments
 (0)