Skip to content

Commit

Permalink
Add weighting function for several scenarios (#567)
Browse files Browse the repository at this point in the history
* implement weighting for several scnearios and members

* implement tests

* work around datatree.testing

* extend tests to root dt

* docs

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mathias Hauser <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Dec 3, 2024
1 parent da5803b commit 6552d66
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Weighted operations: calculate global mean
~core.weighted.global_mean
~core.weighted.lat_weights
~core.weighted.weighted_mean
~core.weighted.equal_scenario_weights_from_datatree

Geospatial
----------
Expand Down
87 changes: 87 additions & 0 deletions mesmer/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import xarray as xr
from datatree import DataTree, map_over_subtree


def _weighted_if_dim(obj, weights, dims):
Expand Down Expand Up @@ -106,3 +107,89 @@ def global_mean(data, weights=None, x_dim="lon", y_dim="lat"):
weights = lat_weights(data[y_dim])

return weighted_mean(data, weights, [x_dim, y_dim])


def equal_scenario_weights_from_datatree(
dt: DataTree, ens_dim: str = "member", time_dim: str = "time"
) -> DataTree:
"""
Create a DataTree isomorphic to ``dt``, holding the weights for each scenario to weight the ensemble members of each
scenario such that each scenario contributes equally to some fitting procedure.
The weight of each member = 1 / number of members in the scenario, so weights = 1 / ds[ens_dim].size.
Thus, if all scenarios have the same number of members, all weights will be equal.
If one scenario has more members than the others, its weights will be smaller.
Weights are always along the time and ens dim, if there are more dimensions in a dataset, they will be dropped.
Parameters:
-----------
dt : DataTree
DataTree holding the ``xr.Datasets`` for which the weights should be created. Each dataset must have at least
ens_dim and time_dim as dimensions, but can have more dimensions.
ens_dim : str
Name of the dimension along which the weights should be created. Default is "member".
time_dim : str
Name of the time dimension, will be filled with equal values for each ensemble member. Default is "time".
Returns:
--------
DataTree
DataTree holding the weights for each scenario isomorphic to dt, where each dataset has dimensions (time_dim, ens_dim).
Example:
--------
>>> dt = DataTree()
>>> dt["ssp119"] = DataTree(xr.Dataset({"tas": xr.DataArray(np.ones((20, 3)), dims=("time", "member"))}))
>>> dt["ssp585"] = DataTree(xr.Dataset({"tas": xr.DataArray(np.ones((20, 2)), dims=("time", "member"))}))
>>> weights = equal_scenario_weights_from_datatree(dt)
>>> weights
DataTree('None', parent=None)
├── DataTree('ssp119')
│ Dimensions: (time: 20, member: 3)
│ Coordinates:
│ * time (time) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
│ * member (member) int64 0 1 2
│ Data variables:
│ weights (time, member) float64 0.3333 0.3333 0.3333 ... 0.3333 0.3333
└── DataTree('ssp585')
Dimensions: (time: 20, member: 2)
Coordinates:
* time (time) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
* member (member) int64 0 1
Data variables:
weights (time, member) float64 0.5 0.5 0.5 0.5 0.5 ... 0.5 0.5 0.5 0.5 0.5
"""
if dt.depth != 1:
raise ValueError(f"DataTree must have a depth of 1, not {dt.depth}.")

def _create_weights(ds: xr.Dataset) -> xr.Dataset:
dims = set(ds.dims)
if ens_dim not in dims:
raise ValueError(f"Member dimension '{ens_dim}' not found in dataset.")
if time_dim not in dims:
raise ValueError(f"Time dimension '{time_dim}' not found in dataset.")

name, *others = ds.data_vars
if others:
raise ValueError("Dataset must only contain one data variable.")

# create weights
dims = [time_dim, ens_dim]
shape = [ds[time_dim].size, ds[ens_dim].size]

data = np.full(shape, fill_value=1 / ds[ens_dim].size)

weights = xr.DataArray(data, dims=dims)

# add back coords if they were there on ds
if ds[time_dim].coords:
weights = weights.assign_coords(ds[time_dim].coords)
if ds[ens_dim].coords:
weights = weights.assign_coords(ds[ens_dim].coords)

return xr.Dataset({"weights": weights})

weights = map_over_subtree(_create_weights)(dt)

return weights
113 changes: 113 additions & 0 deletions tests/unit/test_weighted.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
import xarray as xr
from datatree import DataTree

import mesmer

Expand Down Expand Up @@ -170,3 +171,115 @@ def test_global_mean_weights_passed(as_dataset):
expected = data.mean(("lat", "lon"))

xr.testing.assert_allclose(result, expected)


def test_equal_sceanrio_weights_from_datatree():
dt = DataTree()

n_members_ssp119 = 3
n_members_ssp585 = 2
n_gridcells = 3
n_ts = 30

ssp119 = xr.Dataset(
{
"tas": xr.DataArray(
np.ones((n_ts, n_members_ssp119)), dims=("time", "member")
)
}
)
ssp119 = ssp119.assign_coords(time=np.arange(n_ts))
ssp585 = xr.Dataset(
{
"tas": xr.DataArray(
np.ones((n_ts, n_members_ssp585)), dims=("time", "member")
)
}
)
ssp585 = ssp585.assign_coords(member=np.arange(n_members_ssp585))
dt = DataTree()
dt["ssp119"] = DataTree(ssp119)
dt["ssp585"] = DataTree(ssp585)

result1 = mesmer.weighted.equal_scenario_weights_from_datatree(dt)
expected = DataTree.from_dict(
{
"ssp119": DataTree(
xr.full_like(ssp119, fill_value=1 / n_members_ssp119).rename(
{"tas": "weights"}
)
),
"ssp585": DataTree(
xr.full_like(ssp585, fill_value=1 / n_members_ssp585).rename(
{"tas": "weights"}
)
),
}
)

# TODO: replace with datatree testing funcs when switching to xarray internal DataTree
assert result1.equals(expected)

dt["ssp119"] = DataTree(
dt.ssp119.ds.expand_dims(gridcell=np.arange(n_gridcells), axis=1)
)
dt["ssp585"] = DataTree(
dt.ssp585.ds.expand_dims(gridcell=np.arange(n_gridcells), axis=1)
)

result2 = mesmer.weighted.equal_scenario_weights_from_datatree(
dt, ens_dim="member", time_dim="time"
)
# TODO: replace with datatree testing funcs when switching to xarray internal DataTree
assert result2.equals(expected)


def test_create_equal_sceanrio_weights_from_datatree_checks():

dt = DataTree()
ssp119 = xr.Dataset(
{"tas": xr.DataArray(np.ones((20, 2)), dims=("time", "member"))}
)
ssp585 = xr.Dataset(
{"tas": xr.DataArray(np.ones((20, 3)), dims=("time", "member"))}
)
dt = DataTree()
dt["ssp119"] = DataTree(ssp119)
dt["ssp585"] = DataTree(ssp585)

# too deep
dt_too_deep = dt.copy()
dt_too_deep["ssp585/1"] = DataTree(
xr.Dataset({"tas": xr.DataArray([4, 5], dims="member")})
)
with pytest.raises(ValueError, match="DataTree must have a depth of 1, not 2."):
mesmer.weighted.equal_scenario_weights_from_datatree(dt_too_deep)

# missing member dimension
dt_no_member = dt.copy()
dt_no_member["ssp119"] = DataTree(dt_no_member.ssp119.ds.sel(member=1))
with pytest.raises(
ValueError, match="Member dimension 'member' not found in dataset."
):
mesmer.weighted.equal_scenario_weights_from_datatree(dt_no_member)

# missing time dimension
dt_no_time = dt.copy()
dt_no_time["ssp119"] = DataTree(dt_no_time.ssp119.ds.sel(time=1))
with pytest.raises(ValueError, match="Time dimension 'time' not found in dataset."):
mesmer.weighted.equal_scenario_weights_from_datatree(dt_no_time)

# multiple data variables
dt_multiple_vars = dt.copy()
dt_multiple_vars["ssp119"] = DataTree(
xr.Dataset(
{
"tas": xr.DataArray(np.ones((20, 2)), dims=("time", "member")),
"tas2": xr.DataArray(np.ones((20, 2)), dims=("time", "member")),
}
)
)
with pytest.raises(
ValueError, match="Dataset must only contain one data variable."
):
mesmer.weighted.equal_scenario_weights_from_datatree(dt_multiple_vars)

0 comments on commit 6552d66

Please sign in to comment.