Skip to content

Commit

Permalink
add in hooks for delayed calculation of task graph
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Jan 31, 2024
1 parent 41f0162 commit d9b7403
Showing 1 changed file with 72 additions and 8 deletions.
80 changes: 72 additions & 8 deletions src/coffea/dataset_tools/apply_processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import copy
from functools import partial
from typing import Any, Callable, Dict, Hashable, List, Set, Tuple, Union

import awkward
import dask.base
import dask.delayed
import dask_awkward

from coffea.dataset_tools.preprocess import (
Expand All @@ -31,12 +33,66 @@
GenericHEPAnalysis = Callable[[dask_awkward.Array], DaskOutputType]


def _pack_meta_to_wire(*collections):
unpacked, repacker = dask.base.unpack_collections(*collections)

output = []
for i in range(len(unpacked)):
output.append(unpacked[i])
if isinstance(
unpacked[i], (dask_awkward.Array, dask_awkward.Record, dask_awkward.Scalar)
):
output[-1]._meta = awkward.Array(
unpacked[i]._meta.layout.form.length_zero_array(),
behavior=unpacked[i]._meta.behavior,
attrs=unpacked[i]._meta.attrs,
)
packed_out = repacker(output)
if len(packed_out) == 1:
return packed_out[0]
return packed_out


def _unpack_meta_from_wire(*collections):
unpacked, repacker = dask.base.unpack_collections(*collections)

output = []
for i in range(len(unpacked)):
output.append(unpacked[i])
if isinstance(
unpacked[i], (dask_awkward.Array, dask_awkward.Record, dask_awkward.Scalar)
):
output[-1]._meta = awkward.Array(
unpacked[i]._meta.layout.to_typetracer(forget_length=True),
behavior=unpacked[i]._meta.behavior,
attrs=unpacked[i]._meta.attrs,
)
packed_out = repacker(output)
if len(packed_out) == 1:
return packed_out[0]
return packed_out


def _apply_analysis(analysis, events_and_maybe_report):
events = events_and_maybe_report
report = None
if isinstance(events_and_maybe_report, tuple):
events, report = events_and_maybe_report

out = analysis(events)

if report is not None:
return out, report
return out


def apply_to_dataset(
data_manipulation: ProcessorABC | GenericHEPAnalysis,
dataset: DatasetSpec | DatasetSpecOptional,
schemaclass: BaseSchema = NanoAODSchema,
metadata: dict[Hashable, Any] = {},
uproot_options: dict[str, Any] = {},
parallelize_with_dask: bool = False,
) -> DaskOutputType | tuple[DaskOutputType, dask_awkward.Array]:
"""
Apply the supplied function or processor to the supplied dataset.
Expand All @@ -52,6 +108,8 @@ def apply_to_dataset(
Metadata for the dataset that is accessible by the input analysis. Should also be dask-serializable.
uproot_options: dict[str, Any], default {}
Options to pass to uproot. Pass at least {"allow_read_errors_with_report": True} to turn on file access reports.
parallelize_with_dask: bool, default False
Create dask.delayed objects that will return the the computable dask collections for the analysis when computed.
Returns
-------
Expand All @@ -64,26 +122,32 @@ def apply_to_dataset(
if maybe_base_form is not None:
maybe_base_form = awkward.forms.from_json(decompress_form(maybe_base_form))
files = dataset["files"]
events = NanoEventsFactory.from_root(
events_and_maybe_report = NanoEventsFactory.from_root(
files,
metadata=metadata,
schemaclass=schemaclass,
known_base_form=maybe_base_form,
uproot_options=uproot_options,
).events()

report = None
if isinstance(events, tuple):
events, report = events

out = None
analysis = None
if isinstance(data_manipulation, ProcessorABC):
out = data_manipulation.process(events)
analysis = data_manipulation.process
elif isinstance(data_manipulation, Callable):
out = data_manipulation(events)
out = data_manipulation
else:
raise ValueError("data_manipulation must either be a ProcessorABC or Callable")

out = None
if parallelize_with_dask:
out = dask.delayed(partial(_apply_analysis, analysis, events_and_maybe_report))
else:
out = _apply_analysis(analysis, events_and_maybe_report)

report = None
if isinstance(out, tuple):
out, report = out

if report is not None:
return out, report
return out
Expand Down

0 comments on commit d9b7403

Please sign in to comment.