Skip to content

Commit

Permalink
plot S2_genie v S2_est, ability to add offset
Browse files Browse the repository at this point in the history
  • Loading branch information
drowe67 committed Jan 5, 2025
1 parent 9481435 commit c9c306c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 30 deletions.
102 changes: 73 additions & 29 deletions est_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,23 @@ def snr_est_test(model, snr_target, h, Nw, test_S1=False):
# matrix of received pilots plus noise samples
Pcn_hat = h*Pcn + n

Rcn_hat_genie = np.abs(h)*Pcn + n
Ns = model.Ns + 1
rx_sym_pilots = torch.zeros((1,1,Nw*Ns,Nc), dtype=torch.complex64)
rx_sym_pilots[0,0,::Ns,:] = torch.tensor(Pcn_hat)
rx_pilots = receiver.est_pilots(rx_sym_pilots, Nw-1, Nc, Ns)
rx_pilots = rx_pilots.cpu().detach().numpy()
rx_phase = np.angle(rx_pilots)
Rcn_hat_est = Pcn_hat *np.exp(-1j*rx_phase)

# phase corrected received pilots
genie_phase = not args.eq_ls

if genie_phase:
Rcn_hat = np.abs(h)*Pcn + n
Rcn_hat = Rcn_hat_genie
else:
Ns = model.Ns + 1
rx_sym_pilots = torch.zeros((1,1,Nw*Ns,Nc), dtype=torch.complex64)
rx_sym_pilots[0,0,::Ns,:] = torch.tensor(Pcn_hat)
rx_pilots = receiver.est_pilots(rx_sym_pilots, Nw-1, Nc, Ns)
rx_pilots = rx_pilots.cpu().detach().numpy()
rx_phase = np.angle(rx_pilots)
#print(rx_phase.shape)
#print(rx_phase)
Rcn_hat = Pcn_hat *np.exp(-1j*rx_phase)

Rcn_hat = Rcn_hat_est

if args.plots:
plt.figure(1)
plt.plot(Rcn_hat.real, Rcn_hat.imag,'b+')
Expand All @@ -96,17 +98,25 @@ def snr_est_test(model, snr_target, h, Nw, test_S1=False):
print(f"S1: {S1:5.2f} S1_sum: {S1_sum:5.2f}")

# calculate S2 and SNR est
S2 = np.sum(np.abs(Rcn_hat.imag)**2)
snr_est = S1/(2*S2) - 1

# actual snr as check, for AWGN should be close to snr_target, for non untity h
# it can be quite different

S2_genie = np.sum(np.abs(Rcn_hat_genie.imag)**2)
S2_est = np.sum(np.abs(Rcn_hat_est.imag)**2)
if genie_phase:
snr_est = S1/(2*S2_genie) - 1
else:
snr_est = S1/(2*S2_est) - 1

# remove occasional illegal values
if snr_est <= 0:
snr_est = 0.1

# actual snr for this time window as check, for AWGN should be
# close to snr_target, for non untity h it can be quite different
snr_check = np.sum(np.abs(h*Pcn)**2)/np.sum(np.abs(n)**2)
#print(f"S: {np.sum(np.abs(h*Pcn)**2):f} N: {np.sum(np.abs(n)**2)}")
#print(f"snr:target {snr_target:5.2f} snr_check: {snr_check.real:5.2f} snr_est: {snr_est:5.2f}")

return snr_est,snr_check
# user supplied correction factor
snr_est *= 10 ** (args.offset/10)

return snr_est, snr_check, S2_genie, S2_est

# Bring up a RADAE model
latent_dim = 80
Expand All @@ -122,47 +132,79 @@ def snr_est_test(model, snr_target, h, Nw, test_S1=False):

# single timestep test
def single(snrdB, h, Nw, test_S1):
snr_est, snr_check = snr_est_test(model, 10**(snrdB/10), h, Nw, test_S1)
snr_est, snr_check, S2_genie, S2_est = snr_est_test(model, 10**(snrdB/10), h, Nw, test_S1)
print(f"snrdB: {snrdB:5.2f} snrdB_check: {10*np.log10(snr_check):5.2f} snrdB_est: {10*np.log10(snr_est):5.2f}")

# run over a sequence of timesteps, and return lists of each each est
def sequence(Ntimesteps, snrdB, h, Nw):
snrdB_est_list = []
snrdB_check_list = []
NdB_genie_list = []
NdB_est_list = []

for i in range(Ntimesteps):
snr_est, snr_check = snr_est_test(model, 10**(snrdB/10), h[i*Nw:(i+1)*Nw,:], Nw)
snr_est, snr_check, S2_genie, S2_est = snr_est_test(model, 10**(snrdB/10), h[i*Nw:(i+1)*Nw,:], Nw)
snrdB_check = 10*np.log10(snr_check)
snrdB_est = 10*np.log10(snr_est)
print(f"snrdB: {snrdB:5.2f} snrdB_check: {snrdB_check:5.2f} snrdB_est: {snrdB_est:5.2f}")
NdB_genie = 10*np.log10(2*S2_genie)
NdB_est = 10*np.log10(2*S2_est)

print(f"snrdB: {snrdB:5.2f} snrdB_check: {snrdB_check:5.2f} snrdB_est: {snrdB_est:5.2f} NdB: {NdB_genie:5.2f} {NdB_est:5.2f}")

snrdB_est_list = np.append(snrdB_est_list, snrdB_est)
snrdB_check_list = np.append(snrdB_check_list, snrdB_check)
NdB_genie_list = np.append(NdB_genie_list, NdB_genie)
NdB_est_list = np.append(NdB_est_list, NdB_est)

return snrdB_est_list, snrdB_check_list
return snrdB_est_list, snrdB_check_list, NdB_genie,NdB_est

# sweep across SNRs
def sweep(Ntimesteps, h, Nw):

EsNodB_check = []
EsNodB_est = []
NdB_genie = []
NdB_est = []

r = range(-5,20)
for aEsNodB in r:
aEsNodB_check, aEsNodB_est = sequence(Ntimesteps, aEsNodB, h, Nw)
aEsNodB_check, aEsNodB_est, aNdB_genie, aNdB_est = sequence(Ntimesteps, aEsNodB, h, Nw)
EsNodB_check = np.append(EsNodB_check, aEsNodB_check)
EsNodB_est = np.append(EsNodB_est, aEsNodB_est)
NdB_genie = np.append(NdB_genie, aNdB_genie)
NdB_est = np.append(NdB_est, aNdB_est)

z = np.polyfit(EsNodB_check, EsNodB_est, 1)
print(z)
EsNodB_est_fit = z[0]*EsNodB_check + z[1]

plt.figure(1)
plt.plot(EsNodB_check, EsNodB_est,'b+')
plt.plot(EsNodB_check, EsNodB_est_fit,'r')
plt.plot(r,r)
plt.axis([-5, 20, -5, 20])
plt.grid()
plt.xlabel('SNR (dB)')
plt.ylabel('SNR est (dB)')

z = np.polyfit(NdB_genie, NdB_est, 1)
print(z)
NdB_est_fit = z[0]*NdB_genie + z[1]
print(len(NdB_est))
plt.figure(2)
plt.plot(NdB_genie, NdB_est,'b+')
plt.plot(NdB_genie,NdB_genie)
plt.plot(NdB_genie, NdB_est_fit,'r')
plt.grid()
plt.xlabel('N_genie (dB)')
plt.ylabel('N_est (dB)')

plt.show()

# save test file of test points for Latex plotting in Octave radae_plots.m:est_snr_plot()
test_points = np.transpose(np.array((EsNodB_check,EsNodB_est)))
np.savetxt('est_snr.txt',test_points,delimiter='\t')
if args.save_text:
# save test file of test points for Latex plotting in Octave radae_plots.m:est_snr_plot()
test_points = np.transpose(np.array((EsNodB_check,EsNodB_est)))
np.savetxt(args.save_text,test_points,delimiter='\t')

parser = argparse.ArgumentParser()
parser.add_argument('--snrdB', type=float, default=10.0, help='snrdB set point')
Expand All @@ -172,8 +214,10 @@ def sweep(Ntimesteps, h, Nw):
parser.add_argument('-T', type=float, default=1.0, help='length of time window for estimate (default 1.0 sec)')
parser.add_argument('--Nt', type=int, default=1, help='number of analysis time windows to test across (default 1)')
parser.add_argument('--test_S1', action='store_true', help='calculate S1 two ways to check S1 expression')
parser.add_argument('--eq_ls', action='store_true', help='est phase from received pilots usin least square (default genie phase)')
parser.add_argument('--eq_ls', action='store_true', help='est phase from received pilots using least square (default genie phase)')
parser.add_argument('--plots', action='store_true', help='debug plots (default off)')
parser.add_argument('--save_text', type=str, default="", help='path to text file to save test points')
parser.add_argument('--offset', type=float, default=0.0, help='y offset correction in dB (default 0)')
args = parser.parse_args()

Nw = int(args.T // model.Tmf)
Expand Down
3 changes: 2 additions & 1 deletion multipath_samples.m
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ function multipath_samples(ch, Fs, Rs, Nc, Nseconds, H_fn, G_fn="",H_complex=0)
else
bytes_per_sample = 4
end
printf("H file size is Nseconds*Rs*Nc*(%d bytes/sample) = %d*%d*%d*%d = %d bytes\n", bytes_per_sample, Nseconds,Rs,Nc,Nseconds*Rs*Nc*bytes_per_sample)
printf("H file size is Nseconds*Rs*Nc*(%d bytes/sample) = %d*%d*%d*%d = %d bytes\n", bytes_per_sample,
Nseconds,Rs,Nc,bytes_per_sample, Nseconds*Rs*Nc*bytes_per_sample)
f=fopen(H_fn,"wb");
[r c] = size(H);
Hflat = reshape(H', 1, r*c);
Expand Down

0 comments on commit c9c306c

Please sign in to comment.