Skip to content

Commit a22fdcf

Browse files
mvdbeeknsoranzo
authored andcommitted
Keep directory parameters in job parameters
and set LoadListingRequirement as supported requirement, since cwltool handles this for us.
1 parent 34e5701 commit a22fdcf

File tree

4 files changed

+61
-44
lines changed

4 files changed

+61
-44
lines changed

lib/galaxy/tool_util/cwl/parser.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
"EnvVarRequirement",
7171
"InitialWorkDirRequirement",
7272
"InlineJavascriptRequirement",
73+
"LoadListingRequirement",
7374
"ResourceRequirement",
7475
"ShellCommandRequirement",
7576
"ScatterFeatureRequirement",
@@ -346,15 +347,16 @@ def _ensure_cwl_job_initialized(self):
346347
self._is_command_line_job = hasattr(self._cwl_job, "command_line")
347348

348349
def _normalize_job(self):
350+
runtime_context = RuntimeContext({})
351+
make_fs_access = getdefault(runtime_context.make_fs_access, StdFsAccess)
352+
fs_access = make_fs_access(runtime_context.basedir)
353+
349354
# Somehow reuse whatever causes validate in cwltool... maybe?
350355
def pathToLoc(p):
351356
if "location" not in p and "path" in p:
352357
p["location"] = p["path"]
353358
del p["path"]
354359

355-
runtime_context = RuntimeContext({})
356-
make_fs_access = getdefault(runtime_context.make_fs_access, StdFsAccess)
357-
fs_access = make_fs_access(runtime_context.basedir)
358360
process.fill_in_defaults(self._tool_proxy._tool.tool["inputs"], self._input_dict, fs_access)
359361
visit_class(self._input_dict, ("File", "Directory"), pathToLoc)
360362
# TODO: Why doesn't fillInDefault fill in locations instead of paths?

lib/galaxy/tool_util/cwl/representation.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from enum import Enum
88
from typing import (
99
Any,
10+
Dict,
1011
NamedTuple,
1112
Optional,
1213
)
@@ -206,6 +207,9 @@ def dataset_wrapper_to_directory_json(inputs_dir, dataset_wrapper):
206207
except Exception:
207208
archive_location = None
208209

210+
extra_params = getattr(dataset_wrapper.unsanitized, "extra_params", {})
211+
# We need to resolve path to location if there is a listing
212+
209213
directory_json = {
210214
"location": dataset_wrapper.extra_files_path,
211215
"class": "Directory",
@@ -214,8 +218,20 @@ def dataset_wrapper_to_directory_json(inputs_dir, dataset_wrapper):
214218
"archive_nameext": nameext,
215219
"archive_nameroot": nameroot,
216220
}
217-
218-
return directory_json
221+
extra_params.update(directory_json)
222+
entry_to_location(extra_params, extra_params["location"])
223+
return extra_params
224+
225+
226+
def entry_to_location(entry: Dict[str, Any], parent_location: str):
227+
# TODO unit test
228+
if entry["class"] == "File" and "path" in entry and "location" not in entry:
229+
entry["location"] = os.path.join(parent_location, entry.pop("path"))
230+
elif entry["class"] == "Directory" and "listing" in entry:
231+
if "location" not in entry and "path" in entry:
232+
entry["location"] = os.path.join(parent_location, entry.pop("path"))
233+
for listing_entry in entry["listing"]:
234+
entry_to_location(listing_entry, parent_location=entry["location"])
219235

220236

221237
def collection_wrapper_to_array(inputs_dir, wrapped_value):

lib/galaxy/tool_util/cwl/util.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,9 @@ def replacement_directory(value):
312312
finally:
313313
if temp_dir:
314314
shutil.rmtree(temp_dir)
315-
return upload_tar(tmp.name)
315+
upload_response = upload_tar(tmp.name)
316+
upload_response.update(value)
317+
return upload_response
316318

317319
def replacement_list(value) -> Dict[str, str]:
318320
collection_element_identifiers = []

lib/galaxy/tools/parameters/basic.py

+35-38
Original file line numberDiff line numberDiff line change
@@ -1899,7 +1899,10 @@ def single_to_json(value):
18991899
src = "hda"
19001900
if src is not None:
19011901
object_id = cached_id(value)
1902-
return {"id": app.security.encode_id(object_id) if use_security else object_id, "src": src}
1902+
new_val = getattr(value, "extra_params", {})
1903+
new_val["id"] = app.security.encode_id(object_id) if use_security else object_id
1904+
new_val["src"] = src
1905+
return new_val
19031906

19041907
if value not in [None, "", "None"]:
19051908
if isinstance(value, list) and len(value) > 0:
@@ -1912,15 +1915,9 @@ def single_to_json(value):
19121915
def to_python(self, value, app):
19131916
def single_to_python(value):
19141917
if isinstance(value, dict) and "src" in value:
1915-
id = value["id"] if isinstance(value["id"], int) else app.security.decode_id(value["id"])
1916-
if value["src"] == "dce":
1917-
return app.model.context.query(DatasetCollectionElement).get(id)
1918-
elif value["src"] == "hdca":
1919-
return app.model.context.query(HistoryDatasetCollectionAssociation).get(id)
1920-
elif value["src"] == "ldda":
1921-
return app.model.context.query(LibraryDatasetDatasetAssociation).get(id)
1922-
else:
1923-
return app.model.context.query(HistoryDatasetAssociation).get(id)
1918+
if not value["src"] in ("hda", "dce", "ldda", "hdca"):
1919+
raise ParameterValueError(f"Invalid value {value}", self.name)
1920+
return src_id_to_item(sa_session=app.model.context, security=app.security, value=value)
19241921

19251922
if isinstance(value, dict) and "values" in value:
19261923
if hasattr(self, "multiple") and self.multiple is True:
@@ -1996,6 +1993,23 @@ def do_validate(v):
19961993
raise ValueError("At most %d datasets are required for %s" % (self.max, self.name))
19971994

19981995

1996+
def src_id_to_item(sa_session, value, security):
1997+
src_to_class = {
1998+
"hda": HistoryDatasetAssociation,
1999+
"ldda": LibraryDatasetDatasetAssociation,
2000+
"dce": DatasetCollectionElement,
2001+
"hdca": HistoryDatasetCollectionAssociation,
2002+
}
2003+
id_value = value["id"]
2004+
decoded_id = id_value if isinstance(id_value, int) else security.decode_id(id_value)
2005+
try:
2006+
item = sa_session.query(src_to_class[value["src"]]).get(decoded_id)
2007+
except KeyError:
2008+
raise ValueError(f"Unknown input source {value['src']} passed to job submission API.")
2009+
item.extra_params = {k: v for k, v in value.items() if k not in ("src", "id")}
2010+
return item
2011+
2012+
19992013
class DataToolParameter(BaseDataToolParameter):
20002014
# TODO, Nate: Make sure the following unit tests appropriately test the dataset security
20012015
# components. Add as many additional tests as necessary.
@@ -2063,21 +2077,13 @@ def from_json(self, value, trans, other_values=None):
20632077
value = [int(value_part) for value_part in value.split(",")]
20642078
rval = []
20652079
if isinstance(value, list):
2066-
found_hdca = False
2080+
found_srcs = set()
20672081
for single_value in value:
20682082
if isinstance(single_value, dict) and "src" in single_value and "id" in single_value:
2069-
if single_value["src"] == "hda":
2070-
decoded_id = trans.security.decode_id(single_value["id"])
2071-
rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(decoded_id))
2072-
elif single_value["src"] == "hdca":
2073-
found_hdca = True
2074-
decoded_id = trans.security.decode_id(single_value["id"])
2075-
rval.append(trans.sa_session.query(HistoryDatasetCollectionAssociation).get(decoded_id))
2076-
elif single_value["src"] == "ldda":
2077-
decoded_id = trans.security.decode_id(single_value["id"])
2078-
rval.append(trans.sa_session.query(LibraryDatasetDatasetAssociation).get(decoded_id))
2079-
else:
2080-
raise ValueError(f"Unknown input source {single_value['src']} passed to job submission API.")
2083+
found_srcs.add(single_value["src"])
2084+
rval.append(
2085+
src_id_to_item(sa_session=trans.sa_session, value=single_value, security=trans.security)
2086+
)
20812087
elif isinstance(
20822088
single_value,
20832089
(
@@ -2095,24 +2101,15 @@ def from_json(self, value, trans, other_values=None):
20952101
log.warning("Encoded ID where unencoded ID expected.")
20962102
single_value = trans.security.decode_id(single_value)
20972103
rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(single_value))
2098-
if found_hdca:
2099-
for val in rval:
2100-
if not isinstance(val, HistoryDatasetCollectionAssociation):
2101-
raise ParameterValueError(
2102-
"if collections are supplied to multiple data input parameter, only collections may be used",
2103-
self.name,
2104-
)
2104+
if len(found_srcs) > 1 and "hdca" in found_srcs:
2105+
raise ParameterValueError(
2106+
"if collections are supplied to multiple data input parameter, only collections may be used",
2107+
self.name,
2108+
)
21052109
elif isinstance(value, (HistoryDatasetAssociation, LibraryDatasetDatasetAssociation)):
21062110
rval.append(value)
21072111
elif isinstance(value, dict) and "src" in value and "id" in value:
2108-
if value["src"] == "hda":
2109-
decoded_id = trans.security.decode_id(value["id"])
2110-
rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(decoded_id))
2111-
elif value["src"] == "hdca":
2112-
decoded_id = trans.security.decode_id(value["id"])
2113-
rval.append(trans.sa_session.query(HistoryDatasetCollectionAssociation).get(decoded_id))
2114-
else:
2115-
raise ValueError(f"Unknown input source {value['src']} passed to job submission API.")
2112+
rval.append(src_id_to_item(sa_session=trans.sa_session, value=value, security=trans.security))
21162113
elif str(value).startswith("__collection_reduce__|"):
21172114
encoded_ids = [v[len("__collection_reduce__|") :] for v in str(value).split(",")]
21182115
decoded_ids = map(trans.security.decode_id, encoded_ids)

0 commit comments

Comments
 (0)