Skip to content

Commit

Permalink
moving init code into class
Browse files Browse the repository at this point in the history
  • Loading branch information
drowe67 committed Sep 25, 2024
1 parent 301fb93 commit a106bed
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions embed/radae_rx.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from radae import RADAE,complex_bpf,acquisition,receiver_one

# Hard code all this for now to avoid arg passing complexities TODO: consider a way to pass in at init-time
model_name = "../model19_check3/checkpoints/checkpoint_epoch_100.pth"
latent_dim = 80
auxdata = True
bottleneck = 3
Expand All @@ -57,19 +56,15 @@
device = torch.device("cpu")

nb_total_features = 36
num_features = 20
num_used_features = 20
num_features = 20
if auxdata:
num_features += 1

# load model from a checkpoint file
model = RADAE(num_features, latent_dim, EbNodB=100, rate_Fs=True,
pilots=True, pilot_eq=True, eq_mean6 = False, cyclic_prefix=0.004,
coarse_mag=True,time_offset=-16, bottleneck=bottleneck)
checkpoint = torch.load(model_name, map_location='cpu',weights_only=True)
model.load_state_dict(checkpoint['state_dict'], strict=False)
# Stateful decoder wasn't present during training, so we need to load weights from existing decoder
model.core_decoder_statefull_load_state_dict()
model.eval()

# check a bunch of model options we rely on for receiver to work
Expand Down Expand Up @@ -112,11 +107,10 @@
# extra Ncp at end so we can handle timing slips
rx_buf = np.zeros(2*Nmf+M+Ncp,np.csingle)
rx = np.zeros(0,np.csingle)
rx_phase_vec = np.zeros(Nmf+M+Ncp,np.csingle)
z_hat_log = torch.zeros(0,model.Nzmf,model.latent_dim)

class radae_rx:
def __init__(self):
def __init__(self,model_name):
self.nin = Nmf
self.state = "search"
self.tmax_candidate = 0
Expand All @@ -126,6 +120,11 @@ def __init__(self):
self.synced_count = 0
self.rx_phase = 1 + 1j*0

checkpoint = torch.load(model_name, map_location='cpu',weights_only=True)
model.load_state_dict(checkpoint['state_dict'], strict=False)
# Stateful decoder wasn't present during training, so we need to load weights from existing decoder
model.core_decoder_statefull_load_state_dict()

def do_radae_rx(self, buffer_complex, features_out):
with torch.inference_mode():
prev_state = self.state
Expand Down Expand Up @@ -168,12 +167,13 @@ def do_radae_rx(self, buffer_complex, features_out):
# correct frequency offset, note we preserve state of phase
# TODO do we need preserve state of phase? We're passing entire vector and there isn't any memory (I think)
w = 2*np.pi*self.fmax/Fs
rx_phase_vec = np.zeros(Nmf+M+Ncp,np.csingle)
for n in range(Nmf+M+Ncp):
self.rx_phase = self.rx_phase*np.exp(-1j*w)
rx_phase_vec[n] = self.rx_phase
rx1 = rx_buf[self.tmax-Ncp:self.tmax-Ncp+Nmf+M+Ncp]
#print(tmax-Ncp, tmax-Ncp+Nmf+M+Ncp,rx_buf.shape, rx1.shape, rx_phase_vec.shape, file=sys.stderr)
rx = torch.tensor(rx1*rx_phase_vec, dtype=torch.complex64)

# run through RADAE receiver DSP
z_hat = receiver.receiver_one(rx)
# decode z_hat to features
Expand All @@ -185,7 +185,8 @@ def do_radae_rx(self, buffer_complex, features_out):
aux_bits = 1*(aux_symb[0,::symb_repeat] > 0)
features_hat = features_hat[:,:,0:20]
self.uw_errors += np.sum(aux_bits)
# add unused features and send to stdout

# add unused features and output
features_hat = torch.cat([features_hat, torch.zeros_like(features_hat)[:,:,:16]], dim=-1)
features_hat = features_hat.cpu().detach().numpy().flatten().astype('float32')
np.copyto(features_out, features_hat)
Expand Down Expand Up @@ -243,16 +244,15 @@ def do_radae_rx(self, buffer_complex, features_out):
return valid_output

if __name__ == '__main__':
rx = radae_rx()
rx = radae_rx(model_name = "../model19_check3/checkpoints/checkpoint_epoch_100.pth")

# TODO put this size in a getter func
features_out = np.zeros(model.Nzmf*model.dec_stride*nb_total_features,dtype=np.float32)
while True:
# TODO put this size in a getter func
features_out = np.zeros(model.Nzmf*model.dec_stride*nb_total_features,dtype=np.float32)
while True:
buffer = sys.stdin.buffer.read(rx.nin*struct.calcsize("ff"))
if len(buffer) != rx.nin*struct.calcsize("ff"):
break
buffer_complex = np.frombuffer(buffer,np.csingle)
valid_output = rx.do_radae_rx(buffer_complex, features_out)
if valid_output:
print(features_out.shape, file=sys.stderr)
sys.stdout.buffer.write(features_out)

0 comments on commit a106bed

Please sign in to comment.