From ae7575cf5180d1ab631a9970cc73bcfb7a723f62 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Fri, 1 Nov 2024 14:07:26 +0800 Subject: [PATCH] Feat/support to set device for batch dataset (#147) * Feat/support to set device for BatchDataset * bump version --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 727493b..41be859 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.2.3" +version = "5.3.0" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 3685b4f..4c8640e 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -219,6 +219,7 @@ def __init__( batch_size: int = 0, sort_by_length: bool = True, max_seq_len: int = math.inf, + device: str = "cpu", ): if dataframe.empty: raise ValueError("Training data is inadequate.") @@ -248,10 +249,10 @@ def __init__( max_seq_len = max(seq_lens) sequences_truncated = sequences[:, :max_seq_len] self.batches[i] = ( - sequences_truncated.transpose(0, 1), - self.t_train[start_index:end_index], - self.y_train[start_index:end_index], - seq_lens, + sequences_truncated.transpose(0, 1).to(device), + self.t_train[start_index:end_index].to(device), + self.y_train[start_index:end_index].to(device), + seq_lens.to(device), ) def __getitem__(self, idx):