Skip to content

Commit bc9c32e

Browse files
committed
pause receiving while submitting tasks
results coming in during submission causes thread contention while submitting tasks pause receiving messages while we are preparing tasks to be submitted
1 parent 3093386 commit bc9c32e

File tree

2 files changed

+81
-22
lines changed

2 files changed

+81
-22
lines changed

ipyparallel/client/client.py

+51-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import warnings
1313
from collections.abc import Iterable
1414
from concurrent.futures import Future
15+
from contextlib import contextmanager
1516
from getpass import getpass
1617
from pprint import pprint
1718
from threading import current_thread
@@ -990,21 +991,39 @@ def _stop_io_thread(self):
990991
self._io_thread.join()
991992

992993
def _setup_streams(self):
993-
self._query_stream = ZMQStream(self._query_socket, self._io_loop)
994+
self._streams = []
995+
self._query_stream = s = ZMQStream(self._query_socket, self._io_loop)
996+
self._streams.append(s)
997+
998+
self._control_stream = s = ZMQStream(self._control_socket, self._io_loop)
999+
self._streams.append(s)
1000+
self._mux_stream = s = ZMQStream(self._mux_socket, self._io_loop)
1001+
self._streams.append(s)
1002+
self._task_stream = s = ZMQStream(self._task_socket, self._io_loop)
1003+
self._streams.append(s)
1004+
self._broadcast_stream = s = ZMQStream(self._broadcast_socket, self._io_loop)
1005+
self._streams.append(s)
1006+
self._iopub_stream = s = ZMQStream(self._iopub_socket, self._io_loop)
1007+
self._streams.append(s)
1008+
self._notification_stream = s = ZMQStream(
1009+
self._notification_socket, self._io_loop
1010+
)
1011+
self._streams.append(s)
1012+
self._start_receiving()
1013+
1014+
def _start_receiving(self):
9941015
self._query_stream.on_recv(self._dispatch_single_reply, copy=False)
995-
self._control_stream = ZMQStream(self._control_socket, self._io_loop)
9961016
self._control_stream.on_recv(self._dispatch_single_reply, copy=False)
997-
self._mux_stream = ZMQStream(self._mux_socket, self._io_loop)
9981017
self._mux_stream.on_recv(self._dispatch_reply, copy=False)
999-
self._task_stream = ZMQStream(self._task_socket, self._io_loop)
10001018
self._task_stream.on_recv(self._dispatch_reply, copy=False)
1001-
self._iopub_stream = ZMQStream(self._iopub_socket, self._io_loop)
1019+
self._broadcast_stream.on_recv(self._dispatch_reply, copy=False)
10021020
self._iopub_stream.on_recv(self._dispatch_iopub, copy=False)
1003-
self._notification_stream = ZMQStream(self._notification_socket, self._io_loop)
10041021
self._notification_stream.on_recv(self._dispatch_notification, copy=False)
10051022

1006-
self._broadcast_stream = ZMQStream(self._broadcast_socket, self._io_loop)
1007-
self._broadcast_stream.on_recv(self._dispatch_reply, copy=False)
1023+
def _stop_receiving(self):
1024+
"""Stop receiving on engine streams"""
1025+
for s in self._streams:
1026+
s.stop_on_recv()
10081027

10091028
def _start_io_thread(self):
10101029
"""Start IOLoop in a background thread."""
@@ -1034,6 +1053,30 @@ def _io_main(self, start_evt=None):
10341053
self._io_loop.start()
10351054
self._io_loop.close()
10361055

1056+
@contextmanager
1057+
def _pause_results(self):
1058+
"""Context manager to pause receiving results
1059+
1060+
When submitting lots of tasks,
1061+
the arrival of results can disrupt the processing
1062+
of new submissions.
1063+
1064+
Threadsafe.
1065+
"""
1066+
f = Future()
1067+
1068+
def _stop():
1069+
self._stop_receiving()
1070+
f.set_result(None)
1071+
1072+
# use add_callback to make it threadsafe
1073+
self._io_loop.add_callback(_stop)
1074+
f.result()
1075+
try:
1076+
yield
1077+
finally:
1078+
self._io_loop.add_callback(self._start_receiving)
1079+
10371080
@unpack_message
10381081
def _dispatch_single_reply(self, msg):
10391082
"""Dispatch single (non-execution) replies"""

ipyparallel/client/view.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -578,11 +578,12 @@ def _really_apply(
578578
pargs = [PrePickled(arg) for arg in args]
579579
pkwargs = {k: PrePickled(v) for k, v in kwargs.items()}
580580

581-
for ident in _idents:
582-
future = self.client.send_apply_request(
583-
self._socket, pf, pargs, pkwargs, track=track, ident=ident
584-
)
585-
futures.append(future)
581+
with self.client._pause_results():
582+
for ident in _idents:
583+
future = self.client.send_apply_request(
584+
self._socket, pf, pargs, pkwargs, track=track, ident=ident
585+
)
586+
futures.append(future)
586587
if track:
587588
trackers = [_.tracker for _ in futures]
588589
else:
@@ -641,9 +642,16 @@ def map(self, f, *sequences, block=None, track=False, return_exceptions=False):
641642

642643
assert len(sequences) > 0, "must have some sequences to map onto!"
643644
pf = ParallelFunction(
644-
self, f, block=block, track=track, return_exceptions=return_exceptions
645+
self, f, block=False, track=track, return_exceptions=return_exceptions
645646
)
646-
return pf.map(*sequences)
647+
with self.client._pause_results():
648+
ar = pf.map(*sequences)
649+
if block:
650+
try:
651+
return ar.get()
652+
except KeyboardInterrupt:
653+
return ar
654+
return ar
647655

648656
@sync_results
649657
@save_ids
@@ -665,11 +673,12 @@ def execute(self, code, silent=True, targets=None, block=None):
665673

666674
_idents, _targets = self.client._build_targets(targets)
667675
futures = []
668-
for ident in _idents:
669-
future = self.client.send_execute_request(
670-
self._socket, code, silent=silent, ident=ident
671-
)
672-
futures.append(future)
676+
with self.client._pause_results():
677+
for ident in _idents:
678+
future = self.client.send_execute_request(
679+
self._socket, code, silent=silent, ident=ident
680+
)
681+
futures.append(future)
673682
if isinstance(targets, int):
674683
futures = futures[0]
675684
ar = AsyncResult(
@@ -1292,12 +1301,19 @@ def map(
12921301
pf = ParallelFunction(
12931302
self,
12941303
f,
1295-
block=block,
1304+
block=False,
12961305
chunksize=chunksize,
12971306
ordered=ordered,
12981307
return_exceptions=return_exceptions,
12991308
)
1300-
return pf.map(*sequences)
1309+
with self.client._pause_results():
1310+
ar = pf.map(*sequences)
1311+
if block:
1312+
try:
1313+
return ar.get()
1314+
except KeyboardInterrupt:
1315+
return ar
1316+
return ar
13011317

13021318
def imap(
13031319
self,

0 commit comments

Comments
 (0)