From 4dd725041cdb6de336a4abe699e05d86b16c9756 Mon Sep 17 00:00:00 2001
From: Alex Thewsey <thewsey@amazon.com>
Date: Wed, 17 Aug 2022 22:15:29 +0800
Subject: [PATCH] feat: CustomAttributes/context in handler fns

Initial draft of a non-breaking option to support receiving an extra
request context parameter in handler override functions (e.g.
transform_fn, input_fn, predict_fn, output_fn) when the signature of
the provided function includes the extra parameter. This should allow
script mode users to access additional request context (such as the
SageMaker CustomAttributes header) without breaking implementations
using the existing function APIs.
---
 setup.py                               |  3 ++
 src/sagemaker_inference/transformer.py | 35 +++++++++++--
 test/unit/test_transfomer.py           | 68 +++++++++++++++++++++++++-
 3 files changed, 100 insertions(+), 6 deletions(-)

diff --git a/setup.py b/setup.py
index 369e568..3b99339 100644
--- a/setup.py
+++ b/setup.py
@@ -34,6 +34,9 @@ def read_version():
 # enum is introduced in Python 3.4. Installing enum back port
 if sys.version_info < (3, 4):
     required_packages.append("enum34 >= 1.1.6")
+# inspect.signature is introduced in Python 3.3. Installing funcsigs back port
+if sys.version_info < (3, 3):
+    required_packages.append("funcsigs >= 1.0.2")
 
 PKG_NAME = "sagemaker_inference"
 
diff --git a/src/sagemaker_inference/transformer.py b/src/sagemaker_inference/transformer.py
index 70fde95..b3b6690 100644
--- a/src/sagemaker_inference/transformer.py
+++ b/src/sagemaker_inference/transformer.py
@@ -16,6 +16,12 @@
 """
 from __future__ import absolute_import
 
+try:
+    from inspect import signature
+except ImportError:
+    # Python<3.3 backport:
+    from funcsigs import signature
+
 import importlib
 import traceback
 
@@ -125,7 +131,17 @@ def transform(self, data, context):
             if content_type in content_types.UTF8_TYPES:
                 input_data = input_data.decode("utf-8")
 
-            result = self._transform_fn(self._model, input_data, content_type, accept)
+            # If the configured transform_fn supports the optional request context argument, pass it
+            # through:
+            transform_args = [self._model, input_data, content_type, accept]
+            n_fn_args = len(signature(self._transform_fn).parameters)
+            if n_fn_args > 4:
+                # TODO: Probably wouldn't use the actual 'context' obj here?
+                # Some more curated, useful object could be preferred maybe... Just something so
+                # that users' override functions can access CustomAttributes header! (Maybe even
+                # modify it on the response?)
+                transform_args.append(context)
+            result = self._transform_fn(*transform_args)
 
             response = result
             response_content_type = accept
@@ -214,7 +230,7 @@ def _validate_user_module_and_set_functions(self):
 
             self._transform_fn = self._default_transform_fn
 
-    def _default_transform_fn(self, model, input_data, content_type, accept):
+    def _default_transform_fn(self, model, input_data, content_type, accept, context):
         """Make predictions against the model and return a serialized response.
         This serves as the default implementation of transform_fn, used when the
         user has not provided an implementation.
@@ -224,13 +240,22 @@ def _default_transform_fn(self, model, input_data, content_type, accept):
             input_data (obj): the request data.
             content_type (str): the request content type.
             accept (str): accept header expected by the client.
+            context (?): TODO
 
         Returns:
             obj: the serialized prediction result or a tuple of the form
                 (response_data, content_type)
 
         """
-        data = self._input_fn(input_data, content_type)
-        prediction = self._predict_fn(data, model)
-        result = self._output_fn(prediction, accept)
+        # If configured handler functions support the extra context arg, pass it through:
+        n_input_args = len(signature(self._input_fn).parameters)
+        n_predict_args = len(signature(self._predict_fn).parameters)
+        n_output_args = len(signature(self._output_fn).parameters)
+
+        input_args = [input_data, content_type] + ([context] if n_input_args > 2 else [])
+        data = self._input_fn(*input_args)
+        predict_args = [data, model] + ([context] if n_predict_args > 2 else [])
+        prediction = self._predict_fn(*predict_args)
+        output_args = [prediction, accept] + ([context] if n_output_args > 2 else [])
+        result = self._output_fn(*output_args)
         return result
diff --git a/test/unit/test_transfomer.py b/test/unit/test_transfomer.py
index be0cfa1..ab87b7c 100644
--- a/test/unit/test_transfomer.py
+++ b/test/unit/test_transfomer.py
@@ -18,6 +18,12 @@
 except ImportError:
     import httplib as http_client
 
+try:
+    from inspect import signature
+except ImportError:
+    # Python<3.3 backport:
+    from funcsigs import signature
+
 from sagemaker_inference import content_types, environment
 from sagemaker_inference.default_inference_handler import DefaultInferenceHandler
 from sagemaker_inference.errors import BaseInferenceToolkitError
@@ -29,6 +35,7 @@
 DEFAULT_ACCEPT = "default_accept"
 RESULT = "result"
 MODEL = "foo"
+REQ_CONTEXT = object()
 
 PREPROCESSED_DATA = "preprocessed_data"
 PREDICT_RESULT = "prediction_result"
@@ -90,12 +97,46 @@ def test_transform(validate, retrieve_content_type_header, accept_key):
 
     validate.assert_called_once()
     retrieve_content_type_header.assert_called_once_with(request_property)
+    # Since Mock()'s callable signature has only 2 args ('args', 'kwargs'), the extra 'context' arg
+    # will not be added in this case (see test_transform_with_context below):
     transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT)
     context.set_response_content_type.assert_called_once_with(0, ACCEPT)
     assert isinstance(result, list)
     assert result[0] == RESULT
 
 
+@pytest.mark.parametrize("accept_key", ["Accept", "accept"])
+@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE)
+@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize")
+def test_transform_with_context(validate, retrieve_content_type_header, accept_key):
+    data = [{"body": INPUT_DATA}]
+    context = Mock()
+    request_processor = Mock()
+
+    # Simulate setting a transform_fn that supports additional 'context' argument:
+    transform_fn = Mock(return_value=RESULT)
+    transform_fn.__signature__ = signature(
+        lambda model, input_data, content_type, accept, context: None
+    )
+
+    context.request_processor = [request_processor]
+    request_property = {accept_key: ACCEPT}
+    request_processor.get_request_properties.return_value = request_property
+
+    transformer = Transformer()
+    transformer._model = MODEL
+    transformer._transform_fn = transform_fn
+
+    result = transformer.transform(data, context)
+
+    validate.assert_called_once()
+    retrieve_content_type_header.assert_called_once_with(request_property)
+    transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, context)
+    context.set_response_content_type.assert_called_once_with(0, ACCEPT)
+    assert isinstance(result, list)
+    assert result[0] == RESULT
+
+
 @patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE)
 @patch("sagemaker_inference.transformer.Transformer.validate_and_initialize")
 def test_transform_no_accept(validate, retrieve_content_type_header):
@@ -417,6 +458,8 @@ def test_validate_user_module_error(find_spec, import_module, user_module):
 def test_default_transform_fn():
     transformer = Transformer()
 
+    # Default Mock.__call__ signature has 2 args anyway (args, kwargs) so no signature hacking
+    # necessary for the context-free case:
     input_fn = Mock(return_value=PREPROCESSED_DATA)
     predict_fn = Mock(return_value=PREDICT_RESULT)
     output_fn = Mock(return_value=PROCESSED_RESULT)
@@ -425,9 +468,32 @@ def test_default_transform_fn():
     transformer._predict_fn = predict_fn
     transformer._output_fn = output_fn
 
-    result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT)
+    result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, REQ_CONTEXT)
 
     input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE)
     predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL)
     output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT)
     assert result == PROCESSED_RESULT
+
+
+def test_default_transform_with_contextual_fns():
+    transformer = Transformer()
+
+    # Set the __signature__ on the mock functions to indicate the user overrides want req context:
+    input_fn = Mock(return_value=PREPROCESSED_DATA)
+    input_fn.__signature__ = signature(lambda input_data, content_type, context: None)
+    predict_fn = Mock(return_value=PREDICT_RESULT)
+    predict_fn.__signature__ = signature(lambda data, model, context: None)
+    output_fn = Mock(return_value=PROCESSED_RESULT)
+    output_fn.__signature__ = signature(lambda prediction, accept, context: None)
+
+    transformer._input_fn = input_fn
+    transformer._predict_fn = predict_fn
+    transformer._output_fn = output_fn
+
+    result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, REQ_CONTEXT)
+
+    input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE, REQ_CONTEXT)
+    predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL, REQ_CONTEXT)
+    output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT, REQ_CONTEXT)
+    assert result == PROCESSED_RESULT