Skip to content

Commit 89387c8

Browse files
committed
mpi: add make_basic_mapper
1 parent 3b30632 commit 89387c8

File tree

1 file changed

+142
-155
lines changed

1 file changed

+142
-155
lines changed

devito/mpi/routines.py

+142-155
Original file line numberDiff line numberDiff line change
@@ -435,23 +435,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
435435

436436
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}
437437

438-
# Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
439-
# `ofs` are symbolic objects. This mapper tells what data values should be
440-
# sent (OWNED) or received (HALO) given dimension and side
441-
mapper = {}
442-
for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)):
443-
if d0 in fixed:
444-
continue
445-
sizes = []
446-
ofs = []
447-
for d1 in f.dimensions:
448-
if d1 in fixed:
449-
ofs.append(fixed[d1])
450-
else:
451-
meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side)
452-
ofs.append(meta.offset)
453-
sizes.append(meta.size)
454-
mapper[(d0, side, region)] = (sizes, ofs)
438+
mapper = self._make_basic_mapper(f, fixed)
455439

456440
body = []
457441
for d in f.dimensions:
@@ -483,6 +467,27 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
483467

484468
return HaloUpdate('haloupdate%s' % key, iet, parameters)
485469

470+
def _make_basic_mapper(self, f, fixed):
471+
# Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
472+
# `ofs` are symbolic objects. This mapper tells what data values should be
473+
# sent (OWNED) or received (HALO) given dimension and side
474+
mapper = {}
475+
for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)):
476+
if d0 in fixed:
477+
continue
478+
sizes = []
479+
ofs = []
480+
for d1 in f.dimensions:
481+
if d1 in fixed:
482+
ofs.append(fixed[d1])
483+
else:
484+
meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side)
485+
ofs.append(meta.offset)
486+
sizes.append(meta.size)
487+
mapper[(d0, side, region)] = (sizes, ofs)
488+
489+
return mapper
490+
486491
def _call_haloupdate(self, name, f, hse, *args):
487492
comm = f.grid.distributor._obj_comm
488493
nb = f.grid.distributor._obj_neighborhood
@@ -526,6 +531,125 @@ def _make_body(self, callcompute, remainder, haloupdates, halowaits):
526531
return List(body=body)
527532

528533

534+
class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder):
535+
536+
"""
537+
A BasicHaloExchangeBuilder making use of pre-allocated buffers for
538+
message size.
539+
540+
Generates:
541+
542+
haloupdate()
543+
compute()
544+
"""
545+
546+
def _make_msg(self, f, hse, key):
547+
# Pass the fixed mapper e.g. {t: otime}
548+
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}
549+
550+
return MPIMsgBasic2('msg%d' % key, f, hse.halos, fixed)
551+
552+
def _make_sendrecv(self, f, hse, key, msg=None):
553+
cast = cast_mapper[(f.c0.dtype, '*')]
554+
comm = f.grid.distributor._obj_comm
555+
556+
bufg = FieldFromPointer(msg._C_field_bufg, msg)
557+
bufs = FieldFromPointer(msg._C_field_bufs, msg)
558+
559+
ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]
560+
ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]
561+
562+
fromrank = Symbol(name='fromrank')
563+
torank = Symbol(name='torank')
564+
565+
sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
566+
for i in range(len(f._dist_dimensions))]
567+
568+
arguments = [cast(bufg)] + sizes + list(f.handles) + ofsg
569+
gather = Gather('gather%s' % key, arguments)
570+
# The `gather` is unnecessary if sending to MPI.PROC_NULL
571+
gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)
572+
573+
arguments = [cast(bufs)] + sizes + list(f.handles) + ofss
574+
scatter = Scatter('scatter%s' % key, arguments)
575+
# The `scatter` must be guarded as we must not alter the halo values along
576+
# the domain boundary, where the sender is actually MPI.PROC_NULL
577+
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)
578+
579+
count = reduce(mul, sizes, 1)*dtype_len(f.dtype)
580+
rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg))
581+
rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg))
582+
recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)),
583+
fromrank, Integer(13), comm, rrecv])
584+
send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)),
585+
torank, Integer(13), comm, rsend])
586+
587+
waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
588+
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])
589+
590+
iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])
591+
592+
parameters = (list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg])
593+
594+
return SendRecv('sendrecv%s' % key, iet, parameters, bufg, bufs)
595+
596+
def _call_sendrecv(self, name, *args, msg=None, haloid=None):
597+
# Drop `sizes` as this HaloExchangeBuilder conveys them through `msg`
598+
f, _, ofsg, ofss, fromrank, torank, comm = args
599+
msg = Byref(IndexedPointer(msg, haloid))
600+
return Call(name, list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg])
601+
602+
def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
603+
distributor = f.grid.distributor
604+
nb = distributor._obj_neighborhood
605+
comm = distributor._obj_comm
606+
607+
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}
608+
609+
mapper = self._make_basic_mapper(f, fixed)
610+
611+
body = []
612+
for d in f.dimensions:
613+
if d in fixed:
614+
continue
615+
616+
name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
617+
rpeer = FieldFromPointer(name, nb)
618+
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
619+
lpeer = FieldFromPointer(name, nb)
620+
621+
if (d, LEFT) in hse.halos:
622+
# Sending to left, receiving from right
623+
lsizes, lofs = mapper[(d, LEFT, OWNED)]
624+
rsizes, rofs = mapper[(d, RIGHT, HALO)]
625+
args = [f, lsizes, lofs, rofs, rpeer, lpeer, comm]
626+
kwargs['haloid'] = len(body)
627+
body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs))
628+
629+
if (d, RIGHT) in hse.halos:
630+
# Sending to right, receiving from left
631+
rsizes, rofs = mapper[(d, RIGHT, OWNED)]
632+
lsizes, lofs = mapper[(d, LEFT, HALO)]
633+
args = [f, rsizes, rofs, lofs, lpeer, rpeer, comm]
634+
kwargs['haloid'] = len(body)
635+
body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs))
636+
637+
iet = List(body=body)
638+
639+
parameters = list(f.handles) + [comm, nb] + list(fixed.values())
640+
641+
node = HaloUpdate('haloupdate%s' % key, iet, parameters)
642+
643+
node = node._rebuild(parameters=node.parameters + (kwargs['msg'],))
644+
645+
return node
646+
647+
def _call_haloupdate(self, name, f, hse, msg):
648+
call = super()._call_haloupdate(name, f, hse)
649+
call = call._rebuild(arguments=call.arguments + (msg,))
650+
return call
651+
652+
529653
class DiagHaloExchangeBuilder(BasicHaloExchangeBuilder):
530654

531655
"""
@@ -741,141 +865,6 @@ def _call_remainder(self, remainder):
741865
return call
742866

743867

744-
class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder):
745-
746-
"""
747-
A BasicHaloExchangeBuilder making use of pre-allocated buffers for
748-
message size.
749-
750-
Generates:
751-
752-
haloupdate()
753-
compute()
754-
"""
755-
756-
def _make_msg(self, f, hse, key):
757-
# Pass the fixed mapper e.g. {t: otime}
758-
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}
759-
760-
return MPIMsgBasic2('msg%d' % key, f, hse.halos, fixed)
761-
762-
def _make_sendrecv(self, f, hse, key, msg=None):
763-
cast = cast_mapper[(f.c0.dtype, '*')]
764-
comm = f.grid.distributor._obj_comm
765-
766-
bufg = FieldFromPointer(msg._C_field_bufg, msg)
767-
bufs = FieldFromPointer(msg._C_field_bufs, msg)
768-
769-
ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]
770-
ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]
771-
772-
fromrank = Symbol(name='fromrank')
773-
torank = Symbol(name='torank')
774-
775-
sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
776-
for i in range(len(f._dist_dimensions))]
777-
778-
arguments = [cast(bufg)] + sizes + list(f.handles) + ofsg
779-
gather = Gather('gather%s' % key, arguments)
780-
# The `gather` is unnecessary if sending to MPI.PROC_NULL
781-
gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)
782-
783-
arguments = [cast(bufs)] + sizes + list(f.handles) + ofss
784-
scatter = Scatter('scatter%s' % key, arguments)
785-
# The `scatter` must be guarded as we must not alter the halo values along
786-
# the domain boundary, where the sender is actually MPI.PROC_NULL
787-
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)
788-
789-
count = reduce(mul, sizes, 1)*dtype_len(f.dtype)
790-
rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg))
791-
rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg))
792-
recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)),
793-
fromrank, Integer(13), comm, rrecv])
794-
send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)),
795-
torank, Integer(13), comm, rsend])
796-
797-
waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
798-
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])
799-
800-
iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])
801-
802-
parameters = (list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg])
803-
804-
return SendRecv('sendrecv%s' % key, iet, parameters, bufg, bufs)
805-
806-
def _call_sendrecv(self, name, *args, msg=None, haloid=None):
807-
# Drop `sizes` as this HaloExchangeBuilder conveys them through `msg`
808-
f, _, ofsg, ofss, fromrank, torank, comm = args
809-
msg = Byref(IndexedPointer(msg, haloid))
810-
return Call(name, list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg])
811-
812-
def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
813-
distributor = f.grid.distributor
814-
nb = distributor._obj_neighborhood
815-
comm = distributor._obj_comm
816-
817-
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}
818-
819-
# Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
820-
# `ofs` are symbolic objects. This mapper tells what data values should be
821-
# sent (OWNED) or received (HALO) given dimension and side
822-
mapper = {}
823-
for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)):
824-
if d0 in fixed:
825-
continue
826-
sizes = []
827-
ofs = []
828-
for d1 in f.dimensions:
829-
if d1 in fixed:
830-
ofs.append(fixed[d1])
831-
else:
832-
meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side)
833-
ofs.append(meta.offset)
834-
sizes.append(meta.size)
835-
mapper[(d0, side, region)] = (sizes, ofs)
836-
837-
body = []
838-
for d in f.dimensions:
839-
if d in fixed:
840-
continue
841-
842-
name = ''.join('r' if i is d else 'c' for i in distributor.dimensions)
843-
rpeer = FieldFromPointer(name, nb)
844-
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions)
845-
lpeer = FieldFromPointer(name, nb)
846-
847-
if (d, LEFT) in hse.halos:
848-
# Sending to left, receiving from right
849-
lsizes, lofs = mapper[(d, LEFT, OWNED)]
850-
rsizes, rofs = mapper[(d, RIGHT, HALO)]
851-
args = [f, lsizes, lofs, rofs, rpeer, lpeer, comm]
852-
kwargs['haloid'] = len(body)
853-
body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs))
854-
855-
if (d, RIGHT) in hse.halos:
856-
# Sending to right, receiving from left
857-
rsizes, rofs = mapper[(d, RIGHT, OWNED)]
858-
lsizes, lofs = mapper[(d, LEFT, HALO)]
859-
args = [f, rsizes, rofs, lofs, lpeer, rpeer, comm]
860-
kwargs['haloid'] = len(body)
861-
body.append(self._call_sendrecv(sendrecv.name, *args, **kwargs))
862-
863-
iet = List(body=body)
864-
865-
parameters = list(f.handles) + [comm, nb] + list(fixed.values())
866-
867-
node = HaloUpdate('haloupdate%s' % key, iet, parameters)
868-
869-
node = node._rebuild(parameters=node.parameters + (kwargs['msg'],))
870-
871-
return node
872-
873-
def _call_haloupdate(self, name, f, hse, msg):
874-
call = super()._call_haloupdate(name, f, hse)
875-
call = call._rebuild(arguments=call.arguments + (msg,))
876-
return call
877-
878-
879868
class Overlap2HaloExchangeBuilder(OverlapHaloExchangeBuilder):
880869

881870
"""
@@ -1425,9 +1414,7 @@ def _arg_defaults(self, allocator, alias, args=None):
14251414

14261415
fixed = self._fixed
14271416

1428-
# Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
1429-
# `ofs` are symbolic objects. This mapper tells what data values should be
1430-
# sent (OWNED) or received (HALO) given dimension and side
1417+
# Build a mapper `(dim, side, region) -> (size)` for `f`.
14311418
mapper = {}
14321419
for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)):
14331420
if d0 in fixed:

0 commit comments

Comments
 (0)