Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for Gemini via openai package #382

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# To set API keys, copy this as .env and fill in the values
# API keys are required to rewrite the VCR cassettes for the tests
ANTHROPIC_API_KEY="anthropic_api_key"
GEMINI_API_KEY="gemini_api_key"
MISTRAL_API_KEY="mistral_api_key"
OPENAI_API_KEY="openai_api_key"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ markers = [
"litellm_openai: Tests that query the OpenAI API via litellm. Requires the OPENAI_API_KEY environment variable to be set.",
"mistral: Tests that query the Mistral API (via openai). Requires the MISTRAL_API_KEY environment variable to be set.",
"openai: Tests that query the OpenAI API. Requires the OPENAI_API_KEY environment variable to be set.",
"openai_gemini: Tests that query the Gemini API via openai. Requires the GEMINI_API_KEY environment variable to be set.",
"openai_ollama: Tests that query Ollama via openai. Requires ollama to be installed and running on localhost:11434.",
]

Expand Down
5 changes: 4 additions & 1 deletion src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,10 @@ def update(self, item: ChatCompletionChunk) -> None:
tool_call_chunk.index = self._current_tool_call_index
self._chat_completion_stream_state.handle_chunk(item)
if item.usage:
assert not self.usage_ref
# Only keep the last usage
# Gemini openai-compatible API includes usage in all streamed chunks
# but OpenAI only includes this in the last chunk
self.usage_ref.clear()
self.usage_ref.append(
Usage(
input_tokens=item.usage.prompt_tokens,
Expand Down
52 changes: 52 additions & 0 deletions tests/chat_model/test_openai_chat_model_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import pytest

from magentic.chat_model.message import UserMessage
from magentic.chat_model.openai_chat_model import OpenaiChatModel


@pytest.mark.parametrize(
("prompt", "output_types", "expected_output_type"),
[
("Say hello!", [str], str),
("Return True.", [bool], bool),
("Return [1, 2, 3, 4, 5]", [list[int]], list),
("Return a list of fruit", [list[str]], list),
],
)
@pytest.mark.openai_gemini
def test_openai_chat_model_complete_gemini(prompt, output_types, expected_output_type):
chat_model = OpenaiChatModel(
"gemini-1.5-flash",
api_key=os.environ["GEMINI_API_KEY"],
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
message = chat_model.complete(
messages=[UserMessage(prompt)], output_types=output_types
)
assert isinstance(message.content, expected_output_type)


@pytest.mark.parametrize(
("prompt", "output_types", "expected_output_type"),
[
("Say hello!", [str], str),
("Return True.", [bool], bool),
("Return [1, 2, 3, 4, 5]", [list[int]], list),
("Return a list of fruit", [list[str]], list),
],
)
@pytest.mark.openai_gemini
async def test_openai_chat_model_acomplete_gemini(
prompt, output_types, expected_output_type
):
chat_model = OpenaiChatModel(
"gemini-1.5-flash",
api_key=os.environ["GEMINI_API_KEY"],
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
message = await chat_model.acomplete(
messages=[UserMessage(prompt)], output_types=output_types
)
assert isinstance(message.content, expected_output_type)
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def pytest_collection_modifyitems(
"litellm_openai",
"mistral",
"openai",
"openai_gemini",
"openai_ollama",
]
for item in items:
Expand Down
Loading