|
| 1 | +import os |
| 2 | +import re |
| 3 | +import argparse |
| 4 | +import random |
| 5 | + |
| 6 | + |
| 7 | +def crossval_files(prefix, numfolds): |
| 8 | + cvfiles = [] |
| 9 | + for i in range(numfolds): |
| 10 | + trainfile = '{}train{}.types'.format(prefix, i) |
| 11 | + testfile = '{}test{}.types'.format(prefix, i) |
| 12 | + cvfiles.append((trainfile, testfile)) |
| 13 | + return cvfiles |
| 14 | + |
| 15 | + |
| 16 | +def reduced_file(file): |
| 17 | + match = re.match('(.*?)(((train|test)[0-9]+)?.types)', file) |
| 18 | + return match.group(1) + '_reduced' + match.group(2) |
| 19 | + |
| 20 | + |
| 21 | +def read_lines(file): |
| 22 | + with open(file, 'r') as f: |
| 23 | + lines = f.readlines() |
| 24 | + return lines |
| 25 | + |
| 26 | + |
| 27 | +def write_reduced_lines(file, lines, factor): |
| 28 | + random.shuffle(lines) |
| 29 | + reduced = lines[:int(len(lines)/factor)] |
| 30 | + with open(file, 'w') as f: |
| 31 | + f.write(''.join(reduced)) |
| 32 | + |
| 33 | + |
| 34 | +def parse_args(argv=None): |
| 35 | + parser = argparse.ArgumentParser() |
| 36 | + parser.add_argument('-p', '--prefix', required=True) |
| 37 | + parser.add_argument('-n', '--numfolds', type=int, default=3) |
| 38 | + parser.add_argument('-a', '--allfolds', default=False, action='store_true') |
| 39 | + parser.add_argument('-f', '--factor', required=True, type=float) |
| 40 | + parser.add_argument('-s', '--random_seed', type=int, default=0) |
| 41 | + return parser.parse_args(argv) |
| 42 | + |
| 43 | + |
| 44 | +if __name__ == '__main__': |
| 45 | + args = parse_args() |
| 46 | + random.seed(args.random_seed) |
| 47 | + cvfiles = crossval_files(args.prefix, args.numfolds) |
| 48 | + for i, (trainfile, testfile) in enumerate(cvfiles): |
| 49 | + train = read_lines(trainfile) |
| 50 | + reduced_trainfile = reduced_file(trainfile) |
| 51 | + write_reduced_lines(reduced_trainfile, train, args.factor) |
| 52 | + print(reduced_trainfile) |
| 53 | + test = read_lines(testfile) |
| 54 | + reduced_testfile = reduced_file(testfile) |
| 55 | + write_reduced_lines(reduced_testfile, test, args.factor) |
| 56 | + print(reduced_testfile) |
| 57 | + if args.allfolds: |
| 58 | + allfile = '{}.types'.format(args.prefix) |
| 59 | + all = read_lines(allfile) |
| 60 | + reduced_allfile = reduced_file(allfile) |
| 61 | + write_reduced_lines(reduced_allfile, all, args.factor) |
| 62 | + print(reduced_allfile) |
| 63 | + |
0 commit comments