1
+ """
2
+
3
+ Basic decoder to test operation of RADAE using Python Embedding.
4
+
5
+ /*
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions
8
+ are met:
9
+
10
+ - Redistributions of source code must retain the above copyright
11
+ notice, this list of conditions and the following disclaimer.
12
+
13
+ - Redistributions in binary form must reproduce the above copyright
14
+ notice, this list of conditions and the following disclaimer in the
15
+ documentation and/or other materials provided with the distribution.
16
+
17
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
21
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+ */
29
+ """
30
+
31
+ import os , sys
32
+ import numpy as np
33
+ import torch
34
+ sys .path .append ("../" )
35
+ from radae import RADAE , distortion_loss
36
+
37
+ # Hard code all this for now to avodiu arg poassing complexities
38
+ model_name = "../model05/checkpoints/checkpoint_epoch_100.pth"
39
+ features_in_fn = "features_in.f32"
40
+ features_out_fn = "features_out.f32"
41
+ latent_dim = 80
42
+ auxdata = False
43
+
44
+ os .environ ['CUDA_VISIBLE_DEVICES' ] = ""
45
+ device = torch .device ("cpu" )
46
+ nb_total_features = 36
47
+ num_features = 20
48
+ num_used_features = 20
49
+ if auxdata :
50
+ num_features += 1
51
+
52
+ # load model from a checkpoint file
53
+ model = RADAE (num_features , latent_dim , 100.0 ,)
54
+ checkpoint = torch .load (model_name , map_location = 'cpu' )
55
+ model .load_state_dict (checkpoint ['state_dict' ], strict = False )
56
+
57
+ def my_decode ():
58
+ # dataloader
59
+ features_in = np .reshape (np .fromfile (features_in_fn , dtype = np .float32 ), (1 , - 1 , nb_total_features ))
60
+ nb_features_rounded = model .num_10ms_times_steps_rounded_to_modem_frames (features_in .shape [1 ])
61
+ features = torch .tensor (features_in [:,:nb_features_rounded ,:num_used_features ])
62
+ print (f"Processing: { nb_features_rounded } feature vectors" )
63
+
64
+ model .to (device )
65
+ features = features .to (device )
66
+ z = model .core_encoder (features )
67
+ features_hat = model .core_decoder (z )
68
+
69
+ loss = distortion_loss (features ,features_hat ).cpu ().detach ().numpy ()[0 ]
70
+ print (f"loss: { loss :5.3f} " )
71
+
72
+ features_hat = torch .cat ([features_hat , torch .zeros_like (features_hat )[:,:,:16 ]], dim = - 1 )
73
+ features_hat = features_hat .cpu ().detach ().numpy ().flatten ().astype ('float32' )
74
+ features_hat .tofile (features_out_fn )
75
+
76
+ if __name__ == '__main__' :
77
+ my_decode ()
0 commit comments