Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hack (don't merge unless really needed): Overwrite task version id in manual function executor management mode #1201

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@ class FunctionExecutorState:
under the lock.
"""

def __init__(self, function_id_with_version: str, function_id_without_version: str):
def __init__(
self,
function_id_with_version: str,
function_id_without_version: str,
function_version: str,
):
self.function_id_with_version: str = function_id_with_version
self.function_id_without_version: str = function_id_without_version
# All the fields below are protected by the lock.
self.lock: asyncio.Lock = asyncio.Lock()
self.function_version: str = function_version
self.is_shutdown: bool = False
self.function_executor: Optional[FunctionExecutor] = None
self.running_tasks: int = 0
Expand Down
3 changes: 3 additions & 0 deletions indexify/src/indexify/executor/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ async def _get_or_create_state(self, task: Task) -> FunctionExecutorState:
state = FunctionExecutorState(
function_id_with_version=_function_id_with_version(task),
function_id_without_version=id,
function_version=task.graph_version,
)
self._function_executor_states[id] = state
return self._function_executor_states[id]
Expand All @@ -76,6 +77,8 @@ async def _run_task_policy(self, state: FunctionExecutorState, task: Task) -> No
await state.wait_running_tasks_less(1)

if self._disable_automatic_function_executor_management:
# Hack: lie to task executor about the function version so it won't raise errors.
task.graph_version = state.function_version
return # Disable Function Executor destroy in manual management mode.

if state.function_id_with_version != _function_id_with_version(task):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,18 @@ def test_two_executors_only_one_function_executor_destroy(self):

# As invocations might land on dev Executor, we need to run the graph multiple times
# to ensure that we catch function executor getting destroyed on Executor A if it ever happens.
fe_ids_v1 = []
fe_ids_v1 = set()
for _ in range(10):
invocation_id = graph_v1.run(block_until_done=True)
output = graph_v1.output(invocation_id, "get_function_executor_id")
self.assertEqual(len(output), 1)
if output[0] not in fe_ids_v1:
fe_ids_v1.append(output[0])
fe_ids_v1.add(output[0])

self.assertGreaterEqual(len(fe_ids_v1), 1)
not_destroyable_fe_id = None
self.assertLessEqual(len(fe_ids_v1), 2)
if len(fe_ids_v1) == 2:
self.assertIn(destroyable_fe_id, fe_ids_v1)
not_destroyable_fe_id = fe_ids_v1.difference([destroyable_fe_id]).pop()

graph_v2 = Graph(
name=test_graph_name(self),
Expand All @@ -69,16 +72,17 @@ def test_two_executors_only_one_function_executor_destroy(self):
)
graph_v2 = RemoteGraph.deploy(graph_v2)

success_fe_ids_v2 = []
fe_ids_v2 = set()
for _ in range(10):
invocation_id = graph_v2.run(block_until_done=True)
output = graph_v2.output(invocation_id, "get_function_executor_id")
if len(output) == 1 and output[0] not in success_fe_ids_v2:
success_fe_ids_v2.append(output[0])
self.assertEqual(len(output), 1)
fe_ids_v2.add(output[0])

# Executor A should fail in all v2 invokes because it won't destroy its function executor and
# the function executor will raise an error because the task version must be v1 not v2.
self.assertEqual(len(success_fe_ids_v2), 1)
self.assertLessEqual(len(fe_ids_v2), 2)
self.assertTrue(destroyable_fe_id not in fe_ids_v2)
if len(fe_ids_v2) == 2 and not_destroyable_fe_id is not None:
self.assertIn(not_destroyable_fe_id, fe_ids_v2)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tensorlake
Loading