Skip to content

Commit

Permalink
mleq_cvurves.sh set up for different frames, loss_phase working again
Browse files Browse the repository at this point in the history
  • Loading branch information
drowe67 committed Jan 31, 2025
1 parent b4482a8 commit a77e1c9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
14 changes: 7 additions & 7 deletions ml_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@
parser.add_argument('--load_model', type=str, default="", help='before inference, load model using this filename')
parser.add_argument('--curve', type=str, default="", help='before inference, load model using this filename')
parser.add_argument('--framer', type=int, default=1, help='framer design')
parser.add_argument('--batch_size', type=int, help="batch size, default: 32", default=32)
parser.set_defaults(train=True)
parser.set_defaults(plots=True)
args = parser.parse_args()
n_syms = args.n_syms

bps = 2
batch_size = 16
w1 = 32

# Get cpu, gpu or mps device for training.
device = (
Expand Down Expand Up @@ -237,10 +236,10 @@ def forward(self, tx_data):

# sym and sym hat in float format (real,imag pairs)
def loss_phase_mse(sym_hat, sym):
sym = sym[:,0] + 1j*sym[:,1]
sym_hat = sym_hat[:,0] + 1j*sym_hat[:,1]
sym = sym[:,0::2] + 1j*sym[:,1::2]
sym_hat = sym_hat[:,0::2] + 1j*sym_hat[:,1::2]
error = torch.angle(sym*torch.conj(sym_hat))
loss = torch.sum(error**2)
loss = torch.mean(error**2)
return loss

if args.framer == 1:
Expand All @@ -259,11 +258,12 @@ def loss_phase_mse(sym_hat, sym):
if args.loss_phase:
loss_fn = loss_phase_mse
else:
loss_fn = nn.MSELoss(reduction='sum')
loss_fn = nn.MSELoss(reduction='mean')

optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

dataset = aQPSKDataset(n_syms, model.n_data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size)

# Train model
for epoch in range(args.epochs):
Expand Down
33 changes: 22 additions & 11 deletions mleq_curves.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,36 @@
# ML EQ experiments - generates sets of curves

epochs=100
n_syms=100000
batch_size=128

function frame01_train() {
function train() {
f=$1
# train with mse and phase loss functions
python3 ml_eq.py --EbNodB 4 --phase_offset --lr 0.001 --epochs ${epochs} --save_model mleq01_mse.model --noplots
python3 ml_eq.py --EbNodB 4 --phase_offset --lr 0.001 --epochs ${epochs} --loss_phase --save_model mleq01_phase.model --noplots
python3 ml_eq.py --frame ${f} --EbNodB 4 --phase_offset \
--lr 0.05 --n_syms ${n_syms} --epochs ${epochs} --batch_size ${batch_size} --save_model mleq0${f}_mse.model --noplots
python3 ml_eq.py --frame ${f} --loss_phase --EbNodB 4 --phase_offset \
--lr 0.05 --n_syms ${n_syms} --epochs ${epochs} --batch_size ${batch_size} --save_model mleq0${f}_phase.model --noplots
}

function frame01_curve() {
function curve() {
f=$1
# run BER v Eb/No curves, including DSP lin as control
python3 ml_eq.py --eq dsp --notrain --phase_offset --curve mleq01_ber_lin.txt
python3 ml_eq.py --notrain --curve mleq01_ber_mse.txt --phase_offset --load_model mleq01_mse.model
python3 ml_eq.py --notrain --curve mleq01_ber_phase.txt --phase_offset --load_model mleq01_phase.model
python3 ml_eq.py --frame ${f} --notrain --eq dsp --phase_offset --curve mleq0${f}_ber_lin.txt
python3 ml_eq.py --frame ${f} --notrain --load_model mleq0${f}_mse.model --phase_offset --curve mleq0${f}_ber_mse.txt
python3 ml_eq.py --frame ${f} --notrain --load_model mleq0${f}_phase.model --phase_offset --curve mleq0${f}_ber_phase.txt

# generate EPS plots for paper
echo "radae_plots; plot_ber_EbNodB( \
'mleq01_ber_lin.txt','mleq01_ber_mse.txt','mleq01_ber_phase.txt', \
'mleq01_ber_EbNodB.png','mleq01_ber_EbNodB.eps')" | octave -qf
'mleq0${f}_ber_lin.txt','mleq0${f}_ber_mse.txt','mleq0${f}_ber_phase.txt', \
'mleq0${f}_ber_EbNodB.png','mleq0${f}_ber_EbNodB.eps')" | octave -qf
}

frame01_train
frame01_curve
if [ $# -ne 1 ]; then
echo "usage mleq_cruves.sh framer[1|2]"
exit 1
fi

train $1
curve $1

0 comments on commit a77e1c9

Please sign in to comment.