|
| 1 | +import torch |
| 2 | +import numpy as np |
| 3 | +import argparse |
| 4 | + |
| 5 | +if __name__ == '__main__': |
| 6 | + |
| 7 | + parser = argparse.ArgumentParser() |
| 8 | + parser.add_argument('--src', type=str) |
| 9 | + parser.add_argument('--dst', type=str, default='./model.npy') |
| 10 | + args = parser.parse_args() |
| 11 | + |
| 12 | + state_dict = torch.load(args.src, map_location='cpu')['state_dict'] |
| 13 | + |
| 14 | + # padding = torch.zeros(13, 16) |
| 15 | + # rgb_out = state_dict['model.rgb_net.output_layer.weight'] |
| 16 | + # print(rgb_out.shape) |
| 17 | + # rgb_out = torch.cat([rgb_out, padding], dim=0) |
| 18 | + |
| 19 | + |
| 20 | + model_keys = { |
| 21 | + 'per_level_scale', 'n_neurons', |
| 22 | + 'sigma_n_input', 'sigma_n_output', |
| 23 | + 'rgb_depth', 'rgb_n_input', 'rgb_n_output', |
| 24 | + 'cascade', 'box_scale', |
| 25 | + } |
| 26 | + |
| 27 | + new_dict = { |
| 28 | + # 'camera_angle_x': meta['camera_angle_x'], |
| 29 | + 'K': state_dict['K'].numpy(), |
| 30 | + 'poses': state_dict['poses'].numpy(), |
| 31 | + 'directions': state_dict['directions'].numpy(), |
| 32 | + 'model.density_bitfield': state_dict['model.density_bitfield'].numpy(), |
| 33 | + 'model.hash_encoder.params': state_dict['model.hash_encoder.params'].numpy(), |
| 34 | + # 'model.xyz_encoder.params': |
| 35 | + # torch.cat( |
| 36 | + # [state_dict['model.xyz_encoder.hidden_layers.0.weight'].reshape(-1), |
| 37 | + # state_dict['model.xyz_encoder.output_layer.weight'].reshape(-1)] |
| 38 | + # ).numpy(), |
| 39 | + # 'model.rgb_net.params': |
| 40 | + # torch.cat( |
| 41 | + # [state_dict['model.rgb_net.hidden_layers.0.weight'].reshape(-1), |
| 42 | + # rgb_out.reshape(-1)] |
| 43 | + # ).numpy(), |
| 44 | + 'model.xyz_encoder.params': state_dict['model.xyz_encoder.params'].numpy(), |
| 45 | + # 'model.xyz_sigmas.params': state_dict['model.xyz_sigmas.params'].numpy(), |
| 46 | + 'model.rgb_net.params': state_dict['model.rgb_net.params'].numpy(), |
| 47 | + } |
| 48 | + for key in model_keys: |
| 49 | + new_dict[f'model.{key}'] = state_dict[f'model.{key}'].item() |
| 50 | + np.save(args.dst, new_dict) |
0 commit comments