From b958664e83bb177ac798aa0b88ae00f2373f7a94 Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Thu, 15 Feb 2024 16:57:42 -0600 Subject: [PATCH] ressurect tests after tuple-out fix --- src/coffea/dataset_tools/apply_processor.py | 25 ++++++++++++--------- tests/test_dataset_tools.py | 16 ++++++++----- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/coffea/dataset_tools/apply_processor.py b/src/coffea/dataset_tools/apply_processor.py index 9293db42c..b93b4c6fd 100644 --- a/src/coffea/dataset_tools/apply_processor.py +++ b/src/coffea/dataset_tools/apply_processor.py @@ -139,16 +139,14 @@ def apply_to_dataset( out = None if parallelize_with_dask: (wired_events,) = _pack_meta_to_wire(events) - out = ( - dask.delayed( - lambda: lz4.frame.compress( - cloudpickle.dumps( - partial(_apply_analysis_wire, analysis, wired_events)() - ), - compression_level=6, - ) - )(), - ) + out = dask.delayed( + lambda: lz4.frame.compress( + cloudpickle.dumps( + partial(_apply_analysis_wire, analysis, wired_events)() + ), + compression_level=6, + ) + )() dask.base.function_cache.clear() else: out = analysis(events) @@ -217,6 +215,7 @@ def apply_to_fileset( events[name], analyses_to_compute[name], report[name] = dataset_out elif len(dataset_out) == 2: events[name], analyses_to_compute[name] = dataset_out + print(dataset_out) else: raise ValueError( "apply_to_dataset only returns (events, outputs) or (events, outputs, reports)" @@ -224,7 +223,7 @@ def apply_to_fileset( elif isinstance(dataset_out, tuple) and len(dataset_out) == 3: events[name], out[name], report[name] = dataset_out elif isinstance(dataset_out, tuple) and len(dataset_out) == 2: - events[name], out[name] = dataset_out[0] + events[name], out[name] = dataset_out else: raise ValueError( "apply_to_dataset only returns (events, outputs) or (events, outputs, reports)" @@ -238,6 +237,10 @@ def apply_to_fileset( ) (out[name],) = _unpack_meta_from_wire(*dataset_out_wire) + for name in out: + if isinstance(out[name], tuple) and len(out[name]) == 1: + out[name] = out[name][0] + if len(report) > 0: return events, out, report return events, out diff --git a/tests/test_dataset_tools.py b/tests/test_dataset_tools.py index 324c287a6..8fd9235c1 100644 --- a/tests/test_dataset_tools.py +++ b/tests/test_dataset_tools.py @@ -220,8 +220,8 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report): if allow_read_errors_with_report: assert isinstance(out, tuple) - assert len(out) == 2 - out, report = out + assert len(out) == 3 + _, out, report = out assert isinstance(out, dict) assert isinstance(report, dict) assert out.keys() == {"ZJets", "Data"} @@ -236,8 +236,10 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report): assert isinstance(report["ZJets"], dask_awkward.Array) assert isinstance(report["Data"], dask_awkward.Array) else: - assert isinstance(out, dict) + assert isinstance(out, tuple) assert len(out) == 2 + _, out = out + assert isinstance(out, dict) assert out.keys() == {"ZJets", "Data"} assert isinstance(out["ZJets"], tuple) assert isinstance(out["Data"], tuple) @@ -255,8 +257,8 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report): if allow_read_errors_with_report: assert isinstance(out, tuple) - assert len(out) == 2 - out, report = out + assert len(out) == 3 + _, out, report = out assert isinstance(out, dict) assert isinstance(report, dict) assert out.keys() == {"ZJets", "Data"} @@ -271,8 +273,10 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report): assert isinstance(report["ZJets"], dask_awkward.Array) assert isinstance(report["Data"], dask_awkward.Array) else: - assert isinstance(out, dict) + assert isinstance(out, tuple) assert len(out) == 2 + _, out = out + assert isinstance(out, dict) assert out.keys() == {"ZJets", "Data"} assert isinstance(out["ZJets"], tuple) assert isinstance(out["Data"], tuple)