diff --git a/src/coffea/dataset_tools/__init__.py b/src/coffea/dataset_tools/__init__.py index 647fd336c..cc9488672 100644 --- a/src/coffea/dataset_tools/__init__.py +++ b/src/coffea/dataset_tools/__init__.py @@ -1,4 +1,9 @@ -from coffea.dataset_tools.apply_processor import apply_to_dataset, apply_to_fileset +from coffea.dataset_tools.apply_processor import ( + apply_to_dataset, + apply_to_fileset, + load_taskgraph, + save_taskgraph, +) from coffea.dataset_tools.manipulations import ( filter_files, get_failed_steps_for_dataset, @@ -14,6 +19,8 @@ "preprocess", "apply_to_dataset", "apply_to_fileset", + "save_taskgraph", + "load_taskgraph", "max_chunks", "slice_chunks", "filter_files", diff --git a/src/coffea/dataset_tools/apply_processor.py b/src/coffea/dataset_tools/apply_processor.py index 2c135a393..cb592ae94 100644 --- a/src/coffea/dataset_tools/apply_processor.py +++ b/src/coffea/dataset_tools/apply_processor.py @@ -17,7 +17,7 @@ ) from coffea.nanoevents import BaseSchema, NanoAODSchema, NanoEventsFactory from coffea.processor import ProcessorABC -from coffea.util import decompress_form +from coffea.util import decompress_form, load, save DaskOutputBaseType = Union[ dask.base.DaskMethodsMixin, @@ -48,8 +48,6 @@ def _pack_meta_to_wire(*collections): attrs=unpacked[i]._meta.attrs, ) packed_out = repacker(output) - if len(packed_out) == 1: - return packed_out[0] return packed_out @@ -68,21 +66,13 @@ def _unpack_meta_from_wire(*collections): attrs=unpacked[i]._meta.attrs, ) packed_out = repacker(output) - if len(packed_out) == 1: - return packed_out[0] return packed_out -def _apply_analysis_wire(analysis, events_and_maybe_report_wire): - events = _unpack_meta_from_wire(events_and_maybe_report_wire) - report = None - if isinstance(events, tuple): - events, report = events +def _apply_analysis_wire(analysis, events_wire): + (events,) = _unpack_meta_from_wire(events_wire) events._meta.attrs["@original_array"] = events - out = analysis(events) - if report is not None: - return _pack_meta_to_wire(out, report) return _pack_meta_to_wire(out) @@ -145,16 +135,14 @@ def apply_to_dataset( out = None if parallelize_with_dask: - if not isinstance(events_and_maybe_report, tuple): - events_and_maybe_report = (events_and_maybe_report,) - wired_events = _pack_meta_to_wire(*events_and_maybe_report) + (wired_events,) = _pack_meta_to_wire(events) out = dask.delayed(partial(_apply_analysis_wire, analysis, wired_events))() else: out = analysis(events) if report is not None: - return out, report - return out + return events, out, report + return events, out def apply_to_fileset( @@ -184,11 +172,14 @@ def apply_to_fileset( Returns ------- + events: dict[str, dask_awkward.Array] + The NanoEvents objects the analysis function was applied to. out : dict[str, DaskOutputType] The output of the analysis workflow applied to the datasets, keyed by dataset name. report : dask_awkward.Array, optional The file access report for running the analysis on the input dataset. Needs to be computed in simultaneously with the analysis to be accurate. """ + events = {} out = {} analyses_to_compute = {} report = {} @@ -206,24 +197,92 @@ def apply_to_fileset( parallelize_with_dask, ) if parallelize_with_dask: - analyses_to_compute[name] = dataset_out - elif isinstance(dataset_out, tuple): - out[name], report[name] = dataset_out + if len(dataset_out) == 3: + events[name], analyses_to_compute[name], report[name] = dataset_out + elif len(dataset_out) == 2: + events[name], analyses_to_compute[name] = dataset_out + else: + raise ValueError( + "apply_to_dataset only returns (events, outputs) or (events, outputs, reports)" + ) + elif isinstance(dataset_out, tuple) and len(dataset_out) == 3: + events[name], out[name], report[name] = dataset_out + elif isinstance(dataset_out, tuple) and len(dataset_out) == 2: + events[name], out[name] = dataset_out else: - out[name] = dataset_out + raise ValueError( + "apply_to_dataset only returns (events, outputs) or (events, outputs, reports)" + ) if parallelize_with_dask: (calculated_graphs,) = dask.compute(analyses_to_compute, scheduler=scheduler) for name, dataset_out_wire in calculated_graphs.items(): - to_unwire = dataset_out_wire - if not isinstance(dataset_out_wire, tuple): - to_unwire = (dataset_out_wire,) - dataset_out = _unpack_meta_from_wire(*to_unwire) - if isinstance(dataset_out, tuple): - out[name], report[name] = dataset_out - else: - out[name] = dataset_out + (out[name],) = _unpack_meta_from_wire(*dataset_out_wire) if len(report) > 0: - return out, report - return out + return events, out, report + return events, out + + +def save_taskgraph(filename, events, *data_products, optimize_graph=False): + """ + Save a task graph and its originating nanoevents to a file + Parameters + ---------- + filename: str + Where to save the resulting serialized taskgraph and nanoevents. + Suggested postfix ".hlg", after dask's HighLevelGraph object. + events: dict[str, dask_awkward.Array] + A dictionary of nanoevents objects. + data_products: dict[str, DaskOutputBaseType] + The data products resulting from applying an analysis to + a NanoEvents object. This may include report objects. + optimize_graph: bool, default False + Whether or not to save the task graph in its optimized form. + + Returns + ------- + """ + (events_wire,) = _pack_meta_to_wire(events) + + if len(data_products) == 0: + raise ValueError( + "You must supply at least one analysis data product to save a task graph!" + ) + + data_products_out = data_products + if optimize_graph: + data_products_out = dask.optimize(data_products) + + data_products_wire = _pack_meta_to_wire(*data_products_out) + + save( + { + "events": events_wire, + "data_products": data_products_wire, + "optimized": optimize_graph, + }, + filename, + ) + + +def load_taskgraph(filename): + """ + Load a task graph and its originating nanoevents from a file. + Parameters + ---------- + filename: str + The file from which to load the task graph. + Returns + _______ + """ + graph_information_wire = load(filename) + + (events,) = _unpack_meta_from_wire(graph_information_wire["events"]) + (data_products,) = _unpack_meta_from_wire(*graph_information_wire["data_products"]) + optimized = graph_information_wire["optimized"] + + for dataset_name in events: + events[dataset_name]._meta.attrs["@original_array"] = events[dataset_name] + + return events, data_products, optimized diff --git a/src/coffea/util.py b/src/coffea/util.py index bfb0b5119..55a99232a 100644 --- a/src/coffea/util.py +++ b/src/coffea/util.py @@ -36,21 +36,20 @@ import lz4.frame -def load(filename): +def load(filename, mode="rb"): """Load a coffea file from disk""" - with lz4.frame.open(filename) as fin: + with lz4.frame.open(filename, mode) as fin: output = cloudpickle.load(fin) return output -def save(output, filename): +def save(output, filename, mode="wb"): """Save a coffea object or collection thereof to disk This function can accept any picklable object. Suggested suffix: ``.coffea`` """ - with lz4.frame.open(filename, "wb") as fout: - thepickle = cloudpickle.dumps(output) - fout.write(thepickle) + with lz4.frame.open(filename, mode) as fout: + cloudpickle.dump(output, fout) def _hex(string): diff --git a/tests/test_dataset_tools.py b/tests/test_dataset_tools.py index 627a51811..e68800d5f 100644 --- a/tests/test_dataset_tools.py +++ b/tests/test_dataset_tools.py @@ -7,9 +7,11 @@ apply_to_fileset, filter_files, get_failed_steps_for_fileset, + load_taskgraph, max_chunks, max_files, preprocess, + save_taskgraph, slice_chunks, slice_files, ) @@ -202,7 +204,7 @@ def test_apply_to_fileset(proc_and_schema, delayed_taskgraph_calc): proc, schemaclass = proc_and_schema with Client() as _: - to_compute = apply_to_fileset( + _, to_compute = apply_to_fileset( proc(), _runnable_result, schemaclass=schemaclass, @@ -215,7 +217,7 @@ def test_apply_to_fileset(proc_and_schema, delayed_taskgraph_calc): assert out["Data"]["cutflow"]["Data_pt"] == 84 assert out["Data"]["cutflow"]["Data_mass"] == 66 - to_compute = apply_to_fileset( + _, to_compute = apply_to_fileset( proc(), max_chunks(_runnable_result, 1), schemaclass=schemaclass, @@ -240,7 +242,7 @@ def test_apply_to_fileset_hinted_form(): save_form=True, ) - to_compute = apply_to_fileset( + _, to_compute = apply_to_fileset( NanoEventsProcessor(), dataset_runnable, schemaclass=NanoAODSchema, @@ -445,14 +447,14 @@ def test_slice_chunks(): @pytest.mark.parametrize("delayed_taskgraph_calc", [True, False]) def test_recover_failed_chunks(delayed_taskgraph_calc): with Client() as _: - to_compute = apply_to_fileset( + _, to_compute, reports = apply_to_fileset( NanoEventsProcessor(), _starting_fileset_with_steps, schemaclass=NanoAODSchema, uproot_options={"allow_read_errors_with_report": True}, parallelize_with_dask=delayed_taskgraph_calc, ) - out, reports = dask.compute(*to_compute) + out, reports = dask.compute(to_compute, reports) failed_fset = get_failed_steps_for_fileset(_starting_fileset_with_steps, reports) assert failed_fset == { @@ -474,3 +476,50 @@ def test_recover_failed_chunks(delayed_taskgraph_calc): } } } + + +@pytest.mark.parametrize( + "proc_and_schema", + [(NanoTestProcessor, BaseSchema), (NanoEventsProcessor, NanoAODSchema)], +) +@pytest.mark.parametrize( + "with_report", + [True, False], +) +def test_task_graph_serialization(proc_and_schema, with_report): + proc, schemaclass = proc_and_schema + + with Client() as _: + output = apply_to_fileset( + proc(), + _runnable_result, + schemaclass=schemaclass, + parallelize_with_dask=False, + uproot_options={"allow_read_errors_with_report": with_report}, + ) + + events = output[0] + to_compute = output[1:] + + save_taskgraph( + "./test_task_graph_serialization.hlg", + events, + to_compute, + optimize_graph=False, + ) + + _, to_compute_serdes, is_optimized = load_taskgraph( + "./test_task_graph_serialization.hlg" + ) + + print(to_compute_serdes) + + if len(to_compute_serdes) > 1: + (out, _) = dask.compute(*to_compute_serdes) + else: + (out,) = dask.compute(*to_compute_serdes) + + assert out["ZJets"]["cutflow"]["ZJets_pt"] == 18 + assert out["ZJets"]["cutflow"]["ZJets_mass"] == 6 + assert out["Data"]["cutflow"]["Data_pt"] == 84 + assert out["Data"]["cutflow"]["Data_mass"] == 66