Skip to content

Commit 4547cf8

Browse files
committed
update
1 parent 3bf55ab commit 4547cf8

File tree

4 files changed

+36
-19
lines changed

4 files changed

+36
-19
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ dist/
1313
*.pyc
1414
docs/
1515
*.pkl
16-
saved_model
16+
saved_model
17+
build/

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ PyKT is a python library build upon PyTorch to train deep learning based knowled
77
Use the following command to install PyKY:
88

99
```
10-
pip install -U pykt-toolkit
10+
pip install -U pykt-toolkit -i https://pypi.python.org/simple
1111
```
1212

1313
<!--

build.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
rm -r dist
21
python -m build
3-
twine upload dist/*
2+
twine upload dist/*
3+
rm -r dist

pykt/datasets/data_loader.py

+31-15
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,29 @@
88
from torch.cuda import FloatTensor, LongTensor
99
import numpy as np
1010

11+
12+
1113
class KTDataset(Dataset):
1214
"""Dataset for KT
1315
can use to init dataset for: (for models except dkt_forget)
1416
train data, valid data
1517
common test data(concept level evaluation), real educational scenario test data(question level evaluation).
18+
19+
Args:
20+
file_path (str): train_valid/test file path
21+
input_type (list[str]): the input type of the dataset, values are in ["questions", "concepts"]
22+
folds (set(int)): the folds used to generate dataset, -1 for test data
23+
qtest (bool, optional): is question evaluation or not. Defaults to False.
1624
"""
1725
def __init__(self, file_path, input_type, folds, qtest=False):
18-
"""init KTDataset
19-
20-
Args:
21-
file_path (str): train_valid/test file path
22-
input_type (list[str]): the input type of the dataset, values are in ["questions", "concepts"]
23-
folds (set(int)): the folds used to generate dataset, -1 for test data
24-
qtest (bool, optional): is question evaluation or not. Defaults to False.
25-
"""
26+
# """init KTDataset
27+
28+
# Args:
29+
# file_path (str): train_valid/test file path
30+
# input_type (list[str]): the input type of the dataset, values are in ["questions", "concepts"]
31+
# folds (set(int)): the folds used to generate dataset, -1 for test data
32+
# qtest (bool, optional): is question evaluation or not. Defaults to False.
33+
# """
2634
super(KTDataset, self).__init__()
2735
sequence_path = file_path
2836
self.input_type = input_type
@@ -65,6 +73,7 @@ def __getitem__(self, index):
6573
index (int): the index of the data want to get
6674
6775
Returns:
76+
(tuple): tuple containing:
6877
q_seqs (torch.tensor): question id sequence of the 0~seqlen-2 interactions
6978
c_seqs (torch.tensor): knowledge concept id sequence of the 0~seqlen-2 interactions
7079
r_seqs (torch.tensor): response id sequence of the 0~seqlen-2 interactions
@@ -95,6 +104,8 @@ def __getitem__(self, index):
95104
dcur[key] = self.dqtest[key][index]
96105
return q_seqs, c_seqs, r_seqs, qshft_seqs, cshft_seqs, rshft_seqs, mask_seqs, select_masks, dcur
97106

107+
108+
98109
def load_data(self, sequence_path, folds, pad_val=-1):
99110
"""load data
100111
@@ -103,14 +114,19 @@ def load_data(self, sequence_path, folds, pad_val=-1):
103114
folds (list[int]):
104115
pad_val (int, optional): pad value. Defaults to -1.
105116
106-
Returns:
107-
q_seqs (torch.tensor): question id sequence of the 0~seqlen-1 interactions
108-
c_seqs (torch.tensor): knowledge concept id sequence of the 0~seqlen-1 interactions
109-
r_seqs (torch.tensor): response id sequence of the 0~seqlen-1 interactions
110-
mask_seqs (torch.tensor): masked value sequence, shape is seqlen-1
111-
select_masks (torch.tensor): is select to calculate the performance or not, 0 is not selected, 1 is selected, only available for 1~seqlen-1, shape is seqlen-1
112-
dqtest (dict): not null only self.qtest is True, for question level evaluation
117+
Returns:
118+
(tuple): tuple containing
119+
120+
- **q_seqs (torch.tensor)**: question id sequence of the 0~seqlen-1 interactions
121+
- c_seqs (torch.tensor): knowledge concept id sequence of the 0~seqlen-1 interactions
122+
- r_seqs (torch.tensor): response id sequence of the 0~seqlen-1 interactions
123+
- mask_seqs (torch.tensor): masked value sequence, shape is seqlen-1
124+
- select_masks (torch.tensor): is select to calculate the performance or not, 0 is not selected, 1 is selected, only available for 1~seqlen-1, shape is seqlen-1
125+
- dqtest (dict): not null only self.qtest is True, for question level evaluation
126+
127+
113128
"""
129+
114130
seq_qids, seq_cids, seq_rights, seq_mask = [], [], [], []
115131
df = pd.read_csv(sequence_path)
116132
df = df[df["fold"].isin(folds)]

0 commit comments

Comments
 (0)