Mercurial > repos > shellac > sam_consensus_v3
view env/lib/python3.9/site-packages/cwltool/workflow.py @ 0:4f3585e2f14b draft default tip
"planemo upload commit 60cee0fc7c0cda8592644e1aad72851dec82c959"
author | shellac |
---|---|
date | Mon, 22 Mar 2021 18:12:50 +0000 |
parents | |
children |
line wrap: on
line source
import copy import datetime import functools import logging import random from typing import ( Callable, Dict, List, Mapping, MutableMapping, MutableSequence, Optional, cast, ) from uuid import UUID from ruamel.yaml.comments import CommentedMap from schema_salad.exceptions import ValidationException from schema_salad.sourceline import SourceLine, indent from . import command_line_tool, context, procgenerator from .checker import static_checker from .context import LoadingContext, RuntimeContext, getdefault from .errors import WorkflowException from .load_tool import load_tool from .loghandler import _logger from .process import Process, get_overrides, shortname from .provenance_profile import ProvenanceProfile from .utils import ( CWLObjectType, JobsGeneratorType, OutputCallbackType, StepType, aslist, ) from .workflow_job import WorkflowJob def default_make_tool( toolpath_object: CommentedMap, loadingContext: LoadingContext ) -> Process: if not isinstance(toolpath_object, MutableMapping): raise WorkflowException("Not a dict: '%s'" % toolpath_object) if "class" in toolpath_object: if toolpath_object["class"] == "CommandLineTool": return command_line_tool.CommandLineTool(toolpath_object, loadingContext) if toolpath_object["class"] == "ExpressionTool": return command_line_tool.ExpressionTool(toolpath_object, loadingContext) if toolpath_object["class"] == "Workflow": return Workflow(toolpath_object, loadingContext) if toolpath_object["class"] == "ProcessGenerator": return procgenerator.ProcessGenerator(toolpath_object, loadingContext) if toolpath_object["class"] == "Operation": return command_line_tool.AbstractOperation(toolpath_object, loadingContext) raise WorkflowException( "Missing or invalid 'class' field in " "%s, expecting one of: CommandLineTool, ExpressionTool, Workflow" % toolpath_object["id"] ) context.default_make_tool = default_make_tool class Workflow(Process): def __init__( self, toolpath_object: CommentedMap, loadingContext: LoadingContext, ) -> None: """Initialize this Workflow.""" super().__init__(toolpath_object, loadingContext) self.provenance_object = None # type: Optional[ProvenanceProfile] if loadingContext.research_obj is not None: run_uuid = None # type: Optional[UUID] is_main = not loadingContext.prov_obj # Not yet set if is_main: run_uuid = loadingContext.research_obj.ro_uuid self.provenance_object = ProvenanceProfile( loadingContext.research_obj, full_name=loadingContext.cwl_full_name, host_provenance=loadingContext.host_provenance, user_provenance=loadingContext.user_provenance, orcid=loadingContext.orcid, run_uuid=run_uuid, fsaccess=loadingContext.research_obj.fsaccess, ) # inherit RO UUID for main wf run # TODO: Is Workflow(..) only called when we are the main workflow? self.parent_wf = self.provenance_object # FIXME: Won't this overwrite prov_obj for nested workflows? loadingContext.prov_obj = self.provenance_object loadingContext = loadingContext.copy() loadingContext.requirements = self.requirements loadingContext.hints = self.hints self.steps = [] # type: List[WorkflowStep] validation_errors = [] for index, step in enumerate(self.tool.get("steps", [])): try: self.steps.append( self.make_workflow_step( step, index, loadingContext, loadingContext.prov_obj ) ) except ValidationException as vexc: if _logger.isEnabledFor(logging.DEBUG): _logger.exception("Validation failed at") validation_errors.append(vexc) if validation_errors: raise ValidationException("\n".join(str(v) for v in validation_errors)) random.shuffle(self.steps) # statically validate data links instead of doing it at runtime. workflow_inputs = self.tool["inputs"] workflow_outputs = self.tool["outputs"] step_inputs = [] # type: List[CWLObjectType] step_outputs = [] # type: List[CWLObjectType] param_to_step = {} # type: Dict[str, CWLObjectType] for step in self.steps: step_inputs.extend(step.tool["inputs"]) step_outputs.extend(step.tool["outputs"]) for s in step.tool["inputs"]: param_to_step[s["id"]] = step.tool for s in step.tool["outputs"]: param_to_step[s["id"]] = step.tool if getdefault(loadingContext.do_validate, True): static_checker( workflow_inputs, workflow_outputs, step_inputs, step_outputs, param_to_step, ) def make_workflow_step( self, toolpath_object: CommentedMap, pos: int, loadingContext: LoadingContext, parentworkflowProv: Optional[ProvenanceProfile] = None, ) -> "WorkflowStep": return WorkflowStep(toolpath_object, pos, loadingContext, parentworkflowProv) def job( self, job_order: CWLObjectType, output_callbacks: Optional[OutputCallbackType], runtimeContext: RuntimeContext, ) -> JobsGeneratorType: builder = self._init_job(job_order, runtimeContext) if runtimeContext.research_obj is not None: if runtimeContext.toplevel: # Record primary-job.json runtimeContext.research_obj.fsaccess = runtimeContext.make_fs_access("") runtimeContext.research_obj.create_job(builder.job) job = WorkflowJob(self, runtimeContext) yield job runtimeContext = runtimeContext.copy() runtimeContext.part_of = "workflow %s" % job.name runtimeContext.toplevel = False yield from job.job(builder.job, output_callbacks, runtimeContext) def visit(self, op: Callable[[CommentedMap], None]) -> None: op(self.tool) for step in self.steps: step.visit(op) def used_by_step(step: StepType, shortinputid: str) -> bool: for st in cast(MutableSequence[CWLObjectType], step["in"]): if st.get("valueFrom"): if ("inputs.%s" % shortinputid) in cast(str, st.get("valueFrom")): return True if step.get("when"): if ("inputs.%s" % shortinputid) in cast(str, step.get("when")): return True return False class WorkflowStep(Process): def __init__( self, toolpath_object: CommentedMap, pos: int, loadingContext: LoadingContext, parentworkflowProv: Optional[ProvenanceProfile] = None, ) -> None: """Initialize this WorkflowStep.""" if "id" in toolpath_object: self.id = toolpath_object["id"] else: self.id = "#step" + str(pos) loadingContext = loadingContext.copy() loadingContext.requirements = copy.deepcopy( getdefault(loadingContext.requirements, []) ) assert loadingContext.requirements is not None # nosec loadingContext.requirements.extend(toolpath_object.get("requirements", [])) loadingContext.requirements.extend( cast( List[CWLObjectType], get_overrides( getdefault(loadingContext.overrides_list, []), self.id ).get("requirements", []), ) ) hints = copy.deepcopy(getdefault(loadingContext.hints, [])) hints.extend(toolpath_object.get("hints", [])) loadingContext.hints = hints try: if isinstance(toolpath_object["run"], CommentedMap): self.embedded_tool = loadingContext.construct_tool_object( toolpath_object["run"], loadingContext ) # type: Process else: loadingContext.metadata = {} self.embedded_tool = load_tool(toolpath_object["run"], loadingContext) except ValidationException as vexc: if loadingContext.debug: _logger.exception("Validation exception") raise WorkflowException( "Tool definition %s failed validation:\n%s" % (toolpath_object["run"], indent(str(vexc))) ) from vexc validation_errors = [] self.tool = toolpath_object = copy.deepcopy(toolpath_object) bound = set() for stepfield, toolfield in (("in", "inputs"), ("out", "outputs")): toolpath_object[toolfield] = [] for index, step_entry in enumerate(toolpath_object[stepfield]): if isinstance(step_entry, str): param = CommentedMap() # type: CommentedMap inputid = step_entry else: param = CommentedMap(step_entry.items()) inputid = step_entry["id"] shortinputid = shortname(inputid) found = False for tool_entry in self.embedded_tool.tool[toolfield]: frag = shortname(tool_entry["id"]) if frag == shortinputid: # if the case that the step has a default for a parameter, # we do not want the default of the tool to override it step_default = None if "default" in param and "default" in tool_entry: step_default = param["default"] param.update(tool_entry) param["_tool_entry"] = tool_entry if step_default is not None: param["default"] = step_default found = True bound.add(frag) break if not found: if stepfield == "in": param["type"] = "Any" param["used_by_step"] = used_by_step(self.tool, shortinputid) param["not_connected"] = True else: if isinstance(step_entry, Mapping): step_entry_name = step_entry["id"] else: step_entry_name = step_entry validation_errors.append( SourceLine(self.tool["out"], index).makeError( "Workflow step output '%s' does not correspond to" % shortname(step_entry_name) ) + "\n" + SourceLine(self.embedded_tool.tool, "outputs").makeError( " tool output (expected '%s')" % ( "', '".join( [ shortname(tool_entry["id"]) for tool_entry in self.embedded_tool.tool[ "outputs" ] ] ) ) ) ) param["id"] = inputid param.lc.line = toolpath_object[stepfield].lc.data[index][0] param.lc.col = toolpath_object[stepfield].lc.data[index][1] param.lc.filename = toolpath_object[stepfield].lc.filename toolpath_object[toolfield].append(param) missing_values = [] for _, tool_entry in enumerate(self.embedded_tool.tool["inputs"]): if shortname(tool_entry["id"]) not in bound: if "null" not in tool_entry["type"] and "default" not in tool_entry: missing_values.append(shortname(tool_entry["id"])) if missing_values: validation_errors.append( SourceLine(self.tool, "in").makeError( "Step is missing required parameter%s '%s'" % ( "s" if len(missing_values) > 1 else "", "', '".join(missing_values), ) ) ) if validation_errors: raise ValidationException("\n".join(validation_errors)) super().__init__(toolpath_object, loadingContext) if self.embedded_tool.tool["class"] == "Workflow": (feature, _) = self.get_requirement("SubworkflowFeatureRequirement") if not feature: raise WorkflowException( "Workflow contains embedded workflow but " "SubworkflowFeatureRequirement not in requirements" ) if "scatter" in self.tool: (feature, _) = self.get_requirement("ScatterFeatureRequirement") if not feature: raise WorkflowException( "Workflow contains scatter but ScatterFeatureRequirement " "not in requirements" ) inputparms = copy.deepcopy(self.tool["inputs"]) outputparms = copy.deepcopy(self.tool["outputs"]) scatter = aslist(self.tool["scatter"]) method = self.tool.get("scatterMethod") if method is None and len(scatter) != 1: raise ValidationException( "Must specify scatterMethod when scattering over multiple inputs" ) inp_map = {i["id"]: i for i in inputparms} for inp in scatter: if inp not in inp_map: raise ValidationException( SourceLine(self.tool, "scatter").makeError( "Scatter parameter '%s' does not correspond to " "an input parameter of this step, expecting '%s'" % ( shortname(inp), "', '".join(shortname(k) for k in inp_map.keys()), ) ) ) inp_map[inp]["type"] = {"type": "array", "items": inp_map[inp]["type"]} if self.tool.get("scatterMethod") == "nested_crossproduct": nesting = len(scatter) else: nesting = 1 for _ in range(0, nesting): for oparam in outputparms: oparam["type"] = {"type": "array", "items": oparam["type"]} self.tool["inputs"] = inputparms self.tool["outputs"] = outputparms self.prov_obj = None # type: Optional[ProvenanceProfile] if loadingContext.research_obj is not None: self.prov_obj = parentworkflowProv if self.embedded_tool.tool["class"] == "Workflow": self.parent_wf = self.embedded_tool.parent_wf else: self.parent_wf = self.prov_obj def receive_output( self, output_callback: OutputCallbackType, jobout: CWLObjectType, processStatus: str, ) -> None: output = {} for i in self.tool["outputs"]: field = shortname(i["id"]) if field in jobout: output[i["id"]] = jobout[field] else: processStatus = "permanentFail" output_callback(output, processStatus) def job( self, job_order: CWLObjectType, output_callbacks: Optional[OutputCallbackType], runtimeContext: RuntimeContext, ) -> JobsGeneratorType: """Initialize sub-workflow as a step in the parent profile.""" if ( self.embedded_tool.tool["class"] == "Workflow" and runtimeContext.research_obj and self.prov_obj and self.embedded_tool.provenance_object ): self.embedded_tool.parent_wf = self.prov_obj process_name = self.tool["id"].split("#")[1] self.prov_obj.start_process( process_name, datetime.datetime.now(), self.embedded_tool.provenance_object.workflow_run_uri, ) step_input = {} for inp in self.tool["inputs"]: field = shortname(inp["id"]) if not inp.get("not_connected"): step_input[field] = job_order[inp["id"]] try: yield from self.embedded_tool.job( step_input, functools.partial(self.receive_output, output_callbacks), runtimeContext, ) except WorkflowException: _logger.error("Exception on step '%s'", runtimeContext.name) raise except Exception as exc: _logger.exception("Unexpected exception") raise WorkflowException(str(exc)) from exc def visit(self, op: Callable[[CommentedMap], None]) -> None: self.embedded_tool.visit(op)