Skip to content

Commit

Permalink
ressurect tests after tuple-out fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Feb 15, 2024
1 parent e26f51f commit 6b9cda8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
25 changes: 14 additions & 11 deletions src/coffea/dataset_tools/apply_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,14 @@ def apply_to_dataset(
out = None
if parallelize_with_dask:
(wired_events,) = _pack_meta_to_wire(events)
out = (
dask.delayed(
lambda: lz4.frame.compress(
cloudpickle.dumps(
partial(_apply_analysis_wire, analysis, wired_events)()
),
compression_level=6,
)
)(),
)
out = dask.delayed(
lambda: lz4.frame.compress(
cloudpickle.dumps(
partial(_apply_analysis_wire, analysis, wired_events)()
),
compression_level=6,
)
)()
dask.base.function_cache.clear()
else:
out = analysis(events)
Expand Down Expand Up @@ -217,14 +215,15 @@ def apply_to_fileset(
events[name], analyses_to_compute[name], report[name] = dataset_out
elif len(dataset_out) == 2:
events[name], analyses_to_compute[name] = dataset_out
print(dataset_out)
else:
raise ValueError(
"apply_to_dataset only returns (events, outputs) or (events, outputs, reports)"
)
elif isinstance(dataset_out, tuple) and len(dataset_out) == 3:
events[name], out[name], report[name] = dataset_out
elif isinstance(dataset_out, tuple) and len(dataset_out) == 2:
events[name], out[name] = dataset_out[0]
events[name], out[name] = dataset_out
else:
raise ValueError(
"apply_to_dataset only returns (events, outputs) or (events, outputs, reports)"
Expand All @@ -238,6 +237,10 @@ def apply_to_fileset(
)
(out[name],) = _unpack_meta_from_wire(*dataset_out_wire)

for name in out:
if isinstance(out[name], tuple) and len(out[name]) == 1:
out[name] = out[name][0]

if len(report) > 0:
return events, out, report
return events, out
Expand Down
16 changes: 10 additions & 6 deletions tests/test_dataset_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report):

if allow_read_errors_with_report:
assert isinstance(out, tuple)
assert len(out) == 2
out, report = out
assert len(out) == 3
_, out, report = out
assert isinstance(out, dict)
assert isinstance(report, dict)
assert out.keys() == {"ZJets", "Data"}
Expand All @@ -236,8 +236,10 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report):
assert isinstance(report["ZJets"], dask_awkward.Array)
assert isinstance(report["Data"], dask_awkward.Array)
else:
assert isinstance(out, dict)
assert isinstance(out, tuple)
assert len(out) == 2
_, out = out
assert isinstance(out, dict)
assert out.keys() == {"ZJets", "Data"}
assert isinstance(out["ZJets"], tuple)
assert isinstance(out["Data"], tuple)
Expand All @@ -255,8 +257,8 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report):

if allow_read_errors_with_report:
assert isinstance(out, tuple)
assert len(out) == 2
out, report = out
assert len(out) == 3
_, out, report = out
assert isinstance(out, dict)
assert isinstance(report, dict)
assert out.keys() == {"ZJets", "Data"}
Expand All @@ -271,8 +273,10 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report):
assert isinstance(report["ZJets"], dask_awkward.Array)
assert isinstance(report["Data"], dask_awkward.Array)
else:
assert isinstance(out, dict)
assert isinstance(out, tuple)
assert len(out) == 2
_, out = out
assert isinstance(out, dict)
assert out.keys() == {"ZJets", "Data"}
assert isinstance(out["ZJets"], tuple)
assert isinstance(out["Data"], tuple)
Expand Down

0 comments on commit 6b9cda8

Please sign in to comment.