Skip to content

Commit

Permalink
embed_dec RADAE decoder works with Embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
drowe67 committed Sep 19, 2024
1 parent e03cc5b commit e736139
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 2 deletions.
32 changes: 30 additions & 2 deletions embed/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,40 @@ Need pythonx.y-dev so C program can fine `Python.h`, adjust for your Python vers

`sudo apt install python3.10-dev`

# Build and Run demo
# Test1 - move numpy arrays C<->Python, basic numpy and PyTorch

Adapted from [2] above, basic test of numpy, torch
Adapted from [2] above, basic test of numpy, torch, and moving numpy vectors between C and Python.

Building on Machine 1 (Ubuntu 20):
```
gcc embed1.c -o embed1 $(python3-config --cflags) $(python3-config --ldflags --embed) -fPIE
```
Building on Machine 2 (Ubuntu 22):
```
gcc embed1.c -o embed1 $(python3.10-config --cflags) $(python3.10-config --ldflags --embed)
```
Different build cmd lines suggests we need to focus on one distro/Python version, or have some Cmake magic to work out the gcc options.

To run:
```
PYTHONPATH="." ./embed1 mult multiply 2 2
```

# Test 2 - run RADAE in Python, C top level

This is a more serious test of running the RADAE decoder in a Python function, kicked off by a top level C program. Requires `features_in.f32` as input (to create see many examples in Ctests, inference.sh etc).
Ubuntu 22 Build & Run:
```
gcc embed_dec.c -o embed_dec $(python3.10-config --cflags) $(python3.10-config --ldflags --embed)
PYTHONPATH="." ./embed_dec embed_dec my_decode
<snip>
Rs: 50.00 Rs': 50.00 Ts': 0.020 Nsmf: 120 Ns: 6 Nc: 20 M: 160 Ncp: 0
Processing: 972 feature vectors
loss: 0.145
```
Compare with vanilla run just from Python:
```
python3 embed_dec.py
<snip>
loss: 0.145
```
57 changes: 57 additions & 0 deletions embed/embed_dec.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// C top level loader for emebd_dec.py

#define PY_SSIZE_T_CLEAN
#include <Python.h>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include "numpy/arrayobject.h"

#define NARGS 4

int main(int argc, char *argv[])
{
PyObject *pName, *pModule, *pFunc;

if (argc < 3) {
fprintf(stderr,"Usage: %s pythonfile funcname\n", argv[0]);
return 1;
}

Py_Initialize();
// need import_array for numpy
int ret = _import_array();
fprintf(stderr, "import_array returned: %d\n", ret);

// name of Python script
pName = PyUnicode_DecodeFSDefault(argv[1]);
/* Error checking of pName left out */
pModule = PyImport_Import(pName);

Py_DECREF(pName);

if (pModule != NULL) {
pFunc = PyObject_GetAttrString(pModule, argv[2]);
/* pFunc is a new reference */

if (pFunc && PyCallable_Check(pFunc)) {

// do the function call
PyObject_CallObject(pFunc, NULL);
}
else {
if (PyErr_Occurred())
PyErr_Print();
fprintf(stderr, "Cannot find function \"%s\"\n", argv[2]);
}
Py_XDECREF(pFunc);
Py_DECREF(pModule);
}
else {
PyErr_Print();
fprintf(stderr, "Failed to load \"%s\"\n", argv[1]);
return 1;
}
if (Py_FinalizeEx() < 0) {
return 120;
}
return 0;
}
77 changes: 77 additions & 0 deletions embed/embed_dec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Basic decoder to test operation of RADAE using Python Embedding.
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""

import os, sys
import numpy as np
import torch
sys.path.append("../")
from radae import RADAE, distortion_loss

# Hard code all this for now to avodiu arg poassing complexities
model_name = "../model05/checkpoints/checkpoint_epoch_100.pth"
features_in_fn = "features_in.f32"
features_out_fn = "features_out.f32"
latent_dim = 80
auxdata = False

os.environ['CUDA_VISIBLE_DEVICES'] = ""
device = torch.device("cpu")
nb_total_features = 36
num_features = 20
num_used_features = 20
if auxdata:
num_features += 1

# load model from a checkpoint file
model = RADAE(num_features, latent_dim, 100.0,)
checkpoint = torch.load(model_name, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)

def my_decode():
# dataloader
features_in = np.reshape(np.fromfile(features_in_fn, dtype=np.float32), (1, -1, nb_total_features))
nb_features_rounded = model.num_10ms_times_steps_rounded_to_modem_frames(features_in.shape[1])
features = torch.tensor(features_in[:,:nb_features_rounded,:num_used_features])
print(f"Processing: {nb_features_rounded} feature vectors")

model.to(device)
features = features.to(device)
z = model.core_encoder(features)
features_hat = model.core_decoder(z)

loss = distortion_loss(features,features_hat).cpu().detach().numpy()[0]
print(f"loss: {loss:5.3f}")

features_hat = torch.cat([features_hat, torch.zeros_like(features_hat)[:,:,:16]], dim=-1)
features_hat = features_hat.cpu().detach().numpy().flatten().astype('float32')
features_hat.tofile(features_out_fn)

if __name__ == '__main__':
my_decode()

0 comments on commit e736139

Please sign in to comment.