Skip to content

Commit 7f6fc1f

Browse files
change: do init checking and marking in one func
1 parent 62067cc commit 7f6fc1f

File tree

2 files changed

+41
-63
lines changed

2 files changed

+41
-63
lines changed

deepspeed/__init__.py

+36-38
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77
import types
88
import json
9-
from typing import Optional, Union
9+
from typing import Callable, Optional, Union
1010
import torch
1111
from torch.optim import Optimizer
1212
from torch.optim.lr_scheduler import _LRScheduler
@@ -27,6 +27,8 @@
2727

2828
from .accelerator import get_accelerator
2929
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
30+
from .runtime.base_optimizer import DeepSpeedOptimizer
31+
from .runtime.dataloader import DeepSpeedDataLoader, RepeatingLoader
3032
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
3133
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
3234
from .runtime.hybrid_engine import DeepSpeedHybridEngine
@@ -65,46 +67,44 @@ def _parse_version(version_str):
6567
# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
6668
dist = None
6769

70+
DS_PRIM_TYPES = (DeepSpeedEngine, DeepSpeedHybridEngine, DeepSpeedOptimizer, DeepSpeedDataLoader, RepeatingLoader)
71+
6872

6973
def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
7074
"""Mark a trainobj as initialized by setting the ds_is_inited attribute to True."""
71-
trainobj.ds_is_inited = True
75+
if not isinstance(trainobj, DS_PRIM_TYPES): # only mark non-DeepSpeed objects
76+
trainobj.ds_is_inited = True
7277

7378

7479
def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
7580
"""Check if a trainobj has been initialized by checking the ds_is_inited attribute."""
76-
return getattr(trainobj, 'ds_is_inited', False)
77-
78-
79-
def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Union[Optimizer,
80-
DeepSpeedOptimizerCallable]],
81-
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]):
82-
"""Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call."""
83-
if _is_ds_initialized(model):
84-
raise ValueError(
85-
"Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.")
86-
if optimizer is not None and isinstance(optimizer, Optimizer) and _is_ds_initialized(optimizer):
87-
raise ValueError(
88-
"Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once."
89-
)
90-
if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler) and _is_ds_initialized(lr_scheduler):
91-
raise ValueError(
92-
"LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once."
93-
)
94-
95-
96-
def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Union[Optimizer,
97-
DeepSpeedOptimizerCallable]],
98-
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]):
99-
"""Mark the model, optimizer, and lr_scheduler as initialized.
100-
Note that callables of type DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable are not marked
101-
as they are not stateful and reuse should be permissible.
102-
"""
103-
_mark_ds_initialized(model)
104-
if optimizer is not None and isinstance(optimizer, Optimizer):
105-
_mark_ds_initialized(optimizer)
106-
if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler):
107-
_mark_ds_initialized(lr_scheduler)
81+
if isinstance(trainobj, DS_PRIM_TYPES):
82+
return True
83+
else:
84+
return getattr(trainobj, 'ds_is_inited', False)
85+
86+
87+
def _ensure_and_mark_trainobjs_inited(
88+
model: torch.nn.Module,
89+
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]],
90+
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]],
91+
ensures_not_inited: bool = False,
92+
):
93+
trainobjs = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
94+
95+
for name, trainobj in trainobjs.items():
96+
print(f"Checking {name}")
97+
if trainobj is None:
98+
continue
99+
if name in ("optimizer", "lr_scheduler") and not isinstance(trainobj, (Optimizer, _LRScheduler)):
100+
# skipping DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable
101+
continue
102+
if ensures_not_inited:
103+
if _is_ds_initialized(trainobj):
104+
raise ValueError(
105+
f"{name} has already been initialized, please make sure to only call deepspeed.initialize on a {name} once."
106+
)
107+
_mark_ds_initialized(trainobj)
108108

109109

110110
def initialize(args=None,
@@ -179,9 +179,7 @@ def initialize(args=None,
179179

180180
assert model is not None, "deepspeed.initialize requires a model"
181181
# enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call
182-
_assert_trainobjs_not_inited(model, optimizer, lr_scheduler)
183-
# mark model, optimizer, and lr_scheduler as initialized
184-
_mark_trainobjs_initialized(model, optimizer, lr_scheduler)
182+
_ensure_and_mark_trainobjs_inited(model, optimizer, lr_scheduler, ensures_not_inited=True)
185183

186184
global dist
187185
from deepspeed import comm as dist
@@ -267,7 +265,7 @@ def initialize(args=None,
267265
zero.partition_parameters.restore_init_context()
268266

269267
# mark engine, optimizer, and lr_scheduler as initialized
270-
_mark_trainobjs_initialized(engine, engine.optimizer, engine.lr_scheduler)
268+
_ensure_and_mark_trainobjs_inited(engine, engine.optimizer, engine.lr_scheduler, ensures_not_inited=False)
271269

272270
return_items = [
273271
engine,

tests/unit/runtime/test_ds_initialize.py

+5-25
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from deepspeed.utils.torch import required_torch_version
2222
from deepspeed.accelerator import get_accelerator
2323
from deepspeed.ops.op_builder import FusedAdamBuilder
24-
from deepspeed import _assert_trainobjs_not_inited, _is_ds_initialized
24+
from deepspeed import _is_ds_initialized
2525

2626

2727
@pytest.mark.parametrize('zero_stage', [0, 3])
@@ -459,7 +459,6 @@ def _optimizer_callable(params) -> Optimizer:
459459
client_optimizer = _optimizer_callable
460460

461461
# Initialize DeepSpeed engine
462-
_assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None)
463462
model_engine, optim, _, _ = deepspeed.initialize(model=model,
464463
optimizer=client_optimizer,
465464
config_params=config_dict)
@@ -473,33 +472,14 @@ def _optimizer_callable(params) -> Optimizer:
473472
assert _is_ds_initialized(model_engine), "Model engine should be marked as initialized"
474473
assert _is_ds_initialized(optim), "Optimizer should be marked as initialized"
475474

476-
exception_raised = False
477-
try:
475+
with pytest.raises(ValueError):
478476
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)
479-
except ValueError:
480-
exception_raised = True
481477

482-
assert exception_raised, "Repeated initialization should raise an exception"
483-
484-
exception_raised = False
485-
try:
478+
with pytest.raises(ValueError):
486479
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)
487-
except ValueError:
488-
exception_raised = True
489-
490-
assert exception_raised, "Initialization on ds types should raise an exception"
491480

492-
exception_raised = False
493-
try:
481+
with pytest.raises(ValueError):
494482
deepspeed.initialize(model=model, optimizer=optim, config_params=config_dict)
495-
except ValueError:
496-
exception_raised = True
497-
498-
assert exception_raised, "Initialization on ds types should raise an exception"
499483

500-
exception_raised = False
501-
try:
484+
with pytest.raises(ValueError):
502485
deepspeed.initialize(model=model_engine, optimizer=optim, config_params=config_dict)
503-
except ValueError:
504-
exception_raised = True
505-
assert exception_raised, "Initialization on ds types should raise an exception"

0 commit comments

Comments
 (0)