diff --git a/ocf_data_sampler/numpy_sample/__init__.py b/ocf_data_sampler/numpy_sample/__init__.py index bd871496..f6063d8a 100644 --- a/ocf_data_sampler/numpy_sample/__init__.py +++ b/ocf_data_sampler/numpy_sample/__init__.py @@ -1,7 +1,7 @@ """Conversion from Xarray to NumpySample""" +from .convert import convert_to_numpy_sample from .datetime_features import encode_datetimes, get_t0_embedding -from .generation import convert_generation_to_numpy_sample, GenerationSampleKey -from .nwp import convert_nwp_to_numpy_sample, NWPSampleKey -from .satellite import convert_satellite_to_numpy_sample, SatelliteSampleKey +from .collate import stack_np_samples_into_batch +from .common_types import NumpySample, NumpyBatch, TensorBatch from .sun_position import make_sun_position_numpy_sample \ No newline at end of file diff --git a/ocf_data_sampler/numpy_sample/convert.py b/ocf_data_sampler/numpy_sample/convert.py new file mode 100644 index 00000000..c1a4825d --- /dev/null +++ b/ocf_data_sampler/numpy_sample/convert.py @@ -0,0 +1,60 @@ +"""Convert a dictionary of xarray objects to a NumpySample.""" + +import numpy as np +import xarray as xr + +from ocf_data_sampler.numpy_sample.common_types import NumpySample + + +def convert_to_numpy_sample( + sample: dict[str, xr.DataArray | dict[str, xr.DataArray]], + t0_idx: int, +) -> NumpySample: + """Convert a dictionary of xarray objects to a NumpySample. + + Args: + sample: Dictionary of xarray DataArrays, with same structure as used inside + PVNet Dataset classes. Expected keys are any of following: + - "generation": DataArray of generation data + - "sat": DataArray of satellite data + - "nwp": dict of DataArrays by provider name (e.g. {"ukv": da, "ecmwf": da}) + t0_idx: Index of t0 within generation + + Returns: + NumpySample dictionary with all modalities merged + """ + numpy_sample: NumpySample = {} + + if "generation" in sample: + da = sample["generation"] + numpy_sample.update({ + "generation": da.values, + "capacity_mwp": da.capacity_mwp.values[0], + "time_utc": da["time_utc"].values.astype(float), + "t0_idx": int(t0_idx), + "longitude": float(da.longitude.values), + "latitude": float(da.latitude.values), + }) + + if "sat" in sample: + da = sample["sat"] + numpy_sample.update({ + "satellite_actual": da.values, + "satellite_time_utc": da.time_utc.values.astype(float), + "satellite_x_geostationary": da.x_geostationary.values, + "satellite_y_geostationary": da.y_geostationary.values, + }) + + if "nwp" in sample: + numpy_sample["nwp"] = {} + for provider, da in sample["nwp"].items(): + target_time_utc = da.init_time_utc.values + da.step.values + numpy_sample["nwp"][provider] = { + "nwp": da.values, + "nwp_channel_names": da.channel.values, + "nwp_init_time_utc": da.init_time_utc.values.astype(float), + "nwp_step": (da.step.values / np.timedelta64(1, "h")).astype(int), + "nwp_target_time_utc": target_time_utc.astype(float), + } + + return numpy_sample diff --git a/ocf_data_sampler/numpy_sample/generation.py b/ocf_data_sampler/numpy_sample/generation.py deleted file mode 100644 index f49f4104..00000000 --- a/ocf_data_sampler/numpy_sample/generation.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Convert Generation data to Numpy Sample.""" - -import xarray as xr - -from ocf_data_sampler.numpy_sample.common_types import NumpySample - - -class GenerationSampleKey: - """Keys for the Generation sample dictionary.""" - - generation = "generation" - capacity_mwp = "capacity_mwp" - time_utc = "time_utc" - t0_idx = "t0_idx" - location_id = "location_id" - longitude = "longitude" - latitude = "latitude" - - -def convert_generation_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> NumpySample: - """Convert from Xarray to NumpySample. - - Args: - da: Xarray DataArray containing generation data - t0_idx: Index of the t0 timestamp in the time dimension of the generation data - """ - sample = { - GenerationSampleKey.generation: da.values, - GenerationSampleKey.capacity_mwp: da.capacity_mwp.values[0], - GenerationSampleKey.time_utc: da["time_utc"].values.astype(float), - } - - if t0_idx is not None: - sample[GenerationSampleKey.t0_idx] = t0_idx - - return sample diff --git a/ocf_data_sampler/numpy_sample/nwp.py b/ocf_data_sampler/numpy_sample/nwp.py deleted file mode 100644 index e3ee86c2..00000000 --- a/ocf_data_sampler/numpy_sample/nwp.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Convert NWP to NumpySample.""" - -import xarray as xr - -from ocf_data_sampler.numpy_sample.common_types import NumpySample - - -class NWPSampleKey: - """Keys for NWP NumpySample.""" - - nwp = "nwp" - channel_names = "nwp_channel_names" - init_time_utc = "nwp_init_time_utc" - step = "nwp_step" - target_time_utc = "nwp_target_time_utc" - t0_idx = "nwp_t0_idx" - - -def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> NumpySample: - """Convert from Xarray to NWP NumpySample. - - Args: - da: Xarray DataArray containing NWP data - t0_idx: Index of the t0 timestamp in the time dimension of the NWP - """ - sample = { - NWPSampleKey.nwp: da.values, - NWPSampleKey.channel_names: da.channel.values, - NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float), - NWPSampleKey.step: (da.step.values / 3600).astype(int), - NWPSampleKey.target_time_utc: (da.init_time_utc.values + da.step.values).astype(float), - } - - if t0_idx is not None: - sample[NWPSampleKey.t0_idx] = t0_idx - - return sample diff --git a/ocf_data_sampler/numpy_sample/satellite.py b/ocf_data_sampler/numpy_sample/satellite.py deleted file mode 100644 index d7045151..00000000 --- a/ocf_data_sampler/numpy_sample/satellite.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Convert Satellite to NumpySample.""" - -import xarray as xr - -from ocf_data_sampler.numpy_sample.common_types import NumpySample - - -class SatelliteSampleKey: - """Keys for the SatelliteSample dictionary.""" - - satellite_actual = "satellite_actual" - time_utc = "satellite_time_utc" - x_geostationary = "satellite_x_geostationary" - y_geostationary = "satellite_y_geostationary" - t0_idx = "satellite_t0_idx" - - -def convert_satellite_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> NumpySample: - """Convert from Xarray to NumpySample. - - Args: - da: xarray DataArray containing satellite data - t0_idx: Index of the t0 timestamp in the time dimension of the satellite data - """ - sample = { - SatelliteSampleKey.satellite_actual: da.values, - SatelliteSampleKey.time_utc: da.time_utc.values.astype(float), - SatelliteSampleKey.x_geostationary: da.x_geostationary.values, - SatelliteSampleKey.y_geostationary: da.y_geostationary.values, - } - - if t0_idx is not None: - sample[SatelliteSampleKey.t0_idx] = t0_idx - - return sample diff --git a/ocf_data_sampler/torch_datasets/pvnet_dataset.py b/ocf_data_sampler/torch_datasets/pvnet_dataset.py index 0cbce80f..c7adff72 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_dataset.py +++ b/ocf_data_sampler/torch_datasets/pvnet_dataset.py @@ -14,17 +14,13 @@ from ocf_data_sampler.config import Configuration, load_yaml_configuration from ocf_data_sampler.load.load_dataset import get_dataset_dict from ocf_data_sampler.numpy_sample import ( - convert_generation_to_numpy_sample, - convert_nwp_to_numpy_sample, - convert_satellite_to_numpy_sample, + convert_to_numpy_sample, encode_datetimes, get_t0_embedding, make_sun_position_numpy_sample, ) from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch from ocf_data_sampler.numpy_sample.common_types import NumpyBatch, NumpySample -from ocf_data_sampler.numpy_sample.generation import GenerationSampleKey -from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey from ocf_data_sampler.select import ( Location, fill_time_periods, @@ -37,7 +33,6 @@ diff_nwp_data, fill_nans_in_arrays, find_valid_time_periods, - merge_dicts, slice_datasets_by_space, slice_datasets_by_time, ) @@ -196,92 +191,66 @@ def process_and_combine_datasets( t0: init-time for sample location: location of the sample """ - numpy_modalities = [{"t0": t0.timestamp()}] - + # Normalise NWP if "nwp" in dataset_dict: - nwp_numpy_modalities = {} - for nwp_key, da_nwp in dataset_dict["nwp"].items(): - # Standardise and convert to NumpyBatch channel_means = self.means_dict["nwp"][nwp_key] channel_stds = self.stds_dict["nwp"][nwp_key] + dataset_dict["nwp"][nwp_key] = (da_nwp - channel_means) / channel_stds - da_nwp = (da_nwp - channel_means) / channel_stds - - nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp) - - # Combine the NWPs into NumpyBatch - numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities}) - + # Normalise satellite if "sat" in dataset_dict: - da_sat = dataset_dict["sat"] - - # Standardise and convert to NumpyBatch channel_means = self.means_dict["sat"] channel_stds = self.stds_dict["sat"] + dataset_dict["sat"] = (dataset_dict["sat"] - channel_means) / channel_stds - da_sat = (da_sat - channel_means) / channel_stds - - numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat)) - + # Normalise generation by capacity if "generation" in dataset_dict: da_generation = dataset_dict["generation"] - da_generation = da_generation / da_generation.capacity_mwp.values + dataset_dict["generation"] = da_generation / da_generation.capacity_mwp.values - # Convert to NumpyBatch - numpy_modalities.append( - convert_generation_to_numpy_sample( - da_generation, - t0_idx=self.t0_idx, - ), - ) + # Convert all xarray modalities to a single NumpySample + sample = convert_to_numpy_sample(dataset_dict, self.t0_idx) - numpy_modalities.append( - { - GenerationSampleKey.location_id: location.id, - GenerationSampleKey.longitude: da_generation.longitude.values, - GenerationSampleKey.latitude: da_generation.latitude.values, - }, - ) + # Add location metadata not present on the DataArray + if "generation" in dataset_dict: + sample["location_id"] = location.id - # Add datetime features + # Add datetime encodings over the full generation time range generation_config = self.config.input_data.generation datetimes = pd.date_range( t0 + minutes(generation_config.interval_start_minutes), t0 + minutes(generation_config.interval_end_minutes), freq=minutes(generation_config.time_resolution_minutes), ) - numpy_modalities.append(encode_datetimes(datetimes=datetimes)) + sample.update(encode_datetimes(datetimes=datetimes)) + # Add t0 embedding if configured if self.config.input_data.t0_embedding is not None: - numpy_modalities.append( - get_t0_embedding(t0, self.config.input_data.t0_embedding.embeddings), - ) + sample.update(get_t0_embedding(t0, self.config.input_data.t0_embedding.embeddings)) - # Only add solar position if explicitly configured + # Add solar position if configured if self.config.input_data.solar_position is not None: solar_config = self.config.input_data.solar_position - - # Create datetime range for solar position calculation datetimes = pd.date_range( t0 + minutes(solar_config.interval_start_minutes), t0 + minutes(solar_config.interval_end_minutes), freq=minutes(solar_config.time_resolution_minutes), ) - - numpy_modalities.append( + sample.update( make_sun_position_numpy_sample( datetimes, - da_generation.longitude.values, - da_generation.latitude.values, + dataset_dict["generation"].longitude.values, + dataset_dict["generation"].latitude.values, ), ) - # Combine all the modalities and fill NaNs - combined_sample = merge_dicts(numpy_modalities) - combined_sample = fill_nans_in_arrays(combined_sample, config=self.config) + sample["t0"] = t0.timestamp() - return combined_sample + # Fill NaNs + sample = fill_nans_in_arrays(sample, config=self.config) + + return sample @staticmethod def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex: @@ -461,7 +430,6 @@ def _get_sample(self, t0: pd.Timestamp) -> NumpyBatch: t0: init-time for sample """ # Slice by time then load to avoid loading the data multiple times from disk - sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config) sample_dict = load_data_dict(sample_dict) sample_dict = diff_nwp_data(sample_dict, self.config) @@ -493,7 +461,7 @@ def get_sample(self, t0: pd.Timestamp) -> NumpyBatch: Args: t0: init-time for sample """ - # Check data is availablle for init-time t0 + # Check data is available for init-time t0 if t0 not in self.valid_t0_times: raise ValueError(f"Input init time '{t0!s}' not in valid times") return self._get_sample(t0) diff --git a/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py b/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py index 3283ff63..b565e152 100644 --- a/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +++ b/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py @@ -3,9 +3,6 @@ import numpy as np from ocf_data_sampler.config.model import Configuration -from ocf_data_sampler.numpy_sample.generation import GenerationSampleKey -from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey -from ocf_data_sampler.numpy_sample.satellite import SatelliteSampleKey def merge_dicts(list_of_dicts: list[dict]) -> dict: @@ -31,11 +28,11 @@ def fill_nans_in_arrays( if np.isnan(v).any(): fill_value = 0.0 if config is not None: - if k == GenerationSampleKey.generation: + if k == "generation": fill_value = config.input_data.generation.dropout_value - elif k == SatelliteSampleKey.satellite_actual: + elif k == "satellite_actual": fill_value = config.input_data.satellite.dropout_value - elif k == NWPSampleKey.nwp and nwp_provider in config.input_data.nwp: + elif k == "nwp" and nwp_provider in config.input_data.nwp: fill_value = config.input_data.nwp[nwp_provider].dropout_value sample[k] = np.nan_to_num(v, copy=False, nan=fill_value) diff --git a/tests/load/test_load_nwp.py b/tests/load/test_load_nwp.py index a7e8395c..0a804547 100755 --- a/tests/load/test_load_nwp.py +++ b/tests/load/test_load_nwp.py @@ -40,7 +40,7 @@ def test_load_cloudcasting(nwp_cloudcasting_zarr_path): def test_load_ukv_new_coords(tmp_path): """Test UKV loader handles data that already has new coordinate names.""" zarr_path = tmp_path / "new_ukv_coords.zarr" - + # Create a mock dataset that already uses the new naming convention new_coords_array = DataArray( np.random.rand(1, 1, 1, 1, 1).astype(np.float32), @@ -54,10 +54,10 @@ def test_load_ukv_new_coords(tmp_path): }, ) new_coords_array.to_zarr(zarr_path) - + # This should succeed without KeyError da = open_nwp(zarr_path=zarr_path, provider="ukv") - + assert isinstance(da, DataArray) assert "x_osgb" in da.coords assert "y_osgb" in da.coords diff --git a/tests/numpy_sample/test_generation.py b/tests/numpy_sample/test_generation.py index 1b843ffd..5eeba992 100644 --- a/tests/numpy_sample/test_generation.py +++ b/tests/numpy_sample/test_generation.py @@ -1,32 +1,25 @@ import numpy as np from ocf_data_sampler.load.generation import open_generation -from ocf_data_sampler.numpy_sample import GenerationSampleKey, convert_generation_to_numpy_sample +from ocf_data_sampler.numpy_sample import convert_to_numpy_sample def test_convert_generation_to_numpy_sample(generation_zarr_path): da = open_generation(generation_zarr_path).isel(time_utc=slice(0, 10)).sel(location_id=1) - numpy_sample = convert_generation_to_numpy_sample(da) + t0_idx = 0 + numpy_sample = convert_to_numpy_sample({"generation": da}, t0_idx=t0_idx) # Assert structure - expected_keys = { - GenerationSampleKey.generation, - GenerationSampleKey.capacity_mwp, - GenerationSampleKey.time_utc, - } assert isinstance(numpy_sample, dict) - assert set(numpy_sample) == expected_keys + assert "generation" in numpy_sample + assert "capacity_mwp" in numpy_sample + assert "time_utc" in numpy_sample # Assert content and capacity values - assert np.array_equal(numpy_sample[GenerationSampleKey.generation], da.values) - assert isinstance(numpy_sample[GenerationSampleKey.time_utc], np.ndarray) - assert numpy_sample[GenerationSampleKey.time_utc].dtype == float + assert np.array_equal(numpy_sample["generation"], da.values) + assert isinstance(numpy_sample["time_utc"], np.ndarray) + assert numpy_sample["time_utc"].dtype == float + assert numpy_sample["capacity_mwp"] == da.capacity_mwp.isel(time_utc=0).values - assert numpy_sample[GenerationSampleKey.capacity_mwp] == ( - da.capacity_mwp.isel(time_utc=0).values - ) - - # With t0_idx - t0_idx = 5 - numpy_sample_with_t0 = convert_generation_to_numpy_sample(da, t0_idx=t0_idx) - assert numpy_sample_with_t0[GenerationSampleKey.t0_idx] == t0_idx + # Assert t0_idx is passed through + assert numpy_sample["t0_idx"] == 0 diff --git a/tests/numpy_sample/test_nwp.py b/tests/numpy_sample/test_nwp.py index 26cda095..c5cdc022 100644 --- a/tests/numpy_sample/test_nwp.py +++ b/tests/numpy_sample/test_nwp.py @@ -1,9 +1,12 @@ -from ocf_data_sampler.numpy_sample import NWPSampleKey, convert_nwp_to_numpy_sample +from ocf_data_sampler.numpy_sample import convert_to_numpy_sample def test_convert_nwp_to_numpy_sample(ds_nwp_ukv_time_sliced): - numpy_sample = convert_nwp_to_numpy_sample(ds_nwp_ukv_time_sliced) + t0_idx = 0 + numpy_sample = convert_to_numpy_sample( + {"nwp": {"ukv": ds_nwp_ukv_time_sliced}}, + t0_idx=t0_idx, + ) - # Assert output type and shape of sample assert isinstance(numpy_sample, dict) - assert (numpy_sample[NWPSampleKey.nwp] == ds_nwp_ukv_time_sliced.values).all() + assert (numpy_sample["nwp"]["ukv"]["nwp"] == ds_nwp_ukv_time_sliced.values).all() diff --git a/tests/numpy_sample/test_satellite.py b/tests/numpy_sample/test_satellite.py index 199bbe6b..022b4d51 100644 --- a/tests/numpy_sample/test_satellite.py +++ b/tests/numpy_sample/test_satellite.py @@ -1,10 +1,9 @@ - -from ocf_data_sampler.numpy_sample import SatelliteSampleKey, convert_satellite_to_numpy_sample +from ocf_data_sampler.numpy_sample import convert_to_numpy_sample def test_convert_satellite_to_numpy_sample(da_sat_like): - numpy_sample = convert_satellite_to_numpy_sample(da_sat_like) + t0_idx = 0 + numpy_sample = convert_to_numpy_sample({"sat": da_sat_like}, t0_idx=t0_idx) - # Assert output type and shape of sample assert isinstance(numpy_sample, dict) - assert (numpy_sample[SatelliteSampleKey.satellite_actual] == da_sat_like.values).all() + assert (numpy_sample["satellite_actual"] == da_sat_like.values).all() diff --git a/tests/torch_datasets/utils/test_merge_and_fill_utils.py b/tests/torch_datasets/utils/test_merge_and_fill_utils.py index 757b28c3..8802e6b7 100644 --- a/tests/torch_datasets/utils/test_merge_and_fill_utils.py +++ b/tests/torch_datasets/utils/test_merge_and_fill_utils.py @@ -1,9 +1,6 @@ import numpy as np from ocf_data_sampler.config import load_yaml_configuration -from ocf_data_sampler.numpy_sample.generation import GenerationSampleKey -from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey -from ocf_data_sampler.numpy_sample.satellite import SatelliteSampleKey from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import ( fill_nans_in_arrays, merge_dicts, @@ -51,18 +48,16 @@ def test_fill_nans_on_numpy_samples(config_filename): array_with_nans = np.array([1.0, np.nan, 3.0, np.nan]) # we use copy() to ensure separate arrays for each key - dict = { - GenerationSampleKey.generation: array_with_nans.copy(), - SatelliteSampleKey.satellite_actual: array_with_nans.copy(), + sample = { + "generation": array_with_nans.copy(), + "satellite_actual": array_with_nans.copy(), "ukv": { - NWPSampleKey.nwp: np.array([np.nan, 2.0, np.nan, 4.0]), + "nwp": np.array([np.nan, 2.0, np.nan, 4.0]), }, } - result = fill_nans_in_arrays(dict, config=configuration) + result = fill_nans_in_arrays(sample, config=configuration) - assert np.array_equal(result[GenerationSampleKey.generation], np.array([1.0, 0.0, 3.0, 0.0])) - assert np.array_equal( - result[SatelliteSampleKey.satellite_actual], np.array([1.0, -1.0, 3.0, -1.0]), - ) - assert np.array_equal(result["ukv"][NWPSampleKey.nwp], np.array([-2.0, 2.0, -2.0, 4.0])) + assert np.array_equal(result["generation"], np.array([1.0, 0.0, 3.0, 0.0])) + assert np.array_equal(result["satellite_actual"], np.array([1.0, -1.0, 3.0, -1.0])) + assert np.array_equal(result["ukv"]["nwp"], np.array([-2.0, 2.0, -2.0, 4.0]))