Skip to content

Commit 5d75c52

Browse files
committed
WIP cluster API
1 parent 0fa696a commit 5d75c52

File tree

3 files changed

+444
-0
lines changed

3 files changed

+444
-0
lines changed

Diff for: ipyparallel/cluster/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .cluster import *

Diff for: ipyparallel/cluster/cluster.py

+333
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
"""Cluster class
2+
3+
defines the basic interface to a single IPython Parallel cluster
4+
5+
starts/stops/polls controllers, engines, etc.
6+
"""
7+
import asyncio
8+
import inspect
9+
import logging
10+
import random
11+
import socket
12+
import string
13+
import sys
14+
import time
15+
from functools import partial
16+
from multiprocessing import cpu_count
17+
18+
import IPython
19+
import traitlets.log
20+
from IPython.core.profiledir import ProfileDir
21+
from traitlets import Any
22+
from traitlets import default
23+
from traitlets import Dict
24+
from traitlets import Integer
25+
from traitlets import Type
26+
from traitlets import Unicode
27+
from traitlets.config import LoggingConfigurable
28+
29+
from .._async import AsyncFirst
30+
from ..apps import launcher
31+
32+
_suffix_chars = string.ascii_lowercase + string.digits
33+
34+
35+
class Cluster(AsyncFirst, LoggingConfigurable):
36+
"""Class representing an IPP cluster
37+
38+
i.e. one controller and a groups of engines
39+
40+
Can start/stop/monitor/poll cluster resources
41+
"""
42+
43+
controller_launcher_class = Type(
44+
default_value=launcher.LocalControllerLauncher,
45+
klass=launcher.BaseLauncher,
46+
help="""Launcher class for controllers""",
47+
)
48+
engine_launcher_class = Type(
49+
default_value=launcher.LocalEngineSetLauncher,
50+
klass=launcher.BaseLauncher,
51+
help="""Launcher class for sets of engines""",
52+
)
53+
54+
cluster_id = Unicode(help="The id of the cluster (default: random string)")
55+
56+
@default("cluster_id")
57+
def _default_cluster_id(self):
58+
return f"{socket.gethostname()}-{int(time.time())}-{''.join(random.choice(_suffix_chars) for i in range(4))}"
59+
60+
profile_dir = Unicode(
61+
help="""The profile directory.
62+
63+
Default priority:
64+
65+
- specified explicitly
66+
- current IPython session
67+
- use profile name (default: 'default')
68+
69+
"""
70+
)
71+
72+
@default("profile_dir")
73+
def _default_profile_dir(self):
74+
if not self.profile:
75+
ip = IPython.get_ipython()
76+
if ip is not None:
77+
return ip.profile_dir.location
78+
return ProfileDir.find_profile_dir_by_name(
79+
IPython.paths.get_ipython_dir(), name=self.profile or 'default'
80+
).location
81+
82+
profile = Unicode(
83+
"",
84+
help="""The profile name,
85+
a shortcut for specifying profile_dir within $IPYTHONDIR.""",
86+
)
87+
88+
n = Integer(None, allow_none=True, help="The number of engines to start")
89+
90+
@default("parent")
91+
def _default_parent(self):
92+
"""Default to inheriting config from current IPython session"""
93+
return IPython.get_ipython()
94+
95+
log_level = Integer(logging.INFO)
96+
97+
@default("log")
98+
def _default_log(self):
99+
if self.parent and self.parent is IPython.get_ipython():
100+
# log to stdout in an IPython session
101+
log = logging.getLogger(f"{__name__}.{self.cluster_id}")
102+
log.setLevel(self.log_level)
103+
104+
handler = logging.StreamHandler(sys.stdout)
105+
log.handlers = [handler]
106+
return log
107+
else:
108+
return traitlets.log.get_logger()
109+
110+
# private state
111+
_controller = Any()
112+
_engine_sets = Dict()
113+
114+
@classmethod
115+
def from_json(self, json_dict):
116+
"""Construct a Cluster from serialized state"""
117+
raise NotImplementedError()
118+
119+
def to_json(self):
120+
"""Serialize a Cluster object for later reconstruction"""
121+
raise NotImplementedError()
122+
123+
async def start_controller(self):
124+
"""Start the controller"""
125+
# start controller
126+
# retrieve connection info
127+
# webhook?
128+
if self._controller is not None:
129+
raise RuntimeError(
130+
"controller is already running. Call stop_controller() first."
131+
)
132+
self._controller = self.controller_launcher_class(
133+
work_dir=u'.',
134+
parent=self,
135+
log=self.log,
136+
profile_dir=self.profile_dir,
137+
cluster_id=self.cluster_id,
138+
)
139+
self._controller.on_stop(self._controller_stopped)
140+
r = self._controller.start()
141+
if inspect.isawaitable(r):
142+
await r
143+
# TODO: retrieve connection info
144+
145+
def _controller_stopped(self, stop_data=None):
146+
"""Callback when a controller stops"""
147+
self.log.info(f"Controller stopped: {stop_data}")
148+
149+
async def start_engines(self, n=None, engine_set_id=None):
150+
"""Start an engine set
151+
152+
Returns an engine set id which can be used in stop_engines
153+
"""
154+
# TODO: send engines connection info
155+
if engine_set_id is None:
156+
engine_set_id = f"{int(time.time())}-{''.join(random.choice(_suffix_chars) for i in range(4))}"
157+
engine_set = self._engine_sets[engine_set_id] = self.engine_launcher_class(
158+
work_dir=u'.',
159+
parent=self,
160+
log=self.log,
161+
profile_dir=self.profile_dir,
162+
cluster_id=self.cluster_id,
163+
)
164+
if n is None:
165+
n = self.n
166+
n = getattr(engine_set, 'engine_count', n)
167+
if n is None:
168+
n = cpu_count()
169+
self.log.info(f"Starting {n or ''} engines with {self.engine_launcher_class}")
170+
r = engine_set.start(n)
171+
engine_set.on_stop(self._engines_stopped)
172+
if inspect.isawaitable(r):
173+
await r
174+
return engine_set_id
175+
176+
def _engines_stopped(self, engine_set_id, stop_data):
177+
self.log.warning(f"engine set stopped {engine_set_id}: {stop_data}")
178+
179+
async def start_cluster(self, n=None):
180+
"""Start a cluster
181+
182+
starts one controller and n engines (default: self.n)
183+
"""
184+
await self.start_controller()
185+
await self.start_engines(n)
186+
187+
async def stop_engines(self, engine_set_id=None):
188+
"""Stop an engine set
189+
190+
If engine_set_id is not given,
191+
all engines are stopped"""
192+
if engine_set_id is None:
193+
for engine_set_id in list(self._engine_sets):
194+
await self.stop_engines(engine_set_id)
195+
return
196+
197+
engine_set = self._engine_sets[engine_set_id]
198+
r = engine_set.stop()
199+
if inspect.isawaitable(r):
200+
await r
201+
self._engine_sets.pop(engine_set_id)
202+
203+
async def stop_engine(self, engine_id):
204+
"""Stop one engine
205+
206+
*May* stop all engines in a set,
207+
depending on EngineSet features (e.g. mpiexec)
208+
"""
209+
raise NotImplementedError("How do we find an engine by id?")
210+
211+
async def restart_engine_set(self, engine_set_id):
212+
"""Restart an engine set"""
213+
engine_set = self._engine_sets[engine_set_id]
214+
n = engine_set.n
215+
await self.stop_engines(engine_set_id)
216+
await self.start_engines(n, engine_set_id)
217+
218+
async def restart_engine(self, engine_id):
219+
"""Restart one engine
220+
221+
*May* stop all engines in a set,
222+
depending on EngineSet features (e.g. mpiexec)
223+
"""
224+
raise NotImplementedError("How do we find an engine by id?")
225+
226+
async def signal_engine(self, engine_id, signum):
227+
"""Signal one engine
228+
229+
*May* signal all engines in a set,
230+
depending on EngineSet features (e.g. mpiexec)
231+
"""
232+
raise NotImplementedError("How do we find an engine by id?")
233+
234+
async def signal_engines(self, engine_set_id, signum):
235+
"""Signal all engines in a set"""
236+
engine_set = self._engine_sets[engine_set_id]
237+
engine_set.signal(signum)
238+
239+
async def stop_controller(self):
240+
"""Stop the controller"""
241+
if self._controller and self._controller.running:
242+
r = self._controller.stop()
243+
if inspect.isawaitable(r):
244+
await r
245+
246+
self._controller = None
247+
248+
async def stop_cluster(self):
249+
"""Stop the controller and all engines"""
250+
await self.stop_engines()
251+
await self.stop_controller()
252+
253+
def connect_client(self):
254+
"""Return a client connected to the cluster"""
255+
# TODO: get connect info directly from controller
256+
# this assumes local files exist
257+
from ipyparallel import Client
258+
259+
return Client(
260+
parent=self, profile_dir=self.profile_dir, cluster_id=self.cluster_id
261+
)
262+
263+
# context managers (both async and sync)
264+
_context_client = None
265+
266+
async def __aenter__(self):
267+
await self.start_controller()
268+
await self.start_engines()
269+
client = self._context_client = self.connect_client()
270+
if self.n:
271+
# wait for engine registration
272+
# TODO: timeout
273+
while len(client) < self.n:
274+
await asyncio.sleep(0.1)
275+
return client
276+
277+
async def __aexit__(self, *args):
278+
if self._context_client is not None:
279+
self._context_client.close()
280+
self._context_client = None
281+
await self.stop_engines()
282+
await self.stop_controller()
283+
284+
def __enter__(self):
285+
self.start_controller_sync()
286+
self.start_engines_sync()
287+
client = self._context_client = self.connect_client()
288+
if self.n:
289+
# wait for engine registration
290+
while len(client) < self.n:
291+
time.sleep(0.1)
292+
return client
293+
294+
def __exit__(self, *args):
295+
if self._context_client:
296+
self._context_client.close()
297+
self._context_client = None
298+
self.stop_engines_sync()
299+
self.stop_controller_sync()
300+
301+
302+
class ClusterManager(LoggingConfigurable):
303+
"""A manager of clusters
304+
305+
Wraps Cluster, adding"""
306+
307+
_clusters = Dict(help="My cluster objects")
308+
309+
def load_clusters(self):
310+
"""Load serialized cluster state"""
311+
raise NotImplementedError()
312+
313+
def list_clusters(self):
314+
"""List current clusters"""
315+
316+
def new_cluster(self, cluster_cls):
317+
"""Create a new cluster"""
318+
319+
def _cluster_method(self, method_name, cluster_id, *args, **kwargs):
320+
"""Wrapper around single-cluster methods
321+
322+
Defines ClusterManager.method(cluster_id, ...)
323+
324+
which returns ClusterManager.clusters[cluster_id].method(...)
325+
"""
326+
cluster = self._clusters[cluster_id]
327+
method = getattr(cluster, method_name)
328+
return method(*args, **kwargs)
329+
330+
def __getattr__(self, key):
331+
if key in Cluster.__dict__:
332+
return partial(self._cluster_method, key)
333+
return super().__getattr__(self, key)

0 commit comments

Comments
 (0)