diff --git a/BBFM.md b/BBFM.md index 72dbf20..c77bef0 100644 --- a/BBFM.md +++ b/BBFM.md @@ -27,6 +27,24 @@ A version of the Radio Autoencoder (RADE) designed for the baseband FM channel p octave:4> radae_plots; do_plots_bbfm('z_hat.f32') ``` +# Fading channel simulation + +HF channel sim (two path Rayleigh) is pretty close to TIA-102.CAAA-E 1.6.33 Faded Channel Simulator. The measured level crossing rate (LCR) seems to meet req (f), for v=60 km/hr, f = 450 MHz, and P=1 when measured over a 10 second sample. We've used Rs=2000 symb/s here, so x-axis of plot is 1 second in time. + +![LMR 60](doc/lmr_60.png) + +``` +octave:39> multipath_samples("lmr60",8000, 2000, 1, 10, "h_lmr60.f32") +Generating Doppler spreading samples... +fd = 25.000 +path_delay_s = 2.0000e-04 +Nsecplot = 1 +Pav = 1.0366 +P = 1 +LCR_theory = 23.457 +LCR_meas = 24.400 +``` + # Single Carrier PSK Modem A single carrier PSK modem "back end" that connects the ML symbols to the radio. This particular modem is written in Python, and can work with DC coupled and passband BBFM radios. It uses classical DSP, rather than ML. Unlike the HF RADE waveform which used OFDM, this modem is single carrier. @@ -37,6 +55,29 @@ A single carrier PSK modem "back end" that connects the ML symbols to the radio. ``` 1. Run a suite of tests: ``` - python3 -c "from radae import single_carrier; s=single_carrier(); s.run_tests()" + ctest -V -R bbfm_sc + ``` +1. Create a file of BBFM symbols, 80 symbols every 40ms, plays expected output speech: + ``` + ./bbfm_inference.sh model_bbfm_01/checkpoints/checkpoint_epoch_100.pth wav/brian_g8sez.wav - --write_latent z.f32 + ``` +2. Sanity check of modem, BER test using digital, BPSK symbols, the symbols in z.f32 are replaced with BPSK symbols. `t.int16` is a real valued Fs=9600Hz sample file, that could be played into a FM radio. + ``` + cat z.f32 | python3 sc_tx.py --ber_test > t.int16 + cat t.int16 | python3 sc_rx.py --ber_test --plots > /dev/null + ``` +3. Send the BBFM symbols over the modem, and listen to results: + ``` + cat z.f32 | python3 sc_tx.py > t.int16 + cat t.int16 | python3 sc_rx.py > z_hat.f32 + ./bbfm_rx.sh model_bbfm_01/checkpoints/checkpoint_epoch_100.pth z_hat.f32 - + ``` +4. Compare MSE of features passed through the system, first with z == z_hat, then with z passed through modem to get z_hat: + ``` + python3 loss.py features_in.f32 features_out.f32 + loss: 0.033 + python3 loss.py features_in.f32 features_rx_out.f32 + loss: 0.035 ``` + This is a really good result, and likely inaudible. The `feature*.f32` files are produced as intermediate outputs from the `bbfm_inference.sh` and `bbfm_rx.sh` scripts. diff --git a/bbfm_inference.py b/bbfm_inference.py new file mode 100644 index 0000000..8a10b62 --- /dev/null +++ b/bbfm_inference.py @@ -0,0 +1,142 @@ +""" +/* Copyright (c) 2024 modifications for radio autoencoder project + by David Rowe */ + +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse + +import numpy as np +import torch + +from radae import BBFM, distortion_loss + +parser = argparse.ArgumentParser() + +parser.add_argument('model_name', type=str, help='path to model in .pth format') +parser.add_argument('features', type=str, help='path to input feature file in .f32 format') +parser.add_argument('features_hat', type=str, help='path to output feature file in .f32 format') +parser.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80) +parser.add_argument('--cuda-visible-devices', type=str, help="set to 0 to run using GPU rather than CPU", default="") +parser.add_argument('--write_latent', type=str, default="", help='path to output file of latent vectors z[latent_dim] in .f32 format') +parser.add_argument('--CNRdB', type=float, default=100, help='FM demod input CNR in dB') +parser.add_argument('--passthru', action='store_true', help='copy features in to feature out, bypassing ML network') +parser.add_argument('--h_file', type=str, default="", help='path to rate Rs fading channel magnitude samples, rate Rs time steps by Nc=1 carriers .f32 format') +parser.add_argument('--write_CNRdB', type=str, default="", help='path to output file of CNRdB per sample after fading in .f32 format') +parser.add_argument('--loss_test', type=float, default=0.0, help='compare loss to arg, print PASS/FAIL') +args = parser.parse_args() + +# set visible devices +os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices + +# device +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +latent_dim = args.latent_dim + +# not exposed +nb_total_features = 36 +num_features = 20 +num_used_features = 20 + +# load model from a checkpoint file +model = BBFM(num_features, latent_dim, args.CNRdB) +checkpoint = torch.load(args.model_name, map_location='cpu') +model.load_state_dict(checkpoint['state_dict'], strict=False) +checkpoint['state_dict'] = model.state_dict() + +# load features from file +feature_file = args.features +features_in = np.reshape(np.fromfile(feature_file, dtype=np.float32), (1, -1, nb_total_features)) +nb_features_rounded = model.num_10ms_times_steps_rounded_to_modem_frames(features_in.shape[1]) +features = features_in[:,:nb_features_rounded,:] +features = features[:, :, :num_used_features] +features = torch.tensor(features) +print(f"Processing: {nb_features_rounded} feature vectors") + +# default rate Rb multipath model H=1 +Rb = model.Rb +Nc = 1 +num_timesteps_at_rate_Rs = model.num_timesteps_at_rate_Rs(nb_features_rounded) +H = torch.ones((1,num_timesteps_at_rate_Rs,Nc)) + +# user supplied rate Rs multipath model, sequence of H magnitude samples +if args.h_file: + H = np.reshape(np.fromfile(args.h_file, dtype=np.float32), (1, -1, Nc)) + print(H.shape, num_timesteps_at_rate_Rs) + if H.shape[1] < num_timesteps_at_rate_Rs: + print("Multipath H file too short") + quit() + H = H[:,:num_timesteps_at_rate_Rs,:] + H = torch.tensor(H) + +if __name__ == '__main__': + + if args.passthru: + features_hat = features_in.flatten() + features_hat.tofile(args.features_hat) + quit() + + # push model to device and run test + model.to(device) + features = features.to(device) + H = H.to(device) + output = model(features,H) + + # Lets check actual SNR at output of FM demod + tx_sym = output["z_hat"].cpu().detach().numpy() + S = np.mean(np.abs(tx_sym)**2) + N = np.mean(output["sigma"].cpu().detach().numpy()**2) + SNRdB_meas = 10*np.log10(S/N) + print(f"SNRdB Measured: {SNRdB_meas:6.2f}") + + features_hat = output["features_hat"][:,:,:num_used_features] + features_hat = torch.cat([features_hat, torch.zeros_like(features_hat)[:,:,:16]], dim=-1) + features_hat = features_hat.cpu().detach().numpy().flatten().astype('float32') + features_hat.tofile(args.features_hat) + + loss = distortion_loss(features,output['features_hat']).cpu().detach().numpy()[0] + print(f"loss: {loss:5.3f}") + if args.loss_test > 0.0: + if loss < args.loss_test: + print("PASS") + else: + print("FAIL") + + # write output symbols (latent vectors) + if len(args.write_latent): + z_hat = output["z_hat"].cpu().detach().numpy().flatten().astype('float32') + z_hat.tofile(args.write_latent) + + # write CNRdB after fading + if len(args.write_CNRdB): + CNRdB = output["CNRdB"].cpu().detach().numpy().flatten().astype('float32') + CNRdB.tofile(args.write_CNRdB) + diff --git a/bbfm_inference.sh b/bbfm_inference.sh new file mode 100755 index 0000000..04bd3ae --- /dev/null +++ b/bbfm_inference.sh @@ -0,0 +1,47 @@ +#!/bin/bash -x +# +# Some automation around inference_bbfm.py to help with testing + +OPUS=build/src +PATH=${PATH}:${OPUS} + +if [ $# -lt 3 ]; then + echo "usage (write output to file):" + echo " ./bbfm_inference.sh model in.s16 out.wav [optional bbfm_inference.py args]" + echo "usage (play output with aplay):" + echo " ./bbfm_inference.sh model in.s16 - [optional bbfm_inference.py args]" + exit 1 +fi + +if [ ! -f $1 ]; then + echo "can't find $1" + exit 1 +fi +if [ ! -f $2 ]; then + echo "can't find $2" + exit 1 +fi + +model=$1 +input_speech=$2 +output_speech=$3 +features_in=features_in.f32 +features_out=features_out.f32 + +# eat first 3 args before passing rest to inference.py in $@ +shift; shift; shift + +lpcnet_demo -features ${input_speech} ${features_in} +python3 ./bbfm_inference.py ${model} ${features_in} ${features_out} "$@" +if [ $? -ne 0 ]; then + exit 1 +fi +if [ $output_speech == "-" ]; then + tmp=$(mktemp) + lpcnet_demo -fargan-synthesis ${features_out} ${tmp} + aplay $tmp -r 16000 -f S16_LE 2>/dev/null +elif [ $output_speech != "/dev/null" ]; then + tmp=$(mktemp) + lpcnet_demo -fargan-synthesis ${features_out} ${tmp} + sox -t .s16 -r 16000 -c 1 ${tmp} ${output_speech} +fi diff --git a/doc/lmr_60.png b/doc/lmr_60.png new file mode 100644 index 0000000..fdd3fae Binary files /dev/null and b/doc/lmr_60.png differ diff --git a/doc/rade_intro_waveform_jp_r5.pdf b/doc/rade_intro_waveform_jp_r5.pdf new file mode 100644 index 0000000..ef80827 Binary files /dev/null and b/doc/rade_intro_waveform_jp_r5.pdf differ diff --git a/multipath_samples.m b/multipath_samples.m index f589f2b..208f711 100644 --- a/multipath_samples.m +++ b/multipath_samples.m @@ -41,6 +41,7 @@ function multipath_samples(ch, Fs, Rs, Nc, Nseconds, H_fn, G_fn="") if Nc > 1 mesh(H(1:10*Rs,:)) else + # single carrier case Nsecplot=1 subplot(211); plot(H(1:Nsecplot*Rs,:)); xlabel('Symbols'); ylabel('|H|') subplot(212); plot(20*log10(H(1:Nsecplot*Rs,:))); xlabel('Symbols'); ylabel('|H| (dB)') @@ -58,8 +59,8 @@ function multipath_samples(ch, Fs, Rs, Nc, Nseconds, H_fn, G_fn="") end end LCR_meas = LC/Nseconds + subplot(211); hold on; stem(LC_log,sqrt(P)*ones(length(LC_log))); hold off; axis([0 Nsecplot*Rs 0 3]); end - subplot(211); hold on; stem(LC_log,sqrt(P)*ones(length(LC_log))); hold off; axis([0 Nsecplot*Rs 0 3]); printf("H file size is Nseconds*Rs*Nc*(4 bytes/sample) = %d*%d*%d*4 = %d bytes\n", Nseconds,Rs,Nc,Nseconds*Rs*Nc*4) f=fopen(H_fn,"wb"); [r c] = size(H); diff --git a/radae/dsp.py b/radae/dsp.py index 7b6584a..60e81c0 100644 --- a/radae/dsp.py +++ b/radae/dsp.py @@ -403,7 +403,7 @@ def do_pilot_eq_one(self, num_modem_frames, rx_sym_pilots): # est ampl across one just two sets of pilots seems to work OK (loss isn't impacted) if self.coarse_mag: - mag = torch.mean(torch.abs(rx_pilots)**2)**0.5 + mag = torch.mean(torch.abs(rx_pilots)**2)**0.5 + 1E-6 if self.bottleneck == 3: mag = mag*torch.abs(self.P[0])/self.pilot_gain #print(f"coarse mag: {mag:f}", file=sys.stderr) diff --git a/radae/radae_base.py b/radae/radae_base.py index 532ca43..95d8790 100644 --- a/radae/radae_base.py +++ b/radae/radae_base.py @@ -104,6 +104,8 @@ def __init__(self, input_dim, hidden_dim, batch_first): def forward(self, x): gru_out,self.states = self.gru(x,self.states) return gru_out + def reset(self): + self.states = torch.zeros(1,1,self.hidden_dim) # Wrapper for conv1D layer that maintains state internally class Conv1DStatefull(nn.Module): @@ -122,6 +124,9 @@ def forward(self, x): self.states = conv_in[:,-self.states_len:,:] conv_in = conv_in.permute(0, 2, 1) return torch.tanh(self.conv(conv_in)).permute(0, 2, 1) + + def reset(self): + self.states = torch.zeros(1,self.states_len,self.input_dim) #Gated Linear Unit activation class GLU(nn.Module): @@ -410,3 +415,17 @@ def forward(self, z): features = torch.reshape(x,(1,z.shape[1]*self.FRAMES_PER_STEP,self.output_dim)) return features + def reset(self): + self.conv1.reset() + self.conv2.reset() + self.conv3.reset() + self.conv4.reset() + self.conv5.reset() + + self.gru1.reset() + self.gru1.reset() + self.gru2.reset() + self.gru3.reset() + self.gru4.reset() + self.gru5.reset() + \ No newline at end of file diff --git a/radae_rxe.py b/radae_rxe.py index 94b6bef..58d8291 100644 --- a/radae_rxe.py +++ b/radae_rxe.py @@ -140,7 +140,7 @@ def get_nin(self): def get_sync(self): return self.state == "sync" - + def do_radae_rx(self, buffer_complex, features_out): acq = self.acq bpf = self.bpf @@ -247,6 +247,7 @@ def do_radae_rx(self, buffer_complex, features_out): self.valid_count = self.valid_count + 1 if self.valid_count > 3: next_state = "sync" + model.core_decoder_statefull.module.reset() self.synced_count = 0 uw_fail = False if auxdata: