Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions ocf_data_sampler/lightarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
"""A lightweight DataArray-like class."""

from typing import Any

import numpy as np
import xarray as xr
from tensorstore import Future as TensorStoreFuture
from tensorstore import TensorStore
from xarray_tensorstore import _TensorStoreAdapter


class LightDataArray:
"""A lightweight DataArray-like class."""

__slots__ = ["attrs", "coord_dims", "coords", "data", "dims", "future"]

def __init__(
self,
data: np.ndarray | TensorStore,
dims: tuple[str, ...],
coords: dict[str, np.ndarray],
coord_dims: dict[str, tuple[str, ...]],
attrs: None | dict = None,
) -> None:
"""A lightweight DataArray-like class."""
self.data = data
self.dims = dims
self.coords = coords
self.coord_dims = coord_dims
self.attrs = attrs or {}
self.future: None | TensorStoreFuture = None

@classmethod
def from_xarray(cls, da: xr.DataArray) -> "LightDataArray":
"""Create a LightDataArray from an Xarray DataArray."""
# Get raw data handle which can be a numpy array or TensorStore
data: TensorStore | np.ndarray
if isinstance(da.variable._data, _TensorStoreAdapter):
data = da.variable._data.array
elif isinstance(da.variable._data, np.ndarray):
data = da.variable._data
else:
raise ValueError(f"Data backend of type {type(da.variable._data)} not supported.")

coord_values: dict[str, np.ndarray] = {}
coord_dims: dict[str, tuple[str, ...]] = {}

for k, v in da.coords.items():
if v.ndim <= 1:
coord_values[k] = v.values
coord_dims[k] = v.dims
else:
raise ValueError(
"Coordinates with more than 1 dimension not supported. "
f"Found coord '{k}' with shape {v.shape}.",
)

return cls(
data=data,
dims=da.dims,
coords=coord_values,
coord_dims=coord_dims,
attrs=da.attrs,
)

def to_xarray(self) -> xr.DataArray:
"""Convert to an Xarray DataArray.

Note this loads the data eagerly.
"""
coords_dict = {}
for c, v in self.coords.items():
cdims = self.coord_dims.get(c, ())

# If it's a 1D array and the dimension is still in our dims list
if np.ndim(v) == 1 and cdims[0] in self.dims:
coords_dict[c] = (cdims, v)
else:
# It's a scalar or a non-indexed coordinate
coords_dict[c] = v

return xr.DataArray(
data=self.values,
dims=self.dims,
coords=coords_dict,
attrs=self.attrs,
)

def isel(
self,
indexers: None | dict[str, int | slice | list] = None,
**indexers_kwargs: object,
) -> "LightDataArray":
"""Select data by integer index along specified dimensions.

Args:
indexers: A dict with keys matching dimensions and values given by integers, slice
objects or arrays. `indexer` can be an integer, slice or array-like.
**indexers_kwargs: The keyword arguments form of indexers.
"""
if indexers is not None:
indexers_kwargs.update(indexers)

axis_indexers = [slice(None)] * len(self.dims)
new_coords = self.coords.copy()
dims_to_remove = []

for dim, indexer in indexers_kwargs.items():
if dim not in self.dims:
raise KeyError(
f"'{dim}' is not a valid dimension or coordinate for data with dimensions"
f"{self.dims}",
)

axis_indexers[self.dims.index(dim)] = indexer

# Slice the coords which depend on this dimension
for c_name, c_dim_name in self.coord_dims.items():
if c_dim_name == (dim,):
new_coords[c_name] = new_coords[c_name][indexer]

# Check if this dimension is being collapsed (e.g. an integer index like .isel(time=0))
if isinstance(indexer, int | np.integer):
dims_to_remove.append(dim)

# Slice the underlying dta
sliced_data = self.data[tuple(axis_indexers)]

# Remove dims that have been reduced to points
remaining_dims = tuple(d for d in self.dims if d not in dims_to_remove)

# Remove dims from coords that have been reduced to points
new_coord_dims = self.coord_dims.copy()
for dim in dims_to_remove:
for c_name, c_dim_name in self.coord_dims.items():
if c_dim_name == (dim,):
new_coord_dims[c_name] = ()

return LightDataArray(
data=sliced_data,
dims=remaining_dims,
coords=new_coords,
coord_dims=new_coord_dims,
attrs=self.attrs,
)

def _to_index(self, dim: str, label: object) -> slice | int:
coord = self.coords[dim]
if isinstance(label, slice):
# start: find first index >= label.start
start = None
if label.start is not None:
start = np.searchsorted(coord, label.start, side="left")

# stop: find first index > label.stop to ensure slice includes endpoints
stop = None
if label.stop is not None:
stop = np.searchsorted(coord, label.stop, side="right")

return slice(start, stop)
else:
return np.searchsorted(coord, label, side="left")

def sel(
self,
indexers: None | dict[str, Any | slice | list] = None,
**indexers_kwargs: object,
) -> "LightDataArray":
"""Select data by coordinate labels, converting them to indices.

Args:
indexers: A dict with keys matching dimensions and values given by scalars, slices or
arrays of tick labels. For dimensions with multi-index, the indexer may also be a
dict-like object with keys matching index level names.
**indexers_kwargs: The keyword arguments form of indexers.
"""
if indexers is not None:
indexers_kwargs.update(indexers)

isel_kwargs = {dim: self._to_index(dim, val) for dim, val in indexers_kwargs.items()}
return self.isel(**isel_kwargs)

def read(self) -> None:
"""Trigger reading of the data if it's a lazy handle."""
if isinstance(self.data, TensorStore):
self.future = self.data.read()

def load(self) -> "LightDataArray":
"""Load data in-place and return self."""
self.data = self.values
self.future = None
return self

@property
def values(self) -> np.ndarray:
"""Get the underlying data as numpy array, loading it if necessary."""
if isinstance(self.data, TensorStore):
# If TensorStore handle reading
if self.future is None:
return np.asarray(self.data.read().result())
else:
return np.asarray(self.future.result())
else:
return np.asarray(self.data)


def __getattr__(self, name: str) -> "LightDataArray":
"""Allow access to coordinates via attribute syntax, e.g., da.time."""
if name in self.coords:
return self[name]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __getitem__(self, key: str) -> "LightDataArray":
"""Allow access to coordinates via indexing syntax, e.g., da['time']."""
if key in self.coords:
return LightDataArray(
data=self.coords[key],
dims=self.coord_dims[key],
coords={key: self.coords[key]},
coord_dims={key: self.coord_dims[key]},
)
raise KeyError(f"Coordinate '{key}' not found.")

def __getstate__(self) -> dict:
"""Prepare state for pickling, excluding un-picklable attributes."""
return {
"data": self.data,
"dims": self.dims,
"coords": self.coords,
"attrs": self.attrs,
"coord_dims": self.coord_dims,
}

def __setstate__(self, state: dict) -> None:
"""Restore state after unpickling."""
for k, v in state.items():
setattr(self, k, v)
# Restore the un-picklable attribute to a default state
self.future = None

@property
def shape(self) -> tuple[int, ...]:
"""Return the shape of the underlying data array."""
return self.data.shape

def __len__(self) -> int:
"""Return the length of the underlying data array."""
return self.shape[0]
7 changes: 2 additions & 5 deletions ocf_data_sampler/load/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,15 @@ def open_generation(zarr_path: str, public: bool = False) -> xr.DataArray:
backend_kwargs=backend_kwargs,
)

ds = ds.assign_coords(capacity_mwp=ds.capacity_mwp)

da = ds.generation_mw
da = ds.to_dataarray("gen_param").transpose("time_utc", "location_id", "gen_param")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to move the capacities out of the coords. This makes the dataarray simpler and allows us to enforce that all coords at 1-dimensional. This allows theLightDataArray to be simpler by reducing the scope


# Validate data types
if not np.issubdtype(da.dtype, np.floating):
raise TypeError(f"generation_mw should be floating, not {da.dtype}")
raise TypeError(f"generation should be floating, not {da.dtype}")

coord_dtypes = {
"time_utc": np.datetime64,
"location_id": np.integer,
"capacity_mwp": np.floating,
"longitude": np.floating,
"latitude": np.floating,
}
Expand Down
31 changes: 21 additions & 10 deletions ocf_data_sampler/numpy_sample/datetime_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,32 @@
from typing import Literal

import numpy as np
import pandas as pd
from numpy.typing import NDArray

from ocf_data_sampler.numpy_sample.common_types import NumpySample
from ocf_data_sampler.time_utils import (
get_day_fraction,
get_day_of_year,
get_hour,
get_is_leap_year,
get_minute,
get_year,
)


def encode_datetimes(datetimes: pd.DatetimeIndex) -> NumpySample:
def encode_datetimes(datetimes: NDArray[np.datetime64]) -> NumpySample:
"""Creates dictionary of sin and cos datetime embeddings.

Args:
datetimes: DatetimeIndex to create radian embeddings for
datetimes: datetime array to create radian embeddings for

Returns:
Dictionary of datetime encodings
"""
day_of_year = datetimes.dayofyear
minute_of_day = datetimes.minute + datetimes.hour * 60
day_fraction = get_day_fraction(datetimes)
day_of_year = get_day_of_year(datetimes)

time_in_radians = (2 * np.pi) * (minute_of_day / (24 * 60))
time_in_radians = (2 * np.pi) * day_fraction
date_in_radians = (2 * np.pi) * (day_of_year / 365)

return {
Expand All @@ -32,7 +40,7 @@ def encode_datetimes(datetimes: pd.DatetimeIndex) -> NumpySample:


def get_t0_embedding(
t0: pd.Timestamp,
t0: np.datetime64,
embeddings: list[tuple[str, Literal["cyclic", "linear"]]],
) -> dict[str, np.ndarray]:
"""Creates dictionary of t0 time embeddings.
Expand All @@ -50,12 +58,15 @@ def get_t0_embedding(

if period_str.endswith("h"):
period_hours = int(period_str.removesuffix("h"))
frac = (t0.hour + t0.minute / 60) / period_hours
frac = (get_hour(t0) + get_minute(t0) / 60) / period_hours

elif period_str.endswith("y"):
period_years = int(period_str.removesuffix("y"))
days_in_year = 366 if t0.is_leap_year else 365
frac = (((t0.dayofyear-1) / days_in_year) + t0.year % period_years) / period_years
days_in_year = 366 if get_is_leap_year(t0) else 365
frac = (
(((get_day_of_year(t0)-1) / days_in_year) + get_year(t0) % period_years)
/ period_years
)

if embedding_type=="cyclic":
radians = 2 * np.pi * frac
Expand Down
10 changes: 8 additions & 2 deletions ocf_data_sampler/numpy_sample/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ def convert_generation_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = No
da: Xarray DataArray containing generation data
t0_idx: Index of the t0 timestamp in the time dimension of the generation data
"""
generation_values = da.sel(gen_param="generation_mw").values
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the normalisation in here instead of the main body of the .process_and_combine_datasets() method. See comment on that method for the reasoning

capacity_value = da.sel(gen_param="capacity_mwp").values[0]

if capacity_value!=0:
generation_values = generation_values/capacity_value

sample = {
GenerationSampleKey.generation: da.values,
GenerationSampleKey.capacity_mwp: da.capacity_mwp.values[0],
GenerationSampleKey.generation: generation_values,
GenerationSampleKey.capacity_mwp: capacity_value,
GenerationSampleKey.time_utc: da["time_utc"].values.astype(float),
}

Expand Down
6 changes: 3 additions & 3 deletions ocf_data_sampler/numpy_sample/sun_position.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Module for calculating solar position."""

import numpy as np
import pandas as pd
import pvlib
from numpy.typing import NDArray

from ocf_data_sampler.numpy_sample.common_types import NumpySample


def calculate_azimuth_and_elevation(
datetimes: pd.DatetimeIndex,
datetimes: NDArray[np.datetime64],
lon: float,
lat: float,
) -> tuple[np.ndarray, np.ndarray]:
Expand All @@ -34,7 +34,7 @@ def calculate_azimuth_and_elevation(


def make_sun_position_numpy_sample(
datetimes: pd.DatetimeIndex,
datetimes: NDArray[np.datetime64],
lon: float,
lat: float,
) -> NumpySample:
Expand Down
2 changes: 1 addition & 1 deletion ocf_data_sampler/select/diff_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ def diff_channels(da: xr.DataArray, accum_channels: list[str]) -> xr.DataArray:
# Make a copy of the values to avoid changing the underlying numpy array
vals = da.values.copy()
vals[:-1, accum_channel_inds] = np.diff(vals[:, accum_channel_inds], axis=0)
da.values = vals
da.data = vals

return da.isel(step=slice(0, -1))
Loading