4
4
"""
5
5
import json
6
6
import logging
7
+ from typing import Optional
7
8
8
9
import numpy as np
9
10
@@ -104,24 +105,43 @@ def _from_json(self, body_list):
104
105
logger .debug ("Bytes array is %s" , body_list )
105
106
106
107
input_names = []
107
- for index , input in enumerate (body_list [0 ]["inputs" ]):
108
- if input ["datatype" ] == "BYTES" :
109
- body_list [0 ]["inputs" ][index ]["data" ] = input ["data" ][0 ]
110
- else :
111
- body_list [0 ]["inputs" ][index ]["data" ] = (
112
- np .array (input ["data" ]).reshape (tuple (input ["shape" ])).tolist ()
113
- )
114
- input_names .append (input ["name" ])
108
+ parameters = []
109
+ ids = []
110
+ input_parameters = []
111
+ data_list = []
112
+
113
+ for body in body_list :
114
+ id = body .get ("id" )
115
+ ids .append (id )
116
+ params = body .get ("parameters" )
117
+ if params :
118
+ parameters .append (params )
119
+ inp_names = []
120
+ inp_params = []
121
+ for i , input in enumerate (body ["inputs" ]):
122
+ params = input .get ("parameters" )
123
+ if params :
124
+ inp_params .append (params )
125
+ if input ["datatype" ] == "BYTES" :
126
+ body ["inputs" ][i ]["data" ] = input ["data" ][0 ]
127
+ else :
128
+ body ["inputs" ][i ]["data" ] = (
129
+ np .array (input ["data" ]).reshape (tuple (input ["shape" ])).tolist ()
130
+ )
131
+ inp_names .append (input ["name" ])
132
+ data = body ["inputs" ] if len (body ["inputs" ]) > 1 else body ["inputs" ][0 ]
133
+ data_list .append (data )
134
+
135
+ input_parameters .append (inp_params )
136
+ input_names .append (inp_names )
137
+
138
+ setattr (self .context , "input_request_id" , ids )
115
139
setattr (self .context , "input_names" , input_names )
116
- logger .debug ("Bytes array is %s" , body_list )
117
- id = body_list [0 ].get ("id" )
118
- if id and id .strip ():
119
- setattr (self .context , "input_request_id" , body_list [0 ]["id" ])
120
- # TODO: Add parameters support
121
- # parameters = body_list[0].get("parameters")
122
- # if parameters:
123
- # setattr(self.context, "input_parameters", body_list[0]["parameters"])
124
- data_list = [inputs_list .get ("inputs" ) for inputs_list in body_list ][0 ]
140
+ setattr (self .context , "request_parameters" , parameters )
141
+ setattr (self .context , "input_parameters" , input_parameters )
142
+ logger .debug ("Data array is %s" , data_list )
143
+ logger .debug ("Request paraemeters array is %s" , parameters )
144
+ logger .debug ("Input parameters is %s" , input_parameters )
125
145
return data_list
126
146
127
147
def format_output (self , data ):
@@ -145,41 +165,48 @@ def format_output(self, data):
145
165
146
166
"""
147
167
logger .debug ("The Response of KServe v2 format %s" , data )
148
- response = {}
149
- if hasattr (self .context , "input_request_id" ):
150
- response ["id" ] = getattr (self .context , "input_request_id" )
151
- delattr (self .context , "input_request_id" )
152
- else :
153
- response ["id" ] = self .context .get_request_id (0 )
154
- # TODO: Add parameters support
155
- # if hasattr(self.context, "input_parameters"):
156
- # response["parameters"] = getattr(self.context, "input_parameters")
157
- # delattr(self.context, "input_parameters")
158
- response ["model_name" ] = self .context .manifest .get ("model" ).get ("modelName" )
159
- response ["model_version" ] = self .context .manifest .get ("model" ).get (
160
- "modelVersion"
161
- )
162
- response ["outputs" ] = self ._batch_to_json (data )
163
- return [response ]
164
-
165
- def _batch_to_json (self , data ):
168
+ return self ._batch_to_json (data )
169
+
170
+ def _batch_to_json (self , batch : dict ):
166
171
"""
167
172
Splits batch output to json objects
168
173
"""
169
- output = []
170
- input_names = getattr (self .context , "input_names" )
174
+ parameters = getattr (self .context , "request_parameters" )
175
+ ids = getattr (self .context , "input_request_id" )
176
+ input_parameters = getattr (self .context , "input_parameters" )
177
+ responses = []
178
+ for index , data in enumerate (batch ):
179
+ response = {}
180
+ response ["id" ] = ids [index ] or self .context .get_request_id (index )
181
+ if parameters and parameters [index ]:
182
+ response ["parameters" ] = parameters [index ]
183
+ response ["model_name" ] = self .context .manifest .get ("model" ).get ("modelName" )
184
+ response ["model_version" ] = self .context .manifest .get ("model" ).get (
185
+ "modelVersion"
186
+ )
187
+ outputs = []
188
+ if isinstance (data , dict ):
189
+ for key , item in data .items ():
190
+ outputs .append (self ._to_json (item , key , input_parameters ))
191
+ else :
192
+ outputs .append (self ._to_json (data , "predictions" , input_parameters ))
193
+ response ["outputs" ] = outputs
194
+ responses .append (response )
171
195
delattr (self .context , "input_names" )
172
- for index , item in enumerate (data ):
173
- output .append (self ._to_json (item , input_names [index ]))
174
- return output
196
+ delattr (self .context , "input_request_id" )
197
+ delattr (self .context , "input_parameters" )
198
+ delattr (self .context , "request_parameters" )
199
+ return responses
175
200
176
- def _to_json (self , data , input_name ):
201
+ def _to_json (self , data , output_name , parameters : Optional [ list ] = None ):
177
202
"""
178
203
Constructs JSON object from data
179
204
"""
180
205
output_data = {}
181
206
data_ndarray = np .array (data ).flatten ()
182
- output_data ["name" ] = input_name
207
+ output_data ["name" ] = output_name
208
+ if parameters :
209
+ output_data ["parameters" ] = parameters
183
210
output_data ["datatype" ] = _to_datatype (data_ndarray .dtype )
184
211
output_data ["data" ] = data_ndarray .tolist ()
185
212
output_data ["shape" ] = data_ndarray .flatten ().shape
0 commit comments