Skip to content

Commit f23c65d

Browse files
Add support for SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS variable (#129)
1 parent 6b33274 commit f23c65d

File tree

5 files changed

+37
-12
lines changed

5 files changed

+37
-12
lines changed

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
universal=1
33

44
[metadata]
5-
description-file = README.md
5+
description_file = README.md

src/sagemaker_inference/environment.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import absolute_import
1717

1818
import os
19+
from typing import Optional
1920

2021
from sagemaker_inference import content_types, logging, parameters
2122

@@ -51,7 +52,9 @@ class Environment(object):
5152
5253
Attributes:
5354
module_name (str): The name of the user-provided module. Default is inference.py.
54-
model_server_timeout (int): Timeout in seconds for the model server. Default is 60.
55+
model_server_timeout (int): Timeout for the model server. Default is 60.
56+
model_server_timeout_seconds (Optional[int]): Timeout in seconds for the model server.
57+
Default is None.
5558
model_server_workers (str): Number of worker processes the model server will use.
5659
5760
default_accept (str): The desired default MIME type of the inference in the response
@@ -71,7 +74,13 @@ def __init__(self):
7174
self._model_server_timeout = int(
7275
os.environ.get(parameters.MODEL_SERVER_TIMEOUT_ENV, DEFAULT_MODEL_SERVER_TIMEOUT)
7376
)
77+
timeout_seconds_var = os.environ.get(parameters.MODEL_SERVER_TIMEOUT_SECONDS_ENV)
78+
self._model_server_timeout_seconds = (
79+
int(timeout_seconds_var) if timeout_seconds_var is not None else None
80+
)
81+
7482
self._model_server_workers = os.environ.get(parameters.MODEL_SERVER_WORKERS_ENV)
83+
7584
self._startup_timeout = int(
7685
os.environ.get(parameters.STARTUP_TIMEOUT_ENV, DEFAULT_STARTUP_TIMEOUT)
7786
)
@@ -107,53 +116,61 @@ def module_name(self): # type: () -> str
107116
return self._parse_module_name(self._module_name)
108117

109118
@property
110-
def model_server_timeout(self): # type: () -> int
119+
def model_server_timeout(self) -> int:
120+
"""int: Timeout used for model server's backend workers before they are
121+
deemed unresponsive and rebooted.
122+
123+
"""
124+
return self._model_server_timeout
125+
126+
@property
127+
def model_server_timeout_seconds(self) -> Optional[int]:
111128
"""int: Timeout, in seconds, used for model server's backend workers before
112129
they are deemed unresponsive and rebooted.
113130
"""
114-
return self._model_server_timeout
131+
return self._model_server_timeout_seconds
115132

116133
@property
117-
def model_server_workers(self): # type: () -> str
134+
def model_server_workers(self) -> Optional[str]:
118135
"""str: Number of worker processes the model server is configured to use."""
119136
return self._model_server_workers
120137

121138
@property
122-
def startup_timeout(self): # type () -> int
139+
def startup_timeout(self) -> int:
123140
"""int: Timeout, in seconds, used for starting up the model server and fetching
124141
its process id, before giving up and throwing error.
125142
"""
126143
return self._startup_timeout
127144

128145
@property
129-
def default_accept(self): # type: () -> str
146+
def default_accept(self) -> str:
130147
"""str: The desired default MIME type of the inference in the response."""
131148
return self._default_accept
132149

133150
@property
134-
def inference_http_port(self): # type: () -> str
151+
def inference_http_port(self) -> str:
135152
"""str: HTTP port that SageMaker uses to handle invocations and pings."""
136153
return self._inference_http_port
137154

138155
@property
139-
def management_http_port(self): # type: () -> str
156+
def management_http_port(self) -> str:
140157
"""str: HTTP port that SageMaker uses to handle model management requests."""
141158
return self._management_http_port
142159

143160
@property
144-
def safe_port_range(self): # type: () -> str
161+
def safe_port_range(self) -> Optional[str]:
145162
"""str: HTTP port range that can be used by users to avoid collisions with the HTTP port
146163
specified by SageMaker for handling pings and invocations.
147164
"""
148165
return self._safe_port_range
149166

150167
@property
151-
def vmargs(self): # type: () -> str
168+
def vmargs(self) -> str:
152169
"""str: vmargs can be provided for the JVM, to be overriden"""
153170
return self._vmargs
154171

155172
@property
156-
def max_request_size(self): # type: () -> str
173+
def max_request_size(self) -> Optional[int]:
157174
"""str: max request size set by Sagemaker platform in bytes"""
158175
if self._max_request_size_in_mb is not None:
159176
return int(self._max_request_size_in_mb) * 1024 * 1024

src/sagemaker_inference/model_server.py

+5
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ def _generate_mms_config_properties(env, handler_service=None):
166166
if handler_service:
167167
user_defined_configuration["default_service_handler"] = handler_service
168168

169+
if env.model_server_timeout_seconds:
170+
user_defined_configuration[
171+
"default_response_timeout_seconds"
172+
] = env.model_server_timeout_seconds
173+
169174
custom_configuration = str()
170175

171176
for key in user_defined_configuration:

src/sagemaker_inference/parameters.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DEFAULT_INVOCATIONS_ACCEPT_ENV = "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" # type: str
2121
MODEL_SERVER_WORKERS_ENV = "SAGEMAKER_MODEL_SERVER_WORKERS" # type: str
2222
MODEL_SERVER_TIMEOUT_ENV = "SAGEMAKER_MODEL_SERVER_TIMEOUT" # type: str
23+
MODEL_SERVER_TIMEOUT_SECONDS_ENV = "SAGEMAKER_MODEL_SERVER_TIMEOUT_SECONDS" # type: str
2324
MODEL_SERVER_VMARGS = "SAGEMAKER_MODEL_SERVER_VMARGS" # type: str
2425
STARTUP_TIMEOUT_ENV = "SAGEMAKER_STARTUP_TIMEOUT" # type: str
2526
BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str

test/unit/test_environment.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
{
2424
parameters.USER_PROGRAM_ENV: "main.py",
2525
parameters.MODEL_SERVER_TIMEOUT_ENV: "20",
26+
parameters.MODEL_SERVER_TIMEOUT_SECONDS_ENV: "30",
2627
parameters.MODEL_SERVER_WORKERS_ENV: "8",
2728
parameters.STARTUP_TIMEOUT_ENV: "50",
2829
parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV: "text/html",
@@ -41,6 +42,7 @@ def test_env():
4142
assert environment.code_dir.endswith("opt/ml/model/code")
4243
assert env.module_name == "main"
4344
assert env.model_server_timeout == 20
45+
assert env.model_server_timeout_seconds == 30
4446
assert env.startup_timeout == 50
4547
assert env.model_server_workers == "8"
4648
assert env.default_accept == "text/html"

0 commit comments

Comments
 (0)