Skip to content

Commit 3a4d30a

Browse files
committed
add reduce_data.py
1 parent adb1f1b commit 3a4d30a

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

Diff for: reduce_data.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import re
3+
import argparse
4+
import random
5+
6+
7+
def crossval_files(prefix, numfolds):
8+
cvfiles = []
9+
for i in range(numfolds):
10+
trainfile = '{}train{}.types'.format(prefix, i)
11+
testfile = '{}test{}.types'.format(prefix, i)
12+
cvfiles.append((trainfile, testfile))
13+
return cvfiles
14+
15+
16+
def reduced_file(file):
17+
match = re.match('(.*?)(((train|test)[0-9]+)?.types)', file)
18+
return match.group(1) + '_reduced' + match.group(2)
19+
20+
21+
def read_lines(file):
22+
with open(file, 'r') as f:
23+
lines = f.readlines()
24+
return lines
25+
26+
27+
def write_reduced_lines(file, lines, factor):
28+
random.shuffle(lines)
29+
reduced = lines[:int(len(lines)/factor)]
30+
with open(file, 'w') as f:
31+
f.write(''.join(reduced))
32+
33+
34+
def parse_args(argv=None):
35+
parser = argparse.ArgumentParser()
36+
parser.add_argument('-p', '--prefix', required=True)
37+
parser.add_argument('-n', '--numfolds', type=int, default=3)
38+
parser.add_argument('-a', '--allfolds', default=False, action='store_true')
39+
parser.add_argument('-f', '--factor', required=True, type=float)
40+
parser.add_argument('-s', '--random_seed', type=int, default=0)
41+
return parser.parse_args(argv)
42+
43+
44+
if __name__ == '__main__':
45+
args = parse_args()
46+
random.seed(args.random_seed)
47+
cvfiles = crossval_files(args.prefix, args.numfolds)
48+
for i, (trainfile, testfile) in enumerate(cvfiles):
49+
train = read_lines(trainfile)
50+
reduced_trainfile = reduced_file(trainfile)
51+
write_reduced_lines(reduced_trainfile, train, args.factor)
52+
print(reduced_trainfile)
53+
test = read_lines(testfile)
54+
reduced_testfile = reduced_file(testfile)
55+
write_reduced_lines(reduced_testfile, test, args.factor)
56+
print(reduced_testfile)
57+
if args.allfolds:
58+
allfile = '{}.types'.format(args.prefix)
59+
all = read_lines(allfile)
60+
reduced_allfile = reduced_file(allfile)
61+
write_reduced_lines(reduced_allfile, all, args.factor)
62+
print(reduced_allfile)
63+

0 commit comments

Comments
 (0)