From 4dd725041cdb6de336a4abe699e05d86b16c9756 Mon Sep 17 00:00:00 2001 From: Alex Thewsey <thewsey@amazon.com> Date: Wed, 17 Aug 2022 22:15:29 +0800 Subject: [PATCH] feat: CustomAttributes/context in handler fns Initial draft of a non-breaking option to support receiving an extra request context parameter in handler override functions (e.g. transform_fn, input_fn, predict_fn, output_fn) when the signature of the provided function includes the extra parameter. This should allow script mode users to access additional request context (such as the SageMaker CustomAttributes header) without breaking implementations using the existing function APIs. --- setup.py | 3 ++ src/sagemaker_inference/transformer.py | 35 +++++++++++-- test/unit/test_transfomer.py | 68 +++++++++++++++++++++++++- 3 files changed, 100 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 369e568..3b99339 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,9 @@ def read_version(): # enum is introduced in Python 3.4. Installing enum back port if sys.version_info < (3, 4): required_packages.append("enum34 >= 1.1.6") +# inspect.signature is introduced in Python 3.3. Installing funcsigs back port +if sys.version_info < (3, 3): + required_packages.append("funcsigs >= 1.0.2") PKG_NAME = "sagemaker_inference" diff --git a/src/sagemaker_inference/transformer.py b/src/sagemaker_inference/transformer.py index 70fde95..b3b6690 100644 --- a/src/sagemaker_inference/transformer.py +++ b/src/sagemaker_inference/transformer.py @@ -16,6 +16,12 @@ """ from __future__ import absolute_import +try: + from inspect import signature +except ImportError: + # Python<3.3 backport: + from funcsigs import signature + import importlib import traceback @@ -125,7 +131,17 @@ def transform(self, data, context): if content_type in content_types.UTF8_TYPES: input_data = input_data.decode("utf-8") - result = self._transform_fn(self._model, input_data, content_type, accept) + # If the configured transform_fn supports the optional request context argument, pass it + # through: + transform_args = [self._model, input_data, content_type, accept] + n_fn_args = len(signature(self._transform_fn).parameters) + if n_fn_args > 4: + # TODO: Probably wouldn't use the actual 'context' obj here? + # Some more curated, useful object could be preferred maybe... Just something so + # that users' override functions can access CustomAttributes header! (Maybe even + # modify it on the response?) + transform_args.append(context) + result = self._transform_fn(*transform_args) response = result response_content_type = accept @@ -214,7 +230,7 @@ def _validate_user_module_and_set_functions(self): self._transform_fn = self._default_transform_fn - def _default_transform_fn(self, model, input_data, content_type, accept): + def _default_transform_fn(self, model, input_data, content_type, accept, context): """Make predictions against the model and return a serialized response. This serves as the default implementation of transform_fn, used when the user has not provided an implementation. @@ -224,13 +240,22 @@ def _default_transform_fn(self, model, input_data, content_type, accept): input_data (obj): the request data. content_type (str): the request content type. accept (str): accept header expected by the client. + context (?): TODO Returns: obj: the serialized prediction result or a tuple of the form (response_data, content_type) """ - data = self._input_fn(input_data, content_type) - prediction = self._predict_fn(data, model) - result = self._output_fn(prediction, accept) + # If configured handler functions support the extra context arg, pass it through: + n_input_args = len(signature(self._input_fn).parameters) + n_predict_args = len(signature(self._predict_fn).parameters) + n_output_args = len(signature(self._output_fn).parameters) + + input_args = [input_data, content_type] + ([context] if n_input_args > 2 else []) + data = self._input_fn(*input_args) + predict_args = [data, model] + ([context] if n_predict_args > 2 else []) + prediction = self._predict_fn(*predict_args) + output_args = [prediction, accept] + ([context] if n_output_args > 2 else []) + result = self._output_fn(*output_args) return result diff --git a/test/unit/test_transfomer.py b/test/unit/test_transfomer.py index be0cfa1..ab87b7c 100644 --- a/test/unit/test_transfomer.py +++ b/test/unit/test_transfomer.py @@ -18,6 +18,12 @@ except ImportError: import httplib as http_client +try: + from inspect import signature +except ImportError: + # Python<3.3 backport: + from funcsigs import signature + from sagemaker_inference import content_types, environment from sagemaker_inference.default_inference_handler import DefaultInferenceHandler from sagemaker_inference.errors import BaseInferenceToolkitError @@ -29,6 +35,7 @@ DEFAULT_ACCEPT = "default_accept" RESULT = "result" MODEL = "foo" +REQ_CONTEXT = object() PREPROCESSED_DATA = "preprocessed_data" PREDICT_RESULT = "prediction_result" @@ -90,12 +97,46 @@ def test_transform(validate, retrieve_content_type_header, accept_key): validate.assert_called_once() retrieve_content_type_header.assert_called_once_with(request_property) + # Since Mock()'s callable signature has only 2 args ('args', 'kwargs'), the extra 'context' arg + # will not be added in this case (see test_transform_with_context below): transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT) context.set_response_content_type.assert_called_once_with(0, ACCEPT) assert isinstance(result, list) assert result[0] == RESULT +@pytest.mark.parametrize("accept_key", ["Accept", "accept"]) +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize") +def test_transform_with_context(validate, retrieve_content_type_header, accept_key): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + + # Simulate setting a transform_fn that supports additional 'context' argument: + transform_fn = Mock(return_value=RESULT) + transform_fn.__signature__ = signature( + lambda model, input_data, content_type, accept, context: None + ) + + context.request_processor = [request_processor] + request_property = {accept_key: ACCEPT} + request_processor.get_request_properties.return_value = request_property + + transformer = Transformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + + result = transformer.transform(data, context) + + validate.assert_called_once() + retrieve_content_type_header.assert_called_once_with(request_property) + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, context) + context.set_response_content_type.assert_called_once_with(0, ACCEPT) + assert isinstance(result, list) + assert result[0] == RESULT + + @patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) @patch("sagemaker_inference.transformer.Transformer.validate_and_initialize") def test_transform_no_accept(validate, retrieve_content_type_header): @@ -417,6 +458,8 @@ def test_validate_user_module_error(find_spec, import_module, user_module): def test_default_transform_fn(): transformer = Transformer() + # Default Mock.__call__ signature has 2 args anyway (args, kwargs) so no signature hacking + # necessary for the context-free case: input_fn = Mock(return_value=PREPROCESSED_DATA) predict_fn = Mock(return_value=PREDICT_RESULT) output_fn = Mock(return_value=PROCESSED_RESULT) @@ -425,9 +468,32 @@ def test_default_transform_fn(): transformer._predict_fn = predict_fn transformer._output_fn = output_fn - result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT) + result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, REQ_CONTEXT) input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE) predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL) output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT) assert result == PROCESSED_RESULT + + +def test_default_transform_with_contextual_fns(): + transformer = Transformer() + + # Set the __signature__ on the mock functions to indicate the user overrides want req context: + input_fn = Mock(return_value=PREPROCESSED_DATA) + input_fn.__signature__ = signature(lambda input_data, content_type, context: None) + predict_fn = Mock(return_value=PREDICT_RESULT) + predict_fn.__signature__ = signature(lambda data, model, context: None) + output_fn = Mock(return_value=PROCESSED_RESULT) + output_fn.__signature__ = signature(lambda prediction, accept, context: None) + + transformer._input_fn = input_fn + transformer._predict_fn = predict_fn + transformer._output_fn = output_fn + + result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, REQ_CONTEXT) + + input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE, REQ_CONTEXT) + predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL, REQ_CONTEXT) + output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT, REQ_CONTEXT) + assert result == PROCESSED_RESULT