1
+ #!/usr/bin/env python3
2
+
1
3
# Licensed to the Apache Software Foundation (ASF) under one
2
4
# or more contributor license agreements. See the NOTICE file
3
5
# distributed with this work for additional information
15
17
# specific language governing permissions and limitations
16
18
# under the License.
17
19
20
+ """
21
+ Module : this module to decode using beam search
22
+ https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/utils/CTCDecoder/BeamSearch.py
23
+ """
24
+
18
25
from __future__ import division
19
26
from __future__ import print_function
20
27
import numpy as np
21
28
22
29
class BeamEntry :
23
- "information about one single beam at specific time-step"
30
+ """
31
+ information about one single beam at specific time-step
32
+ """
24
33
def __init__ (self ):
25
34
self .prTotal = 0 # blank and non-blank
26
35
self .prNonBlank = 0 # non-blank
@@ -30,24 +39,32 @@ def __init__(self):
30
39
self .labeling = () # beam-labeling
31
40
32
41
class BeamState :
33
- "information about the beams at specific time-step"
42
+ """
43
+ information about the beams at specific time-step
44
+ """
34
45
def __init__ (self ):
35
46
self .entries = {}
36
47
37
48
def norm (self ):
38
- "length-normalise LM score"
49
+ """
50
+ length-normalise LM score
51
+ """
39
52
for (k , _ ) in self .entries .items ():
40
53
labelingLen = len (self .entries [k ].labeling )
41
54
self .entries [k ].prText = self .entries [k ].prText ** (1.0 / (labelingLen if labelingLen else 1.0 ))
42
55
43
56
def sort (self ):
44
- "return beam-labelings, sorted by probability"
57
+ """
58
+ return beam-labelings, sorted by probability
59
+ """
45
60
beams = [v for (_ , v ) in self .entries .items ()]
46
61
sortedBeams = sorted (beams , reverse = True , key = lambda x : x .prTotal * x .prText )
47
62
return [x .labeling for x in sortedBeams ]
48
63
49
64
def applyLM (parentBeam , childBeam , classes , lm ):
50
- "calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars"
65
+ """
66
+ calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars
67
+ """
51
68
if lm and not childBeam .lmApplied :
52
69
c1 = classes [parentBeam .labeling [- 1 ] if parentBeam .labeling else classes .index (' ' )] # first char
53
70
c2 = classes [childBeam .labeling [- 1 ]] # second char
@@ -57,12 +74,16 @@ def applyLM(parentBeam, childBeam, classes, lm):
57
74
childBeam .lmApplied = True # only apply LM once per beam entry
58
75
59
76
def addBeam (beamState , labeling ):
60
- "add beam if it does not yet exist"
77
+ """
78
+ add beam if it does not yet exist
79
+ """
61
80
if labeling not in beamState .entries :
62
81
beamState .entries [labeling ] = BeamEntry ()
63
82
64
83
def ctcBeamSearch (mat , classes , lm , k , beamWidth ):
65
- "beam search as described by the paper of Hwang et al. and the paper of Graves et al."
84
+ """
85
+ beam search as described by the paper of Hwang et al. and the paper of Graves et al.
86
+ """
66
87
67
88
blankIdx = len (classes )
68
89
maxT , maxC = mat .shape
@@ -81,14 +102,14 @@ def ctcBeamSearch(mat, classes, lm, k, beamWidth):
81
102
# get beam-labelings of best beams
82
103
bestLabelings = last .sort ()[0 :beamWidth ]
83
104
84
- # go over best beams
105
+ # go over best beams
85
106
for labeling in bestLabelings :
86
107
87
- # probability of paths ending with a non-blank
108
+ # probability of paths ending with a non-blank
88
109
prNonBlank = 0
89
- # in case of non-empty beam
110
+ # in case of non-empty beam
90
111
if labeling :
91
- # probability of paths with repeated last char at the end
112
+ # probability of paths with repeated last char at the end
92
113
try :
93
114
prNonBlank = last .entries [labeling ].prNonBlank * mat [t , labeling [- 1 ]]
94
115
except FloatingPointError :
@@ -119,15 +140,15 @@ def ctcBeamSearch(mat, classes, lm, k, beamWidth):
119
140
else :
120
141
prNonBlank = mat [t , c ] * last .entries [labeling ].prTotal
121
142
122
- # add beam at current time-step if needed
143
+ # add beam at current time-step if needed
123
144
addBeam (curr , newLabeling )
124
145
125
- # fill in data
146
+ # fill in data
126
147
curr .entries [newLabeling ].labeling = newLabeling
127
148
curr .entries [newLabeling ].prNonBlank += prNonBlank
128
149
curr .entries [newLabeling ].prTotal += prNonBlank
129
150
130
- # apply LM
151
+ # apply LM
131
152
applyLM (curr .entries [labeling ], curr .entries [newLabeling ], classes , lm )
132
153
133
154
# set new beam state
@@ -146,4 +167,4 @@ def ctcBeamSearch(mat, classes, lm, k, beamWidth):
146
167
for l in bestLabeling :
147
168
res += classes [l ]
148
169
output .append (res )
149
- return output
170
+ return output
0 commit comments