Skip to content

Commit

Permalink
add fwd_scores
Browse files Browse the repository at this point in the history
  • Loading branch information
iiSeymour committed Jan 26, 2021
2 parents 3615abc + 33dd717 commit 496e924
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion bonito/crf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,17 @@ def logZ(self, scores, S:semiring=Log):
beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
return seqdist.sparse.logZ(Ms, self.idx, alpha_0, beta_T, S)

def forward_scores(self, scores, S: semiring=Log):
T, N, _ = scores.shape
Ms = scores.reshape(T, N, -1, self.n_base + 1)
alpha_0 = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
return seqdist.sparse.fwd_scores_cupy(Ms, self.idx, alpha_0, S, K=1)

def backward_scores(self, scores, S: semiring=Log):
T, N, _ = scores.shape
Ms = scores.reshape(T, N, -1, self.n_base + 1)
beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
return seqdist.sparse.logZ_bwd_cupy(Ms, self.idx, beta_T, S, K=1)
return seqdist.sparse.bwd_scores_cupy(Ms, self.idx, beta_T, S, K=1)

def viterbi(self, scores):
traceback = self.posteriors(scores, Max)
Expand Down

0 comments on commit 496e924

Please sign in to comment.