Skip to content

Support for valid examples #158

Open
@ljstrnadiii

Description

@ljstrnadiii

Is your feature request related to a problem?

There is currently no support to serve batches that satisfy some valid criteria. It would be nice to filter out batches based on some criteria such as:

  • does an example contain a valid value in a target variable?
  • does an example contain a valid value at the center of a target variable?

Consider this dataset:

import xarray as xr
import dask.array as da
import numpy as np

w = 100
da = xr.DataArray(np.random.rand(2, w, w), name='foo', dims=['variable','y', 'x'])

# simulate 10% sparse, expensive target data
percent_nans = .90
number_nans = (w ** 2) * percent_nans
da[0] = xr.where(da[1] < .1, da[1], np.nan)

bgen = xbatcher.BatchGenerator(
    da, 
    {'variable': 2, 'x':10, 'y': 10}, 
    input_overlap={'x': 0, 'y': 0}, 
    batch_dims={'x': 100, 'y': 100}, 
    concat_input_dims=True
)

for batch in bgen:
    pass

If we are serving this to a machine learning process and we only care about where we have target data. Many of these examples will not be valid i.e. there will be no target value to use for training.

Describe the solution you'd like

I would like to see something like:

w = 100
da = xr.DataArray(np.random.rand(2, w, w), name='foo', dims=['variable','y', 'x'])

# simulate 10% sparse, expensive target data
percent_nans = .90
number_nans = (w ** 2) * percent_nans
da[0] = xr.where(da[1] < .1, da[1], np.nan)

bgen = xbatcher.BatchGenerator(
    da, 
    {'variable': 2, 'x':10, 'y': 10}, 
    input_overlap={'x': 0, 'y': 0}, 
    batch_dims={'x': 100, 'y': 100}, 
    concat_input_dims=True,
    valid_example=lambda x: ~np.isnan(x[0][5,5])
)

for batch in bgen:
    pass

where we satisfy: np.all(~np.isnan(batch[:,0,5,5]))

Describe alternatives you've considered

see: https://discourse.pangeo.io/t/efficiently-slicing-random-windows-for-reduced-xarray-dataset/2447

I typically filter out all valid "chips" or "patches" in advance and persist as a "training dataset" to get all the computation out of the way. The dims would look something like {'i': number of valid chips, 'variable': 2, 'x': 10, 'y': 10}. I could then simply use xbatcher to batch on the ith dimension.

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions