Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ocf_data_sampler/numpy_sample/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Conversion from Xarray to NumpySample"""

from .convert import convert_to_numpy_sample
from .datetime_features import encode_datetimes, get_t0_embedding
from .generation import convert_generation_to_numpy_sample, GenerationSampleKey
from .nwp import convert_nwp_to_numpy_sample, NWPSampleKey
from .satellite import convert_satellite_to_numpy_sample, SatelliteSampleKey
from .collate import stack_np_samples_into_batch
from .common_types import NumpySample, NumpyBatch, TensorBatch
from .sun_position import make_sun_position_numpy_sample
60 changes: 60 additions & 0 deletions ocf_data_sampler/numpy_sample/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Convert a dictionary of xarray objects to a NumpySample."""

import numpy as np
import xarray as xr

from ocf_data_sampler.numpy_sample.common_types import NumpySample


def convert_to_numpy_sample(
sample: dict[str, xr.DataArray | dict[str, xr.DataArray]],
t0_idx: int,
) -> NumpySample:
"""Convert a dictionary of xarray objects to a NumpySample.

Args:
sample: Dictionary of xarray DataArrays, with same structure as used inside
PVNet Dataset classes. Expected keys are any of following:
- "generation": DataArray of generation data
- "sat": DataArray of satellite data
- "nwp": dict of DataArrays by provider name (e.g. {"ukv": da, "ecmwf": da})
t0_idx: Index of t0 within generation

Returns:
NumpySample dictionary with all modalities merged
"""
numpy_sample: NumpySample = {}

if "generation" in sample:
da = sample["generation"]
numpy_sample.update({
"generation": da.values,
"capacity_mwp": da.capacity_mwp.values[0],
"time_utc": da["time_utc"].values.astype(float),
"t0_idx": int(t0_idx),
"longitude": float(da.longitude.values),
"latitude": float(da.latitude.values),
})

if "sat" in sample:
da = sample["sat"]
numpy_sample.update({
"satellite_actual": da.values,
"satellite_time_utc": da.time_utc.values.astype(float),
"satellite_x_geostationary": da.x_geostationary.values,
"satellite_y_geostationary": da.y_geostationary.values,
})

if "nwp" in sample:
numpy_sample["nwp"] = {}
for provider, da in sample["nwp"].items():
target_time_utc = da.init_time_utc.values + da.step.values
numpy_sample["nwp"][provider] = {
"nwp": da.values,
"nwp_channel_names": da.channel.values,
"nwp_init_time_utc": da.init_time_utc.values.astype(float),
"nwp_step": (da.step.values / np.timedelta64(1, "h")).astype(int),
"nwp_target_time_utc": target_time_utc.astype(float),
}

return numpy_sample
36 changes: 0 additions & 36 deletions ocf_data_sampler/numpy_sample/generation.py

This file was deleted.

37 changes: 0 additions & 37 deletions ocf_data_sampler/numpy_sample/nwp.py

This file was deleted.

35 changes: 0 additions & 35 deletions ocf_data_sampler/numpy_sample/satellite.py

This file was deleted.

84 changes: 26 additions & 58 deletions ocf_data_sampler/torch_datasets/pvnet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,13 @@
from ocf_data_sampler.config import Configuration, load_yaml_configuration
from ocf_data_sampler.load.load_dataset import get_dataset_dict
from ocf_data_sampler.numpy_sample import (
convert_generation_to_numpy_sample,
convert_nwp_to_numpy_sample,
convert_satellite_to_numpy_sample,
convert_to_numpy_sample,
encode_datetimes,
get_t0_embedding,
make_sun_position_numpy_sample,
)
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
from ocf_data_sampler.numpy_sample.common_types import NumpyBatch, NumpySample
from ocf_data_sampler.numpy_sample.generation import GenerationSampleKey
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
from ocf_data_sampler.select import (
Location,
fill_time_periods,
Expand All @@ -37,7 +33,6 @@
diff_nwp_data,
fill_nans_in_arrays,
find_valid_time_periods,
merge_dicts,
slice_datasets_by_space,
slice_datasets_by_time,
)
Expand Down Expand Up @@ -196,92 +191,66 @@ def process_and_combine_datasets(
t0: init-time for sample
location: location of the sample
"""
numpy_modalities = [{"t0": t0.timestamp()}]

# Normalise NWP
if "nwp" in dataset_dict:
nwp_numpy_modalities = {}

for nwp_key, da_nwp in dataset_dict["nwp"].items():
# Standardise and convert to NumpyBatch
channel_means = self.means_dict["nwp"][nwp_key]
channel_stds = self.stds_dict["nwp"][nwp_key]
dataset_dict["nwp"][nwp_key] = (da_nwp - channel_means) / channel_stds

da_nwp = (da_nwp - channel_means) / channel_stds

nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)

# Combine the NWPs into NumpyBatch
numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})

# Normalise satellite
if "sat" in dataset_dict:
da_sat = dataset_dict["sat"]

# Standardise and convert to NumpyBatch
channel_means = self.means_dict["sat"]
channel_stds = self.stds_dict["sat"]
dataset_dict["sat"] = (dataset_dict["sat"] - channel_means) / channel_stds

da_sat = (da_sat - channel_means) / channel_stds

numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))

# Normalise generation by capacity
if "generation" in dataset_dict:
da_generation = dataset_dict["generation"]
da_generation = da_generation / da_generation.capacity_mwp.values
dataset_dict["generation"] = da_generation / da_generation.capacity_mwp.values

# Convert to NumpyBatch
numpy_modalities.append(
convert_generation_to_numpy_sample(
da_generation,
t0_idx=self.t0_idx,
),
)
# Convert all xarray modalities to a single NumpySample
sample = convert_to_numpy_sample(dataset_dict, self.t0_idx)

numpy_modalities.append(
{
GenerationSampleKey.location_id: location.id,
GenerationSampleKey.longitude: da_generation.longitude.values,
GenerationSampleKey.latitude: da_generation.latitude.values,
},
)
# Add location metadata not present on the DataArray
if "generation" in dataset_dict:
sample["location_id"] = location.id

# Add datetime features
# Add datetime encodings over the full generation time range
generation_config = self.config.input_data.generation
datetimes = pd.date_range(
t0 + minutes(generation_config.interval_start_minutes),
t0 + minutes(generation_config.interval_end_minutes),
freq=minutes(generation_config.time_resolution_minutes),
)
numpy_modalities.append(encode_datetimes(datetimes=datetimes))
sample.update(encode_datetimes(datetimes=datetimes))

# Add t0 embedding if configured
if self.config.input_data.t0_embedding is not None:
numpy_modalities.append(
get_t0_embedding(t0, self.config.input_data.t0_embedding.embeddings),
)
sample.update(get_t0_embedding(t0, self.config.input_data.t0_embedding.embeddings))

# Only add solar position if explicitly configured
# Add solar position if configured
if self.config.input_data.solar_position is not None:
solar_config = self.config.input_data.solar_position

# Create datetime range for solar position calculation
datetimes = pd.date_range(
t0 + minutes(solar_config.interval_start_minutes),
t0 + minutes(solar_config.interval_end_minutes),
freq=minutes(solar_config.time_resolution_minutes),
)

numpy_modalities.append(
sample.update(
make_sun_position_numpy_sample(
datetimes,
da_generation.longitude.values,
da_generation.latitude.values,
dataset_dict["generation"].longitude.values,
dataset_dict["generation"].latitude.values,
),
)

# Combine all the modalities and fill NaNs
combined_sample = merge_dicts(numpy_modalities)
combined_sample = fill_nans_in_arrays(combined_sample, config=self.config)
sample["t0"] = t0.timestamp()

return combined_sample
# Fill NaNs
sample = fill_nans_in_arrays(sample, config=self.config)

return sample

@staticmethod
def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
Expand Down Expand Up @@ -461,7 +430,6 @@ def _get_sample(self, t0: pd.Timestamp) -> NumpyBatch:
t0: init-time for sample
"""
# Slice by time then load to avoid loading the data multiple times from disk

sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
sample_dict = load_data_dict(sample_dict)
sample_dict = diff_nwp_data(sample_dict, self.config)
Expand Down Expand Up @@ -493,7 +461,7 @@ def get_sample(self, t0: pd.Timestamp) -> NumpyBatch:
Args:
t0: init-time for sample
"""
# Check data is availablle for init-time t0
# Check data is available for init-time t0
if t0 not in self.valid_t0_times:
raise ValueError(f"Input init time '{t0!s}' not in valid times")
return self._get_sample(t0)
9 changes: 3 additions & 6 deletions ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import numpy as np

from ocf_data_sampler.config.model import Configuration
from ocf_data_sampler.numpy_sample.generation import GenerationSampleKey
from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
from ocf_data_sampler.numpy_sample.satellite import SatelliteSampleKey


def merge_dicts(list_of_dicts: list[dict]) -> dict:
Expand All @@ -31,11 +28,11 @@ def fill_nans_in_arrays(
if np.isnan(v).any():
fill_value = 0.0
if config is not None:
if k == GenerationSampleKey.generation:
if k == "generation":
fill_value = config.input_data.generation.dropout_value
elif k == SatelliteSampleKey.satellite_actual:
elif k == "satellite_actual":
fill_value = config.input_data.satellite.dropout_value
elif k == NWPSampleKey.nwp and nwp_provider in config.input_data.nwp:
elif k == "nwp" and nwp_provider in config.input_data.nwp:
fill_value = config.input_data.nwp[nwp_provider].dropout_value

sample[k] = np.nan_to_num(v, copy=False, nan=fill_value)
Expand Down
6 changes: 3 additions & 3 deletions tests/load/test_load_nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_load_cloudcasting(nwp_cloudcasting_zarr_path):
def test_load_ukv_new_coords(tmp_path):
"""Test UKV loader handles data that already has new coordinate names."""
zarr_path = tmp_path / "new_ukv_coords.zarr"

# Create a mock dataset that already uses the new naming convention
new_coords_array = DataArray(
np.random.rand(1, 1, 1, 1, 1).astype(np.float32),
Expand All @@ -54,10 +54,10 @@ def test_load_ukv_new_coords(tmp_path):
},
)
new_coords_array.to_zarr(zarr_path)

# This should succeed without KeyError
da = open_nwp(zarr_path=zarr_path, provider="ukv")

assert isinstance(da, DataArray)
assert "x_osgb" in da.coords
assert "y_osgb" in da.coords
Expand Down
Loading