Skip to content

Commit 0616e0d

Browse files
committed
fix: disable RflRelax training schedule after UI train mode change
1 parent 3b10de7 commit 0616e0d

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

scripts/run.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def get_scene(scene):
158158
testbed.nerf.training.train_mode = ngp.TrainMode.RflRelax
159159
else:
160160
raise ValueError(f"Unknown train mode: {args.train_mode}")
161+
161162
testbed.nerf.training.rfl_warmup_steps = args.rfl_warmup_steps
162163

163164
if args.nerf_compatibility:
@@ -194,12 +195,19 @@ def get_scene(scene):
194195
n_steps = 35000
195196

196197
original_train_mode = ngp.TrainMode(testbed.nerf.training.train_mode)
198+
prev_train_mode = original_train_mode
199+
197200
tqdm_last_update = 0
198201
if n_steps > 0:
199202
with tqdm(desc="Training", total=n_steps, unit="steps") as t:
200203
while testbed.frame():
204+
if prev_train_mode != testbed.nerf.training.train_mode and not args.no_rflrelax_training_schedule:
205+
print("Disabling RflRelax training schedule due to UI train mode change")
206+
args.no_rflrelax_training_schedule = True
207+
201208
if testbed.want_repl():
202209
repl(testbed)
210+
203211
# What will happen when training is done?
204212
if testbed.training_step >= n_steps:
205213
if args.gui:
@@ -214,8 +222,7 @@ def get_scene(scene):
214222

215223
# Rfl-relax training schedule
216224
progress_fraction = float(testbed.training_step) / n_steps
217-
if (original_train_mode == ngp.TrainMode.RflRelax and
218-
not args.no_rflrelax_training_schedule):
225+
if original_train_mode == ngp.TrainMode.RflRelax and not args.no_rflrelax_training_schedule:
219226
# By default only enable RflRelax mode between 15k and 30k steps
220227
if 3/7 <= progress_fraction < 6/7:
221228
testbed.nerf.training.train_mode = ngp.TrainMode.RflRelax
@@ -229,6 +236,8 @@ def get_scene(scene):
229236
old_training_step = testbed.training_step
230237
tqdm_last_update = now
231238

239+
prev_train_mode = ngp.TrainMode(testbed.nerf.training.train_mode)
240+
232241
if args.save_snapshot:
233242
os.makedirs(os.path.dirname(args.save_snapshot), exist_ok=True)
234243
testbed.save_snapshot(args.save_snapshot, False)
@@ -317,7 +326,7 @@ def get_scene(scene):
317326
cam_matrix = f['transform_matrix_start']
318327
else:
319328
raise KeyError("Missing both 'transform_matrix' and 'transform_matrix_start'")
320-
329+
321330
testbed.set_nerf_camera_matrix(np.matrix(cam_matrix)[:-1,:])
322331
outname = os.path.join(args.screenshot_dir, os.path.basename(f["file_path"]))
323332

0 commit comments

Comments
 (0)