Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Baseband FM POC #2 #34

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions BBFM.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,12 @@ A single carrier PSK modem "back end" that connects the ML symbols to the radio.
```
This is a really good result, and likely inaudible. The `feature*.f32` files are produced as intermediate outputs from the `bbfm_inference.sh` and `bbfm_rx.sh` scripts.

5. Playing samples over a USB sounds card connected to a radio, note selection of sample rate:
```
aplay --device="plughw:CARD=Audio,DEV=0" -r 9600 -f S16_LE t1.int16
```

6. Feeding samples from an off air wave file captured from a Rx to demod. Note `sc_xx` tools default to a centre freq of 1500Hz
```
sox ~/Desktop/sc-ber-003.wav -t .s16 -r 9600 -c 1 - highpass 100 | python3 sc_rx.py --plots > z_hat.f32
```
20 changes: 10 additions & 10 deletions bbfm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@
parser.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
parser.add_argument('--cuda-visible-devices', type=str, help="set to 0 to run using GPU rather than CPU", default="")
parser.add_argument('--write_latent', type=str, default="", help='path to output file of latent vectors z[latent_dim] in .f32 format')
parser.add_argument('--CNRdB', type=float, default=100, help='FM demod input CNR in dB')
parser.add_argument('--RdBm', type=float, default=-100, help='Receive level set point in dBm')
parser.add_argument('--passthru', action='store_true', help='copy features in to feature out, bypassing ML network')
parser.add_argument('--h_file', type=str, default="", help='path to rate Rs fading channel magnitude samples, rate Rs time steps by Nc=1 carriers .f32 format')
parser.add_argument('--write_CNRdB', type=str, default="", help='path to output file of CNRdB per sample after fading in .f32 format')
parser.add_argument('--write_RdBm', type=str, default="", help='path to output file of RdBm per sample after fading in .f32 format')
parser.add_argument('--loss_test', type=float, default=0.0, help='compare loss to arg, print PASS/FAIL')
args = parser.parse_args()

Expand All @@ -66,8 +66,8 @@
num_features = 20
num_used_features = 20

# load model from a checkpoint file
model = BBFM(num_features, latent_dim, args.CNRdB)
model = BBFM(num_features, latent_dim, args.RdBm)
# load model weights from a checkpoint file
checkpoint = torch.load(args.model_name, map_location='cpu', weights_only=True)
model.load_state_dict(checkpoint['state_dict'], strict=False)
checkpoint['state_dict'] = model.state_dict()
Expand All @@ -81,10 +81,10 @@
features = torch.tensor(features)
print(f"Processing: {nb_features_rounded} feature vectors")

# default rate Rb multipath model H=1
Rb = model.Rb
Nc = 1
num_timesteps_at_rate_Rs = model.num_timesteps_at_rate_Rs(nb_features_rounded)
# default AWGN channel (H=1)
H = torch.ones((1,num_timesteps_at_rate_Rs,Nc))

# user supplied rate Rs multipath model, sequence of H magnitude samples
Expand Down Expand Up @@ -130,13 +130,13 @@
else:
print("FAIL")

# write output symbols (latent vectors)
# optionally write output symbols (latent vectors)
if len(args.write_latent):
z_hat = output["z_hat"].cpu().detach().numpy().flatten().astype('float32')
z_hat.tofile(args.write_latent)

# write CNRdB after fading
if len(args.write_CNRdB):
CNRdB = output["CNRdB"].cpu().detach().numpy().flatten().astype('float32')
CNRdB.tofile(args.write_CNRdB)
# optionally write RdBm after fading
if len(args.write_RdBm):
RdBm = output["RdBm"].cpu().detach().numpy().flatten().astype('float32')
RdBm.tofile(args.write_RdBm)

42 changes: 23 additions & 19 deletions radae/bbfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ class BBFM(nn.Module):
def __init__(self,
feature_dim,
latent_dim,
CNRdB,
fd_Hz=5000,
fm_Hz=3000,
RdBm,
fd_Hz=1800,
fm_Hz=2880,
stateful_decoder = False
):

super(BBFM, self).__init__()

self.feature_dim = feature_dim
self.latent_dim = latent_dim
self.CNRdB = CNRdB
self.RdBm = RdBm
self.fd_Hz = fd_Hz
self.fm_Hz = fm_Hz
self.stateful_decoder = stateful_decoder
Expand All @@ -75,11 +75,14 @@ def __init__(self,
self.Rz = 1/self.Tz
self.Rb = latent_dim/self.Tz # payload data BPSK symbol rate (symbols/s or Hz)

self.beta = self.fd_Hz/self.fm_Hz # deviation
self.BWfm = 2*(self.fd_Hz + self.fm_Hz) # BW estimate using Carsons rule
self.Gfm = 10*m.log10(3*(self.beta**2)*(self.beta+1))

print(f"Rb: {self.Rb:5.2f} Deviation: {self.fd_Hz}Hz Max Modn freq: {self.fm_Hz}Hz Beta: {self.beta:3.2f}", file=sys.stderr)
x_bar = 1 # average power of modulating symbols wrt peak deviation
k = 1.38E-23; T=274; NFdB = 5
self.beta = self.fd_Hz/self.fm_Hz
self.Gfm = 10*m.log10(3*(self.beta**2)*x_bar/(1E3*k*T*self.fm_Hz)) - NFdB
self.TdBm = 12 - self.Gfm

print(f"Rb: {self.Rb:5.2f} Deviation: {self.fd_Hz} Hz Max Modn freq: {self.fm_Hz} Hz Beta: {self.beta:3.2f}", file=sys.stderr)
print(f"x_bar: {x_bar:5.2f} Gfm: {self.Gfm:5.2f} dB TdB: {self.TdBm:5.2f} dB RdBm: {self.RdBm:5.2f}", file=sys.stderr)

# Stateful decoder wasn't present during training, so we need to load weights from existing decoder
def core_decoder_statefull_load_state_dict(self):
Expand Down Expand Up @@ -171,27 +174,28 @@ def forward(self, features, H):
z_shape = z.shape
z_hat = torch.reshape(z,(num_batches,num_timesteps_at_rate_Rs,1))

# determine FM demod SNR using piecewise approximation implemented with relus to be torch-friendly
# note SNR is a vector, 1 sample for symbol as SNR evolves with H
CNRdB = 20*torch.log10(H) + self.CNRdB
print(H.shape,CNRdB.shape)
SNRdB_relu = torch.relu(CNRdB-12) + 12 + self.Gfm
SNRdB_relu += -torch.relu(-(CNRdB-12))*(1 + self.Gfm/3)
SNR = 10**(SNRdB_relu/10)
# determine FM demod SNR using piecewise approximation expressed as sum of
# heaviside step functions for efficient implementation during training.
# Note SNR is a vector, 1 sample per symbol as SNR evolves with H
values = torch.zeros(1, device=H.device)
RdBm = 20*torch.log10(H) + self.RdBm
SNRdB = (RdBm+self.Gfm)*torch.heaviside(RdBm-self.TdBm, values) \
+ (3*RdBm+self.Gfm-2*self.TdBm)*torch.heaviside(-RdBm+self.TdBm, values)
SNR = 10**(SNRdB/10)

# note sigma is a vector, noise power evolves across each symbol with H
sigma = 1/(SNR**0.5)
n = sigma*torch.randn_like(z_hat)
z_hat = torch.clamp(z_hat + n, min=-1.0,max=1.0)

z_hat = torch.reshape(z_hat,z_shape)
#print(z.shape, z_hat.shape)

features_hat = self.core_decoder(z_hat)

return {
"features_hat" : features_hat,
"z_hat" : z_hat,
"sigma" : sigma,
"CNRdB" : CNRdB
}
"SNRdB" : SNRdB,
"RdBm" : RdBm
}
82 changes: 56 additions & 26 deletions radae_plots.m
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,26 @@ function do_plots(z_fn='l.f32',rx_fn='', png_fn='', epslatex='')
end
endfunction

function do_plots_bbfm(z1_fn, z2_fn="", png_fn='')
function do_plots_bbfm(z1_fn, z2_fn='', png_fn='', epslatex='')
if length(epslatex)
[textfontsize linewidth] = set_fonts(20);
end
z1=load_f32(z1_fn,1);
figure(1); clf;
stem(z1(1:40),'g');
stem(z1(1:80),'g');
axis([0 80 -1.2 1.2]);
if length(z2_fn)
z2=load_f32(z2_fn,1);
hold on;
stem(z2(1:40),'r');
hold off;
end
title('Rx Symbols');
if length(png_fn)
print("-dpng",sprintf("%s.png",png_fn));
end
if length(epslatex)
print_eps_restore(sprintf("%s.eps",epslatex),"-S300,200",textfontsize,linewidth);
end
endfunction


Expand Down Expand Up @@ -383,42 +389,66 @@ function test_rayleigh(epslatex="")
y(find(x<0)) = 0;
end

% Plot SNR v CNR for FM demod model
function plot_SNR_CNR(epslatex="")
function y = heaviside(x)
y = x>0;
end

% Plot SNR v R for FM demod model
function bbfm_plot_SNR_R(epslatex="")
if length(epslatex)
[textfontsize linewidth] = set_fonts();
[textfontsize linewidth] = set_fonts(20);
end
figure(1); clf; hold on;
fd=5000; fm=3000;
beta= fd/fm;
Gfm=10*log10(3*(beta^2)*(beta+1))
BWfm = 2*(fd+fm);

fd=2500; fm=3000; A = 1; k=1.38E-23; T=274; NFdB=5;
beta = fd/fm;
x_bar = A^2/2;
Gfm=10*log10(3*(beta^2)*x_bar/(1E3*k*T*fm)) - NFdB;
TdBm = 12 - Gfm;
printf("fd: %6.0f fm: %6.0f Beta: %f A: %5.2f x_bar: %5.2f Gfm: %5.2f dB TdBm: %5.2f\n", fd, fm, beta, A, x_bar, Gfm, TdBm);

% vanilla implementation of curve
CNRdB=0:20;
for i=1:length(CNRdB)
if CNRdB(i) >= 12
SNRdB(i) = CNRdB(i) + Gfm;
RdBm=-130:-105;
for i=1:length(RdBm)
if RdBm(i) >= TdBm
SNRdB(i) = RdBm(i) + Gfm;
else
SNRdB(i) = (1+Gfm/3)*CNRdB(i) - 3*Gfm;
SNRdB(i) = 3*RdBm(i) + Gfm - 2*TdBm;
end
end

% implementation using relus (suitable for PyTorch)
SNRdB_relu = relu(CNRdB-12) + 12 + Gfm;
SNRdB_relu += -relu(-(CNRdB-12))*(1+Gfm/3);

plot(CNRdB,SNRdB,'g;FM;');
plot(CNRdB,SNRdB_relu,'r+;FM relu;');
SSBdB = CNRdB + 10*log10(BWfm) - 10*log10(fm);
plot(CNRdB,SSBdB,'b;SSB;');
axis([min(CNRdB) max(CNRdB) 10 30]);
hold off; grid('minor'); xlabel('CNR (dB)'); ylabel('SNR (dB)'); legend('boxoff'); legend('location','northwest');
% implementation using common ML toolkit non-linearity rather than if/then for efficiency in training
SNRdB_heaviside = (RdBm+Gfm).*heaviside(RdBm-TdBm) + (3*RdBm+Gfm-2*TdBm).*heaviside(-RdBm+TdBm);

figure(1); clf; hold on;
plot(RdBm,SNRdB,'g+-');
if length(epslatex) == 0
hold on; plot(RdBm,SNRdB_heaviside,'bx'); hold off;
end
grid('minor'); xlabel('R (dBm)'); ylabel('SNR (dB)'); legend('off');
if length(epslatex)
print_eps_restore(epslatex,"-S300,300",textfontsize,linewidth);
end
endfunction

% test expression derived from Carslon (17)
function bbfm_carlson()
fd=2500; fm=3000;
beta = fd/fm;
Sx = 0.5;
Bt = 9E3;
CNR_dB = 0:20; CNR = 10.^(CNR_dB/10);
SNR1 = 3*(beta^2)*Sx*CNR;
num = 3*(beta^2)*Sx*CNR;
denom = (1+(12*beta/pi)*CNR.*exp(-fm*CNR/Bt));
SNR3 = num./denom;
SNR1_dB = 10*log10(SNR1);
SNR3_dB = 10*log10(SNR3);
figure(1); clf; hold on;
plot(CNR_dB,SNR1_dB,'b;Eq (1);');
plot(CNR_dB,SNR3_dB,'r;Eq (4);');
hold off; grid;
endfunction

% test handling of single sample per symbol phase jumps
function test_phase_est
theta = 0:0.01:2*pi;
Expand Down
10 changes: 3 additions & 7 deletions train_bbfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
parser.add_argument('output', type=str, help='path to output folder')
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="")
parser.add_argument('--latent-dim', type=int, help="number of symbols produced by encoder, default: 80", default=80)
parser.add_argument('--CNRdB', type=float, default=0, help='FM demod input CNR in dB')
parser.add_argument('--RdBm', type=float, default=-120.0, help='Receive level set point in dBm (default -120)')
parser.add_argument('--h_file', type=str, default="", help='path to rate Rs multipath file, rate Rs time steps by 1 carriers .f32 format')

training_group = parser.add_argument_group(title="training parameters")
Expand All @@ -61,8 +61,6 @@

training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
training_group.add_argument('--plot_loss', action='store_true', help='plot loss versus epoch as we train')
training_group.add_argument('--plot_EqNo', type=str, default="", help='plot loss versus Eq/No for final epoch')
training_group.add_argument('--auxdata', action='store_true', help='inject auxillary data symbol')

args = parser.parse_args()

Expand Down Expand Up @@ -103,15 +101,13 @@
latent_dim = args.latent_dim

num_features = 20
if args.auxdata:
num_features += 1

# training data
feature_file = args.features

# model
checkpoint['model_args'] = (num_features, latent_dim, args.CNRdB)
model = BBFM(num_features, latent_dim, args.CNRdB)
checkpoint['model_args'] = (num_features, latent_dim, args.RdBm)
model = BBFM(num_features, latent_dim, args.RdBm)

if type(args.initial_checkpoint) != type(None):
print(f"Loading from checkpoint: {args.initial_checkpoint}")
Expand Down
Loading