Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions packages/common/src/weathergen/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_REPO_ROOT = Path(
__file__
).parent.parent.parent.parent.parent.parent # TODO use importlib for resources
_DEFAULT_CONFIG_PTH = _REPO_ROOT / "config" / "default_config.yml"
DEFAULT_CONFIG = _REPO_ROOT / "config" / "default_config.yml"

_DATETIME_TYPE_NAME = "datetime" # Names for custom resolvers used in Omegaconf
_TIMEDELTA_TYPE_NAME = "timedelta"
Expand Down Expand Up @@ -344,7 +344,7 @@ def load_merge_configs(
*overwrites: Additional overwrites from different sources

Note: The order of precedence for merging the final config is in ascending order:
- base config (either default config or loaded from previous run)
- base config (either `base` or loaded from previous run)
- private config
- overwrites (also in ascending order)

Expand Down Expand Up @@ -530,8 +530,8 @@ def _load_private_conf(private_home: Path | None = None) -> DictConfig:
return private_cf


def _load_base_conf(base: Path | Config | None) -> Config:
"""Return the base configuration"""
def _load_base_conf(base: Path | Config) -> Config:
"""Deserialize base config into a proper config instance."""
match base:
case Path():
_logger.info(f"Loading specified base config from file: {base}.")
Expand All @@ -540,8 +540,9 @@ def _load_base_conf(base: Path | Config | None) -> Config:
_logger.info(f"Using existing config as base: {base}.")
conf = base
case _:
_logger.info("Deserialize default configuration.")
conf = OmegaConf.load(_DEFAULT_CONFIG_PTH)
msg = f"Cannot load base config: {base}"
raise ValueError(msg)

assert isinstance(conf, Config)
return conf

Expand Down
4 changes: 2 additions & 2 deletions src/weathergen/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def inference_from_args(argl: list[str]):
args.private_config,
args.from_run_id,
args.mini_epoch,
args.base_config,
None,
*args.config,
inference_overwrite,
cli_overwrite,
Expand Down Expand Up @@ -108,7 +108,7 @@ def train_continue_from_args(argl: list[str]):
args.private_config,
args.from_run_id,
args.mini_epoch,
args.base_config,
None,
*args.config,
{},
cli_overwrite,
Expand Down
23 changes: 13 additions & 10 deletions src/weathergen/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,23 @@

import pandas as pd

import weathergen.common.config as config


def get_train_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(allow_abbrev=False)
_add_general_arguments(parser)
parser.add_argument(
"--base-config",
type=Path,
nargs="?",
help=(
"Path to the base configuration file."
"If not provided, ./config/default_config.yml is used."
),
default=config.DEFAULT_CONFIG, # --base-config missing entirely
const=config.DEFAULT_CONFIG, # --base-config given, but no value
)

return parser

Expand Down Expand Up @@ -113,16 +126,6 @@ def _add_general_arguments(parser: argparse.ArgumentParser):
" Individual items should be of the form: parent_obj.nested_obj=value"
),
)
parser.add_argument(
"--base-config",
type=Path,
nargs="?",
help=(
"Path to the base configuration file."
"If not provided, ./config/default_config.yml is used."
),
default=None,
)


def _add_model_loading_params(parser: argparse.ArgumentParser):
Expand Down
Loading