-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathmodel_endpoints.py
161 lines (129 loc) · 5.3 KB
/
model_endpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Contains various input and output types relating to Model Bundles for the server.
TODO figure out how to do: (or if we want to do it)
List model endpoint history: GET model-endpoints/<endpoint id>/history
Read model endpoint creation logs: GET model-endpoints/<endpoint id>/creation-logs
"""
import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from model_engine_server.common.dtos.core import HttpUrlStr
from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field
from model_engine_server.domain.entities import (
CallbackAuth,
CpuSpecificationType,
GpuType,
ModelEndpointDeploymentState,
ModelEndpointResourceState,
ModelEndpointsSchema,
ModelEndpointStatus,
ModelEndpointType,
StorageSpecificationType,
)
class BrokerType(str, Enum):
"""
The list of available broker types for async endpoints.
"""
REDIS = "redis"
REDIS_24H = "redis_24h"
SQS = "sqs"
SERVICEBUS = "servicebus"
class BrokerName(str, Enum):
"""
The list of available broker names for async endpoints.
Broker name is only used in endpoint k8s annotations for the celery autoscaler.
"""
REDIS = "redis-message-broker-master"
REDIS_GCP = "redis-gcp-memorystore-message-broker-master"
SQS = "sqs-message-broker-master"
SERVICEBUS = "servicebus-message-broker-master"
class CreateModelEndpointV1Request(BaseModel):
name: str = Field(..., max_length=63)
model_bundle_id: str
endpoint_type: ModelEndpointType
metadata: Dict[str, Any] # TODO: JSON type
post_inference_hooks: Optional[List[str]] = None
cpus: CpuSpecificationType
gpus: int = Field(..., ge=0)
memory: StorageSpecificationType
gpu_type: Optional[GpuType] = None
storage: StorageSpecificationType
nodes_per_worker: int = Field(gt=0, default=1)
optimize_costs: Optional[bool] = None
min_workers: int = Field(..., ge=0)
max_workers: int = Field(..., ge=0)
per_worker: int = Field(..., gt=0)
concurrent_requests_per_worker: Optional[int] = Field(
default=None, gt=0
) # will default to per_worker
labels: Dict[str, str]
prewarm: Optional[bool] = None
high_priority: Optional[bool] = None
billing_tags: Optional[Dict[str, Any]] = None
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = Field(default=False)
class CreateModelEndpointV1Response(BaseModel):
endpoint_creation_task_id: str
class UpdateModelEndpointV1Request(BaseModel):
model_bundle_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None # TODO: JSON type
post_inference_hooks: Optional[List[str]] = None
cpus: Optional[CpuSpecificationType] = None
gpus: Optional[int] = Field(default=None, ge=0)
memory: Optional[StorageSpecificationType] = None
gpu_type: Optional[GpuType] = None
storage: Optional[StorageSpecificationType] = None
optimize_costs: Optional[bool] = None
min_workers: Optional[int] = Field(default=None, ge=0)
max_workers: Optional[int] = Field(default=None, ge=0)
per_worker: Optional[int] = Field(default=None, gt=0)
concurrent_requests_per_worker: Optional[int] = Field(default=None, gt=0)
labels: Optional[Dict[str, str]] = None
prewarm: Optional[bool] = None
high_priority: Optional[bool] = None
billing_tags: Optional[Dict[str, Any]] = None
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = None
class UpdateModelEndpointV1Response(BaseModel):
endpoint_creation_task_id: str
class GetModelEndpointV1Response(BaseModel):
id: str
name: str
endpoint_type: ModelEndpointType
destination: str
deployment_name: Optional[str] = Field(default=None)
metadata: Optional[Dict[str, Any]] = Field(default=None) # TODO: JSON type
bundle_name: str
status: ModelEndpointStatus
post_inference_hooks: Optional[List[str]] = Field(default=None)
default_callback_url: Optional[HttpUrlStr] = Field(default=None)
default_callback_auth: Optional[CallbackAuth] = Field(default=None)
labels: Optional[Dict[str, str]] = Field(default=None)
aws_role: Optional[str] = Field(default=None)
results_s3_bucket: Optional[str] = Field(default=None)
created_by: str
created_at: datetime.datetime
last_updated_at: datetime.datetime
deployment_state: Optional[ModelEndpointDeploymentState] = Field(default=None)
resource_state: Optional[ModelEndpointResourceState] = Field(default=None)
num_queued_items: Optional[int] = Field(default=None)
public_inference: Optional[bool] = Field(default=None)
class ListModelEndpointsV1Response(BaseModel):
model_endpoints: List[GetModelEndpointV1Response]
class DeleteModelEndpointV1Response(BaseModel):
deleted: bool
class RestartModelEndpointV1Response(BaseModel):
restarted: bool
class ModelEndpointOrderBy(str, Enum):
"""
The canonical list of possible orderings of Model Bundles.
"""
NEWEST = "newest"
OLDEST = "oldest"
ALPHABETICAL = "alphabetical"
class GetModelEndpointsSchemaV1Response(BaseModel):
model_config = ConfigDict(protected_namespaces=())
model_endpoints_schema: ModelEndpointsSchema
# TODO history + creation logs