41
41
DEFAULT_MMS_LOG_FILE = pkg_resources .resource_filename (
42
42
sagemaker_inference .__name__ , "/etc/log4j2.xml"
43
43
)
44
- DEFAULT_MMS_MODEL_DIRECTORY = os .path .join (os .getcwd (), ".sagemaker/mms/models" )
44
+ DEFAULT_MMS_MODEL_EXPORT_DIRECTORY = os .path .join (os .getcwd (), ".sagemaker/mms/models" )
45
45
DEFAULT_MMS_MODEL_NAME = "model"
46
46
47
47
ENABLE_MULTI_MODEL = os .getenv ("SAGEMAKER_MULTI_MODEL" , "false" ) == "true"
48
- MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_MMS_MODEL_DIRECTORY
48
+ MODEL_STORE = "/"
49
49
50
50
PYTHON_PATH_ENV = "PYTHONPATH"
51
51
REQUIREMENTS_PATH = os .path .join (code_dir , "requirements.txt" )
@@ -68,15 +68,16 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
68
68
69
69
"""
70
70
71
- if ENABLE_MULTI_MODEL :
72
- if not os .getenv ("SAGEMAKER_HANDLER" ):
73
- os .environ ["SAGEMAKER_HANDLER" ] = handler_service
74
- _set_python_path ()
75
- else :
76
- _adapt_to_mms_format (handler_service )
71
+ if ENABLE_MULTI_MODEL and not os .getenv ("SAGEMAKER_HANDLER" ):
72
+ os .environ ["SAGEMAKER_HANDLER" ] = handler_service
73
+
74
+ _set_python_path ()
77
75
78
76
env = environment .Environment ()
79
- _create_model_server_config_file (env )
77
+
78
+ # Note: multi-model default config already sets default_service_handler
79
+ handler_service_for_config = None if ENABLE_MULTI_MODEL else handler_service
80
+ _create_model_server_config_file (env , handler_service_for_config )
80
81
81
82
if os .path .exists (REQUIREMENTS_PATH ):
82
83
_install_requirements ()
@@ -91,6 +92,8 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
91
92
"--log-config" ,
92
93
DEFAULT_MMS_LOG_FILE ,
93
94
]
95
+ if not ENABLE_MULTI_MODEL :
96
+ multi_model_server_cmd += ["--models" , DEFAULT_MMS_MODEL_NAME + "=" + environment .model_dir ]
94
97
95
98
logger .info (multi_model_server_cmd )
96
99
subprocess .Popen (multi_model_server_cmd )
@@ -104,9 +107,12 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
104
107
mms_process .wait ()
105
108
106
109
110
+ # Note: this legacy function is still here for backwards compatibility.
111
+ # It should not normally need to be used, since the model artifact can be used
112
+ # straight from the original model directory
107
113
def _adapt_to_mms_format (handler_service ):
108
- if not os .path .exists (DEFAULT_MMS_MODEL_DIRECTORY ):
109
- os .makedirs (DEFAULT_MMS_MODEL_DIRECTORY )
114
+ if not os .path .exists (DEFAULT_MMS_MODEL_EXPORT_DIRECTORY ):
115
+ os .makedirs (DEFAULT_MMS_MODEL_EXPORT_DIRECTORY )
110
116
111
117
model_archiver_cmd = [
112
118
"model-archiver" ,
@@ -117,7 +123,7 @@ def _adapt_to_mms_format(handler_service):
117
123
"--model-path" ,
118
124
environment .model_dir ,
119
125
"--export-path" ,
120
- DEFAULT_MMS_MODEL_DIRECTORY ,
126
+ DEFAULT_MMS_MODEL_EXPORT_DIRECTORY ,
121
127
"--archive-format" ,
122
128
"no-archive" ,
123
129
]
@@ -141,20 +147,23 @@ def _set_python_path():
141
147
os .environ [PYTHON_PATH_ENV ] = code_dir_path
142
148
143
149
144
- def _create_model_server_config_file (env ):
145
- configuration_properties = _generate_mms_config_properties (env )
150
+ def _create_model_server_config_file (env , handler_service = None ):
151
+ configuration_properties = _generate_mms_config_properties (env , handler_service )
146
152
147
153
utils .write_file (MMS_CONFIG_FILE , configuration_properties )
148
154
149
155
150
- def _generate_mms_config_properties (env ):
156
+ def _generate_mms_config_properties (env , handler_service = None ):
151
157
user_defined_configuration = {
152
158
"default_response_timeout" : env .model_server_timeout ,
153
159
"default_workers_per_model" : env .model_server_workers ,
154
160
"inference_address" : "http://0.0.0.0:{}" .format (env .inference_http_port ),
155
161
"management_address" : "http://0.0.0.0:{}" .format (env .management_http_port ),
156
162
"vmargs" : "-XX:-UseContainerSupport" ,
157
163
}
164
+ # If provided, add handler service to user config
165
+ if handler_service :
166
+ user_defined_configuration ["default_service_handler" ] = handler_service
158
167
159
168
custom_configuration = str ()
160
169
0 commit comments