Skip to content

modify the FormatMetaparameters function in train_lm.py #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 65 additions & 12 deletions scripts/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,68 @@ def WriteMetaparameters(metaparameters, ngram_order, num_train_sets, out_file):
i += 1
f.close()

def FormatMetaparameters(metaparameters):
def FormatMetaparameters(metaparameters, ngram_order, num_train_sets):
assert num_train_sets == len(metaparameters) - (ngram_order - 1) * 4
out = []
for param in metaparameters:
x = '{:.3f}'.format(param)
if x == '0.00':
x = '{:.3g}'.format(param)
out_param = []

for idx, param in enumerate(metaparameters):
# check whether metaparameters contain 0.0 or 1.0
if param in [0.0, 1.0]:
LogMessage("train_lm.py: Warning: The {0}th parameter is exact {1}. metaparameters should be in range (0, 1).".format(idx + 1, param))
round_idx = 3
x = round(param, round_idx)
# check whether 0.000 or 1.000 occurs because of rounding
while x in [0.0, 1.0]:
round_idx += 1
x = round(param, round_idx)
if round_idx >= 20:
LogMessage("train_lm.py: Warning: The {0}th parameter is too close to {1}. Double check it.".format(idx + 1, x))
break
out.append(x)

return ','.join(out)
# check repeating data-type parameters
(marker, repeat_values) = FindRepeatingValues(out[:num_train_sets])
while marker == True:
round_idx += 1
for idx, param in enumerate(metaparameters[:num_train_sets]):
out[idx] = round(param, round_idx)
(marker, repeat_values) = FindRepeatingValues(out[:num_train_sets])
# terminate the loop and print out repeating values:
if round_idx >= 20:
LogMessage("train_lm.py: Warning: There are repeating parameters {0}.".format(repeat_values))
break

# check repeating parameters of a certain order
for order in range(2, ngram_order + 1):
(marker, repeat_values) = FindRepeatingValues(out[num_train_sets + (order - 2) * 4: num_train_sets + (order - 1) * 4])
round_idx = 3
while marker == True:
round_idx += 1
for idx, param in enumerate(metaparameters[num_train_sets + (order - 2) * 4: num_train_sets + (order - 1) * 4]):
out[idx + num_train_sets + (order - 2) * 4] = round(param, round_idx)
(marker, repeat_values) = FindRepeatingValues(out[num_train_sets + (order - 2) * 4: num_train_sets + (order - 1) * 4])
# terminate the loop and print out repeating values:
if round_idx >= 20:
LogMessage("train_lm.py: Warning: There are repeating values {0} of order {1}.".format(repeat_values, order))
break
order += 1

for param in out:
out_param.append(str(param))
return ','.join(out_param)

def FindRepeatingValues(parameters):
marker = False
seen = set()
repeat_values = ""
for idx, item in enumerate(parameters):
if item not in seen and item not in [0.0, 1.0]:
seen.add(item)
else:
marker = True
repeat_values += " " + str(item)
return (marker, repeat_values)

def ParseMetaparameters(encoded_str, ngram_order, num_train_sets):
metaparameters = encoded_str.split(',')
Expand Down Expand Up @@ -275,15 +328,15 @@ def ParseMetaparameters(encoded_str, ngram_order, num_train_sets):
done_file = os.path.join(int_dir, '.done')
os.remove(done_file)

for name in [ 'ngram_order', 'num_train_sets' ]:
f = open(os.path.join(counts_dir, name))
globals()[name] = int(f.readline())
f.close()

metaparam_file = ''
if args.bypass_metaparameter_optimization != None:
LogMessage("Bypass optimization steps")

for name in [ 'ngram_order', 'num_train_sets' ]:
f = open(os.path.join(counts_dir, name))
globals()[name] = int(f.readline())
f.close()

metaparameters = ParseMetaparameters(args.bypass_metaparameter_optimization,
ngram_order, num_train_sets)
metaparam_file = os.path.join(work_dir, 'bypass.metaparams')
Expand Down Expand Up @@ -342,7 +395,7 @@ def ParseMetaparameters(encoded_str, ngram_order, num_train_sets):
metaparameters = ReadMetaparameters(metaparam_file)
LogMessage("You can set --bypass-metaparameter-optimization='{0}' "
"to get equivalent results".format(
FormatMetaparameters(metaparameters)))
FormatMetaparameters(metaparameters, ngram_order, num_train_sets)))

# make lm dir
lm_dir = os.path.join(args.lm_dir, lm_name + '.pocolm')
Expand Down