@@ -114,7 +114,7 @@ def evaluate(model, test_loader, model_name, rel=None, save_path=""):
114
114
elif model_name == "saint" :
115
115
y = model (cq .long (), cc .long (), r .long ())
116
116
y = y [:, 1 :]
117
- elif model_name in ["akt" ,"extrakt" ,"folibikt" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" , "lefokt_akt" , "fluckt" ]:
117
+ elif model_name in ["akt" ,"extrakt" ,"folibikt" , "robustkt" , " akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" , "lefokt_akt" , "fluckt" ]:
118
118
y , reg_loss = model (cc .long (), cr .long (), cq .long ())
119
119
y = y [:,1 :]
120
120
elif model_name in ["dtransformer" ]:
@@ -183,7 +183,7 @@ def early_fusion(curhs, model, model_name):
183
183
que_diff = model .diff_layer (curhs [1 ])#equ 13
184
184
p = torch .sigmoid (3.0 * stu_ability - que_diff )#equ 14
185
185
p = p .squeeze (- 1 )
186
- elif model_name in ["akt" ,"extrakt" , "folibikt" ,"dtransformer" ,"simplekt" ,"stablekt" ,"cskt" , "fluckt" , "bakt_time" , "sparsekt" , "lefokt_akt" , "ukt" , "hcgkt" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" ]:
186
+ elif model_name in ["akt" ,"extrakt" , "folibikt" ,"robustkt" , " dtransformer" ,"simplekt" ,"stablekt" ,"cskt" , "fluckt" , "bakt_time" , "sparsekt" , "lefokt_akt" , "ukt" , "hcgkt" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" ]:
187
187
output = model .out (curhs [0 ]).squeeze (- 1 )
188
188
m = nn .Sigmoid ()
189
189
p = m (output )
@@ -229,7 +229,7 @@ def effective_fusion(df, model, model_name, fusion_type):
229
229
230
230
curhs , curr = [[], []], []
231
231
dcur = {"late_trues" : [], "qidxs" : [], "questions" : [], "concepts" : [], "row" : [], "concept_preds" : []}
232
- hasearly = ["dkvmn" ,"deep_irt" , "skvmn" , "kqn" , "akt" ,"extrakt" , "folibikt" , "dtransformer" , "simplekt" ,"stablekt" ,"cskt" ,"fluckt" , "ukt" , "hcgkt" , "bakt_time" , "sparsekt" ,"lefokt_akt" , "saint" , "sakt" , "hawkes" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" , "lpkt" ]
232
+ hasearly = ["dkvmn" ,"deep_irt" , "skvmn" , "kqn" , "akt" ,"extrakt" , "folibikt" , "robustkt" , " dtransformer" , "simplekt" ,"stablekt" ,"cskt" ,"fluckt" , "ukt" , "hcgkt" , "bakt_time" , "sparsekt" ,"lefokt_akt" , "saint" , "sakt" , "hawkes" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" , "lpkt" ]
233
233
for ui in df :
234
234
# 一题一题处理
235
235
curdf = ui [1 ]
@@ -277,7 +277,7 @@ def group_fusion(dmerge, model, model_name, fusion_type, fout):
277
277
if cq .shape [1 ] == 0 :
278
278
cq = cc
279
279
280
- hasearly = ["dkvmn" ,"deep_irt" , "skvmn" , "kqn" , "dtransformer" , "akt" ,"extrakt" , "folibikt" ,"simplekt" ,"stablekt" ,"cskt" , "fluckt" , "ukt" , "hcgkt" , "bakt_time" , "sparsekt" ,"lefokt_akt" , "saint" , "sakt" , "hawkes" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" , "lpkt" ]
280
+ hasearly = ["dkvmn" ,"deep_irt" , "skvmn" , "kqn" , "dtransformer" , "akt" ,"robustkt" , " extrakt" , "folibikt" ,"simplekt" ,"stablekt" ,"cskt" , "fluckt" , "ukt" , "hcgkt" , "bakt_time" , "sparsekt" ,"lefokt_akt" , "saint" , "sakt" , "hawkes" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" , "lpkt" ]
281
281
282
282
alldfs , drest = [], dict () # not predict infos!
283
283
# print(f"real bz in group fusion: {rs.shape[0]}")
@@ -374,7 +374,7 @@ def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion
374
374
# dkvmn / akt / saint: give cur -> predict cur
375
375
# sakt: give past+cur -> predict cur
376
376
# kqn: give past+cur -> predict cur
377
- hasearly = ["dkvmn" ,"deep_irt" , "skvmn" , "kqn" , "dtransformer" , "akt" ,"extrakt" ,"folibikt" , "simplekt" ,"cskt" ,"fluckt" , "stablekt" , "ukt" , "hcgkt" , "bakt_time" , "sparsekt" , "lefokt_akt" , "saint" , "sakt" , "hawkes" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" , "lpkt" ]
377
+ hasearly = ["dkvmn" ,"deep_irt" , "skvmn" , "kqn" , "dtransformer" , "akt" ,"extrakt" ,"folibikt" , "robustkt" , " simplekt" ,"cskt" ,"fluckt" , "stablekt" , "ukt" , "hcgkt" , "bakt_time" , "sparsekt" , "lefokt_akt" , "saint" , "sakt" , "hawkes" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" , "lpkt" ]
378
378
if save_path != "" :
379
379
fout = open (save_path , "w" , encoding = "utf8" )
380
380
if model_name in hasearly :
@@ -433,7 +433,7 @@ def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion
433
433
y = y [:,1 :]
434
434
elif model_name in ["rekt" ]:
435
435
y , h = model (dcurori , qtest = True , train = False )
436
- elif model_name in ["akt" ,"extrakt" , "folibikt" ,"fluckt" ,"lefokt_akt" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" ]:
436
+ elif model_name in ["akt" ,"extrakt" , "folibikt" ,"fluckt" ,"robustkt" , " lefokt_akt" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" ]:
437
437
y , reg_loss , h = model (cc .long (), cr .long (), cq .long (), True )
438
438
y = y [:,1 :]
439
439
elif model_name in ["dtransformer" ]:
@@ -934,7 +934,7 @@ def predict_each_group(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid,
934
934
# 应该用预测的r更新memory value,但是这里一个知识点一个知识点预测,所以curr不起作用!
935
935
y = model (cin .long (), rin .long ())
936
936
pred = y [0 ][- 1 ]
937
- elif model_name in ["akt" ,"extrakt" ,"folibikt" ,"fluckt" , "lefokt_akt" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" ]:
937
+ elif model_name in ["akt" ,"extrakt" ,"folibikt" ,"fluckt" , "robustkt" , " lefokt_akt" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" ]:
938
938
#### 输入有question!
939
939
if qout != None :
940
940
curq = torch .tensor ([[qout .item ()]]).to (device )
@@ -1318,7 +1318,7 @@ def predict_each_group2(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid,
1318
1318
elif model_name == "saint" :
1319
1319
y = model (ccq .long (), ccc .long (), curr .long ())
1320
1320
y = y [:, 1 :]
1321
- elif model_name in ["akt" ,"extrakt" ,"folibikt" , "cakt" ,"fluckt" ,"lefokt_akt" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" ]:
1321
+ elif model_name in ["akt" ,"extrakt" ,"folibikt" , "robustkt" , " cakt" ,"fluckt" ,"lefokt_akt" , "akt_vector" , "akt_norasch" , "akt_mono" , "akt_attn" , "aktattn_pos" , "aktmono_pos" , "akt_raschx" , "akt_raschy" , "aktvec_raschx" ]:
1322
1322
y , reg_loss = model (ccc .long (), ccr .long (), ccq .long ())
1323
1323
y = y [:,1 :]
1324
1324
elif model_name in ["dtransformer" ]:
0 commit comments