diff --git a/data_utils.py b/data_utils.py index ae9073f2e..13b1dced9 100644 --- a/data_utils.py +++ b/data_utils.py @@ -366,7 +366,7 @@ def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None): end = beg + 1 cnt_ngram = 1 while end < seg_len: - if _is_start_piece(sp.IdToPiece(seg[beg].item())): + if _is_start_piece(sp.IdToPiece(seg[end].item())): cnt_ngram += 1 if cnt_ngram > n: break