diff --git a/setup.py b/setup.py index 369e568..3b99339 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,9 @@ def read_version(): # enum is introduced in Python 3.4. Installing enum back port if sys.version_info < (3, 4): required_packages.append("enum34 >= 1.1.6") +# inspect.signature is introduced in Python 3.3. Installing funcsigs back port +if sys.version_info < (3, 3): + required_packages.append("funcsigs >= 1.0.2") PKG_NAME = "sagemaker_inference" diff --git a/src/sagemaker_inference/transformer.py b/src/sagemaker_inference/transformer.py index 70fde95..b3b6690 100644 --- a/src/sagemaker_inference/transformer.py +++ b/src/sagemaker_inference/transformer.py @@ -16,6 +16,12 @@ """ from __future__ import absolute_import +try: + from inspect import signature +except ImportError: + # Python<3.3 backport: + from funcsigs import signature + import importlib import traceback @@ -125,7 +131,17 @@ def transform(self, data, context): if content_type in content_types.UTF8_TYPES: input_data = input_data.decode("utf-8") - result = self._transform_fn(self._model, input_data, content_type, accept) + # If the configured transform_fn supports the optional request context argument, pass it + # through: + transform_args = [self._model, input_data, content_type, accept] + n_fn_args = len(signature(self._transform_fn).parameters) + if n_fn_args > 4: + # TODO: Probably wouldn't use the actual 'context' obj here? + # Some more curated, useful object could be preferred maybe... Just something so + # that users' override functions can access CustomAttributes header! (Maybe even + # modify it on the response?) + transform_args.append(context) + result = self._transform_fn(*transform_args) response = result response_content_type = accept @@ -214,7 +230,7 @@ def _validate_user_module_and_set_functions(self): self._transform_fn = self._default_transform_fn - def _default_transform_fn(self, model, input_data, content_type, accept): + def _default_transform_fn(self, model, input_data, content_type, accept, context): """Make predictions against the model and return a serialized response. This serves as the default implementation of transform_fn, used when the user has not provided an implementation. @@ -224,13 +240,22 @@ def _default_transform_fn(self, model, input_data, content_type, accept): input_data (obj): the request data. content_type (str): the request content type. accept (str): accept header expected by the client. + context (?): TODO Returns: obj: the serialized prediction result or a tuple of the form (response_data, content_type) """ - data = self._input_fn(input_data, content_type) - prediction = self._predict_fn(data, model) - result = self._output_fn(prediction, accept) + # If configured handler functions support the extra context arg, pass it through: + n_input_args = len(signature(self._input_fn).parameters) + n_predict_args = len(signature(self._predict_fn).parameters) + n_output_args = len(signature(self._output_fn).parameters) + + input_args = [input_data, content_type] + ([context] if n_input_args > 2 else []) + data = self._input_fn(*input_args) + predict_args = [data, model] + ([context] if n_predict_args > 2 else []) + prediction = self._predict_fn(*predict_args) + output_args = [prediction, accept] + ([context] if n_output_args > 2 else []) + result = self._output_fn(*output_args) return result diff --git a/test/unit/test_transfomer.py b/test/unit/test_transfomer.py index be0cfa1..ab87b7c 100644 --- a/test/unit/test_transfomer.py +++ b/test/unit/test_transfomer.py @@ -18,6 +18,12 @@ except ImportError: import httplib as http_client +try: + from inspect import signature +except ImportError: + # Python<3.3 backport: + from funcsigs import signature + from sagemaker_inference import content_types, environment from sagemaker_inference.default_inference_handler import DefaultInferenceHandler from sagemaker_inference.errors import BaseInferenceToolkitError @@ -29,6 +35,7 @@ DEFAULT_ACCEPT = "default_accept" RESULT = "result" MODEL = "foo" +REQ_CONTEXT = object() PREPROCESSED_DATA = "preprocessed_data" PREDICT_RESULT = "prediction_result" @@ -90,12 +97,46 @@ def test_transform(validate, retrieve_content_type_header, accept_key): validate.assert_called_once() retrieve_content_type_header.assert_called_once_with(request_property) + # Since Mock()'s callable signature has only 2 args ('args', 'kwargs'), the extra 'context' arg + # will not be added in this case (see test_transform_with_context below): transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT) context.set_response_content_type.assert_called_once_with(0, ACCEPT) assert isinstance(result, list) assert result[0] == RESULT +@pytest.mark.parametrize("accept_key", ["Accept", "accept"]) +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize") +def test_transform_with_context(validate, retrieve_content_type_header, accept_key): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + + # Simulate setting a transform_fn that supports additional 'context' argument: + transform_fn = Mock(return_value=RESULT) + transform_fn.__signature__ = signature( + lambda model, input_data, content_type, accept, context: None + ) + + context.request_processor = [request_processor] + request_property = {accept_key: ACCEPT} + request_processor.get_request_properties.return_value = request_property + + transformer = Transformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + + result = transformer.transform(data, context) + + validate.assert_called_once() + retrieve_content_type_header.assert_called_once_with(request_property) + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, context) + context.set_response_content_type.assert_called_once_with(0, ACCEPT) + assert isinstance(result, list) + assert result[0] == RESULT + + @patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) @patch("sagemaker_inference.transformer.Transformer.validate_and_initialize") def test_transform_no_accept(validate, retrieve_content_type_header): @@ -417,6 +458,8 @@ def test_validate_user_module_error(find_spec, import_module, user_module): def test_default_transform_fn(): transformer = Transformer() + # Default Mock.__call__ signature has 2 args anyway (args, kwargs) so no signature hacking + # necessary for the context-free case: input_fn = Mock(return_value=PREPROCESSED_DATA) predict_fn = Mock(return_value=PREDICT_RESULT) output_fn = Mock(return_value=PROCESSED_RESULT) @@ -425,9 +468,32 @@ def test_default_transform_fn(): transformer._predict_fn = predict_fn transformer._output_fn = output_fn - result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT) + result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, REQ_CONTEXT) input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE) predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL) output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT) assert result == PROCESSED_RESULT + + +def test_default_transform_with_contextual_fns(): + transformer = Transformer() + + # Set the __signature__ on the mock functions to indicate the user overrides want req context: + input_fn = Mock(return_value=PREPROCESSED_DATA) + input_fn.__signature__ = signature(lambda input_data, content_type, context: None) + predict_fn = Mock(return_value=PREDICT_RESULT) + predict_fn.__signature__ = signature(lambda data, model, context: None) + output_fn = Mock(return_value=PROCESSED_RESULT) + output_fn.__signature__ = signature(lambda prediction, accept, context: None) + + transformer._input_fn = input_fn + transformer._predict_fn = predict_fn + transformer._output_fn = output_fn + + result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, REQ_CONTEXT) + + input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE, REQ_CONTEXT) + predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL, REQ_CONTEXT) + output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT, REQ_CONTEXT) + assert result == PROCESSED_RESULT