comparison env/lib/python3.9/site-packages/cwltool/workflow.py @ 0:4f3585e2f14b draft default tip

"planemo upload commit 60cee0fc7c0cda8592644e1aad72851dec82c959"
author shellac
date Mon, 22 Mar 2021 18:12:50 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:4f3585e2f14b
1 import copy
2 import datetime
3 import functools
4 import logging
5 import random
6 from typing import (
7 Callable,
8 Dict,
9 List,
10 Mapping,
11 MutableMapping,
12 MutableSequence,
13 Optional,
14 cast,
15 )
16 from uuid import UUID
17
18 from ruamel.yaml.comments import CommentedMap
19 from schema_salad.exceptions import ValidationException
20 from schema_salad.sourceline import SourceLine, indent
21
22 from . import command_line_tool, context, procgenerator
23 from .checker import static_checker
24 from .context import LoadingContext, RuntimeContext, getdefault
25 from .errors import WorkflowException
26 from .load_tool import load_tool
27 from .loghandler import _logger
28 from .process import Process, get_overrides, shortname
29 from .provenance_profile import ProvenanceProfile
30 from .utils import (
31 CWLObjectType,
32 JobsGeneratorType,
33 OutputCallbackType,
34 StepType,
35 aslist,
36 )
37 from .workflow_job import WorkflowJob
38
39
40 def default_make_tool(
41 toolpath_object: CommentedMap, loadingContext: LoadingContext
42 ) -> Process:
43 if not isinstance(toolpath_object, MutableMapping):
44 raise WorkflowException("Not a dict: '%s'" % toolpath_object)
45 if "class" in toolpath_object:
46 if toolpath_object["class"] == "CommandLineTool":
47 return command_line_tool.CommandLineTool(toolpath_object, loadingContext)
48 if toolpath_object["class"] == "ExpressionTool":
49 return command_line_tool.ExpressionTool(toolpath_object, loadingContext)
50 if toolpath_object["class"] == "Workflow":
51 return Workflow(toolpath_object, loadingContext)
52 if toolpath_object["class"] == "ProcessGenerator":
53 return procgenerator.ProcessGenerator(toolpath_object, loadingContext)
54 if toolpath_object["class"] == "Operation":
55 return command_line_tool.AbstractOperation(toolpath_object, loadingContext)
56
57 raise WorkflowException(
58 "Missing or invalid 'class' field in "
59 "%s, expecting one of: CommandLineTool, ExpressionTool, Workflow"
60 % toolpath_object["id"]
61 )
62
63
64 context.default_make_tool = default_make_tool
65
66
67 class Workflow(Process):
68 def __init__(
69 self,
70 toolpath_object: CommentedMap,
71 loadingContext: LoadingContext,
72 ) -> None:
73 """Initialize this Workflow."""
74 super().__init__(toolpath_object, loadingContext)
75 self.provenance_object = None # type: Optional[ProvenanceProfile]
76 if loadingContext.research_obj is not None:
77 run_uuid = None # type: Optional[UUID]
78 is_main = not loadingContext.prov_obj # Not yet set
79 if is_main:
80 run_uuid = loadingContext.research_obj.ro_uuid
81
82 self.provenance_object = ProvenanceProfile(
83 loadingContext.research_obj,
84 full_name=loadingContext.cwl_full_name,
85 host_provenance=loadingContext.host_provenance,
86 user_provenance=loadingContext.user_provenance,
87 orcid=loadingContext.orcid,
88 run_uuid=run_uuid,
89 fsaccess=loadingContext.research_obj.fsaccess,
90 ) # inherit RO UUID for main wf run
91 # TODO: Is Workflow(..) only called when we are the main workflow?
92 self.parent_wf = self.provenance_object
93
94 # FIXME: Won't this overwrite prov_obj for nested workflows?
95 loadingContext.prov_obj = self.provenance_object
96 loadingContext = loadingContext.copy()
97 loadingContext.requirements = self.requirements
98 loadingContext.hints = self.hints
99
100 self.steps = [] # type: List[WorkflowStep]
101 validation_errors = []
102 for index, step in enumerate(self.tool.get("steps", [])):
103 try:
104 self.steps.append(
105 self.make_workflow_step(
106 step, index, loadingContext, loadingContext.prov_obj
107 )
108 )
109 except ValidationException as vexc:
110 if _logger.isEnabledFor(logging.DEBUG):
111 _logger.exception("Validation failed at")
112 validation_errors.append(vexc)
113
114 if validation_errors:
115 raise ValidationException("\n".join(str(v) for v in validation_errors))
116
117 random.shuffle(self.steps)
118
119 # statically validate data links instead of doing it at runtime.
120 workflow_inputs = self.tool["inputs"]
121 workflow_outputs = self.tool["outputs"]
122
123 step_inputs = [] # type: List[CWLObjectType]
124 step_outputs = [] # type: List[CWLObjectType]
125 param_to_step = {} # type: Dict[str, CWLObjectType]
126 for step in self.steps:
127 step_inputs.extend(step.tool["inputs"])
128 step_outputs.extend(step.tool["outputs"])
129 for s in step.tool["inputs"]:
130 param_to_step[s["id"]] = step.tool
131 for s in step.tool["outputs"]:
132 param_to_step[s["id"]] = step.tool
133
134 if getdefault(loadingContext.do_validate, True):
135 static_checker(
136 workflow_inputs,
137 workflow_outputs,
138 step_inputs,
139 step_outputs,
140 param_to_step,
141 )
142
143 def make_workflow_step(
144 self,
145 toolpath_object: CommentedMap,
146 pos: int,
147 loadingContext: LoadingContext,
148 parentworkflowProv: Optional[ProvenanceProfile] = None,
149 ) -> "WorkflowStep":
150 return WorkflowStep(toolpath_object, pos, loadingContext, parentworkflowProv)
151
152 def job(
153 self,
154 job_order: CWLObjectType,
155 output_callbacks: Optional[OutputCallbackType],
156 runtimeContext: RuntimeContext,
157 ) -> JobsGeneratorType:
158 builder = self._init_job(job_order, runtimeContext)
159
160 if runtimeContext.research_obj is not None:
161 if runtimeContext.toplevel:
162 # Record primary-job.json
163 runtimeContext.research_obj.fsaccess = runtimeContext.make_fs_access("")
164 runtimeContext.research_obj.create_job(builder.job)
165
166 job = WorkflowJob(self, runtimeContext)
167 yield job
168
169 runtimeContext = runtimeContext.copy()
170 runtimeContext.part_of = "workflow %s" % job.name
171 runtimeContext.toplevel = False
172
173 yield from job.job(builder.job, output_callbacks, runtimeContext)
174
175 def visit(self, op: Callable[[CommentedMap], None]) -> None:
176 op(self.tool)
177 for step in self.steps:
178 step.visit(op)
179
180
181 def used_by_step(step: StepType, shortinputid: str) -> bool:
182 for st in cast(MutableSequence[CWLObjectType], step["in"]):
183 if st.get("valueFrom"):
184 if ("inputs.%s" % shortinputid) in cast(str, st.get("valueFrom")):
185 return True
186 if step.get("when"):
187 if ("inputs.%s" % shortinputid) in cast(str, step.get("when")):
188 return True
189 return False
190
191
192 class WorkflowStep(Process):
193 def __init__(
194 self,
195 toolpath_object: CommentedMap,
196 pos: int,
197 loadingContext: LoadingContext,
198 parentworkflowProv: Optional[ProvenanceProfile] = None,
199 ) -> None:
200 """Initialize this WorkflowStep."""
201 if "id" in toolpath_object:
202 self.id = toolpath_object["id"]
203 else:
204 self.id = "#step" + str(pos)
205
206 loadingContext = loadingContext.copy()
207
208 loadingContext.requirements = copy.deepcopy(
209 getdefault(loadingContext.requirements, [])
210 )
211 assert loadingContext.requirements is not None # nosec
212 loadingContext.requirements.extend(toolpath_object.get("requirements", []))
213 loadingContext.requirements.extend(
214 cast(
215 List[CWLObjectType],
216 get_overrides(
217 getdefault(loadingContext.overrides_list, []), self.id
218 ).get("requirements", []),
219 )
220 )
221
222 hints = copy.deepcopy(getdefault(loadingContext.hints, []))
223 hints.extend(toolpath_object.get("hints", []))
224 loadingContext.hints = hints
225
226 try:
227 if isinstance(toolpath_object["run"], CommentedMap):
228 self.embedded_tool = loadingContext.construct_tool_object(
229 toolpath_object["run"], loadingContext
230 ) # type: Process
231 else:
232 loadingContext.metadata = {}
233 self.embedded_tool = load_tool(toolpath_object["run"], loadingContext)
234 except ValidationException as vexc:
235 if loadingContext.debug:
236 _logger.exception("Validation exception")
237 raise WorkflowException(
238 "Tool definition %s failed validation:\n%s"
239 % (toolpath_object["run"], indent(str(vexc)))
240 ) from vexc
241
242 validation_errors = []
243 self.tool = toolpath_object = copy.deepcopy(toolpath_object)
244 bound = set()
245 for stepfield, toolfield in (("in", "inputs"), ("out", "outputs")):
246 toolpath_object[toolfield] = []
247 for index, step_entry in enumerate(toolpath_object[stepfield]):
248 if isinstance(step_entry, str):
249 param = CommentedMap() # type: CommentedMap
250 inputid = step_entry
251 else:
252 param = CommentedMap(step_entry.items())
253 inputid = step_entry["id"]
254
255 shortinputid = shortname(inputid)
256 found = False
257 for tool_entry in self.embedded_tool.tool[toolfield]:
258 frag = shortname(tool_entry["id"])
259 if frag == shortinputid:
260 # if the case that the step has a default for a parameter,
261 # we do not want the default of the tool to override it
262 step_default = None
263 if "default" in param and "default" in tool_entry:
264 step_default = param["default"]
265 param.update(tool_entry)
266 param["_tool_entry"] = tool_entry
267 if step_default is not None:
268 param["default"] = step_default
269 found = True
270 bound.add(frag)
271 break
272 if not found:
273 if stepfield == "in":
274 param["type"] = "Any"
275 param["used_by_step"] = used_by_step(self.tool, shortinputid)
276 param["not_connected"] = True
277 else:
278 if isinstance(step_entry, Mapping):
279 step_entry_name = step_entry["id"]
280 else:
281 step_entry_name = step_entry
282 validation_errors.append(
283 SourceLine(self.tool["out"], index).makeError(
284 "Workflow step output '%s' does not correspond to"
285 % shortname(step_entry_name)
286 )
287 + "\n"
288 + SourceLine(self.embedded_tool.tool, "outputs").makeError(
289 " tool output (expected '%s')"
290 % (
291 "', '".join(
292 [
293 shortname(tool_entry["id"])
294 for tool_entry in self.embedded_tool.tool[
295 "outputs"
296 ]
297 ]
298 )
299 )
300 )
301 )
302 param["id"] = inputid
303 param.lc.line = toolpath_object[stepfield].lc.data[index][0]
304 param.lc.col = toolpath_object[stepfield].lc.data[index][1]
305 param.lc.filename = toolpath_object[stepfield].lc.filename
306 toolpath_object[toolfield].append(param)
307
308 missing_values = []
309 for _, tool_entry in enumerate(self.embedded_tool.tool["inputs"]):
310 if shortname(tool_entry["id"]) not in bound:
311 if "null" not in tool_entry["type"] and "default" not in tool_entry:
312 missing_values.append(shortname(tool_entry["id"]))
313
314 if missing_values:
315 validation_errors.append(
316 SourceLine(self.tool, "in").makeError(
317 "Step is missing required parameter%s '%s'"
318 % (
319 "s" if len(missing_values) > 1 else "",
320 "', '".join(missing_values),
321 )
322 )
323 )
324
325 if validation_errors:
326 raise ValidationException("\n".join(validation_errors))
327
328 super().__init__(toolpath_object, loadingContext)
329
330 if self.embedded_tool.tool["class"] == "Workflow":
331 (feature, _) = self.get_requirement("SubworkflowFeatureRequirement")
332 if not feature:
333 raise WorkflowException(
334 "Workflow contains embedded workflow but "
335 "SubworkflowFeatureRequirement not in requirements"
336 )
337
338 if "scatter" in self.tool:
339 (feature, _) = self.get_requirement("ScatterFeatureRequirement")
340 if not feature:
341 raise WorkflowException(
342 "Workflow contains scatter but ScatterFeatureRequirement "
343 "not in requirements"
344 )
345
346 inputparms = copy.deepcopy(self.tool["inputs"])
347 outputparms = copy.deepcopy(self.tool["outputs"])
348 scatter = aslist(self.tool["scatter"])
349
350 method = self.tool.get("scatterMethod")
351 if method is None and len(scatter) != 1:
352 raise ValidationException(
353 "Must specify scatterMethod when scattering over multiple inputs"
354 )
355
356 inp_map = {i["id"]: i for i in inputparms}
357 for inp in scatter:
358 if inp not in inp_map:
359 raise ValidationException(
360 SourceLine(self.tool, "scatter").makeError(
361 "Scatter parameter '%s' does not correspond to "
362 "an input parameter of this step, expecting '%s'"
363 % (
364 shortname(inp),
365 "', '".join(shortname(k) for k in inp_map.keys()),
366 )
367 )
368 )
369
370 inp_map[inp]["type"] = {"type": "array", "items": inp_map[inp]["type"]}
371
372 if self.tool.get("scatterMethod") == "nested_crossproduct":
373 nesting = len(scatter)
374 else:
375 nesting = 1
376
377 for _ in range(0, nesting):
378 for oparam in outputparms:
379 oparam["type"] = {"type": "array", "items": oparam["type"]}
380 self.tool["inputs"] = inputparms
381 self.tool["outputs"] = outputparms
382 self.prov_obj = None # type: Optional[ProvenanceProfile]
383 if loadingContext.research_obj is not None:
384 self.prov_obj = parentworkflowProv
385 if self.embedded_tool.tool["class"] == "Workflow":
386 self.parent_wf = self.embedded_tool.parent_wf
387 else:
388 self.parent_wf = self.prov_obj
389
390 def receive_output(
391 self,
392 output_callback: OutputCallbackType,
393 jobout: CWLObjectType,
394 processStatus: str,
395 ) -> None:
396 output = {}
397 for i in self.tool["outputs"]:
398 field = shortname(i["id"])
399 if field in jobout:
400 output[i["id"]] = jobout[field]
401 else:
402 processStatus = "permanentFail"
403 output_callback(output, processStatus)
404
405 def job(
406 self,
407 job_order: CWLObjectType,
408 output_callbacks: Optional[OutputCallbackType],
409 runtimeContext: RuntimeContext,
410 ) -> JobsGeneratorType:
411 """Initialize sub-workflow as a step in the parent profile."""
412 if (
413 self.embedded_tool.tool["class"] == "Workflow"
414 and runtimeContext.research_obj
415 and self.prov_obj
416 and self.embedded_tool.provenance_object
417 ):
418 self.embedded_tool.parent_wf = self.prov_obj
419 process_name = self.tool["id"].split("#")[1]
420 self.prov_obj.start_process(
421 process_name,
422 datetime.datetime.now(),
423 self.embedded_tool.provenance_object.workflow_run_uri,
424 )
425
426 step_input = {}
427 for inp in self.tool["inputs"]:
428 field = shortname(inp["id"])
429 if not inp.get("not_connected"):
430 step_input[field] = job_order[inp["id"]]
431
432 try:
433 yield from self.embedded_tool.job(
434 step_input,
435 functools.partial(self.receive_output, output_callbacks),
436 runtimeContext,
437 )
438 except WorkflowException:
439 _logger.error("Exception on step '%s'", runtimeContext.name)
440 raise
441 except Exception as exc:
442 _logger.exception("Unexpected exception")
443 raise WorkflowException(str(exc)) from exc
444
445 def visit(self, op: Callable[[CommentedMap], None]) -> None:
446 self.embedded_tool.visit(op)