Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

start sampler code for met office #59

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion src/open_data_pvnet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
merge_hours_to_day,
process_month_by_days,
merge_days_to_month,
get_zarr_groups,
)
from pathlib import Path
import concurrent.futures
from typing import List, Tuple
from typing import List, Tuple, Optional
from open_data_pvnet.utils.data_uploader import upload_monthly_zarr, upload_to_huggingface
from open_data_pvnet.scripts.archive import handle_archive
from open_data_pvnet.nwp.met_office import CONFIG_PATHS
from open_data_pvnet.utils.data_sampler import prepare_nwp_dataset_for_ocf

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -188,6 +190,20 @@ def configure_parser():
consolidate_parser = operation_subparsers.add_parser("consolidate", help="Consolidate data")
_add_common_arguments(consolidate_parser, provider)

# Add only the new sample operation to existing metoffice_subparsers
sample = operation_subparsers.add_parser("sample", help="Sample data for ML training")
sample.add_argument("--year", type=int, required=True, help="Year of data")
sample.add_argument("--month", type=int, required=True, help="Month of data")
sample.add_argument("--day", type=int, required=True, help="Day of data")
sample.add_argument(
"--region", choices=["uk", "global"], default="global", help="Region to process"
)
sample.add_argument("--chunks", type=str, help="Chunk sizes (e.g., 'time:24,latitude:100')")
sample.add_argument(
"--output", type=str, required=True, help="Output path for sampled data"
)
sample.add_argument("--remote", action="store_true", help="Use remote data source")

return parser


Expand Down Expand Up @@ -370,6 +386,55 @@ def archive_to_hf(provider: str, year: int, month: int, day: int = None, **kwarg
raise


def handle_sample(
provider: str,
year: int,
month: int,
day: int,
region: str = "global",
chunks: Optional[str] = None,
output: str = None,
remote: bool = False,
) -> None:
"""
Handle the sample operation.

Args:
provider: Data provider (e.g., 'metoffice').
year: Year of data.
month: Month of data.
day: Day of data.
region: Region to process ('uk' or 'global').
chunks: Optional chunk sizes for dask.
output: Output path for sampled data.
remote: Whether to use remote data source.
"""
logger.info(f"Loading data for {year}-{month:02d}-{day:02d}")

# Parse chunks if provided
chunk_dict = None
if chunks:
chunk_dict = dict(chunk.split(":") for chunk in chunks.split(","))

# Load the dataset
store, ds = load_zarr_data(year, month, day, chunks=chunk_dict, remote=remote)

try:
# Get Zarr groups and prepare dataset for OCF
zarr_groups = get_zarr_groups(store)
ds_ocf = prepare_nwp_dataset_for_ocf(ds, zarr_groups, store, chunk_dict)

# Save the prepared dataset
logger.info(f"Saving prepared dataset to {output}")
ds_ocf.to_netcdf(output)

finally:
# Clean up
store.close()

logger.info("Sampling operation completed successfully")


def main():
"""Entry point for the Open Data PVNet CLI tool.

Expand Down Expand Up @@ -397,6 +462,9 @@ def main():
# Consolidate specific day
open-data-pvnet metoffice consolidate --year 2023 --month 12 --day 1

# Sample data for ML training
open-data-pvnet metoffice sample --year 2023 --month 12 --day 1 --region uk --chunks "time:24,latitude:100" --output /path/to/output

GFS Data:
Partially implemented

Expand Down Expand Up @@ -475,5 +543,16 @@ def main():
"archive_type": getattr(args, "archive_type", "zarr.zip"),
}
archive_to_hf(**archive_kwargs)
elif args.operation == "sample":
handle_sample(
provider="metoffice",
year=args.year,
month=args.month,
day=args.day,
region=args.region,
chunks=args.chunks,
output=args.output,
remote=args.remote,
)

return 0
98 changes: 98 additions & 0 deletions src/open_data_pvnet/utils/data_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import re
import logging
import xarray as xr
from typing import Dict, List, Optional
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
from zarr.storage import ZipStore

logger = logging.getLogger(__name__)

# Variable mapping for Met Office data
METOFFICE_VARIABLE_MAP = {
"cloud_amount_of_high_cloud": "high_type_cloud_area_fraction",
"cloud_amount_of_low_cloud": "low_type_cloud_area_fraction",
"cloud_amount_of_medium_cloud": "medium_type_cloud_area_fraction",
"cloud_amount_of_total_cloud": "cloud_area_fraction",
"radiation_flux_in_longwave_downward_at_surface": "surface_downwelling_longwave_flux_in_air",
"radiation_flux_in_shortwave_total_downward_at_surface": "surface_downwelling_shortwave_flux_in_air",
"radiation_flux_in_uv_downward_at_surface": "surface_downwelling_ultraviolet_flux_in_air",
"snow_depth_water_equivalent": "lwe_thickness_of_surface_snow_amount",
"temperature_at_screen_level": "air_temperature",
"wind_direction_at_10m": "wind_from_direction",
"wind_speed_at_10m": "wind_speed",
}


def create_dynamic_variable_mapping(
zarr_groups: List[str], store: ZipStore, chunks: Optional[Dict], consolidated: bool
) -> Dict[str, str]:
"""
Dynamically maps variable names from Zarr groups to NWPSampleKey formatted names.

Args:
zarr_groups: List of Zarr group names extracted from the dataset.
store: The Zarr store containing the dataset.
chunks: Chunking configuration for opening Zarr files.
consolidated: Whether the Zarr dataset is consolidated.

Returns:
Dictionary mapping internal dataset variables to NWPSampleKey formatted names.
"""
from open_data_pvnet.utils.data_downloader import open_zarr_group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could put this import at the starting of the file.


variable_mapping = {}

for group in zarr_groups:
match = re.search(r"PT\d+H\d+M-(.*).zarr", group)
if match:
file_var_name = match.group(1)
target_var_name = METOFFICE_VARIABLE_MAP.get(file_var_name, file_var_name)

try:
group_ds = open_zarr_group(store, group, chunks, consolidated)
for var in group_ds.variables:
if target_var_name in var:
variable_mapping[var] = f"{NWPSampleKey.nwp}.{target_var_name}"
break
else:
logger.warning(f"No match found for '{file_var_name}' in {group}")

except Exception as e:
logger.error(f"Could not open group {group}: {e}")

return variable_mapping


def prepare_nwp_dataset_for_ocf(
ds: xr.Dataset,
zarr_groups: List[str],
store: ZipStore,
chunks: Optional[Dict] = None,
consolidated: bool = True,
) -> xr.Dataset:
"""
Prepares the merged NWP dataset for use with ocf-data-sampler.

Args:
ds: The merged xarray dataset containing NWP data.
zarr_groups: List of Zarr group names extracted from the dataset.
store: The Zarr store containing the dataset.
chunks: Chunking configuration for opening Zarr files.
consolidated: Whether the Zarr dataset is consolidated.

Returns:
The transformed dataset compatible with ocf-data-sampler.

Raises:
ValueError: If required coordinates are missing.
"""
variable_mapping = create_dynamic_variable_mapping(zarr_groups, store, chunks, consolidated)
ds = ds.rename(variable_mapping)

required_coords = ["projection_x_coordinate", "projection_y_coordinate", "time"]
missing_coords = [coord for coord in required_coords if coord not in ds.coords]

if missing_coords:
raise ValueError(f"Missing required coordinates: {', '.join(missing_coords)}")

return ds
73 changes: 73 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,76 @@ def test_main_metoffice_load_remote(mock_load_env, mock_handle_load):
chunks="time:24,latitude:100",
remote=True,
)


@patch("open_data_pvnet.main.archive_to_hf")
@patch("open_data_pvnet.main.load_env_and_setup_logger")
def test_main_metoffice_monthly_archive(mock_load_env, mock_archive_to_hf):
"""Test monthly archive command (no day specified)"""
test_args = [
"metoffice",
"archive",
"--year",
"2024",
"--month",
"3",
"--region",
"uk",
"--overwrite",
]
with patch("sys.argv", ["script"] + test_args):
main()
mock_archive_to_hf.assert_called_once_with(
provider="metoffice",
year=2024,
month=3,
day=None,
hour=None,
region="uk",
overwrite=True,
archive_type="zarr.zip",
)


@patch("open_data_pvnet.main.handle_monthly_consolidation")
@patch("open_data_pvnet.main.load_env_and_setup_logger")
def test_main_metoffice_consolidate(mock_load_env, mock_consolidate):
"""Test consolidation operation"""
test_args = [
"metoffice",
"consolidate",
"--year",
"2024",
"--month",
"3",
"--day",
"1",
"--region",
"uk",
]
with patch("sys.argv", ["script"] + test_args):
main()
mock_consolidate.assert_called_once_with(
provider="metoffice",
year=2024,
month=3,
day=1,
region="uk",
overwrite=False,
)


def test_main_invalid_provider():
"""Test invalid provider"""
test_args = ["invalid_provider", "archive", "--year", "2024", "--month", "3"]
with patch("sys.argv", ["script"] + test_args):
with pytest.raises(SystemExit):
main()


def test_main_invalid_operation():
"""Test invalid operation"""
test_args = ["metoffice", "invalid_op", "--year", "2024", "--month", "3"]
with patch("sys.argv", ["script"] + test_args):
with pytest.raises(SystemExit):
main()