Skip to content

Commit

Permalink
Adding yaml + lazy execution
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Romeyn <[email protected]>
  • Loading branch information
marcromeyn committed Oct 28, 2024
1 parent c06ac86 commit d68f9e8
Show file tree
Hide file tree
Showing 12 changed files with 1,256 additions and 281 deletions.
4 changes: 4 additions & 0 deletions src/nemo_run/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@
from nemo_run.core.tunnel.client import LocalTunnel, SSHTunnel
from nemo_run.devspace.base import DevSpace
from nemo_run.help import help
from nemo_run.lazy import LazyEntrypoint, lazy_imports
from nemo_run.run.api import run
from nemo_run.run.experiment import Experiment
from nemo_run.run.plugin import ExperimentPlugin as Plugin

__all__ = [
"autoconvert",
"cli",
"dryrun_fn",
"lazy_imports",
"LazyEntrypoint",
"Config",
"ConfigurableMixin",
"DevSpace",
Expand Down
171 changes: 138 additions & 33 deletions src/nemo_run/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from nemo_run.core.execution import LocalExecutor, SkypilotExecutor, SlurmExecutor
from nemo_run.core.execution.base import Executor
from nemo_run.core.frontend.console.styles import BOX_STYLE, TABLE_STYLES
from nemo_run.lazy import LazyEntrypoint

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo_run.lazy
begins an import cycle.
from nemo_run.run.experiment import Experiment
from nemo_run.run.plugin import ExperimentPlugin as Plugin

Expand All @@ -78,6 +79,7 @@
INCLUDE_WORKSPACE_FILE = os.environ.get("INCLUDE_WORKSPACE_FILE", "true").lower() == "true"

logger = logging.getLogger(__name__)
MAIN_ENTRYPOINT = None


def entrypoint(
Expand Down Expand Up @@ -234,9 +236,26 @@ def my_cli_function():
if __name__ == "__main__":
main(my_cli_function, default_factory=my_custom_defaults)
"""
lazy_cli = os.environ.get("LAZY_CLI", "false").lower() == "true"

if not isinstance(fn, EntrypointProtocol):
# Wrap the function with the default entrypoint
fn = entrypoint(**kwargs)(fn)
if getattr(fn, "__is_lazy__", False):
if lazy_cli:
fn = fn._import()
else:
app = typer.Typer()
RunContext.cli_command(
app,
sys.argv[1] if len(sys.argv) > 1 else "default",
LazyEntrypoint(" ".join(sys.argv)),
type="task",
default_factory=default_factory,
default_executor=default_executor,
default_plugins=default_plugins,
)
return app(standalone_mode=False)

_original_default_factory = fn.cli_entrypoint.default_factory
if default_factory:
Expand All @@ -256,6 +275,11 @@ def my_cli_function():
raise ValueError("default_plugins must be a Config object")
fn.cli_entrypoint.default_plugins = default_plugins

if lazy_cli:
global MAIN_ENTRYPOINT
MAIN_ENTRYPOINT = fn.cli_entrypoint
return

fn.cli_entrypoint.main()

fn.cli_entrypoint.default_factory = _original_default_factory
Expand Down Expand Up @@ -491,24 +515,47 @@ def create_cli(
metadata.entry_points().select(group="nemo_run.cli")
for ep in entrypoints:
_get_or_add_typer(app, name=ep.name)
is_lazy = "--lazy" in sys.argv

if not nested_entrypoints_creation or (len(sys.argv) > 1 and sys.argv[1] in entrypoints.names):
_add_typer_nested(app, list_entrypoints())
app: Typer = Typer(add_completion=not is_lazy)
if is_lazy:
if len(sys.argv) > 1 and sys.argv[1] in ["devspace", "experiment"]:
raise ValueError("Lazy CLI does not support devspace and experiment commands.")

app.add_typer(
devspace_cli.create(),
name="devspace",
help="[Module] Manage devspaces",
cls=GeneralCommand,
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
app.add_typer(
experiment_cli.create(),
name="experiment",
help="[Module] Manage Experiments",
cls=GeneralCommand,
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
# remove --lazy from sys.argv
sys.argv = [arg for arg in sys.argv if arg != "--lazy"]

RunContext.cli_command(
app,
sys.argv[1],
LazyEntrypoint(" ".join(sys.argv)),
type="task",
)
else:
entrypoints = metadata.entry_points().select(group="nemo_run.cli")
metadata.entry_points().select(group="nemo_run.cli")
for ep in entrypoints:
_get_or_add_typer(app, name=ep.name)

if not nested_entrypoints_creation or (
len(sys.argv) > 1 and sys.argv[1] in entrypoints.names
):
_add_typer_nested(app, list_entrypoints())

app.add_typer(
devspace_cli.create(),
name="devspace",
help="[Module] Manage devspaces",
cls=GeneralCommand,
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
app.add_typer(
experiment_cli.create(),
name="experiment",
help="[Module] Manage Experiments",
cls=GeneralCommand,
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)

if add_verbose_callback:
app.callback()(global_options)
Expand Down Expand Up @@ -711,6 +758,7 @@ class RunContext:
detach: bool = False
skip_confirmation: bool = False
tail_logs: bool = False
yaml: Optional[str] = None

executor: Optional[Executor] = field(init=False)
plugins: List[Plugin] = field(init=False)
Expand All @@ -720,7 +768,7 @@ def cli_command(
cls,
parent: typer.Typer,
name: str,
fn: Callable,
fn: Callable | LazyEntrypoint,
default_factory: Optional[Callable] = None,
default_executor: Optional[Executor] = None,
default_plugins: Optional[List[Plugin]] = None,
Expand All @@ -747,7 +795,7 @@ def cli_command(
**command_kwargs,
)
def command(
name: str = typer.Option(None, "--name", "-n", help="Name of the run"),
run_name: str = typer.Option(None, "--name", "-n", help="Name of the run"),
direct: bool = typer.Option(
False, "--direct/--no-direct", help="Execute the run directly"
),
Expand All @@ -760,6 +808,9 @@ def command(
load: Optional[str] = typer.Option(
None, "--load", "-l", help="Load a factory from a directory"
),
yaml: Optional[str] = typer.Option(
None, "--yaml", "-y", help="Path to a YAML file to load"
),
repl: bool = typer.Option(False, "--repl", "-r", help="Enter interactive mode"),
detach: bool = typer.Option(False, "--detach", help="Detach from the run"),
skip_confirmation: bool = typer.Option(
Expand All @@ -771,11 +822,12 @@ def command(
ctx: typer.Context = typer.Context,
):
self = cls(
name=name,
name=run_name or name,
direct=direct,
dryrun=dryrun,
factory=factory or default_factory,
load=load,
yaml=yaml,
repl=repl,
detach=detach,
skip_confirmation=skip_confirmation,
Expand All @@ -788,9 +840,20 @@ def command(
if default_plugins:
self.plugins = default_plugins

_load_entrypoints()
_load_workspace()
self.cli_execute(fn, ctx.args, type)
if isinstance(fn, LazyEntrypoint):
self.execute_lazy(fn, sys.argv, name)
return

try:
_load_entrypoints()
_load_workspace()
self.cli_execute(fn, ctx.args, type)
except RunContextError as e:
typer.echo(f"Error: {str(e)}", err=True, color=True)
raise typer.Exit(code=1)
except Exception as e:
typer.echo(f"Unexpected error: {str(e)}", err=True, color=True)
raise typer.Exit(code=1)

return command

Expand Down Expand Up @@ -894,6 +957,51 @@ def run_task():

run_task()

def execute_lazy(self, entrypoint: LazyEntrypoint, args: List[str], name: str):
console = Console()

import nemo_run as run

if self.dryrun:
raise ValueError("Dry run is not supported for lazy execution")

if self.repl:
raise ValueError("Interactive mode is not supported for lazy execution")

if self.direct:
raise ValueError("Direct execution is not supported for lazy execution")

_, run_args, args = _parse_prefixed_args(args, "run")
self.parse_args(run_args, lazy=True)

cmd, cmd_args, i_self = "", [], 0
for i, arg in enumerate(sys.argv):
if arg == name:
i_self = i
if i_self == 0:
cmd += f" {arg}"

elif "=" not in arg and not arg.startswith("--"):
cmd += f" {arg}"
elif "=" in arg and not arg.startswith("--"):
cmd_args.append(arg)

to_run = LazyEntrypoint(cmd, factory=self.factory)
to_run._add_overwrite(*cmd_args)

if self._should_continue(self.skip_confirmation):
console.print(f"[bold cyan]Launching {self.name}...[/bold cyan]")
run.run(
fn_or_script=to_run,
name=self.name,
executor=self.executor,
plugins=self.plugins,
direct=False,
detach=self.detach,
)
else:
console.print("[bold cyan]Exiting...[/bold cyan]")

def _execute_experiment(self, fn: Callable, experiment_args: List[str]):
"""
Execute an experiment.
Expand Down Expand Up @@ -952,16 +1060,11 @@ def parse_fn(self, fn: T, args: List[str], **default_kwargs) -> Partial[T]:
Returns:
Partial[T]: A Partial object representing the parsed function and arguments.
"""
if self.factory:
if isinstance(self.factory, Callable):
output = self.factory()
else:
output = parse_factory(fn, "factory", fn, self.factory)
parse_cli_args(output, args, output)
else:
output = self._parse_partial(fn, args, **default_kwargs)
output = LazyEntrypoint(fn, factory=self.factory, yaml=self.yaml)
if args:
output._add_overwrite(*args)

return output
return output.resolve()

def _parse_partial(self, fn: Callable, args: List[str], **default_args) -> Partial[T]:
"""
Expand All @@ -980,7 +1083,7 @@ def _parse_partial(self, fn: Callable, args: List[str], **default_args) -> Parti
setattr(config, key, value)
return config

def parse_args(self, args: List[str]):
def parse_args(self, args: List[str], lazy: bool = False):
"""
Parse the given arguments and update the RunContext accordingly.
Expand Down Expand Up @@ -1016,9 +1119,11 @@ def parse_args(self, args: List[str]):
if self.plugins:
self.plugins = fdl.build(self.plugins)

if args:
if not lazy and args:
parse_cli_args(self, args, self)

return args

def parse_executor(self, name: str, *args: str) -> Partial[Executor]:
"""
Parse the executor configuration.
Expand Down
Loading

0 comments on commit d68f9e8

Please sign in to comment.