Skip to content

Commit

Permalink
Merge pull request #533 from MichaelDecent/0.4.5.dev1
Browse files Browse the repository at this point in the history
[implemented parametrize features]: llms in swarmauri pkg
  • Loading branch information
cobycloud authored Sep 25, 2024
2 parents 5a6bafe + 8d36cdb commit 2eebeac
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 45 deletions.
27 changes: 21 additions & 6 deletions pkgs/swarmauri/tests/unit/llms/AI21StudioModel_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,30 @@
from swarmauri.llms.concrete.AI21StudioModel import AI21StudioModel as LLM
from swarmauri.conversations.concrete.Conversation import Conversation

from swarmauri.messages.concrete.AgentMessage import AgentMessage
from swarmauri.messages.concrete.HumanMessage import HumanMessage
from swarmauri.messages.concrete.SystemMessage import SystemMessage
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("AI21STUDIO_API_KEY")


@pytest.fixture(scope="module")
def ai21studio_model():
API_KEY = os.getenv("AI21STUDIO_API_KEY")
if not API_KEY:
pytest.skip("Skipping due to environment variable not set")
llm = LLM(api_key=API_KEY)
return llm


def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
return llm.allowed_models


@pytest.mark.unit
def test_ubc_resource(ai21studio_model):
assert ai21studio_model.resource == "LLM"
Expand All @@ -41,8 +51,10 @@ def test_default_name(ai21studio_model):


@pytest.mark.unit
def test_no_system_context(ai21studio_model):
@pytest.mark.parametrize("model_name", get_allowed_models())
def test_no_system_context(ai21studio_model, model_name):
model = ai21studio_model
model.name = model_name
conversation = Conversation()

input_data = "Hello"
Expand All @@ -51,12 +63,15 @@ def test_no_system_context(ai21studio_model):

model.predict(conversation=conversation)
prediction = conversation.get_last().content
assert type(prediction) == str
assert isinstance(prediction, str)


@pytest.mark.unit
def test_preamble_system_context(ai21studio_model):
@pytest.mark.parametrize("model_name", get_allowed_models())
def test_preamble_system_context(ai21studio_model, model_name):
model = ai21studio_model
model.name = model_name

conversation = Conversation()

system_context = 'You only respond with the following phrase, "Jeff"'
Expand All @@ -70,4 +85,4 @@ def test_preamble_system_context(ai21studio_model):
model.predict(conversation=conversation)
prediction = conversation.get_last().content
assert type(prediction) == str
assert "Jeff" in prediction
assert "Jeff" in prediction, f"Test failed for model: {model_name}"
23 changes: 19 additions & 4 deletions pkgs/swarmauri/tests/unit/llms/AnthropicModel_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,30 @@
from swarmauri.llms.concrete.AnthropicModel import AnthropicModel as LLM
from swarmauri.conversations.concrete.Conversation import Conversation

from swarmauri.messages.concrete.AgentMessage import AgentMessage
from swarmauri.messages.concrete.HumanMessage import HumanMessage
from swarmauri.messages.concrete.SystemMessage import SystemMessage
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("ANTHROPIC_API_KEY")


@pytest.fixture(scope="module")
def anthropic_model():
API_KEY = os.getenv("ANTHROPIC_API_KEY")
if not API_KEY:
pytest.skip("Skipping due to environment variable not set")
llm = LLM(api_key=API_KEY)
return llm


def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
return llm.allowed_models


@pytest.mark.unit
def test_ubc_resource(anthropic_model):
assert anthropic_model.resource == "LLM"
Expand All @@ -40,9 +50,12 @@ def test_default_name(anthropic_model):
assert anthropic_model.name == "claude-3-haiku-20240307"


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_no_system_context(anthropic_model):
def test_no_system_context(anthropic_model, model_name):
model = anthropic_model
model.name = model_name

conversation = Conversation()

input_data = "Hello"
Expand All @@ -54,9 +67,11 @@ def test_no_system_context(anthropic_model):
assert type(prediction) == str


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_preamble_system_context(anthropic_model):
def test_preamble_system_context(anthropic_model, model_name):
model = anthropic_model
model.name = model_name
conversation = Conversation()

system_context = 'You only respond with the following phrase, "Jeff"'
Expand Down
24 changes: 20 additions & 4 deletions pkgs/swarmauri/tests/unit/llms/CohereModel_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,30 @@
from swarmauri.llms.concrete.CohereModel import CohereModel as LLM
from swarmauri.conversations.concrete.Conversation import Conversation

from swarmauri.messages.concrete.AgentMessage import AgentMessage
from swarmauri.messages.concrete.HumanMessage import HumanMessage
from swarmauri.messages.concrete.SystemMessage import SystemMessage
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("COHERE_API_KEY")


@pytest.fixture(scope="module")
def cohere_model():
API_KEY = os.getenv("COHERE_API_KEY")
if not API_KEY:
pytest.skip("Skipping due to environment variable not set")
llm = LLM(api_key=API_KEY)
return llm


def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
return llm.allowed_models


@pytest.mark.unit
def test_ubc_resource(cohere_model):
assert cohere_model.resource == "LLM"
Expand All @@ -37,9 +47,12 @@ def test_default_name(cohere_model):
assert cohere_model.name == "command"


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_no_system_context(cohere_model):
def test_no_system_context(cohere_model, model_name):
model = cohere_model
model.name = model_name

conversation = Conversation()

input_data = "Hello"
Expand All @@ -51,9 +64,12 @@ def test_no_system_context(cohere_model):
assert type(prediction) == str


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_preamble_system_context(cohere_model):
def test_preamble_system_context(cohere_model, model_name):
model = cohere_model
model.name = model_name

conversation = Conversation()

system_context = "Jane knows Martin."
Expand Down
22 changes: 18 additions & 4 deletions pkgs/swarmauri/tests/unit/llms/DeepInfraModel_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,30 @@
from swarmauri.llms.concrete.DeepInfraModel import DeepInfraModel as LLM
from swarmauri.conversations.concrete.Conversation import Conversation

from swarmauri.messages.concrete.AgentMessage import AgentMessage
from swarmauri.messages.concrete.HumanMessage import HumanMessage
from swarmauri.messages.concrete.SystemMessage import SystemMessage
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("DEEPINFRA_API_KEY")


@pytest.fixture(scope="module")
def deepinfra_model():
API_KEY = os.getenv("DEEPINFRA_API_KEY")
if not API_KEY:
pytest.skip("Skipping due to environment variable not set")
llm = LLM(api_key=API_KEY)
return llm


def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
return llm.allowed_models


@pytest.mark.unit
def test_ubc_resource(deepinfra_model):
assert deepinfra_model.resource == "LLM"
Expand All @@ -40,9 +50,11 @@ def test_default_name(deepinfra_model):
assert deepinfra_model.name == "Qwen/Qwen2-72B-Instruct"


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_no_system_context(deepinfra_model):
def test_no_system_context(deepinfra_model, model_name):
model = deepinfra_model
model.name = model_name
conversation = Conversation()

input_data = "Hello"
Expand All @@ -54,9 +66,11 @@ def test_no_system_context(deepinfra_model):
assert type(prediction) is str


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_preamble_system_context(deepinfra_model):
def test_preamble_system_context(deepinfra_model, model_name):
model = deepinfra_model
model.name = model_name
conversation = Conversation()

system_context = 'You only respond with the following phrase, "Jeff"'
Expand Down
14 changes: 12 additions & 2 deletions pkgs/swarmauri/tests/unit/llms/DeepSeekModel_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,30 @@
from swarmauri.llms.concrete.DeepSeekModel import DeepSeekModel as LLM
from swarmauri.conversations.concrete.Conversation import Conversation

from swarmauri.messages.concrete.AgentMessage import AgentMessage
from swarmauri.messages.concrete.HumanMessage import HumanMessage
from swarmauri.messages.concrete.SystemMessage import SystemMessage
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("DEEPSEEK_API_KEY")


@pytest.fixture(scope="module")
def deepseek_model():
API_KEY = os.getenv("DEEPSEEK_API_KEY")
if not API_KEY:
pytest.skip("Skipping due to environment variable not set")
llm = LLM(api_key=API_KEY)
return llm


def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
return llm.allowed_models


@pytest.mark.unit
def test_ubc_resource(deepseek_model):
assert deepseek_model.resource == "LLM"
Expand Down
22 changes: 18 additions & 4 deletions pkgs/swarmauri/tests/unit/llms/GeminiProModel_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,30 @@
from swarmauri.llms.concrete.GeminiProModel import GeminiProModel as LLM
from swarmauri.conversations.concrete.Conversation import Conversation

from swarmauri.messages.concrete.AgentMessage import AgentMessage
from swarmauri.messages.concrete.HumanMessage import HumanMessage
from swarmauri.messages.concrete.SystemMessage import SystemMessage
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("GEMINI_API_KEY")


@pytest.fixture(scope="module")
def geminipro_model():
API_KEY = os.getenv("GEMINI_API_KEY")
if not API_KEY:
pytest.skip("Skipping due to environment variable not set")
llm = LLM(api_key=API_KEY)
return llm


def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
return llm.allowed_models


@pytest.mark.unit
def test_ubc_resource(geminipro_model):
assert geminipro_model.resource == "LLM"
Expand All @@ -40,9 +50,11 @@ def test_default_name(geminipro_model):
assert geminipro_model.name == "gemini-1.5-pro-latest"


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_no_system_context(geminipro_model):
def test_no_system_context(geminipro_model, model_name):
model = geminipro_model
model.name = model_name
conversation = Conversation()

input_data = "Hello"
Expand All @@ -53,9 +65,11 @@ def test_no_system_context(geminipro_model):
assert type(prediction) == str


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_preamble_system_context(geminipro_model):
def test_preamble_system_context(geminipro_model, model_name):
model = geminipro_model
model.name = model_name
conversation = Conversation()

system_context = 'You only respond with the following phrase, "Jeff"'
Expand Down
22 changes: 18 additions & 4 deletions pkgs/swarmauri/tests/unit/llms/GroqModel_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,30 @@
from swarmauri.llms.concrete.GroqModel import GroqModel as LLM
from swarmauri.conversations.concrete.Conversation import Conversation

from swarmauri.messages.concrete.AgentMessage import AgentMessage
from swarmauri.messages.concrete.HumanMessage import HumanMessage
from swarmauri.messages.concrete.SystemMessage import SystemMessage
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("GROQ_API_KEY")


@pytest.fixture(scope="module")
def groq_model():
API_KEY = os.getenv("GROQ_API_KEY")
if not API_KEY:
pytest.skip("Skipping due to environment variable not set")
llm = LLM(api_key=API_KEY)
return llm


def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
return llm.allowed_models


@pytest.mark.unit
def test_ubc_resource(groq_model):
assert groq_model.resource == "LLM"
Expand All @@ -37,9 +47,11 @@ def test_default_name(groq_model):
assert groq_model.name == "gemma-7b-it"


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_no_system_context(groq_model):
def test_no_system_context(groq_model, model_name):
model = groq_model
model.name = model_name
conversation = Conversation()

input_data = "Hello"
Expand All @@ -51,9 +63,11 @@ def test_no_system_context(groq_model):
assert type(prediction) == str


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_preamble_system_context(groq_model):
def test_preamble_system_context(groq_model, model_name):
model = groq_model
model.name = model_name
conversation = Conversation()

system_context = 'You only respond with the following phrase, "Jeff"'
Expand Down
Loading

0 comments on commit 2eebeac

Please sign in to comment.