diff --git a/.env.template b/.env.template index 5d3d862..84167fa 100644 --- a/.env.template +++ b/.env.template @@ -1,5 +1,6 @@ # To set API keys, copy this as .env and fill in the values # API keys are required to rewrite the VCR cassettes for the tests ANTHROPIC_API_KEY="anthropic_api_key" +GEMINI_API_KEY="gemini_api_key" MISTRAL_API_KEY="mistral_api_key" OPENAI_API_KEY="openai_api_key" diff --git a/pyproject.toml b/pyproject.toml index 4326802..da3d8eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ markers = [ "litellm_openai: Tests that query the OpenAI API via litellm. Requires the OPENAI_API_KEY environment variable to be set.", "mistral: Tests that query the Mistral API (via openai). Requires the MISTRAL_API_KEY environment variable to be set.", "openai: Tests that query the OpenAI API. Requires the OPENAI_API_KEY environment variable to be set.", + "openai_gemini: Tests that query the Gemini API via openai. Requires the GEMINI_API_KEY environment variable to be set.", "openai_ollama: Tests that query Ollama via openai. Requires ollama to be installed and running on localhost:11434.", ] diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 468fd30..35bac9a 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -350,7 +350,10 @@ def update(self, item: ChatCompletionChunk) -> None: tool_call_chunk.index = self._current_tool_call_index self._chat_completion_stream_state.handle_chunk(item) if item.usage: - assert not self.usage_ref + # Only keep the last usage + # Gemini openai-compatible API includes usage in all streamed chunks + # but OpenAI only includes this in the last chunk + self.usage_ref.clear() self.usage_ref.append( Usage( input_tokens=item.usage.prompt_tokens, diff --git a/tests/chat_model/test_openai_chat_model_gemini.py b/tests/chat_model/test_openai_chat_model_gemini.py new file mode 100644 index 0000000..c2ab7b0 --- /dev/null +++ b/tests/chat_model/test_openai_chat_model_gemini.py @@ -0,0 +1,52 @@ +import os + +import pytest + +from magentic.chat_model.message import UserMessage +from magentic.chat_model.openai_chat_model import OpenaiChatModel + + +@pytest.mark.parametrize( + ("prompt", "output_types", "expected_output_type"), + [ + ("Say hello!", [str], str), + ("Return True.", [bool], bool), + ("Return [1, 2, 3, 4, 5]", [list[int]], list), + ("Return a list of fruit", [list[str]], list), + ], +) +@pytest.mark.openai_gemini +def test_openai_chat_model_complete_gemini(prompt, output_types, expected_output_type): + chat_model = OpenaiChatModel( + "gemini-1.5-flash", + api_key=os.environ["GEMINI_API_KEY"], + base_url="https://generativelanguage.googleapis.com/v1beta/openai/", + ) + message = chat_model.complete( + messages=[UserMessage(prompt)], output_types=output_types + ) + assert isinstance(message.content, expected_output_type) + + +@pytest.mark.parametrize( + ("prompt", "output_types", "expected_output_type"), + [ + ("Say hello!", [str], str), + ("Return True.", [bool], bool), + ("Return [1, 2, 3, 4, 5]", [list[int]], list), + ("Return a list of fruit", [list[str]], list), + ], +) +@pytest.mark.openai_gemini +async def test_openai_chat_model_acomplete_gemini( + prompt, output_types, expected_output_type +): + chat_model = OpenaiChatModel( + "gemini-1.5-flash", + api_key=os.environ["GEMINI_API_KEY"], + base_url="https://generativelanguage.googleapis.com/v1beta/openai/", + ) + message = await chat_model.acomplete( + messages=[UserMessage(prompt)], output_types=output_types + ) + assert isinstance(message.content, expected_output_type) diff --git a/tests/conftest.py b/tests/conftest.py index 060f782..c4205e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -75,6 +75,7 @@ def pytest_collection_modifyitems( "litellm_openai", "mistral", "openai", + "openai_gemini", "openai_ollama", ] for item in items: