@@ -55,7 +55,6 @@ def caplog(caplog: LogCaptureFixture):
5555 logger .remove (handler_id )
5656
5757
58- @pytest .fixture
5958def sample_config ():
6059 """Provide a sample configuration for testing."""
6160 return {
@@ -85,91 +84,34 @@ def test_intensity_config_validation_logging(caplog):
8584
8685def test_to_sleap_nn_cfg ():
8786 """Test serializing a TrainingJobConfig to YAML."""
87+ cfg = sample_config ()
8888 config_dict = {
89- "name" : sample_config ["name" ],
90- "description" : sample_config ["description" ],
89+ "name" : cfg ["name" ],
90+ "description" : cfg ["description" ],
9191 "data_config" : {
92- "train_labels_path" : sample_config ["data_config" ].train_labels_path ,
93- "val_labels_path" : sample_config ["data_config" ].val_labels_path ,
94- "provider" : sample_config ["data_config" ].provider ,
92+ "train_labels_path" : cfg ["data_config" ].train_labels_path ,
93+ "val_labels_path" : cfg ["data_config" ].val_labels_path ,
94+ "provider" : cfg ["data_config" ].provider ,
9595 },
9696 "model_config" : {
97- "init_weights" : sample_config ["model_config" ].init_weights ,
97+ "init_weights" : cfg ["model_config" ].init_weights ,
9898 },
99- "trainer_config" : sample_config [
100- "trainer_config"
101- ], # Include full trainer config
99+ "trainer_config" : cfg ["trainer_config" ], # Include full trainer config
102100 }
103101 yaml_data = OmegaConf .to_yaml (config_dict )
104102 parsed_yaml = OmegaConf .create (yaml_data )
105103
106- assert parsed_yaml .name == sample_config ["name" ]
107- assert parsed_yaml .description == sample_config ["description" ]
104+ assert parsed_yaml .name == cfg ["name" ]
105+ assert parsed_yaml .description == cfg ["description" ]
108106 assert (
109107 parsed_yaml .data_config .train_labels_path
110- == sample_config ["data_config" ].train_labels_path
111- )
112- assert (
113- parsed_yaml .data_config .val_labels_path
114- == sample_config ["data_config" ].val_labels_path
115- )
116- assert parsed_yaml .data_config .provider == sample_config ["data_config" ].provider
117-
118- assert (
119- parsed_yaml .model_config .init_weights
120- == sample_config ["model_config" ].init_weights
121- )
122- assert parsed_yaml .trainer_config == sample_config ["trainer_config" ]
123-
124-
125- def test_load_yaml (sample_config ):
126- """Test loading a TrainingJobConfig from a YAML file."""
127- # Create proper config objects
128- data_config = DataConfig (
129- train_labels_path = sample_config ["data_config" ].train_labels_path ,
130- val_labels_path = sample_config ["data_config" ].val_labels_path ,
131- provider = sample_config ["data_config" ].provider ,
132- )
133-
134- model_config = ModelConfig (
135- init_weights = sample_config ["model_config" ].init_weights ,
136- )
137-
138- trainer_config = TrainerConfig (
139- early_stopping = sample_config ["trainer_config" ].early_stopping
140- )
141-
142- config = TrainingJobConfig (
143- name = sample_config ["name" ],
144- description = sample_config ["description" ],
145- data_config = data_config ,
146- model_config = model_config ,
147- trainer_config = trainer_config ,
108+ == cfg ["data_config" ].train_labels_path
148109 )
110+ assert parsed_yaml .data_config .val_labels_path == cfg ["data_config" ].val_labels_path
111+ assert parsed_yaml .data_config .provider == cfg ["data_config" ].provider
149112
150- with tempfile .TemporaryDirectory () as tmpdir :
151- file_path = os .path .join (tmpdir , "test_config.yaml" )
152-
153- # Use the to_yaml method to save the file
154- config .to_yaml (filename = file_path )
155-
156- # Load from file
157- loaded_config = TrainingJobConfig .load_yaml (file_path )
158- assert loaded_config .name == config .name
159- assert loaded_config .description == config .description
160- # Use dictionary access for loaded config
161- assert (
162- loaded_config .data_config .train_labels_path
163- == config .data_config .train_labels_path
164- )
165- assert (
166- loaded_config .data_config .val_labels_path
167- == config .data_config .val_labels_path
168- )
169- assert (
170- loaded_config .trainer_config .early_stopping .patience
171- == config .trainer_config .early_stopping .patience
172- )
113+ assert parsed_yaml .model_config .init_weights == cfg ["model_config" ].init_weights
114+ assert parsed_yaml .trainer_config == cfg ["trainer_config" ]
173115
174116
175117# def test_missing_attributes(sample_config):
@@ -348,6 +290,3 @@ def test_load_topdown_training_config_from_file(topdown_training_config_path):
348290 omegacfg = cfg .to_sleap_nn_cfg ()
349291 assert isinstance (omegacfg , DictConfig )
350292 assert omegacfg .data_config .train_labels_path == "test.slp"
351-
352- with pytest .raises (MissingMandatoryValue ):
353- config = TrainingJobConfig ().to_sleap_nn_cfg ()
0 commit comments