forked from pemistahl/lingua-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.go
124 lines (111 loc) · 3.5 KB
/
model.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*
* Copyright © 2021-present Peter M. Stahl [email protected]
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expressed or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package lingua
import (
"fmt"
"regexp"
"strings"
)
type trainingDataLanguageModel struct {
language Language
absoluteFrequencies map[ngram]uint32
relativeFrequencies map[ngram]float64
}
type testDataLanguageModel struct {
ngrams [][]ngram
}
func newTrainingDataLanguageModel(
text []string,
language Language,
ngramLength int,
charClass string,
lowerNgramAbsoluteFrequencies map[ngram]uint32,
) trainingDataLanguageModel {
absoluteFrequencies := computeAbsoluteFrequencies(text, ngramLength, charClass)
relativeFrequencies := computeRelativeFrequencies(ngramLength, absoluteFrequencies, lowerNgramAbsoluteFrequencies)
return trainingDataLanguageModel{
language: language,
absoluteFrequencies: absoluteFrequencies,
relativeFrequencies: relativeFrequencies,
}
}
func newTestDataLanguageModel(words []string, ngramLength int) testDataLanguageModel {
if ngramLength > maxNgramLength {
panic(fmt.Sprintf("ngram length %v is greater than %v", ngramLength, maxNgramLength))
}
ngrams := make(map[ngram]struct{})
for _, word := range words {
chars := []rune(word)
charsCount := len(chars)
if charsCount >= ngramLength {
for i := 0; i <= charsCount-ngramLength; i++ {
slice := string(chars[i : i+ngramLength])
ngrams[newNgram(slice)] = struct{}{}
}
}
}
lowerOrderNgrams := make([][]ngram, len(ngrams))
i := 0
for n := range ngrams {
lowerOrderNgrams[i] = n.rangeOfLowerOrderNgrams()
i++
}
return testDataLanguageModel{ngrams: lowerOrderNgrams}
}
func computeAbsoluteFrequencies(
text []string,
ngramLength int,
charClass string,
) map[ngram]uint32 {
absoluteFrequencies := make(map[ngram]uint32)
regex, err := regexp.Compile(fmt.Sprintf("^[%v]+$", charClass))
if err != nil {
panic(fmt.Sprintf("The character class '%v' cannot be compiled to a valid regular expression", charClass))
}
for _, line := range text {
chars := []rune(strings.ToLower(line))
for i := 0; i <= len(chars)-ngramLength; i++ {
slice := string(chars[i : i+ngramLength])
if regex.MatchString(slice) {
absoluteFrequencies[newNgram(slice)]++
}
}
}
return absoluteFrequencies
}
func computeRelativeFrequencies(
ngramLength int,
absoluteFrequencies map[ngram]uint32,
lowerNgramAbsoluteFrequencies map[ngram]uint32,
) map[ngram]float64 {
ngramProbabilities := make(map[ngram]float64, len(absoluteFrequencies))
var totalNgramFrequency uint32
for _, frequency := range absoluteFrequencies {
totalNgramFrequency += frequency
}
for ngram, frequency := range absoluteFrequencies {
var denominator uint32
if ngramLength == 1 || len(lowerNgramAbsoluteFrequencies) == 0 {
denominator = totalNgramFrequency
} else {
chars := []rune(ngram.value)
slice := string(chars[0 : ngramLength-1])
denominator = lowerNgramAbsoluteFrequencies[newNgram(slice)]
}
ngramProbabilities[ngram] = float64(frequency) / float64(denominator)
}
return ngramProbabilities
}