Skip to content

Commit

Permalink
wip refactoring radae_rx as a class
Browse files Browse the repository at this point in the history
  • Loading branch information
drowe67 committed Sep 25, 2024
1 parent 69b79ac commit 96b977a
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 122 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ add_test(NAME radae_rx_embed
COMMAND sh -c "cd ${CMAKE_SOURCE_DIR}; \
./inference.sh model19_check3/checkpoints/checkpoint_epoch_100.pth wav/brian_g8sez.wav /dev/null \
--rate_Fs --pilots --pilot_eq --eq_ls --cp 0.004 --bottleneck 3 --auxdata --write_rx rx.f32 --correct_freq_offset; \
cd embed; cat ../rx.f32 | python3 radae_rx.py > ../features_out.f32; cd ..;
cd embed; cat ../rx.f32 | PYTHONPATH='../' python3 radae_rx.py > ../features_out.f32; cd ..;
python3 loss.py features_in.f32 features_out.f32 --loss_test 0.15 --acq_time_test 0.5")
set_tests_properties(radae_rx_embed PROPERTIES PASS_REGULAR_EXPRESSION "PASS")

256 changes: 135 additions & 121 deletions embed/radae_rx.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import numpy as np
from matplotlib import pyplot as plt
import torch
sys.path.append("../")
from radae import RADAE,complex_bpf,acquisition,receiver_one

# Hard code all this for now to avoid arg passing complexities TODO: consider a way to pass in at init-time
Expand Down Expand Up @@ -102,143 +101,158 @@

acq = acquisition(Fs,Rs,M,Ncp,Nmf,p,model.pend)

tmax_candidate = 0
acquired = False
state = "search"
prev_state = state
mf = 1
valid_count = 0
Tunsync = 3.0 # allow some time before lossing sync to ride over fades
Nmf_unsync = int(Tunsync*Fs/Nmf)
endofover = False
uw_errors = 0
uw_error_thresh = 7 # P(reject|correct) = 1 - binocdf(8,24,0.1) = 4.5E-4
# P(accept|false) = binocdf(8,24,0.5) = 3.2E-3
synced_count = 0
synced_count_one_sec = Fs//Nmf

# P DDD P DDD P Ncp
# extra Ncp at end so we can handle timing slips
rx_buf = np.zeros(2*Nmf+M+Ncp,np.csingle)
rx = np.zeros(0,np.csingle)
rx_phase = 1 + 1j*0
rx_phase_vec = np.zeros(Nmf+M+Ncp,np.csingle)
z_hat_log = torch.zeros(0,model.Nzmf,model.latent_dim)

nin = Nmf
with torch.inference_mode():
while True:
buffer = sys.stdin.buffer.read(nin*struct.calcsize("ff"))
if len(buffer) != nin*struct.calcsize("ff"):
break
buffer_complex = np.frombuffer(buffer,np.csingle)
if bpf:
buffer_complex = bpf.bpf(buffer_complex)
rx_buf[:-nin] = rx_buf[nin:] # out with the old
rx_buf[-nin:] = buffer_complex # in with the new
if state == "search" or state == "candidate":
candidate, tmax, fmax = acq.detect_pilots(rx_buf)
else:
# we're in sync, so check we can still see pilots and run receiver
ffine_range = np.arange(fmax-1,fmax+1,0.1)
tfine_range = np.arange(max(0,tmax-8),tmax+8)
tmax,fmax_hat = acq.refine(rx_buf, tmax, fmax, tfine_range, ffine_range)
fmax = 0.9*fmax + 0.1*fmax_hat
candidate,endofover = acq.check_pilots(rx_buf,tmax,fmax)

# handle timing slip when rx sample clock > tx sample clock
nin = Nmf
if tmax >= Nmf-M:
nin = Nmf + M
tmax -= M
#print("slip+", file=sys.stderr)
# handle timing slip when rx sample clock < tx sample clock
if tmax < M:
nin = Nmf - M
tmax += M
#print("slip-", file=sys.stderr)

synced_count += 1
if synced_count % synced_count_one_sec == 0:
if uw_errors > uw_error_thresh:
uw_fail = True
uw_errors = 0

if not endofover:
# correct frequency offset, note we preserve state of phase
# TODO do we need preserve state of phase? We're passing entire vector and there isn't any memory (I think)
w = 2*np.pi*fmax/Fs
for n in range(Nmf+M+Ncp):
rx_phase = rx_phase*np.exp(-1j*w)
rx_phase_vec[n] = rx_phase
rx1 = rx_buf[tmax-Ncp:tmax-Ncp+Nmf+M+Ncp]
#print(tmax-Ncp, tmax-Ncp+Nmf+M+Ncp,rx_buf.shape, rx1.shape, rx_phase_vec.shape, file=sys.stderr)
rx = torch.tensor(rx1*rx_phase_vec, dtype=torch.complex64)
# run through RADAE receiver DSP
z_hat = receiver.receiver_one(rx)
# decode z_hat to features
assert(z_hat.shape[1] == model.Nzmf)
features_hat = model.core_decoder_statefull(z_hat)
if auxdata:
symb_repeat = 4
aux_symb = features_hat[:,:,20].detach().numpy()
aux_bits = 1*(aux_symb[0,::symb_repeat] > 0)
features_hat = features_hat[:,:,0:20]
uw_errors += np.sum(aux_bits)
# add unused features and send to stdout
features_hat = torch.cat([features_hat, torch.zeros_like(features_hat)[:,:,:16]], dim=-1)
features_hat = features_hat.cpu().detach().numpy().flatten().astype('float32')
if use_stdout:
sys.stdout.buffer.write(features_hat)
#sys.stdout.flush()


if v == 2 or (v == 1 and (state == "search" or state == "candidate" or prev_state == "candidate")):
print(f"{mf:3d} state: {state:10s} valid: {candidate:d} {endofover:d} {valid_count:2d} Dthresh: {acq.Dthresh:8.2f} ", end='', file=sys.stderr)
print(f"Dtmax12: {acq.Dtmax12:8.2f} {acq.Dtmax12_eoo:8.2f} tmax: {tmax:4d} fmax: {fmax:6.2f}", end='', file=sys.stderr)
if auxdata and state == "sync":
print(f" aux: {aux_bits:} uw_err: {uw_errors:d}", file=sys.stderr)
class radae_rx:
def __init__(self):
self.nin = Nmf
self.state = "search"
self.tmax_candidate = 0
self.mf = 1
self.valid_count = 0
self.uw_errors = 0
self.synced_count = 0
self.rx_phase = 1 + 1j*0

def do_radae_rx(self, buffer_complex, features_out):
with torch.inference_mode():
prev_state = self.state
valid_output = False
endofover = False
uw_fail = False
if bpf:
buffer_complex = bpf.bpf(buffer_complex)
rx_buf[:-self.nin] = rx_buf[self.nin:] # out with the old
rx_buf[-self.nin:] = buffer_complex # in with the new
if self.state == "search" or self.state == "candidate":
candidate, self.tmax, self.fmax = acq.detect_pilots(rx_buf)
else:
print("",file=sys.stderr)

# iterate state machine
next_state = state
prev_state = state
if state == "search":
if candidate:
next_state = "candidate"
tmax_candidate = tmax
valid_count = 1
elif state == "candidate":
# look for 3 consecutive matches with about the same timing offset
if candidate and np.abs(tmax-tmax_candidate) < 0.02*M:
valid_count = valid_count + 1
if valid_count > 3:
next_state = "sync"
acquired = True
synced_count = 0
uw_fail = False
# we're in sync, so check we can still see pilots and run receiver
ffine_range = np.arange(self.fmax-1,self.fmax+1,0.1)
tfine_range = np.arange(max(0,self.tmax-8),self.tmax+8)
self.tmax,fmax_hat = acq.refine(rx_buf, self.tmax, self.fmax, tfine_range, ffine_range)
self.fmax = 0.9*self.fmax + 0.1*fmax_hat
candidate,endofover = acq.check_pilots(rx_buf,self.tmax,self.fmax)

# handle timing slip when rx sample clock > tx sample clock
self.nin = Nmf
if self.tmax >= Nmf-M:
self.nin = Nmf + M
self.tmax -= M
#print("slip+", file=sys.stderr)
# handle timing slip when rx sample clock < tx sample clock
if self.tmax < M:
self.nin = Nmf - M
self.tmax += M
#print("slip-", file=sys.stderr)

self.synced_count += 1
if self.synced_count % synced_count_one_sec == 0:
if self.uw_errors > uw_error_thresh:
uw_fail = True
self.uw_errors = 0

if not endofover:
# correct frequency offset, note we preserve state of phase
# TODO do we need preserve state of phase? We're passing entire vector and there isn't any memory (I think)
w = 2*np.pi*self.fmax/Fs
for n in range(Nmf+M+Ncp):
self.rx_phase = self.rx_phase*np.exp(-1j*w)
rx_phase_vec[n] = self.rx_phase
rx1 = rx_buf[self.tmax-Ncp:self.tmax-Ncp+Nmf+M+Ncp]
#print(tmax-Ncp, tmax-Ncp+Nmf+M+Ncp,rx_buf.shape, rx1.shape, rx_phase_vec.shape, file=sys.stderr)
rx = torch.tensor(rx1*rx_phase_vec, dtype=torch.complex64)
# run through RADAE receiver DSP
z_hat = receiver.receiver_one(rx)
# decode z_hat to features
assert(z_hat.shape[1] == model.Nzmf)
features_hat = model.core_decoder_statefull(z_hat)
if auxdata:
uw_errors = 0
valid_count = Nmf_unsync
ffine_range = np.arange(fmax-10,fmax+10,0.25)
tfine_range = np.arange(tmax-1,tmax+2)
tmax,fmax = acq.refine(rx_buf, tmax, fmax, tfine_range, ffine_range)
else:
next_state = "search"
elif state == "sync":
# during some tests it's useful to disable these unsync features
unsync_enable = True
symb_repeat = 4
aux_symb = features_hat[:,:,20].detach().numpy()
aux_bits = 1*(aux_symb[0,::symb_repeat] > 0)
features_hat = features_hat[:,:,0:20]
self.uw_errors += np.sum(aux_bits)
# add unused features and send to stdout
features_hat = torch.cat([features_hat, torch.zeros_like(features_hat)[:,:,:16]], dim=-1)
features_hat = features_hat.cpu().detach().numpy().flatten().astype('float32')
np.copyto(features_out, features_hat)
valid_output = True

if v == 2 or (v == 1 and (self.state == "search" or self.state == "candidate" or prev_state == "candidate")):
print(f"{self.mf:3d} state: {self.state:10s} valid: {candidate:d} {endofover:d} {self.valid_count:2d} Dthresh: {acq.Dthresh:8.2f} ", end='', file=sys.stderr)
print(f"Dtmax12: {acq.Dtmax12:8.2f} {acq.Dtmax12_eoo:8.2f} tmax: {self.tmax:4d} fmax: {self.fmax:6.2f}", end='', file=sys.stderr)
if auxdata and self.state == "sync":
print(f" aux: {aux_bits:} uw_err: {self.uw_errors:d}", file=sys.stderr)
else:
print("",file=sys.stderr)

if candidate:
valid_count = Nmf_unsync
else:
valid_count -= 1
if unsync_enable and valid_count == 0:
# iterate state machine
next_state = self.state
prev_state = self.state
if self.state == "search":
if candidate:
next_state = "candidate"
self.tmax_candidate = self.tmax
self.valid_count = 1
elif self.state == "candidate":
# look for 3 consecutive matches with about the same timing offset
if candidate and np.abs(self.tmax-self.tmax_candidate) < 0.02*M:
self.valid_count = self.valid_count + 1
if self.valid_count > 3:
next_state = "sync"
self.synced_count = 0
uw_fail = False
if auxdata:
self.uw_errors = 0
self.valid_count = Nmf_unsync
ffine_range = np.arange(self.fmax-10,self.fmax+10,0.25)
tfine_range = np.arange(self.tmax-1,self.tmax+2)
self.tmax,self.fmax = acq.refine(rx_buf, self.tmax, self.fmax, tfine_range, ffine_range)
else:
next_state = "search"
elif self.state == "sync":
# during some tests it's useful to disable these unsync features
unsync_enable = True

if unsync_enable and (endofover or uw_fail):
next_state = "search"
if candidate:
self.valid_count = Nmf_unsync
else:
self.valid_count -= 1
if unsync_enable and self.valid_count == 0:
next_state = "search"

state = next_state
mf += 1
if unsync_enable and (endofover or uw_fail):
next_state = "search"

self.state = next_state
self.mf += 1

return valid_output

if __name__ == '__main__':
rx = radae_rx()

# TODO put this size in a getter func
features_out = np.zeros(model.Nzmf*model.dec_stride*nb_total_features,dtype=np.float32)
while True:
buffer = sys.stdin.buffer.read(rx.nin*struct.calcsize("ff"))
if len(buffer) != rx.nin*struct.calcsize("ff"):
break
buffer_complex = np.frombuffer(buffer,np.csingle)
valid_output = rx.do_radae_rx(buffer_complex, features_out)
if valid_output:
print(features_out.shape, file=sys.stderr)
sys.stdout.buffer.write(features_out)

0 comments on commit 96b977a

Please sign in to comment.