Skip to content

Commit 0816a49

Browse files
feature: Write handler to config file and serve model directly from o… (#115)
* feature: Write handler to config file and serve model directly from original model artifact directory to save disk space. * Remove Python 2.7 support
1 parent ad746db commit 0816a49

File tree

6 files changed

+43
-30
lines changed

6 files changed

+43
-30
lines changed

buildspec-release.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ phases:
2222
- tox -e black-check
2323

2424
# run unit tests
25-
- tox -e py27,py36,py37 -- test/unit
25+
- tox -e py36,py37 -- test/unit
2626

2727
# build dummy container
2828
- python3 setup.py sdist
@@ -37,7 +37,7 @@ phases:
3737
- cd ../..
3838

3939
# run local integration tests
40-
- IGNORE_COVERAGE=- tox -e py27,py36,py37 -- test/integration/local
40+
- IGNORE_COVERAGE=- tox -e py36,py37 -- test/integration/local
4141

4242
# publish the release to github
4343
- git-release --publish

buildspec.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ phases:
2020
- tox -e twine
2121

2222
# run unit tests
23-
- tox -e py27,py36,py37 -- test/unit
23+
- tox -e py36,py37 -- test/unit
2424

2525
# build dummy container
2626
- python setup.py sdist
@@ -35,4 +35,4 @@ phases:
3535
- cd ../..
3636

3737
# run local integration tests
38-
- IGNORE_COVERAGE=- tox -e py27,py36,py37 -- test/integration/local
38+
- IGNORE_COVERAGE=- tox -e py36,py37 -- test/integration/local

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def read_version():
5656
"Natural Language :: English",
5757
"License :: OSI Approved :: Apache Software License",
5858
"Programming Language :: Python",
59-
"Programming Language :: Python :: 2.7",
6059
"Programming Language :: Python :: 3.6",
6160
"Programming Language :: Python :: 3.7",
6261
],

src/sagemaker_inference/model_server.py

+24-15
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@
4141
DEFAULT_MMS_LOG_FILE = pkg_resources.resource_filename(
4242
sagemaker_inference.__name__, "/etc/log4j2.xml"
4343
)
44-
DEFAULT_MMS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/mms/models")
44+
DEFAULT_MMS_MODEL_EXPORT_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/mms/models")
4545
DEFAULT_MMS_MODEL_NAME = "model"
4646

4747
ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
48-
MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_MMS_MODEL_DIRECTORY
48+
MODEL_STORE = "/"
4949

5050
PYTHON_PATH_ENV = "PYTHONPATH"
5151
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")
@@ -68,15 +68,16 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
6868
6969
"""
7070

71-
if ENABLE_MULTI_MODEL:
72-
if not os.getenv("SAGEMAKER_HANDLER"):
73-
os.environ["SAGEMAKER_HANDLER"] = handler_service
74-
_set_python_path()
75-
else:
76-
_adapt_to_mms_format(handler_service)
71+
if ENABLE_MULTI_MODEL and not os.getenv("SAGEMAKER_HANDLER"):
72+
os.environ["SAGEMAKER_HANDLER"] = handler_service
73+
74+
_set_python_path()
7775

7876
env = environment.Environment()
79-
_create_model_server_config_file(env)
77+
78+
# Note: multi-model default config already sets default_service_handler
79+
handler_service_for_config = None if ENABLE_MULTI_MODEL else handler_service
80+
_create_model_server_config_file(env, handler_service_for_config)
8081

8182
if os.path.exists(REQUIREMENTS_PATH):
8283
_install_requirements()
@@ -91,6 +92,8 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
9192
"--log-config",
9293
DEFAULT_MMS_LOG_FILE,
9394
]
95+
if not ENABLE_MULTI_MODEL:
96+
multi_model_server_cmd += ["--models", DEFAULT_MMS_MODEL_NAME + "=" + environment.model_dir]
9497

9598
logger.info(multi_model_server_cmd)
9699
subprocess.Popen(multi_model_server_cmd)
@@ -104,9 +107,12 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
104107
mms_process.wait()
105108

106109

110+
# Note: this legacy function is still here for backwards compatibility.
111+
# It should not normally need to be used, since the model artifact can be used
112+
# straight from the original model directory
107113
def _adapt_to_mms_format(handler_service):
108-
if not os.path.exists(DEFAULT_MMS_MODEL_DIRECTORY):
109-
os.makedirs(DEFAULT_MMS_MODEL_DIRECTORY)
114+
if not os.path.exists(DEFAULT_MMS_MODEL_EXPORT_DIRECTORY):
115+
os.makedirs(DEFAULT_MMS_MODEL_EXPORT_DIRECTORY)
110116

111117
model_archiver_cmd = [
112118
"model-archiver",
@@ -117,7 +123,7 @@ def _adapt_to_mms_format(handler_service):
117123
"--model-path",
118124
environment.model_dir,
119125
"--export-path",
120-
DEFAULT_MMS_MODEL_DIRECTORY,
126+
DEFAULT_MMS_MODEL_EXPORT_DIRECTORY,
121127
"--archive-format",
122128
"no-archive",
123129
]
@@ -141,20 +147,23 @@ def _set_python_path():
141147
os.environ[PYTHON_PATH_ENV] = code_dir_path
142148

143149

144-
def _create_model_server_config_file(env):
145-
configuration_properties = _generate_mms_config_properties(env)
150+
def _create_model_server_config_file(env, handler_service=None):
151+
configuration_properties = _generate_mms_config_properties(env, handler_service)
146152

147153
utils.write_file(MMS_CONFIG_FILE, configuration_properties)
148154

149155

150-
def _generate_mms_config_properties(env):
156+
def _generate_mms_config_properties(env, handler_service=None):
151157
user_defined_configuration = {
152158
"default_response_timeout": env.model_server_timeout,
153159
"default_workers_per_model": env.model_server_workers,
154160
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
155161
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
156162
"vmargs": "-XX:-UseContainerSupport",
157163
}
164+
# If provided, add handler service to user config
165+
if handler_service:
166+
user_defined_configuration["default_service_handler"] = handler_service
158167

159168
custom_configuration = str()
160169

test/unit/test_model_server.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,23 @@ def test_start_model_server_default_service_handler(
4949

5050
model_server.start_model_server()
5151

52-
adapt.assert_called_once_with(model_server.DEFAULT_HANDLER_SERVICE)
53-
create_config.assert_called_once_with(env.return_value)
52+
adapt.assert_not_called()
53+
54+
create_config.assert_called_once_with(env.return_value, model_server.DEFAULT_HANDLER_SERVICE)
5455
exists.assert_called_once_with(REQUIREMENTS_PATH)
5556
install_requirements.assert_called_once_with()
5657

5758
multi_model_server_cmd = [
5859
"multi-model-server",
5960
"--start",
6061
"--model-store",
61-
model_server.DEFAULT_MMS_MODEL_DIRECTORY,
62+
model_server.MODEL_STORE,
6263
"--mms-config",
6364
model_server.MMS_CONFIG_FILE,
6465
"--log-config",
6566
model_server.DEFAULT_MMS_LOG_FILE,
67+
"--models",
68+
"{}={}".format(model_server.DEFAULT_MMS_MODEL_NAME, environment.model_dir),
6669
]
6770

6871
subprocess_popen.assert_called_once_with(multi_model_server_cmd)
@@ -75,14 +78,16 @@ def test_start_model_server_default_service_handler(
7578
@patch("sagemaker_inference.model_server._add_sigterm_handler")
7679
@patch("sagemaker_inference.model_server._create_model_server_config_file")
7780
@patch("sagemaker_inference.model_server._adapt_to_mms_format")
81+
@patch("sagemaker_inference.environment.Environment")
7882
def test_start_model_server_custom_handler_service(
79-
adapt, create_config, sigterm, retrieve, subprocess_popen, subprocess_call
83+
env, adapt, create_config, sigterm, retrieve, subprocess_popen, subprocess_call
8084
):
8185
handler_service = Mock()
8286

8387
model_server.start_model_server(handler_service)
8488

85-
adapt.assert_called_once_with(handler_service)
89+
adapt.assert_not_called()
90+
create_config.assert_called_once_with(env.return_value, handler_service)
8691

8792

8893
@patch("sagemaker_inference.model_server._set_python_path")
@@ -94,8 +99,8 @@ def test_adapt_to_mms_format(path_exists, make_dir, subprocess_check_call, set_p
9499

95100
model_server._adapt_to_mms_format(handler_service)
96101

97-
path_exists.assert_called_once_with(model_server.DEFAULT_MMS_MODEL_DIRECTORY)
98-
make_dir.assert_called_once_with(model_server.DEFAULT_MMS_MODEL_DIRECTORY)
102+
path_exists.assert_called_once_with(model_server.DEFAULT_MMS_MODEL_EXPORT_DIRECTORY)
103+
make_dir.assert_called_once_with(model_server.DEFAULT_MMS_MODEL_EXPORT_DIRECTORY)
99104

100105
model_archiver_cmd = [
101106
"model-archiver",
@@ -106,7 +111,7 @@ def test_adapt_to_mms_format(path_exists, make_dir, subprocess_check_call, set_p
106111
"--model-path",
107112
environment.model_dir,
108113
"--export-path",
109-
model_server.DEFAULT_MMS_MODEL_DIRECTORY,
114+
model_server.DEFAULT_MMS_MODEL_EXPORT_DIRECTORY,
110115
"--archive-format",
111116
"no-archive",
112117
]
@@ -126,7 +131,7 @@ def test_adapt_to_mms_format_existing_path(
126131

127132
model_server._adapt_to_mms_format(handler_service)
128133

129-
path_exists.assert_called_once_with(model_server.DEFAULT_MMS_MODEL_DIRECTORY)
134+
path_exists.assert_called_once_with(model_server.DEFAULT_MMS_MODEL_EXPORT_DIRECTORY)
130135
make_dir.assert_not_called()
131136

132137

tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# and then run "tox" from this directory.
55

66
[tox]
7-
envlist = black-format,flake8,pylint,twine,py27,py36,py37
7+
envlist = black-format,flake8,pylint,twine,py36,py37
88

99
skip_missing_interpreters = False
1010

0 commit comments

Comments
 (0)