Skip to content

Commit 2cae0e6

Browse files
committed
callbacks: wrap non-dvc callbacks passed via fsspec
1 parent fdb3851 commit 2cae0e6

File tree

2 files changed

+78
-16
lines changed

2 files changed

+78
-16
lines changed

src/dvc_objects/fs/callbacks.py

+55-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from contextlib import ExitStack
22
from functools import wraps
3-
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, cast, overload
3+
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, TypeVar, cast, overload
44

55
import fsspec
66

@@ -17,9 +17,21 @@
1717
_R = TypeVar("_R")
1818

1919

20-
class Callback(fsspec.Callback):
21-
"""Callback usable as a context manager, and a few helper methods."""
20+
class _CallbackProtocol(Protocol):
21+
def relative_update(self, inc: int = 1) -> None:
22+
...
23+
24+
def branch(
25+
self,
26+
path_1: "Union[str, BinaryIO]",
27+
path_2: str,
28+
kwargs: Dict[str, Any],
29+
child: Optional["Callback"] = None,
30+
) -> "Callback":
31+
...
2232

33+
34+
class _DVCCallbackMixin(_CallbackProtocol):
2335
@overload
2436
def wrap_attr(self, fobj: "BinaryIO", method: str = "read") -> "BinaryIO":
2537
...
@@ -66,7 +78,7 @@ def wrap_and_branch(self, fn: "Callable") -> "Callable":
6678
@wraps(fn)
6779
def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
6880
kw: Dict[str, Any] = dict(kwargs)
69-
with self.branch(path1, path2, kw):
81+
with self.branch(path1, path2, kw): # pylint: disable=not-context-manager
7082
return wrapped(path1, path2, **kw)
7183

7284
return func
@@ -81,7 +93,7 @@ def wrap_and_branch_coro(self, fn: "Callable") -> "Callable":
8193
@wraps(fn)
8294
async def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
8395
kw: Dict[str, Any] = dict(kwargs)
84-
with self.branch(path1, path2, kw):
96+
with self.branch(path1, path2, kw): # pylint: disable=not-context-manager
8597
return await wrapped(path1, path2, **kw)
8698

8799
return func
@@ -95,6 +107,22 @@ def __exit__(self, *exc_args):
95107
def close(self):
96108
"""Handle here on exit."""
97109

110+
@classmethod
111+
def as_tqdm_callback(
112+
cls,
113+
callback: Optional[fsspec.callbacks.Callback] = None,
114+
**tqdm_kwargs: Any,
115+
) -> "Callback":
116+
if callback is None:
117+
return TqdmCallback(**tqdm_kwargs)
118+
if isinstance(callback, Callback):
119+
return callback
120+
return cast("Callback", _FsspecCallbackWrapper(callback))
121+
122+
123+
class Callback(fsspec.Callback, _DVCCallbackMixin):
124+
"""Callback usable as a context manager, and a few helper methods."""
125+
98126
def relative_update(self, inc: int = 1) -> None:
99127
inc = inc if inc is not None else 0
100128
return super().relative_update(inc)
@@ -104,18 +132,14 @@ def absolute_update(self, value: int) -> None:
104132
return super().absolute_update(value)
105133

106134
@classmethod
107-
def as_callback(cls, maybe_callback: Optional["Callback"] = None) -> "Callback":
135+
def as_callback(
136+
cls, maybe_callback: Optional[fsspec.callbacks.Callback] = None
137+
) -> "Callback":
108138
if maybe_callback is None:
109139
return DEFAULT_CALLBACK
110-
return maybe_callback
111-
112-
@classmethod
113-
def as_tqdm_callback(
114-
cls,
115-
callback: Optional["Callback"] = None,
116-
**tqdm_kwargs: Any,
117-
) -> "Callback":
118-
return callback or TqdmCallback(**tqdm_kwargs)
140+
if isinstance(maybe_callback, Callback):
141+
return maybe_callback
142+
return _FsspecCallbackWrapper(maybe_callback)
119143

120144
def branch( # pylint: disable=arguments-differ
121145
self,
@@ -185,4 +209,20 @@ def branch(
185209
return super().branch(path_1, path_2, kwargs, child=child)
186210

187211

212+
class _FsspecCallbackWrapper(fsspec.callbacks.Callback, _DVCCallbackMixin):
213+
def __init__( # pylint: disable=super-init-not-called
214+
self, callback: fsspec.callbacks.Callback
215+
):
216+
object.__setattr__(self, "_callback", callback)
217+
218+
def __getattr__(self, name: str):
219+
return getattr(self._callback, name)
220+
221+
def __setattr__(self, name: str, value: Any):
222+
setattr(self._callback, name, value)
223+
224+
def branch(self, *args, **kwargs):
225+
return _FsspecCallbackWrapper(self._callback.branch(*args, **kwargs))
226+
227+
188228
DEFAULT_CALLBACK = NoOpCallback()

tests/fs/test_callbacks.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from typing import Optional
2+
3+
import fsspec
14
import pytest
25

3-
from dvc_objects.fs.callbacks import DEFAULT_CALLBACK, Callback
6+
from dvc_objects.fs.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
47

58

69
@pytest.mark.parametrize("api", ["set_size", "relative_update", "absolute_update"])
@@ -29,3 +32,22 @@ def test_callback_with_none(request, api, callback_factory, kwargs, mocker):
2932
if callback is not DEFAULT_CALLBACK:
3033
assert callback.size is None
3134
assert callback.value == 0
35+
36+
37+
def test_wrap_fsspec():
38+
def _branch_fn(*args, callback: Optional["Callback"] = None, **kwargs):
39+
pass
40+
41+
callback = fsspec.callbacks.Callback()
42+
assert callback.value == 0
43+
with Callback.as_tqdm_callback(callback) as cb:
44+
assert not isinstance(cb, TqdmCallback)
45+
assert cb.value == 0
46+
cb.relative_update()
47+
assert cb.value == 1
48+
assert callback.value == 1
49+
50+
fn = cb.wrap_and_branch(_branch_fn)
51+
fn("foo", "bar", callback=callback)
52+
assert cb.value == 2
53+
assert callback.value == 2

0 commit comments

Comments
 (0)