Skip to content

Commit b6730ee

Browse files
authored
fix: Changing update function wrapper to not wrap provided parameters. (#15)
Refs: #12
1 parent 765f4df commit b6730ee

File tree

5 files changed

+40
-48
lines changed

5 files changed

+40
-48
lines changed

zcollection/collection/__init__.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -556,24 +556,21 @@ def update(
556556
lock=True))
557557

558558
local_func: WrappedPartitionCallable = _wrap_update_func(
559-
*args,
560559
delayed=delayed,
561560
func=func,
562561
fs=self.fs,
563562
immutable=self._immutable,
563+
selected_variables=selected_variables
564+
) if depth == 0 else _wrap_update_func_with_overlap(
565+
delayed=delayed,
566+
depth=depth,
567+
dim=self.partition_properties.dim,
568+
func=func,
569+
fs=self.fs,
570+
immutable=self._immutable,
571+
selected_partitions=selected_partitions,
564572
selected_variables=selected_variables,
565-
**kwargs) if depth == 0 else _wrap_update_func_with_overlap(
566-
*args,
567-
delayed=delayed,
568-
depth=depth,
569-
dim=self.partition_properties.dim,
570-
func=func,
571-
fs=self.fs,
572-
immutable=self._immutable,
573-
selected_partitions=selected_partitions,
574-
selected_variables=selected_variables,
575-
trim=trim,
576-
**kwargs)
573+
trim=trim)
577574

578575
client: dask.distributed.Client = dask_utils.get_client()
579576

@@ -582,7 +579,11 @@ def update(
582579
or dask_utils.dask_workers(client, cores_only=True))
583580
storage.execute_transaction(
584581
client, self.synchronizer,
585-
client.map(local_func, tuple(batches), key=func.__name__))
582+
client.map(local_func,
583+
tuple(batches),
584+
key=func.__name__,
585+
func_args=args,
586+
func_kwargs=kwargs))
586587
tuple(map(self.fs.invalidate_cache, selected_partitions))
587588

588589
def drop_variable(

zcollection/collection/callable_objects.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
#: Function type to load and call a callback function of type
1818
#: :class:`PartitionCallable`.
19-
WrappedPartitionCallable = Callable[[Sequence[str]], None]
19+
WrappedPartitionCallable = Callable[[Sequence[str], list[Any], dict[str, Any]],
20+
None]
2021

2122

2223
#: pylint: disable=too-few-public-methods

zcollection/collection/detail.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,11 @@ def calculate_slice(
252252

253253

254254
def _wrap_update_func(
255-
*args,
256255
delayed: bool,
257256
func: UpdateCallable,
258257
fs: fsspec.AbstractFileSystem,
259258
immutable: str | None,
260259
selected_variables: Iterable[str] | None,
261-
**kwargs,
262260
) -> WrappedPartitionCallable:
263261
"""Wrap an update function taking a partition's dataset as input and
264262
returning variable's values as a numpy array.
@@ -271,21 +269,21 @@ def _wrap_update_func(
271269
selected_variables: Name of the variables to load from the dataset.
272270
If None, all variables are loaded.
273271
trim: Whether to trim the overlap.
274-
*args: Positional arguments to pass to the function.
275-
**kwargs: Keyword arguments to pass to the function.
276272
277273
Returns:
278274
The wrapped function that takes a set of dataset partitions and the
279275
variable name as input and returns the variable's values as a numpy
280276
array.
281277
"""
282278

283-
def wrap_function(partitions: Iterable[str]) -> None:
279+
def wrap_function(partitions: Iterable[str], func_args: list[Any],
280+
func_kwargs: dict[str, Any]) -> None:
284281
# Applying function for each partition's data
285282
for partition in partitions:
286283
zds: dataset.Dataset = _load_dataset(delayed, fs, immutable,
287284
partition, selected_variables)
288-
dictionary: dict[str, ArrayLike] = func(zds, *args, **kwargs)
285+
dictionary: dict[str, ArrayLike] = func(zds, *func_args,
286+
**func_kwargs)
289287
tuple(
290288
update_zarr_array( # type: ignore[func-returns-value]
291289
dirname=join_path(partition, varname),
@@ -297,7 +295,6 @@ def wrap_function(partitions: Iterable[str]) -> None:
297295

298296

299297
def _wrap_update_func_with_overlap(
300-
*args,
301298
delayed: bool,
302299
depth: int,
303300
dim: str,
@@ -307,7 +304,6 @@ def _wrap_update_func_with_overlap(
307304
selected_partitions: Sequence[str],
308305
selected_variables: Iterable[str] | None,
309306
trim: bool,
310-
**kwargs,
311307
) -> WrappedPartitionCallable:
312308
"""Wrap an update function taking a partition's dataset as input and
313309
returning variable's values as a numpy array.
@@ -323,8 +319,6 @@ def _wrap_update_func_with_overlap(
323319
selected_variables: Name of the variables to load from the dataset.
324320
If None, all variables are loaded.
325321
trim: Whether to trim the overlap.
326-
*args: Positional arguments to pass to the function.
327-
**kwargs: Keyword arguments to pass to the function.
328322
329323
Returns:
330324
The wrapped function that takes a set of dataset partitions and the
@@ -334,7 +328,8 @@ def _wrap_update_func_with_overlap(
334328
if depth < 0:
335329
raise ValueError('Depth must be non-negative.')
336330

337-
def wrap_function(partitions: Sequence[str]) -> None:
331+
def wrap_function(partitions: Sequence[str], func_args: list[Any],
332+
func_kwargs: dict[str, Any]) -> None:
338333
# Applying function for each partition's data
339334
for partition in partitions:
340335

@@ -353,15 +348,15 @@ def wrap_function(partitions: Sequence[str]) -> None:
353348
selected_variables=selected_variables)
354349
# pylint: enable=duplicate-code
355350

356-
_update_with_overlap(*args,
351+
_update_with_overlap(*func_args,
357352
func=func,
358353
zds=zds,
359354
indices=indices,
360355
dim=dim,
361356
fs=fs,
362357
path=partition,
363358
trim=trim,
364-
**kwargs)
359+
**func_kwargs)
365360

366361
return wrap_function
367362

zcollection/view/__init__.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,6 @@ def update(
510510
wrap_function = _wrap_update_func(
511511
func,
512512
self.fs,
513-
*args,
514-
**kwargs,
515513
)
516514
else:
517515
if selected_variables is not None and len(
@@ -527,8 +525,6 @@ def update(
527525
self.fs,
528526
self.view_ref,
529527
trim,
530-
*args,
531-
**kwargs,
532528
)
533529

534530
batchs: Iterator[Sequence[Any]] = dask_utils.split_sequence(
@@ -538,7 +534,9 @@ def update(
538534
wrap_function,
539535
tuple(batchs),
540536
key=func.__name__,
541-
base_dir=self.base_dir)
537+
base_dir=self.base_dir,
538+
func_args=args,
539+
func_kwargs=kwargs)
542540
storage.execute_transaction(client, self.synchronizer, awaitables)
543541

544542
# pylint: disable=duplicate-code

zcollection/view/detail.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99
from __future__ import annotations
1010

11+
from typing import Any
1112
import base64
1213
from collections.abc import Callable, Iterable, Iterator, Sequence
1314
import dataclasses
@@ -37,8 +38,9 @@
3738
from ..type_hints import ArrayLike, NDArray
3839

3940
#: Type of the function used to update a view.
40-
ViewUpdateCallable = \
41-
Callable[[Iterable[tuple[dataset.Dataset, str]], str], None]
41+
ViewUpdateCallable = Callable[
42+
[Iterable[tuple[dataset.Dataset, str]], str, list[Any], dict[str,
43+
Any]], None]
4244

4345
#: Name of the file that contains the checksum of the view.
4446
CHECKSUM_FILE = '.checksum'
@@ -350,28 +352,26 @@ def calculate_slice(
350352
def _wrap_update_func(
351353
func: collection.UpdateCallable,
352354
fs: fsspec.AbstractFileSystem,
353-
*args,
354-
**kwargs,
355355
) -> ViewUpdateCallable:
356356
"""Wrap an update function taking a list of partition's dataset and
357357
partition's path as input and returning None.
358358
359359
Args:
360360
func: The update function.
361361
fs: The file system used to access the variables in the view.
362-
*args: The arguments of the update function.
363-
**kwargs: The keyword arguments of the update function.
364362
365363
Returns:
366364
The wrapped function.
367365
"""
368366

369367
def wrap_function(parameters: Iterable[tuple[dataset.Dataset, str]],
370-
base_dir: str) -> None:
368+
base_dir: str, func_args: list[Any],
369+
func_kwargs: dict[str, Any]) -> None:
371370
"""Wrap the function to be applied to the dataset."""
372371
for zds, partition in parameters:
373372
# Applying function on partition's data
374-
dictionary: dict[str, ArrayLike] = func(zds, *args, **kwargs)
373+
dictionary: dict[str, ArrayLike] = func(zds, *func_args,
374+
**func_kwargs)
375375
tuple(
376376
update_zarr_array( # type: ignore[func-returns-value]
377377
dirname=join_path(base_dir, partition, varname),
@@ -389,8 +389,6 @@ def _wrap_update_func_overlap(
389389
fs: fsspec.AbstractFileSystem,
390390
view_ref: collection.Collection,
391391
trim: bool,
392-
*args,
393-
**kwargs,
394392
) -> ViewUpdateCallable:
395393
"""Wrap an update function taking a list of partition's dataset and
396394
partition's path as input and returning None.
@@ -402,8 +400,6 @@ def _wrap_update_func_overlap(
402400
fs: The file system used to access the variables in the view.
403401
view_ref: The view reference.
404402
trim: If True, trim the dataset to the overlap.
405-
*args: The arguments of the update function.
406-
**kwargs: The keyword arguments of the update function.
407403
408404
Returns:
409405
The wrapped function.
@@ -414,7 +410,8 @@ def _wrap_update_func_overlap(
414410
raise ValueError('The depth must be positive')
415411

416412
def wrap_function(parameters: Iterable[tuple[dataset.Dataset, str]],
417-
base_dir: str) -> None:
413+
base_dir: str, func_args: list[Any],
414+
func_kwargs: dict[str, Any]) -> None:
418415
"""Wrap the function to be applied to the dataset."""
419416
zds: dataset.Dataset
420417
indices: slice
@@ -425,15 +422,15 @@ def wrap_function(parameters: Iterable[tuple[dataset.Dataset, str]],
425422
# pylint: disable=duplicate-code
426423
# False positive with the function _wrap_update_func_with_overlap
427424
# defined in the module zcollection.collection.detail
428-
_update_with_overlap(*args,
425+
_update_with_overlap(*func_args,
429426
func=func,
430427
zds=zds,
431428
indices=indices,
432429
dim=dim,
433430
fs=fs,
434431
path=join_path(base_dir, partition),
435432
trim=trim,
436-
**kwargs)
433+
**func_kwargs)
437434
# pylint: enable=duplicate-code
438435

439436
return wrap_function

0 commit comments

Comments
 (0)