Skip to content

Commit 604f28c

Browse files
Include the token usage for every conversation and workspace
Related: #418 This PR does introduces the changes necessary to track the used tokens per request and then process them to return them in the API. Specific changes: - Make sure we process all the stream and record at the very end - Include the flag `"stream_options": {"include_usage": True},` so the providers respond with the tokens - Added the necessary processing for the API - Modified the initial API models to display correctly the tokens and its price
1 parent f11e28d commit 604f28c

File tree

15 files changed

+169
-72
lines changed

15 files changed

+169
-72
lines changed

src/codegate/api/v1.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -477,11 +477,11 @@ def version_check():
477477
tags=["Workspaces", "Token Usage"],
478478
generate_unique_id_function=uniq_name,
479479
)
480-
async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage:
480+
async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsageAggregate:
481481
"""Get the token usage of a workspace."""
482482
# TODO: This is a dummy implementation. In the future, we should have a proper
483483
# implementation that fetches the token usage from the database.
484-
return v1_models.TokenUsage(
484+
return v1_models.TokenUsageAggregate(
485485
used_tokens=50,
486486
tokens_by_model=[
487487
v1_models.TokenUsageByModel(

src/codegate/api/v1_models.py

+70-15
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import datetime
22
from enum import Enum
3-
from typing import Any, List, Optional, Union
3+
from typing import Any, Dict, List, Optional, Union
44

55
import pydantic
6+
import requests
7+
from cachetools import TTLCache
68

79
from codegate.db import models as db_models
810
from codegate.pipeline.base import CodeSnippet
911

12+
# 1 day cache. Not keep all the models in the cache. Just the ones we have used recently.
13+
model_cost_cache = TTLCache(maxsize=2000, ttl=1 * 24 * 60 * 60)
14+
1015

1116
class Workspace(pydantic.BaseModel):
1217
name: str
@@ -105,15 +110,6 @@ class PartialQuestions(pydantic.BaseModel):
105110
type: QuestionType
106111

107112

108-
class PartialQuestionAnswer(pydantic.BaseModel):
109-
"""
110-
Represents a partial conversation.
111-
"""
112-
113-
partial_questions: PartialQuestions
114-
answer: Optional[ChatMessage]
115-
116-
117113
class ProviderType(str, Enum):
118114
"""
119115
Represents the different types of providers we support.
@@ -124,24 +120,83 @@ class ProviderType(str, Enum):
124120
vllm = "vllm"
125121

126122

123+
class TokenUsage(pydantic.BaseModel):
124+
input_tokens: int = 0
125+
output_tokens: int = 0
126+
input_cost: float = 0
127+
output_cost: float = 0
128+
129+
@classmethod
130+
def from_dict(cls, usage_dict: Dict) -> "TokenUsage":
131+
return cls(
132+
input_tokens=usage_dict.get("prompt_tokens", 0) or usage_dict.get("input_tokens", 0),
133+
output_tokens=usage_dict.get("completion_tokens", 0)
134+
or usage_dict.get("output_tokens", 0),
135+
input_cost=0,
136+
output_cost=0,
137+
)
138+
139+
def __add__(self, other: "TokenUsage") -> "TokenUsage":
140+
return TokenUsage(
141+
input_tokens=self.input_tokens + other.input_tokens,
142+
output_tokens=self.output_tokens + other.output_tokens,
143+
input_cost=self.input_cost + other.input_cost,
144+
output_cost=self.output_cost + other.output_cost,
145+
)
146+
147+
def update_token_cost(self, model: str) -> None:
148+
if not model_cost_cache:
149+
model_cost = requests.get(
150+
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
151+
)
152+
model_cost_cache.update(model_cost.json())
153+
model_cost = model_cost_cache.get(model, {})
154+
input_cost_per_token = model_cost.get("input_cost_per_token", 0)
155+
output_cost_per_token = model_cost.get("output_cost_per_token", 0)
156+
self.input_cost = self.input_tokens * input_cost_per_token
157+
self.output_cost = self.output_tokens * output_cost_per_token
158+
159+
def update_costs_based_on_model(self, model: str):
160+
pass
161+
162+
127163
class TokenUsageByModel(pydantic.BaseModel):
128164
"""
129165
Represents the tokens used by a model.
130166
"""
131167

132168
provider_type: ProviderType
133169
model: str
134-
used_tokens: int
170+
token_usage: TokenUsage
135171

136172

137-
class TokenUsage(pydantic.BaseModel):
173+
class TokenUsageAggregate(pydantic.BaseModel):
138174
"""
139175
Represents the tokens used. Includes the information of the tokens used by model.
140176
`used_tokens` are the total tokens used in the `tokens_by_model` list.
141177
"""
142178

143-
tokens_by_model: List[TokenUsageByModel]
144-
used_tokens: int
179+
tokens_by_model: Dict[str, TokenUsageByModel]
180+
token_usage: TokenUsage
181+
182+
def add_model_token_usage(self, model_token_usage: TokenUsageByModel) -> None:
183+
if model_token_usage.model in self.tokens_by_model:
184+
self.tokens_by_model[
185+
model_token_usage.model
186+
].token_usage += model_token_usage.token_usage
187+
else:
188+
self.tokens_by_model[model_token_usage.model] = model_token_usage
189+
self.token_usage += model_token_usage.token_usage
190+
191+
192+
class PartialQuestionAnswer(pydantic.BaseModel):
193+
"""
194+
Represents a partial conversation.
195+
"""
196+
197+
partial_questions: PartialQuestions
198+
answer: Optional[ChatMessage]
199+
model_token_usage: TokenUsageByModel
145200

146201

147202
class Conversation(pydantic.BaseModel):
@@ -154,7 +209,7 @@ class Conversation(pydantic.BaseModel):
154209
type: QuestionType
155210
chat_id: str
156211
conversation_timestamp: datetime.datetime
157-
token_usage: Optional[TokenUsage]
212+
token_usage_agg: Optional[TokenUsageAggregate]
158213

159214

160215
class AlertConversation(pydantic.BaseModel):

src/codegate/api/v1_processing.py

+47-20
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
PartialQuestionAnswer,
1515
PartialQuestions,
1616
QuestionAnswer,
17+
TokenUsage,
18+
TokenUsageAggregate,
19+
TokenUsageByModel,
1720
)
1821
from codegate.db.connection import alert_queue
1922
from codegate.db.models import Alert, GetPromptWithOutputsRow
@@ -57,16 +60,17 @@ async def _is_system_prompt(message: str) -> bool:
5760
return False
5861

5962

60-
async def parse_request(request_str: str) -> Optional[str]:
63+
async def parse_request(request_str: str) -> Tuple[Optional[List[str]], str]:
6164
"""
62-
Parse the request string from the pipeline and return the message.
65+
Parse the request string from the pipeline and return the message and the model.
6366
"""
6467
try:
6568
request = json.loads(request_str)
6669
except Exception as e:
6770
logger.warning(f"Error parsing request: {request_str}. {e}")
68-
return None
71+
return None, ""
6972

73+
model = request.get("model", "")
7074
messages = []
7175
for message in request.get("messages", []):
7276
role = message.get("role")
@@ -91,57 +95,60 @@ async def parse_request(request_str: str) -> Optional[str]:
9195
if message_prompt and not await _is_system_prompt(message_prompt):
9296
messages.append(message_prompt)
9397

94-
# If still we don't have anything, return empty string
98+
# If still we don't have anything, return None string
9599
if not messages:
96-
return None
100+
return None, model
97101

98-
# Only respond with the latest message
99-
return messages
102+
# Respond with the messages and the model
103+
return messages, model
100104

101105

102-
async def parse_output(output_str: str) -> Optional[str]:
106+
async def parse_output(output_str: str) -> Tuple[Optional[str], TokenUsage]:
103107
"""
104108
Parse the output string from the pipeline and return the message.
105109
"""
106110
try:
107111
if output_str is None:
108-
return None
112+
return None, TokenUsage()
109113

110114
output = json.loads(output_str)
111115
except Exception as e:
112116
logger.warning(f"Error parsing output: {output_str}. {e}")
113-
return None
117+
return None, TokenUsage()
114118

115-
def _parse_single_output(single_output: dict) -> str:
119+
def _parse_single_output(single_output: dict) -> Tuple[str, TokenUsage]:
116120
single_output_message = ""
117121
for choice in single_output.get("choices", []):
118122
if not isinstance(choice, dict):
119123
continue
120124
content_dict = choice.get("delta", {}) or choice.get("message", {})
121125
single_output_message += content_dict.get("content", "")
122-
return single_output_message
126+
return single_output_message, TokenUsage.from_dict(single_output.get("usage", {}))
123127

124128
full_output_message = ""
129+
full_token_usage = TokenUsage()
125130
if isinstance(output, list):
126131
for output_chunk in output:
127132
output_message = ""
133+
token_usage = TokenUsage()
128134
if isinstance(output_chunk, dict):
129-
output_message = _parse_single_output(output_chunk)
135+
output_message, token_usage = _parse_single_output(output_chunk)
130136
elif isinstance(output_chunk, str):
131137
try:
132138
output_decoded = json.loads(output_chunk)
133-
output_message = _parse_single_output(output_decoded)
139+
output_message, token_usage = _parse_single_output(output_decoded)
134140
except Exception:
135141
logger.error(f"Error reading chunk: {output_chunk}")
136142
else:
137143
logger.warning(
138144
f"Could not handle output: {output_chunk}", out_type=type(output_chunk)
139145
)
140146
full_output_message += output_message
147+
full_token_usage += token_usage
141148
elif isinstance(output, dict):
142-
full_output_message = _parse_single_output(output)
149+
full_output_message, full_token_usage = _parse_single_output(output)
143150

144-
return full_output_message
151+
return full_output_message, full_token_usage
145152

146153

147154
async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[PartialQuestionAnswer]:
@@ -154,8 +161,8 @@ async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[Partial
154161
request_task = tg.create_task(parse_request(row.request))
155162
output_task = tg.create_task(parse_output(row.output))
156163

157-
request_user_msgs = request_task.result()
158-
output_msg_str = output_task.result()
164+
request_user_msgs, model = request_task.result()
165+
output_msg_str, token_usage = output_task.result()
159166

160167
# If we couldn't parse the request, return None
161168
if not request_user_msgs:
@@ -176,7 +183,23 @@ async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[Partial
176183
)
177184
else:
178185
output_message = None
179-
return PartialQuestionAnswer(partial_questions=request_message, answer=output_message)
186+
187+
# Use the model to update the token cost
188+
token_usage.update_token_cost(model)
189+
provider = row.provider
190+
# TODO: This should come from the database. For now, we are manually changing copilot to openai
191+
# Change copilot provider to openai
192+
if provider == "copilot":
193+
provider = "openai"
194+
model_token_usage = TokenUsageByModel(
195+
model=model, token_usage=token_usage, provider_type=provider
196+
)
197+
198+
return PartialQuestionAnswer(
199+
partial_questions=request_message,
200+
answer=output_message,
201+
model_token_usage=model_token_usage,
202+
)
180203

181204

182205
def parse_question_answer(input_text: str) -> str:
@@ -304,6 +327,7 @@ async def match_conversations(
304327
map_q_id_to_conversation = {}
305328
for group in grouped_partial_questions:
306329
questions_answers: List[QuestionAnswer] = []
330+
token_usage_agg = TokenUsageAggregate(tokens_by_model={}, token_usage=TokenUsage())
307331
first_partial_qa = None
308332
for partial_question in sorted(group, key=lambda x: x.timestamp):
309333
# Partial questions don't contain the answer, so we need to find the corresponding
@@ -322,16 +346,19 @@ async def match_conversations(
322346
qa = _get_question_answer_from_partial(selected_partial_qa)
323347
qa.question.message = parse_question_answer(qa.question.message)
324348
questions_answers.append(qa)
349+
token_usage_agg.add_model_token_usage(selected_partial_qa.model_token_usage)
325350

326351
# only add conversation if we have some answers
327352
if len(questions_answers) > 0 and first_partial_qa is not None:
353+
if token_usage_agg.token_usage.input_tokens == 0:
354+
token_usage_agg = None
328355
conversation = Conversation(
329356
question_answers=questions_answers,
330357
provider=first_partial_qa.partial_questions.provider,
331358
type=first_partial_qa.partial_questions.type,
332359
chat_id=first_partial_qa.partial_questions.message_id,
333360
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
334-
token_usage=None,
361+
token_usage_agg=token_usage_agg,
335362
)
336363
for qa in questions_answers:
337364
map_q_id_to_conversation[qa.question.message_id] = conversation

src/codegate/inference/inference_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _close_models(self):
3535
model._sampler.close()
3636
model.close()
3737

38-
async def __get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers=0):
38+
async def __get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers=0) -> Llama:
3939
"""
4040
Returns Llama model object from __models if present. Otherwise, the model
4141
is loaded and added to __models and returned.

src/codegate/pipeline/output.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ def _store_chunk_content(self, chunk: ModelResponse) -> None:
113113
if choice.delta is not None and choice.delta.content is not None:
114114
self._context.processed_content.append(choice.delta.content)
115115

116-
async def _record_to_db(self):
117-
await self._db_recorder.record_context(self._input_context)
118-
119116
async def process_stream(
120117
self, stream: AsyncIterator[ModelResponse], cleanup_sensitive: bool = True
121118
) -> AsyncIterator[ModelResponse]:
@@ -144,13 +141,6 @@ async def process_stream(
144141

145142
current_chunks = processed_chunks
146143

147-
# **Needed for Copilot**. This is a hacky way of recording in DB the context
148-
# when we see the last chunk. Ideally this should be done in a `finally` or
149-
# `StopAsyncIteration` but Copilot streams in an infite while loop so is not
150-
# possible
151-
if len(chunk.choices) > 0 and chunk.choices[0].get("finish_reason", "") == "stop":
152-
await self._record_to_db()
153-
154144
# Yield all processed chunks
155145
for c in current_chunks:
156146
self._store_chunk_content(c)
@@ -164,12 +154,13 @@ async def process_stream(
164154
finally:
165155
# Don't flush the buffer if we assume we'll call the pipeline again
166156
if cleanup_sensitive is False:
157+
await self._db_recorder.record_context(self._input_context)
167158
return
168159

169160
# Process any remaining content in buffer when stream ends
170161
if self._context.buffer:
171162
final_content = "".join(self._context.buffer)
172-
yield ModelResponse(
163+
chunk = ModelResponse(
173164
id=self._buffered_chunk.id,
174165
choices=[
175166
StreamingChoices(
@@ -185,8 +176,11 @@ async def process_stream(
185176
model=self._buffered_chunk.model,
186177
object="chat.completion.chunk",
187178
)
179+
self._input_context.add_output(chunk)
180+
yield chunk
188181
self._context.buffer.clear()
189182

183+
await self._db_recorder.record_context(self._input_context)
190184
# Cleanup sensitive data through the input context
191185
if cleanup_sensitive and self._input_context and self._input_context.sensitive:
192186
self._input_context.sensitive.secure_cleanup()

src/codegate/providers/base.py

-2
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,6 @@ async def _cleanup_after_streaming(
193193
yield item
194194
finally:
195195
if context:
196-
# Record to DB the objects captured during the stream
197-
await self._db_recorder.record_context(context)
198196
# Ensure sensitive data is cleaned up after the stream is consumed
199197
if context.sensitive:
200198
context.sensitive.secure_cleanup()

0 commit comments

Comments
 (0)