Skip to content

Commit 68caaed

Browse files
Jung SeunghwanJung Seunghwan
Jung Seunghwan
authored and
Jung Seunghwan
committed
update codes
1 parent 921c35f commit 68caaed

25 files changed

+584
-507
lines changed

Diff for: BeamSearch.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#!/usr/bin/env python3
2+
13
# Licensed to the Apache Software Foundation (ASF) under one
24
# or more contributor license agreements. See the NOTICE file
35
# distributed with this work for additional information
@@ -15,12 +17,19 @@
1517
# specific language governing permissions and limitations
1618
# under the License.
1719

20+
"""
21+
Module : this module to decode using beam search
22+
https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/utils/CTCDecoder/BeamSearch.py
23+
"""
24+
1825
from __future__ import division
1926
from __future__ import print_function
2027
import numpy as np
2128

2229
class BeamEntry:
23-
"information about one single beam at specific time-step"
30+
"""
31+
information about one single beam at specific time-step
32+
"""
2433
def __init__(self):
2534
self.prTotal = 0 # blank and non-blank
2635
self.prNonBlank = 0 # non-blank
@@ -30,24 +39,32 @@ def __init__(self):
3039
self.labeling = () # beam-labeling
3140

3241
class BeamState:
33-
"information about the beams at specific time-step"
42+
"""
43+
information about the beams at specific time-step
44+
"""
3445
def __init__(self):
3546
self.entries = {}
3647

3748
def norm(self):
38-
"length-normalise LM score"
49+
"""
50+
length-normalise LM score
51+
"""
3952
for (k, _) in self.entries.items():
4053
labelingLen = len(self.entries[k].labeling)
4154
self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0))
4255

4356
def sort(self):
44-
"return beam-labelings, sorted by probability"
57+
"""
58+
return beam-labelings, sorted by probability
59+
"""
4560
beams = [v for (_, v) in self.entries.items()]
4661
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)
4762
return [x.labeling for x in sortedBeams]
4863

4964
def applyLM(parentBeam, childBeam, classes, lm):
50-
"calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars"
65+
"""
66+
calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars
67+
"""
5168
if lm and not childBeam.lmApplied:
5269
c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char
5370
c2 = classes[childBeam.labeling[-1]] # second char
@@ -57,12 +74,16 @@ def applyLM(parentBeam, childBeam, classes, lm):
5774
childBeam.lmApplied = True # only apply LM once per beam entry
5875

5976
def addBeam(beamState, labeling):
60-
"add beam if it does not yet exist"
77+
"""
78+
add beam if it does not yet exist
79+
"""
6180
if labeling not in beamState.entries:
6281
beamState.entries[labeling] = BeamEntry()
6382

6483
def ctcBeamSearch(mat, classes, lm, k, beamWidth):
65-
"beam search as described by the paper of Hwang et al. and the paper of Graves et al."
84+
"""
85+
beam search as described by the paper of Hwang et al. and the paper of Graves et al.
86+
"""
6687

6788
blankIdx = len(classes)
6889
maxT, maxC = mat.shape
@@ -81,14 +102,14 @@ def ctcBeamSearch(mat, classes, lm, k, beamWidth):
81102
# get beam-labelings of best beams
82103
bestLabelings = last.sort()[0:beamWidth]
83104

84-
# go over best beams
105+
# go over best beams
85106
for labeling in bestLabelings:
86107

87-
# probability of paths ending with a non-blank
108+
# probability of paths ending with a non-blank
88109
prNonBlank = 0
89-
# in case of non-empty beam
110+
# in case of non-empty beam
90111
if labeling:
91-
# probability of paths with repeated last char at the end
112+
# probability of paths with repeated last char at the end
92113
try:
93114
prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]]
94115
except FloatingPointError:
@@ -119,15 +140,15 @@ def ctcBeamSearch(mat, classes, lm, k, beamWidth):
119140
else:
120141
prNonBlank = mat[t, c] * last.entries[labeling].prTotal
121142

122-
# add beam at current time-step if needed
143+
# add beam at current time-step if needed
123144
addBeam(curr, newLabeling)
124145

125-
# fill in data
146+
# fill in data
126147
curr.entries[newLabeling].labeling = newLabeling
127148
curr.entries[newLabeling].prNonBlank += prNonBlank
128149
curr.entries[newLabeling].prTotal += prNonBlank
129150

130-
# apply LM
151+
# apply LM
131152
applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm)
132153

133154
# set new beam state
@@ -146,4 +167,4 @@ def ctcBeamSearch(mat, classes, lm, k, beamWidth):
146167
for l in bestLabeling:
147168
res += classes[l]
148169
output.append(res)
149-
return output
170+
return output

Diff for: LICENSE

-21
This file was deleted.

0 commit comments

Comments
 (0)