Skip to content

Commit

Permalink
Updates / changes
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Mar 8, 2025
1 parent 4192643 commit 93e8106
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 12 deletions.
13 changes: 7 additions & 6 deletions ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,24 @@ def process_and_combine_datasets(
)

if has_solar_config:
solar_config = config.input_data.solar_position

# Create datetime range for solar position calculation
datetimes = pd.date_range(
t0 + minutes(gsp_config.interval_start_minutes),
t0 + minutes(gsp_config.interval_end_minutes),
freq=minutes(gsp_config.time_resolution_minutes),
t0 + minutes(solar_config.interval_start_minutes),
t0 + minutes(solar_config.interval_end_minutes),
freq=minutes(solar_config.time_resolution_minutes),
)

# We already have lon, lat if target_key is "gsp", otherwise calculate them
if target_key != "gsp":
lon, lat = osgb_to_lon_lat(location.x, location.y)

# Calculate solar positions and add to modalities
solar_positions = make_sun_position_numpy_sample(datetimes, lon, lat)
prefixed_solar_positions = {
solar_positions = {
f"solar_position_{key}": value for key, value in solar_positions.items()
}
numpy_modalities.append(prefixed_solar_positions)
numpy_modalities.append(solar_positions)

# Combine all the modalities and fill NaNs
combined_sample = merge_dicts(numpy_modalities)
Expand Down
28 changes: 23 additions & 5 deletions ocf_data_sampler/torch_datasets/datasets/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)

sample = self.process_and_combine_site_sample_dict(sample_dict)
sample = self.process_and_combine_site_sample_dict(sample_dict, t0)
sample = sample.compute()
return sample

Expand Down Expand Up @@ -218,12 +218,14 @@ def get_locations(self, site_xr: xr.Dataset) -> list[Location]:
def process_and_combine_site_sample_dict(
self,
dataset_dict: dict,
t0: pd.Timestamp,
) -> xr.Dataset:
"""Normalize and combine data into a single xr Dataset.
Args:
dataset_dict: dict containing sliced xr DataArrays
config: Configuration for the model
t0: The initial timestamp of the sample
Returns:
xr.Dataset: A merged Dataset with nans filled in.
Expand Down Expand Up @@ -268,19 +270,35 @@ def process_and_combine_site_sample_dict(
)

if has_solar_config:
# add sun features
solar_config = self.config.input_data.solar_position

# Get site timestamps - determine start and length
site_timestamps = combined_sample_dataset.site__time_utc.values
site_time_length = len(site_timestamps)

# Create datetime range using solar config params
solar_datetimes = pd.date_range(
t0 + minutes(solar_config.interval_start_minutes),
t0 + minutes(solar_config.interval_end_minutes),
freq=minutes(solar_config.time_resolution_minutes),
)

# Ensure matching of site time dimension length
solar_datetimes = solar_datetimes[:site_time_length]

# Calculate sun position features
sun_position_features = make_sun_position_numpy_sample(
datetimes=datetimes,
datetimes=solar_datetimes,
lon=combined_sample_dataset.site__longitude.values,
lat=combined_sample_dataset.site__latitude.values,
)

# Assign to existing site time dimension
combined_sample_dataset = combined_sample_dataset.assign_coords(
{f"solar_position_{key}": ("site__time_utc", v)
for key, v in sun_position_features.items()},
)

# TODO include t0_index in xr dataset?

# Fill any nan values
return combined_sample_dataset.fillna(0.0)

Expand Down
70 changes: 69 additions & 1 deletion tests/torch_datasets/test_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SitesDataset,
coarsen_data,
convert_from_dataset_to_dict_datasets,
convert_netcdf_to_numpy_sample,
)


Expand Down Expand Up @@ -161,8 +162,10 @@ def test_process_and_combine_site_sample_dict(sites_dataset: xr.Dataset) -> None
),
}

t0 = pd.Timestamp("2024-01-01 00:00")

# Call function
result = sites_dataset.process_and_combine_site_sample_dict(site_dict)
result = sites_dataset.process_and_combine_site_sample_dict(site_dict, t0)

# Assert to validate output structure
assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset"
Expand Down Expand Up @@ -235,3 +238,68 @@ def test_solar_position_decoupling_site(tmp_path, site_config_filename):
# Sample with solar config should have solar position data
for key in solar_keys:
assert key in sample_with_solar.coords, f"Solar key {key} should be in sample"


def test_convert_from_dataset_to_dict_solar_handling(tmp_path, site_config_filename):
"""Test that function handles solar position coordinates correctly."""

config = load_yaml_configuration(site_config_filename)
config.input_data.solar_position = SolarPosition(
time_resolution_minutes=30,
interval_start_minutes=0,
interval_end_minutes=180,
)

config_with_solar_path = tmp_path / "site_config_with_solar_for_dict.yaml"
save_yaml_configuration(config, config_with_solar_path)

# Create dataset and obtain sample with solar
dataset_with_solar = SitesDataset(config_with_solar_path)
sample_with_solar = dataset_with_solar[0]

# Verify solar position data exists in original sample
solar_keys = ["solar_position_azimuth", "solar_position_elevation"]
for key in solar_keys:
assert key in sample_with_solar.coords, f"Solar key {key} not found in original sample"

# Conversion and subsequent verification
converted_dict = convert_from_dataset_to_dict_datasets(sample_with_solar)
assert isinstance(converted_dict, dict)
assert "site" in converted_dict


def test_convert_netcdf_to_numpy_solar_handling(tmp_path, site_config_filename):
"""Test that convert_netcdf_to_numpy_sample handles solar position data correctly."""

config = load_yaml_configuration(site_config_filename)
config.input_data.solar_position = SolarPosition(
time_resolution_minutes=30,
interval_start_minutes=0,
interval_end_minutes=180,
)

config_with_solar_path = tmp_path / "site_config_with_solar_for_numpy.yaml"
save_yaml_configuration(config, config_with_solar_path)

# Create dataset and obtain sample with solar
dataset_with_solar = SitesDataset(config_with_solar_path)
sample_with_solar = dataset_with_solar[0]

# Save to netCDF and load back
netcdf_path = tmp_path / "sample_with_solar.nc"
sample_with_solar.to_netcdf(netcdf_path)
loaded_sample = xr.open_dataset(netcdf_path)

# Verify solar position data exists in sample
solar_keys = ["solar_position_azimuth", "solar_position_elevation"]
for key in solar_keys:
assert key in loaded_sample.coords, f"Solar key {key} not found in loaded netCDF"

# Conversion and subsequent assertion
numpy_sample = convert_netcdf_to_numpy_sample(loaded_sample)
assert isinstance(numpy_sample, dict)

# Explicitly verify what is in sample
assert "nwp" in numpy_sample
assert "satellite_actual" in numpy_sample or "sat" in numpy_sample
assert "site" in numpy_sample

0 comments on commit 93e8106

Please sign in to comment.