diff --git a/src/nemo_run/__init__.py b/src/nemo_run/__init__.py index 3e1c071..7ec395e 100644 --- a/src/nemo_run/__init__.py +++ b/src/nemo_run/__init__.py @@ -16,7 +16,13 @@ from nemo_run import cli from nemo_run.api import autoconvert, dryrun_fn from nemo_run.config import Config, ConfigurableMixin, Partial, Script -from nemo_run.core.execution.base import Executor, ExecutorMacros, FaultTolerance, Torchrun +from nemo_run.core.execution.base import ( + Executor, + ExecutorMacros, + FaultTolerance, + Torchrun, + get_executor, +) from nemo_run.core.execution.docker import DockerExecutor from nemo_run.core.execution.local import LocalExecutor from nemo_run.core.execution.skypilot import SkypilotExecutor @@ -40,6 +46,7 @@ "DockerExecutor", "dryrun_fn", "Executor", + "get_executor", "ExecutorMacros", "Experiment", "FaultTolerance", diff --git a/src/nemo_run/core/execution/base.py b/src/nemo_run/core/execution/base.py index aa11e73..8de7730 100644 --- a/src/nemo_run/core/execution/base.py +++ b/src/nemo_run/core/execution/base.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import importlib.util import os from dataclasses import asdict, dataclass, field from string import Template @@ -23,7 +24,7 @@ from torchx.specs import Role from typing_extensions import Self -from nemo_run.config import ConfigurableMixin +from nemo_run.config import NEMORUN_HOME, ConfigurableMixin from nemo_run.core.packaging.base import Packager @@ -226,3 +227,49 @@ def package_configs(self, *cfgs: tuple[str, str]) -> list[str]: return filenames def cleanup(self, handle: str): ... + + +def get_executor(name: str, file_path: Optional[str] = None) -> Executor: + """ + Retrieves an executor instance by name from a specified or default Python file. + The file must contain a global dict called EXECUTOR_MAP, which maps executor names to their corresponding instances. + + This function dynamically imports the file_path, searches for the EXECUTOR_MAP dictionary + and returns the value corresponding to the given name. + + This functionality allows you to define all your executors in a single file which lives separately from your codebase. + It is similar to ~/.ssh/config and allows you to use executors across your projects without having to redefine them. + + Example: + executor = get_executor("local", file_path="path/to/executors.py") + executor = get_executor("gpu") # Uses the default location of os.path.join(NEMORUN_HOME, "executors.py") + + Args: + name (str): The name of the executor to retrieve. + file_path (Optional[str]): The path to the Python file containing the executor definitions. + Defaults to None, in which case the default location of os.path.join(NEMORUN_HOME, "executors.py") is used. + + The file_path is expected to be a string representing a file path with the following structure: + - It should be a path to a Python file (with a .py extension). + - The file should contain a dictionary named `EXECUTOR_MAP` that maps executor names to their corresponding instances. + - The file can be located anywhere in the file system, but if not provided, it defaults to `NEMORUN_HOME/executors.py`. + + Returns: + Executor: The executor instance corresponding to the given name. + + Raises: + AttributeError: If the file at the specified path does not contain an `EXECUTOR_MAP` dictionary. + AssertionError: If the given executor name is not found in the `EXECUTOR_MAP` dictionary. + """ + + if not file_path: + file_path = os.path.join(NEMORUN_HOME, "executors.py") + + spec = importlib.util.spec_from_file_location("executors", file_path) + assert spec + module = importlib.util.module_from_spec(spec) + assert spec.loader + spec.loader.exec_module(module) + executor_map = getattr(module, "EXECUTOR_MAP") + assert name in executor_map, f"Executor {name} not found." + return executor_map[name]