Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Mar 7, 2025
1 parent 779334e commit 96292a7
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 22 deletions.
2 changes: 1 addition & 1 deletion ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
"""

from collections.abc import Iterator
from typing_extensions import override

from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
from typing_extensions import override

from ocf_data_sampler.constants import NWP_PROVIDERS

Expand Down
3 changes: 1 addition & 2 deletions ocf_data_sampler/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Constants for the package."""

from typing_extensions import override

import numpy as np
import xarray as xr
from typing_extensions import override

NWP_PROVIDERS = [
"ukv",
Expand Down
2 changes: 1 addition & 1 deletion ocf_data_sampler/select/fill_time_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.Date
end_dts = pd.to_datetime(time_periods["end_dt"].values)
date_ranges = [
pd.date_range(start_dt, end_dt, freq=freq)
for start_dt, end_dt in zip(start_dts, end_dts)
for start_dt, end_dt in zip(start_dts, end_dts, strict=False)
]
return pd.DatetimeIndex(np.concatenate(date_ranges))
7 changes: 2 additions & 5 deletions ocf_data_sampler/select/find_contiguous_time_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from ocf_data_sampler.load.utils import check_time_unique_increasing


ZERO_TDELTA = pd.Timedelta(0)


Expand Down Expand Up @@ -80,7 +79,6 @@ def trim_contiguous_time_periods(
Returns:
The contiguous_time_periods pd.DataFrame with the `start_dt` and `end_dt` columns updated.
"""

# Make a copy so the data is not edited in place.
trimmed_time_periods = contiguous_time_periods.copy()
trimmed_time_periods["start_dt"] -= interval_start
Expand Down Expand Up @@ -109,9 +107,8 @@ def find_contiguous_t0_periods(
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""

check_time_unique_increasing(datetimes)

total_duration = interval_end - interval_start

contiguous_time_periods = find_contiguous_time_periods(
Expand Down Expand Up @@ -163,7 +160,7 @@ def find_contiguous_t0_periods_nwp(
# Sanity checks.
if len(init_times) == 0:
raise ValueError("No init-times to use")

check_time_unique_increasing(init_times)

if max_staleness < pd.Timedelta(0):
Expand Down
4 changes: 2 additions & 2 deletions ocf_data_sampler/select/geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
"""

import numpy as np
import xarray as xr
import pyproj
import pyresample
import xarray as xr

# Coordinate Reference System (CRS) identifiers
# OSGB36: UK Ordnance Survey National Grid (easting/northing in meters)
Expand Down Expand Up @@ -118,7 +118,7 @@ def coordinates_to_geostationary_area_coords(
Geostationary coords: x, y
"""
if crs_from not in [OSGB36, WGS84]:
raise ValueError(f"Unrecognized coordinate system: {crs_from}")
raise ValueError(f"Unrecognized coordinate system: {crs_from}")

area_definition_yaml = xr_data.attrs["area"]

Expand Down
6 changes: 3 additions & 3 deletions ocf_data_sampler/select/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
class Location(BaseModel):
"""Represent a spatial location."""

coordinate_system: str = Field(...,
description="Coordinate system for the location must be lon_lat, osgb, or geostationary"
coordinate_system: str = Field(...,
description="Coordinate system for the location must be lon_lat, osgb, or geostationary",
)

x: float = Field(..., description="x coordinate - i.e. east-west position")
y: float = Field(..., description="y coordinate - i.e. north-south position")
id: int | None = Field(None, description="ID of the location - e.g. GSP ID")
Expand Down
10 changes: 5 additions & 5 deletions ocf_data_sampler/select/select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def select_time_slice(
time_resolution: pd.Timedelta,
) -> xr.DataArray:
"""Select a time slice from a DataArray.
Args:
da: The DataArray to slice from
t0: The init-time
Expand Down Expand Up @@ -41,7 +41,7 @@ def select_time_slice_nwp(
accum_channels: list[str] | None = None,
) -> xr.DataArray:
"""Select a time slice from an NWP DataArray.
Args:
da: The DataArray to slice from
t0: The init-time
Expand All @@ -54,16 +54,16 @@ def select_time_slice_nwp(
"""
if accum_channels is None:
accum_channels = []

if dropout_timedeltas is not None:
if not all(t < pd.Timedelta(0) for t in dropout_timedeltas):
raise ValueError("dropout timedeltas must be negative")
if len(dropout_timedeltas) < 1:
raise ValueError("dropout timedeltas must have at least one element")

if not (0 <= dropout_frac <= 1):
raise ValueError("dropout_frac must be between 0 and 1")

consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0

# The accumatated and non-accumulated channels
Expand Down
2 changes: 1 addition & 1 deletion ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Torch dataset for UK PVNet."""

from importlib.resources import files
from typing_extensions import override

import numpy as np
import pandas as pd
import xarray as xr
from torch.utils.data import Dataset
from typing_extensions import override

from ocf_data_sampler.config import Configuration, load_yaml_configuration
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
Expand Down
2 changes: 1 addition & 1 deletion ocf_data_sampler/torch_datasets/datasets/site.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Torch dataset for sites."""

import logging
from typing_extensions import override

import numpy as np
import pandas as pd
import xarray as xr
from torch.utils.data import Dataset
from typing_extensions import override

from ocf_data_sampler.config import Configuration, load_yaml_configuration
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sample/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
"""

from pathlib import Path
from typing_extensions import override

import numpy as np
import pytest
import torch
from typing_extensions import override

from ocf_data_sampler.sample.base import SampleBase, batch_to_tensor, copy_batch_to_device

Expand Down

0 comments on commit 96292a7

Please sign in to comment.