diff --git a/dagshub/common/api/repo.py b/dagshub/common/api/repo.py index 69b20253..75624a97 100644 --- a/dagshub/common/api/repo.py +++ b/dagshub/common/api/repo.py @@ -81,7 +81,7 @@ def __init__(self, repo: str, host: Optional[str] = None, auth: Optional[Any] = @retry(retry=retry_if_exception_type(LSInitializingError), wait=wait_fixed(3), stop=stop_after_attempt(5)) def _tenacious_ls_request(self, *args, **kwargs): - res = self.http_request(*args, **kwargs) + res = http_request(*args, **kwargs) if res.text.startswith(""): raise LSInitializingError() elif res.status_code // 100 != 2: diff --git a/dagshub/ls_client.py b/dagshub/ls_client.py index f6d5cbf6..9322645c 100644 --- a/dagshub/ls_client.py +++ b/dagshub/ls_client.py @@ -1,14 +1,51 @@ from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type +from dagshub.data_engine.model.errors import LSInitializingError +from contextlib import _GeneratorContextManager +from dagshub.common.util import lazy_load from json import JSONDecodeError from typing import Optional +from itertools import tee import importlib.util import semver +import types from dagshub.common.api.repo import RepoAPI from dagshub.auth import get_token from dagshub.common import config +ls_sdk = lazy_load("label_studio_sdk") + + +class _TenaciousLSCLientWrapper: + def __init__(self, func): + self.func = func + + @retry( + retry=retry_if_exception_type((LSInitializingError, JSONDecodeError, ls_sdk.core.ApiError)), + wait=wait_fixed(3), + stop=stop_after_attempt(5), + ) + def wrapped_func(self, *args, **kwargs): + res = self.func(*args, **kwargs) + + if isinstance(res, types.GeneratorType): + proxy, res = tee(res) + if next(proxy).startswith(b""): + raise LSInitializingError() + elif isinstance(res, _GeneratorContextManager): + return res + elif isinstance(res, bytes): + if res.startswith(""): + raise LSInitializingError() + else: + if res.text.startswith(""): + raise LSInitializingError() + elif res.status_code // 100 != 2: + raise RuntimeError(f"Process failed! Server Response: {res.text}") + return res + + def _use_legacy_client(): """ https://github.com/HumanSignal/label-studio/releases/tag/1.13.0, \ @@ -61,4 +98,15 @@ def get_label_studio_client( "api_key": token if token is not None else get_token(host=host), } - return LabelStudio(**kwargs) + ls_client = LabelStudio(**kwargs) + if legacy_client: + ls_client.make_request = _TenaciousLSCLientWrapper(ls_client.make_request).wrapped_func + else: + ls_client._client_wrapper.httpx_client.request = _TenaciousLSCLientWrapper( + ls_client._client_wrapper.httpx_client.request + ).wrapped_func + ls_client.projects.exports.create_export = _TenaciousLSCLientWrapper( + ls_client.projects.exports.create_export + ).wrapped_func + + return ls_client diff --git a/requirements-dev.txt b/requirements-dev.txt index 44a2d4d7..e2ea4b50 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,3 +7,4 @@ pytest-mock==3.14.0 fiftyone==0.23.8 datasets==2.19.1 ultralytics==8.3.47 +label-studio-sdk==1.0.8