@@ -158,6 +158,7 @@ def get_scene(scene):
158158 testbed .nerf .training .train_mode = ngp .TrainMode .RflRelax
159159 else :
160160 raise ValueError (f"Unknown train mode: { args .train_mode } " )
161+
161162 testbed .nerf .training .rfl_warmup_steps = args .rfl_warmup_steps
162163
163164 if args .nerf_compatibility :
@@ -194,12 +195,19 @@ def get_scene(scene):
194195 n_steps = 35000
195196
196197 original_train_mode = ngp .TrainMode (testbed .nerf .training .train_mode )
198+ prev_train_mode = original_train_mode
199+
197200 tqdm_last_update = 0
198201 if n_steps > 0 :
199202 with tqdm (desc = "Training" , total = n_steps , unit = "steps" ) as t :
200203 while testbed .frame ():
204+ if prev_train_mode != testbed .nerf .training .train_mode and not args .no_rflrelax_training_schedule :
205+ print ("Disabling RflRelax training schedule due to UI train mode change" )
206+ args .no_rflrelax_training_schedule = True
207+
201208 if testbed .want_repl ():
202209 repl (testbed )
210+
203211 # What will happen when training is done?
204212 if testbed .training_step >= n_steps :
205213 if args .gui :
@@ -214,8 +222,7 @@ def get_scene(scene):
214222
215223 # Rfl-relax training schedule
216224 progress_fraction = float (testbed .training_step ) / n_steps
217- if (original_train_mode == ngp .TrainMode .RflRelax and
218- not args .no_rflrelax_training_schedule ):
225+ if original_train_mode == ngp .TrainMode .RflRelax and not args .no_rflrelax_training_schedule :
219226 # By default only enable RflRelax mode between 15k and 30k steps
220227 if 3 / 7 <= progress_fraction < 6 / 7 :
221228 testbed .nerf .training .train_mode = ngp .TrainMode .RflRelax
@@ -229,6 +236,8 @@ def get_scene(scene):
229236 old_training_step = testbed .training_step
230237 tqdm_last_update = now
231238
239+ prev_train_mode = ngp .TrainMode (testbed .nerf .training .train_mode )
240+
232241 if args .save_snapshot :
233242 os .makedirs (os .path .dirname (args .save_snapshot ), exist_ok = True )
234243 testbed .save_snapshot (args .save_snapshot , False )
@@ -317,7 +326,7 @@ def get_scene(scene):
317326 cam_matrix = f ['transform_matrix_start' ]
318327 else :
319328 raise KeyError ("Missing both 'transform_matrix' and 'transform_matrix_start'" )
320-
329+
321330 testbed .set_nerf_camera_matrix (np .matrix (cam_matrix )[:- 1 ,:])
322331 outname = os .path .join (args .screenshot_dir , os .path .basename (f ["file_path" ]))
323332
0 commit comments