Skip to content

Commit

Permalink
Treat Function Wrapper creation as customer code
Browse files Browse the repository at this point in the history
It turned out that Function Wrapper constructor calls constructor
for customer function which is customer code so Function Wrapper
creation failures need to be treat as customer code failures.
  • Loading branch information
eabatalov committed Jan 21, 2025
1 parent bc55a3b commit d5d41c1
Showing 1 changed file with 19 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ def __init__(
invocation_state: InvocationState,
logger: Any,
):
self._invocation_id: str = request.graph_invocation_id
self._graph_name: str = graph_name
self._graph_version: str = graph_version
self._function_name: str = function_name
self._function: Union[TensorlakeCompute, TensorlakeCompute] = function
self._invocation_state: InvocationState = invocation_state
self._logger = logger.bind(
graph_invocation_id=request.graph_invocation_id,
task_id=request.task_id,
Expand All @@ -42,16 +47,6 @@ def __init__(
self._func_stdout: io.StringIO = io.StringIO()
self._func_stderr: io.StringIO = io.StringIO()

self._function_wrapper: TensorlakeFunctionWrapper = TensorlakeFunctionWrapper(
indexify_function=function,
context=GraphInvocationContext(
invocation_id=request.graph_invocation_id,
graph_name=graph_name,
graph_version=graph_version,
invocation_state=invocation_state,
),
)

def run(self) -> RunTaskResponse:
"""Runs the task.
Expand Down Expand Up @@ -81,8 +76,18 @@ def _run_func_safe_and_captured(self, inputs: FunctionInputs) -> RunTaskResponse
)

def _run_func(self, inputs: FunctionInputs) -> RunTaskResponse:
if _is_router(self._function_wrapper):
result: RouterCallResult = self._function_wrapper.invoke_router(
func_wrapper = TensorlakeFunctionWrapper(
indexify_function=self._function,
context=GraphInvocationContext(
invocation_id=self._invocation_id,
graph_name=self._graph_name,
graph_version=self._graph_version,
invocation_state=self._invocation_state,
),
)

if _is_router(func_wrapper):
result: RouterCallResult = func_wrapper.invoke_router(
self._function_name, inputs.input
)
return self._response_helper.router_response(
Expand All @@ -91,12 +96,12 @@ def _run_func(self, inputs: FunctionInputs) -> RunTaskResponse:
stderr=self._func_stderr.getvalue(),
)
else:
result: FunctionCallResult = self._function_wrapper.invoke_fn_ser(
result: FunctionCallResult = func_wrapper.invoke_fn_ser(
self._function_name, inputs.input, inputs.init_value
)
return self._response_helper.function_response(
result=result,
is_reducer=_func_is_reducer(self._function_wrapper),
is_reducer=_func_is_reducer(func_wrapper),
stdout=self._func_stdout.getvalue(),
stderr=self._func_stderr.getvalue(),
)
Expand Down

0 comments on commit d5d41c1

Please sign in to comment.