Skip to content

Commit dd51311

Browse files
authored
Compatibility for xgboost>=1.7.0, fix master CI (#242)
* Update requirements-test.txt Signed-off-by: Antoni Baum <[email protected]> * Update requirements-test.txt Signed-off-by: Antoni Baum <[email protected]> * Update requirements-test.txt Signed-off-by: Antoni Baum <[email protected]> * Update requirements-test.txt Signed-off-by: Antoni Baum <[email protected]> * Compatibility for xgboost 1.7.0 Signed-off-by: Antoni Baum <[email protected]> * Fix MRO Signed-off-by: Antoni Baum <[email protected]> Signed-off-by: Antoni Baum <[email protected]>
1 parent d0647bc commit dd51311

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

requirements-test.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ packaging
66
petastorm
77
pytest
88
pyarrow
9-
ray[tune]
9+
ray[tune, data]
1010
scikit-learn
1111
modin
1212
dask
1313

1414
#workaround for now
15+
protobuf<4.0.0
1516
tensorboardX==2.2

xgboost_ray/main.py

+38-13
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@
2525
class EarlyStopException(XGBoostError):
2626
pass
2727

28+
29+
# From xgboost>=1.7.0, rabit is replaced by a collective communicator
30+
try:
31+
from xgboost.collective import CommunicatorContext
32+
rabit = None
33+
HAS_COLLECTIVE = True
34+
except ImportError:
35+
from xgboost import rabit # noqa
36+
CommunicatorContext = None
37+
HAS_COLLECTIVE = False
38+
2839
from xgboost_ray.callback import DistributedCallback, \
2940
DistributedCallbackContainer
3041
from xgboost_ray.compat import TrainingCallback, RabitTracker, LEGACY_CALLBACK
@@ -66,7 +77,7 @@ def inner_f(*args, **kwargs):
6677
RayDeviceQuantileDMatrix, RayDataIter, concat_dataframes, \
6778
LEGACY_MATRIX
6879
from xgboost_ray.session import init_session, put_queue, \
69-
set_session_queue
80+
set_session_queue, get_rabit_rank
7081

7182

7283
def _get_environ(item: str, old_val: Any):
@@ -237,25 +248,40 @@ def _stop_rabit_tracker(rabit_process: multiprocessing.Process):
237248
rabit_process.terminate()
238249

239250

240-
class _RabitContext:
251+
class _RabitContextBase:
241252
"""This context is used by local training actors to connect to the
242253
Rabit tracker.
243254
244255
Args:
245256
actor_id (str): Unique actor ID
246-
args (list): Arguments for Rabit initialisation. These are
257+
args (dict): Arguments for Rabit initialisation. These are
247258
environment variables to configure Rabit clients.
248259
"""
249260

250-
def __init__(self, actor_id, args):
261+
def __init__(self, actor_id: int, args: dict):
262+
args["DMLC_TASK_ID"] = "[xgboost.ray]:" + actor_id
251263
self.args = args
252-
self.args.append(("DMLC_TASK_ID=[xgboost.ray]:" + actor_id).encode())
253264

254-
def __enter__(self):
255-
xgb.rabit.init(self.args)
256265

257-
def __exit__(self, *args):
258-
xgb.rabit.finalize()
266+
# From xgboost>=1.7.0, rabit is replaced by a collective communicator
267+
if HAS_COLLECTIVE:
268+
269+
class _RabitContext(_RabitContextBase, CommunicatorContext):
270+
pass
271+
272+
else:
273+
274+
class _RabitContext(_RabitContextBase):
275+
def __init__(self, actor_id: int, args: dict):
276+
super().__init__(actor_id, args)
277+
self._list_args = [("%s=%s" % item).encode()
278+
for item in self.args.items()]
279+
280+
def __enter__(self):
281+
xgb.rabit.init(self._list_args)
282+
283+
def __exit__(self, *args):
284+
xgb.rabit.finalize()
259285

260286

261287
def _ray_get_actor_cpus():
@@ -517,12 +543,12 @@ def _save_checkpoint_callback(self):
517543

518544
class _SaveInternalCheckpointCallback(TrainingCallback):
519545
def after_iteration(self, model, epoch, evals_log):
520-
if xgb.rabit.get_rank() == 0 and \
546+
if get_rabit_rank() == 0 and \
521547
epoch % this.checkpoint_frequency == 0:
522548
put_queue(_Checkpoint(epoch, pickle.dumps(model)))
523549

524550
def after_training(self, model):
525-
if xgb.rabit.get_rank() == 0:
551+
if get_rabit_rank() == 0:
526552
put_queue(_Checkpoint(-1, pickle.dumps(model)))
527553
return model
528554

@@ -1054,8 +1080,7 @@ def handle_actor_failure(actor_id):
10541080
maybe_log("[RayXGBoost] Starting XGBoost training.")
10551081

10561082
# Start Rabit tracker for gradient sharing
1057-
rabit_process, env = _start_rabit_tracker(alive_actors)
1058-
rabit_args = [("%s=%s" % item).encode() for item in env.items()]
1083+
rabit_process, rabit_args = _start_rabit_tracker(alive_actors)
10591084

10601085
# Load checkpoint if we have one. In that case we need to adjust the
10611086
# number of training rounds.

xgboost_ray/session.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ def get_actor_rank() -> int:
6363
@PublicAPI
6464
def get_rabit_rank() -> int:
6565
import xgboost as xgb
66-
return xgb.rabit.get_rank()
66+
try:
67+
# From xgboost>=1.7.0, rabit is replaced by a collective communicator
68+
return xgb.collective.get_rank()
69+
except (ImportError, AttributeError):
70+
return xgb.rabit.get_rank()
6771

6872

6973
@PublicAPI

0 commit comments

Comments
 (0)