Skip to content

Commit

Permalink
train with phase based loss function; classical lin DSP option for co…
Browse files Browse the repository at this point in the history
…mparison
  • Loading branch information
drowe67 committed Jan 27, 2025
1 parent 6e65d7a commit 772c237
Showing 1 changed file with 35 additions and 13 deletions.
48 changes: 35 additions & 13 deletions ml_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
parser.add_argument('--EbNodB', type=float, default=100, help='energy per bit over spectral noise density in dB')
parser.add_argument('--epochs', type=int, default=10, help='number of training epochs')
parser.add_argument('--lr', type=float, default=5E-2, help='learning rate')
parser.add_argument('--loss_phase', action='store_true', help='')
parser.add_argument('--phase_offset', action='store_true', help='insert random phase offset')
parser.add_argument('--bypass_eq', action='store_true', help='bypass equaliser')
parser.add_argument('--eq', type=str, default='ml', help='equaliser ml/bypass/lin (default ml)')
parser.add_argument('--notrain', action='store_false', dest='train', help='bypass training (defualt train, then inference)')
parser.set_defaults(train=True)
args = parser.parse_args()
n_syms = args.n_syms

bps = 2
batch_size = 4
batch_size = 16
w1 = 32
n_pilots = 2
n_data = 1
Expand Down Expand Up @@ -53,9 +54,9 @@ def __getitem__(self, index):
# Generalised network for equalisation, we provide n_pilots and n_data symbols,
# hopefully network will determine which pilots EQ which data symbols. We
# feed the data symbols into each layer so they are available at the end for
# final EQ, DenseNet style. Hopefully pilot and data symbol information will be
# final EQ, DenseNet style. Hopefully both pilot and data symbol information will be
# used in EQ process. Give it a few layers to approximate non-linear functions
# like trig and arg[].
# like cos,sin and arg[] that are used in classical DSP ML.
class EQ(nn.Module):
def __init__(self, n_pilots, n_data, EbNodB):
super().__init__()
Expand All @@ -70,7 +71,7 @@ def __init__(self, n_pilots, n_data, EbNodB):
self.dense5 = nn.Linear(w1+self.n_data, self.n_data)

# note complex values passed in as real,imag pairs
def equaliser(self, pilots,data):
def equaliser(self, pilots, data):
x = torch.relu(self.dense1(torch.cat([pilots, data],-1)))
x = torch.relu(self.dense2(torch.cat([x, data],-1)))
x = torch.relu(self.dense3(torch.cat([x, data],-1)))
Expand Down Expand Up @@ -107,26 +108,47 @@ def forward(self, tx_data):
rx_data[:,1] = rx_frame[:,1].imag

# run equaliser
if args.bypass_eq:
if args.eq == "bypass":
rx_data_eq = rx_data
else:
if args.eq == "ml":
rx_data_eq = self.equaliser(rx_pilots, rx_data)

if args.eq == "lin":
#sum = torch.zeros((batch_size,1), dtype=torch.complex64, device=tx_data.device)
sum = torch.sum(rx_pilots[:,::2] + 1j*rx_pilots[:,1::2],dim=1)
phase_est = torch.angle(sum)
x = rx_frame[:,1]*torch.exp(-1j*phase_est)
print(sum.shape,phase_est.shape,x.shape)

rx_data_eq = torch.zeros((batch_size,2*n_data), device=tx_data.device)
rx_data_eq[:,0] = x.real
rx_data_eq[:,1] = x.imag

# real version of tx_data symbol for loss function
tx_data_real = torch.zeros((batch_size,2*n_data), device=tx_data.device)
tx_data_real[:,0] = tx_data[:,0].real
tx_data_real[:,1] = tx_data[:,0].imag
tx_data_real = tx_data_real.to(device)

return tx_data_real,rx_data, rx_data_eq
return tx_data_real,rx_data,rx_data_eq

# sym and sym hat real,imag pairs
def loss_phase_mse(sym_hat, sym):
sym = sym[:,0] + 1j*sym[:,1]
sym_hat = sym_hat[:,0] + 1j*sym_hat[:,1]
error = torch.angle(sym*torch.conj(sym_hat))
loss = torch.sum(error**2)
return loss

model = EQ(2*n_pilots, 2*n_data, args.EbNodB).to(device)
print(model)
nb_params = sum(p.numel() for p in model.parameters())
print(f" {nb_params} weights")

if args.train:
loss_fn = nn.MSELoss(reduction='sum')
if args.loss_phase:
loss_fn = loss_phase_mse
else:
loss_fn = nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

dataset = aQPSKDataset(n_syms)
Expand All @@ -138,7 +160,7 @@ def forward(self, tx_data):
for batch,(tx_data) in enumerate(dataloader):
tx_data = tx_data.to(device)
tx_data_real,rx_data,rx_data_eq = model(tx_data)
loss = loss_fn(tx_data_real, rx_data_eq)
loss = loss_fn(rx_data_eq,tx_data_real)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Expand All @@ -154,7 +176,7 @@ def forward(self, tx_data):


# Inference using trained model (or non-ML sim if bypass_eq)
n_syms_inf = 1000
n_syms_inf = args.n_syms
model.eval()
bits = torch.sign(torch.rand(n_syms_inf, bps)-0.5)
tx_data = (bits[:,::2] + 1j*bits[:,1::2])/np.sqrt(2.0)
Expand All @@ -169,7 +191,7 @@ def forward(self, tx_data):
n_errors = np.sum(-tx_data_real.flatten()*rx_data_eq.flatten()>0)
n_bits = n_syms_inf*bps
BER = n_errors/n_bits
print(f"n_bits: {n_bits:d} BER: {BER:5.3f}")
print(f"n_bits: {n_bits:d} n_errors: {n_errors:d} BER: {BER:5.3f}")
plt.plot(rx_data[:,0],rx_data[:,1],'+')
plt.plot(rx_data_eq[:,0],rx_data_eq[:,1],'+')
plt.axis([-2,2,-2,2])
Expand Down

0 comments on commit 772c237

Please sign in to comment.