From c0704e92fe1ad829a151dc0ad496f8138c0be701 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Thu, 30 Jan 2025 18:44:26 -0800 Subject: [PATCH] retry on layer above Signed-off-by: Ruiyang Wang --- python/ray/dashboard/modules/job/job_head.py | 114 ++++++++++--------- 1 file changed, 61 insertions(+), 53 deletions(-) diff --git a/python/ray/dashboard/modules/job/job_head.py b/python/ray/dashboard/modules/job/job_head.py index bca0be8c727d..ab6ccc99a200 100644 --- a/python/ray/dashboard/modules/job/job_head.py +++ b/python/ray/dashboard/modules/job/job_head.py @@ -195,6 +195,23 @@ async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]: 2. if not, randomly select one agent from all available agents, it is possible that the selected one already exists in `self._agents`. + + If there's no agent available at all, or there's exception, it will retry every + `TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS` seconds indefinitely. + """ + while True: + try: + return await self._pick_random_agent_once() + except Exception: + logger.exception( + f"Failed to fetch all agent infos, retrying in {TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS} seconds..." + ) + await asyncio.sleep(TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) + + async def _pick_random_agent_once(self) -> JobAgentSubmissionClient: + """ + Query the internal kv for all agent infos, and pick agents randomly. May raise + exception if there's no agent available at all or there's network error. """ # NOTE: Following call will block until there's at least 1 agent info # being populated from GCS @@ -214,21 +231,20 @@ async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]: node_id = choice(list(agent_node_ids)) if node_id not in self._agents: - # Fetch agent info from InternalKV, and create a new JobAgentSubmissionClient. + # Fetch agent info from InternalKV, and create a new + # JobAgentSubmissionClient. May raise if the node_id is removed in + # InternalKV after the _fetch_all_agent_node_ids, though unlikely. ip, http_port, grpc_port = await self._fetch_agent_info(node_id) agent_http_address = f"http://{ip}:{http_port}" self._agents[node_id] = JobAgentSubmissionClient(agent_http_address) return self._agents[node_id] - async def _get_head_node_agent(self) -> Optional[JobAgentSubmissionClient]: - """Retrieves HTTP client for `JobAgent` running on the Head node""" - + async def _get_head_node_agent_once(self) -> JobAgentSubmissionClient: head_node_id_hex = await get_head_node_id(self.gcs_aio_client) if not head_node_id_hex: - logger.warning("Head node id has not yet been persisted in GCS") - return None + raise Exception("Head node id has not yet been persisted in GCS") head_node_id = NodeID.from_hex(head_node_id_hex) @@ -239,66 +255,58 @@ async def _get_head_node_agent(self) -> Optional[JobAgentSubmissionClient]: return self._agents[head_node_id] - async def _fetch_all_agent_node_ids(self) -> List[NodeID]: - """ - Fetches all NodeIDs with agent infos in the cluster. - - If there's no agent available at all, or there's exception, it will retry every + async def _get_head_node_agent(self) -> JobAgentSubmissionClient: + """Retrieves HTTP client for `JobAgent` running on the Head node. If the head + node does not have an agent, it will retry every `TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS` seconds indefinitely. - - Returns: List[NodeID] """ while True: try: - keys = await self.gcs_aio_client.internal_kv_keys( - f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}".encode(), - namespace=KV_NAMESPACE_DASHBOARD, - timeout=GCS_RPC_TIMEOUT_SECONDS, - ) - if not keys: - # No agent keys found, retry - raise Exception() - return [ - NodeID.from_hex( - key[len(DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX) :].decode() - ) - for key in keys - ] - + return await self._get_head_node_agent_once() except Exception: - logger.info( - f"Failed to fetch all agent infos, retrying in {TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS} seconds..." + logger.exception( + f"Failed to get head node agent, retrying in {TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS} seconds..." ) await asyncio.sleep(TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) - async def _fetch_agent_info(self, target_node_id: NodeID) -> Tuple[str, int, int]: + async def _fetch_all_agent_node_ids(self) -> List[NodeID]: """ - Fetches agent info by the Node ID. + Fetches all NodeIDs with agent infos in the cluster. - If the agent info is not found, it will retry every - `TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS` seconds indefinitely. + May raise exception if there's no agent available at all or there's network error. + Returns: List[NodeID] + """ + keys = await self.gcs_aio_client.internal_kv_keys( + f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}".encode(), + namespace=KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + if not keys: + # No agent keys found, retry + raise Exception("No agents found in InternalKV.") + return [ + NodeID.from_hex(key[len(DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX) :].decode()) + for key in keys + ] - Returns: (ip, http_port, grpc_port) + async def _fetch_agent_info(self, target_node_id: NodeID) -> Tuple[str, int, int]: """ + Fetches agent info by the Node ID. May raise exception if there's network error or the + agent info is not found. - while True: - try: - key = f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{target_node_id.hex()}" - value = await self.gcs_aio_client.internal_kv_get( - key, - namespace=KV_NAMESPACE_DASHBOARD, - timeout=GCS_RPC_TIMEOUT_SECONDS, - ) - if not value: - # Agent info not found, retry - raise Exception("Agent info not found in internal kv") - return json.loads(value.decode()) - - except Exception as e: - logger.info( - f"Failed to fetch agent info for node {target_node_id}: {e}. Retrying in {TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS} seconds..." - ) - await asyncio.sleep(TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) + Returns: (ip, http_port, grpc_port) + """ + key = f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{target_node_id.hex()}" + value = await self.gcs_aio_client.internal_kv_get( + key, + namespace=KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + if not value: + raise KeyError( + f"Agent info not found in internal kv for node {target_node_id}" + ) + return json.loads(value.decode()) @routes.get("/api/version") async def get_version(self, req: Request) -> Response: