Skip to content

Commit b08476e

Browse files
authored
Merge pull request #239 from pykt-team/youh_dev
add new
2 parents f4bf7d2 + 4ff5d51 commit b08476e

File tree

9 files changed

+452
-13
lines changed

9 files changed

+452
-13
lines changed

README.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,12 @@ The hyper parameter tunning results of our experiments about all the DLKT models
7676
22. DTransformer: Tracing Knowledge Instead of Patterns: Stable Knowledge Tracing with Diagnostic Transformer
7777
23. stableKT: Enhancing Length Generalization for Attention Based Knowledge Tracing Models with Linear Biases
7878
24. extraKT: Extending Context Window of Attention Based Knowledge Tracing Models via Length Extrapolation
79-
25. ReKT: Revisiting Knowledge Tracing: A Simple and Powerful Model
80-
79+
25. csKT: Addressing Cold-start Problem in Knowledge Tracing via Kernel Bias and Cone Attention
80+
26. LefoKT: Rethinking and Improving Student Learning and Forgetting Processes for Attention Based Knowledge Tracing Models
81+
27. FlucKT: Cognitive Fluctuations Enhanced Attention Network for Knowledge Tracing
82+
28. UKT: Uncertainty-aware Knowledge Tracing
83+
29. HCGKT: Hierarchical Contrastive Graph Knowledge Tracing with Multi-level Feature Learning
84+
30. RobustKT: Enhancing Knowledge Tracing through Decoupling Cognitive Pattern from Error-Prone Data
8185

8286
## Citation
8387

docs/pics/robustkt.png

49.8 KB
Loading

examples/seedwandb/robustkt.yaml

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
program: wandb_robustkt_train.py
2+
method: bayes
3+
metric:
4+
goal: maximize
5+
name: validauc
6+
parameters:
7+
model_name:
8+
values: ["robustkt"]
9+
dataset_name:
10+
values: ["xes"]
11+
emb_type:
12+
values: ["qid"]
13+
save_dir:
14+
values: ["models/akt_tiaocan"]
15+
d_model:
16+
values: [64, 256]
17+
d_ff:
18+
values: [64, 256]
19+
dropout:
20+
values: [0.05,0.1,0.3,0.5]
21+
learning_rate:
22+
values: [1e-3, 1e-4, 1e-5]
23+
num_attn_heads:
24+
values: [4, 8]
25+
n_blocks:
26+
values: [1, 2, 4]
27+
kernel_size:
28+
values: [4,5,8,16,32]
29+
seed:
30+
values: [42, 3407]
31+
fold:
32+
values: [0, 1, 2, 3, 4]

examples/wandb_robustkt_train.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import argparse
2+
from wandb_train import main
3+
4+
if __name__ == "__main__":
5+
parser = argparse.ArgumentParser()
6+
parser.add_argument("--dataset_name", type=str, default="assist2009")
7+
parser.add_argument("--model_name", type=str, default="robustkt")
8+
parser.add_argument("--emb_type", type=str, default="qid")
9+
parser.add_argument("--save_dir", type=str, default="saved_model")
10+
parser.add_argument("--seed", type=int, default=3407)
11+
parser.add_argument("--fold", type=int, default=0)
12+
parser.add_argument("--dropout", type=float, default=0.2)
13+
14+
parser.add_argument("--d_model", type=int, default=256)
15+
parser.add_argument("--d_ff", type=int, default=512)
16+
parser.add_argument("--num_attn_heads", type=int, default=8)
17+
parser.add_argument("--n_blocks", type=int, default=4)
18+
parser.add_argument("--learning_rate", type=float, default=1e-4)
19+
parser.add_argument("--ks", type=float, default=5)
20+
21+
parser.add_argument("--use_wandb", type=int, default=0)
22+
parser.add_argument("--add_uuid", type=int, default=0)
23+
24+
args = parser.parse_args()
25+
26+
params = vars(args)
27+
main(params)

examples/wandb_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def main(params):
3939
with open("../configs/kt_config.json") as f:
4040
config = json.load(f)
4141
train_config = config["train_config"]
42-
if model_name in ["dkvmn","deep_irt", "sakt", "saint","saint++", "akt","folibikt", "atkt", "lpkt", "skvmn", "dimkt"]:
42+
if model_name in ["dkvmn","deep_irt", "sakt", "saint","saint++", "akt", "robustkt", "folibikt", "atkt", "lpkt", "skvmn", "dimkt"]:
4343
train_config["batch_size"] = 64 ## because of OOM
4444
if model_name in ["simplekt","stablekt", "bakt_time", "sparsekt"]:
4545
train_config["batch_size"] = 64 ## because of OOM

pykt/models/evaluate_model.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def evaluate(model, test_loader, model_name, rel=None, save_path=""):
114114
elif model_name == "saint":
115115
y = model(cq.long(), cc.long(), r.long())
116116
y = y[:, 1:]
117-
elif model_name in ["akt","extrakt","folibikt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lefokt_akt", "fluckt"]:
117+
elif model_name in ["akt","extrakt","folibikt", "robustkt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lefokt_akt", "fluckt"]:
118118
y, reg_loss = model(cc.long(), cr.long(), cq.long())
119119
y = y[:,1:]
120120
elif model_name in ["dtransformer"]:
@@ -183,7 +183,7 @@ def early_fusion(curhs, model, model_name):
183183
que_diff = model.diff_layer(curhs[1])#equ 13
184184
p = torch.sigmoid(3.0*stu_ability-que_diff)#equ 14
185185
p = p.squeeze(-1)
186-
elif model_name in ["akt","extrakt", "folibikt","dtransformer","simplekt","stablekt","cskt", "fluckt", "bakt_time", "sparsekt", "lefokt_akt", "ukt", "hcgkt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
186+
elif model_name in ["akt","extrakt", "folibikt","robustkt", "dtransformer","simplekt","stablekt","cskt", "fluckt", "bakt_time", "sparsekt", "lefokt_akt", "ukt", "hcgkt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
187187
output = model.out(curhs[0]).squeeze(-1)
188188
m = nn.Sigmoid()
189189
p = m(output)
@@ -229,7 +229,7 @@ def effective_fusion(df, model, model_name, fusion_type):
229229

230230
curhs, curr = [[], []], []
231231
dcur = {"late_trues": [], "qidxs": [], "questions": [], "concepts": [], "row": [], "concept_preds": []}
232-
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt","extrakt", "folibikt", "dtransformer", "simplekt","stablekt","cskt","fluckt", "ukt", "hcgkt", "bakt_time", "sparsekt","lefokt_akt", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
232+
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt","extrakt", "folibikt", "robustkt", "dtransformer", "simplekt","stablekt","cskt","fluckt", "ukt", "hcgkt", "bakt_time", "sparsekt","lefokt_akt", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
233233
for ui in df:
234234
# 一题一题处理
235235
curdf = ui[1]
@@ -277,7 +277,7 @@ def group_fusion(dmerge, model, model_name, fusion_type, fout):
277277
if cq.shape[1] == 0:
278278
cq = cc
279279

280-
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "dtransformer", "akt","extrakt", "folibikt","simplekt","stablekt","cskt", "fluckt", "ukt", "hcgkt", "bakt_time", "sparsekt","lefokt_akt", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
280+
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "dtransformer", "akt","robustkt", "extrakt", "folibikt","simplekt","stablekt","cskt", "fluckt", "ukt", "hcgkt", "bakt_time", "sparsekt","lefokt_akt", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
281281

282282
alldfs, drest = [], dict() # not predict infos!
283283
# print(f"real bz in group fusion: {rs.shape[0]}")
@@ -374,7 +374,7 @@ def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion
374374
# dkvmn / akt / saint: give cur -> predict cur
375375
# sakt: give past+cur -> predict cur
376376
# kqn: give past+cur -> predict cur
377-
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "dtransformer", "akt","extrakt","folibikt", "simplekt","cskt","fluckt", "stablekt", "ukt", "hcgkt", "bakt_time", "sparsekt", "lefokt_akt", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
377+
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "dtransformer", "akt","extrakt","folibikt", "robustkt", "simplekt","cskt","fluckt", "stablekt", "ukt", "hcgkt", "bakt_time", "sparsekt", "lefokt_akt", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
378378
if save_path != "":
379379
fout = open(save_path, "w", encoding="utf8")
380380
if model_name in hasearly:
@@ -433,7 +433,7 @@ def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion
433433
y = y[:,1:]
434434
elif model_name in ["rekt"]:
435435
y, h = model(dcurori, qtest=True, train=False)
436-
elif model_name in ["akt","extrakt", "folibikt","fluckt","lefokt_akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
436+
elif model_name in ["akt","extrakt", "folibikt","fluckt","robustkt", "lefokt_akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
437437
y, reg_loss, h = model(cc.long(), cr.long(), cq.long(), True)
438438
y = y[:,1:]
439439
elif model_name in ["dtransformer"]:
@@ -934,7 +934,7 @@ def predict_each_group(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid,
934934
# 应该用预测的r更新memory value,但是这里一个知识点一个知识点预测,所以curr不起作用!
935935
y = model(cin.long(), rin.long())
936936
pred = y[0][-1]
937-
elif model_name in ["akt","extrakt","folibikt","fluckt", "lefokt_akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
937+
elif model_name in ["akt","extrakt","folibikt","fluckt", "robustkt","lefokt_akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
938938
#### 输入有question!
939939
if qout != None:
940940
curq = torch.tensor([[qout.item()]]).to(device)
@@ -1318,7 +1318,7 @@ def predict_each_group2(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid,
13181318
elif model_name == "saint":
13191319
y = model(ccq.long(), ccc.long(), curr.long())
13201320
y = y[:, 1:]
1321-
elif model_name in ["akt","extrakt","folibikt", "cakt","fluckt","lefokt_akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
1321+
elif model_name in ["akt","extrakt","folibikt", "robustkt", "cakt","fluckt","lefokt_akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
13221322
y, reg_loss = model(ccc.long(), ccr.long(), ccq.long())
13231323
y = y[:,1:]
13241324
elif model_name in ["dtransformer"]:

pykt/models/init_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .lefokt_akt import LEFOKT_AKT
3838
from .ukt import UKT
3939
from .hcgkt import HCGKT
40+
from .robustkt import Robustkt
4041

4142
device = "cpu" if not torch.cuda.is_available() else "cuda"
4243

@@ -135,6 +136,8 @@ def init_model(model_name, model_config, data_config, emb_type):
135136
model = UKT(data_config["num_c"], data_config["num_q"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
136137
elif model_name == "hcgkt":
137138
model = HCGKT(data_config["num_c"], data_config["num_q"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
139+
elif model_name == "robustkt":
140+
model = Robustkt(data_config["num_c"], data_config["num_q"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
138141
elif model_name == "dtransformer":
139142
model = DTransformer(data_config["num_c"], data_config["num_q"], **model_config, emb_type=emb_type,
140143
emb_path=data_config["emb_path"]).to(device)

0 commit comments

Comments
 (0)