Skip to content

Commit

Permalink
Only ask for confirmation once (#39)
Browse files Browse the repository at this point in the history
* Only ask for confirmation once

Signed-off-by: Marc Romeijn <[email protected]>

* Fix linting

Signed-off-by: Marc Romeyn <[email protected]>

* Fix failing test

Signed-off-by: Marc Romeyn <[email protected]>

* Detect torchrun

Signed-off-by: Marc Romeyn <[email protected]>

* Trying to fix failing test

Signed-off-by: Marc Romeyn <[email protected]>

* Fix failing tests

Signed-off-by: Marc Romeyn <[email protected]>

---------

Signed-off-by: Marc Romeijn <[email protected]>
Signed-off-by: Marc Romeyn <[email protected]>
  • Loading branch information
marcromeyn authored Oct 2, 2024
1 parent 7217163 commit 0efe5d3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
27 changes: 26 additions & 1 deletion src/nemo_run/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@
DEFAULT_NAME = "default"
EXECUTOR_CLASSES = [Executor, LocalExecutor, SkypilotExecutor, SlurmExecutor]
PLUGIN_CLASSES = [Plugin, List[Plugin]]
NEMORUN_SKIP_CONFIRMATION: Optional[bool] = None

INCLUDE_WORKSPACE_FILE = os.environ.get("INCLUDE_WORKSPACE_FILE", "true").lower() == "true"

logger = logging.getLogger(__name__)


def entrypoint(
fn: Optional[F] = None,
Expand Down Expand Up @@ -918,7 +921,24 @@ def _should_continue(self, skip_confirmation: bool) -> bool:
Returns:
bool: True if execution should continue, False otherwise.
"""
return skip_confirmation or typer.confirm("Continue?")
global NEMORUN_SKIP_CONFIRMATION

# If we're running under torchrun, always continue
if _is_torchrun():
logger.info("Detected torchrun environment. Skipping confirmation.")
return True

if NEMORUN_SKIP_CONFIRMATION is not None:
return NEMORUN_SKIP_CONFIRMATION

# If skip_confirmation is True or user confirms, continue
if skip_confirmation or typer.confirm("Continue?"):
NEMORUN_SKIP_CONFIRMATION = True
return True

# Otherwise, don't continue
NEMORUN_SKIP_CONFIRMATION = False
return False

def parse_fn(self, fn: T, args: List[str], **default_kwargs) -> Partial[T]:
"""
Expand Down Expand Up @@ -1357,6 +1377,11 @@ class MissingRequiredOptionError(RunContextError):
"""Raised when a required option is missing."""


def _is_torchrun() -> bool:
"""Check if running under torchrun."""
return "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1


if __name__ == "__main__":
app = create_cli()
app()
2 changes: 2 additions & 0 deletions test/cli/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import nemo_run as run
from nemo_run import cli, config
from nemo_run.cli import api as cli_api
from nemo_run.cli.api import Entrypoint, RunContext, create_cli

_RUN_FACTORIES_ENTRYPOINT: str = """
Expand Down Expand Up @@ -187,6 +188,7 @@ def test_run_context_execute_task_with_confirmation_denied(
self, mock_confirm, mock_run, mock_dryrun_fn, sample_function
):
ctx = RunContext(name="test_run")
cli_api.NEMORUN_SKIP_CONFIRMATION = None
ctx.cli_execute(sample_function, ["a=10", "b=hello"])
mock_dryrun_fn.assert_called_once()
mock_confirm.assert_called_once()
Expand Down
12 changes: 12 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
from typing import Any, Optional

import pytest
from invoke.config import Config
from invoke.context import Context

Expand All @@ -33,3 +34,14 @@ def run(self, command: str, **kwargs: Any):
kwargs["in_stream"] = False
runner = self.config.runners.local(self)
return self._run(runner, command, **kwargs)


@pytest.fixture(autouse=True)
def reset_nemorun_skip_confirmation():
from nemo_run.cli import api

"""Reset NEMORUN_SKIP_CONFIRMATION to None before each test."""
api.NEMORUN_SKIP_CONFIRMATION = None
yield
# Optionally, reset after the test as well
api.NEMORUN_SKIP_CONFIRMATION = None

0 comments on commit 0efe5d3

Please sign in to comment.