Skip to content

Commit a98a4fa

Browse files
committed
pause receiving while submitting tasks
results coming in during submission causes thread contention while submitting tasks pause receiving messages while we are preparing tasks to be submitted
1 parent 3093386 commit a98a4fa

File tree

2 files changed

+103
-24
lines changed

2 files changed

+103
-24
lines changed

ipyparallel/client/client.py

+73-10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import warnings
1313
from collections.abc import Iterable
1414
from concurrent.futures import Future
15+
from contextlib import contextmanager
1516
from getpass import getpass
1617
from pprint import pprint
1718
from threading import current_thread
@@ -990,21 +991,59 @@ def _stop_io_thread(self):
990991
self._io_thread.join()
991992

992993
def _setup_streams(self):
993-
self._query_stream = ZMQStream(self._query_socket, self._io_loop)
994-
self._query_stream.on_recv(self._dispatch_single_reply, copy=False)
995-
self._control_stream = ZMQStream(self._control_socket, self._io_loop)
994+
self._streams = [] # all streams
995+
self._engine_streams = [] # streams that talk to engines
996+
self._query_stream = s = ZMQStream(self._query_socket, self._io_loop)
997+
self._streams.append(s)
998+
self._notification_stream = s = ZMQStream(
999+
self._notification_socket, self._io_loop
1000+
)
1001+
self._streams.append(s)
1002+
1003+
self._control_stream = s = ZMQStream(self._control_socket, self._io_loop)
1004+
self._streams.append(s)
1005+
self._engine_streams.append(s)
1006+
self._mux_stream = s = ZMQStream(self._mux_socket, self._io_loop)
1007+
self._streams.append(s)
1008+
self._engine_streams.append(s)
1009+
self._task_stream = s = ZMQStream(self._task_socket, self._io_loop)
1010+
self._streams.append(s)
1011+
self._engine_streams.append(s)
1012+
self._broadcast_stream = s = ZMQStream(self._broadcast_socket, self._io_loop)
1013+
self._streams.append(s)
1014+
self._engine_streams.append(s)
1015+
self._iopub_stream = s = ZMQStream(self._iopub_socket, self._io_loop)
1016+
self._streams.append(s)
1017+
self._engine_streams.append(s)
1018+
self._start_receiving(all=True)
1019+
1020+
def _start_receiving(self, all=False):
1021+
"""Start receiving on streams
1022+
1023+
default: only engine streams
1024+
1025+
if all: include hub streams
1026+
"""
1027+
if all:
1028+
self._query_stream.on_recv(self._dispatch_single_reply, copy=False)
1029+
self._notification_stream.on_recv(self._dispatch_notification, copy=False)
9961030
self._control_stream.on_recv(self._dispatch_single_reply, copy=False)
997-
self._mux_stream = ZMQStream(self._mux_socket, self._io_loop)
9981031
self._mux_stream.on_recv(self._dispatch_reply, copy=False)
999-
self._task_stream = ZMQStream(self._task_socket, self._io_loop)
10001032
self._task_stream.on_recv(self._dispatch_reply, copy=False)
1001-
self._iopub_stream = ZMQStream(self._iopub_socket, self._io_loop)
1033+
self._broadcast_stream.on_recv(self._dispatch_reply, copy=False)
10021034
self._iopub_stream.on_recv(self._dispatch_iopub, copy=False)
1003-
self._notification_stream = ZMQStream(self._notification_socket, self._io_loop)
1004-
self._notification_stream.on_recv(self._dispatch_notification, copy=False)
10051035

1006-
self._broadcast_stream = ZMQStream(self._broadcast_socket, self._io_loop)
1007-
self._broadcast_stream.on_recv(self._dispatch_reply, copy=False)
1036+
def _stop_receiving(self, all=False):
1037+
"""Stop receiving on engine streams
1038+
1039+
If all: include hub streams
1040+
"""
1041+
if all:
1042+
streams = self._streams
1043+
else:
1044+
streams = self._engine_streams
1045+
for s in streams:
1046+
s.stop_on_recv()
10081047

10091048
def _start_io_thread(self):
10101049
"""Start IOLoop in a background thread."""
@@ -1034,6 +1073,30 @@ def _io_main(self, start_evt=None):
10341073
self._io_loop.start()
10351074
self._io_loop.close()
10361075

1076+
@contextmanager
1077+
def _pause_results(self):
1078+
"""Context manager to pause receiving results
1079+
1080+
When submitting lots of tasks,
1081+
the arrival of results can disrupt the processing
1082+
of new submissions.
1083+
1084+
Threadsafe.
1085+
"""
1086+
f = Future()
1087+
1088+
def _stop():
1089+
self._stop_receiving()
1090+
f.set_result(None)
1091+
1092+
# use add_callback to make it threadsafe
1093+
self._io_loop.add_callback(_stop)
1094+
f.result()
1095+
try:
1096+
yield
1097+
finally:
1098+
self._io_loop.add_callback(self._start_receiving)
1099+
10371100
@unpack_message
10381101
def _dispatch_single_reply(self, msg):
10391102
"""Dispatch single (non-execution) replies"""

ipyparallel/client/view.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -578,11 +578,12 @@ def _really_apply(
578578
pargs = [PrePickled(arg) for arg in args]
579579
pkwargs = {k: PrePickled(v) for k, v in kwargs.items()}
580580

581-
for ident in _idents:
582-
future = self.client.send_apply_request(
583-
self._socket, pf, pargs, pkwargs, track=track, ident=ident
584-
)
585-
futures.append(future)
581+
with self.client._pause_results():
582+
for ident in _idents:
583+
future = self.client.send_apply_request(
584+
self._socket, pf, pargs, pkwargs, track=track, ident=ident
585+
)
586+
futures.append(future)
586587
if track:
587588
trackers = [_.tracker for _ in futures]
588589
else:
@@ -641,9 +642,16 @@ def map(self, f, *sequences, block=None, track=False, return_exceptions=False):
641642

642643
assert len(sequences) > 0, "must have some sequences to map onto!"
643644
pf = ParallelFunction(
644-
self, f, block=block, track=track, return_exceptions=return_exceptions
645+
self, f, block=False, track=track, return_exceptions=return_exceptions
645646
)
646-
return pf.map(*sequences)
647+
with self.client._pause_results():
648+
ar = pf.map(*sequences)
649+
if block:
650+
try:
651+
return ar.get()
652+
except KeyboardInterrupt:
653+
return ar
654+
return ar
647655

648656
@sync_results
649657
@save_ids
@@ -665,11 +673,12 @@ def execute(self, code, silent=True, targets=None, block=None):
665673

666674
_idents, _targets = self.client._build_targets(targets)
667675
futures = []
668-
for ident in _idents:
669-
future = self.client.send_execute_request(
670-
self._socket, code, silent=silent, ident=ident
671-
)
672-
futures.append(future)
676+
with self.client._pause_results():
677+
for ident in _idents:
678+
future = self.client.send_execute_request(
679+
self._socket, code, silent=silent, ident=ident
680+
)
681+
futures.append(future)
673682
if isinstance(targets, int):
674683
futures = futures[0]
675684
ar = AsyncResult(
@@ -1292,12 +1301,19 @@ def map(
12921301
pf = ParallelFunction(
12931302
self,
12941303
f,
1295-
block=block,
1304+
block=False,
12961305
chunksize=chunksize,
12971306
ordered=ordered,
12981307
return_exceptions=return_exceptions,
12991308
)
1300-
return pf.map(*sequences)
1309+
with self.client._pause_results():
1310+
ar = pf.map(*sequences)
1311+
if block:
1312+
try:
1313+
return ar.get()
1314+
except KeyboardInterrupt:
1315+
return ar
1316+
return ar
13011317

13021318
def imap(
13031319
self,

0 commit comments

Comments
 (0)