Skip to content

Commit

Permalink
Merge branch 'dr-bbfm' into dr-cport
Browse files Browse the repository at this point in the history
  • Loading branch information
drowe67 committed Nov 7, 2024
2 parents 64d4618 + 3cd792b commit 8bafe95
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 4 deletions.
43 changes: 42 additions & 1 deletion BBFM.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ A version of the Radio Autoencoder (RADE) designed for the baseband FM channel p
octave:4> radae_plots; do_plots_bbfm('z_hat.f32')
```

# Fading channel simulation

HF channel sim (two path Rayleigh) is pretty close to TIA-102.CAAA-E 1.6.33 Faded Channel Simulator. The measured level crossing rate (LCR) seems to meet req (f), for v=60 km/hr, f = 450 MHz, and P=1 when measured over a 10 second sample. We've used Rs=2000 symb/s here, so x-axis of plot is 1 second in time.

![LMR 60](doc/lmr_60.png)

```
octave:39> multipath_samples("lmr60",8000, 2000, 1, 10, "h_lmr60.f32")
Generating Doppler spreading samples...
fd = 25.000
path_delay_s = 2.0000e-04
Nsecplot = 1
Pav = 1.0366
P = 1
LCR_theory = 23.457
LCR_meas = 24.400
```

# Single Carrier PSK Modem

A single carrier PSK modem "back end" that connects the ML symbols to the radio. This particular modem is written in Python, and can work with DC coupled and passband BBFM radios. It uses classical DSP, rather than ML. Unlike the HF RADE waveform which used OFDM, this modem is single carrier.
Expand All @@ -37,6 +55,29 @@ A single carrier PSK modem "back end" that connects the ML symbols to the radio.
```
1. Run a suite of tests:
```
python3 -c "from radae import single_carrier; s=single_carrier(); s.run_tests()"
ctest -V -R bbfm_sc
```
1. Create a file of BBFM symbols, 80 symbols every 40ms, plays expected output speech:
```
./bbfm_inference.sh model_bbfm_01/checkpoints/checkpoint_epoch_100.pth wav/brian_g8sez.wav - --write_latent z.f32
```
2. Sanity check of modem, BER test using digital, BPSK symbols, the symbols in z.f32 are replaced with BPSK symbols. `t.int16` is a real valued Fs=9600Hz sample file, that could be played into a FM radio.
```
cat z.f32 | python3 sc_tx.py --ber_test > t.int16
cat t.int16 | python3 sc_rx.py --ber_test --plots > /dev/null
```
3. Send the BBFM symbols over the modem, and listen to results:
```
cat z.f32 | python3 sc_tx.py > t.int16
cat t.int16 | python3 sc_rx.py > z_hat.f32
./bbfm_rx.sh model_bbfm_01/checkpoints/checkpoint_epoch_100.pth z_hat.f32 -
```
4. Compare MSE of features passed through the system, first with z == z_hat, then with z passed through modem to get z_hat:
```
python3 loss.py features_in.f32 features_out.f32
loss: 0.033
python3 loss.py features_in.f32 features_rx_out.f32
loss: 0.035
```
This is a really good result, and likely inaudible. The `feature*.f32` files are produced as intermediate outputs from the `bbfm_inference.sh` and `bbfm_rx.sh` scripts.

142 changes: 142 additions & 0 deletions bbfm_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
/* Copyright (c) 2024 modifications for radio autoencoder project
by David Rowe */
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""

import os
import argparse

import numpy as np
import torch

from radae import BBFM, distortion_loss

parser = argparse.ArgumentParser()

parser.add_argument('model_name', type=str, help='path to model in .pth format')
parser.add_argument('features', type=str, help='path to input feature file in .f32 format')
parser.add_argument('features_hat', type=str, help='path to output feature file in .f32 format')
parser.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
parser.add_argument('--cuda-visible-devices', type=str, help="set to 0 to run using GPU rather than CPU", default="")
parser.add_argument('--write_latent', type=str, default="", help='path to output file of latent vectors z[latent_dim] in .f32 format')
parser.add_argument('--CNRdB', type=float, default=100, help='FM demod input CNR in dB')
parser.add_argument('--passthru', action='store_true', help='copy features in to feature out, bypassing ML network')
parser.add_argument('--h_file', type=str, default="", help='path to rate Rs fading channel magnitude samples, rate Rs time steps by Nc=1 carriers .f32 format')
parser.add_argument('--write_CNRdB', type=str, default="", help='path to output file of CNRdB per sample after fading in .f32 format')
parser.add_argument('--loss_test', type=float, default=0.0, help='compare loss to arg, print PASS/FAIL')
args = parser.parse_args()

# set visible devices
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices

# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

latent_dim = args.latent_dim

# not exposed
nb_total_features = 36
num_features = 20
num_used_features = 20

# load model from a checkpoint file
model = BBFM(num_features, latent_dim, args.CNRdB)
checkpoint = torch.load(args.model_name, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
checkpoint['state_dict'] = model.state_dict()

# load features from file
feature_file = args.features
features_in = np.reshape(np.fromfile(feature_file, dtype=np.float32), (1, -1, nb_total_features))
nb_features_rounded = model.num_10ms_times_steps_rounded_to_modem_frames(features_in.shape[1])
features = features_in[:,:nb_features_rounded,:]
features = features[:, :, :num_used_features]
features = torch.tensor(features)
print(f"Processing: {nb_features_rounded} feature vectors")

# default rate Rb multipath model H=1
Rb = model.Rb
Nc = 1
num_timesteps_at_rate_Rs = model.num_timesteps_at_rate_Rs(nb_features_rounded)
H = torch.ones((1,num_timesteps_at_rate_Rs,Nc))

# user supplied rate Rs multipath model, sequence of H magnitude samples
if args.h_file:
H = np.reshape(np.fromfile(args.h_file, dtype=np.float32), (1, -1, Nc))
print(H.shape, num_timesteps_at_rate_Rs)
if H.shape[1] < num_timesteps_at_rate_Rs:
print("Multipath H file too short")
quit()
H = H[:,:num_timesteps_at_rate_Rs,:]
H = torch.tensor(H)

if __name__ == '__main__':

if args.passthru:
features_hat = features_in.flatten()
features_hat.tofile(args.features_hat)
quit()

# push model to device and run test
model.to(device)
features = features.to(device)
H = H.to(device)
output = model(features,H)

# Lets check actual SNR at output of FM demod
tx_sym = output["z_hat"].cpu().detach().numpy()
S = np.mean(np.abs(tx_sym)**2)
N = np.mean(output["sigma"].cpu().detach().numpy()**2)
SNRdB_meas = 10*np.log10(S/N)
print(f"SNRdB Measured: {SNRdB_meas:6.2f}")

features_hat = output["features_hat"][:,:,:num_used_features]
features_hat = torch.cat([features_hat, torch.zeros_like(features_hat)[:,:,:16]], dim=-1)
features_hat = features_hat.cpu().detach().numpy().flatten().astype('float32')
features_hat.tofile(args.features_hat)

loss = distortion_loss(features,output['features_hat']).cpu().detach().numpy()[0]
print(f"loss: {loss:5.3f}")
if args.loss_test > 0.0:
if loss < args.loss_test:
print("PASS")
else:
print("FAIL")

# write output symbols (latent vectors)
if len(args.write_latent):
z_hat = output["z_hat"].cpu().detach().numpy().flatten().astype('float32')
z_hat.tofile(args.write_latent)

# write CNRdB after fading
if len(args.write_CNRdB):
CNRdB = output["CNRdB"].cpu().detach().numpy().flatten().astype('float32')
CNRdB.tofile(args.write_CNRdB)

47 changes: 47 additions & 0 deletions bbfm_inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash -x
#
# Some automation around inference_bbfm.py to help with testing

OPUS=build/src
PATH=${PATH}:${OPUS}

if [ $# -lt 3 ]; then
echo "usage (write output to file):"
echo " ./bbfm_inference.sh model in.s16 out.wav [optional bbfm_inference.py args]"
echo "usage (play output with aplay):"
echo " ./bbfm_inference.sh model in.s16 - [optional bbfm_inference.py args]"
exit 1
fi

if [ ! -f $1 ]; then
echo "can't find $1"
exit 1
fi
if [ ! -f $2 ]; then
echo "can't find $2"
exit 1
fi

model=$1
input_speech=$2
output_speech=$3
features_in=features_in.f32
features_out=features_out.f32

# eat first 3 args before passing rest to inference.py in $@
shift; shift; shift

lpcnet_demo -features ${input_speech} ${features_in}
python3 ./bbfm_inference.py ${model} ${features_in} ${features_out} "$@"
if [ $? -ne 0 ]; then
exit 1
fi
if [ $output_speech == "-" ]; then
tmp=$(mktemp)
lpcnet_demo -fargan-synthesis ${features_out} ${tmp}
aplay $tmp -r 16000 -f S16_LE 2>/dev/null
elif [ $output_speech != "/dev/null" ]; then
tmp=$(mktemp)
lpcnet_demo -fargan-synthesis ${features_out} ${tmp}
sox -t .s16 -r 16000 -c 1 ${tmp} ${output_speech}
fi
Binary file added doc/lmr_60.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/rade_intro_waveform_jp_r5.pdf
Binary file not shown.
3 changes: 2 additions & 1 deletion multipath_samples.m
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ function multipath_samples(ch, Fs, Rs, Nc, Nseconds, H_fn, G_fn="")
if Nc > 1
mesh(H(1:10*Rs,:))
else
# single carrier case
Nsecplot=1
subplot(211); plot(H(1:Nsecplot*Rs,:)); xlabel('Symbols'); ylabel('|H|')
subplot(212); plot(20*log10(H(1:Nsecplot*Rs,:))); xlabel('Symbols'); ylabel('|H| (dB)')
Expand All @@ -58,8 +59,8 @@ function multipath_samples(ch, Fs, Rs, Nc, Nseconds, H_fn, G_fn="")
end
end
LCR_meas = LC/Nseconds
subplot(211); hold on; stem(LC_log,sqrt(P)*ones(length(LC_log))); hold off; axis([0 Nsecplot*Rs 0 3]);
end
subplot(211); hold on; stem(LC_log,sqrt(P)*ones(length(LC_log))); hold off; axis([0 Nsecplot*Rs 0 3]);
printf("H file size is Nseconds*Rs*Nc*(4 bytes/sample) = %d*%d*%d*4 = %d bytes\n", Nseconds,Rs,Nc,Nseconds*Rs*Nc*4)
f=fopen(H_fn,"wb");
[r c] = size(H);
Expand Down
2 changes: 1 addition & 1 deletion radae/dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def do_pilot_eq_one(self, num_modem_frames, rx_sym_pilots):

# est ampl across one just two sets of pilots seems to work OK (loss isn't impacted)
if self.coarse_mag:
mag = torch.mean(torch.abs(rx_pilots)**2)**0.5
mag = torch.mean(torch.abs(rx_pilots)**2)**0.5 + 1E-6
if self.bottleneck == 3:
mag = mag*torch.abs(self.P[0])/self.pilot_gain
#print(f"coarse mag: {mag:f}", file=sys.stderr)
Expand Down
19 changes: 19 additions & 0 deletions radae/radae_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self, input_dim, hidden_dim, batch_first):
def forward(self, x):
gru_out,self.states = self.gru(x,self.states)
return gru_out
def reset(self):
self.states = torch.zeros(1,1,self.hidden_dim)

# Wrapper for conv1D layer that maintains state internally
class Conv1DStatefull(nn.Module):
Expand All @@ -122,6 +124,9 @@ def forward(self, x):
self.states = conv_in[:,-self.states_len:,:]
conv_in = conv_in.permute(0, 2, 1)
return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)

def reset(self):
self.states = torch.zeros(1,self.states_len,self.input_dim)

#Gated Linear Unit activation
class GLU(nn.Module):
Expand Down Expand Up @@ -410,3 +415,17 @@ def forward(self, z):
features = torch.reshape(x,(1,z.shape[1]*self.FRAMES_PER_STEP,self.output_dim))
return features

def reset(self):
self.conv1.reset()
self.conv2.reset()
self.conv3.reset()
self.conv4.reset()
self.conv5.reset()

self.gru1.reset()
self.gru1.reset()
self.gru2.reset()
self.gru3.reset()
self.gru4.reset()
self.gru5.reset()

3 changes: 2 additions & 1 deletion radae_rxe.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_nin(self):

def get_sync(self):
return self.state == "sync"

def do_radae_rx(self, buffer_complex, features_out):
acq = self.acq
bpf = self.bpf
Expand Down Expand Up @@ -247,6 +247,7 @@ def do_radae_rx(self, buffer_complex, features_out):
self.valid_count = self.valid_count + 1
if self.valid_count > 3:
next_state = "sync"
model.core_decoder_statefull.module.reset()
self.synced_count = 0
uw_fail = False
if auxdata:
Expand Down

0 comments on commit 8bafe95

Please sign in to comment.