Skip to content

Commit

Permalink
Sbachmei/mic 5753/embarrassingly parallel basic step (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier authored Feb 26, 2025
1 parent bebce5b commit 664a4a5
Show file tree
Hide file tree
Showing 30 changed files with 1,515 additions and 272 deletions.
1 change: 1 addition & 0 deletions docs/source/api_reference/cli.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.. automodule:: easylink.cli
2 changes: 2 additions & 0 deletions docs/source/api_reference/utilities/aggregator_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. automodule:: easylink.utilities.aggregator_utils

2 changes: 2 additions & 0 deletions docs/source/api_reference/utilities/splitter_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. automodule:: easylink.utilities.splitter_utils

26 changes: 17 additions & 9 deletions src/easylink/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@
default=False,
help="Do not save the results in a timestamped sub-directory of ``--output-dir``.",
),
click.option(
"-v", "--verbose", count=True, help="Increase logging verbosity.", hidden=True
),
click.option(
"--pdb",
"with_debugger",
is_flag=True,
help="Drop into python debugger if an error occurs.",
hidden=True,
),
]


Expand Down Expand Up @@ -129,14 +139,6 @@ def easylink():
"the pipeline will be run locally."
),
)
@click.option("-v", "--verbose", count=True, help="Increase logging verbosity.", hidden=True)
@click.option(
"--pdb",
"with_debugger",
is_flag=True,
help="Drop into python debugger if an error occurs.",
hidden=True,
)
def run(
pipeline_specification: str,
input_data: str,
Expand Down Expand Up @@ -178,17 +180,23 @@ def generate_dag(
input_data: str,
output_dir: str | None,
no_timestamp: bool,
verbose: int,
with_debugger: bool,
) -> None:
"""Generates an image of the proposed pipeline directed acyclic graph (DAG).
This command only generates the DAG image of the pipeline; it does not actually
run it. To run the pipeline, use the ``easylink run`` command.
"""
configure_logging_to_terminal(verbose)
logger.info("Generating DAG")
results_dir = get_results_directory(output_dir, no_timestamp).as_posix()
logger.info(f"Results directory: {results_dir}")
# TODO [MIC-4493]: Add configuration validation
runner.main(
main = handle_exceptions(
func=runner.main, exceptions_logger=logger, with_debugger=with_debugger
)
main(
command="generate_dag",
pipeline_specification=pipeline_specification,
input_data=input_data,
Expand Down
14 changes: 12 additions & 2 deletions src/easylink/graph_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import networkx as nx

Expand Down Expand Up @@ -45,8 +45,13 @@ class InputSlot:
"""A function that validates the input data being passed into the pipeline via
this ``InputSlot``. If the data is invalid, the function should raise an exception
with a descriptive error message which will then be reported to the user.
**Note that the function must be defined in the** :mod:`easylink.utilities.validation_utils`
**Note that the function *must* be defined in the** :mod:`easylink.utilities.validation_utils`
**module!**"""
splitter: Callable[[list[str], str, Any], None] | None = None
"""A function that splits the incoming data to this ``InputSlot`` into smaller
pieces. The primary purpose of this functionality is to run sections of the
pipeline in an embarrassingly parallel manner. **Note that the function *must*
be defined in the **:mod:`easylink.utilities.splitter_utils`** module!**"""


@dataclass(frozen=True)
Expand All @@ -70,6 +75,11 @@ class OutputSlot:

name: str
"""The name of the ``OutputSlot``."""
aggregator: Callable[[list[str], str], None] = None
"""A function that aggregates all of the generated data to be passed out via this
``OutputSlot``. The primary purpose of this functionality is to run sections
of the pipeline in an embarrassingly parallel manner. **Note that the function
*must* be defined in the **:py:mod:`easylink.utilities.aggregator_utils`** module!**"""


@dataclass(frozen=True)
Expand Down
2 changes: 2 additions & 0 deletions src/easylink/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
implementation_config: LayeredConfigTree,
input_slots: Iterable["InputSlot"] = (),
output_slots: Iterable["OutputSlot"] = (),
is_embarrassingly_parallel: bool = False,
):
self.name = implementation_config.name
"""The name of this ``Implementation``."""
Expand All @@ -63,6 +64,7 @@ def __init__(
implemented by this particular ``Implementation``."""
self.requires_spark = self._metadata.get("requires_spark", False)
"""Whether this ``Implementation`` requires a Spark environment."""
self.is_embarrassingly_parallel = is_embarrassingly_parallel

def __repr__(self) -> str:
return f"Implementation.{self.name}"
Expand Down
126 changes: 92 additions & 34 deletions src/easylink/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

from easylink.configuration import Config
from easylink.pipeline_graph import PipelineGraph
from easylink.rule import ImplementedRule, InputValidationRule, TargetRule
from easylink.rule import (
AggregationRule,
CheckpointRule,
ImplementedRule,
InputValidationRule,
TargetRule,
)
from easylink.utilities.general_utils import exit_with_validation_error
from easylink.utilities.paths import SPARK_SNAKEFILE
from easylink.utilities.validation_utils import validate_input_file_dummy
Expand All @@ -40,13 +46,17 @@ class Pipeline:
The :class:`~easylink.pipeline_graph.PipelineGraph` object.
spark_is_required
A boolean indicating whether the pipeline requires Spark.
any_embarrassingly_parallel
A boolean indicating whether any implementation in the pipeline is to be
run in an embarrassingly parallel manner.
"""

def __init__(self, config: Config):
self.config = config
self.pipeline_graph = PipelineGraph(config)
self.spark_is_required = self.pipeline_graph.spark_is_required()
self.spark_is_required = self.pipeline_graph.spark_is_required
self.any_embarrassingly_parallel = self.pipeline_graph.any_embarrassingly_parallel

# TODO [MIC-4880]: refactor into validation object
self._validate()
Expand Down Expand Up @@ -79,10 +89,10 @@ def build_snakefile(self) -> Path:
logger.warning("Snakefile already exists, overwriting.")
self.snakefile_path.unlink()
self._write_imports()
self._write_config()
self._write_wildcard_constraints()
self._write_spark_config()
self._write_target_rules()
if self.spark_is_required:
self._write_spark_module()
self._write_spark_module()
for node in self.pipeline_graph.implementation_nodes:
self._write_implementation_rules(node)
return self.snakefile_path
Expand Down Expand Up @@ -121,26 +131,35 @@ def _validate_implementations(self) -> dict:
return errors

def _write_imports(self) -> None:
"""Writes the necessary imports to the Snakefile."""
with open(self.snakefile_path, "a") as f:
f.write("from easylink.utilities import validation_utils")
if not self.any_embarrassingly_parallel:
imports = "from easylink.utilities import validation_utils\n"
else:
imports = """import glob
import os
def _write_config(self) -> None:
"""Writes configuration settings to the Snakefile.
from snakemake.exceptions import IncompleteCheckpointException
from snakemake.io import checkpoint_target
Notes
-----
This is currently only applicable for spark-dependent pipelines.
"""
from easylink.utilities import aggregator_utils, splitter_utils, validation_utils\n"""
with open(self.snakefile_path, "a") as f:
if self.spark_is_required:
f.write(imports)

def _write_wildcard_constraints(self) -> None:
if self.any_embarrassingly_parallel:
with open(self.snakefile_path, "a") as f:
f.write(
f"\nscattergather:\n\tnum_workers={self.config.spark_resources['num_workers']},"
"""
wildcard_constraints:
# never include '/' since those are reserved for filepaths
chunk="[^/]+",\n"""
)

def _write_target_rules(self) -> None:
"""Writes the rule for the final output and its validation."""
## The "input" files to the result node/the target rule are the final output themselves.
"""Writes the rule for the final output and its validation.
The input files to the the target rule (i.e. the result node) are the final
output themselves.
"""
final_output, _ = self.pipeline_graph.get_io_filepaths("results")
validator_file = str("input_validations/final_validator")
# Snakemake resolves the DAG based on the first rule, so we put the target
Expand All @@ -152,29 +171,43 @@ def _write_target_rules(self) -> None:
)
final_validation = InputValidationRule(
name="results",
slot_name="main_input",
input_slot_name="main_input",
input=final_output,
output=validator_file,
validator=validate_input_file_dummy,
)
target_rule.write_to_snakefile(self.snakefile_path)
final_validation.write_to_snakefile(self.snakefile_path)

def _write_spark_config(self) -> None:
"""Writes configuration settings to the Snakefile.
Notes
-----
This is currently only applicable for spark-dependent pipelines.
"""
if self.spark_is_required:
with open(self.snakefile_path, "a") as f:
f.write(
f"\nscattergather:\n\tnum_workers={self.config.spark_resources['num_workers']},"
)

def _write_spark_module(self) -> None:
"""Inserts the ``easylink.utilities.spark.smk`` Snakemake module into the Snakefile."""
if not self.spark_is_required:
return
slurm_resources = self.config.slurm_resources
spark_resources = self.config.spark_resources
with open(self.snakefile_path, "a") as f:
module = f"""
module = f"""
module spark_cluster:
snakefile: '{SPARK_SNAKEFILE}'
config: config
use rule * from spark_cluster
use rule terminate_spark from spark_cluster with:
input: rules.all.input.final_output"""
if self.config.computing_environment == "slurm":
module += f"""
if self.config.computing_environment == "slurm":
module += f"""
use rule start_spark_master from spark_cluster with:
resources:
slurm_account={slurm_resources['slurm_account']},
Expand All @@ -195,30 +228,57 @@ def _write_spark_module(self) -> None:
terminate_file_name=rules.terminate_spark.output,
user=os.environ["USER"],
cores={spark_resources['cpus_per_task']},
memory={spark_resources['mem_mb']}
"""
memory={spark_resources['mem_mb']}"""

with open(self.snakefile_path, "a") as f:
f.write(module)

def _write_implementation_rules(self, node_name: str) -> None:
"""Writes the rules for each :class:`~easylink.implementation.Implementation`.
This method writes *all* rules required for a given ``Implementation``,
e.g. splitters and aggregators (if necessary), validations, and the actual
rule to run the container itself.
Parameters
----------
node_name
The name of the ``Implementation`` to write the rule(s) for.
"""
implementation = self.pipeline_graph.nodes[node_name]["implementation"]

input_slots, output_slots = self.pipeline_graph.get_io_slot_attributes(node_name)
validation_files, validation_rules = self._get_validations(node_name, input_slots)
for validation_rule in validation_rules:
validation_rule.write_to_snakefile(self.snakefile_path)

_input_files, output_files = self.pipeline_graph.get_io_filepaths(node_name)
input_slots = self.pipeline_graph.get_input_slot_attributes(node_name)
is_embarrassingly_parallel = self.pipeline_graph.get_whether_embarrassingly_parallel(
node_name
)
if is_embarrassingly_parallel:
CheckpointRule(
name=node_name,
input_slots=input_slots,
validations=validation_files,
output=output_files,
).write_to_snakefile(self.snakefile_path)
for name, attrs in output_slots.items():
AggregationRule(
name=node_name,
input_slots=input_slots,
output_slot_name=name,
output_slot=attrs,
).write_to_snakefile(self.snakefile_path)

implementation = self.pipeline_graph.nodes[node_name]["implementation"]
diagnostics_dir = Path("diagnostics") / node_name
diagnostics_dir.mkdir(parents=True, exist_ok=True)
resources = (
self.config.slurm_resources
if self.config.computing_environment == "slurm"
else None
)
validation_files, validation_rules = self._get_validations(node_name, input_slots)
implementation_rule = ImplementedRule(
ImplementedRule(
name=node_name,
step_name=" and ".join(implementation.metadata_steps),
implementation_name=implementation.name,
Expand All @@ -231,10 +291,8 @@ def _write_implementation_rules(self, node_name: str) -> None:
image_path=implementation.singularity_image_path,
script_cmd=implementation.script_cmd,
requires_spark=implementation.requires_spark,
)
for validation_rule in validation_rules:
validation_rule.write_to_snakefile(self.snakefile_path)
implementation_rule.write_to_snakefile(self.snakefile_path)
is_embarrassingly_parallel=is_embarrassingly_parallel,
).write_to_snakefile(self.snakefile_path)

@staticmethod
def _get_validations(
Expand Down Expand Up @@ -262,7 +320,7 @@ def _get_validations(
validation_rules.append(
InputValidationRule(
name=node_name,
slot_name=input_slot_name,
input_slot_name=input_slot_name,
input=input_slot_attrs["filepaths"],
output=validation_file,
validator=input_slot_attrs["validator"],
Expand Down
Loading

0 comments on commit 664a4a5

Please sign in to comment.