Skip to content

Commit

Permalink
Merge pull request #31 from drowe67/dr-reset
Browse files Browse the repository at this point in the history
Reset decoder states on resync
  • Loading branch information
drowe67 authored Oct 29, 2024
2 parents 63950bf + 4438249 commit 516f4e4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
2 changes: 1 addition & 1 deletion radae/dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def do_pilot_eq_one(self, num_modem_frames, rx_sym_pilots):

# est ampl across one just two sets of pilots seems to work OK (loss isn't impacted)
if self.coarse_mag:
mag = torch.mean(torch.abs(rx_pilots)**2)**0.5
mag = torch.mean(torch.abs(rx_pilots)**2)**0.5 + 1E-6
if self.bottleneck == 3:
mag = mag*torch.abs(self.P[0])/self.pilot_gain
#print(f"coarse mag: {mag:f}", file=sys.stderr)
Expand Down
21 changes: 21 additions & 0 deletions radae/radae.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,14 @@ def __init__(self, input_dim, hidden_dim, batch_first):
self.hidden_dim = hidden_dim
self.states = torch.zeros(1,1,self.hidden_dim)
self.gru = nn.GRU(input_dim, hidden_dim, batch_first=batch_first)

def forward(self, x):
gru_out,self.states = self.gru(x,self.states)
return gru_out

def reset(self):
self.states = torch.zeros(1,1,self.hidden_dim)

# Wrapper for conv1D layer that maintains state internally
class Conv1DStatefull(nn.Module):
def __init__(self, input_dim, output_dim, dilation=1):
Expand All @@ -140,6 +144,9 @@ def forward(self, x):
conv_in = conv_in.permute(0, 2, 1)
return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)

def reset(self):
self.states = torch.zeros(1,self.states_len,self.input_dim)

#Gated Linear Unit activation
class GLU(nn.Module):
def __init__(self, feat_size):
Expand Down Expand Up @@ -427,6 +434,20 @@ def forward(self, z):
features = torch.reshape(x,(1,z.shape[1]*self.FRAMES_PER_STEP,self.output_dim))
return features

def reset(self):
self.conv1.reset()
self.conv2.reset()
self.conv3.reset()
self.conv4.reset()
self.conv5.reset()

self.gru1.reset()
self.gru1.reset()
self.gru2.reset()
self.gru3.reset()
self.gru4.reset()
self.gru5.reset()

class RADAE(nn.Module):
def __init__(self,
feature_dim,
Expand Down
3 changes: 2 additions & 1 deletion radae_rxe.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_nin(self):

def get_sync(self):
return self.state == "sync"

def do_radae_rx(self, buffer_complex, features_out):
acq = self.acq
bpf = self.bpf
Expand Down Expand Up @@ -247,6 +247,7 @@ def do_radae_rx(self, buffer_complex, features_out):
self.valid_count = self.valid_count + 1
if self.valid_count > 3:
next_state = "sync"
model.core_decoder_statefull.module.reset()
self.synced_count = 0
uw_fail = False
if auxdata:
Expand Down

0 comments on commit 516f4e4

Please sign in to comment.