Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Support CustomAttributes / request context in handler function overrides #111

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def read_version():
# enum is introduced in Python 3.4. Installing enum back port
if sys.version_info < (3, 4):
required_packages.append("enum34 >= 1.1.6")
# inspect.signature is introduced in Python 3.3. Installing funcsigs back port
if sys.version_info < (3, 3):
required_packages.append("funcsigs >= 1.0.2")

PKG_NAME = "sagemaker_inference"

Expand Down
35 changes: 30 additions & 5 deletions src/sagemaker_inference/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
"""
from __future__ import absolute_import

try:
from inspect import signature
except ImportError:
# Python<3.3 backport:
from funcsigs import signature

import importlib
import traceback

Expand Down Expand Up @@ -125,7 +131,17 @@ def transform(self, data, context):
if content_type in content_types.UTF8_TYPES:
input_data = input_data.decode("utf-8")

result = self._transform_fn(self._model, input_data, content_type, accept)
# If the configured transform_fn supports the optional request context argument, pass it
# through:
transform_args = [self._model, input_data, content_type, accept]
n_fn_args = len(signature(self._transform_fn).parameters)
if n_fn_args > 4:
# TODO: Probably wouldn't use the actual 'context' obj here?
# Some more curated, useful object could be preferred maybe... Just something so
# that users' override functions can access CustomAttributes header! (Maybe even
# modify it on the response?)
transform_args.append(context)
result = self._transform_fn(*transform_args)

response = result
response_content_type = accept
Expand Down Expand Up @@ -214,7 +230,7 @@ def _validate_user_module_and_set_functions(self):

self._transform_fn = self._default_transform_fn

def _default_transform_fn(self, model, input_data, content_type, accept):
def _default_transform_fn(self, model, input_data, content_type, accept, context):
"""Make predictions against the model and return a serialized response.
This serves as the default implementation of transform_fn, used when the
user has not provided an implementation.
Expand All @@ -224,13 +240,22 @@ def _default_transform_fn(self, model, input_data, content_type, accept):
input_data (obj): the request data.
content_type (str): the request content type.
accept (str): accept header expected by the client.
context (?): TODO

Returns:
obj: the serialized prediction result or a tuple of the form
(response_data, content_type)

"""
data = self._input_fn(input_data, content_type)
prediction = self._predict_fn(data, model)
result = self._output_fn(prediction, accept)
# If configured handler functions support the extra context arg, pass it through:
n_input_args = len(signature(self._input_fn).parameters)
n_predict_args = len(signature(self._predict_fn).parameters)
n_output_args = len(signature(self._output_fn).parameters)

input_args = [input_data, content_type] + ([context] if n_input_args > 2 else [])
data = self._input_fn(*input_args)
predict_args = [data, model] + ([context] if n_predict_args > 2 else [])
prediction = self._predict_fn(*predict_args)
output_args = [prediction, accept] + ([context] if n_output_args > 2 else [])
result = self._output_fn(*output_args)
return result
68 changes: 67 additions & 1 deletion test/unit/test_transfomer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
except ImportError:
import httplib as http_client

try:
from inspect import signature
except ImportError:
# Python<3.3 backport:
from funcsigs import signature

from sagemaker_inference import content_types, environment
from sagemaker_inference.default_inference_handler import DefaultInferenceHandler
from sagemaker_inference.errors import BaseInferenceToolkitError
Expand All @@ -29,6 +35,7 @@
DEFAULT_ACCEPT = "default_accept"
RESULT = "result"
MODEL = "foo"
REQ_CONTEXT = object()

PREPROCESSED_DATA = "preprocessed_data"
PREDICT_RESULT = "prediction_result"
Expand Down Expand Up @@ -90,12 +97,46 @@ def test_transform(validate, retrieve_content_type_header, accept_key):

validate.assert_called_once()
retrieve_content_type_header.assert_called_once_with(request_property)
# Since Mock()'s callable signature has only 2 args ('args', 'kwargs'), the extra 'context' arg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this comment can be removed, the test is self-explanatory.

# will not be added in this case (see test_transform_with_context below):
transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT)
context.set_response_content_type.assert_called_once_with(0, ACCEPT)
assert isinstance(result, list)
assert result[0] == RESULT


@pytest.mark.parametrize("accept_key", ["Accept", "accept"])
@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE)
@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize")
def test_transform_with_context(validate, retrieve_content_type_header, accept_key):
data = [{"body": INPUT_DATA}]
context = Mock()
request_processor = Mock()

# Simulate setting a transform_fn that supports additional 'context' argument:
transform_fn = Mock(return_value=RESULT)
transform_fn.__signature__ = signature(
lambda model, input_data, content_type, accept, context: None
)

context.request_processor = [request_processor]
request_property = {accept_key: ACCEPT}
request_processor.get_request_properties.return_value = request_property

transformer = Transformer()
transformer._model = MODEL
transformer._transform_fn = transform_fn

result = transformer.transform(data, context)

validate.assert_called_once()
retrieve_content_type_header.assert_called_once_with(request_property)
transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, context)
context.set_response_content_type.assert_called_once_with(0, ACCEPT)
assert isinstance(result, list)
assert result[0] == RESULT


@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE)
@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize")
def test_transform_no_accept(validate, retrieve_content_type_header):
Expand Down Expand Up @@ -417,6 +458,8 @@ def test_validate_user_module_error(find_spec, import_module, user_module):
def test_default_transform_fn():
transformer = Transformer()

# Default Mock.__call__ signature has 2 args anyway (args, kwargs) so no signature hacking
# necessary for the context-free case:
input_fn = Mock(return_value=PREPROCESSED_DATA)
predict_fn = Mock(return_value=PREDICT_RESULT)
output_fn = Mock(return_value=PROCESSED_RESULT)
Expand All @@ -425,9 +468,32 @@ def test_default_transform_fn():
transformer._predict_fn = predict_fn
transformer._output_fn = output_fn

result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT)
result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, REQ_CONTEXT)

input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE)
predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL)
output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT)
assert result == PROCESSED_RESULT


def test_default_transform_with_contextual_fns():
transformer = Transformer()

# Set the __signature__ on the mock functions to indicate the user overrides want req context:
input_fn = Mock(return_value=PREPROCESSED_DATA)
input_fn.__signature__ = signature(lambda input_data, content_type, context: None)
predict_fn = Mock(return_value=PREDICT_RESULT)
predict_fn.__signature__ = signature(lambda data, model, context: None)
output_fn = Mock(return_value=PROCESSED_RESULT)
output_fn.__signature__ = signature(lambda prediction, accept, context: None)

transformer._input_fn = input_fn
transformer._predict_fn = predict_fn
transformer._output_fn = output_fn

result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, REQ_CONTEXT)

input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE, REQ_CONTEXT)
predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL, REQ_CONTEXT)
output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT, REQ_CONTEXT)
assert result == PROCESSED_RESULT