-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
run_nerf.py
928 lines (763 loc) · 37.1 KB
/
run_nerf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
import sys
import tensorflow as tf
import numpy as np
import imageio
import json
import random
import time
from run_nerf_helpers import *
from load_llff import load_llff_data
from load_deepvoxels import load_dv_data
from load_blender import load_blender_data
tf.compat.v1.enable_eager_execution()
def batchify(fn, chunk):
"""Constructs a version of 'fn' that applies to smaller batches."""
if chunk is None:
return fn
def ret(inputs):
return tf.concat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
return ret
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
"""Prepares inputs and applies network 'fn'."""
inputs_flat = tf.reshape(inputs, [-1, inputs.shape[-1]])
embedded = embed_fn(inputs_flat)
if viewdirs is not None:
input_dirs = tf.broadcast_to(viewdirs[:, None], inputs.shape)
input_dirs_flat = tf.reshape(input_dirs, [-1, input_dirs.shape[-1]])
embedded_dirs = embeddirs_fn(input_dirs_flat)
embedded = tf.concat([embedded, embedded_dirs], -1)
outputs_flat = batchify(fn, netchunk)(embedded)
outputs = tf.reshape(outputs_flat, list(
inputs.shape[:-1]) + [outputs_flat.shape[-1]])
return outputs
def render_rays(ray_batch,
network_fn,
network_query_fn,
N_samples,
retraw=False,
lindisp=False,
perturb=0.,
N_importance=0,
network_fine=None,
white_bkgd=False,
raw_noise_std=0.,
verbose=False):
"""Volumetric rendering.
Args:
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
network_fn: function. Model for predicting RGB and density at each point
in space.
network_query_fn: function used for passing queries to network_fn.
N_samples: int. Number of different times to sample along each ray.
retraw: bool. If True, include model's raw, unprocessed predictions.
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
random points in time.
N_importance: int. Number of additional times to sample along each ray.
These samples are only passed to network_fine.
network_fine: "fine" network with same spec as network_fn.
white_bkgd: bool. If True, assume a white background.
raw_noise_std: ...
verbose: bool. If True, print more debugging info.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
disp_map: [num_rays]. Disparity map. 1 / depth.
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
raw: [num_rays, num_samples, 4]. Raw predictions from model.
rgb0: See rgb_map. Output for coarse model.
disp0: See disp_map. Output for coarse model.
acc0: See acc_map. Output for coarse model.
z_std: [num_rays]. Standard deviation of distances along ray for each
sample.
"""
def raw2outputs(raw, z_vals, rays_d):
"""Transforms model's predictions to semantically meaningful values.
Args:
raw: [num_rays, num_samples along ray, 4]. Prediction from model.
z_vals: [num_rays, num_samples along ray]. Integration time.
rays_d: [num_rays, 3]. Direction of each ray.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
"""
# Function for computing density from model prediction. This value is
# strictly between [0, 1].
def raw2alpha(raw, dists, act_fn=tf.nn.relu): return 1.0 - \
tf.exp(-act_fn(raw) * dists)
# Compute 'distance' (in time) between each integration time along a ray.
dists = z_vals[..., 1:] - z_vals[..., :-1]
# The 'distance' from the last integration time is infinity.
dists = tf.concat(
[dists, tf.broadcast_to([1e10], dists[..., :1].shape)],
axis=-1) # [N_rays, N_samples]
# Multiply each distance by the norm of its corresponding direction ray
# to convert to real world distance (accounts for non-unit directions).
dists = dists * tf.linalg.norm(rays_d[..., None, :], axis=-1)
# Extract RGB of each sample position along each ray.
rgb = tf.math.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3]
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
noise = 0.
if raw_noise_std > 0.:
noise = tf.random.normal(raw[..., 3].shape) * raw_noise_std
# Predict density of each sample along each ray. Higher values imply
# higher likelihood of being absorbed at this point.
alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples]
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
# [N_rays, N_samples]
weights = alpha * \
tf.math.cumprod(1.-alpha + 1e-10, axis=-1, exclusive=True)
# Computed weighted color of each sample along each ray.
rgb_map = tf.reduce_sum(
weights[..., None] * rgb, axis=-2) # [N_rays, 3]
# Estimated depth map is expected distance.
depth_map = tf.reduce_sum(weights * z_vals, axis=-1)
# Disparity map is inverse depth.
disp_map = 1./tf.maximum(1e-10, depth_map /
tf.reduce_sum(weights, axis=-1))
# Sum of weights along each ray. This value is in [0, 1] up to numerical error.
acc_map = tf.reduce_sum(weights, -1)
# To composite onto a white background, use the accumulated alpha map.
if white_bkgd:
rgb_map = rgb_map + (1.-acc_map[..., None])
return rgb_map, disp_map, acc_map, weights, depth_map
###############################
# batch size
N_rays = ray_batch.shape[0]
# Extract ray origin, direction.
rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each
# Extract unit-normalized viewing direction.
viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None
# Extract lower, upper bound for ray distance.
bounds = tf.reshape(ray_batch[..., 6:8], [-1, 1, 2])
near, far = bounds[..., 0], bounds[..., 1] # [-1,1]
# Decide where to sample along each ray. Under the logic, all rays will be sampled at
# the same times.
t_vals = tf.linspace(0., 1., N_samples)
if not lindisp:
# Space integration times linearly between 'near' and 'far'. Same
# integration points will be used for all rays.
z_vals = near * (1.-t_vals) + far * (t_vals)
else:
# Sample linearly in inverse depth (disparity).
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
z_vals = tf.broadcast_to(z_vals, [N_rays, N_samples])
# Perturb sampling time along each ray.
if perturb > 0.:
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = tf.concat([mids, z_vals[..., -1:]], -1)
lower = tf.concat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
t_rand = tf.random.uniform(z_vals.shape)
z_vals = lower + (upper - lower) * t_rand
# Points in space to evaluate model at.
pts = rays_o[..., None, :] + rays_d[..., None, :] * \
z_vals[..., :, None] # [N_rays, N_samples, 3]
# Evaluate model at each point.
raw = network_query_fn(pts, viewdirs, network_fn) # [N_rays, N_samples, 4]
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
raw, z_vals, rays_d)
if N_importance > 0:
rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
# Obtain additional integration times to evaluate based on the weights
# assigned to colors in the coarse model.
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
z_samples = sample_pdf(
z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.))
z_samples = tf.stop_gradient(z_samples)
# Obtain all points to evaluate color, density at.
z_vals = tf.sort(tf.concat([z_vals, z_samples], -1), -1)
pts = rays_o[..., None, :] + rays_d[..., None, :] * \
z_vals[..., :, None] # [N_rays, N_samples + N_importance, 3]
# Make predictions with network_fine.
run_fn = network_fn if network_fine is None else network_fine
raw = network_query_fn(pts, viewdirs, run_fn)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
raw, z_vals, rays_d)
ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map}
if retraw:
ret['raw'] = raw
if N_importance > 0:
ret['rgb0'] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0'] = acc_map_0
ret['z_std'] = tf.math.reduce_std(z_samples, -1) # [N_rays]
for k in ret:
tf.debugging.check_numerics(ret[k], 'output {}'.format(k))
return ret
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
"""Render rays in smaller minibatches to avoid OOM."""
all_ret = {}
for i in range(0, rays_flat.shape[0], chunk):
ret = render_rays(rays_flat[i:i+chunk], **kwargs)
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])
all_ret = {k: tf.concat(all_ret[k], 0) for k in all_ret}
return all_ret
def render(H, W, focal,
chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
**kwargs):
"""Render rays
Args:
H: int. Height of image in pixels.
W: int. Width of image in pixels.
focal: float. Focal length of pinhole camera.
chunk: int. Maximum number of rays to process simultaneously. Used to
control maximum memory usage. Does not affect final results.
rays: array of shape [2, batch_size, 3]. Ray origin and direction for
each example in batch.
c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
ndc: bool. If True, represent ray origin, direction in NDC coordinates.
near: float or array of shape [batch_size]. Nearest distance for a ray.
far: float or array of shape [batch_size]. Farthest distance for a ray.
use_viewdirs: bool. If True, use viewing direction of a point in space in model.
c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
camera while using other c2w argument for viewing directions.
Returns:
rgb_map: [batch_size, 3]. Predicted RGB values for rays.
disp_map: [batch_size]. Disparity map. Inverse of depth.
acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
extras: dict with everything returned by render_rays().
"""
if c2w is not None:
# special case to render full image
rays_o, rays_d = get_rays(H, W, focal, c2w)
else:
# use provided ray batch
rays_o, rays_d = rays
if use_viewdirs:
# provide ray directions as input
viewdirs = rays_d
if c2w_staticcam is not None:
# special case to visualize effect of viewdirs
rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam)
# Make all directions unit magnitude.
# shape: [batch_size, 3]
viewdirs = viewdirs / tf.linalg.norm(viewdirs, axis=-1, keepdims=True)
viewdirs = tf.cast(tf.reshape(viewdirs, [-1, 3]), dtype=tf.float32)
sh = rays_d.shape # [..., 3]
if ndc:
# for forward facing scenes
rays_o, rays_d = ndc_rays(
H, W, focal, tf.cast(1., tf.float32), rays_o, rays_d)
# Create ray batch
rays_o = tf.cast(tf.reshape(rays_o, [-1, 3]), dtype=tf.float32)
rays_d = tf.cast(tf.reshape(rays_d, [-1, 3]), dtype=tf.float32)
near, far = near * \
tf.ones_like(rays_d[..., :1]), far * tf.ones_like(rays_d[..., :1])
# (ray origin, ray direction, min dist, max dist) for each ray
rays = tf.concat([rays_o, rays_d, near, far], axis=-1)
if use_viewdirs:
# (ray origin, ray direction, min dist, max dist, normalized viewing direction)
rays = tf.concat([rays, viewdirs], axis=-1)
# Render and reshape
all_ret = batchify_rays(rays, chunk, **kwargs)
for k in all_ret:
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = tf.reshape(all_ret[k], k_sh)
k_extract = ['rgb_map', 'disp_map', 'acc_map']
ret_list = [all_ret[k] for k in k_extract]
ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract}
return ret_list + [ret_dict]
def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
H, W, focal = hwf
if render_factor != 0:
# Render downsampled for speed
H = H//render_factor
W = W//render_factor
focal = focal/render_factor
rgbs = []
disps = []
t = time.time()
for i, c2w in enumerate(render_poses):
print(i, time.time() - t)
t = time.time()
rgb, disp, acc, _ = render(
H, W, focal, chunk=chunk, c2w=c2w[:3, :4], **render_kwargs)
rgbs.append(rgb.numpy())
disps.append(disp.numpy())
if i == 0:
print(rgb.shape, disp.shape)
if gt_imgs is not None and render_factor == 0:
p = -10. * np.log10(np.mean(np.square(rgb - gt_imgs[i])))
print(p)
if savedir is not None:
rgb8 = to8b(rgbs[-1])
filename = os.path.join(savedir, '{:03d}.png'.format(i))
imageio.imwrite(filename, rgb8)
rgbs = np.stack(rgbs, 0)
disps = np.stack(disps, 0)
return rgbs, disps
def create_nerf(args):
"""Instantiate NeRF's MLP model."""
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(
args.multires_views, args.i_embed)
output_ch = 4
skips = [4]
model = init_nerf_model(
D=args.netdepth, W=args.netwidth,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
grad_vars = model.trainable_variables
models = {'model': model}
model_fine = None
if args.N_importance > 0:
model_fine = init_nerf_model(
D=args.netdepth_fine, W=args.netwidth_fine,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
grad_vars += model_fine.trainable_variables
models['model_fine'] = model_fine
def network_query_fn(inputs, viewdirs, network_fn): return run_network(
inputs, viewdirs, network_fn,
embed_fn=embed_fn,
embeddirs_fn=embeddirs_fn,
netchunk=args.netchunk)
render_kwargs_train = {
'network_query_fn': network_query_fn,
'perturb': args.perturb,
'N_importance': args.N_importance,
'network_fine': model_fine,
'N_samples': args.N_samples,
'network_fn': model,
'use_viewdirs': args.use_viewdirs,
'white_bkgd': args.white_bkgd,
'raw_noise_std': args.raw_noise_std,
}
# NDC only good for LLFF-style forward facing data
if args.dataset_type != 'llff' or args.no_ndc:
print('Not ndc!')
render_kwargs_train['ndc'] = False
render_kwargs_train['lindisp'] = args.lindisp
render_kwargs_test = {
k: render_kwargs_train[k] for k in render_kwargs_train}
render_kwargs_test['perturb'] = False
render_kwargs_test['raw_noise_std'] = 0.
start = 0
basedir = args.basedir
expname = args.expname
if args.ft_path is not None and args.ft_path != 'None':
ckpts = [args.ft_path]
else:
ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if
('model_' in f and 'fine' not in f and 'optimizer' not in f)]
print('Found ckpts', ckpts)
if len(ckpts) > 0 and not args.no_reload:
ft_weights = ckpts[-1]
print('Reloading from', ft_weights)
model.set_weights(np.load(ft_weights, allow_pickle=True))
start = int(ft_weights[-10:-4]) + 1
print('Resetting step to', start)
if model_fine is not None:
ft_weights_fine = '{}_fine_{}'.format(
ft_weights[:-11], ft_weights[-10:])
print('Reloading fine from', ft_weights_fine)
model_fine.set_weights(np.load(ft_weights_fine, allow_pickle=True))
return render_kwargs_train, render_kwargs_test, start, grad_vars, models
def config_parser():
import configargparse
parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True,
help='config file path')
parser.add_argument("--expname", type=str, help='experiment name')
parser.add_argument("--basedir", type=str, default='./logs/',
help='where to store ckpts and logs')
parser.add_argument("--datadir", type=str,
default='./data/llff/fern', help='input data directory')
# training options
parser.add_argument("--netdepth", type=int, default=8,
help='layers in network')
parser.add_argument("--netwidth", type=int, default=256,
help='channels per layer')
parser.add_argument("--netdepth_fine", type=int,
default=8, help='layers in fine network')
parser.add_argument("--netwidth_fine", type=int, default=256,
help='channels per layer in fine network')
parser.add_argument("--N_rand", type=int, default=32*32*4,
help='batch size (number of random rays per gradient step)')
parser.add_argument("--lrate", type=float,
default=5e-4, help='learning rate')
parser.add_argument("--lrate_decay", type=int, default=250,
help='exponential learning rate decay (in 1000s)')
parser.add_argument("--chunk", type=int, default=1024*32,
help='number of rays processed in parallel, decrease if running out of memory')
parser.add_argument("--netchunk", type=int, default=1024*64,
help='number of pts sent through network in parallel, decrease if running out of memory')
parser.add_argument("--no_batching", action='store_true',
help='only take random rays from 1 image at a time')
parser.add_argument("--no_reload", action='store_true',
help='do not reload weights from saved ckpt')
parser.add_argument("--ft_path", type=str, default=None,
help='specific weights npy file to reload for coarse network')
parser.add_argument("--random_seed", type=int, default=None,
help='fix random seed for repeatability')
# pre-crop options
parser.add_argument("--precrop_iters", type=int, default=0,
help='number of steps to train on central crops')
parser.add_argument("--precrop_frac", type=float,
default=.5, help='fraction of img taken for central crops')
# rendering options
parser.add_argument("--N_samples", type=int, default=64,
help='number of coarse samples per ray')
parser.add_argument("--N_importance", type=int, default=0,
help='number of additional fine samples per ray')
parser.add_argument("--perturb", type=float, default=1.,
help='set to 0. for no jitter, 1. for jitter')
parser.add_argument("--use_viewdirs", action='store_true',
help='use full 5D input instead of 3D')
parser.add_argument("--i_embed", type=int, default=0,
help='set 0 for default positional encoding, -1 for none')
parser.add_argument("--multires", type=int, default=10,
help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--multires_views", type=int, default=4,
help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--raw_noise_std", type=float, default=0.,
help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
parser.add_argument("--render_only", action='store_true',
help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--render_test", action='store_true',
help='render the test set instead of render_poses path')
parser.add_argument("--render_factor", type=int, default=0,
help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
# dataset options
parser.add_argument("--dataset_type", type=str, default='llff',
help='options: llff / blender / deepvoxels')
parser.add_argument("--testskip", type=int, default=8,
help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
# deepvoxels flags
parser.add_argument("--shape", type=str, default='greek',
help='options : armchair / cube / greek / vase')
# blender flags
parser.add_argument("--white_bkgd", action='store_true',
help='set to render synthetic data on a white bkgd (always use for dvoxels)')
parser.add_argument("--half_res", action='store_true',
help='load blender synthetic data at 400x400 instead of 800x800')
# llff flags
parser.add_argument("--factor", type=int, default=8,
help='downsample factor for LLFF images')
parser.add_argument("--no_ndc", action='store_true',
help='do not use normalized device coordinates (set for non-forward facing scenes)')
parser.add_argument("--lindisp", action='store_true',
help='sampling linearly in disparity rather than depth')
parser.add_argument("--spherify", action='store_true',
help='set for spherical 360 scenes')
parser.add_argument("--llffhold", type=int, default=8,
help='will take every 1/N images as LLFF test set, paper uses 8')
# logging/saving options
parser.add_argument("--i_print", type=int, default=100,
help='frequency of console printout and metric loggin')
parser.add_argument("--i_img", type=int, default=500,
help='frequency of tensorboard image logging')
parser.add_argument("--i_weights", type=int, default=10000,
help='frequency of weight ckpt saving')
parser.add_argument("--i_testset", type=int, default=50000,
help='frequency of testset saving')
parser.add_argument("--i_video", type=int, default=50000,
help='frequency of render_poses video saving')
return parser
def train():
parser = config_parser()
args = parser.parse_args()
if args.random_seed is not None:
print('Fixing random seed', args.random_seed)
np.random.seed(args.random_seed)
tf.compat.v1.set_random_seed(args.random_seed)
# Load data
if args.dataset_type == 'llff':
images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
recenter=True, bd_factor=.75,
spherify=args.spherify)
hwf = poses[0, :3, -1]
poses = poses[:, :3, :4]
print('Loaded llff', images.shape,
render_poses.shape, hwf, args.datadir)
if not isinstance(i_test, list):
i_test = [i_test]
if args.llffhold > 0:
print('Auto LLFF holdout,', args.llffhold)
i_test = np.arange(images.shape[0])[::args.llffhold]
i_val = i_test
i_train = np.array([i for i in np.arange(int(images.shape[0])) if
(i not in i_test and i not in i_val)])
print('DEFINING BOUNDS')
if args.no_ndc:
near = tf.reduce_min(bds) * .9
far = tf.reduce_max(bds) * 1.
else:
near = 0.
far = 1.
print('NEAR FAR', near, far)
elif args.dataset_type == 'blender':
images, poses, render_poses, hwf, i_split = load_blender_data(
args.datadir, args.half_res, args.testskip)
print('Loaded blender', images.shape,
render_poses.shape, hwf, args.datadir)
i_train, i_val, i_test = i_split
near = 2.
far = 6.
if args.white_bkgd:
images = images[..., :3]*images[..., -1:] + (1.-images[..., -1:])
else:
images = images[..., :3]
elif args.dataset_type == 'deepvoxels':
images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
basedir=args.datadir,
testskip=args.testskip)
print('Loaded deepvoxels', images.shape,
render_poses.shape, hwf, args.datadir)
i_train, i_val, i_test = i_split
hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
near = hemi_R-1.
far = hemi_R+1.
else:
print('Unknown dataset type', args.dataset_type, 'exiting')
return
# Cast intrinsics to right types
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
if args.render_test:
render_poses = np.array(poses[i_test])
# Create log dir and copy the config file
basedir = args.basedir
expname = args.expname
os.makedirs(os.path.join(basedir, expname), exist_ok=True)
f = os.path.join(basedir, expname, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(basedir, expname, 'config.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
# Create nerf model
render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf(
args)
bds_dict = {
'near': tf.cast(near, tf.float32),
'far': tf.cast(far, tf.float32),
}
render_kwargs_train.update(bds_dict)
render_kwargs_test.update(bds_dict)
# Short circuit if only rendering out from trained model
if args.render_only:
print('RENDER ONLY')
if args.render_test:
# render_test switches to test poses
images = images[i_test]
else:
# Default is smoother render_poses path
images = None
testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format(
'test' if args.render_test else 'path', start))
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', render_poses.shape)
rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test,
gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
print('Done rendering', testsavedir)
imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
to8b(rgbs), fps=30, quality=8)
return
# Create optimizer
lrate = args.lrate
if args.lrate_decay > 0:
lrate = tf.keras.optimizers.schedules.ExponentialDecay(lrate,
decay_steps=args.lrate_decay * 1000, decay_rate=0.1)
optimizer = tf.keras.optimizers.Adam(lrate)
models['optimizer'] = optimizer
global_step = tf.compat.v1.train.get_or_create_global_step()
global_step.assign(start)
# Prepare raybatch tensor if batching random rays
N_rand = args.N_rand
use_batching = not args.no_batching
if use_batching:
# For random ray batching.
#
# Constructs an array 'rays_rgb' of shape [N*H*W, 3, 3] where axis=1 is
# interpreted as,
# axis=0: ray origin in world space
# axis=1: ray direction in world space
# axis=2: observed RGB color of pixel
print('get rays')
# get_rays_np() returns rays_origin=[H, W, 3], rays_direction=[H, W, 3]
# for each pixel in the image. This stack() adds a new dimension.
rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]]
rays = np.stack(rays, axis=0) # [N, ro+rd, H, W, 3]
print('done, concats')
# [N, ro+rd+rgb, H, W, 3]
rays_rgb = np.concatenate([rays, images[:, None, ...]], 1)
# [N, H, W, ro+rd+rgb, 3]
rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4])
rays_rgb = np.stack([rays_rgb[i]
for i in i_train], axis=0) # train images only
# [(N-1)*H*W, ro+rd+rgb, 3]
rays_rgb = np.reshape(rays_rgb, [-1, 3, 3])
rays_rgb = rays_rgb.astype(np.float32)
print('shuffle rays')
np.random.shuffle(rays_rgb)
print('done')
i_batch = 0
N_iters = 1000000
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)
# Summary writers
writer = tf.contrib.summary.create_file_writer(
os.path.join(basedir, 'summaries', expname))
writer.set_as_default()
for i in range(start, N_iters):
time0 = time.time()
# Sample random ray batch
if use_batching:
# Random over all images
batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
batch = tf.transpose(batch, [1, 0, 2])
# batch_rays[i, n, xyz] = ray origin or direction, example_id, 3D position
# target_s[n, rgb] = example_id, observed color.
batch_rays, target_s = batch[:2], batch[2]
i_batch += N_rand
if i_batch >= rays_rgb.shape[0]:
np.random.shuffle(rays_rgb)
i_batch = 0
else:
# Random from one image
img_i = np.random.choice(i_train)
target = images[img_i]
pose = poses[img_i, :3, :4]
if N_rand is not None:
rays_o, rays_d = get_rays(H, W, focal, pose)
if i < args.precrop_iters:
dH = int(H//2 * args.precrop_frac)
dW = int(W//2 * args.precrop_frac)
coords = tf.stack(tf.meshgrid(
tf.range(H//2 - dH, H//2 + dH),
tf.range(W//2 - dW, W//2 + dW),
indexing='ij'), -1)
if i < 10:
print('precrop', dH, dW, coords[0,0], coords[-1,-1])
else:
coords = tf.stack(tf.meshgrid(
tf.range(H), tf.range(W), indexing='ij'), -1)
coords = tf.reshape(coords, [-1, 2])
select_inds = np.random.choice(
coords.shape[0], size=[N_rand], replace=False)
select_inds = tf.gather_nd(coords, select_inds[:, tf.newaxis])
rays_o = tf.gather_nd(rays_o, select_inds)
rays_d = tf.gather_nd(rays_d, select_inds)
batch_rays = tf.stack([rays_o, rays_d], 0)
target_s = tf.gather_nd(target, select_inds)
##### Core optimization loop #####
with tf.GradientTape() as tape:
# Make predictions for color, disparity, accumulated opacity.
rgb, disp, acc, extras = render(
H, W, focal, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True, **render_kwargs_train)
# Compute MSE loss between predicted and true RGB.
img_loss = img2mse(rgb, target_s)
trans = extras['raw'][..., -1]
loss = img_loss
psnr = mse2psnr(img_loss)
# Add MSE loss for coarse-grained model
if 'rgb0' in extras:
img_loss0 = img2mse(extras['rgb0'], target_s)
loss += img_loss0
psnr0 = mse2psnr(img_loss0)
gradients = tape.gradient(loss, grad_vars)
optimizer.apply_gradients(zip(gradients, grad_vars))
dt = time.time()-time0
##### end #####
# Rest is logging
def save_weights(net, prefix, i):
path = os.path.join(
basedir, expname, '{}_{:06d}.npy'.format(prefix, i))
np.save(path, net.get_weights())
print('saved weights at', path)
if i % args.i_weights == 0:
for k in models:
save_weights(models[k], k, i)
if i % args.i_video == 0 and i > 0:
rgbs, disps = render_path(
render_poses, hwf, args.chunk, render_kwargs_test)
print('Done, saving', rgbs.shape, disps.shape)
moviebase = os.path.join(
basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
imageio.mimwrite(moviebase + 'rgb.mp4',
to8b(rgbs), fps=30, quality=8)
imageio.mimwrite(moviebase + 'disp.mp4',
to8b(disps / np.max(disps)), fps=30, quality=8)
if args.use_viewdirs:
render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4]
rgbs_still, _ = render_path(
render_poses, hwf, args.chunk, render_kwargs_test)
render_kwargs_test['c2w_staticcam'] = None
imageio.mimwrite(moviebase + 'rgb_still.mp4',
to8b(rgbs_still), fps=30, quality=8)
if i % args.i_testset == 0 and i > 0:
testsavedir = os.path.join(
basedir, expname, 'testset_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', poses[i_test].shape)
render_path(poses[i_test], hwf, args.chunk, render_kwargs_test,
gt_imgs=images[i_test], savedir=testsavedir)
print('Saved test set')
if i % args.i_print == 0 or i < 10:
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
tf.contrib.summary.scalar('loss', loss)
tf.contrib.summary.scalar('psnr', psnr)
tf.contrib.summary.histogram('tran', trans)
if args.N_importance > 0:
tf.contrib.summary.scalar('psnr0', psnr0)
if i % args.i_img == 0:
# Log a rendered validation view to Tensorboard
img_i = np.random.choice(i_val)
target = images[img_i]
pose = poses[img_i, :3, :4]
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
**render_kwargs_test)
psnr = mse2psnr(img2mse(rgb, target))
# Save out the validation image for Tensorboard-free monitoring
testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs')
if i==0:
os.makedirs(testimgdir, exist_ok=True)
imageio.imwrite(os.path.join(testimgdir, '{:06d}.png'.format(i)), to8b(rgb))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
tf.contrib.summary.image(
'disp', disp[tf.newaxis, ..., tf.newaxis])
tf.contrib.summary.image(
'acc', acc[tf.newaxis, ..., tf.newaxis])
tf.contrib.summary.scalar('psnr_holdout', psnr)
tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])
if args.N_importance > 0:
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image(
'rgb0', to8b(extras['rgb0'])[tf.newaxis])
tf.contrib.summary.image(
'disp0', extras['disp0'][tf.newaxis, ..., tf.newaxis])
tf.contrib.summary.image(
'z_std', extras['z_std'][tf.newaxis, ..., tf.newaxis])
global_step.assign_add(1)
if __name__ == '__main__':
train()