14
14
PartialQuestionAnswer ,
15
15
PartialQuestions ,
16
16
QuestionAnswer ,
17
+ TokenUsage ,
18
+ TokenUsageAggregate ,
19
+ TokenUsageByModel ,
17
20
)
18
21
from codegate .db .connection import alert_queue
19
22
from codegate .db .models import Alert , GetPromptWithOutputsRow
@@ -57,16 +60,17 @@ async def _is_system_prompt(message: str) -> bool:
57
60
return False
58
61
59
62
60
- async def parse_request (request_str : str ) -> Optional [str ]:
63
+ async def parse_request (request_str : str ) -> Tuple [ Optional [List [ str ]], str ]:
61
64
"""
62
- Parse the request string from the pipeline and return the message.
65
+ Parse the request string from the pipeline and return the message and the model .
63
66
"""
64
67
try :
65
68
request = json .loads (request_str )
66
69
except Exception as e :
67
70
logger .warning (f"Error parsing request: { request_str } . { e } " )
68
- return None
71
+ return None , ""
69
72
73
+ model = request .get ("model" , "" )
70
74
messages = []
71
75
for message in request .get ("messages" , []):
72
76
role = message .get ("role" )
@@ -91,57 +95,60 @@ async def parse_request(request_str: str) -> Optional[str]:
91
95
if message_prompt and not await _is_system_prompt (message_prompt ):
92
96
messages .append (message_prompt )
93
97
94
- # If still we don't have anything, return empty string
98
+ # If still we don't have anything, return None string
95
99
if not messages :
96
- return None
100
+ return None , model
97
101
98
- # Only respond with the latest message
99
- return messages
102
+ # Respond with the messages and the model
103
+ return messages , model
100
104
101
105
102
- async def parse_output (output_str : str ) -> Optional [str ]:
106
+ async def parse_output (output_str : str ) -> Tuple [ Optional [str ], TokenUsage ]:
103
107
"""
104
108
Parse the output string from the pipeline and return the message.
105
109
"""
106
110
try :
107
111
if output_str is None :
108
- return None
112
+ return None , TokenUsage ()
109
113
110
114
output = json .loads (output_str )
111
115
except Exception as e :
112
116
logger .warning (f"Error parsing output: { output_str } . { e } " )
113
- return None
117
+ return None , TokenUsage ()
114
118
115
- def _parse_single_output (single_output : dict ) -> str :
119
+ def _parse_single_output (single_output : dict ) -> Tuple [ str , TokenUsage ] :
116
120
single_output_message = ""
117
121
for choice in single_output .get ("choices" , []):
118
122
if not isinstance (choice , dict ):
119
123
continue
120
124
content_dict = choice .get ("delta" , {}) or choice .get ("message" , {})
121
125
single_output_message += content_dict .get ("content" , "" )
122
- return single_output_message
126
+ return single_output_message , TokenUsage . from_dict ( single_output . get ( "usage" , {}))
123
127
124
128
full_output_message = ""
129
+ full_token_usage = TokenUsage ()
125
130
if isinstance (output , list ):
126
131
for output_chunk in output :
127
132
output_message = ""
133
+ token_usage = TokenUsage ()
128
134
if isinstance (output_chunk , dict ):
129
- output_message = _parse_single_output (output_chunk )
135
+ output_message , token_usage = _parse_single_output (output_chunk )
130
136
elif isinstance (output_chunk , str ):
131
137
try :
132
138
output_decoded = json .loads (output_chunk )
133
- output_message = _parse_single_output (output_decoded )
139
+ output_message , token_usage = _parse_single_output (output_decoded )
134
140
except Exception :
135
141
logger .error (f"Error reading chunk: { output_chunk } " )
136
142
else :
137
143
logger .warning (
138
144
f"Could not handle output: { output_chunk } " , out_type = type (output_chunk )
139
145
)
140
146
full_output_message += output_message
147
+ full_token_usage += token_usage
141
148
elif isinstance (output , dict ):
142
- full_output_message = _parse_single_output (output )
149
+ full_output_message , full_token_usage = _parse_single_output (output )
143
150
144
- return full_output_message
151
+ return full_output_message , full_token_usage
145
152
146
153
147
154
async def _get_question_answer (row : GetPromptWithOutputsRow ) -> Optional [PartialQuestionAnswer ]:
@@ -154,8 +161,8 @@ async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[Partial
154
161
request_task = tg .create_task (parse_request (row .request ))
155
162
output_task = tg .create_task (parse_output (row .output ))
156
163
157
- request_user_msgs = request_task .result ()
158
- output_msg_str = output_task .result ()
164
+ request_user_msgs , model = request_task .result ()
165
+ output_msg_str , token_usage = output_task .result ()
159
166
160
167
# If we couldn't parse the request, return None
161
168
if not request_user_msgs :
@@ -176,7 +183,23 @@ async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[Partial
176
183
)
177
184
else :
178
185
output_message = None
179
- return PartialQuestionAnswer (partial_questions = request_message , answer = output_message )
186
+
187
+ # Use the model to update the token cost
188
+ token_usage .update_token_cost (model )
189
+ provider = row .provider
190
+ # TODO: This should come from the database. For now, we are manually changing copilot to openai
191
+ # Change copilot provider to openai
192
+ if provider == "copilot" :
193
+ provider = "openai"
194
+ model_token_usage = TokenUsageByModel (
195
+ model = model , token_usage = token_usage , provider_type = provider
196
+ )
197
+
198
+ return PartialQuestionAnswer (
199
+ partial_questions = request_message ,
200
+ answer = output_message ,
201
+ model_token_usage = model_token_usage ,
202
+ )
180
203
181
204
182
205
def parse_question_answer (input_text : str ) -> str :
@@ -304,6 +327,7 @@ async def match_conversations(
304
327
map_q_id_to_conversation = {}
305
328
for group in grouped_partial_questions :
306
329
questions_answers : List [QuestionAnswer ] = []
330
+ token_usage_agg = TokenUsageAggregate (tokens_by_model = {}, token_usage = TokenUsage ())
307
331
first_partial_qa = None
308
332
for partial_question in sorted (group , key = lambda x : x .timestamp ):
309
333
# Partial questions don't contain the answer, so we need to find the corresponding
@@ -322,16 +346,19 @@ async def match_conversations(
322
346
qa = _get_question_answer_from_partial (selected_partial_qa )
323
347
qa .question .message = parse_question_answer (qa .question .message )
324
348
questions_answers .append (qa )
349
+ token_usage_agg .add_model_token_usage (selected_partial_qa .model_token_usage )
325
350
326
351
# only add conversation if we have some answers
327
352
if len (questions_answers ) > 0 and first_partial_qa is not None :
353
+ if token_usage_agg .token_usage .input_tokens == 0 :
354
+ token_usage_agg = None
328
355
conversation = Conversation (
329
356
question_answers = questions_answers ,
330
357
provider = first_partial_qa .partial_questions .provider ,
331
358
type = first_partial_qa .partial_questions .type ,
332
359
chat_id = first_partial_qa .partial_questions .message_id ,
333
360
conversation_timestamp = first_partial_qa .partial_questions .timestamp ,
334
- token_usage = None ,
361
+ token_usage_agg = token_usage_agg ,
335
362
)
336
363
for qa in questions_answers :
337
364
map_q_id_to_conversation [qa .question .message_id ] = conversation
0 commit comments