Skip to content

Commit

Permalink
Fix: LLMs test cases and added new xfail test case
Browse files Browse the repository at this point in the history
Signed-off-by: RafaelJohn9 <[email protected]>
  • Loading branch information
RafaelJohn9 committed Sep 26, 2024
1 parent fc9a7df commit 96ac9d4
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 4 deletions.
3 changes: 1 addition & 2 deletions pkgs/swarmauri/swarmauri/llms/concrete/PerplexityModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class PerplexityModel(LLMBase):
"""
Provider resources: https://docs.perplexity.ai/guides/model-cards
Link to depreciated models: https://docs.perplexity.ai/changelog/changelog#model-deprecation-notice
Link to deprecated models: https://docs.perplexity.ai/changelog/changelog#model-deprecation-notice
"""

api_key: str
Expand Down Expand Up @@ -54,7 +54,6 @@ def predict(
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
):

if top_p and top_k:
raise ValueError("Do not set top_p and top_k")

Expand Down
4 changes: 2 additions & 2 deletions pkgs/swarmauri/tests/integration/llms/CohereModel_i9n_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_nonpreamble_system_context(cohere_model):
human_message = HumanMessage(content=input_data)
conversation.add_message(human_message)

model.predict(conversation=conversation)
model.predict(conversation=conversation, temperature=0)
prediction = conversation.get_last().content
assert "Jeff" in prediction

Expand Down Expand Up @@ -68,7 +68,7 @@ def test_multiple_system_contexts(cohere_model):
human_message = HumanMessage(content=input_data_2)
conversation.add_message(human_message)

model.predict(conversation=conversation)
model.predict(conversation=conversation, temperature=0)
prediction = conversation.get_last().content
assert type(prediction) == str
assert "Ben" in prediction
30 changes: 30 additions & 0 deletions pkgs/swarmauri/tests/unit/llms/GroqModel_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_allowed_models():
if not API_KEY:
return []
llm = LLM(api_key=API_KEY)
llm.allowed_models.remove("llava-v1.5-7b-4096-preview")
return llm.allowed_models


Expand Down Expand Up @@ -82,3 +83,32 @@ def test_preamble_system_context(groq_model, model_name):
prediction = conversation.get_last().content
assert type(prediction) == str
assert "Jeff" in prediction


@pytest.mark.unit
def test_preamble_system_context_custom():
"""
specifically for llava-v1.5-7b-4096-preview
"""
if not API_KEY:
return []

groq_model = LLM(api_key=API_KEY)
model_name = "llava-v1.5-7b-4096-preview"

model = groq_model
model.name = model_name
conversation = Conversation()

system_context = 'You only respond with the following phrase, "Jeff"'
human_message = SystemMessage(content=system_context)
conversation.add_message(human_message)

input_data = "Hi"
human_message = HumanMessage(content=input_data)
conversation.add_message(human_message)

model.predict(conversation=conversation)
prediction = conversation.get_last().content
assert type(prediction) == str
assert "Jeff" in prediction

0 comments on commit 96ac9d4

Please sign in to comment.