diff --git a/pyproject.toml b/pyproject.toml index 727493b..41be859 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.2.3" +version = "5.3.0" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 3685b4f..4c8640e 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -219,6 +219,7 @@ def __init__( batch_size: int = 0, sort_by_length: bool = True, max_seq_len: int = math.inf, + device: str = "cpu", ): if dataframe.empty: raise ValueError("Training data is inadequate.") @@ -248,10 +249,10 @@ def __init__( max_seq_len = max(seq_lens) sequences_truncated = sequences[:, :max_seq_len] self.batches[i] = ( - sequences_truncated.transpose(0, 1), - self.t_train[start_index:end_index], - self.y_train[start_index:end_index], - seq_lens, + sequences_truncated.transpose(0, 1).to(device), + self.t_train[start_index:end_index].to(device), + self.y_train[start_index:end_index].to(device), + seq_lens.to(device), ) def __getitem__(self, idx):