@@ -435,23 +435,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
435
435
436
436
fixed = {d : Symbol (name = "o%s" % d .root ) for d in hse .loc_indices }
437
437
438
- # Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
439
- # `ofs` are symbolic objects. This mapper tells what data values should be
440
- # sent (OWNED) or received (HALO) given dimension and side
441
- mapper = {}
442
- for d0 , side , region in product (f .dimensions , (LEFT , RIGHT ), (OWNED , HALO )):
443
- if d0 in fixed :
444
- continue
445
- sizes = []
446
- ofs = []
447
- for d1 in f .dimensions :
448
- if d1 in fixed :
449
- ofs .append (fixed [d1 ])
450
- else :
451
- meta = f ._C_get_field (region if d0 is d1 else NOPAD , d1 , side )
452
- ofs .append (meta .offset )
453
- sizes .append (meta .size )
454
- mapper [(d0 , side , region )] = (sizes , ofs )
438
+ mapper = self ._make_basic_mapper (f , fixed )
455
439
456
440
body = []
457
441
for d in f .dimensions :
@@ -483,6 +467,27 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
483
467
484
468
return HaloUpdate ('haloupdate%s' % key , iet , parameters )
485
469
470
+ def _make_basic_mapper (self , f , fixed ):
471
+ # Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
472
+ # `ofs` are symbolic objects. This mapper tells what data values should be
473
+ # sent (OWNED) or received (HALO) given dimension and side
474
+ mapper = {}
475
+ for d0 , side , region in product (f .dimensions , (LEFT , RIGHT ), (OWNED , HALO )):
476
+ if d0 in fixed :
477
+ continue
478
+ sizes = []
479
+ ofs = []
480
+ for d1 in f .dimensions :
481
+ if d1 in fixed :
482
+ ofs .append (fixed [d1 ])
483
+ else :
484
+ meta = f ._C_get_field (region if d0 is d1 else NOPAD , d1 , side )
485
+ ofs .append (meta .offset )
486
+ sizes .append (meta .size )
487
+ mapper [(d0 , side , region )] = (sizes , ofs )
488
+
489
+ return mapper
490
+
486
491
def _call_haloupdate (self , name , f , hse , * args ):
487
492
comm = f .grid .distributor ._obj_comm
488
493
nb = f .grid .distributor ._obj_neighborhood
@@ -526,6 +531,125 @@ def _make_body(self, callcompute, remainder, haloupdates, halowaits):
526
531
return List (body = body )
527
532
528
533
534
+ class Basic2HaloExchangeBuilder (BasicHaloExchangeBuilder ):
535
+
536
+ """
537
+ A BasicHaloExchangeBuilder making use of pre-allocated buffers for
538
+ message size.
539
+
540
+ Generates:
541
+
542
+ haloupdate()
543
+ compute()
544
+ """
545
+
546
+ def _make_msg (self , f , hse , key ):
547
+ # Pass the fixed mapper e.g. {t: otime}
548
+ fixed = {d : Symbol (name = "o%s" % d .root ) for d in hse .loc_indices }
549
+
550
+ return MPIMsgBasic2 ('msg%d' % key , f , hse .halos , fixed )
551
+
552
+ def _make_sendrecv (self , f , hse , key , msg = None ):
553
+ cast = cast_mapper [(f .c0 .dtype , '*' )]
554
+ comm = f .grid .distributor ._obj_comm
555
+
556
+ bufg = FieldFromPointer (msg ._C_field_bufg , msg )
557
+ bufs = FieldFromPointer (msg ._C_field_bufs , msg )
558
+
559
+ ofsg = [Symbol (name = 'og%s' % d .root ) for d in f .dimensions ]
560
+ ofss = [Symbol (name = 'os%s' % d .root ) for d in f .dimensions ]
561
+
562
+ fromrank = Symbol (name = 'fromrank' )
563
+ torank = Symbol (name = 'torank' )
564
+
565
+ sizes = [FieldFromPointer ('%s[%d]' % (msg ._C_field_sizes , i ), msg )
566
+ for i in range (len (f ._dist_dimensions ))]
567
+
568
+ arguments = [cast (bufg )] + sizes + list (f .handles ) + ofsg
569
+ gather = Gather ('gather%s' % key , arguments )
570
+ # The `gather` is unnecessary if sending to MPI.PROC_NULL
571
+ gather = Conditional (CondNe (torank , Macro ('MPI_PROC_NULL' )), gather )
572
+
573
+ arguments = [cast (bufs )] + sizes + list (f .handles ) + ofss
574
+ scatter = Scatter ('scatter%s' % key , arguments )
575
+ # The `scatter` must be guarded as we must not alter the halo values along
576
+ # the domain boundary, where the sender is actually MPI.PROC_NULL
577
+ scatter = Conditional (CondNe (fromrank , Macro ('MPI_PROC_NULL' )), scatter )
578
+
579
+ count = reduce (mul , sizes , 1 )* dtype_len (f .dtype )
580
+ rrecv = Byref (FieldFromPointer (msg ._C_field_rrecv , msg ))
581
+ rsend = Byref (FieldFromPointer (msg ._C_field_rsend , msg ))
582
+ recv = IrecvCall ([bufs , count , Macro (dtype_to_mpitype (f .dtype )),
583
+ fromrank , Integer (13 ), comm , rrecv ])
584
+ send = IsendCall ([bufg , count , Macro (dtype_to_mpitype (f .dtype )),
585
+ torank , Integer (13 ), comm , rsend ])
586
+
587
+ waitrecv = Call ('MPI_Wait' , [rrecv , Macro ('MPI_STATUS_IGNORE' )])
588
+ waitsend = Call ('MPI_Wait' , [rsend , Macro ('MPI_STATUS_IGNORE' )])
589
+
590
+ iet = List (body = [recv , gather , send , waitsend , waitrecv , scatter ])
591
+
592
+ parameters = (list (f .handles ) + ofsg + ofss + [fromrank , torank , comm , msg ])
593
+
594
+ return SendRecv ('sendrecv%s' % key , iet , parameters , bufg , bufs )
595
+
596
+ def _call_sendrecv (self , name , * args , msg = None , haloid = None ):
597
+ # Drop `sizes` as this HaloExchangeBuilder conveys them through `msg`
598
+ f , _ , ofsg , ofss , fromrank , torank , comm = args
599
+ msg = Byref (IndexedPointer (msg , haloid ))
600
+ return Call (name , list (f .handles ) + ofsg + ofss + [fromrank , torank , comm , msg ])
601
+
602
+ def _make_haloupdate (self , f , hse , key , sendrecv , ** kwargs ):
603
+ distributor = f .grid .distributor
604
+ nb = distributor ._obj_neighborhood
605
+ comm = distributor ._obj_comm
606
+
607
+ fixed = {d : Symbol (name = "o%s" % d .root ) for d in hse .loc_indices }
608
+
609
+ mapper = self ._make_basic_mapper (f , fixed )
610
+
611
+ body = []
612
+ for d in f .dimensions :
613
+ if d in fixed :
614
+ continue
615
+
616
+ name = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
617
+ rpeer = FieldFromPointer (name , nb )
618
+ name = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
619
+ lpeer = FieldFromPointer (name , nb )
620
+
621
+ if (d , LEFT ) in hse .halos :
622
+ # Sending to left, receiving from right
623
+ lsizes , lofs = mapper [(d , LEFT , OWNED )]
624
+ rsizes , rofs = mapper [(d , RIGHT , HALO )]
625
+ args = [f , lsizes , lofs , rofs , rpeer , lpeer , comm ]
626
+ kwargs ['haloid' ] = len (body )
627
+ body .append (self ._call_sendrecv (sendrecv .name , * args , ** kwargs ))
628
+
629
+ if (d , RIGHT ) in hse .halos :
630
+ # Sending to right, receiving from left
631
+ rsizes , rofs = mapper [(d , RIGHT , OWNED )]
632
+ lsizes , lofs = mapper [(d , LEFT , HALO )]
633
+ args = [f , rsizes , rofs , lofs , lpeer , rpeer , comm ]
634
+ kwargs ['haloid' ] = len (body )
635
+ body .append (self ._call_sendrecv (sendrecv .name , * args , ** kwargs ))
636
+
637
+ iet = List (body = body )
638
+
639
+ parameters = list (f .handles ) + [comm , nb ] + list (fixed .values ())
640
+
641
+ node = HaloUpdate ('haloupdate%s' % key , iet , parameters )
642
+
643
+ node = node ._rebuild (parameters = node .parameters + (kwargs ['msg' ],))
644
+
645
+ return node
646
+
647
+ def _call_haloupdate (self , name , f , hse , msg ):
648
+ call = super ()._call_haloupdate (name , f , hse )
649
+ call = call ._rebuild (arguments = call .arguments + (msg ,))
650
+ return call
651
+
652
+
529
653
class DiagHaloExchangeBuilder (BasicHaloExchangeBuilder ):
530
654
531
655
"""
@@ -741,141 +865,6 @@ def _call_remainder(self, remainder):
741
865
return call
742
866
743
867
744
- class Basic2HaloExchangeBuilder (BasicHaloExchangeBuilder ):
745
-
746
- """
747
- A BasicHaloExchangeBuilder making use of pre-allocated buffers for
748
- message size.
749
-
750
- Generates:
751
-
752
- haloupdate()
753
- compute()
754
- """
755
-
756
- def _make_msg (self , f , hse , key ):
757
- # Pass the fixed mapper e.g. {t: otime}
758
- fixed = {d : Symbol (name = "o%s" % d .root ) for d in hse .loc_indices }
759
-
760
- return MPIMsgBasic2 ('msg%d' % key , f , hse .halos , fixed )
761
-
762
- def _make_sendrecv (self , f , hse , key , msg = None ):
763
- cast = cast_mapper [(f .c0 .dtype , '*' )]
764
- comm = f .grid .distributor ._obj_comm
765
-
766
- bufg = FieldFromPointer (msg ._C_field_bufg , msg )
767
- bufs = FieldFromPointer (msg ._C_field_bufs , msg )
768
-
769
- ofsg = [Symbol (name = 'og%s' % d .root ) for d in f .dimensions ]
770
- ofss = [Symbol (name = 'os%s' % d .root ) for d in f .dimensions ]
771
-
772
- fromrank = Symbol (name = 'fromrank' )
773
- torank = Symbol (name = 'torank' )
774
-
775
- sizes = [FieldFromPointer ('%s[%d]' % (msg ._C_field_sizes , i ), msg )
776
- for i in range (len (f ._dist_dimensions ))]
777
-
778
- arguments = [cast (bufg )] + sizes + list (f .handles ) + ofsg
779
- gather = Gather ('gather%s' % key , arguments )
780
- # The `gather` is unnecessary if sending to MPI.PROC_NULL
781
- gather = Conditional (CondNe (torank , Macro ('MPI_PROC_NULL' )), gather )
782
-
783
- arguments = [cast (bufs )] + sizes + list (f .handles ) + ofss
784
- scatter = Scatter ('scatter%s' % key , arguments )
785
- # The `scatter` must be guarded as we must not alter the halo values along
786
- # the domain boundary, where the sender is actually MPI.PROC_NULL
787
- scatter = Conditional (CondNe (fromrank , Macro ('MPI_PROC_NULL' )), scatter )
788
-
789
- count = reduce (mul , sizes , 1 )* dtype_len (f .dtype )
790
- rrecv = Byref (FieldFromPointer (msg ._C_field_rrecv , msg ))
791
- rsend = Byref (FieldFromPointer (msg ._C_field_rsend , msg ))
792
- recv = IrecvCall ([bufs , count , Macro (dtype_to_mpitype (f .dtype )),
793
- fromrank , Integer (13 ), comm , rrecv ])
794
- send = IsendCall ([bufg , count , Macro (dtype_to_mpitype (f .dtype )),
795
- torank , Integer (13 ), comm , rsend ])
796
-
797
- waitrecv = Call ('MPI_Wait' , [rrecv , Macro ('MPI_STATUS_IGNORE' )])
798
- waitsend = Call ('MPI_Wait' , [rsend , Macro ('MPI_STATUS_IGNORE' )])
799
-
800
- iet = List (body = [recv , gather , send , waitsend , waitrecv , scatter ])
801
-
802
- parameters = (list (f .handles ) + ofsg + ofss + [fromrank , torank , comm , msg ])
803
-
804
- return SendRecv ('sendrecv%s' % key , iet , parameters , bufg , bufs )
805
-
806
- def _call_sendrecv (self , name , * args , msg = None , haloid = None ):
807
- # Drop `sizes` as this HaloExchangeBuilder conveys them through `msg`
808
- f , _ , ofsg , ofss , fromrank , torank , comm = args
809
- msg = Byref (IndexedPointer (msg , haloid ))
810
- return Call (name , list (f .handles ) + ofsg + ofss + [fromrank , torank , comm , msg ])
811
-
812
- def _make_haloupdate (self , f , hse , key , sendrecv , ** kwargs ):
813
- distributor = f .grid .distributor
814
- nb = distributor ._obj_neighborhood
815
- comm = distributor ._obj_comm
816
-
817
- fixed = {d : Symbol (name = "o%s" % d .root ) for d in hse .loc_indices }
818
-
819
- # Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
820
- # `ofs` are symbolic objects. This mapper tells what data values should be
821
- # sent (OWNED) or received (HALO) given dimension and side
822
- mapper = {}
823
- for d0 , side , region in product (f .dimensions , (LEFT , RIGHT ), (OWNED , HALO )):
824
- if d0 in fixed :
825
- continue
826
- sizes = []
827
- ofs = []
828
- for d1 in f .dimensions :
829
- if d1 in fixed :
830
- ofs .append (fixed [d1 ])
831
- else :
832
- meta = f ._C_get_field (region if d0 is d1 else NOPAD , d1 , side )
833
- ofs .append (meta .offset )
834
- sizes .append (meta .size )
835
- mapper [(d0 , side , region )] = (sizes , ofs )
836
-
837
- body = []
838
- for d in f .dimensions :
839
- if d in fixed :
840
- continue
841
-
842
- name = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
843
- rpeer = FieldFromPointer (name , nb )
844
- name = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
845
- lpeer = FieldFromPointer (name , nb )
846
-
847
- if (d , LEFT ) in hse .halos :
848
- # Sending to left, receiving from right
849
- lsizes , lofs = mapper [(d , LEFT , OWNED )]
850
- rsizes , rofs = mapper [(d , RIGHT , HALO )]
851
- args = [f , lsizes , lofs , rofs , rpeer , lpeer , comm ]
852
- kwargs ['haloid' ] = len (body )
853
- body .append (self ._call_sendrecv (sendrecv .name , * args , ** kwargs ))
854
-
855
- if (d , RIGHT ) in hse .halos :
856
- # Sending to right, receiving from left
857
- rsizes , rofs = mapper [(d , RIGHT , OWNED )]
858
- lsizes , lofs = mapper [(d , LEFT , HALO )]
859
- args = [f , rsizes , rofs , lofs , lpeer , rpeer , comm ]
860
- kwargs ['haloid' ] = len (body )
861
- body .append (self ._call_sendrecv (sendrecv .name , * args , ** kwargs ))
862
-
863
- iet = List (body = body )
864
-
865
- parameters = list (f .handles ) + [comm , nb ] + list (fixed .values ())
866
-
867
- node = HaloUpdate ('haloupdate%s' % key , iet , parameters )
868
-
869
- node = node ._rebuild (parameters = node .parameters + (kwargs ['msg' ],))
870
-
871
- return node
872
-
873
- def _call_haloupdate (self , name , f , hse , msg ):
874
- call = super ()._call_haloupdate (name , f , hse )
875
- call = call ._rebuild (arguments = call .arguments + (msg ,))
876
- return call
877
-
878
-
879
868
class Overlap2HaloExchangeBuilder (OverlapHaloExchangeBuilder ):
880
869
881
870
"""
@@ -1425,9 +1414,7 @@ def _arg_defaults(self, allocator, alias, args=None):
1425
1414
1426
1415
fixed = self ._fixed
1427
1416
1428
- # Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and
1429
- # `ofs` are symbolic objects. This mapper tells what data values should be
1430
- # sent (OWNED) or received (HALO) given dimension and side
1417
+ # Build a mapper `(dim, side, region) -> (size)` for `f`.
1431
1418
mapper = {}
1432
1419
for d0 , side , region in product (f .dimensions , (LEFT , RIGHT ), (OWNED , HALO )):
1433
1420
if d0 in fixed :
0 commit comments