diff --git a/indexify/src/indexify/function_executor/handlers/run_function/handler.py b/indexify/src/indexify/function_executor/handlers/run_function/handler.py index 2af816cae..240e60528 100644 --- a/indexify/src/indexify/function_executor/handlers/run_function/handler.py +++ b/indexify/src/indexify/function_executor/handlers/run_function/handler.py @@ -30,7 +30,12 @@ def __init__( invocation_state: InvocationState, logger: Any, ): + self._invocation_id: str = request.graph_invocation_id + self._graph_name: str = graph_name + self._graph_version: str = graph_version self._function_name: str = function_name + self._function: Union[TensorlakeCompute, TensorlakeCompute] = function + self._invocation_state: InvocationState = invocation_state self._logger = logger.bind( graph_invocation_id=request.graph_invocation_id, task_id=request.task_id, @@ -42,16 +47,6 @@ def __init__( self._func_stdout: io.StringIO = io.StringIO() self._func_stderr: io.StringIO = io.StringIO() - self._function_wrapper: TensorlakeFunctionWrapper = TensorlakeFunctionWrapper( - indexify_function=function, - context=GraphInvocationContext( - invocation_id=request.graph_invocation_id, - graph_name=graph_name, - graph_version=graph_version, - invocation_state=invocation_state, - ), - ) - def run(self) -> RunTaskResponse: """Runs the task. @@ -81,8 +76,18 @@ def _run_func_safe_and_captured(self, inputs: FunctionInputs) -> RunTaskResponse ) def _run_func(self, inputs: FunctionInputs) -> RunTaskResponse: - if _is_router(self._function_wrapper): - result: RouterCallResult = self._function_wrapper.invoke_router( + func_wrapper = TensorlakeFunctionWrapper( + indexify_function=self._function, + context=GraphInvocationContext( + invocation_id=self._invocation_id, + graph_name=self._graph_name, + graph_version=self._graph_version, + invocation_state=self._invocation_state, + ), + ) + + if _is_router(func_wrapper): + result: RouterCallResult = func_wrapper.invoke_router( self._function_name, inputs.input ) return self._response_helper.router_response( @@ -91,12 +96,12 @@ def _run_func(self, inputs: FunctionInputs) -> RunTaskResponse: stderr=self._func_stderr.getvalue(), ) else: - result: FunctionCallResult = self._function_wrapper.invoke_fn_ser( + result: FunctionCallResult = func_wrapper.invoke_fn_ser( self._function_name, inputs.input, inputs.init_value ) return self._response_helper.function_response( result=result, - is_reducer=_func_is_reducer(self._function_wrapper), + is_reducer=_func_is_reducer(func_wrapper), stdout=self._func_stdout.getvalue(), stderr=self._func_stderr.getvalue(), )