Skip to content

Commit

Permalink
fix gemini (#1191)
Browse files Browse the repository at this point in the history
  • Loading branch information
karanataryn authored Feb 20, 2025
1 parent 5c1ce95 commit 49f6bf9
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions lib/sycamore/sycamore/llms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum
from typing import Any, Optional, Union
import os
import io

from sycamore.llms.llms import LLM
from sycamore.llms.prompts.prompts import RenderedPrompt
Expand All @@ -22,10 +23,10 @@ class GeminiModels(Enum):
"""Represents available Gemini models. More info: https://googleapis.github.io/python-genai/"""

# Note that the models available on a given Gemini account may vary.
GEMINI_2_FLASH = GeminiModel(name="gemini-2.0-flash-exp", is_chat=True)
GEMINI_2_FLASH = GeminiModel(name="gemini-2.0-flash", is_chat=True)
GEMINI_2_FLASH_LITE = GeminiModel(name="gemini-2.0-flash-lite-preview-02-05", is_chat=True)
GEMINI_2_FLASH_THINKING = GeminiModel(name="gemini-2.0-flash-thinking-exp", is_chat=True)
GEMINI_2_PRO = GeminiModel(name="gemini-2.0-pro-exp", is_chat=True)
GEMINI_2_FLASH_THINKING = GeminiModel(name="gemini-2.0-flash-thinking-exp-01-21", is_chat=True)
GEMINI_2_PRO = GeminiModel(name="gemini-2.0-pro-exp-02-05", is_chat=True)

@classmethod
def from_name(cls, name: str):
Expand Down Expand Up @@ -86,7 +87,7 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]
if prompt.response_format:
config["response_mime_type"] = "application/json"
config["response_schema"] = prompt.response_format
content_list = []
content_list: list[types.Content] = []
for message in prompt.messages:
if message.role == "system":
config["system_message"] = message.content
Expand All @@ -95,13 +96,15 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]
content = types.Content(parts=[types.Part.from_text(text=message.content)], role=role)
if message.images:
for image in message.images:
image_bytes = image.convert("RGB").tobytes()
content.parts.append(types.Part.from_bytes(image_bytes, media_type="image/png"))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
content.parts.append(types.Part.from_bytes(data=image_bytes, mime_type="image/png"))
content_list.append(content)
kwargs["config"] = None
if config:
kwargs["config"] = types.GenerateContentConfig(**config)
kwargs["content"] = content
kwargs["content"] = content_list
return kwargs

def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
Expand Down

0 comments on commit 49f6bf9

Please sign in to comment.