-
Notifications
You must be signed in to change notification settings - Fork 179
/
Copy pathconfig.py
428 lines (364 loc) · 14.1 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import os
import string
from pydantic import field_validator, model_validator, Field
from typing import List, Optional, Dict, Any
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.inference.config import DtypeEnum
from deepspeed.launcher.runner import DLTS_HOSTFILE, fetch_hostfile
import mii.legacy as mii
from .constants import DeploymentType, TaskType, ModelProvider, MII_MODEL_PATH_DEFAULT
class ReplicaConfig(DeepSpeedConfigModel):
hostname: str = ""
tensor_parallel_ports: List[int] = []
torch_dist_port: Optional[int] = None
gpu_indices: List[int] = []
class ModelConfig(DeepSpeedConfigModel):
model: str
"""
Name of a supported model for the task. Models in MII are sourced from
multiple open-source projects such as Huggingface Transformer, FairSeq,
EluetherAI etc. For the list of supported models for each task, please see
here [TODO].
"""
task: TaskType
"""
Name of the machine learning task to be deployed.Currently MII supports the
following list of tasks ``['text-generation', 'text-classification',
'question-answering', 'fill-mask', 'token-classification',
'text-to-image']``
"""
dtype: torch.dtype = torch.float32
"""
Desired model data type, will convert model to this type. Supported target
types: `torch.half`, `torch.float`, `torch.int8` (for BLOOM models)
"""
model_path: str = ""
"""
In LOCAL deployments this is the local path where model checkpoints are
available. In AML deployments this is an optional relative path with
AZURE_MODEL_DIR for the deployment.
"""
load_with_sys_mem: bool = False
"""
Loads the model onto system memory instead of GPU memory. This can help
avoid OOM errors when sharding a model across several GPUs because MII will
try to load a full copy of each model onto each GPU initially.
"""
meta_tensor: bool = False
"""
Loads the initial HuggingFace model using Meta Tensors that use no memory.
Can dramatically improve load time and reduce memory requirements on
supported models. Supported for GPT-J, GPT-NeoX, OPT, and BLOOM when kernel
injection is enabled. Supported for all models when kernel injection is
disabled.
"""
deploy_rank: Optional[List[int]] = None
"""
GPU indices a model is deployed on. Note that CUDA_VISIBLE_DEVICES does not
work with DeepSpeed-MII.
"""
torch_dist_port: int = 29500
"""
Torch distributed port.
"""
replica_num: int = 1
"""
Number of model replicas. Enables easy data parallelism.
"""
replica_configs: List[ReplicaConfig] = []
"""
Configuration details for each replica. This will be automatically
generated, but you can provide a set of custom configs.
"""
profile_model_time: bool = False
"""
Enable profiling of model times (i.e., without communication overhead).
"""
skip_model_check: bool = False
"""
Skip validation that a model supports a given task.
"""
hf_auth_token: Optional[str] = Field(
None,
json_schema_extra={
"deprecated":
True,
"deprecated_msg":
"Parameter will be removed. Please use the `pipeline_kwargs` field to pass kwargs to the HuggingFace pipeline creation."
},
)
"""
HuggingFace authentication token for accessing models. Will be propagated
to all ModelConfig if none are provided there.
"""
trust_remote_code: bool = Field(
False,
json_schema_extra={
"deprecated":
True,
"deprecated_msg":
"Parameter will be removed. Please use the `pipeline_kwargs` field to pass kwargs to the HuggingFace pipeline creation."
},
)
"""
HuggingFace `tranformer.pipeline` option for `trust_remote_code`.
"""
pipeline_kwargs: Dict[str, Any] = {}
"""
kwargs to be passed to HuggingFace's `transformer.pipeline`.
"""
# TODO: Replace with DeepSpeedInferenceConfig
enable_deepspeed: bool = True
"""
Enable DeepSpeed-Inference.
"""
enable_zero: bool = False
"""
Enable Zero-Inference.
"""
ds_config: Dict[str, Any] = {}
"""
DeepSpeed config to use when Zero-Inference is enabled.
"""
tensor_parallel: int = 1
"""
Tensor parallelism to use for a model (i.e., how many GPUs to shard a model across).
"""
enable_cuda_graph: bool = False
"""
Enables CUDA Graph captures with DeepSpeed-Inference.
"""
replace_with_kernel_inject: bool = True
"""
Enable custom kernel injection with DeepSpeed-Inference.
"""
checkpoint_dict: Optional[Dict[str, Any]] = None
"""
DeepSpeed model checkpoint dict.
"""
max_tokens: int = 1024
"""
The maximum number of tokens DeepSpeed-Inference can work with, including
the input and output tokens. Please consider increasing it to the required
token-length required for your use-case.
"""
@property
def provider(self):
return mii.utils.get_provider(self.model, self.task)
@field_validator("checkpoint_dict", mode="after")
@classmethod
def checkpoint_dict_valid(cls, field_value):
if field_value is None:
return field_value
for k in ["checkpoints", "version", "type", "base_dir"]:
if not field_value.get(k, ""):
raise ValueError(f"Missing key={k} in checkpoint_dict")
return field_value
@field_validator("deploy_rank", mode="before")
@classmethod
def deploy_rank_to_list(cls, field_value):
if field_value and not isinstance(field_value, list):
field_value = [field_value]
return field_value
@field_validator("dtype", mode="before")
def validate_dtype(cls, field_value, values):
if isinstance(field_value, str):
return DtypeEnum.from_str(field_value).value[0]
if isinstance(field_value, torch.dtype):
return field_value
raise TypeError(f"Invalid type for dtype: {type(field_value)}")
@model_validator(mode="after")
def zero_or_meta(self):
if self.enable_zero:
assert not self.meta_tensor, "ZeRO-Inference does not support meta tensors."
return self
@model_validator(mode="after")
def bloom_model_valid(self):
if "bigscience/bloom" in self.model:
# TODO: SHould be albe to use DtypeEnum here
assert self.dtype in [
torch.int8,
torch.float16,
], "Bloom models only support fp16/int8."
assert not self.enable_cuda_graph, "Bloom models do not support CUDA Graph."
return self
@model_validator(mode="after")
def deploy_rank_valid(self):
deploy_rank = self.deploy_rank
# if deploy rank is not given, default to align with TP value
if deploy_rank is None:
deploy_rank = list(range(self.tensor_parallel))
# number of ranks provided must be equal to TP size, DP is handled outside MII currently
assert self.tensor_parallel == len(
deploy_rank
), f"{len(deploy_rank)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {self.tensor_parallel}"
self.__dict__["deploy_rank"] = deploy_rank
return self
@model_validator(mode="before")
@classmethod
def set_model_path(cls, values):
model_path = values.get("model_path")
if not model_path:
if values.get("deployment_type") == DeploymentType.AML:
model_path = "model"
else:
model_path = MII_MODEL_PATH_DEFAULT
aml_model_dir = os.environ.get("AZUREML_MODEL_DIR", None)
if aml_model_dir and not model_path.startswith(aml_model_dir):
assert os.path.isabs(
aml_model_dir
), "AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path."
assert not os.path.isabs(
model_path
), f"model_path={model_path} must be relative to append w/ AML path."
model_path = os.path.join(aml_model_dir, model_path)
values["model_path"] = model_path
return values
@model_validator(mode="after")
def validate_model_and_task(self):
if not self.skip_model_check:
mii.utils.check_if_task_and_model_is_valid(self.task, self.model)
mii.utils.check_if_task_and_model_is_supported(self.task, self.model)
return self
@model_validator(mode="after")
def meta_tensor_or_sys_mem(self):
if self.meta_tensor and self.load_with_sys_mem:
raise ValueError(
"`meta_tensor` and `load_with_sys_mem` cannot be active at the same time."
)
return self
@model_validator(mode="after")
def sys_mem_and_diffusers(self):
if self.load_with_sys_mem:
assert not (mii.utils.get_provider(self.model, self.task) == ModelProvider.DIFFUSERS), "`load_with_sys_mem` is not support with Stable Diffusion"
return self
@model_validator(mode="after")
def zero_dtype_valid(self):
if self.enable_zero:
if self.ds_config.get("fp16", {}).get("enabled", False):
# TODO: We should be able to use DtypeEnum instead of torch.float
assert (
self.dtype == torch.float16
), "ZeRO FP16 enabled, `dtype` must be set to `torch.float16`"
else:
assert (
self.dtype == torch.float32
), "ZeRO FP16 disabled, `dtype` must be set to `torch.float32`"
return self
@model_validator(mode="after")
def deepspeed_or_zero(self):
assert not (
self.enable_deepspeed and self.enable_zero
), "DeepSpeed and ZeRO cannot both be enabled, select only one"
return self
class MIIConfig(DeepSpeedConfigModel):
deployment_name: str
"""
Name of the deployment. Used as an identifier for obtaining a inference
server client and posting queries.
"""
deployment_type: DeploymentType = DeploymentType.LOCAL
"""
One of the `enum mii.DeploymentTypes: [LOCAL]`.
* `LOCAL` uses a grpc server to create a local deployment.
* `NON_PERSISTENT` creates a local deployment that will end when the process exits.
* `AML` will generate the assets necessary to deploy on AML resources.
"""
model_conf: ModelConfig
"""
Configuration for the deployed model(s).
"""
port_number: int = 50050
"""
Port number to use for the load balancer process.
"""
enable_restful_api: bool = False
"""
Enables a RESTful API that can be queries with via http POST method.
"""
restful_api_port: int = 51080
"""
Port number to use for the RESTful API.
"""
hostfile: str = DLTS_HOSTFILE
"""
DeepSpeed hostfile. Will be autogenerated if None is provided.
"""
# TODO: Place AML-related configs in subconfig
version: int = 1
"""
Version number to pass to AML deployments.
"""
instance_type: str = "Standard_NC12s_v3"
"""
AML instance type to use when create AML deployment assets.
"""
@model_validator(mode="after")
def AML_name_valid(self):
if self.deployment_type == DeploymentType.AML:
allowed_chars = set(string.ascii_lowercase + string.ascii_uppercase +
string.digits + "-")
assert (
set(self.deployment_name) <= allowed_chars
), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'."
return self
def generate_replica_configs(self):
# TODO: refactor this function
hostfile = self.hostfile
port_number = self.port_number
torch_dist_port = self.model_conf.torch_dist_port
tensor_parallel = self.model_conf.tensor_parallel
replica_num = self.model_conf.replica_num
replica_pool = _allocate_processes(hostfile, tensor_parallel, replica_num)
replica_configs = []
for i, (hostname, gpu_indices) in enumerate(replica_pool):
# Reserver port for a LB proxy when replication is enabled
port_offset = 1
base_port = port_number + i * tensor_parallel + port_offset
tensor_parallel_ports = list(range(base_port, base_port + tensor_parallel))
replica_torch_dist_port = torch_dist_port + (100 * i)
replica_configs.append(
ReplicaConfig(
hostname=hostname,
tensor_parallel_ports=tensor_parallel_ports,
torch_dist_port=replica_torch_dist_port,
gpu_indices=gpu_indices,
))
self.model_conf.replica_configs = replica_configs
def _allocate_processes(hostfile_path, tensor_parallel, replica_num):
resource_pool = fetch_hostfile(hostfile_path)
assert (
resource_pool is not None and len(resource_pool) > 0
), f"No hosts found in {hostfile_path}"
replica_pool = []
allocated_num = 0
for host, slots in resource_pool.items():
available_on_host = slots
while available_on_host >= tensor_parallel:
if allocated_num >= replica_num:
break
if slots < tensor_parallel:
raise ValueError(
f"Host {host} has {slots} slot(s), but {tensor_parallel} slot(s) are required"
)
allocated_num_on_host = slots - available_on_host
replica_pool.append((
host,
[
i for i in range(
allocated_num_on_host,
allocated_num_on_host + tensor_parallel,
)
],
))
allocated_num += 1
available_on_host -= tensor_parallel
if allocated_num < replica_num:
raise ValueError(
f"Not sufficient GPUs for {replica_num} replica(s), only {allocated_num} replica(s) can be deployed"
)
return replica_pool