Skip to content

Commit 80fbb5d

Browse files
committed
add gemini support via openai
1 parent 030dc45 commit 80fbb5d

File tree

5 files changed

+365
-4
lines changed

5 files changed

+365
-4
lines changed

examples/example.ts

+22-2
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,35 @@
55
* npx create-browser-app@latest my-browser-app
66
*/
77

8-
import { Stagehand } from "@/dist";
8+
import { AvailableModel, Stagehand } from "@/dist";
99
import StagehandConfig from "@/stagehand.config";
10-
10+
import { z } from "zod";
1111
async function example() {
12+
const modelName = "cerebras-llama-3.3-70b";
13+
// const modelName = "gemini-2.0-flash";
1214
const stagehand = new Stagehand({
1315
...StagehandConfig,
16+
env: "LOCAL",
17+
modelName,
18+
modelClientOptions: {
19+
apiKey:
20+
modelName === ("gemini-2.0-flash" as AvailableModel)
21+
? process.env.GOOGLE_API_KEY
22+
: process.env.CEREBRAS_API_KEY,
23+
},
1424
});
1525
await stagehand.init();
1626
await stagehand.page.goto("https://docs.stagehand.dev");
27+
await stagehand.page.act("Click the quickstart");
28+
const { text } = await stagehand.page.extract({
29+
instruction: "Extract the title",
30+
schema: z.object({
31+
text: z.string(),
32+
}),
33+
useTextExtract: true,
34+
});
35+
console.log(text);
36+
await stagehand.close();
1737
}
1838

1939
(async () => {

lib/llm/GoogleClient.ts

+328
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
import OpenAI from "openai";
2+
import type { ClientOptions } from "openai";
3+
import { zodToJsonSchema } from "zod-to-json-schema";
4+
import { LogLine } from "../../types/log";
5+
import { AvailableModel } from "../../types/model";
6+
import { LLMCache } from "../cache/LLMCache";
7+
import {
8+
ChatMessage,
9+
CreateChatCompletionOptions,
10+
LLMClient,
11+
LLMResponse,
12+
} from "./LLMClient";
13+
14+
export class GoogleClient extends LLMClient {
15+
public type = "google" as const;
16+
private client: OpenAI;
17+
private cache: LLMCache | undefined;
18+
private enableCaching: boolean;
19+
public clientOptions: ClientOptions;
20+
public hasVision = false;
21+
22+
constructor({
23+
enableCaching = false,
24+
cache,
25+
modelName,
26+
clientOptions,
27+
userProvidedInstructions,
28+
}: {
29+
logger: (message: LogLine) => void;
30+
enableCaching?: boolean;
31+
cache?: LLMCache;
32+
modelName: AvailableModel;
33+
clientOptions?: ClientOptions;
34+
userProvidedInstructions?: string;
35+
}) {
36+
super(modelName, userProvidedInstructions);
37+
38+
// Create OpenAI client with the base URL set to Google API
39+
this.client = new OpenAI({
40+
baseURL: "https://generativelanguage.googleapis.com/v1beta/openai/",
41+
apiKey: clientOptions?.apiKey || process.env.GEMINI_API_KEY,
42+
...clientOptions,
43+
});
44+
45+
this.cache = cache;
46+
this.enableCaching = enableCaching;
47+
this.modelName = modelName;
48+
this.clientOptions = clientOptions;
49+
}
50+
51+
async createChatCompletion<T = LLMResponse>({
52+
options,
53+
retries,
54+
logger,
55+
}: CreateChatCompletionOptions): Promise<T> {
56+
const optionsWithoutImage = { ...options };
57+
delete optionsWithoutImage.image;
58+
59+
logger({
60+
category: "google",
61+
message: "creating chat completion",
62+
level: 1,
63+
auxiliary: {
64+
options: {
65+
value: JSON.stringify(optionsWithoutImage),
66+
type: "object",
67+
},
68+
},
69+
});
70+
71+
// Try to get cached response
72+
const cacheOptions = {
73+
model: this.modelName,
74+
messages: options.messages,
75+
temperature: options.temperature,
76+
response_model: options.response_model,
77+
tools: options.tools,
78+
retries: retries,
79+
};
80+
81+
if (this.enableCaching) {
82+
const cachedResponse = await this.cache.get<T>(
83+
cacheOptions,
84+
options.requestId,
85+
);
86+
if (cachedResponse) {
87+
logger({
88+
category: "llm_cache",
89+
message: "LLM cache hit - returning cached response",
90+
level: 1,
91+
auxiliary: {
92+
cachedResponse: {
93+
value: JSON.stringify(cachedResponse),
94+
type: "object",
95+
},
96+
requestId: {
97+
value: options.requestId,
98+
type: "string",
99+
},
100+
cacheOptions: {
101+
value: JSON.stringify(cacheOptions),
102+
type: "object",
103+
},
104+
},
105+
});
106+
return cachedResponse as T;
107+
}
108+
}
109+
110+
// Format messages for Google API (using OpenAI format)
111+
const formattedMessages = options.messages.map((msg: ChatMessage) => {
112+
const baseMessage = {
113+
content:
114+
typeof msg.content === "string"
115+
? msg.content
116+
: Array.isArray(msg.content) &&
117+
msg.content.length > 0 &&
118+
"text" in msg.content[0]
119+
? msg.content[0].text
120+
: "",
121+
};
122+
123+
// Google only supports system, user, and assistant roles
124+
if (msg.role === "system") {
125+
return { ...baseMessage, role: "system" as const };
126+
} else if (msg.role === "assistant") {
127+
return { ...baseMessage, role: "assistant" as const };
128+
} else {
129+
// Default to user for any other role
130+
return { ...baseMessage, role: "user" as const };
131+
}
132+
});
133+
134+
// Format tools if provided
135+
let tools = options.tools?.map((tool) => ({
136+
type: "function" as const,
137+
function: {
138+
name: tool.name,
139+
description: tool.description,
140+
parameters: {
141+
type: "object",
142+
properties: tool.parameters.properties,
143+
required: tool.parameters.required,
144+
},
145+
},
146+
}));
147+
148+
// Add response model as a tool if provided
149+
if (options.response_model) {
150+
const jsonSchema = zodToJsonSchema(options.response_model.schema) as {
151+
properties?: Record<string, unknown>;
152+
required?: string[];
153+
};
154+
const schemaProperties = jsonSchema.properties || {};
155+
const schemaRequired = jsonSchema.required || [];
156+
157+
const responseTool = {
158+
type: "function" as const,
159+
function: {
160+
name: "print_extracted_data",
161+
description:
162+
"Prints the extracted data based on the provided schema.",
163+
parameters: {
164+
type: "object",
165+
properties: schemaProperties,
166+
required: schemaRequired,
167+
},
168+
},
169+
};
170+
171+
tools = tools ? [...tools, responseTool] : [responseTool];
172+
}
173+
174+
try {
175+
// Use OpenAI client with Google API
176+
const apiResponse = await this.client.chat.completions.create({
177+
model: this.modelName,
178+
messages: [
179+
...formattedMessages,
180+
// Add explicit instruction to return JSON if we have a response model
181+
...(options.response_model
182+
? [
183+
{
184+
role: "system" as const,
185+
content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify(options.response_model.schema)}`,
186+
},
187+
]
188+
: []),
189+
],
190+
temperature: options.temperature || 0.7,
191+
max_tokens: options.maxTokens,
192+
tools: tools,
193+
tool_choice: options.tool_choice || "auto",
194+
});
195+
196+
// Format the response to match the expected LLMResponse format
197+
const response: LLMResponse = {
198+
id: apiResponse.id,
199+
object: "chat.completion",
200+
created: Date.now(),
201+
model: this.modelName,
202+
choices: [
203+
{
204+
index: 0,
205+
message: {
206+
role: "assistant",
207+
content: apiResponse.choices[0]?.message?.content || null,
208+
tool_calls: apiResponse.choices[0]?.message?.tool_calls || [],
209+
},
210+
finish_reason: apiResponse.choices[0]?.finish_reason || "stop",
211+
},
212+
],
213+
usage: {
214+
prompt_tokens: apiResponse.usage?.prompt_tokens || 0,
215+
completion_tokens: apiResponse.usage?.completion_tokens || 0,
216+
total_tokens: apiResponse.usage?.total_tokens || 0,
217+
},
218+
};
219+
220+
logger({
221+
category: "google",
222+
message: "response",
223+
level: 1,
224+
auxiliary: {
225+
response: {
226+
value: JSON.stringify(response),
227+
type: "object",
228+
},
229+
requestId: {
230+
value: options.requestId,
231+
type: "string",
232+
},
233+
},
234+
});
235+
236+
if (options.response_model) {
237+
// First try standard function calling format
238+
const toolCall = response.choices[0]?.message?.tool_calls?.[0];
239+
if (toolCall?.function?.arguments) {
240+
try {
241+
const result = JSON.parse(toolCall.function.arguments);
242+
if (this.enableCaching) {
243+
this.cache.set(cacheOptions, result, options.requestId);
244+
}
245+
return result as T;
246+
} catch (e) {
247+
// If JSON parse fails, the model might be returning a different format
248+
logger({
249+
category: "google",
250+
message: "failed to parse tool call arguments as JSON, retrying",
251+
level: 1,
252+
auxiliary: {
253+
error: {
254+
value: e.message,
255+
type: "string",
256+
},
257+
},
258+
});
259+
}
260+
}
261+
262+
// If we have content but no tool calls, try to parse the content as JSON
263+
const content = response.choices[0]?.message?.content;
264+
if (content) {
265+
try {
266+
// Try to extract JSON from the content
267+
const jsonMatch = content.match(/\{[\s\S]*\}/);
268+
if (jsonMatch) {
269+
const result = JSON.parse(jsonMatch[0]);
270+
if (this.enableCaching) {
271+
this.cache.set(cacheOptions, result, options.requestId);
272+
}
273+
return result as T;
274+
}
275+
} catch (e) {
276+
logger({
277+
category: "google",
278+
message: "failed to parse content as JSON",
279+
level: 1,
280+
auxiliary: {
281+
error: {
282+
value: e.message,
283+
type: "string",
284+
},
285+
},
286+
});
287+
}
288+
}
289+
290+
// If we still haven't found valid JSON and have retries left, try again
291+
if (!retries || retries < 5) {
292+
return this.createChatCompletion({
293+
options,
294+
logger,
295+
retries: (retries ?? 0) + 1,
296+
});
297+
}
298+
299+
throw new Error(
300+
"Create Chat Completion Failed: Could not extract valid JSON from response",
301+
);
302+
}
303+
304+
if (this.enableCaching) {
305+
this.cache.set(cacheOptions, response, options.requestId);
306+
}
307+
308+
return response as T;
309+
} catch (error) {
310+
logger({
311+
category: "google",
312+
message: "error creating chat completion",
313+
level: 1,
314+
auxiliary: {
315+
error: {
316+
value: error.message,
317+
type: "string",
318+
},
319+
requestId: {
320+
value: options.requestId,
321+
type: "string",
322+
},
323+
},
324+
});
325+
throw error;
326+
}
327+
}
328+
}

lib/llm/LLMClient.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ export interface CreateChatCompletionOptions {
8181
}
8282

8383
export abstract class LLMClient {
84-
public type: "openai" | "anthropic" | "cerebras" | string;
84+
public type: "openai" | "anthropic" | "cerebras" | "google" | string;
8585
public modelName: AvailableModel;
8686
public hasVision: boolean;
8787
public clientOptions: ClientOptions;

0 commit comments

Comments
 (0)