Speedup: Add optional LightDataArray backend and migrate datetimes to np.datetime64#397
Speedup: Add optional LightDataArray backend and migrate datetimes to np.datetime64#397dfulu wants to merge 1 commit intodev_feb2026_speedupsfrom
Conversation
e114c2e to
3126fae
Compare
| else: | ||
| # Get the coordinates of the sample | ||
| t0, location_id = self.valid_t0_and_location_ids.iloc[idx] | ||
| t0 = self.valid_t0_and_location_ids["t0"].values[idx] |
There was a problem hiding this comment.
This change required since self.valid_t0_and_location_ids.iloc[idx] wasn't working well with returning a np.datetime64 object. It was stubbornly returning a pandas Timestamp
|
|
||
| def tensorstore_read(xarray_dict: dict) -> dict: | ||
| """Start reading a nested dictionary of xarray-tensorstore DataArrays.""" | ||
| def read_data_dict(xarray_dict: dict) -> dict: |
There was a problem hiding this comment.
This function still aims to do the same thing, just renamed for more clarity
| ds_2 = ds_nwp_ecmwf.copy(deep=True) | ||
| ds_2["init_time"] = pd.date_range( | ||
| start=ds_nwp_ecmwf.init_time.max().values + pd.Timedelta("6h"), | ||
| start=ds_nwp_ecmwf.init_time.values.max() + pd.Timedelta("6h"), |
There was a problem hiding this comment.
Calling the values first returns an array of np.datetime64 objects so start is cast to np.datetime64. I didn't change the pandas timedelta here. Maybe I should have, but I wanted to keep the number of changes low. So some pandas timestamps currently remain in the tests
| data_vars={ | ||
| "generation_mw": (("time_utc", "location_id"), np.random.randint(0, 100, (10, 2))), | ||
| "capacity_mwp": (("location_id",), [90.0, 110.0]), | ||
| "capacity_mwp": (("location_id",), [90, 110]), |
There was a problem hiding this comment.
This test to to check that an error is raised if the dtyle of the generation is int. Since we concat the generation and capacities together into a single array now, the generation would be promoted to floats if the capacities were floats. So this test would fail
3126fae to
8b78420
Compare
| ds = ds.assign_coords(capacity_mwp=ds.capacity_mwp) | ||
|
|
||
| da = ds.generation_mw | ||
| da = ds.to_dataarray("gen_param").transpose("time_utc", "location_id", "gen_param") |
There was a problem hiding this comment.
I wanted to move the capacities out of the coords. This makes the dataarray simpler and allows us to enforce that all coords at 1-dimensional. This allows theLightDataArray to be simpler by reducing the scope
| # We only look at the dimensional coords. It is possible that other coordinate systems are | ||
| # included as non-dimensional coords | ||
| dimensional_coords = set(da.xindexes) | ||
| dimensional_coords = set(da.dims) |
There was a problem hiding this comment.
This is functionally the same. I just switched so that the LightDataArray didn't have to have .xindexes and only needed .dims
| # This is the max staleness we can use considering the max step of the input data | ||
| max_possible_staleness = ( | ||
| pd.Timedelta(da["step"].max().item()) | ||
| da["step"].values.max() |
There was a problem hiding this comment.
This change is just to return a datetime64 object
|
|
||
| # Find the first forecast step | ||
| first_forecast_step = pd.Timedelta(da["step"].min().item()) | ||
| first_forecast_step = da["step"].values[0] |
There was a problem hiding this comment.
We assume the steps are sorted from low to high here. I think we implicitly assume that in our slicing anyway and we should probably make that more explicit in loading datasets (we don't right now)
| channel_stds = self.stds_dict["nwp"][nwp_key] | ||
|
|
||
| da_nwp = (da_nwp - channel_means) / channel_stds | ||
| da_nwp.data = (da_nwp.data - channel_means) / channel_stds |
There was a problem hiding this comment.
Replace the xarray .data is faster than the original line and this also works with the LightDataArray
| Args: | ||
| datetimes: the datetimes to get POSIX timestamps for | ||
| """ | ||
| return datetimes.astype("datetime64[ns]").astype(np.float64) * 1e-9 |
There was a problem hiding this comment.
To match the pd.Timestamp.timestamp() method, this needs to use float64 to avoid losing accuracy
There was a problem hiding this comment.
I made the get_day_fraction() function also float64 to match this
| from xarray_tensorstore import read as xtr_read | ||
|
|
||
|
|
||
| def minutes(minutes: int | list[float]) -> pd.Timedelta | pd.TimedeltaIndex: |
There was a problem hiding this comment.
This was just moved to time_utils.py
|
|
||
| if "generation" in dataset_dict: | ||
| da_generation = dataset_dict["generation"] | ||
| da_generation = da_generation / da_generation.capacity_mwp.values |
There was a problem hiding this comment.
The normalisation here was hard to emulate with the new generation data structure where the generation and capacity are in the same array but on different indices. It would be hard to normalise the generation values without also normalising the capacity values too. I moved the nornalisation into the convert_generation_to_numpy_sample() function instead
| da: Xarray DataArray containing generation data | ||
| t0_idx: Index of the t0 timestamp in the time dimension of the generation data | ||
| """ | ||
| generation_values = da.sel(gen_param="generation_mw").values |
There was a problem hiding this comment.
I moved the normalisation in here instead of the main body of the .process_and_combine_datasets() method. See comment on that method for the reasoning
8b78420 to
21c5289
Compare
Pull Request
Description
There are two main changes in this PR which combine the speed up the data-sampler.
These changes were motivated by running code profiling with
pyinstrument1. Getting rid of pandas timestamps
When profiling the code I found that we were spending a silly amount of the time doing things with timestamps. This is just because the pandas
Timestamp,DatetimeIndexanddate_range()are slow. I have moved us to use numpy datetime64 objects everywhere to handle datetimes. The pandas timestamp stuff is built on top of numpy datetime64 anyway so we've just removed a layer of overhead. Unfortunately this costs us in simplicity of code. The pandas Timestamps have those nice.hour,.day_of_year,.ceil()attributes and methods which numpy does not have. I have added a new filetime_utils.pywith a bunch of helper functions to make working with the numpy datetime64 objects easier and to replace some of the utility we were getting from pandas.I've added tests for these utility functions which check them against the pandas equivalents.
2. Replacing xarray DataArray
This is the biggest part of the speedup. In profiling we were spending a lot of time on xarray's internal methods which we don't really need. I've added a new class
LightDataArraywhich replaces the xarrayDataArrayin the code path under_get_sample(). This means that we can use xarray for the initialisation of the torch Dataset (where the speed of xarray doesn't matter) but then we use the new class when we are slicing out samples. We do not need all of xarray's functionality when slicing these samples, and so we can do it faster with the custom class.I have based the custom class on the xarray DataArray class. For example I've created
isel(),.sel(), and.load()methods and it has a.valuesproperty. This means we can get away with less code changes here and also means it should be pretty trivial if we want to switch back to xarray. In the current code I have left it as a option to use the newLightDataArrayclass when calling_get_sample()and it works pulling fromxr.DataArrays orLightDataArrays. This has the benefit of making it easy to test.The
LightDataArrayobject has afrom_xarray()method which instantiates aLightDataArrayfrom anxr.DataArray. The new class also has ato_xarray()method which instantiates anxr.DataArrayfrom aLightDataArrayinstance. This makes it easy to test that the methods ofLightDataArraymatch the xarray methods. For example, we can run something likeSo we use the xarray behaviour as the expected behaviour for our tests
Checklist: