Skip to content

Commit 542f604

Browse files
authored
cleanup callback (#276)
* cleanup callback Structured callbacks are now available in fsspec>=2024.2.0, so we can remove most of the code here. Almost all of the `dvc_objects.fs.callbacks` is deprecated, and will be removed within 4.x minor release, including TqdmCallback. TqdmCallback can be replaced with fsspec's TqdmCallback, or can be overridden very easily. This will likely be moved to dvc-data for now. Similarly, `dvc_objects._tqdm` module is deprecated and is slated for removal within 4.x minor release. The code will likely move over to dvc-data. * get rid of callback wrappers, add tests for generic transfers * bump fsspec req
1 parent edc151c commit 542f604

File tree

9 files changed

+213
-332
lines changed

9 files changed

+213
-332
lines changed

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ requires-python = ">=3.8"
2323
dynamic = ["version"]
2424
dependencies = [
2525
"funcy>=1.14; python_version < '3.12'",
26-
"fsspec>=2022.10.0",
26+
"fsspec>=2024.2.0",
2727
]
2828

2929
[project.urls]
@@ -104,7 +104,7 @@ module = [
104104
]
105105

106106
[tool.codespell]
107-
ignore-words-list = " "
107+
ignore-words-list = "cachable,"
108108

109109
[tool.ruff]
110110
ignore = [

src/dvc_objects/db.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
cast,
1717
)
1818

19+
from fsspec.callbacks import DEFAULT_CALLBACK
20+
1921
from .errors import ObjectDBPermissionError
20-
from .fs.callbacks import DEFAULT_CALLBACK
2122
from .obj import Object
2223

2324
if TYPE_CHECKING:
25+
from fsspec import Callback
26+
2427
from .fs.base import AnyFSPath, FileSystem
25-
from .fs.callbacks import Callback
2628

2729

2830
logger = logging.getLogger(__name__)

src/dvc_objects/executors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TypeVar,
1919
)
2020

21-
from .fs.callbacks import Callback
21+
from fsspec import Callback
2222

2323
_T = TypeVar("_T")
2424

src/dvc_objects/fs/base.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,12 @@
2727

2828
import fsspec
2929
from fsspec.asyn import get_loop
30+
from fsspec.callbacks import DEFAULT_CALLBACK
3031

3132
from dvc_objects.compat import cached_property
3233
from dvc_objects.executors import ThreadPoolExecutor, batch_coros
3334

34-
from .callbacks import (
35-
DEFAULT_CALLBACK,
36-
wrap_and_branch_callback,
37-
wrap_file,
38-
)
35+
from .callbacks import wrap_file
3936
from .errors import RemoteMissingDepsError
4037

4138
if TYPE_CHECKING:
@@ -706,9 +703,14 @@ def put(
706703

707704
callback.set_size(len(from_infos))
708705
executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True)
706+
707+
def put_file(from_path, to_path):
708+
with callback.branched(from_path, to_path) as child:
709+
return self.put_file(from_path, to_path, callback=child)
710+
709711
with executor:
710-
put_file = wrap_and_branch_callback(callback, self.put_file)
711-
list(executor.imap_unordered(put_file, from_infos, to_infos))
712+
it = executor.imap_unordered(put_file, from_infos, to_infos)
713+
list(callback.wrap(it))
712714

713715
def get(
714716
self,
@@ -724,9 +726,8 @@ def get(
724726

725727
def get_file(rpath, lpath, **kwargs):
726728
localfs.makedirs(localfs.parent(lpath), exist_ok=True)
727-
self.fs.get_file(rpath, lpath, **kwargs)
728-
729-
get_file = wrap_and_branch_callback(callback, get_file)
729+
with callback.branched(rpath, lpath) as child:
730+
self.fs.get_file(rpath, lpath, callback=child, **kwargs)
730731

731732
if isinstance(from_info, list) and isinstance(to_info, list):
732733
from_infos: List[AnyFSPath] = from_info
@@ -737,7 +738,9 @@ def get_file(rpath, lpath, **kwargs):
737738

738739
if not self.isdir(from_info):
739740
callback.set_size(1)
740-
return get_file(from_info, to_info)
741+
get_file(from_info, to_info)
742+
callback.relative_update()
743+
return
741744

742745
from_infos = list(self.find(from_info))
743746
if not from_infos:
@@ -760,7 +763,8 @@ def get_file(rpath, lpath, **kwargs):
760763
callback.set_size(len(from_infos))
761764
executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True)
762765
with executor:
763-
list(executor.imap_unordered(get_file, from_infos, to_infos))
766+
it = executor.imap_unordered(get_file, from_infos, to_infos)
767+
list(callback.wrap(it))
764768

765769
def ukey(self, path: AnyFSPath) -> str:
766770
return self.fs.ukey(path)

src/dvc_objects/fs/callbacks.py

+19-140
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1-
import asyncio
21
from functools import wraps
3-
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Optional, TypeVar, cast
2+
from typing import TYPE_CHECKING, BinaryIO, Optional, Type, cast
43

54
import fsspec
5+
from fsspec.callbacks import DEFAULT_CALLBACK, Callback, NoOpCallback
66

77
if TYPE_CHECKING:
88
from typing import Union
99

10-
from dvc_objects._tqdm import Tqdm
10+
from tqdm import tqdm
1111

12-
F = TypeVar("F", bound=Callable)
12+
13+
__all__ = ["Callback", "NoOpCallback", "TqdmCallback", "DEFAULT_CALLBACK"]
1314

1415

1516
class CallbackStream:
16-
def __init__(self, stream, callback: fsspec.Callback):
17+
def __init__(self, stream, callback: Callback):
1718
self.stream = stream
1819

1920
@wraps(stream.read)
@@ -28,151 +29,29 @@ def __getattr__(self, attr):
2829
return getattr(self.stream, attr)
2930

3031

31-
class ScopedCallback(fsspec.Callback):
32-
def __enter__(self):
33-
return self
34-
35-
def __exit__(self, *exc_args):
36-
self.close()
37-
38-
def close(self):
39-
"""Handle here on exit."""
40-
41-
def branch(
42-
self,
43-
path_1: "Union[str, BinaryIO]",
44-
path_2: str,
45-
kwargs: Dict[str, Any],
46-
child: Optional["Callback"] = None,
47-
) -> "Callback":
48-
child = kwargs["callback"] = child or DEFAULT_CALLBACK
49-
return child
50-
51-
52-
class Callback(ScopedCallback):
53-
def absolute_update(self, value: int) -> None:
54-
value = value if value is not None else self.value
55-
return super().absolute_update(value)
56-
57-
@classmethod
58-
def as_callback(
59-
cls, maybe_callback: Optional[fsspec.Callback] = None
60-
) -> "Callback":
61-
if maybe_callback is None:
62-
return DEFAULT_CALLBACK
63-
if isinstance(maybe_callback, Callback):
64-
return maybe_callback
65-
return FsspecCallbackWrapper(maybe_callback)
66-
67-
68-
class NoOpCallback(Callback, fsspec.callbacks.NoOpCallback):
69-
pass
70-
71-
72-
class TqdmCallback(Callback):
32+
class TqdmCallback(fsspec.callbacks.TqdmCallback):
7333
def __init__(
7434
self,
7535
size: Optional[int] = None,
7636
value: int = 0,
77-
progress_bar: Optional["Tqdm"] = None,
37+
progress_bar: Optional["tqdm"] = None,
38+
tqdm_cls: Optional[Type["tqdm"]] = None,
7839
**tqdm_kwargs,
7940
):
8041
from dvc_objects._tqdm import Tqdm
8142

8243
tqdm_kwargs.pop("total", None)
83-
self._tqdm_kwargs = tqdm_kwargs
84-
self._tqdm_cls = Tqdm
85-
self.tqdm = progress_bar
86-
super().__init__(size=size, value=value)
87-
88-
def close(self):
89-
if self.tqdm is not None:
90-
self.tqdm.close()
91-
self.tqdm = None
92-
93-
def call(self, hook_name=None, **kwargs):
94-
if self.tqdm is None:
95-
self.tqdm = self._tqdm_cls(**self._tqdm_kwargs, total=self.size or -1)
96-
self.tqdm.update_to(self.value, total=self.size)
97-
98-
def branch(
99-
self,
100-
path_1: "Union[str, BinaryIO]",
101-
path_2: str,
102-
kwargs: Dict[str, Any],
103-
child: Optional[Callback] = None,
104-
):
44+
tqdm_cls = tqdm_cls or Tqdm
45+
super().__init__(
46+
tqdm_kwargs=tqdm_kwargs, tqdm_cls=tqdm_cls, size=size, value=value
47+
)
48+
if progress_bar is None:
49+
self.tqdm = progress_bar
50+
51+
def branched(self, path_1: "Union[str, BinaryIO]", path_2: str, **kwargs):
10552
desc = path_1 if isinstance(path_1, str) else path_2
106-
child = child or TqdmCallback(bytes=True, desc=desc)
107-
return super().branch(path_1, path_2, kwargs, child=child)
108-
109-
110-
class FsspecCallbackWrapper(Callback):
111-
def __init__(self, callback: fsspec.Callback):
112-
object.__setattr__(self, "_callback", callback)
113-
114-
def __getattr__(self, name: str):
115-
return getattr(self._callback, name)
53+
return TqdmCallback(bytes=True, desc=desc)
11654

117-
def __setattr__(self, name: str, value: Any):
118-
setattr(self._callback, name, value)
11955

120-
def absolute_update(self, value: int) -> None:
121-
value = value if value is not None else self.value
122-
return self._callback.absolute_update(value)
123-
124-
def branch(
125-
self,
126-
path_1: "Union[str, BinaryIO]",
127-
path_2: str,
128-
kwargs: Dict[str, Any],
129-
child: Optional["Callback"] = None,
130-
) -> "Callback":
131-
if not child:
132-
self._callback.branch(path_1, path_2, kwargs)
133-
child = self.as_callback(kwargs.get("callback"))
134-
return super().branch(path_1, path_2, kwargs, child=child)
135-
136-
137-
def wrap_fn(callback: fsspec.Callback, fn: F) -> F:
138-
@wraps(fn)
139-
async def async_wrapper(*args, **kwargs):
140-
res = await fn(*args, **kwargs)
141-
callback.relative_update()
142-
return res
143-
144-
@wraps(fn)
145-
def sync_wrapper(*args, **kwargs):
146-
res = fn(*args, **kwargs)
147-
callback.relative_update()
148-
return res
149-
150-
return async_wrapper if asyncio.iscoroutinefunction(fn) else sync_wrapper # type: ignore[return-value]
151-
152-
153-
def branch_callback(callback: fsspec.Callback, fn: F) -> F:
154-
callback = Callback.as_callback(callback)
155-
156-
@wraps(fn)
157-
async def async_wrapper(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
158-
with callback.branch(path1, path2, kwargs):
159-
return await fn(path1, path2, **kwargs)
160-
161-
@wraps(fn)
162-
def sync_wrapper(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
163-
with callback.branch(path1, path2, kwargs):
164-
return fn(path1, path2, **kwargs)
165-
166-
return async_wrapper if asyncio.iscoroutinefunction(fn) else sync_wrapper # type: ignore[return-value]
167-
168-
169-
def wrap_and_branch_callback(callback: fsspec.Callback, fn: F) -> F:
170-
branch_wrapper = branch_callback(callback, fn)
171-
return wrap_fn(callback, branch_wrapper)
172-
173-
174-
def wrap_file(file, callback: fsspec.Callback) -> BinaryIO:
56+
def wrap_file(file, callback: Callback) -> BinaryIO:
17557
return cast(BinaryIO, CallbackStream(file, callback))
176-
177-
178-
DEFAULT_CALLBACK = NoOpCallback()

0 commit comments

Comments
 (0)