Skip to content

Commit

Permalink
Add Prefect Cloud IP Allowlist CLI (#15087)
Browse files Browse the repository at this point in the history
  • Loading branch information
collincchoy authored Aug 29, 2024
1 parent a428078 commit e81f363
Show file tree
Hide file tree
Showing 9 changed files with 1,393 additions and 486 deletions.
1 change: 1 addition & 0 deletions src/prefect/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import prefect.cli.artifact
import prefect.cli.block
import prefect.cli.cloud
import prefect.cli.cloud.ip_allowlist
import prefect.cli.cloud.webhook
import prefect.cli.shell
import prefect.cli.concurrency_limit
Expand Down
256 changes: 256 additions & 0 deletions src/prefect/cli/cloud/ip_allowlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import asyncio
from typing import Annotated, Optional

import typer
from pydantic import BaseModel, IPvAnyNetwork
from rich.panel import Panel
from rich.table import Table

from prefect.cli._types import PrefectTyper
from prefect.cli._utilities import exit_with_error, exit_with_success
from prefect.cli.cloud import cloud_app, confirm_logged_in
from prefect.cli.root import app
from prefect.client.cloud import get_cloud_client
from prefect.client.schemas.objects import IPAllowlist, IPAllowlistEntry
from prefect.exceptions import PrefectHTTPStatusError
from prefect.logging.loggers import get_logger

ip_allowlist_app = PrefectTyper(
name="ip-allowlist", help="Manage Prefect Cloud IP Allowlists"
)
cloud_app.add_typer(ip_allowlist_app, aliases=["ip-allowlists"])

logger = get_logger(__name__)


@ip_allowlist_app.callback()
def require_access_to_ip_allowlisting(ctx: typer.Context):
"""Enforce access to IP allowlisting for all subcommands."""
asyncio.run(_require_access_to_ip_allowlisting(ctx))


async def _require_access_to_ip_allowlisting(ctx: typer.Context):
"""Check if the account has access to IP allowlisting.
Exits with an error if the account does not have access to IP allowlisting.
On success, sets Typer context meta["enforce_ip_allowlist"] to
True if the account has IP allowlist enforcement enabled, False otherwise.
"""
confirm_logged_in()

async with get_cloud_client(infer_cloud_url=True) as client:
account_settings = await client.read_account_settings()

if "enforce_ip_allowlist" not in account_settings:
return exit_with_error("IP allowlisting is not available for this account.")

enforce_ip_allowlist = account_settings.get("enforce_ip_allowlist", False)
ctx.meta["enforce_ip_allowlist"] = enforce_ip_allowlist


@ip_allowlist_app.command()
async def enable(ctx: typer.Context):
"""Enable the IP allowlist for your account. When enabled, if the allowlist is non-empty, then access to your Prefect Cloud account will be restricted to only those IP addresses on the allowlist."""
enforcing_ip_allowlist = ctx.meta["enforce_ip_allowlist"]
if enforcing_ip_allowlist:
exit_with_success("IP allowlist is already enabled.")

async with get_cloud_client(infer_cloud_url=True) as client:
my_access_if_enabled = await client.check_ip_allowlist_access()
if not my_access_if_enabled.allowed:
exit_with_error(
f"Error enabling IP allowlist: {my_access_if_enabled.detail}"
)

logger.debug(my_access_if_enabled.detail)

if not typer.confirm(
"Enabling the IP allowlist will restrict Prefect Cloud API and UI access to only the IP addresses on the list. "
"Continue?"
):
exit_with_error("Aborted.")
await client.update_account_settings({"enforce_ip_allowlist": True})

exit_with_success("IP allowlist enabled.")


@ip_allowlist_app.command()
async def disable():
"""Disable the IP allowlist for your account. When disabled, all IP addresses will be allowed to access your Prefect Cloud account."""
async with get_cloud_client(infer_cloud_url=True) as client:
await client.update_account_settings({"enforce_ip_allowlist": False})

exit_with_success("IP allowlist disabled.")


@ip_allowlist_app.command()
async def ls(ctx: typer.Context):
"""Fetch and list all IP allowlist entries in your account."""
async with get_cloud_client(infer_cloud_url=True) as client:
ip_allowlist = await client.read_account_ip_allowlist()

_print_ip_allowlist_table(
ip_allowlist, enabled=ctx.meta["enforce_ip_allowlist"]
)


class IPNetworkArg(BaseModel):
raw: str
parsed: IPvAnyNetwork


def parse_ip_network_argument(val: str) -> IPNetworkArg:
return IPNetworkArg(
raw=val,
parsed=val, # type: ignore
)


IP_ARGUMENT = Annotated[
IPNetworkArg,
typer.Argument(
parser=parse_ip_network_argument,
help="An IP address or range in CIDR notation. E.g. 192.168.1.0 or 192.168.1.0/24",
metavar="IP address or range",
),
]


@ip_allowlist_app.command()
async def add(
ctx: typer.Context,
ip_address_or_range: IP_ARGUMENT,
description: Optional[str] = typer.Option(
None,
"--description",
"-d",
help="A short description to annotate the entry with.",
),
):
"""Add a new IP entry to your account IP allowlist."""
new_entry = IPAllowlistEntry(
ip_network=ip_address_or_range.parsed, description=description, enabled=True
)

async with get_cloud_client(infer_cloud_url=True) as client:
ip_allowlist = await client.read_account_ip_allowlist()

existing_entry_with_same_ip = None
for entry in ip_allowlist.entries:
if entry.ip_network == ip_address_or_range.parsed:
existing_entry_with_same_ip = entry
break

if existing_entry_with_same_ip:
if not typer.confirm(
f"There's already an entry for this IP ({ip_address_or_range.raw}). Do you want to overwrite it?"
):
exit_with_error("Aborted.")
ip_allowlist.entries.remove(existing_entry_with_same_ip)

ip_allowlist.entries.append(new_entry)

try:
await client.update_account_ip_allowlist(ip_allowlist)
except PrefectHTTPStatusError as exc:
_handle_update_error(exc)

updated_ip_allowlist = await client.read_account_ip_allowlist()
_print_ip_allowlist_table(
updated_ip_allowlist, enabled=ctx.meta["enforce_ip_allowlist"]
)


@ip_allowlist_app.command()
async def remove(ctx: typer.Context, ip_address_or_range: IP_ARGUMENT):
"""Remove an IP entry from your account IP allowlist."""
async with get_cloud_client(infer_cloud_url=True) as client:
ip_allowlist = await client.read_account_ip_allowlist()
ip_allowlist.entries = [
entry
for entry in ip_allowlist.entries
if entry.ip_network != ip_address_or_range.parsed
]

try:
await client.update_account_ip_allowlist(ip_allowlist)
except PrefectHTTPStatusError as exc:
_handle_update_error(exc)

updated_ip_allowlist = await client.read_account_ip_allowlist()
_print_ip_allowlist_table(
updated_ip_allowlist, enabled=ctx.meta["enforce_ip_allowlist"]
)


@ip_allowlist_app.command()
async def toggle(ctx: typer.Context, ip_address_or_range: IP_ARGUMENT):
"""Toggle the enabled status of an individual IP entry in your account IP allowlist."""
async with get_cloud_client(infer_cloud_url=True) as client:
ip_allowlist = await client.read_account_ip_allowlist()

found_matching_entry = False
for entry in ip_allowlist.entries:
if entry.ip_network == ip_address_or_range.parsed:
entry.enabled = not entry.enabled
found_matching_entry = True
break

if not found_matching_entry:
exit_with_error(
f"No entry found with IP address `{ip_address_or_range.raw}`."
)

try:
await client.update_account_ip_allowlist(ip_allowlist)
except PrefectHTTPStatusError as exc:
_handle_update_error(exc)

updated_ip_allowlist = await client.read_account_ip_allowlist()
_print_ip_allowlist_table(
updated_ip_allowlist, enabled=ctx.meta["enforce_ip_allowlist"]
)


def _print_ip_allowlist_table(ip_allowlist: IPAllowlist, enabled: bool):
if not ip_allowlist.entries:
app.console.print(
Panel(
"IP allowlist is empty. Add an entry to secure access to your Prefect Cloud account.",
expand=False,
)
)
return

red_asterisk_if_not_enabled = "[red]*[/red]" if enabled is False else ""

table = Table(
title="IP Allowlist " + red_asterisk_if_not_enabled,
caption=f"{red_asterisk_if_not_enabled} Enforcement is "
f"[bold]{'ENABLED' if enabled else '[red]DISABLED[/red]'}[/bold].",
caption_style="not dim",
)

table.add_column("IP Address", style="cyan", no_wrap=True)
table.add_column("Description", style="blue", no_wrap=False)
table.add_column("Enabled", style="green", justify="right", no_wrap=True)
table.add_column("Last Seen", style="magenta", justify="right", no_wrap=True)

for entry in ip_allowlist.entries:
table.add_row(
str(entry.ip_network),
entry.description,
str(entry.enabled),
entry.last_seen or "Never",
style="dim" if not entry.enabled else None,
)

app.console.print(table)


def _handle_update_error(error: PrefectHTTPStatusError):
if error.response.status_code == 422 and (
details := error.response.json().get("detail")
):
exit_with_error(f"Error updating allowlist: {details}")
63 changes: 56 additions & 7 deletions src/prefect/client/cloud.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast

import anyio
import httpx
Expand All @@ -9,7 +9,11 @@
import prefect.context
import prefect.settings
from prefect.client.base import PrefectHttpxAsyncClient
from prefect.client.schemas.objects import Workspace
from prefect.client.schemas.objects import (
IPAllowlist,
IPAllowlistMyAccessResponse,
Workspace,
)
from prefect.exceptions import ObjectNotFound, PrefectException
from prefect.settings import (
PREFECT_API_KEY,
Expand Down Expand Up @@ -69,6 +73,26 @@ def __init__(
**httpx_settings, enable_csrf_support=False
)

if match := (
re.search(PARSE_API_URL_REGEX, host)
or re.search(PARSE_API_URL_REGEX, prefect.settings.PREFECT_API_URL.value())
):
self.account_id, self.workspace_id = match.groups()

@property
def account_base_url(self) -> str:
if not self.account_id:
raise ValueError("Account ID not set")

return f"accounts/{self.account_id}"

@property
def workspace_base_url(self) -> str:
if not self.workspace_id:
raise ValueError("Workspace ID not set")

return f"{self.account_base_url}/workspaces/{self.workspace_id}"

async def api_healthcheck(self):
"""
Attempts to connect to the Cloud API and raises the encountered exception if not
Expand All @@ -86,11 +110,36 @@ async def read_workspaces(self) -> List[Workspace]:
return workspaces

async def read_worker_metadata(self) -> Dict[str, Any]:
configured_url = prefect.settings.PREFECT_API_URL.value()
account_id, workspace_id = re.findall(PARSE_API_URL_REGEX, configured_url)[0]
return await self.get(
f"accounts/{account_id}/workspaces/{workspace_id}/collections/work_pool_types"
response = await self.get(
f"{self.workspace_base_url}/collections/work_pool_types"
)
return cast(Dict[str, Any], response)

async def read_account_settings(self) -> Dict[str, Any]:
response = await self.get(f"{self.account_base_url}/settings")
return cast(Dict[str, Any], response)

async def update_account_settings(self, settings: Dict[str, Any]):
await self.request(
"PATCH",
f"{self.account_base_url}/settings",
json=settings,
)

async def read_account_ip_allowlist(self) -> IPAllowlist:
response = await self.get(f"{self.account_base_url}/ip_allowlist")
return IPAllowlist.model_validate(response)

async def update_account_ip_allowlist(self, updated_allowlist: IPAllowlist):
await self.request(
"PUT",
f"{self.account_base_url}/ip_allowlist",
json=updated_allowlist.model_dump(mode="json"),
)

async def check_ip_allowlist_access(self) -> IPAllowlistMyAccessResponse:
response = await self.get(f"{self.account_base_url}/ip_allowlist/my_access")
return IPAllowlistMyAccessResponse.model_validate(response)

async def __aenter__(self):
await self._client.__aenter__()
Expand Down Expand Up @@ -120,7 +169,7 @@ async def request(self, method, route, **kwargs):
status.HTTP_401_UNAUTHORIZED,
status.HTTP_403_FORBIDDEN,
):
raise CloudUnauthorizedError
raise CloudUnauthorizedError(str(exc)) from exc
elif exc.response.status_code == status.HTTP_404_NOT_FOUND:
raise ObjectNotFound(http_exc=exc) from exc
else:
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/client/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ def get_collections_metadata_client(
"""
orchestration_client = get_client(httpx_settings=httpx_settings)
if orchestration_client.server_type == ServerType.CLOUD:
return get_cloud_client(httpx_settings=httpx_settings)
return get_cloud_client(httpx_settings=httpx_settings, infer_cloud_url=True)
else:
return orchestration_client
Loading

0 comments on commit e81f363

Please sign in to comment.