We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 3615abc + 33dd717 commit 496e924Copy full SHA for 496e924
bonito/crf/model.py
@@ -118,11 +118,17 @@ def logZ(self, scores, S:semiring=Log):
118
beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
119
return seqdist.sparse.logZ(Ms, self.idx, alpha_0, beta_T, S)
120
121
+ def forward_scores(self, scores, S: semiring=Log):
122
+ T, N, _ = scores.shape
123
+ Ms = scores.reshape(T, N, -1, self.n_base + 1)
124
+ alpha_0 = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
125
+ return seqdist.sparse.fwd_scores_cupy(Ms, self.idx, alpha_0, S, K=1)
126
+
127
def backward_scores(self, scores, S: semiring=Log):
128
T, N, _ = scores.shape
129
Ms = scores.reshape(T, N, -1, self.n_base + 1)
130
- return seqdist.sparse.logZ_bwd_cupy(Ms, self.idx, beta_T, S, K=1)
131
+ return seqdist.sparse.bwd_scores_cupy(Ms, self.idx, beta_T, S, K=1)
132
133
def viterbi(self, scores):
134
traceback = self.posteriors(scores, Max)
0 commit comments