diff --git a/src/sagemaker_inference/transformer.py b/src/sagemaker_inference/transformer.py index ed4fab5..8b0ca84 100644 --- a/src/sagemaker_inference/transformer.py +++ b/src/sagemaker_inference/transformer.py @@ -122,40 +122,33 @@ def transform(self, data, context): model_dir = properties.get("model_dir") self.validate_and_initialize(model_dir=model_dir, context=context) - response_list = [] + input_data = [req.get("body") for req in data] - for i in range(len(data)): - input_data = data[i].get("body") + request_processor = context.request_processor[0] - request_processor = context.request_processor[0] + request_property = request_processor.get_request_properties() + content_type = utils.retrieve_content_type_header(request_property) + accept = request_property.get("Accept") or request_property.get("accept") - request_property = request_processor.get_request_properties() - content_type = utils.retrieve_content_type_header(request_property) - accept = request_property.get("Accept") or request_property.get("accept") + if not accept or accept == content_types.ANY: + accept = self._environment.default_accept - if not accept or accept == content_types.ANY: - accept = self._environment.default_accept - - if content_type in content_types.UTF8_TYPES: - input_data = input_data.decode("utf-8") - - result = self._run_handler_function( - self._transform_fn, *(self._model, input_data, content_type, accept) - ) - - response = result - response_content_type = accept - - if isinstance(result, tuple): - # handles tuple for backwards compatibility - response = result[0] - response_content_type = result[1] + if content_type in content_types.UTF8_TYPES: + input_data = [item.decode("utf-8") for item in input_data] + result = self._run_handler_function( + self._transform_fn, *(self._model, input_data, content_type, accept) + ) - context.set_response_content_type(0, response_content_type) + response = result + response_content_type = accept - response_list.append(response) + if isinstance(result, tuple): + # handles tuple for backwards compatibility + response = result[0] + response_content_type = result[1] - return response_list + context.set_response_content_type(0, response_content_type) + return [response] except Exception as e: # pylint: disable=broad-except trace = traceback.format_exc() if isinstance(e, BaseInferenceToolkitError): diff --git a/test/unit/test_transfomer.py b/test/unit/test_transfomer.py index 3a6cd16..fec911a 100644 --- a/test/unit/test_transfomer.py +++ b/test/unit/test_transfomer.py @@ -96,7 +96,7 @@ def test_transform(validate, retrieve_content_type_header, run_handler, accept_k validate.assert_called_once() retrieve_content_type_header.assert_called_once_with(request_property) run_handler.assert_called_once_with( - transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT + transformer._transform_fn, MODEL, [INPUT_DATA], CONTENT_TYPE, ACCEPT ) context.set_response_content_type.assert_called_once_with(0, ACCEPT) assert isinstance(result, list) @@ -125,16 +125,13 @@ def test_batch_transform(validate, retrieve_content_type_header, run_handler, ac result = transformer.transform(data, context) validate.assert_called_once() - retrieve_content_type_header.assert_called_with(request_property) - assert retrieve_content_type_header.call_count == 2 - run_handler.assert_called_with( - transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT + retrieve_content_type_header.assert_called_once_with(request_property) + run_handler.assert_called_once_with( + transformer._transform_fn, MODEL, [INPUT_DATA, INPUT_DATA], CONTENT_TYPE, ACCEPT ) - assert run_handler.call_count == 2 - context.set_response_content_type.assert_called_with(0, ACCEPT) - assert context.set_response_content_type.call_count == 2 + context.set_response_content_type.assert_called_once_with(0, ACCEPT) assert isinstance(result, list) - assert result == [RESULT, RESULT] + assert result[0] == RESULT @patch("sagemaker_inference.transformer.Transformer._run_handler_function") @@ -161,7 +158,7 @@ def test_transform_no_accept(validate, retrieve_content_type_header, run_handler validate.assert_called_once() run_handler.assert_called_once_with( - transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT + transformer._transform_fn, MODEL, [INPUT_DATA], CONTENT_TYPE, DEFAULT_ACCEPT ) @@ -189,7 +186,7 @@ def test_transform_any_accept(validate, retrieve_content_type_header, run_handle validate.assert_called_once() run_handler.assert_called_once_with( - transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT + transformer._transform_fn, MODEL, [INPUT_DATA], CONTENT_TYPE, DEFAULT_ACCEPT ) @@ -218,7 +215,7 @@ def test_transform_decode(validate, retrieve_content_type_header, run_handler, c input_data.decode.assert_called_once_with("utf-8") run_handler.assert_called_once_with( - transformer._transform_fn, MODEL, INPUT_DATA, content_type, ACCEPT + transformer._transform_fn, MODEL, [INPUT_DATA], content_type, ACCEPT ) @@ -245,7 +242,7 @@ def test_transform_tuple(validate, retrieve_content_type_header, run_handler): result = transformer.transform(data, context) run_handler.assert_called_once_with( - transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT + transformer._transform_fn, MODEL, [INPUT_DATA], CONTENT_TYPE, ACCEPT ) context.set_response_content_type.assert_called_once_with(0, run_handler()[1]) assert isinstance(result, list)