1
1
from contextlib import ExitStack
2
2
from functools import wraps
3
- from typing import TYPE_CHECKING , Any , Dict , Optional , TypeVar , cast , overload
3
+ from typing import TYPE_CHECKING , Any , Dict , Optional , Protocol , TypeVar , cast , overload
4
4
5
5
import fsspec
6
6
17
17
_R = TypeVar ("_R" )
18
18
19
19
20
- class Callback (fsspec .Callback ):
21
- """Callback usable as a context manager, and a few helper methods."""
20
+ class _CallbackProtocol (Protocol ):
21
+ def relative_update (self , inc : int = 1 ) -> None :
22
+ ...
23
+
24
+ def branch (
25
+ self ,
26
+ path_1 : "Union[str, BinaryIO]" ,
27
+ path_2 : str ,
28
+ kwargs : Dict [str , Any ],
29
+ child : Optional ["Callback" ] = None ,
30
+ ) -> "Callback" :
31
+ ...
22
32
33
+
34
+ class _DVCCallbackMixin (_CallbackProtocol ):
23
35
@overload
24
36
def wrap_attr (self , fobj : "BinaryIO" , method : str = "read" ) -> "BinaryIO" :
25
37
...
@@ -66,7 +78,7 @@ def wrap_and_branch(self, fn: "Callable") -> "Callable":
66
78
@wraps (fn )
67
79
def func (path1 : "Union[str, BinaryIO]" , path2 : str , ** kwargs ):
68
80
kw : Dict [str , Any ] = dict (kwargs )
69
- with self .branch (path1 , path2 , kw ):
81
+ with self .branch (path1 , path2 , kw ): # pylint: disable=not-context-manager
70
82
return wrapped (path1 , path2 , ** kw )
71
83
72
84
return func
@@ -81,7 +93,7 @@ def wrap_and_branch_coro(self, fn: "Callable") -> "Callable":
81
93
@wraps (fn )
82
94
async def func (path1 : "Union[str, BinaryIO]" , path2 : str , ** kwargs ):
83
95
kw : Dict [str , Any ] = dict (kwargs )
84
- with self .branch (path1 , path2 , kw ):
96
+ with self .branch (path1 , path2 , kw ): # pylint: disable=not-context-manager
85
97
return await wrapped (path1 , path2 , ** kw )
86
98
87
99
return func
@@ -95,6 +107,22 @@ def __exit__(self, *exc_args):
95
107
def close (self ):
96
108
"""Handle here on exit."""
97
109
110
+ @classmethod
111
+ def as_tqdm_callback (
112
+ cls ,
113
+ callback : Optional [fsspec .callbacks .Callback ] = None ,
114
+ ** tqdm_kwargs : Any ,
115
+ ) -> "Callback" :
116
+ if callback is None :
117
+ return TqdmCallback (** tqdm_kwargs )
118
+ if isinstance (callback , Callback ):
119
+ return callback
120
+ return cast ("Callback" , _FsspecCallbackWrapper (callback ))
121
+
122
+
123
+ class Callback (fsspec .Callback , _DVCCallbackMixin ):
124
+ """Callback usable as a context manager, and a few helper methods."""
125
+
98
126
def relative_update (self , inc : int = 1 ) -> None :
99
127
inc = inc if inc is not None else 0
100
128
return super ().relative_update (inc )
@@ -104,18 +132,14 @@ def absolute_update(self, value: int) -> None:
104
132
return super ().absolute_update (value )
105
133
106
134
@classmethod
107
- def as_callback (cls , maybe_callback : Optional ["Callback" ] = None ) -> "Callback" :
135
+ def as_callback (
136
+ cls , maybe_callback : Optional [fsspec .callbacks .Callback ] = None
137
+ ) -> "Callback" :
108
138
if maybe_callback is None :
109
139
return DEFAULT_CALLBACK
110
- return maybe_callback
111
-
112
- @classmethod
113
- def as_tqdm_callback (
114
- cls ,
115
- callback : Optional ["Callback" ] = None ,
116
- ** tqdm_kwargs : Any ,
117
- ) -> "Callback" :
118
- return callback or TqdmCallback (** tqdm_kwargs )
140
+ if isinstance (maybe_callback , Callback ):
141
+ return maybe_callback
142
+ return _FsspecCallbackWrapper (maybe_callback )
119
143
120
144
def branch ( # pylint: disable=arguments-differ
121
145
self ,
@@ -185,4 +209,20 @@ def branch(
185
209
return super ().branch (path_1 , path_2 , kwargs , child = child )
186
210
187
211
212
+ class _FsspecCallbackWrapper (fsspec .callbacks .Callback , _DVCCallbackMixin ):
213
+ def __init__ ( # pylint: disable=super-init-not-called
214
+ self , callback : fsspec .callbacks .Callback
215
+ ):
216
+ object .__setattr__ (self , "_callback" , callback )
217
+
218
+ def __getattr__ (self , name : str ):
219
+ return getattr (self ._callback , name )
220
+
221
+ def __setattr__ (self , name : str , value : Any ):
222
+ setattr (self ._callback , name , value )
223
+
224
+ def branch (self , * args , ** kwargs ):
225
+ return _FsspecCallbackWrapper (self ._callback .branch (* args , ** kwargs ))
226
+
227
+
188
228
DEFAULT_CALLBACK = NoOpCallback ()
0 commit comments