Skip to content

Commit e736139

Browse files
committed
embed_dec RADAE decoder works with Embedding
1 parent e03cc5b commit e736139

File tree

3 files changed

+164
-2
lines changed

3 files changed

+164
-2
lines changed

embed/README.md

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,40 @@ Need pythonx.y-dev so C program can fine `Python.h`, adjust for your Python vers
1414

1515
`sudo apt install python3.10-dev`
1616

17-
# Build and Run demo
17+
# Test1 - move numpy arrays C<->Python, basic numpy and PyTorch
1818

19-
Adapted from [2] above, basic test of numpy, torch
19+
Adapted from [2] above, basic test of numpy, torch, and moving numpy vectors between C and Python.
2020

21+
Building on Machine 1 (Ubuntu 20):
2122
```
2223
gcc embed1.c -o embed1 $(python3-config --cflags) $(python3-config --ldflags --embed) -fPIE
24+
```
25+
Building on Machine 2 (Ubuntu 22):
26+
```
27+
gcc embed1.c -o embed1 $(python3.10-config --cflags) $(python3.10-config --ldflags --embed)
28+
```
29+
Different build cmd lines suggests we need to focus on one distro/Python version, or have some Cmake magic to work out the gcc options.
30+
31+
To run:
32+
```
2333
PYTHONPATH="." ./embed1 mult multiply 2 2
2434
```
2535

36+
# Test 2 - run RADAE in Python, C top level
37+
38+
This is a more serious test of running the RADAE decoder in a Python function, kicked off by a top level C program. Requires `features_in.f32` as input (to create see many examples in Ctests, inference.sh etc).
39+
Ubuntu 22 Build & Run:
40+
```
41+
gcc embed_dec.c -o embed_dec $(python3.10-config --cflags) $(python3.10-config --ldflags --embed)
42+
PYTHONPATH="." ./embed_dec embed_dec my_decode
43+
<snip>
44+
Rs: 50.00 Rs': 50.00 Ts': 0.020 Nsmf: 120 Ns: 6 Nc: 20 M: 160 Ncp: 0
45+
Processing: 972 feature vectors
46+
loss: 0.145
47+
```
48+
Compare with vanilla run just from Python:
49+
```
50+
python3 embed_dec.py
51+
<snip>
52+
loss: 0.145
53+
```

embed/embed_dec.c

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// C top level loader for emebd_dec.py
2+
3+
#define PY_SSIZE_T_CLEAN
4+
#include <Python.h>
5+
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
6+
#include "numpy/arrayobject.h"
7+
8+
#define NARGS 4
9+
10+
int main(int argc, char *argv[])
11+
{
12+
PyObject *pName, *pModule, *pFunc;
13+
14+
if (argc < 3) {
15+
fprintf(stderr,"Usage: %s pythonfile funcname\n", argv[0]);
16+
return 1;
17+
}
18+
19+
Py_Initialize();
20+
// need import_array for numpy
21+
int ret = _import_array();
22+
fprintf(stderr, "import_array returned: %d\n", ret);
23+
24+
// name of Python script
25+
pName = PyUnicode_DecodeFSDefault(argv[1]);
26+
/* Error checking of pName left out */
27+
pModule = PyImport_Import(pName);
28+
29+
Py_DECREF(pName);
30+
31+
if (pModule != NULL) {
32+
pFunc = PyObject_GetAttrString(pModule, argv[2]);
33+
/* pFunc is a new reference */
34+
35+
if (pFunc && PyCallable_Check(pFunc)) {
36+
37+
// do the function call
38+
PyObject_CallObject(pFunc, NULL);
39+
}
40+
else {
41+
if (PyErr_Occurred())
42+
PyErr_Print();
43+
fprintf(stderr, "Cannot find function \"%s\"\n", argv[2]);
44+
}
45+
Py_XDECREF(pFunc);
46+
Py_DECREF(pModule);
47+
}
48+
else {
49+
PyErr_Print();
50+
fprintf(stderr, "Failed to load \"%s\"\n", argv[1]);
51+
return 1;
52+
}
53+
if (Py_FinalizeEx() < 0) {
54+
return 120;
55+
}
56+
return 0;
57+
}

embed/embed_dec.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
3+
Basic decoder to test operation of RADAE using Python Embedding.
4+
5+
/*
6+
Redistribution and use in source and binary forms, with or without
7+
modification, are permitted provided that the following conditions
8+
are met:
9+
10+
- Redistributions of source code must retain the above copyright
11+
notice, this list of conditions and the following disclaimer.
12+
13+
- Redistributions in binary form must reproduce the above copyright
14+
notice, this list of conditions and the following disclaimer in the
15+
documentation and/or other materials provided with the distribution.
16+
17+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18+
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
21+
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22+
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23+
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24+
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25+
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26+
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
*/
29+
"""
30+
31+
import os, sys
32+
import numpy as np
33+
import torch
34+
sys.path.append("../")
35+
from radae import RADAE, distortion_loss
36+
37+
# Hard code all this for now to avodiu arg poassing complexities
38+
model_name = "../model05/checkpoints/checkpoint_epoch_100.pth"
39+
features_in_fn = "features_in.f32"
40+
features_out_fn = "features_out.f32"
41+
latent_dim = 80
42+
auxdata = False
43+
44+
os.environ['CUDA_VISIBLE_DEVICES'] = ""
45+
device = torch.device("cpu")
46+
nb_total_features = 36
47+
num_features = 20
48+
num_used_features = 20
49+
if auxdata:
50+
num_features += 1
51+
52+
# load model from a checkpoint file
53+
model = RADAE(num_features, latent_dim, 100.0,)
54+
checkpoint = torch.load(model_name, map_location='cpu')
55+
model.load_state_dict(checkpoint['state_dict'], strict=False)
56+
57+
def my_decode():
58+
# dataloader
59+
features_in = np.reshape(np.fromfile(features_in_fn, dtype=np.float32), (1, -1, nb_total_features))
60+
nb_features_rounded = model.num_10ms_times_steps_rounded_to_modem_frames(features_in.shape[1])
61+
features = torch.tensor(features_in[:,:nb_features_rounded,:num_used_features])
62+
print(f"Processing: {nb_features_rounded} feature vectors")
63+
64+
model.to(device)
65+
features = features.to(device)
66+
z = model.core_encoder(features)
67+
features_hat = model.core_decoder(z)
68+
69+
loss = distortion_loss(features,features_hat).cpu().detach().numpy()[0]
70+
print(f"loss: {loss:5.3f}")
71+
72+
features_hat = torch.cat([features_hat, torch.zeros_like(features_hat)[:,:,:16]], dim=-1)
73+
features_hat = features_hat.cpu().detach().numpy().flatten().astype('float32')
74+
features_hat.tofile(features_out_fn)
75+
76+
if __name__ == '__main__':
77+
my_decode()

0 commit comments

Comments
 (0)