1
- import asyncio
2
1
from functools import wraps
3
- from typing import TYPE_CHECKING , Any , BinaryIO , Callable , Dict , Optional , TypeVar , cast
2
+ from typing import TYPE_CHECKING , BinaryIO , Optional , Type , cast
4
3
5
4
import fsspec
5
+ from fsspec .callbacks import DEFAULT_CALLBACK , Callback , NoOpCallback
6
6
7
7
if TYPE_CHECKING :
8
8
from typing import Union
9
9
10
- from dvc_objects . _tqdm import Tqdm
10
+ from tqdm import tqdm
11
11
12
- F = TypeVar ("F" , bound = Callable )
12
+
13
+ __all__ = ["Callback" , "NoOpCallback" , "TqdmCallback" , "DEFAULT_CALLBACK" ]
13
14
14
15
15
16
class CallbackStream :
16
- def __init__ (self , stream , callback : fsspec . Callback ):
17
+ def __init__ (self , stream , callback : Callback ):
17
18
self .stream = stream
18
19
19
20
@wraps (stream .read )
@@ -28,151 +29,29 @@ def __getattr__(self, attr):
28
29
return getattr (self .stream , attr )
29
30
30
31
31
- class ScopedCallback (fsspec .Callback ):
32
- def __enter__ (self ):
33
- return self
34
-
35
- def __exit__ (self , * exc_args ):
36
- self .close ()
37
-
38
- def close (self ):
39
- """Handle here on exit."""
40
-
41
- def branch (
42
- self ,
43
- path_1 : "Union[str, BinaryIO]" ,
44
- path_2 : str ,
45
- kwargs : Dict [str , Any ],
46
- child : Optional ["Callback" ] = None ,
47
- ) -> "Callback" :
48
- child = kwargs ["callback" ] = child or DEFAULT_CALLBACK
49
- return child
50
-
51
-
52
- class Callback (ScopedCallback ):
53
- def absolute_update (self , value : int ) -> None :
54
- value = value if value is not None else self .value
55
- return super ().absolute_update (value )
56
-
57
- @classmethod
58
- def as_callback (
59
- cls , maybe_callback : Optional [fsspec .Callback ] = None
60
- ) -> "Callback" :
61
- if maybe_callback is None :
62
- return DEFAULT_CALLBACK
63
- if isinstance (maybe_callback , Callback ):
64
- return maybe_callback
65
- return FsspecCallbackWrapper (maybe_callback )
66
-
67
-
68
- class NoOpCallback (Callback , fsspec .callbacks .NoOpCallback ):
69
- pass
70
-
71
-
72
- class TqdmCallback (Callback ):
32
+ class TqdmCallback (fsspec .callbacks .TqdmCallback ):
73
33
def __init__ (
74
34
self ,
75
35
size : Optional [int ] = None ,
76
36
value : int = 0 ,
77
- progress_bar : Optional ["Tqdm" ] = None ,
37
+ progress_bar : Optional ["tqdm" ] = None ,
38
+ tqdm_cls : Optional [Type ["tqdm" ]] = None ,
78
39
** tqdm_kwargs ,
79
40
):
80
41
from dvc_objects ._tqdm import Tqdm
81
42
82
43
tqdm_kwargs .pop ("total" , None )
83
- self ._tqdm_kwargs = tqdm_kwargs
84
- self ._tqdm_cls = Tqdm
85
- self .tqdm = progress_bar
86
- super ().__init__ (size = size , value = value )
87
-
88
- def close (self ):
89
- if self .tqdm is not None :
90
- self .tqdm .close ()
91
- self .tqdm = None
92
-
93
- def call (self , hook_name = None , ** kwargs ):
94
- if self .tqdm is None :
95
- self .tqdm = self ._tqdm_cls (** self ._tqdm_kwargs , total = self .size or - 1 )
96
- self .tqdm .update_to (self .value , total = self .size )
97
-
98
- def branch (
99
- self ,
100
- path_1 : "Union[str, BinaryIO]" ,
101
- path_2 : str ,
102
- kwargs : Dict [str , Any ],
103
- child : Optional [Callback ] = None ,
104
- ):
44
+ tqdm_cls = tqdm_cls or Tqdm
45
+ super ().__init__ (
46
+ tqdm_kwargs = tqdm_kwargs , tqdm_cls = tqdm_cls , size = size , value = value
47
+ )
48
+ if progress_bar is None :
49
+ self .tqdm = progress_bar
50
+
51
+ def branched (self , path_1 : "Union[str, BinaryIO]" , path_2 : str , ** kwargs ):
105
52
desc = path_1 if isinstance (path_1 , str ) else path_2
106
- child = child or TqdmCallback (bytes = True , desc = desc )
107
- return super ().branch (path_1 , path_2 , kwargs , child = child )
108
-
109
-
110
- class FsspecCallbackWrapper (Callback ):
111
- def __init__ (self , callback : fsspec .Callback ):
112
- object .__setattr__ (self , "_callback" , callback )
113
-
114
- def __getattr__ (self , name : str ):
115
- return getattr (self ._callback , name )
53
+ return TqdmCallback (bytes = True , desc = desc )
116
54
117
- def __setattr__ (self , name : str , value : Any ):
118
- setattr (self ._callback , name , value )
119
55
120
- def absolute_update (self , value : int ) -> None :
121
- value = value if value is not None else self .value
122
- return self ._callback .absolute_update (value )
123
-
124
- def branch (
125
- self ,
126
- path_1 : "Union[str, BinaryIO]" ,
127
- path_2 : str ,
128
- kwargs : Dict [str , Any ],
129
- child : Optional ["Callback" ] = None ,
130
- ) -> "Callback" :
131
- if not child :
132
- self ._callback .branch (path_1 , path_2 , kwargs )
133
- child = self .as_callback (kwargs .get ("callback" ))
134
- return super ().branch (path_1 , path_2 , kwargs , child = child )
135
-
136
-
137
- def wrap_fn (callback : fsspec .Callback , fn : F ) -> F :
138
- @wraps (fn )
139
- async def async_wrapper (* args , ** kwargs ):
140
- res = await fn (* args , ** kwargs )
141
- callback .relative_update ()
142
- return res
143
-
144
- @wraps (fn )
145
- def sync_wrapper (* args , ** kwargs ):
146
- res = fn (* args , ** kwargs )
147
- callback .relative_update ()
148
- return res
149
-
150
- return async_wrapper if asyncio .iscoroutinefunction (fn ) else sync_wrapper # type: ignore[return-value]
151
-
152
-
153
- def branch_callback (callback : fsspec .Callback , fn : F ) -> F :
154
- callback = Callback .as_callback (callback )
155
-
156
- @wraps (fn )
157
- async def async_wrapper (path1 : "Union[str, BinaryIO]" , path2 : str , ** kwargs ):
158
- with callback .branch (path1 , path2 , kwargs ):
159
- return await fn (path1 , path2 , ** kwargs )
160
-
161
- @wraps (fn )
162
- def sync_wrapper (path1 : "Union[str, BinaryIO]" , path2 : str , ** kwargs ):
163
- with callback .branch (path1 , path2 , kwargs ):
164
- return fn (path1 , path2 , ** kwargs )
165
-
166
- return async_wrapper if asyncio .iscoroutinefunction (fn ) else sync_wrapper # type: ignore[return-value]
167
-
168
-
169
- def wrap_and_branch_callback (callback : fsspec .Callback , fn : F ) -> F :
170
- branch_wrapper = branch_callback (callback , fn )
171
- return wrap_fn (callback , branch_wrapper )
172
-
173
-
174
- def wrap_file (file , callback : fsspec .Callback ) -> BinaryIO :
56
+ def wrap_file (file , callback : Callback ) -> BinaryIO :
175
57
return cast (BinaryIO , CallbackStream (file , callback ))
176
-
177
-
178
- DEFAULT_CALLBACK = NoOpCallback ()
0 commit comments