Skip to content

Commit

Permalink
reprocess record types
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-c committed Oct 4, 2016
1 parent 278a62b commit 8644b63
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
29 changes: 21 additions & 8 deletions cwltool/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def output_callback(out, processStatus):

return final_output[0]


class FSAction(argparse.Action):
objclass = None # type: Text

Expand Down Expand Up @@ -294,8 +293,9 @@ class DirectoryAppendAction(FSAppendAction):
objclass = "Directory"


def add_argument(toolparser, name, inptype, description="", default=None):
# type: (argparse.ArgumentParser, Text, Any, Text, Any) -> None
def add_argument(toolparser, name, inptype, records, description="",
default=None):
# type: (argparse.ArgumentParser, Text, Any, List[Text], Text, Any) -> None
if len(name) == 1:
flag = "-"
else:
Expand Down Expand Up @@ -329,12 +329,14 @@ def add_argument(toolparser, name, inptype, description="", default=None):
elif isinstance(inptype, dict) and inptype["type"] == "enum":
atype = Text
elif isinstance(inptype, dict) and inptype["type"] == "record":
records.append(name)
for field in inptype['fields']:
fieldname = name+"."+shortname(field['name'])
fieldtype = field['type']
fielddescription = field.get("doc", "")
add_argument(
toolparser, fieldname, fieldtype, fielddescription)
toolparser, fieldname, fieldtype, records,
fielddescription)
return
if inptype == "string":
atype = Text
Expand Down Expand Up @@ -364,8 +366,8 @@ def add_argument(toolparser, name, inptype, description="", default=None):
default=default, **typekw)


def generate_parser(toolparser, tool, namemap):
# type: (argparse.ArgumentParser, Process, Dict[Text, Text]) -> argparse.ArgumentParser
def generate_parser(toolparser, tool, namemap, records):
# type: (argparse.ArgumentParser, Process, Dict[Text, Text], List[Text]) -> argparse.ArgumentParser
toolparser.add_argument("job_order", nargs="?", help="Job input json file")
namemap["job_order"] = "job_order"

Expand All @@ -375,7 +377,7 @@ def generate_parser(toolparser, tool, namemap):
inptype = inp["type"]
description = inp.get("doc", "")
default = inp.get("default", None)
add_argument(toolparser, name, inptype, description, default)
add_argument(toolparser, name, inptype, records, description, default)

return toolparser

Expand Down Expand Up @@ -418,12 +420,23 @@ def load_job_order(args, t, stdin, print_input_deps=False, relative_deps=False,
else:
input_basedir = args.basedir if args.basedir else os.getcwd()
namemap = {} # type: Dict[Text, Text]
toolparser = generate_parser(argparse.ArgumentParser(prog=args.workflow), t, namemap)
records = [] # type: List[Text]
toolparser = generate_parser(
argparse.ArgumentParser(prog=args.workflow), t, namemap, records)
if toolparser:
if args.tool_help:
toolparser.print_help()
return 0
cmd_line = vars(toolparser.parse_args(args.job_order))
for record_name in records:
record = {}
record_items = {
k:v for k,v in cmd_line.iteritems()
if k.startswith(record_name)}
for key, value in record_items.iteritems():
record[key[len(record_name)+1:]] = value
del cmd_line[key]
cmd_line[str(record_name)] = record

if cmd_line["job_order"]:
try:
Expand Down
15 changes: 14 additions & 1 deletion tests/test_toolargparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class ToolArgparse(unittest.TestCase):
one: File
two: string
expression: $(inputs.foo.two)
outputs: []
'''

Expand All @@ -77,7 +79,7 @@ def test_bool(self):
except SystemExit as e:
self.assertEquals(e.code, 0)

def test_record(self):
def test_record_help(self):
with NamedTemporaryFile() as f:
f.write(self.script3)
f.flush()
Expand All @@ -86,6 +88,17 @@ def test_record(self):
except SystemExit as e:
self.assertEquals(e.code, 0)

def test_record(self):
with NamedTemporaryFile() as f:
f.write(self.script3)
f.flush()
try:
self.assertEquals(main([f.name, '--foo.one', 'README.rst',
'--foo.two', 'test']), 0)
except SystemExit as e:
self.assertEquals(e.code, 0)



if __name__ == '__main__':
unittest.main()

0 comments on commit 8644b63

Please sign in to comment.