Skip to content

Commit

Permalink
Merge pull request #531 from keetrap/dev
Browse files Browse the repository at this point in the history
Added TokenCountEstimatorMetric
  • Loading branch information
cobycloud authored Sep 25, 2024
2 parents 3f68d19 + d478869 commit c0a3282
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pkgs/community/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@
"tf-keras",
"pinecone",
"neo4j",
"pinecone"
"tiktoken"


]
},
classifiers=[
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Literal
import tiktoken
from swarmauri.metrics.base.MetricBase import MetricBase
from swarmauri.metrics.base.MetricCalculateMixin import MetricCalculateMixin

class TokenCountEstimatorMetric(MetricBase, MetricCalculateMixin):
"""
A metric class to estimate the number of tokens in a given text.
"""
unit: str = "tokens"
type: Literal['TokenCountEstimatorMetric'] = 'TokenCountEstimatorMetric'

def calculate(self, text: str,encoding='cl100k_base') -> int:
"""
Calculate the number of tokens in the given text.
Args:
text (str): The input text to calculate token count for.
Returns:
int: The number of tokens in the text, or None if an error occurs.
"""
try:
encoding = tiktoken.get_encoding(encoding)
except ValueError as e:
print(f"Error: {e}")
return None

tokens = encoding.encode(text)
return len(tokens)
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
from swarmauri.metrics.concrete.TokenCountEstimatorMetric import TokenCountEstimatorMetric as Metric

@pytest.mark.unit
def test_ubc_resource():
def test():
assert Metric().resource == 'Metric'
test()

@pytest.mark.unit
def test_ubc_type():
metric = Metric()
assert metric.type == 'TokenCountEstimatorMetric'

@pytest.mark.unit
def test_serialization():
metric = Metric()
assert metric.id == Metric.model_validate_json(metric.model_dump_json()).id


@pytest.mark.unit
def test_metric_value():
def test():
assert Metric().calculate("Lorem ipsum odor amet, consectetuer adipiscing elit.") == 11
test()


@pytest.mark.unit
def test_metric_unit():
def test():
assert Metric().unit == "tokens"
test()

0 comments on commit c0a3282

Please sign in to comment.