Skip to content

Commit da465e9

Browse files
committed
fix: kservev2 batching issue and missing parameters
Fix KServeV2 picking up only the first item from the list. As a result it is able to dynamically batch the requests. Updated that if the model returns a dictionary it picks up the keys as the names of the outputs. Update the envelope to handle request parameters and input parameters
1 parent 3182443 commit da465e9

File tree

1 file changed

+69
-42
lines changed

1 file changed

+69
-42
lines changed

ts/torch_handler/request_envelope/kservev2.py

+69-42
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
import json
66
import logging
7+
from typing import Optional
78

89
import numpy as np
910

@@ -104,24 +105,43 @@ def _from_json(self, body_list):
104105
logger.debug("Bytes array is %s", body_list)
105106

106107
input_names = []
107-
for index, input in enumerate(body_list[0]["inputs"]):
108-
if input["datatype"] == "BYTES":
109-
body_list[0]["inputs"][index]["data"] = input["data"][0]
110-
else:
111-
body_list[0]["inputs"][index]["data"] = (
112-
np.array(input["data"]).reshape(tuple(input["shape"])).tolist()
113-
)
114-
input_names.append(input["name"])
108+
parameters = []
109+
ids = []
110+
input_parameters = []
111+
data_list = []
112+
113+
for body in body_list:
114+
id = body.get("id")
115+
ids.append(id)
116+
params = body.get("parameters")
117+
if params:
118+
parameters.append(params)
119+
inp_names = []
120+
inp_params = []
121+
for i, input in enumerate(body["inputs"]):
122+
params = input.get("parameters")
123+
if params:
124+
inp_params.append(params)
125+
if input["datatype"] == "BYTES":
126+
body["inputs"][i]["data"] = input["data"][0]
127+
else:
128+
body["inputs"][i]["data"] = (
129+
np.array(input["data"]).reshape(tuple(input["shape"])).tolist()
130+
)
131+
inp_names.append(input["name"])
132+
data = body["inputs"] if len(body["inputs"]) > 1 else body["inputs"][0]
133+
data_list.append(data)
134+
135+
input_parameters.append(inp_params)
136+
input_names.append(inp_names)
137+
138+
setattr(self.context, "input_request_id", ids)
115139
setattr(self.context, "input_names", input_names)
116-
logger.debug("Bytes array is %s", body_list)
117-
id = body_list[0].get("id")
118-
if id and id.strip():
119-
setattr(self.context, "input_request_id", body_list[0]["id"])
120-
# TODO: Add parameters support
121-
# parameters = body_list[0].get("parameters")
122-
# if parameters:
123-
# setattr(self.context, "input_parameters", body_list[0]["parameters"])
124-
data_list = [inputs_list.get("inputs") for inputs_list in body_list][0]
140+
setattr(self.context, "request_parameters", parameters)
141+
setattr(self.context, "input_parameters", input_parameters)
142+
logger.debug("Data array is %s", data_list)
143+
logger.debug("Request paraemeters array is %s", parameters)
144+
logger.debug("Input parameters is %s", input_parameters)
125145
return data_list
126146

127147
def format_output(self, data):
@@ -145,41 +165,48 @@ def format_output(self, data):
145165
146166
"""
147167
logger.debug("The Response of KServe v2 format %s", data)
148-
response = {}
149-
if hasattr(self.context, "input_request_id"):
150-
response["id"] = getattr(self.context, "input_request_id")
151-
delattr(self.context, "input_request_id")
152-
else:
153-
response["id"] = self.context.get_request_id(0)
154-
# TODO: Add parameters support
155-
# if hasattr(self.context, "input_parameters"):
156-
# response["parameters"] = getattr(self.context, "input_parameters")
157-
# delattr(self.context, "input_parameters")
158-
response["model_name"] = self.context.manifest.get("model").get("modelName")
159-
response["model_version"] = self.context.manifest.get("model").get(
160-
"modelVersion"
161-
)
162-
response["outputs"] = self._batch_to_json(data)
163-
return [response]
164-
165-
def _batch_to_json(self, data):
168+
return self._batch_to_json(data)
169+
170+
def _batch_to_json(self, batch: dict):
166171
"""
167172
Splits batch output to json objects
168173
"""
169-
output = []
170-
input_names = getattr(self.context, "input_names")
174+
parameters = getattr(self.context, "request_parameters")
175+
ids = getattr(self.context, "input_request_id")
176+
input_parameters = getattr(self.context, "input_parameters")
177+
responses = []
178+
for index, data in enumerate(batch):
179+
response = {}
180+
response["id"] = ids[index] or self.context.get_request_id(index)
181+
if parameters and parameters[index]:
182+
response["parameters"] = parameters[index]
183+
response["model_name"] = self.context.manifest.get("model").get("modelName")
184+
response["model_version"] = self.context.manifest.get("model").get(
185+
"modelVersion"
186+
)
187+
outputs = []
188+
if isinstance(data, dict):
189+
for key, item in data.items():
190+
outputs.append(self._to_json(item, key, input_parameters))
191+
else:
192+
outputs.append(self._to_json(data, "predictions", input_parameters))
193+
response["outputs"] = outputs
194+
responses.append(response)
171195
delattr(self.context, "input_names")
172-
for index, item in enumerate(data):
173-
output.append(self._to_json(item, input_names[index]))
174-
return output
196+
delattr(self.context, "input_request_id")
197+
delattr(self.context, "input_parameters")
198+
delattr(self.context, "request_parameters")
199+
return responses
175200

176-
def _to_json(self, data, input_name):
201+
def _to_json(self, data, output_name, parameters: Optional[list] = None):
177202
"""
178203
Constructs JSON object from data
179204
"""
180205
output_data = {}
181206
data_ndarray = np.array(data).flatten()
182-
output_data["name"] = input_name
207+
output_data["name"] = output_name
208+
if parameters:
209+
output_data["parameters"] = parameters
183210
output_data["datatype"] = _to_datatype(data_ndarray.dtype)
184211
output_data["data"] = data_ndarray.tolist()
185212
output_data["shape"] = data_ndarray.flatten().shape

0 commit comments

Comments
 (0)