|
6 | 6 | import sys
|
7 | 7 | import types
|
8 | 8 | import json
|
9 |
| -from typing import Optional, Union |
| 9 | +from typing import Callable, Optional, Union |
10 | 10 | import torch
|
11 | 11 | from torch.optim import Optimizer
|
12 | 12 | from torch.optim.lr_scheduler import _LRScheduler
|
|
27 | 27 |
|
28 | 28 | from .accelerator import get_accelerator
|
29 | 29 | from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
|
| 30 | +from .runtime.base_optimizer import DeepSpeedOptimizer |
| 31 | +from .runtime.dataloader import DeepSpeedDataLoader, RepeatingLoader |
30 | 32 | from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
|
31 | 33 | from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
|
32 | 34 | from .runtime.hybrid_engine import DeepSpeedHybridEngine
|
@@ -65,46 +67,44 @@ def _parse_version(version_str):
|
65 | 67 | # Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
|
66 | 68 | dist = None
|
67 | 69 |
|
| 70 | +DS_PRIM_TYPES = (DeepSpeedEngine, DeepSpeedHybridEngine, DeepSpeedOptimizer, DeepSpeedDataLoader, RepeatingLoader) |
| 71 | + |
68 | 72 |
|
69 | 73 | def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
|
70 | 74 | """Mark a trainobj as initialized by setting the ds_is_inited attribute to True."""
|
71 |
| - trainobj.ds_is_inited = True |
| 75 | + if not isinstance(trainobj, DS_PRIM_TYPES): # only mark non-DeepSpeed objects |
| 76 | + trainobj.ds_is_inited = True |
72 | 77 |
|
73 | 78 |
|
74 | 79 | def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
|
75 | 80 | """Check if a trainobj has been initialized by checking the ds_is_inited attribute."""
|
76 |
| - return getattr(trainobj, 'ds_is_inited', False) |
77 |
| - |
78 |
| - |
79 |
| -def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Union[Optimizer, |
80 |
| - DeepSpeedOptimizerCallable]], |
81 |
| - lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]): |
82 |
| - """Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call.""" |
83 |
| - if _is_ds_initialized(model): |
84 |
| - raise ValueError( |
85 |
| - "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.") |
86 |
| - if optimizer is not None and isinstance(optimizer, Optimizer) and _is_ds_initialized(optimizer): |
87 |
| - raise ValueError( |
88 |
| - "Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once." |
89 |
| - ) |
90 |
| - if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler) and _is_ds_initialized(lr_scheduler): |
91 |
| - raise ValueError( |
92 |
| - "LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once." |
93 |
| - ) |
94 |
| - |
95 |
| - |
96 |
| -def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Union[Optimizer, |
97 |
| - DeepSpeedOptimizerCallable]], |
98 |
| - lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]): |
99 |
| - """Mark the model, optimizer, and lr_scheduler as initialized. |
100 |
| - Note that callables of type DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable are not marked |
101 |
| - as they are not stateful and reuse should be permissible. |
102 |
| - """ |
103 |
| - _mark_ds_initialized(model) |
104 |
| - if optimizer is not None and isinstance(optimizer, Optimizer): |
105 |
| - _mark_ds_initialized(optimizer) |
106 |
| - if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler): |
107 |
| - _mark_ds_initialized(lr_scheduler) |
| 81 | + if isinstance(trainobj, DS_PRIM_TYPES): |
| 82 | + return True |
| 83 | + else: |
| 84 | + return getattr(trainobj, 'ds_is_inited', False) |
| 85 | + |
| 86 | + |
| 87 | +def _ensure_and_mark_trainobjs_inited( |
| 88 | + model: torch.nn.Module, |
| 89 | + optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]], |
| 90 | + lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]], |
| 91 | + ensures_not_inited: bool = False, |
| 92 | +): |
| 93 | + trainobjs = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler} |
| 94 | + |
| 95 | + for name, trainobj in trainobjs.items(): |
| 96 | + print(f"Checking {name}") |
| 97 | + if trainobj is None: |
| 98 | + continue |
| 99 | + if name in ("optimizer", "lr_scheduler") and not isinstance(trainobj, (Optimizer, _LRScheduler)): |
| 100 | + # skipping DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable |
| 101 | + continue |
| 102 | + if ensures_not_inited: |
| 103 | + if _is_ds_initialized(trainobj): |
| 104 | + raise ValueError( |
| 105 | + f"{name} has already been initialized, please make sure to only call deepspeed.initialize on a {name} once." |
| 106 | + ) |
| 107 | + _mark_ds_initialized(trainobj) |
108 | 108 |
|
109 | 109 |
|
110 | 110 | def initialize(args=None,
|
@@ -179,9 +179,7 @@ def initialize(args=None,
|
179 | 179 |
|
180 | 180 | assert model is not None, "deepspeed.initialize requires a model"
|
181 | 181 | # enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call
|
182 |
| - _assert_trainobjs_not_inited(model, optimizer, lr_scheduler) |
183 |
| - # mark model, optimizer, and lr_scheduler as initialized |
184 |
| - _mark_trainobjs_initialized(model, optimizer, lr_scheduler) |
| 182 | + _ensure_and_mark_trainobjs_inited(model, optimizer, lr_scheduler, ensures_not_inited=True) |
185 | 183 |
|
186 | 184 | global dist
|
187 | 185 | from deepspeed import comm as dist
|
@@ -267,7 +265,7 @@ def initialize(args=None,
|
267 | 265 | zero.partition_parameters.restore_init_context()
|
268 | 266 |
|
269 | 267 | # mark engine, optimizer, and lr_scheduler as initialized
|
270 |
| - _mark_trainobjs_initialized(engine, engine.optimizer, engine.lr_scheduler) |
| 268 | + _ensure_and_mark_trainobjs_inited(engine, engine.optimizer, engine.lr_scheduler, ensures_not_inited=False) |
271 | 269 |
|
272 | 270 | return_items = [
|
273 | 271 | engine,
|
|
0 commit comments