Skip to content

Commit

Permalink
fixes for eko
Browse files Browse the repository at this point in the history
  • Loading branch information
gaa-cifasis committed Oct 17, 2015
1 parent 3d1cf7d commit ab9224d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 12 deletions.
8 changes: 4 additions & 4 deletions tcreator
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ concatenate = lambda *lists: reduce((lambda a,b: a.extend(b) or a),lists,[])
if __name__ == "__main__":

# Arguments
parser = argparse.ArgumentParser(description='A script to create new test cases using a name and a command line')
parser.add_argument("--name", help="A csv with the features to train or predict", type=str, default=None)
parser = argparse.ArgumentParser(description='A small utility to create new test cases using a name and a command line')
parser.add_argument("--name", help="The name of the ", type=str, default=None)
parser.add_argument("--cmd", help="Command-line to execute", type=str, default=None)
parser.add_argument("--batch", help="A csv with the command lines", type=str, default=None)

parser.add_argument("--copy", help="A csv with the features to train or predict", action='store_true', default=False)
parser.add_argument("--copy", help="Force the copy of the files in command lines instead of symbolic linking", action='store_true', default=False)

parser.add_argument("outdir", help="Output directory to write testcases", type=str, default=None)

Expand Down Expand Up @@ -79,7 +79,7 @@ if __name__ == "__main__":
if arg <> '':
pargs = pargs + arg
#args = concatenate(args)
print pargs
print "Procesing '" + " ".join(pargs) + "'"
#args = filter(lambda x: x is not '', cmd.split(" "))
WriteTestcase(name,pargs[0],pargs[1:], copy)

9 changes: 9 additions & 0 deletions vdiscover/Pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def fit_transform(self, X, y=None, **fit_params):
def fit(self, X, y=None, **fit_params):
return self

def get_params(self, deep=True):
return []


class ItemSelector(BaseEstimator, TransformerMixin):

Expand All @@ -59,6 +62,9 @@ def fit(self, x, y=None):
def transform(self, data_dict):
return data_dict[self.key]

def get_params(self, deep=True):
return []

class CutoffMax(BaseEstimator, TransformerMixin):

def __init__(self, maxv):
Expand All @@ -72,6 +78,9 @@ def transform(self, X, y=None, **fit_params):
X[self.pos] = self.maxv
return X

def get_params(self, deep=True):
return []



def make_train_pipeline(ftype):
Expand Down
8 changes: 3 additions & 5 deletions vdiscover/Recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
def Recall(model_file, in_file, in_type, out_file, test_mode, probability=False):

model = load_model(model_file)
#csvreader = open_csv(in_file)

outfile = open_csv(out_file)
csvwriter = csv.writer(outfile, delimiter='\t')
csvwriter = write_csv(out_file)

x = dict()

Expand All @@ -39,8 +36,9 @@ def Recall(model_file, in_file, in_type, out_file, test_mode, probability=False)
else:
err = recall_score(test_classes, predicted_classes, average=None)

print err[0], err[1], sum(err)/2.0
print classification_report(test_classes, predicted_classes)
print "Errors per class:", err[0], err[1]
print "Average error:", sum(err)/2.0

elif test_mode == "aggregated":

Expand Down
6 changes: 3 additions & 3 deletions vdiscover/Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def TrainScikitLearn(model_file, train_file, valid_file, ftype, nsamples):
model = make_train_pipeline(ftype)
model.fit(train_dict,train_classes)

print "Resulting model:"
print model
print confusion_matrix(train_classes, model.predict(train_dict))
print "Done!"
#print model
#print confusion_matrix(train_classes, model.predict(train_dict))

print "Saving model to",model_file
modelfile.write(pickle.dumps(model))
Expand Down
8 changes: 8 additions & 0 deletions vdiscover/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def load_csv(in_file):

return csv.reader(infile, delimiter='\t')

def write_csv(in_file):

if ".gz" in in_file:
infile = gzip.open(in_file, "w")
else:
infile = open(in_file, "w")

return csv.writer(infile, delimiter='\t')

def open_csv(in_file):

Expand Down

0 comments on commit ab9224d

Please sign in to comment.