|
| 1 | +"""Cluster class |
| 2 | +
|
| 3 | +defines the basic interface to a single IPython Parallel cluster |
| 4 | +
|
| 5 | +starts/stops/polls controllers, engines, etc. |
| 6 | +""" |
| 7 | +import asyncio |
| 8 | +import inspect |
| 9 | +import logging |
| 10 | +import random |
| 11 | +import socket |
| 12 | +import string |
| 13 | +import sys |
| 14 | +import time |
| 15 | +from functools import partial |
| 16 | +from multiprocessing import cpu_count |
| 17 | + |
| 18 | +import IPython |
| 19 | +import traitlets.log |
| 20 | +from IPython.core.profiledir import ProfileDir |
| 21 | +from traitlets import Any |
| 22 | +from traitlets import default |
| 23 | +from traitlets import Dict |
| 24 | +from traitlets import Integer |
| 25 | +from traitlets import Type |
| 26 | +from traitlets import Unicode |
| 27 | +from traitlets.config import LoggingConfigurable |
| 28 | + |
| 29 | +from .._async import AsyncFirst |
| 30 | +from ..apps import launcher |
| 31 | + |
| 32 | +_suffix_chars = string.ascii_lowercase + string.digits |
| 33 | + |
| 34 | + |
| 35 | +class Cluster(AsyncFirst, LoggingConfigurable): |
| 36 | + """Class representing an IPP cluster |
| 37 | +
|
| 38 | + i.e. one controller and a groups of engines |
| 39 | +
|
| 40 | + Can start/stop/monitor/poll cluster resources |
| 41 | + """ |
| 42 | + |
| 43 | + controller_launcher_class = Type( |
| 44 | + default_value=launcher.LocalControllerLauncher, |
| 45 | + klass=launcher.BaseLauncher, |
| 46 | + help="""Launcher class for controllers""", |
| 47 | + ) |
| 48 | + engine_launcher_class = Type( |
| 49 | + default_value=launcher.LocalEngineSetLauncher, |
| 50 | + klass=launcher.BaseLauncher, |
| 51 | + help="""Launcher class for sets of engines""", |
| 52 | + ) |
| 53 | + |
| 54 | + cluster_id = Unicode(help="The id of the cluster (default: random string)") |
| 55 | + |
| 56 | + @default("cluster_id") |
| 57 | + def _default_cluster_id(self): |
| 58 | + return f"{socket.gethostname()}-{int(time.time())}-{''.join(random.choice(_suffix_chars) for i in range(4))}" |
| 59 | + |
| 60 | + profile_dir = Unicode( |
| 61 | + help="""The profile directory. |
| 62 | +
|
| 63 | + Default priority: |
| 64 | +
|
| 65 | + - specified explicitly |
| 66 | + - current IPython session |
| 67 | + - use profile name (default: 'default') |
| 68 | +
|
| 69 | + """ |
| 70 | + ) |
| 71 | + |
| 72 | + @default("profile_dir") |
| 73 | + def _default_profile_dir(self): |
| 74 | + if not self.profile: |
| 75 | + ip = IPython.get_ipython() |
| 76 | + if ip is not None: |
| 77 | + return ip.profile_dir.location |
| 78 | + return ProfileDir.find_profile_dir_by_name( |
| 79 | + IPython.paths.get_ipython_dir(), name=self.profile or 'default' |
| 80 | + ).location |
| 81 | + |
| 82 | + profile = Unicode( |
| 83 | + "", |
| 84 | + help="""The profile name, |
| 85 | + a shortcut for specifying profile_dir within $IPYTHONDIR.""", |
| 86 | + ) |
| 87 | + |
| 88 | + n = Integer(None, allow_none=True, help="The number of engines to start") |
| 89 | + |
| 90 | + @default("parent") |
| 91 | + def _default_parent(self): |
| 92 | + """Default to inheriting config from current IPython session""" |
| 93 | + return IPython.get_ipython() |
| 94 | + |
| 95 | + log_level = Integer(logging.INFO) |
| 96 | + |
| 97 | + @default("log") |
| 98 | + def _default_log(self): |
| 99 | + if self.parent and self.parent is IPython.get_ipython(): |
| 100 | + # log to stdout in an IPython session |
| 101 | + log = logging.getLogger(f"{__name__}.{self.cluster_id}") |
| 102 | + log.setLevel(self.log_level) |
| 103 | + |
| 104 | + handler = logging.StreamHandler(sys.stdout) |
| 105 | + log.handlers = [handler] |
| 106 | + return log |
| 107 | + else: |
| 108 | + return traitlets.log.get_logger() |
| 109 | + |
| 110 | + # private state |
| 111 | + _controller = Any() |
| 112 | + _engine_sets = Dict() |
| 113 | + |
| 114 | + @classmethod |
| 115 | + def from_json(self, json_dict): |
| 116 | + """Construct a Cluster from serialized state""" |
| 117 | + raise NotImplementedError() |
| 118 | + |
| 119 | + def to_json(self): |
| 120 | + """Serialize a Cluster object for later reconstruction""" |
| 121 | + raise NotImplementedError() |
| 122 | + |
| 123 | + async def start_controller(self): |
| 124 | + """Start the controller""" |
| 125 | + # start controller |
| 126 | + # retrieve connection info |
| 127 | + # webhook? |
| 128 | + if self._controller is not None: |
| 129 | + raise RuntimeError( |
| 130 | + "controller is already running. Call stop_controller() first." |
| 131 | + ) |
| 132 | + self._controller = self.controller_launcher_class( |
| 133 | + work_dir=u'.', |
| 134 | + parent=self, |
| 135 | + log=self.log, |
| 136 | + profile_dir=self.profile_dir, |
| 137 | + cluster_id=self.cluster_id, |
| 138 | + ) |
| 139 | + self._controller.on_stop(self._controller_stopped) |
| 140 | + r = self._controller.start() |
| 141 | + if inspect.isawaitable(r): |
| 142 | + await r |
| 143 | + # TODO: retrieve connection info |
| 144 | + |
| 145 | + def _controller_stopped(self, stop_data=None): |
| 146 | + """Callback when a controller stops""" |
| 147 | + self.log.info(f"Controller stopped: {stop_data}") |
| 148 | + |
| 149 | + async def start_engines(self, n=None, engine_set_id=None): |
| 150 | + """Start an engine set |
| 151 | +
|
| 152 | + Returns an engine set id which can be used in stop_engines |
| 153 | + """ |
| 154 | + # TODO: send engines connection info |
| 155 | + if engine_set_id is None: |
| 156 | + engine_set_id = f"{int(time.time())}-{''.join(random.choice(_suffix_chars) for i in range(4))}" |
| 157 | + engine_set = self._engine_sets[engine_set_id] = self.engine_launcher_class( |
| 158 | + work_dir=u'.', |
| 159 | + parent=self, |
| 160 | + log=self.log, |
| 161 | + profile_dir=self.profile_dir, |
| 162 | + cluster_id=self.cluster_id, |
| 163 | + ) |
| 164 | + if n is None: |
| 165 | + n = self.n |
| 166 | + n = getattr(engine_set, 'engine_count', n) |
| 167 | + if n is None: |
| 168 | + n = cpu_count() |
| 169 | + self.log.info(f"Starting {n or ''} engines with {self.engine_launcher_class}") |
| 170 | + r = engine_set.start(n) |
| 171 | + engine_set.on_stop(self._engines_stopped) |
| 172 | + if inspect.isawaitable(r): |
| 173 | + await r |
| 174 | + return engine_set_id |
| 175 | + |
| 176 | + def _engines_stopped(self, engine_set_id, stop_data): |
| 177 | + self.log.warning(f"engine set stopped {engine_set_id}: {stop_data}") |
| 178 | + |
| 179 | + async def start_cluster(self, n=None): |
| 180 | + """Start a cluster |
| 181 | +
|
| 182 | + starts one controller and n engines (default: self.n) |
| 183 | + """ |
| 184 | + await self.start_controller() |
| 185 | + await self.start_engines(n) |
| 186 | + |
| 187 | + async def stop_engines(self, engine_set_id=None): |
| 188 | + """Stop an engine set |
| 189 | +
|
| 190 | + If engine_set_id is not given, |
| 191 | + all engines are stopped""" |
| 192 | + if engine_set_id is None: |
| 193 | + for engine_set_id in list(self._engine_sets): |
| 194 | + await self.stop_engines(engine_set_id) |
| 195 | + return |
| 196 | + |
| 197 | + engine_set = self._engine_sets[engine_set_id] |
| 198 | + r = engine_set.stop() |
| 199 | + if inspect.isawaitable(r): |
| 200 | + await r |
| 201 | + self._engine_sets.pop(engine_set_id) |
| 202 | + |
| 203 | + async def stop_engine(self, engine_id): |
| 204 | + """Stop one engine |
| 205 | +
|
| 206 | + *May* stop all engines in a set, |
| 207 | + depending on EngineSet features (e.g. mpiexec) |
| 208 | + """ |
| 209 | + raise NotImplementedError("How do we find an engine by id?") |
| 210 | + |
| 211 | + async def restart_engine_set(self, engine_set_id): |
| 212 | + """Restart an engine set""" |
| 213 | + engine_set = self._engine_sets[engine_set_id] |
| 214 | + n = engine_set.n |
| 215 | + await self.stop_engines(engine_set_id) |
| 216 | + await self.start_engines(n, engine_set_id) |
| 217 | + |
| 218 | + async def restart_engine(self, engine_id): |
| 219 | + """Restart one engine |
| 220 | +
|
| 221 | + *May* stop all engines in a set, |
| 222 | + depending on EngineSet features (e.g. mpiexec) |
| 223 | + """ |
| 224 | + raise NotImplementedError("How do we find an engine by id?") |
| 225 | + |
| 226 | + async def signal_engine(self, engine_id, signum): |
| 227 | + """Signal one engine |
| 228 | +
|
| 229 | + *May* signal all engines in a set, |
| 230 | + depending on EngineSet features (e.g. mpiexec) |
| 231 | + """ |
| 232 | + raise NotImplementedError("How do we find an engine by id?") |
| 233 | + |
| 234 | + async def signal_engines(self, engine_set_id, signum): |
| 235 | + """Signal all engines in a set""" |
| 236 | + engine_set = self._engine_sets[engine_set_id] |
| 237 | + engine_set.signal(signum) |
| 238 | + |
| 239 | + async def stop_controller(self): |
| 240 | + """Stop the controller""" |
| 241 | + if self._controller and self._controller.running: |
| 242 | + r = self._controller.stop() |
| 243 | + if inspect.isawaitable(r): |
| 244 | + await r |
| 245 | + |
| 246 | + self._controller = None |
| 247 | + |
| 248 | + async def stop_cluster(self): |
| 249 | + """Stop the controller and all engines""" |
| 250 | + await self.stop_engines() |
| 251 | + await self.stop_controller() |
| 252 | + |
| 253 | + def connect_client(self): |
| 254 | + """Return a client connected to the cluster""" |
| 255 | + # TODO: get connect info directly from controller |
| 256 | + # this assumes local files exist |
| 257 | + from ipyparallel import Client |
| 258 | + |
| 259 | + return Client( |
| 260 | + parent=self, profile_dir=self.profile_dir, cluster_id=self.cluster_id |
| 261 | + ) |
| 262 | + |
| 263 | + # context managers (both async and sync) |
| 264 | + _context_client = None |
| 265 | + |
| 266 | + async def __aenter__(self): |
| 267 | + await self.start_controller() |
| 268 | + await self.start_engines() |
| 269 | + client = self._context_client = self.connect_client() |
| 270 | + if self.n: |
| 271 | + # wait for engine registration |
| 272 | + # TODO: timeout |
| 273 | + while len(client) < self.n: |
| 274 | + await asyncio.sleep(0.1) |
| 275 | + return client |
| 276 | + |
| 277 | + async def __aexit__(self, *args): |
| 278 | + if self._context_client is not None: |
| 279 | + self._context_client.close() |
| 280 | + self._context_client = None |
| 281 | + await self.stop_engines() |
| 282 | + await self.stop_controller() |
| 283 | + |
| 284 | + def __enter__(self): |
| 285 | + self.start_controller_sync() |
| 286 | + self.start_engines_sync() |
| 287 | + client = self._context_client = self.connect_client() |
| 288 | + if self.n: |
| 289 | + # wait for engine registration |
| 290 | + while len(client) < self.n: |
| 291 | + time.sleep(0.1) |
| 292 | + return client |
| 293 | + |
| 294 | + def __exit__(self, *args): |
| 295 | + if self._context_client: |
| 296 | + self._context_client.close() |
| 297 | + self._context_client = None |
| 298 | + self.stop_engines_sync() |
| 299 | + self.stop_controller_sync() |
| 300 | + |
| 301 | + |
| 302 | +class ClusterManager(LoggingConfigurable): |
| 303 | + """A manager of clusters |
| 304 | +
|
| 305 | + Wraps Cluster, adding""" |
| 306 | + |
| 307 | + _clusters = Dict(help="My cluster objects") |
| 308 | + |
| 309 | + def load_clusters(self): |
| 310 | + """Load serialized cluster state""" |
| 311 | + raise NotImplementedError() |
| 312 | + |
| 313 | + def list_clusters(self): |
| 314 | + """List current clusters""" |
| 315 | + |
| 316 | + def new_cluster(self, cluster_cls): |
| 317 | + """Create a new cluster""" |
| 318 | + |
| 319 | + def _cluster_method(self, method_name, cluster_id, *args, **kwargs): |
| 320 | + """Wrapper around single-cluster methods |
| 321 | +
|
| 322 | + Defines ClusterManager.method(cluster_id, ...) |
| 323 | +
|
| 324 | + which returns ClusterManager.clusters[cluster_id].method(...) |
| 325 | + """ |
| 326 | + cluster = self._clusters[cluster_id] |
| 327 | + method = getattr(cluster, method_name) |
| 328 | + return method(*args, **kwargs) |
| 329 | + |
| 330 | + def __getattr__(self, key): |
| 331 | + if key in Cluster.__dict__: |
| 332 | + return partial(self._cluster_method, key) |
| 333 | + return super().__getattr__(self, key) |
0 commit comments