Skip to content

Commit cacdf10

Browse files
Merge pull request #247 from WorksApplications/fix/reset-unknode
Reset `unkNodes` for each lattice position
2 parents b5d8753 + 417370c commit cacdf10

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

src/main/java/com/worksap/nlp/sudachi/JapaneseTokenizer.java

+23-10
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,13 @@ MorphemeList tokenizeSentence(Tokenizer.SplitMode mode, UTF8InputText input) {
205205
LatticeImpl buildLattice(UTF8InputText input) {
206206
byte[] bytes = input.getByteText();
207207
lattice.resize(bytes.length);
208-
ArrayList<LatticeNodeImpl> unkNodes = new ArrayList<>(64);
208+
ArrayList<LatticeNodeImpl> crrNodes = new ArrayList<>(64);
209209
WordLookup wordLookup = lexicon.makeLookup();
210210
for (int byteBoundary = 0; byteBoundary < bytes.length; byteBoundary++) {
211211
if (!input.canBow(byteBoundary) || !lattice.hasPreviousNode(byteBoundary)) {
212212
continue;
213213
}
214+
crrNodes.clear();
214215
wordLookup.reset(bytes, byteBoundary, bytes.length);
215216
long wordMask = 0L;
216217
while (wordLookup.next()) {
@@ -224,7 +225,7 @@ LatticeImpl buildLattice(UTF8InputText input) {
224225
int wordId = wordIds[word];
225226
LatticeNodeImpl n = new LatticeNodeImpl(lexicon, lexicon.parameters(wordId), wordId);
226227
lattice.insert(byteBoundary, end, n);
227-
unkNodes.add(n);
228+
crrNodes.add(n);
228229
wordMask = WordMask.addNth(wordMask, end - byteBoundary);
229230
}
230231
}
@@ -233,11 +234,11 @@ LatticeImpl buildLattice(UTF8InputText input) {
233234
// OOV
234235
if (!input.getCharCategoryTypes(byteBoundary).contains(CategoryType.NOOOVBOW)) {
235236
for (OovProviderPlugin plugin : oovProviderPlugins) {
236-
wordMaskWithOov = provideOovs(plugin, input, unkNodes, byteBoundary, wordMaskWithOov);
237+
wordMaskWithOov = provideOovs(plugin, input, byteBoundary, wordMaskWithOov, crrNodes);
237238
}
238239
}
239240
if (wordMaskWithOov == 0 && defaultOovProvider != null) {
240-
wordMaskWithOov = provideOovs(defaultOovProvider, input, unkNodes, byteBoundary, wordMaskWithOov);
241+
wordMaskWithOov = provideOovs(defaultOovProvider, input, byteBoundary, wordMaskWithOov, crrNodes);
241242
}
242243
if (wordMaskWithOov == 0) {
243244
throw new IllegalStateException("failed to found any morpheme candidate at boundary " + byteBoundary);
@@ -249,19 +250,31 @@ LatticeImpl buildLattice(UTF8InputText input) {
249250
}
250251

251252
/**
252-
* Create OOV nodes using plugin and add them to the lattice and unkNodes.
253+
* Create OOV nodes using plugin at the given position and update crrNodes and
254+
* wordMask.
253255
*
256+
* @param plugin
257+
* OOVProviderPlugin to use
258+
* @param input
259+
* Full inputText
260+
* @param boundary
261+
* Byte index of inputText where OOV nodes should start from
262+
* @param crrNodes
263+
* Nodes already provided by dict or other plugins. Provided nodes
264+
* should be appended to this
265+
* @param wordMask
266+
* Word mask based on crrNodes
254267
* @return wordMask updated based on created OOV nodes.
255268
*/
256-
private long provideOovs(OovProviderPlugin plugin, UTF8InputText input, ArrayList<LatticeNodeImpl> unkNodes,
257-
int boundary, long wordMask) {
258-
int initialSize = unkNodes.size();
259-
int created = plugin.provideOOV(input, boundary, wordMask, unkNodes);
269+
private long provideOovs(OovProviderPlugin plugin, UTF8InputText input, int boundary, long wordMask,
270+
ArrayList<LatticeNodeImpl> crrNodes) {
271+
int initialSize = crrNodes.size();
272+
int created = plugin.provideOOV(input, boundary, wordMask, crrNodes);
260273
if (created == 0) {
261274
return wordMask;
262275
}
263276
for (int i = initialSize; i < initialSize + created; ++i) {
264-
LatticeNodeImpl node = unkNodes.get(i);
277+
LatticeNodeImpl node = crrNodes.get(i);
265278
lattice.insert(node.getBegin(), node.getEnd(), node);
266279
wordMask = WordMask.addNth(wordMask, node.getEnd() - node.getBegin());
267280
}

0 commit comments

Comments
 (0)