Skip to content

Commit 440bbba

Browse files
Move lpr_patterns to config (open-edge-platform#209)
1 parent f4cc817 commit 440bbba

File tree

4 files changed

+23
-22
lines changed

4 files changed

+23
-22
lines changed

tensorflow_toolkit/lpr/chinese_lp/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
max_lp_length = 20
2222
rnn_cells_num = 128
2323

24+
# Licens plate patterns
25+
lpr_patterns = [
26+
'^<[^>]*>[A-Z][0-9A-Z]{5}$',
27+
'^<[^>]*>[A-Z][0-9A-Z][0-9]{3}<police>$',
28+
'^<[^>]*>[A-Z][0-9A-Z]{4}<[^>]*>$', # <Guangdong>, <Hebei>
29+
'^WJ<[^>]*>[0-9]{4}[0-9A-Z]$',
30+
]
31+
2432
# Path to the folder where all training and evaluation artifacts will be located
2533
model_dir = os.path.realpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'model'))
2634
if not os.path.exists(model_dir):
@@ -38,7 +46,7 @@ class train:
3846
opt_type = 'Adam'
3947

4048
save_checkpoints_steps = 1000 # Number of training steps when checkpoint should be saved
41-
display_iter = 10
49+
display_iter = 100
4250

4351
apply_basic_aug = False
4452
apply_stn_aug = True

tensorflow_toolkit/lpr/chinese_lp/config_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
max_lp_length = 20
2222
rnn_cells_num = 128
2323

24+
# Licens plate patterns
25+
lpr_patterns = [
26+
'^<[^>]*>[A-Z][0-9A-Z]{5}$',
27+
'^<[^>]*>[A-Z][0-9A-Z][0-9]{3}<police>$',
28+
'^<[^>]*>[A-Z][0-9A-Z]{4}<[^>]*>$', # <Guangdong>, <Hebei>
29+
'^WJ<[^>]*>[0-9]{4}[0-9A-Z]$',
30+
]
31+
2432
# Path to the folder where all training and evaluation artifacts will be located
2533
model_dir = os.path.realpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'model'))
2634
if not os.path.exists(model_dir):
@@ -37,7 +45,7 @@ class train:
3745
grad_noise_scale = 0.001
3846
opt_type = 'Adam'
3947

40-
save_checkpoints_steps = 1000 # Number of training steps when checkpoint should be saved
48+
save_checkpoints_steps = 500 # Number of training steps when checkpoint should be saved
4149
display_iter = 10
4250

4351
apply_basic_aug = False

tensorflow_toolkit/lpr/lpr/utils.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,12 @@ def dataset_size(fname):
2323
count += 1
2424
return count
2525

26-
27-
LPR_PATTERNS = [
28-
'^<[^>]*>[A-Z][0-9A-Z]{5}$',
29-
'^<[^>]*>[A-Z][0-9A-Z][0-9]{3}<police>$',
30-
'^<[^>]*>[A-Z][0-9A-Z]{4}<[^>]*>$', # <Guangdong>, <Hebei>
31-
'^WJ<[^>]*>[0-9]{4}[0-9A-Z]$',
32-
]
33-
34-
def lpr_pattern_check(label):
35-
for pattern in LPR_PATTERNS:
26+
def lpr_pattern_check(label, lpr_patterns):
27+
for pattern in lpr_patterns:
3628
if re.match(pattern, label):
3729
return True
3830
return False
3931

40-
def find_best(predictions):
41-
for prediction in predictions:
42-
if lpr_pattern_check(prediction):
43-
return prediction
44-
return predictions[0] # fallback
45-
46-
4732
def edit_distance(string1, string2):
4833
len1 = len(string1) + 1
4934
len2 = len(string2) + 1
@@ -60,13 +45,13 @@ def edit_distance(string1, string2):
6045
return tbl[i, j]
6146

6247

63-
def accuracy(label, val, vocab, r_vocab):
48+
def accuracy(label, val, vocab, r_vocab, lpr_patterns):
6449
pred = decode_beams(val, r_vocab)
6550
label_len = len(label)
6651
acc, acc1 = 0, 0
6752
num = 0
6853
for i in range(label_len):
69-
if not lpr_pattern_check(label[i].decode('utf-8')): # GT label fails
54+
if not lpr_pattern_check(label[i].decode('utf-8'), lpr_patterns): # GT label fails
7055
print('GT label fails: ' + label[i].decode('utf-8'))
7156
continue
7257
best = pred[i]

tensorflow_toolkit/lpr/tools/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def validate(config):
131131
num = 0
132132
for _ in range(steps):
133133
val, slabel, _ = sess.run([d_predictions, label_val, file_names])
134-
acc, acc1, num_ = accuracy(slabel, val, config.vocab, config.r_vocab)
134+
acc, acc1, num_ = accuracy(slabel, val, config.vocab, config.r_vocab, config.lpr_patterns)
135135
mean_accuracy += acc
136136
mean_accuracy_minus_1 += acc1
137137
num += num_

0 commit comments

Comments
 (0)