Skip to content

Commit 496e924

Browse files
committed
add fwd_scores
2 parents 3615abc + 33dd717 commit 496e924

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

bonito/crf/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,17 @@ def logZ(self, scores, S:semiring=Log):
118118
beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
119119
return seqdist.sparse.logZ(Ms, self.idx, alpha_0, beta_T, S)
120120

121+
def forward_scores(self, scores, S: semiring=Log):
122+
T, N, _ = scores.shape
123+
Ms = scores.reshape(T, N, -1, self.n_base + 1)
124+
alpha_0 = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
125+
return seqdist.sparse.fwd_scores_cupy(Ms, self.idx, alpha_0, S, K=1)
126+
121127
def backward_scores(self, scores, S: semiring=Log):
122128
T, N, _ = scores.shape
123129
Ms = scores.reshape(T, N, -1, self.n_base + 1)
124130
beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
125-
return seqdist.sparse.logZ_bwd_cupy(Ms, self.idx, beta_T, S, K=1)
131+
return seqdist.sparse.bwd_scores_cupy(Ms, self.idx, beta_T, S, K=1)
126132

127133
def viterbi(self, scores):
128134
traceback = self.posteriors(scores, Max)

0 commit comments

Comments
 (0)