-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_data.py
112 lines (88 loc) · 3.95 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class getReader():
def __init__(self, path):
self.path = path
def readData(self):
problem_list = []
ans_list = []
split_char = ','
read = open(self.path, 'r')
for index, line in enumerate(read):
if index % 3 == 0:
pass
elif index % 3 == 1:
problems = line.strip().split(split_char)
# 由于列表problems每个元素都是char 需要变为int
problems = list(map(int, problems))
problem_list.append(problems)
elif index % 3 == 2:
ans = line.strip().split(split_char)
# 由于列表ans每个元素都是char 需要变为int
ans = list(map(float, ans))
ans = [int(x) for x in ans]
ans_list.append(ans)
read.close()
return problem_list, ans_list
import torch.utils.data as data
import numpy as np
class KT_Dataset(data.Dataset):
def __init__(self, problem_max, problem_list, ans_list, min_problem_num, max_problem_num):
self.problem_max = problem_max
self.min_problem_num = min_problem_num
self.max_problem_num = max_problem_num
self.problem_list, self.ans_list = [], []
for (problem, ans) in zip(problem_list, ans_list):
num = len(problem)
if num < min_problem_num:
continue
elif num > max_problem_num:
segment = num // max_problem_num
now_problem = problem[num - segment * max_problem_num:]
now_ans = ans[num - segment * max_problem_num:]
if num > segment * max_problem_num:
self.problem_list.append(problem[:num - segment * max_problem_num])
self.ans_list.append(ans[:num - segment * max_problem_num])
for i in range(segment):
item_problem = now_problem[i * max_problem_num:(i + 1) * max_problem_num]
item_ans = now_ans[i * max_problem_num:(i + 1) * max_problem_num]
self.problem_list.append(item_problem)
self.ans_list.append(item_ans)
else:
item_problem = problem
item_ans = ans
self.problem_list.append(item_problem)
self.ans_list.append(item_ans)
def __len__(self):
return len(self.problem_list)
def __getitem__(self, index):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
now_problem = self.problem_list[index]
now_problem = np.array(now_problem)
now_ans = self.ans_list[index]
# 由于需要统一格式
use_problem = np.zeros(self.max_problem_num, dtype=int)
use_ans = np.zeros(self.max_problem_num, dtype=int)
num = len(now_problem)
use_problem[-num:] = now_problem
use_ans[-num:] = now_ans
next_ans = use_ans[1:]
next_problem = use_problem[1:]
last_ans = use_ans[:-1]
last_problem = use_problem[:-1]
mask = np.zeros(self.max_problem_num - 1, dtype=int)
mask[-num + 1:] = 1
last_problem = torch.from_numpy(last_problem).to(device).long()
next_problem = torch.from_numpy(next_problem).to(device).long()
last_ans = torch.from_numpy(last_ans).to(device).long()
next_ans = torch.from_numpy(next_ans).to(device).float()
return last_problem, last_ans, next_problem, next_ans, torch.tensor(mask == 1).to(device)
def getLoader(problem_max, pro_path, batch_size, is_train, min_problem_num, max_problem_num):
problem_list, ans_list = getReader(pro_path).readData()
dataset = KT_Dataset(problem_max, problem_list, ans_list, min_problem_num, max_problem_num)
loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=is_train)
return loader