|
25 | 25 | class EarlyStopException(XGBoostError):
|
26 | 26 | pass
|
27 | 27 |
|
| 28 | + |
| 29 | +# From xgboost>=1.7.0, rabit is replaced by a collective communicator |
| 30 | +try: |
| 31 | + from xgboost.collective import CommunicatorContext |
| 32 | + rabit = None |
| 33 | + HAS_COLLECTIVE = True |
| 34 | +except ImportError: |
| 35 | + from xgboost import rabit # noqa |
| 36 | + CommunicatorContext = None |
| 37 | + HAS_COLLECTIVE = False |
| 38 | + |
28 | 39 | from xgboost_ray.callback import DistributedCallback, \
|
29 | 40 | DistributedCallbackContainer
|
30 | 41 | from xgboost_ray.compat import TrainingCallback, RabitTracker, LEGACY_CALLBACK
|
@@ -66,7 +77,7 @@ def inner_f(*args, **kwargs):
|
66 | 77 | RayDeviceQuantileDMatrix, RayDataIter, concat_dataframes, \
|
67 | 78 | LEGACY_MATRIX
|
68 | 79 | from xgboost_ray.session import init_session, put_queue, \
|
69 |
| - set_session_queue |
| 80 | + set_session_queue, get_rabit_rank |
70 | 81 |
|
71 | 82 |
|
72 | 83 | def _get_environ(item: str, old_val: Any):
|
@@ -237,25 +248,40 @@ def _stop_rabit_tracker(rabit_process: multiprocessing.Process):
|
237 | 248 | rabit_process.terminate()
|
238 | 249 |
|
239 | 250 |
|
240 |
| -class _RabitContext: |
| 251 | +class _RabitContextBase: |
241 | 252 | """This context is used by local training actors to connect to the
|
242 | 253 | Rabit tracker.
|
243 | 254 |
|
244 | 255 | Args:
|
245 | 256 | actor_id (str): Unique actor ID
|
246 |
| - args (list): Arguments for Rabit initialisation. These are |
| 257 | + args (dict): Arguments for Rabit initialisation. These are |
247 | 258 | environment variables to configure Rabit clients.
|
248 | 259 | """
|
249 | 260 |
|
250 |
| - def __init__(self, actor_id, args): |
| 261 | + def __init__(self, actor_id: int, args: dict): |
| 262 | + args["DMLC_TASK_ID"] = "[xgboost.ray]:" + actor_id |
251 | 263 | self.args = args
|
252 |
| - self.args.append(("DMLC_TASK_ID=[xgboost.ray]:" + actor_id).encode()) |
253 | 264 |
|
254 |
| - def __enter__(self): |
255 |
| - xgb.rabit.init(self.args) |
256 | 265 |
|
257 |
| - def __exit__(self, *args): |
258 |
| - xgb.rabit.finalize() |
| 266 | +# From xgboost>=1.7.0, rabit is replaced by a collective communicator |
| 267 | +if HAS_COLLECTIVE: |
| 268 | + |
| 269 | + class _RabitContext(_RabitContextBase, CommunicatorContext): |
| 270 | + pass |
| 271 | + |
| 272 | +else: |
| 273 | + |
| 274 | + class _RabitContext(_RabitContextBase): |
| 275 | + def __init__(self, actor_id: int, args: dict): |
| 276 | + super().__init__(actor_id, args) |
| 277 | + self._list_args = [("%s=%s" % item).encode() |
| 278 | + for item in self.args.items()] |
| 279 | + |
| 280 | + def __enter__(self): |
| 281 | + xgb.rabit.init(self._list_args) |
| 282 | + |
| 283 | + def __exit__(self, *args): |
| 284 | + xgb.rabit.finalize() |
259 | 285 |
|
260 | 286 |
|
261 | 287 | def _ray_get_actor_cpus():
|
@@ -517,12 +543,12 @@ def _save_checkpoint_callback(self):
|
517 | 543 |
|
518 | 544 | class _SaveInternalCheckpointCallback(TrainingCallback):
|
519 | 545 | def after_iteration(self, model, epoch, evals_log):
|
520 |
| - if xgb.rabit.get_rank() == 0 and \ |
| 546 | + if get_rabit_rank() == 0 and \ |
521 | 547 | epoch % this.checkpoint_frequency == 0:
|
522 | 548 | put_queue(_Checkpoint(epoch, pickle.dumps(model)))
|
523 | 549 |
|
524 | 550 | def after_training(self, model):
|
525 |
| - if xgb.rabit.get_rank() == 0: |
| 551 | + if get_rabit_rank() == 0: |
526 | 552 | put_queue(_Checkpoint(-1, pickle.dumps(model)))
|
527 | 553 | return model
|
528 | 554 |
|
@@ -1054,8 +1080,7 @@ def handle_actor_failure(actor_id):
|
1054 | 1080 | maybe_log("[RayXGBoost] Starting XGBoost training.")
|
1055 | 1081 |
|
1056 | 1082 | # Start Rabit tracker for gradient sharing
|
1057 |
| - rabit_process, env = _start_rabit_tracker(alive_actors) |
1058 |
| - rabit_args = [("%s=%s" % item).encode() for item in env.items()] |
| 1083 | + rabit_process, rabit_args = _start_rabit_tracker(alive_actors) |
1059 | 1084 |
|
1060 | 1085 | # Load checkpoint if we have one. In that case we need to adjust the
|
1061 | 1086 | # number of training rounds.
|
|
0 commit comments