Skip to content

Commit c534d34

Browse files
authoredSep 19, 2023
Merge pull request #48 from roboflow/fix/license-server-bug
License Server Bug Fix
2 parents 2f8f773 + ef3eca3 commit c534d34

File tree

5 files changed

+41
-44
lines changed

5 files changed

+41
-44
lines changed
 

‎.github/workflows/test.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ jobs:
3636
with:
3737
push: false
3838
tags: roboflow/roboflow-inference-server-cpu:test
39-
cache-from: type=registry,ref=roboflow/roboflow-inference-server-cpu:test
40-
cache-to: type=registry,ref=roboflow/roboflow-inference-server-cpu:test,mode=max
39+
cache-from: type=registry,ref=roboflow/roboflow-inference-server-cpu:test-cache
40+
cache-to: type=registry,ref=roboflow/roboflow-inference-server-cpu:test-cache,mode=max
4141
platforms: linux/amd64
4242
file: ./docker/dockerfiles/Dockerfile.onnx.cpu
4343
outputs: type=docker

‎inference/core/models/roboflow.py

+10-38
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
AWS_SECRET_ACCESS_KEY,
2929
CORE_MODEL_BUCKET,
3030
INFER_BUCKET,
31-
LICENSE_SERVER,
3231
MODEL_CACHE_DIR,
3332
ONNXRUNTIME_EXECUTION_PROVIDERS,
3433
REQUIRED_ONNX_PROVIDERS,
@@ -43,6 +42,7 @@
4342
from inference.core.utils.image_utils import load_image
4443
from inference.core.utils.onnx import get_onnxruntime_execution_providers
4544
from inference.core.utils.preprocess import prepare
45+
from inference.core.utils.url_utils import ApiUrl
4646

4747
if AWS_ACCESS_KEY_ID and AWS_ACCESS_KEY_ID:
4848
import boto3
@@ -337,12 +337,9 @@ def get_model_artifacts(self) -> None:
337337
else:
338338
self.log("Downloading model artifacts from Roboflow API")
339339
# AWS Keys are not available so we use the API Key to hit the Roboflow API which returns a signed link for downloading model artifacts
340-
self.api_url = f"{API_BASE_URL}/ort/{self.endpoint}?api_key={self.api_key}&device={self.device_id}&nocache=true&dynamic=true"
341-
if LICENSE_SERVER:
342-
self.api_url = (
343-
f"http://{LICENSE_SERVER}/proxy?url="
344-
+ urllib.parse.quote(self.api_url, safe="~()*!'")
345-
)
340+
self.api_url = ApiUrl(
341+
f"{API_BASE_URL}/ort/{self.endpoint}?api_key={self.api_key}&device={self.device_id}&nocache=true&dynamic=true"
342+
)
346343
api_data = get_api_data(self.api_url)
347344
if "ort" not in api_data.keys():
348345
raise TensorrtRoboflowAPIError(
@@ -359,24 +356,8 @@ def get_model_artifacts(self) -> None:
359356
if "colors" in api_data:
360357
self.colors = api_data["colors"]
361358

362-
if LICENSE_SERVER:
363-
license_server_base_url = f"http://{LICENSE_SERVER}/proxy?url="
364-
weights_url = license_server_base_url + urllib.parse.quote(
365-
api_data["model"], safe="~()*!'"
366-
)
367-
368-
def get_env_url(api_data):
369-
return license_server_base_url + urllib.parse.quote(
370-
api_data["environment"], safe="~()*!'"
371-
)
372-
373-
else:
374-
weights_url = api_data["model"]
375-
376-
def get_env_url(api_data):
377-
return api_data["environment"]
378-
379359
t1 = perf_counter()
360+
weights_url = ApiUrl(api_data["model"])
380361
r = requests.get(weights_url)
381362
with self.open_cache(self.weights_file, "wb") as f:
382363
f.write(r.content)
@@ -385,7 +366,7 @@ def get_env_url(api_data):
385366
"Weights download took longer than 120 seconds, refreshing API request"
386367
)
387368
api_data = get_api_data(self.api_url)
388-
env_url = get_env_url(api_data)
369+
env_url = ApiUrl(api_data["environment"])
389370
self.environment = requests.get(env_url).json()
390371
with open(self.cache_file("environment.json"), "w") as f:
391372
json.dump(self.environment, f)
@@ -625,12 +606,9 @@ def download_weights(self) -> None:
625606
raise Exception(f"Failed to download model artifacts.")
626607
else:
627608
# AWS Keys are not available so we use the API Key to hit the Roboflow API which returns a signed link for downloading model artifacts
628-
self.api_url = f"{API_BASE_URL}/core_model/{self.endpoint}?api_key={self.api_key}&device={self.device_id}&nocache=true"
629-
if LICENSE_SERVER:
630-
self.api_url = (
631-
f"http://{LICENSE_SERVER}/proxy?url="
632-
+ urllib.parse.quote(self.api_url, safe="~()*!'")
633-
)
609+
self.api_url = ApiUrl(
610+
f"{API_BASE_URL}/core_model/{self.endpoint}?api_key={self.api_key}&device={self.device_id}&nocache=true"
611+
)
634612
api_data = get_api_data(self.api_url)
635613
if "weights" not in api_data.keys():
636614
raise TensorrtRoboflowAPIError(
@@ -640,13 +618,7 @@ def download_weights(self) -> None:
640618
weights_url_keys = api_data["weights"].keys()
641619

642620
for weights_url_key in weights_url_keys:
643-
if LICENSE_SERVER:
644-
license_server_base_url = f"http://{LICENSE_SERVER}/proxy?url="
645-
weights_url = license_server_base_url + urllib.parse.quote(
646-
api_data["weights"][weights_url_key], safe="~()*!'"
647-
)
648-
else:
649-
weights_url = api_data["weights"][weights_url_key]
621+
weights_url = ApiUrl(api_data["weights"][weights_url_key])
650622
t1 = perf_counter()
651623
attempts = 0
652624
success = False

‎inference/core/registries/roboflow.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from inference.core.exceptions import DatasetLoadError, WorkspaceLoadError
88
from inference.core.models.base import Model
99
from inference.core.registries.base import ModelRegistry
10+
from inference.core.utils.url_utils import ApiUrl
1011

1112
MODEL_TYPE_DEFAULTS = {
1213
"object-detection": "yolov5v2s",
@@ -72,7 +73,8 @@ def get_model_type(model_id: str, api_key: str) -> str:
7273
model_type = cache_data["model_type"]
7374
return project_task_type, model_type
7475

75-
api_key_info = requests.get("/".join([API_BASE_URL, f"?api_key={api_key}"]))
76+
api_url = ApiUrl("/".join([API_BASE_URL, f"?api_key={api_key}"]))
77+
api_key_info = requests.get(api_url)
7678
try:
7779
api_key_info.raise_for_status()
7880
except requests.exceptions.HTTPError as e:
@@ -84,12 +86,13 @@ def get_model_type(model_id: str, api_key: str) -> str:
8486
if workspace_id is None:
8587
raise WorkspaceLoadError(f"Empty workspace, check your API key")
8688

87-
dataset_info = requests.get(
89+
api_url = ApiUrl(
8890
"/".join(
8991
[API_BASE_URL, workspace_id, dataset_id, f"?api_key={api_key}&nocache=true"]
9092
)
9193
)
92-
version_info = requests.get(
94+
dataset_info = requests.get(api_url)
95+
api_url = ApiUrl(
9396
"/".join(
9497
[
9598
API_BASE_URL,
@@ -100,6 +103,7 @@ def get_model_type(model_id: str, api_key: str) -> str:
100103
]
101104
)
102105
)
106+
version_info = requests.get(api_url)
103107
try:
104108
dataset_info.raise_for_status()
105109
version_info.raise_for_status()

‎inference/core/utils/url_utils.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import urllib
2+
3+
from inference.core.env import LICENSE_SERVER
4+
5+
6+
def ProxyUrl(url):
7+
"""Returns a proxied URL if according to LICENSE_SERVER settings"""
8+
return f"http://{LICENSE_SERVER}/proxy?url=" + urllib.parse.quote(
9+
url, safe="~()*!'"
10+
)
11+
12+
13+
def RawUrl(url):
14+
"""Returns a raw URL"""
15+
return url
16+
17+
18+
if LICENSE_SERVER:
19+
ApiUrl = ProxyUrl
20+
else:
21+
ApiUrl = RawUrl

‎inference/core/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.8.4"
1+
__version__ = "0.8.5"
22

33
if __name__ == "__main__":
44
print(__version__)

0 commit comments

Comments
 (0)