|
1 | 1 | import asyncio
|
2 |
| -import queue |
3 |
| -import sys |
4 | 2 | from collections.abc import Coroutine, Iterable, Iterator, Sequence
|
5 | 3 | from concurrent import futures
|
6 | 4 | from itertools import islice
|
7 |
| -from typing import ( |
8 |
| - Any, |
9 |
| - Callable, |
10 |
| - Optional, |
11 |
| - TypeVar, |
12 |
| -) |
| 5 | +from typing import Any, Callable, Optional, TypeVar |
13 | 6 |
|
14 | 7 | from fsspec import Callback
|
15 | 8 |
|
16 | 9 | _T = TypeVar("_T")
|
17 | 10 |
|
18 | 11 |
|
19 | 12 | class ThreadPoolExecutor(futures.ThreadPoolExecutor):
|
20 |
| - _max_workers: int |
21 |
| - |
22 | 13 | def __init__(
|
23 | 14 | self, max_workers: Optional[int] = None, cancel_on_error: bool = False, **kwargs
|
24 | 15 | ):
|
25 | 16 | super().__init__(max_workers=max_workers, **kwargs)
|
26 | 17 | self._cancel_on_error = cancel_on_error
|
27 | 18 |
|
28 |
| - @property |
29 |
| - def max_workers(self) -> int: |
30 |
| - return self._max_workers |
31 |
| - |
32 | 19 | def imap_unordered(
|
33 | 20 | self, fn: Callable[..., _T], *iterables: Iterable[Any]
|
34 | 21 | ) -> Iterator[_T]:
|
35 | 22 | """Lazier version of map that does not preserve ordering of results.
|
36 | 23 |
|
37 | 24 | It does not create all the futures at once to reduce memory usage.
|
38 | 25 | """
|
39 |
| - |
40 | 26 | it = zip(*iterables)
|
41 |
| - if self.max_workers == 1: |
| 27 | + if self._max_workers == 1: |
42 | 28 | for args in it:
|
43 | 29 | yield fn(*args)
|
44 | 30 | return
|
45 | 31 |
|
46 | 32 | def create_taskset(n: int) -> set[futures.Future]:
|
47 | 33 | return {self.submit(fn, *args) for args in islice(it, n)}
|
48 | 34 |
|
49 |
| - tasks = create_taskset(self.max_workers * 5) |
| 35 | + tasks = create_taskset(self._max_workers * 5) |
50 | 36 | while tasks:
|
51 | 37 | done, tasks = futures.wait(tasks, return_when=futures.FIRST_COMPLETED)
|
52 | 38 | for fut in done:
|
53 | 39 | yield fut.result()
|
54 | 40 | tasks.update(create_taskset(len(done)))
|
55 | 41 |
|
56 |
| - def shutdown(self, wait=True, *, cancel_futures=False): |
57 |
| - if sys.version_info > (3, 9): |
58 |
| - return super().shutdown(wait=wait, cancel_futures=cancel_futures) |
59 |
| - else: # noqa: RET505 |
60 |
| - with self._shutdown_lock: |
61 |
| - self._shutdown = True |
62 |
| - if cancel_futures: |
63 |
| - # Drain all work items from the queue, and then cancel their |
64 |
| - # associated futures. |
65 |
| - while True: |
66 |
| - try: |
67 |
| - work_item = self._work_queue.get_nowait() |
68 |
| - except queue.Empty: |
69 |
| - break |
70 |
| - if work_item is not None: |
71 |
| - work_item.future.cancel() |
72 |
| - |
73 |
| - # Send a wake-up to prevent threads calling |
74 |
| - # _work_queue.get(block=True) from permanently blocking. |
75 |
| - self._work_queue.put(None) # type: ignore[arg-type] |
76 |
| - if wait: |
77 |
| - for t in self._threads: |
78 |
| - t.join() |
79 |
| - |
80 | 42 | def __exit__(self, exc_type, exc_val, exc_tb):
|
81 |
| - if self._cancel_on_error: |
82 |
| - self.shutdown(wait=True, cancel_futures=exc_val is not None) |
83 |
| - else: |
84 |
| - self.shutdown(wait=True) |
| 43 | + cancel_futures = self._cancel_on_error and exc_val is not None |
| 44 | + self.shutdown(wait=True, cancel_futures=cancel_futures) |
85 | 45 | return False
|
86 | 46 |
|
87 | 47 |
|
|
0 commit comments