forked from ygxw0909/MST-SQL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
wikisql_prediction_simple.py
135 lines (111 loc) · 5.38 KB
/
wikisql_prediction_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os, sys
sys.path.append("..")
import json
import pickle
from utils import read_conf
from modeling.torch_model import HydraTorch
from wikisql_lib.query import Query
from featurizer import HydraFeaturizer, SQLDataset
from wikisql_lib.dbengine import DBEngine
def print_metric(label_file, pred_file):
sp = [({"sel": json.loads(ls)["select"], "agg": json.loads(ls)["agg"], "conds": json.loads(ls)["conditions"]}, json.loads(lp)["query"]) for ls, lp in zip(open(label_file), open(pred_file))]
sel_acc = sum(p["sel"] == s["sel"] for s, p in sp) / len(sp)
agg_acc = sum(p["agg"] == s["agg"] for s, p in sp) / len(sp)
wcn_acc = sum(len(p["conds"]) == len(s["conds"]) for s, p in sp) / len(sp)
def wcc_match(a, b):
a = sorted(a, key=lambda k: k[0])
b = sorted(b, key=lambda k: k[0])
return [c[0] for c in a] == [c[0] for c in b]
def wco_match(a, b):
a = sorted(a, key=lambda k: k[0])
b = sorted(b, key=lambda k: k[0])
return [c[1] for c in a] == [c[1] for c in b]
def wcv_match(a, b):
a = sorted(a, key=lambda k: k[0])
b = sorted(b, key=lambda k: k[0])
return [str(c[2]).lower() for c in a] == [str(c[2]).lower() for c in b]
wcc_acc = sum(wcc_match(p["conds"], s["conds"]) for s, p in sp) / len(sp)
wco_acc = sum(wco_match(p["conds"], s["conds"]) for s, p in sp) / len(sp)
wcv_acc = sum(wcv_match(p["conds"], s["conds"]) for s, p in sp) / len(sp)
# engine = DBEngine(db_file)
exact_match = []
# grades = []
with open(label_file) as fs, open(pred_file) as fp:
for ls, lp in zip(fs, fp):
eg = json.loads(ls)
ep = json.loads(lp)
qg = Query.from_dict({"sel": eg["select"], "agg": eg["agg"], "conds": eg["conditions"]}, ordered=False)
# gold = engine.execute_query(eg['table_id'], qg, lower=True)
qp = Query.from_dict(ep['query'], ordered=False)
# try:
# pred = engine.execute_query(eg['table_id'], qp, lower=True)
# except Exception as e:
# pred = repr(e)
# correct = pred == gold
match = qp == qg
# grades.append(correct)
exact_match.append(match)
res = 'lf_acc: {}\nsel_acc: {}\nagg_acc: {}\nwcn_acc: {}\nwcc_acc: {}\nwco_acc: {}\nwcv_acc: {}\n' \
.format(sum(exact_match) / len(exact_match), sel_acc, agg_acc, wcn_acc, wcc_acc, wco_acc, wcv_acc)
print(res)
return res
def execute_one_test(dataset, model_moment, epoch):
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
model_path = "output/" + model_moment
in_file = "data/esql/esql_{}_content.jsonl".format(dataset)
# in_file = "data/wikisql/wiki{}_content.jsonl".format(dataset)
# db_file = "data/wikisql/{}.db".format(dataset)
# label_file = "data/wikisql/{}.jsonl".format(dataset)
out_path = "predictions/{}_{}_{}".format(model_moment, epoch, dataset)
if not os.path.exists(out_path):
os.makedirs(out_path)
out_file = os.path.join(out_path, "out.jsonl")
eg_out_file = os.path.join(out_path, "out_eg.jsonl")
model_out_file = os.path.join(out_path, "model_out.pkl")
test_result_file = os.path.join(out_path, "result.txt")
# engine = DBEngine(db_file)
config = read_conf(os.path.join(model_path, "model.conf"))
# config = read_conf("../conf/wikisql_content.conf")
# config["DEBUG"] = 1
featurizer = HydraFeaturizer(config)
pred_data = SQLDataset(in_file, config, featurizer, False)
print("num of samples: {0}".format(len(pred_data.input_features)))
model = HydraTorch(config)
model.load(model_path, epoch)
if "DEBUG" in config:
model_out_file = model_out_file + ".partial"
model_outputs = model.dataset_inference(pred_data)
pickle.dump(model_outputs, open(model_out_file, "wb"))
# model_outputs = pickle.load(open(model_out_file, "rb"))
print("===HydraNet===")
pred_sqls = model.predict_SQL(pred_data, model_outputs=model_outputs)
with open(out_file, "w") as g:
for pred_sql in pred_sqls:
# print(pred_sql)
result = {"query": {}}
result["query"]["agg"] = int(pred_sql[0])
result["query"]["sel"] = int(pred_sql[1])
result["query"]["conds"] = [(int(cond[0]), int(cond[1]), str(cond[2])) for cond in pred_sql[2]]
g.write(json.dumps(result) + "\n")
normal_res = print_metric(in_file, out_file)
# print("===HydraNet+EG===")
# pred_sqls = model.predict_SQL_with_EG(engine, pred_data, model_outputs=model_outputs)
# with open(eg_out_file, "w") as g:
# for pred_sql in pred_sqls:
# # print(pred_sql)
# result = {"query": {}}
# result["query"]["agg"] = int(pred_sql[0])
# result["query"]["sel"] = int(pred_sql[1])
# result["query"]["conds"] = [(int(cond[0]), int(cond[1]), str(cond[2])) for cond in pred_sql[2]]
# g.write(json.dumps(result) + "\n")
# eg_res = print_metric(label_file, eg_out_file, db_file)
with open(test_result_file, "w") as g:
g.write("normal results:\n" + normal_res)
if __name__ == "__main__":
splits = ["dev", "test"]
# models = [("20211022_063251", 29), ("20211023_115645", 25)]
models = [("20211023_115427", 31)]
# models = [("20211024_072729", 4)]
for moment, epoch in models:
for split in splits:
execute_one_test(split, moment, epoch)