Skip to content

Commit

Permalink
Merge pull request #28 from aai-institute/improve-evaluation
Browse files Browse the repository at this point in the history
Improve evaluation
  • Loading branch information
fariedabuzaid authored Sep 29, 2023
2 parents 06ce63e + 70ce4dc commit e7e0549
Show file tree
Hide file tree
Showing 16 changed files with 205 additions and 115 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
*.csv
*.onnx
*.pkl
*.pt
#
.idea
config_local.json
Expand Down
4 changes: 2 additions & 2 deletions experiments/cfair/config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
---
__object__: src.experiments.base.ExperimentCollection
__object__: src.explib.base.ExperimentCollection
name: cfair10_basedist_comparison
experiments:
- &exp_laplace
__object__: src.experiments.hyperopt.HyperoptExperiment
__object__: src.explib.hyperopt.HyperoptExperiment
name: mnist_laplace
scheduler: &scheduler
__object__: ray.tune.schedulers.ASHAScheduler
Expand Down
6 changes: 3 additions & 3 deletions experiments/fashion/config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
---
__object__: src.experiments.base.ExperimentCollection
__object__: src.explib.base.ExperimentCollection
name: fashion_basedist_comparison
experiments:
- &exp_laplace
__object__: src.experiments.hyperopt.HyperoptExperiment
__object__: src.explib.hyperopt.HyperoptExperiment
name: mnist_laplace
scheduler: &scheduler
__object__: ray.tune.schedulers.ASHAScheduler
Expand All @@ -18,7 +18,7 @@ experiments:
mode: min
trial_config:
dataset: &dataset
__object__: src.experiments.datasets.FashionMnistSplit
__object__: src.explib.datasets.FashionMnistSplit
label: 0
epochs: &epochs 10000
patience: &patience 500
Expand Down
66 changes: 43 additions & 23 deletions experiments/mnist/config.yaml
Original file line number Diff line number Diff line change
@@ -1,63 +1,83 @@
---
__object__: src.experiments.base.ExperimentCollection
__object__: src.explib.base.ExperimentCollection
name: mnist_basedist_comparison
experiments:
- &exp_laplace
__object__: src.experiments.hyperopt.HyperoptExperiment
name: mnist_laplace
- &exp_nice_lu_laplace
__object__: src.explib.hyperopt.HyperoptExperiment
name: mnist_nice_lu_laplace
scheduler: &scheduler
__object__: ray.tune.schedulers.ASHAScheduler
max_t: 10000
grace_period: 10000
max_t: 1000000
grace_period: 1000000
reduction_factor: 2
num_hyperopt_samples: &num_hyperopt_samples 25
num_hyperopt_samples: &num_hyperopt_samples 50
gpus_per_trial: &gpus_per_trial 0
cpus_per_trial: &cpus_per_trial 1
tuner_params: &tuner_params
metric: val_loss
mode: min
trial_config:
dataset: &dataset
__object__: src.veriflow.experiments.datasets.MnistSplit
__object__: src.explib.datasets.MnistSplit
digit: 0
epochs: &epochs 20000
epochs: &epochs 200000
patience: &patience 50
batch_size: &batch_size
__eval__: tune.choice([8, 16, 32, 64])
__eval__: tune.choice([32])
optim_cfg: &optim
optimizer:
__class__: torch.optim.Adam
params:
lr:
__eval__: tune.loguniform(1e-2, 5e-4)
__eval__: tune.loguniform(1e-4, 1e-2)
weight_decay: 0.0

model_cfg:
type:
__class__: &model src.veriflow.flows.NiceFlow
params:
coupling_layers: &coupling_layers
__eval__: tune.choice([ 2, 3, 4, 5])
__eval__: tune.choice([2, 3, 4, 5, 6, 7, 8, 9, 10])
coupling_nn_layers: &coupling_nn_layers
__eval__: tune.choice([[w]*l for w in [50, 100] for l in range(1, 3)])
__eval__: tune.choice([[w]*l for l in [1, 2, 3, 4] for w in [10, 20, 50, 100, 200]])
nonlinearity: &nonlinearity
__eval__: tune.choice([torch.nn.ReLU()])
split_dim: &split_dim 50
split_dim:
__eval__: tune.choice([i for i in range(1, 51)])
base_distribution:
__object__: pyro.distributions.Laplace
loc:
__eval__: torch.zeros(100)
scale:
__eval__: torch.ones(100)
permutation: &permutation LU
- &exp_normal
__overwrites__: *exp_laplace
name: mnist_normal
- &exp_nice_lu_normal
__overwrites__: *exp_nice_lu_laplace
name: mnist_nice_lu_normal
model_cfg:
params:
base_distribution:
__object__: pyro.distributions.Normal
loc:
__eval__: torch.zeros(100)
scale:
__eval__: torch.ones(100)
- &exp_nice_rand_laplace
__overwrites__: *exp_nice_lu_laplace
name: mnist_nice_rand_laplace
model_cfg:
params:
base_distribution:
__object__: pyro.distributions.Normal
loc:
__eval__: torch.zeros(100)
scale:
__eval__: torch.ones(100)
permutation: random
- &exp_nice_rand_normal
__overwrites__: *exp_nice_lu_laplace
name: mnist_nice_rand_normal
model_cfg:
params:
permutation: random
base_distribution:
__object__: pyro.distributions.Normal
loc:
__eval__: torch.zeros(100)
scale:
__eval__: torch.ones(100)

6 changes: 3 additions & 3 deletions experiments/mnist/config_best.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
---
__object__: src.experiments.base.ExperimentCollection
__object__: src.explib.base.ExperimentCollection
name: mnist_basedist_comparison
experiments:
- &exp_laplace_best
__object__: src.experiments.hyperopt.HyperoptExperiment
__object__: src.explib.hyperopt.HyperoptExperiment
name: mnist_normal_best
scheduler: &scheduler
__object__: ray.tune.schedulers.ASHAScheduler
Expand All @@ -18,7 +18,7 @@ experiments:
mode: min
trial_config:
dataset: &dataset
__object__: src.experiments.datasets.MnistSplit
__object__: src.explib.datasets.MnistSplit
digit: 0
epochs: &epochs 10000
patience: &patience 50
Expand Down
6 changes: 3 additions & 3 deletions experiments/mnist/config_lu.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
---
__object__: src.experiments.baseExperimentCollection
__object__: src.explib.baseExperimentCollection
name: mnist_basedist_comparison
experiments:
- &exp_laplace
__object__: src.experiments.hyperopt.HyperoptExperiment
__object__: src.explib.hyperopt.HyperoptExperiment
name: mnist_laplace
scheduler: &scheduler
__object__: ray.tune.schedulers.ASHAScheduler
Expand All @@ -18,7 +18,7 @@ experiments:
mode: min
trial_config:
dataset: &dataset
__object__: src.experiments.datasets.MnistSplit
__object__: src.explib.datasets.MnistSplit
digit: 0
epochs: &epochs 20000
patience: &patience 50
Expand Down
6 changes: 3 additions & 3 deletions experiments/synthetic/config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
---
__object__: laplace_flows.experiments.base.ExperimentCollection
__object__: src.explib.base.ExperimentCollection
name: mnist_basedist_comparison
experiments:
- &main
__object__: laplace_flows.experiments.hyperopt.HyperoptExperiment
__object__: src.explib.hyperopt.HyperoptExperiment
name: normal_moons
scheduler: &scheduler
__object__: ray.tune.schedulers.ASHAScheduler
Expand All @@ -18,7 +18,7 @@ experiments:
mode: min
trial_config:
dataset: &dataset
__object__: laplace_flows.experiments.datasets.SyntheticSplit
__object__: src.explib.datasets.SyntheticSplit
generator: make_moons
params_train: &params_train
n_samples: 100000
Expand Down
2 changes: 1 addition & 1 deletion scripts/run-expreiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import click

from src.experiments.config_parser import read_config
from src.explib.config_parser import read_config

Pathable = T.Union[str, os.PathLike] # In principle one can cast it to os.path.Path

Expand Down
File renamed without changes.
File renamed without changes.
59 changes: 47 additions & 12 deletions src/experiments/config_parser.py → src/explib/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Union
from pickle import load

import yaml

Expand Down Expand Up @@ -113,17 +114,29 @@ def read_config(yaml_path: Union[str, Path]) -> dict:
additional functionality:
Special keys:
__class__: The value of this key is interpreted as the class name of the object.
The class is imported and stored in the result dictionary under the key <key>.
__class__: The value is interpreted as a class name and the corresponding class is imported.
__object__: The value is interpreted as a class, all other keys are interpreted as constructor arguments.
The key indicates that this (sub-)dictionary is interpreted as on object specification.
__eval__: The value is evaluated. All other keys in the (sub-)dictionary are ignored.
The keywords supports the core python languages. Additionally, tune and torch are already imported for convenience.
Example:
entry in yaml: __class__model: laplace_flows.flows.NiceFlow)
entry in result: model: __import__("laplace_flows.flows.NiceFlow")
__tune__<key>: The value of this key is interpreted as a dictionary that contains the
configuration for the hyperparameter optimization using tune sample methods.
the directive is evaluated and the result in the result dictionary under the key <key>.
Example:
entry in yaml: __tune__lr: loguniform(1e-4, 1e-1)
entry in result: lr: eval("tune.loguniform(1e-4, 1e-1)")
---
entry in yaml:
model:
__class__: src.verfiflow.flows.NiceFlow
entry in result: model: <src.verfiflow.flows.NiceFlow>
---
entry in yaml:
model:
__object__: src.verfiflow.flows.NiceFlow
p1: 1
p2: 2
entry in result: model: <src.verfiflow.flows.NiceFlow(p1=1, p2=2)>
---
entry in yaml:
lr:
__eval__: tune.loguniform(1e-4, 1e-1)
entry in result: lr: <tune.loguniform(1e-4, 1e-1)>
:param yaml_path: Path to the yaml file.
"""
Expand All @@ -137,15 +150,18 @@ def read_config(yaml_path: Union[str, Path]) -> dict:
return config


def parse_raw_config(d: dict):
def parse_raw_config(d: dict) -> Any:
"""Parses an unfolded raw config dictionary and returns the corresponding dictionary.
Parsing includes the following steps:
- Overwrites are applied (see apply_overwrite)
- The "__object__" key is interpreted as a class name and the corresponding class is imported.
- The "__eval__" key is evaluated.
- The "__class__" key is interpreted as a class name and the corresponding class is imported.
:param d: The raw config dictionary.
Args:
d: The raw config dictionary.
Returns:
The result after all semantics have been applied.
"""
if isinstance(d, dict):
d = apply_overwrite(d, recurse=False)
Expand All @@ -172,3 +188,22 @@ def parse_raw_config(d: dict):
return result
else:
return d

def from_checkpoint(params: str, state_dict: str) -> Any:
"""Loads a model from a checkpoint.
Args:
params: Path to the file containing the model specification.
state_dict: Path to the file containing the state dict.
Returns:
The loaded model.
"""
spec = load(open(params, "rb"))["model_cfg"]
model = spec["type"](**spec["params"])

state_dict = torch.load(state_dict)
model.load_state_dict(state_dict)

return model


1 change: 0 additions & 1 deletion src/experiments/datasets.py → src/explib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def __getitem__(self, index: int):
x = self.transform(x)
return x, 0


class DataSplitFromCSV(DataSplit):
def __init__(self, train: os.PathLike, test: os.PathLike, val: os.PathLike):
self.train = train
Expand Down
Loading

0 comments on commit e7e0549

Please sign in to comment.