Skip to content

Commit

Permalink
sequences of multiple time windows and curves working
Browse files Browse the repository at this point in the history
  • Loading branch information
drowe67 committed Jan 3, 2025
1 parent e904e18 commit c1a4458
Showing 1 changed file with 44 additions and 66 deletions.
110 changes: 44 additions & 66 deletions est_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
os.environ['CUDA_VISIBLE_DEVICES'] = ""
device = torch.device("cpu")

def snr_est_test(model, snr_target, h, Nw):
def snr_est_test(model, snr_target, h, Nw, test_S1=False):

Nc = model.Nc
Pc = np.array(model.pilot_gain*model.P)
Expand All @@ -58,35 +58,34 @@ def snr_est_test(model, snr_target, h, Nw):
No = Es/snr_target # noise per symbol
sigma = np.sqrt(No)/(2**0.5)
n = sigma*(np.random.normal(size=(Nw,Nc)) + 1j*np.random.normal(size=(Nw,Nc)))
print(Es,No,sigma,np.var(n),np.var(Pcn), np.sum(np.abs(Pcn)**2))

# matrix of received pilots plus noise samples
Pcn_hat = h*Pcn + n

# phase corrected received pilots
Rcn_hat = np.abs(h)*Pcn + n

# calculate S1 two ways to test expression
# calculate S1 two ways to test expression, observe second term is small

S1 = np.sum(np.abs(Pcn_hat)**2)
S1_first = np.sum(np.abs(h*Pcn)**2)
S1_second = np.sum(2*(h*Pcn*n).real)
#S1_second = np.sum(h*Pcn*np.conj(n) + np.conj(h*Pcn)*n)
S1_third = np.sum(np.abs(n)**2)
S1_sum = S1_first + S1_second + S1_third
print(f"S1_first: {S1_first:5.2f} S1_second: {S1_second:5.2f} S1_third: {S1_third:5.2f}")
print(f"S1: {S1:5.2f} S1_sum: {S1_sum:5.2f}")

# calculate SNR est
S2 = np.sum(np.abs(Rcn_hat.imag)**2)
print(S1, S2)
#print(f"S1: {S1:f} S2: {S2:f}")

if test_S1:
S1_first = np.sum(np.abs(h*Pcn)**2)
S1_second = np.sum(2*(h*Pcn*n).real)
S1_third = np.sum(np.abs(n)**2)
S1_sum = S1_first + S1_second + S1_third
print(f"S1_first: {S1_first:5.2f} S1_second: {S1_second:5.2f} S1_third: {S1_third:5.2f}")
print(f"S1: {S1:5.2f} S1_sum: {S1_sum:5.2f}")

# calculate S2 and SNR est
S2 = np.sum(np.abs(Rcn_hat.imag)**2)
snr_est = S1/(2*S2) - 1

# actual snr as check, should be same as snr_target
# actual snr as check, for AWGN should be close to snr_target, for non untity h
# it can be quite different

snr_check = np.sum(np.abs(h*Pcn)**2)/np.sum(np.abs(n)**2)
print(f"S: {np.sum(np.abs(h*Pcn)**2):f} N: {np.sum(np.abs(n)**2)}")
print(f"snr:target {snr_target:5.2f} snr_check: {snr_check.real:5.2f} snr_est: {snr_est:5.2f}")
#print(f"S: {np.sum(np.abs(h*Pcn)**2):f} N: {np.sum(np.abs(n)**2)}")
#print(f"snr:target {snr_target:5.2f} snr_check: {snr_check.real:5.2f} snr_est: {snr_est:5.2f}")

return snr_est,snr_check

Expand All @@ -97,56 +96,44 @@ def snr_est_test(model, snr_target, h, Nw):
model = RADAE(num_features, latent_dim, EbNodB=100, rate_Fs=True, pilots=True, cyclic_prefix=0.004, bottleneck=3)

# single timestep test
def single(snrdB, h, Nw):
snr_est, snr_check = snr_est_test(model, 10**(snrdB/10), h, Nw)
#print(f"snrdB: {snrdB:5.2f} snrdB_check: {10*np.log10(snr_check):5.2f} snrdB_est: {10*np.log10(snr_est):5.2f}")
def single(snrdB, h, Nw, test_S1):
snr_est, snr_check = snr_est_test(model, 10**(snrdB/10), h, Nw, test_S1)
print(f"snrdB: {snrdB:5.2f} snrdB_check: {10*np.log10(snr_check):5.2f} snrdB_est: {10*np.log10(snr_est):5.2f}")

# run over a sequence of timesteps
def sequence(Ntimesteps, EsNodB, h):
rx_sym = np.zeros((Ntimesteps,model.Nc),dtype=np.csingle)

print(np.mean(h[:Ntimesteps,:]**2))
#sum_EsNodB = 0
#sum_EsNodB_est = 0
sum_Ct_sq = 0
# run over a sequence of timesteps, and return mean
def sequence(Ntimesteps, snrdB, h, Nw):
sum_snrdB_est = 0
sum_snrdB_check = 0

for i in range(Ntimesteps):
Ct_sq, arx_sym = snr_est_test(model, 10**(EsNodB/10), h[i,:])
#print(f"EsNodB_actual: {10*np.log10(EsNo_actual):5.2f} EsNodB_est: {10*np.log10(EsNo_est):5.2f}")
#sum_EsNodB += 10*np.log10(EsNo_actual)
#sum_EsNodB_est += 10*np.log10(EsNo_est)
sum_Ct_sq += Ct_sq.real
rx_sym[i,:] = arx_sym
snr_est, snr_check = snr_est_test(model, 10**(snrdB/10), h[i*Nw:(i+1)*Nw,:], Nw)
snrdB_check = 10*np.log10(snr_check)
snrdB_est = 10*np.log10(snr_est)
print(f"snrdB: {snrdB:5.2f} snrdB_check: {snrdB_check:5.2f} snrdB_est: {snrdB_est:5.2f}")
sum_snrdB_est += snrdB_est
sum_snrdB_check += snrdB_check

Ct_sq = sum_Ct_sq/Ntimesteps
P = np.array(model.pilot_gain*model.P)
EsNo_est = Ct_sq/(np.dot(np.conj(P),P) - Ct_sq)
print(Ct_sq, EsNo_est)
EsNodB_est = 10*np.log10(EsNo_est)
print(f"EsNodB: {EsNodB:5.2f} EsNodB_est: {EsNodB_est:5.2f}")

return EsNodB_est, rx_sym
return sum_snrdB_check/Ntimesteps, sum_snrdB_est/Ntimesteps

# sweep across SNRs
def sweep(Ntimesteps, h):
def sweep(Ntimesteps, h, Nw):

EsNodB = []
EsNodB_check = []
EsNodB_est = []
r = range(-5,15)
for aEsNodB in r:
aEsNodB_est, tx_sym = sequence(Ntimesteps, aEsNodB, h)
EsNodB = np.append(EsNodB, aEsNodB)
aEsNodB_check, aEsNodB_est = sequence(Ntimesteps, aEsNodB, h, Nw)
EsNodB_check = np.append(EsNodB_check, aEsNodB_check)
EsNodB_est = np.append(EsNodB_est, aEsNodB_est)

plt.figure(1)
plt.plot(EsNodB, EsNodB_est,'b+')
plt.plot(EsNodB_check, EsNodB_est,'b+')
plt.plot(r,r)
plt.grid()
plt.show()

# save test file of test points for Latex plotting in Octave radae_plots.m:est_snr_plot()
test_points = np.transpose(np.array((EsNodB,EsNodB_est)))
test_points = np.transpose(np.array((EsNodB_check,EsNodB_est)))
np.savetxt('est_snr.txt',test_points,delimiter='\t')

parser = argparse.ArgumentParser()
Expand All @@ -155,33 +142,24 @@ def sweep(Ntimesteps, h):
parser.add_argument('--sequence', action='store_true', help='run over a sequence of timesteps')
parser.add_argument('--h_file', type=str, default="", help='path to rate Rs multipath samples, rate Rs time steps by Nc carriers .f32 format')
parser.add_argument('-T', type=float, default=1.0, help='length of time window for estimate (default 1.0 sec)')
parser.add_argument('--Nt', type=int, default=1, help='number of analysis time windows to average (default 1)')
parser.add_argument('--test_S1', action='store_true', help='calculate S1 two ways to check S1 expression')
args = parser.parse_args()

Nw = int(args.T // model.Tmf)

if len(args.h_file):
h = np.fromfile(args.h_file,dtype=np.float32)
h = h.reshape((-1,model.Nc))
print(h.shape)
# sample once every modem frame
h = h[::model.Ns+1,:]
h = h[:Nw,:]
print(h.shape)
h = h[:Nw*args.Nt,:]
else:
h = 0.5*np.ones((Nw,model.Nc))
print(h.shape)
h = np.ones((Nw*args.Nt,model.Nc))

if args.single:
single(args.snrdB,h, Nw)
single(args.snrdB, h, Nw, args.test_S1)
elif args.sequence:
snrdB_est, rx_sym = sequence(args.snrdB, h, Nw)
#print(rx_sym.dtype, rx_sym.shape)

plt.figure(1)
plt.plot(rx_sym.real, rx_sym.imag, 'b+')
mx = np.max(np.abs(rx_sym))
mx = 10*np.ceil(mx/10)
plt.axis([-mx,mx,-mx,mx])
plt.show()
sequence(args.Nt, args.snrdB, h, Nw)
else:
sweep(args.Nt,h)
sweep(args.Nt,h, Nw)

0 comments on commit c1a4458

Please sign in to comment.