Skip to content

Commit

Permalink
make interface changes backwards compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Mar 6, 2024
1 parent 56164c0 commit c5a53c7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
23 changes: 19 additions & 4 deletions src/coffea/dataset_tools/apply_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def apply_to_dataset(
metadata: dict[Hashable, Any] = {},
uproot_options: dict[str, Any] = {},
parallelize_with_dask: bool = False,
) -> DaskOutputType | tuple[DaskOutputType, dask_awkward.Array]:
) -> (
DaskOutputType
| tuple[DaskOutputType, dask_awkward.Array]
| tuple[dask_awkward.Array, DaskOutputType, dask_awkward.Array]
):
"""
Apply the supplied function or processor to the supplied dataset.
Parameters
Expand All @@ -103,6 +107,8 @@ def apply_to_dataset(
Options to pass to uproot. Pass at least {"allow_read_errors_with_report": True} to turn on file access reports.
parallelize_with_dask: bool, default False
Create dask.delayed objects that will return the the computable dask collections for the analysis when computed.
return_events: bool, default True
Return the created events object, or not.
Returns
-------
Expand Down Expand Up @@ -165,7 +171,16 @@ def apply_to_fileset(
uproot_options: dict[str, Any] = {},
parallelize_with_dask: bool = False,
scheduler: Callable | str | None = None,
) -> dict[str, DaskOutputType] | tuple[dict[str, DaskOutputType], dask_awkward.Array]:
return_events: bool = False,
) -> (
dict[str, DaskOutputType]
| tuple[dict[str, DaskOutputType], dict[str, dask_awkward.Array]]
| tuple[
dict[str, dask_awkward.Array],
dict[str, DaskOutputType],
dict[str, dask_awkward.Array],
]
):
"""
Apply the supplied function or processor to the supplied fileset (set of datasets).
Parameters
Expand Down Expand Up @@ -242,8 +257,8 @@ def apply_to_fileset(
out[name] = out[name][0]

if len(report) > 0:
return events, out, report
return events, out
return (events, out, report) if return_events else (out, report)
return (events, out) if return_events else out


def save_taskgraph(filename, events, *data_products, optimize_graph=False):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_dataset_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report):
_my_analysis_output_2,
_runnable_result,
uproot_options={"allow_read_errors_with_report": allow_read_errors_with_report},
return_events=True,
)

if allow_read_errors_with_report:
Expand Down Expand Up @@ -253,6 +254,7 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report):
_my_analysis_output_3,
_runnable_result,
uproot_options={"allow_read_errors_with_report": allow_read_errors_with_report},
return_events=True,
)

if allow_read_errors_with_report:
Expand Down Expand Up @@ -301,6 +303,7 @@ def test_apply_to_fileset(proc_and_schema, delayed_taskgraph_calc):
_runnable_result,
schemaclass=schemaclass,
parallelize_with_dask=delayed_taskgraph_calc,
return_events=True,
)
out = dask.compute(to_compute)[0]

Expand All @@ -314,6 +317,7 @@ def test_apply_to_fileset(proc_and_schema, delayed_taskgraph_calc):
max_chunks(_runnable_result, 1),
schemaclass=schemaclass,
parallelize_with_dask=delayed_taskgraph_calc,
return_events=True,
)
out = dask.compute(to_compute)[0]

Expand All @@ -338,6 +342,7 @@ def test_apply_to_fileset_hinted_form():
NanoEventsProcessor(),
dataset_runnable,
schemaclass=NanoAODSchema,
return_events=True,
)
out = dask.compute(to_compute)[0]

Expand Down Expand Up @@ -554,6 +559,7 @@ def test_recover_failed_chunks(delayed_taskgraph_calc):
schemaclass=NanoAODSchema,
uproot_options={"allow_read_errors_with_report": True},
parallelize_with_dask=delayed_taskgraph_calc,
return_events=True,
)
out, reports = dask.compute(to_compute, reports)

Expand Down Expand Up @@ -597,6 +603,7 @@ def test_task_graph_serialization(proc_and_schema, with_report):
schemaclass=schemaclass,
parallelize_with_dask=False,
uproot_options={"allow_read_errors_with_report": with_report},
return_events=True,
)

events = output[0]
Expand Down

0 comments on commit c5a53c7

Please sign in to comment.