Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wrong task routing tests for Function Executor #1172

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions indexify/src/indexify/function_executor/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,33 +86,37 @@ def run_task(
# If our code raises an exception the grpc framework converts it into GRPC_STATUS_UNKNOWN
# error with the exception message. Differentiating errors is not needed for now.
RunTaskRequestValidator(request=request).check()
self._check_task_routed_correctly(request)

return RunTaskHandler(
request=request,
graph_name=self._graph_name,
graph_version=self._graph_version,
function_name=self._function_name,
function=self._function,
invocation_state=ProxiedInvocationState(
request.task_id, self._invocation_state_proxy_server
),
logger=self._logger,
).run()

def _check_task_routed_correctly(self, request: RunTaskRequest):
# Fail with internal error as this happened due to wrong task routing to this Server.
# If we run the wrongly routed task then it can steal data from this Server if it belongs
# to a different customer.
if request.namespace != self._namespace:
raise ValueError(
f"This Function Executor is not initialized for this namespace {request.namespace}"
)
if request.graph_name != self._graph_name:
raise ValueError(
f"This Function Executor is not initialized for this graph {request.graph_name}"
f"This Function Executor is not initialized for this graph_name {request.graph_name}"
)
if request.graph_version != self._graph_version:
raise ValueError(
f"This Function Executor is not initialized for this graph version {request.graph_version}"
f"This Function Executor is not initialized for this graph_version {request.graph_version}"
)
if request.function_name != self._function_name:
raise ValueError(
f"This Function Executor is not initialized for this function {request.function_name}"
f"This Function Executor is not initialized for this function_name {request.function_name}"
)

return RunTaskHandler(
request=request,
graph_name=self._graph_name,
graph_version=self._graph_version,
function_name=self._function_name,
function=self._function,
invocation_state=ProxiedInvocationState(
request.task_id, self._invocation_state_proxy_server
),
logger=self._logger,
).run()
52 changes: 52 additions & 0 deletions indexify/tests/function_executor/test_run_task.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import unittest
from typing import List, Mapping

from grpc import RpcError
from pydantic import BaseModel
from tensorlake import Graph
from tensorlake.functions_sdk.data_objects import File
from tensorlake.functions_sdk.functions import tensorlake_function
from tensorlake.functions_sdk.object_serializer import CloudPickleSerializer
from testing import (
FunctionExecutorServerTestCase,
copy_and_modify_request,
deserialized_function_output,
run_task,
)

from indexify.function_executor.proto.function_executor_pb2 import (
InitializeRequest,
InitializeResponse,
RunTaskRequest,
RunTaskResponse,
SerializedObject,
)
Expand Down Expand Up @@ -143,6 +146,55 @@ def test_function_raises_error(self):
"this extractor throws an exception." in run_task_response.stderr
)

def test_wrong_task_routing(self):
with self._rpc_channel() as channel:
stub: FunctionExecutorStub = FunctionExecutorStub(channel)
initialize_response: InitializeResponse = stub.initialize(
InitializeRequest(
namespace="test",
graph_name="test",
graph_version="1",
function_name="extractor_b",
graph=SerializedObject(
bytes=CloudPickleSerializer.serialize(
create_graph_a().serialize(
additional_modules=[],
)
),
content_type=CloudPickleSerializer.content_type,
),
)
)
self.assertTrue(initialize_response.success)
valid_request: RunTaskRequest = RunTaskRequest(
namespace="test",
graph_name="test",
graph_version="1",
function_name="extractor_b",
graph_invocation_id="123",
task_id="test-task",
function_input=SerializedObject(
bytes=CloudPickleSerializer.serialize(input),
content_type=CloudPickleSerializer.content_type,
),
)
wrong_requests: List[RunTaskRequest] = [
copy_and_modify_request(
valid_request, {"namespace": "wrong-namespace"}
),
copy_and_modify_request(
valid_request, {"graph_name": "wrong-graph-name"}
),
copy_and_modify_request(
valid_request, {"graph_version": "wrong-graph-version"}
),
copy_and_modify_request(
valid_request, {"function_name": "wrong-function-name"}
),
]
for request in wrong_requests:
self.assertRaises(RpcError, stub.run_task, request)


if __name__ == "__main__":
unittest.main()
12 changes: 11 additions & 1 deletion indexify/tests/function_executor/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import subprocess
import unittest
from typing import Any, List
from typing import Any, Dict, List

import grpc
from tensorlake.functions_sdk.object_serializer import CloudPickleSerializer
Expand Down Expand Up @@ -89,3 +89,13 @@ def deserialized_function_output(
test_case.assertEqual(output.content_type, CloudPickleSerializer.content_type)
outputs.append(CloudPickleSerializer.deserialize(output.bytes))
return outputs


def copy_and_modify_request(
src: RunTaskRequest, modifications: Dict[str, Any]
) -> RunTaskRequest:
request = RunTaskRequest()
request.CopyFrom(src)
for key, value in modifications.items():
setattr(request, key, value)
return request
Loading