Skip to content

Commit b4bbae6

Browse files
authored
Allow dataset config to be passed for dataset init (#400)
* used with precedence over `path / config.json` * bit of validation logic moved out into pydantic model
1 parent 861047b commit b4bbae6

File tree

6 files changed

+70
-32
lines changed

6 files changed

+70
-32
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rslearn"
3-
version = "0.0.18"
3+
version = "0.0.19"
44
description = "A library for developing remote sensing datasets and models"
55
authors = [
66
{ name = "OlmoEarth Team" },

rslearn/config/dataset.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,3 +635,12 @@ class DatasetConfig(BaseModel):
635635
default_factory=lambda: StorageConfig(),
636636
description="jsonargparse configuration for the WindowStorageFactory.",
637637
)
638+
639+
@field_validator("layers", mode="after")
640+
@classmethod
641+
def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
642+
"""Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
643+
for layer_name in v.keys():
644+
if "." in layer_name:
645+
raise ValueError(f"layer names must not contain periods: {layer_name}")
646+
return v

rslearn/dataset/dataset.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class Dataset:
2323
.. code-block:: none
2424
2525
dataset/
26-
config.json
26+
config.json # optional, if config provided as runtime object
2727
windows/
2828
group1/
2929
epsg:3857_10_623565_1528020/
@@ -40,37 +40,43 @@ class Dataset:
4040
materialize.
4141
"""
4242

43-
def __init__(self, path: UPath, disabled_layers: list[str] = []) -> None:
43+
def __init__(
44+
self,
45+
path: UPath,
46+
disabled_layers: list[str] = [],
47+
dataset_config: DatasetConfig | None = None,
48+
) -> None:
4449
"""Initializes a new Dataset.
4550
4651
Args:
4752
path: the root directory of the dataset
4853
disabled_layers: list of layers to disable
54+
dataset_config: optional dataset configuration to use instead of loading from the dataset directory
4955
"""
5056
self.path = path
5157

52-
# Load dataset configuration.
53-
with (self.path / "config.json").open("r") as f:
54-
config_content = f.read()
55-
config_content = substitute_env_vars_in_string(config_content)
56-
config = DatasetConfig.model_validate(json.loads(config_content))
57-
58-
self.layers = {}
59-
for layer_name, layer_config in config.layers.items():
60-
# Layer names must not contain period, since we use period to
61-
# distinguish different materialized groups within a layer.
62-
assert "." not in layer_name, "layer names must not contain periods"
63-
if layer_name in disabled_layers:
64-
logger.warning(f"Layer {layer_name} is disabled")
65-
continue
66-
self.layers[layer_name] = layer_config
67-
68-
self.tile_store_config = config.tile_store
69-
self.storage = (
70-
config.storage.instantiate_window_storage_factory().get_storage(
71-
self.path
58+
if dataset_config is None:
59+
# Load dataset configuration from the dataset directory.
60+
with (self.path / "config.json").open("r") as f:
61+
config_content = f.read()
62+
config_content = substitute_env_vars_in_string(config_content)
63+
dataset_config = DatasetConfig.model_validate(
64+
json.loads(config_content)
7265
)
66+
67+
self.layers = {}
68+
for layer_name, layer_config in dataset_config.layers.items():
69+
if layer_name in disabled_layers:
70+
logger.warning(f"Layer {layer_name} is disabled")
71+
continue
72+
self.layers[layer_name] = layer_config
73+
74+
self.tile_store_config = dataset_config.tile_store
75+
self.storage = (
76+
dataset_config.storage.instantiate_window_storage_factory().get_storage(
77+
self.path
7378
)
79+
)
7480

7581
def load_windows(
7682
self,

rslearn/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def apply_on_windows(
380380

381381
def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
382382
"""Call apply_on_windows with arguments passed via command-line interface."""
383-
dataset = Dataset(UPath(args.root), args.disabled_layers)
383+
dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
384384
apply_on_windows(
385385
f=f,
386386
dataset=dataset,

tests/unit/dataset/test_dataset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
from upath import UPath
99

10+
from rslearn.config import BandSetConfig, DatasetConfig, DType, LayerConfig, LayerType
1011
from rslearn.dataset import Dataset
1112

1213

@@ -110,3 +111,25 @@ def test_no_template_variables(self) -> None:
110111
dataset.tile_store_config["init_args"]["path_suffix"]
111112
== "/static/path/to/tiles"
112113
)
114+
115+
def test_load_from_config_object(self) -> None:
116+
"""Test that Dataset can be initialized with a pre-built DatasetConfig object."""
117+
with tempfile.TemporaryDirectory() as temp_dir:
118+
dataset_path = Path(temp_dir)
119+
120+
# Create a DatasetConfig object directly (no config.json file needed)
121+
config = DatasetConfig(
122+
layers={
123+
"images": LayerConfig(
124+
type=LayerType.RASTER,
125+
band_sets=[
126+
BandSetConfig(dtype=DType.UINT8, bands=["R", "G", "B"])
127+
],
128+
),
129+
"labels": LayerConfig(type=LayerType.VECTOR),
130+
},
131+
)
132+
133+
dataset = Dataset(UPath(dataset_path), dataset_config=config)
134+
135+
assert set(dataset.layers.keys()) == {"images", "labels"}

uv.lock

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)